Dataset exploration¶
OPFData samples are PyG HeteroData graphs with several node types (bus, generator, load, shunt) and edge types (ac_line, transformer, generator_link, load_link, shunt_link). This notebook walks through what's inside a single sample and how to inspect feature/target tensors.
In [ ]:
Copied!
from lumina.dataset.opf.opf_dataset import OPFDataset
DATA_ROOT = '/path/to/datasets'
ds = OPFDataset(
root=DATA_ROOT,
case_name='pglib_opf_case14_ieee',
group_id=0,
)
sample = ds[0]
sample
from lumina.dataset.opf.opf_dataset import OPFDataset
DATA_ROOT = '/path/to/datasets'
ds = OPFDataset(
root=DATA_ROOT,
case_name='pglib_opf_case14_ieee',
group_id=0,
)
sample = ds[0]
sample
Node types¶
In [ ]:
Copied!
for ntype in sample.node_types:
x = sample[ntype].get('x')
y = sample[ntype].get('y')
n = x.shape[0] if x is not None else 0
fdim = x.shape[1] if x is not None else 0
tdim = y.shape[1] if y is not None else 0
print(f'{ntype:>10s} count={n:4d} features={fdim:3d} targets={tdim:3d}')
for ntype in sample.node_types:
x = sample[ntype].get('x')
y = sample[ntype].get('y')
n = x.shape[0] if x is not None else 0
fdim = x.shape[1] if x is not None else 0
tdim = y.shape[1] if y is not None else 0
print(f'{ntype:>10s} count={n:4d} features={fdim:3d} targets={tdim:3d}')
Edge types¶
In [ ]:
Copied!
for etype in sample.edge_types:
ei = sample[etype].edge_index
ea = sample[etype].get('edge_attr')
print(f'{str(etype):60s} edges={ei.shape[1]:4d} attr_dim={ea.shape[1] if ea is not None else 0}')
for etype in sample.edge_types:
ei = sample[etype].edge_index
ea = sample[etype].get('edge_attr')
print(f'{str(etype):60s} edges={ei.shape[1]:4d} attr_dim={ea.shape[1] if ea is not None else 0}')
Schema reference¶
The Pydantic schemas in lumina/dataset/opf/schema.py document each feature and target column. They're the source of truth for column ordering and units.
In [ ]:
Copied!
from lumina.dataset.opf import schema
for s in [schema.JSONBus, schema.JSONGenerator, schema.JSONLoad]:
print(f'-- {s.__name__} --')
for name, field in s.model_fields.items():
desc = (field.description or '').splitlines()[0]
print(f' {name:24s} {desc}')
print()
from lumina.dataset.opf import schema
for s in [schema.JSONBus, schema.JSONGenerator, schema.JSONLoad]:
print(f'-- {s.__name__} --')
for name, field in s.model_fields.items():
desc = (field.description or '').splitlines()[0]
print(f' {name:24s} {desc}')
print()
Visualize the bus topology¶
Convert the bus -> ac_line -> bus subgraph to NetworkX for a quick layout.
In [ ]:
Copied!
import networkx as nx
import matplotlib.pyplot as plt
ei = sample[('bus', 'ac_line', 'bus')].edge_index.numpy()
G = nx.Graph()
G.add_edges_from(zip(ei[0], ei[1]))
pos = nx.kamada_kawai_layout(G)
fig, ax = plt.subplots(figsize=(6, 5))
nx.draw_networkx_nodes(G, pos, node_size=300, node_color='#4F86F7', ax=ax)
nx.draw_networkx_labels(G, pos, font_size=9, ax=ax)
nx.draw_networkx_edges(G, pos, alpha=0.6, ax=ax)
ax.set_title('case14 bus topology'); ax.axis('off')
import networkx as nx
import matplotlib.pyplot as plt
ei = sample[('bus', 'ac_line', 'bus')].edge_index.numpy()
G = nx.Graph()
G.add_edges_from(zip(ei[0], ei[1]))
pos = nx.kamada_kawai_layout(G)
fig, ax = plt.subplots(figsize=(6, 5))
nx.draw_networkx_nodes(G, pos, node_size=300, node_color='#4F86F7', ax=ax)
nx.draw_networkx_labels(G, pos, font_size=9, ax=ax)
nx.draw_networkx_edges(G, pos, alpha=0.6, ax=ax)
ax.set_title('case14 bus topology'); ax.axis('off')
Distribution of bus voltages (targets)¶
In [ ]:
Copied!
import torch
y_bus = torch.stack([ds[i]['bus'].y for i in range(len(ds))])
vm = y_bus[..., 0].flatten().numpy()
va = y_bus[..., 1].flatten().numpy()
fig, axes = plt.subplots(1, 2, figsize=(10, 3.5))
axes[0].hist(vm, bins=60); axes[0].set_title('voltage magnitude'); axes[0].set_xlabel('p.u.')
axes[1].hist(va, bins=60); axes[1].set_title('voltage angle'); axes[1].set_xlabel('rad')
for ax in axes: ax.grid(True, alpha=0.3)
import torch
y_bus = torch.stack([ds[i]['bus'].y for i in range(len(ds))])
vm = y_bus[..., 0].flatten().numpy()
va = y_bus[..., 1].flatten().numpy()
fig, axes = plt.subplots(1, 2, figsize=(10, 3.5))
axes[0].hist(vm, bins=60); axes[0].set_title('voltage magnitude'); axes[0].set_xlabel('p.u.')
axes[1].hist(va, bins=60); axes[1].set_title('voltage angle'); axes[1].set_xlabel('rad')
for ax in axes: ax.grid(True, alpha=0.3)