Quickstart: Train a HeteroGNN on case14¶
This notebook walks through the smallest end-to-end LUMINA workflow:
- Load the IEEE 14-bus OPFData group
- Build a small
OPFHeteroGNN - Train for a few epochs and plot the loss curve
Expected dataset layout:
<root>/OPFData/raw/dataset_release_1/<case>_<group>.tar.gz
<root>/OPFData/processed/dataset_release_1/<case>/group_<id>.pt
In [ ]:
Copied!
import torch
from lumina.dataset.opf.opf_dataset import OPFDataset
from lumina.dataset.opf.transforms import to_float32
from lumina.loader.opf.opf_loader import DataLoader
from lumina.model.opf.hetero_model import OPFHeteroGNN
from lumina.model.opf.losses import OPFLossManager
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DEVICE
import torch
from lumina.dataset.opf.opf_dataset import OPFDataset
from lumina.dataset.opf.transforms import to_float32
from lumina.loader.opf.opf_loader import DataLoader
from lumina.model.opf.hetero_model import OPFHeteroGNN
from lumina.model.opf.losses import OPFLossManager
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DEVICE
1. Load the dataset¶
Point root at the directory that contains OPFData/. The dataset will look for the preprocessed .pt file under OPFData/processed/dataset_release_1/pglib_opf_case14_ieee/.
In [ ]:
Copied!
DATA_ROOT = '/path/to/datasets' # parent of OPFData/
ds = OPFDataset(
root=DATA_ROOT,
case_name='pglib_opf_case14_ieee',
group_id=0,
transform=to_float32, # cast features to float32 to match model weights
)
print(f'samples: {len(ds)}')
print(ds[0])
DATA_ROOT = '/path/to/datasets' # parent of OPFData/
ds = OPFDataset(
root=DATA_ROOT,
case_name='pglib_opf_case14_ieee',
group_id=0,
transform=to_float32, # cast features to float32 to match model weights
)
print(f'samples: {len(ds)}')
print(ds[0])
In [ ]:
Copied!
# Train/val split (90/10)
n = len(ds)
n_train = int(0.9 * n)
train_ds, val_ds = ds[:n_train], ds[n_train:]
train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=64)
# Train/val split (90/10)
n = len(ds)
n_train = int(0.9 * n)
train_ds, val_ds = ds[:n_train], ds[n_train:]
train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=64)
2. Build the model¶
OPFHeteroGNN exposes the standard PyG hetero backends. SAGE is a good default for case14.
In [ ]:
Copied!
sample = ds[0]
input_channels = {nt: sample[nt].x.size(-1) for nt in sample.node_types}
model = OPFHeteroGNN(
metadata=sample.metadata(),
input_channels=input_channels,
hidden_channels=64,
num_layers=3,
backend='sage',
).to(DEVICE)
model
sample = ds[0]
input_channels = {nt: sample[nt].x.size(-1) for nt in sample.node_types}
model = OPFHeteroGNN(
metadata=sample.metadata(),
input_channels=input_channels,
hidden_channels=64,
num_layers=3,
backend='sage',
).to(DEVICE)
model
3. Train¶
Use plain MSE here. See the physics-informed loss notebook for a richer loss setup.
In [ ]:
Copied!
optim = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-5)
loss_mgr = OPFLossManager(loss_type='mse')
EPOCHS = 5
history = {'train': [], 'val': []}
for epoch in range(EPOCHS):
model.train()
train_losses = []
for batch in train_loader:
batch = batch.to(DEVICE)
optim.zero_grad()
pred = model(batch.x_dict, batch.edge_index_dict, batch.edge_attr_dict)
loss, _ = loss_mgr.compute_loss(pred, batch)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optim.step()
train_losses.append(loss.item())
model.eval()
val_losses = []
with torch.no_grad():
for batch in val_loader:
batch = batch.to(DEVICE)
pred = model(batch.x_dict, batch.edge_index_dict, batch.edge_attr_dict)
val_loss, _ = loss_mgr.compute_loss(pred, batch)
val_losses.append(val_loss.item())
history['train'].append(sum(train_losses) / len(train_losses))
history['val'].append(sum(val_losses) / len(val_losses))
print(f'epoch {epoch:2d} train {history["train"][-1]:.4e} val {history["val"][-1]:.4e}')
# Save a Modeler-compatible checkpoint so 03_evaluation_walkthrough.ipynb can load it.
torch.save({
'model_class': 'lumina.model.opf.hetero_model.OPFHeteroGNN',
'model_kwargs': {
'metadata': sample.metadata(),
'input_channels': input_channels,
'hidden_channels': 64,
'num_layers': 3,
'backend': 'sage',
},
'model_state_dict': model.state_dict(),
}, 'case14_quickstart.pt')
print('saved checkpoint -> case14_quickstart.pt')
optim = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-5)
loss_mgr = OPFLossManager(loss_type='mse')
EPOCHS = 5
history = {'train': [], 'val': []}
for epoch in range(EPOCHS):
model.train()
train_losses = []
for batch in train_loader:
batch = batch.to(DEVICE)
optim.zero_grad()
pred = model(batch.x_dict, batch.edge_index_dict, batch.edge_attr_dict)
loss, _ = loss_mgr.compute_loss(pred, batch)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optim.step()
train_losses.append(loss.item())
model.eval()
val_losses = []
with torch.no_grad():
for batch in val_loader:
batch = batch.to(DEVICE)
pred = model(batch.x_dict, batch.edge_index_dict, batch.edge_attr_dict)
val_loss, _ = loss_mgr.compute_loss(pred, batch)
val_losses.append(val_loss.item())
history['train'].append(sum(train_losses) / len(train_losses))
history['val'].append(sum(val_losses) / len(val_losses))
print(f'epoch {epoch:2d} train {history["train"][-1]:.4e} val {history["val"][-1]:.4e}')
# Save a Modeler-compatible checkpoint so 03_evaluation_walkthrough.ipynb can load it.
torch.save({
'model_class': 'lumina.model.opf.hetero_model.OPFHeteroGNN',
'model_kwargs': {
'metadata': sample.metadata(),
'input_channels': input_channels,
'hidden_channels': 64,
'num_layers': 3,
'backend': 'sage',
},
'model_state_dict': model.state_dict(),
}, 'case14_quickstart.pt')
print('saved checkpoint -> case14_quickstart.pt')
In [ ]:
Copied!
import matplotlib.pyplot as plt
plt.plot(history['train'], label='train')
plt.plot(history['val'], label='val')
plt.xlabel('epoch'); plt.ylabel('MSE'); plt.legend(); plt.grid(True)
plt.title('case14 training curve')
import matplotlib.pyplot as plt
plt.plot(history['train'], label='train')
plt.plot(history['val'], label='val')
plt.xlabel('epoch'); plt.ylabel('MSE'); plt.legend(); plt.grid(True)
plt.title('case14 training curve')
Next steps¶
- Dataset exploration — what's inside an OPFData sample
- Evaluation walkthrough — measure constraint violations
- Physics-informed loss demo — beyond MSE
- DDP training (local GPUs) — scale to multi-GPU