Skip to content

Evaluator API

Constraint Evaluator

ACOPFConstraintEvaluator

Bases: Module

Comprehensive constraint violation evaluator for ACOPF problems.

Evaluates bound, power-flow, and line-flow constraint violations given model predictions and network parameters (admittance matrix, limits).

Supported constraint categories
  • Bound constraints: voltage magnitude/angle, active/reactive generation.
  • Power flow constraints: active and reactive power balance at each bus.
  • Line flow constraints: thermal limits on transmission lines.

Parameters:

Name Type Description Default
voltage_limits dict[str, Tensor]

Dict with 'vmin', 'vmax' tensors of shape [n_bus].

None
generation_limits dict[str, Tensor]

Dict with 'pmin', 'pmax', 'qmin', 'qmax' tensors of shape [n_gen].

None
line_limits Tensor

Line flow limits [n_lines].

None
Y_real Tensor

Real part of admittance matrix [n_bus, n_bus].

None
Y_imag Tensor

Imaginary part of admittance matrix [n_bus, n_bus].

None
edge_index Tensor

Line connectivity [2, n_lines].

None
base_mva float

Base MVA for power scaling.

100.0
slack_bus_indices list[int]

Indices of slack buses.

None
device device

Computation device.

None
Source code in lumina/evaluator/opf/evaluator.py
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
class ACOPFConstraintEvaluator(nn.Module):
    """Comprehensive constraint violation evaluator for ACOPF problems.

    Evaluates bound, power-flow, and line-flow constraint violations given
    model predictions and network parameters (admittance matrix, limits).

    Supported constraint categories:
        - Bound constraints: voltage magnitude/angle, active/reactive generation.
        - Power flow constraints: active and reactive power balance at each bus.
        - Line flow constraints: thermal limits on transmission lines.

    Args:
        voltage_limits (dict[str, torch.Tensor], optional): Dict with ``'vmin'``,
            ``'vmax'`` tensors of shape ``[n_bus]``.
        generation_limits (dict[str, torch.Tensor], optional): Dict with
            ``'pmin'``, ``'pmax'``, ``'qmin'``, ``'qmax'`` tensors of
            shape ``[n_gen]``.
        line_limits (torch.Tensor, optional): Line flow limits ``[n_lines]``.
        Y_real (torch.Tensor, optional): Real part of admittance matrix
            ``[n_bus, n_bus]``.
        Y_imag (torch.Tensor, optional): Imaginary part of admittance matrix
            ``[n_bus, n_bus]``.
        edge_index (torch.Tensor, optional): Line connectivity ``[2, n_lines]``.
        base_mva (float): Base MVA for power scaling.
        slack_bus_indices (list[int], optional): Indices of slack buses.
        device (torch.device, optional): Computation device.
    """

    def __init__(
        self,
        voltage_limits: Optional[Dict[str, torch.Tensor]] = None,
        generation_limits: Optional[Dict[str, torch.Tensor]] = None,
        line_limits: Optional[torch.Tensor] = None,
        Y_real: Optional[torch.Tensor] = None,
        Y_imag: Optional[torch.Tensor] = None,
        edge_index: Optional[torch.Tensor] = None,
        base_mva: float = 100.0,
        slack_bus_indices: Optional[List[int]] = None,
        device: Optional[torch.device] = None
    ):
        """
        Initialize the constraint evaluator.

        Args:
            voltage_limits: Dict with 'vmin', 'vmax' voltage limits [n_bus]
            generation_limits: Dict with 'pmin', 'pmax', 'qmin', 'qmax' [n_gen]
            line_limits: Line flow limits [n_lines]
            Y_real: Real part of admittance matrix [n_bus, n_bus]
            Y_imag: Imaginary part of admittance matrix [n_bus, n_bus]
            edge_index: Line connectivity [2, n_lines]
            base_mva: Base MVA for power scaling
            device: Computation device
        """
        super().__init__()

        self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.base_mva = base_mva

        # Store network parameters
        self.voltage_limits = voltage_limits
        self.generation_limits = generation_limits
        self.line_limits = line_limits
        self.Y_real = Y_real
        self.Y_imag = Y_imag
        self.edge_index = edge_index
        self.slack_bus_indices = slack_bus_indices or [0]
        self.total_real_power_demand = 0.0

        # Move tensors to device if provided
        self._move_to_device()

    def _move_to_device(self):
        """Move all tensors to the specified device."""
        if self.voltage_limits:
            for key, tensor in self.voltage_limits.items():
                if tensor is not None:
                    self.voltage_limits[key] = tensor.to(self.device)

        if self.generation_limits:
            for key, tensor in self.generation_limits.items():
                if tensor is not None:
                    self.generation_limits[key] = tensor.to(self.device)

        if self.line_limits is not None:
            self.line_limits = self.line_limits.to(self.device)

        if self.Y_real is not None:
            self.Y_real = self.Y_real.to(self.device)

        if self.Y_imag is not None:
            self.Y_imag = self.Y_imag.to(self.device)

        if self.edge_index is not None:
            self.edge_index = self.edge_index.to(self.device)

    def _group_node_predictions(self, node_pred: torch.Tensor, node_batch: torch.Tensor) -> torch.Tensor:
        """Group node-level predictions into per-sample tensors.

        Args:
            node_pred: Tensor of shape [total_nodes, feat]
            node_batch: LongTensor of shape [total_nodes] with sample indices

        Returns:
            grouped: Tensor of shape [batch_size, n_nodes_per_sample, feat]
        """
        # Ensure on same device
        node_pred = node_pred.to(self.device)
        node_batch = node_batch.to(self.device)

        if node_batch.numel() == 0:
            return node_pred.unsqueeze(0)

        batch_size = int(node_batch.max().item()) + 1
        counts = torch.bincount(node_batch.cpu())
        # If number of nodes per sample varies, use the max and pad as needed
        n_nodes = int(counts.max().item())

        feat = node_pred.size(1) if node_pred.dim() > 1 else 1
        grouped = node_pred.new_zeros((batch_size, n_nodes, feat))

        for i in range(batch_size):
            mask = (node_batch == i)
            grp = node_pred[mask]
            if grp.size(0) == 0:
                continue
            if grp.size(0) < n_nodes:
                pad = node_pred.new_zeros((n_nodes - grp.size(0), feat))
                grp = torch.cat([grp, pad], dim=0)
            elif grp.size(0) > n_nodes:
                grp = grp[:n_nodes]
            grouped[i] = grp

        return grouped

    def set_network_parameters(
        self,
        voltage_limits: Optional[Dict[str, torch.Tensor]] = None,
        generation_limits: Optional[Dict[str, torch.Tensor]] = None,
        line_limits: Optional[torch.Tensor] = None,
        Y_real: Optional[torch.Tensor] = None,
        Y_imag: Optional[torch.Tensor] = None,
        edge_index: Optional[torch.Tensor] = None
    ):
        """Update network parameters for constraint evaluation.

        Any argument that is not ``None`` replaces the corresponding stored
        parameter.  All tensors are moved to the evaluator's device after
        assignment.

        Args:
            voltage_limits (dict[str, torch.Tensor], optional): Updated voltage
                limits.
            generation_limits (dict[str, torch.Tensor], optional): Updated
                generation limits.
            line_limits (torch.Tensor, optional): Updated line flow limits.
            Y_real (torch.Tensor, optional): Updated real admittance matrix.
            Y_imag (torch.Tensor, optional): Updated imaginary admittance matrix.
            edge_index (torch.Tensor, optional): Updated edge connectivity.
        """
        if voltage_limits is not None:
            self.voltage_limits = voltage_limits
        if generation_limits is not None:
            self.generation_limits = generation_limits
        if line_limits is not None:
            self.line_limits = line_limits
        if Y_real is not None:
            self.Y_real = Y_real
        if Y_imag is not None:
            self.Y_imag = Y_imag
        if edge_index is not None:
            self.edge_index = edge_index

        self._move_to_device()

    def evaluate_bound_constraints(
        self,
        predictions: Dict[str, torch.Tensor],
        batch_data=None,
        return_individual: bool = True
    ) -> Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]:
        """Evaluate bound constraint violations for voltage and generation.

        Computes ReLU-based violation magnitudes for voltage magnitude
        bounds, active power generation bounds, and reactive power generation
        bounds.

        Args:
            predictions (dict[str, torch.Tensor]): Model predictions with
                ``'bus'`` ``[total_bus, 2]`` (VA, VM) and ``'generator'``
                ``[total_gen, 2]`` (PG, QG) keys.
            batch_data: Batch data (unused, reserved for future extensions).
            return_individual (bool): If ``True``, include per-direction
                violation components (e.g. ``vm_low_violations``,
                ``vm_high_violations``).

        Returns:
            dict[str, torch.Tensor]: Violation metrics including
                ``'total_bound_violations'`` and optionally individual
                component violations.
        """
        violations = {}

        # Extract predictions (concatenated over the batch: [total_nodes, feat])
        # Model outputs bus as [VA, VM], generator as [PG, QG]
        bus_pred = predictions['bus']  # [total_bus_nodes, 2] -> [VA, VM]
        gen_pred = predictions['generator']  # [total_gen_nodes, 2] -> [PG, QG]

        # Voltage magnitude and angle violations
        if self.voltage_limits is not None:
            vm = bus_pred[..., 1]  # Voltage magnitude (concatenated)

            if 'vmin' in self.voltage_limits and 'vmax' in self.voltage_limits:
                vmin = self.voltage_limits['vmin']
                vmax = self.voltage_limits['vmax']

                # If vmin/vmax are per-case (n_bus), expand to concatenated per-batch vector
                try:
                    n_bus = int(vmin.numel())
                    if vm.numel() != n_bus:
                        # If vm length is a multiple of n_bus, repeat per sample
                        if vm.numel() % n_bus == 0:
                            warnings.warn(
                                "Repeated voltage limits across samples because bus vector length does not match base limit length. "
                                "If this batch contains variable-size cases, VM bounds may be assigned to wrong buses."
                            )
                            batch_size = vm.numel() // n_bus
                            vmin_cat = vmin.to(self.device).repeat(batch_size)
                            vmax_cat = vmax.to(self.device).repeat(batch_size)
                        else:
                            # Fallback: broadcast where possible
                            warnings.warn(
                                "Broadcasted voltage limits by repetition/truncation because bus vector length is not a multiple "
                                "of base limit length; this can hide mixed-topology misalignment."
                            )
                            vmin_cat = vmin.to(self.device).reshape(-1).repeat(vm.numel() //
                                                                               vmin.numel() + 1)[:vm.numel()]
                            vmax_cat = vmax.to(self.device).reshape(-1).repeat(vm.numel() //
                                                                               vmax.numel() + 1)[:vm.numel()]
                    else:
                        vmin_cat = vmin.to(self.device)
                        vmax_cat = vmax.to(self.device)
                except Exception:
                    vmin_cat = vmin.to(self.device)
                    vmax_cat = vmax.to(self.device)

                # Voltage magnitude violations
                vm_low_viol = torch.relu(vmin_cat - vm)
                vm_high_viol = torch.relu(vm - vmax_cat)
                vm_violations = vm_low_viol + vm_high_viol

                violations['voltage_magnitude'] = vm_violations.mean()

                if return_individual:
                    violations['vm_low_violations'] = vm_low_viol.mean()
                    violations['vm_high_violations'] = vm_high_viol.mean()

        # Generation limit violations
        if self.generation_limits is not None:
            pg = gen_pred[..., 0]  # Active power generation (concatenated)
            qg = gen_pred[..., 1]  # Reactive power generation (concatenated)

            # Active power violations
            if 'pmin' in self.generation_limits and 'pmax' in self.generation_limits:
                pmin = self.generation_limits['pmin']
                pmax = self.generation_limits['pmax']

                try:
                    n_gen = int(pmin.numel())
                    if pg.numel() != n_gen:
                        # Repeat per sample if sizes align
                        if pg.numel() % n_gen == 0:
                            warnings.warn(
                                "Repeated active-power limits across samples because generator vector length does not match base limit length. "
                                "If this batch contains variable-size cases, PG bounds may be assigned to wrong generators."
                            )
                            batch_size_g = pg.numel() // n_gen
                            pmin_cat = pmin.to(self.device).repeat(batch_size_g)
                            pmax_cat = pmax.to(self.device).repeat(batch_size_g)
                        else:
                            warnings.warn(
                                "Broadcasted active-power limits by repetition/truncation because generator vector length is not a multiple "
                                "of base limit length; this can hide mixed-topology misalignment."
                            )
                            pmin_cat = pmin.to(self.device).reshape(-1).repeat(pg.numel() //
                                                                               pmin.numel() + 1)[:pg.numel()]
                            pmax_cat = pmax.to(self.device).reshape(-1).repeat(pg.numel() //
                                                                               pmax.numel() + 1)[:pg.numel()]
                    else:
                        pmin_cat = pmin.to(self.device)
                        pmax_cat = pmax.to(self.device)
                except Exception:
                    pmin_cat = pmin.to(self.device)
                    pmax_cat = pmax.to(self.device)

                pg_low_viol = torch.relu(pmin_cat - pg)
                pg_high_viol = torch.relu(pg - pmax_cat)
                pg_violations = pg_low_viol + pg_high_viol

                violations['active_power_generation'] = pg_violations.mean()

                if return_individual:
                    violations['pg_low_violations'] = pg_low_viol.mean()
                    violations['pg_high_violations'] = pg_high_viol.mean()

            # Reactive power violations
            if 'qmin' in self.generation_limits and 'qmax' in self.generation_limits:
                qmin = self.generation_limits['qmin']
                qmax = self.generation_limits['qmax']

                try:
                    n_genq = int(qmin.numel())
                    if qg.numel() != n_genq:
                        if qg.numel() % n_genq == 0:
                            warnings.warn(
                                "Repeated reactive-power limits across samples because generator vector length does not match base limit length. "
                                "If this batch contains variable-size cases, QG bounds may be assigned to wrong generators."
                            )
                            batch_size_q = qg.numel() // n_genq
                            qmin_cat = qmin.to(self.device).repeat(batch_size_q)
                            qmax_cat = qmax.to(self.device).repeat(batch_size_q)
                        else:
                            warnings.warn(
                                "Broadcasted reactive-power limits by repetition/truncation because generator vector length is not a multiple "
                                "of base limit length; this can hide mixed-topology misalignment."
                            )
                            qmin_cat = qmin.to(self.device).reshape(-1).repeat(qg.numel() //
                                                                               qmin.numel() + 1)[:qg.numel()]
                            qmax_cat = qmax.to(self.device).reshape(-1).repeat(qg.numel() //
                                                                               qmax.numel() + 1)[:qg.numel()]
                    else:
                        qmin_cat = qmin.to(self.device)
                        qmax_cat = qmax.to(self.device)
                except Exception:
                    qmin_cat = qmin.to(self.device)
                    qmax_cat = qmax.to(self.device)

                qg_low_viol = torch.relu(qmin_cat - qg)
                qg_high_viol = torch.relu(qg - qmax_cat)
                qg_violations = qg_low_viol + qg_high_viol

                violations['reactive_power_generation'] = qg_violations.mean()

                if return_individual:
                    violations['qg_low_violations'] = qg_low_viol.mean()
                    violations['qg_high_violations'] = qg_high_viol.mean()

        # Total bound constraint violation
        total_bound_violation = sum(v for k, v in violations.items()
                                    if not k.endswith('_violations') and isinstance(v, torch.Tensor))
        violations['total_bound_violations'] = total_bound_violation

        return violations

    def evaluate_all_constraints(
        self,
        predictions: Dict[str, torch.Tensor],
        batch_data,
        normalize: bool = True,
        return_individual: bool = True,
    ) -> Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]:
        """Evaluate all constraint violations comprehensively.

        Currently delegates to ``evaluate_bound_constraints`` and prefixes
        results with ``bound_``.  Power-flow and line-flow evaluations are
        planned but not yet active.

        Args:
            predictions (dict[str, torch.Tensor]): Model predictions keyed
                by node type.
            batch_data: Batch data containing network information.
            normalize (bool): Whether to normalize violations (reserved for
                future use).
            return_individual (bool): Whether to include individual violation
                components.

        Returns:
            dict[str, torch.Tensor]: All violation metrics, keyed with a
                ``bound_`` prefix.
        """
        all_violations = {}

        # Evaluate bound constraints
        bound_violations = self.evaluate_bound_constraints(predictions, batch_data, return_individual)
        all_violations.update({f"bound_{k}": v for k, v in bound_violations.items()})

        # Compute total constraint violation
        # violation_keys = ['bound_total_bound_violations', 'real_power_flow_violations', 'reactive_power_flow_violations', 'line_flow_violations']
        # total_violation = sum(all_violations.get(key, torch.tensor(0.0, device=self.device))
        #                       for key in violation_keys)
        # all_violations['total_constraint_violations'] = total_violation

        return all_violations

    def get_violation_summary(
        self,
        violations: Dict[str, torch.Tensor]
    ) -> Dict[str, float]:
        """Convert violation tensors to a plain-float summary dictionary.

        Scalar tensors are converted via ``.item()``; multi-element tensors
        are reduced by mean.

        Args:
            violations (dict[str, torch.Tensor]): Dictionary of violation
                tensors as returned by ``evaluate_all_constraints``.

        Returns:
            dict[str, float]: Violation values as Python floats.
        """
        summary = {}

        for key, value in violations.items():
            if isinstance(value, torch.Tensor):
                summary[key] = value.item() if value.numel() == 1 else value.mean().item()
            else:
                summary[key] = float(value)

        return summary

__init__(voltage_limits: Optional[Dict[str, torch.Tensor]] = None, generation_limits: Optional[Dict[str, torch.Tensor]] = None, line_limits: Optional[torch.Tensor] = None, Y_real: Optional[torch.Tensor] = None, Y_imag: Optional[torch.Tensor] = None, edge_index: Optional[torch.Tensor] = None, base_mva: float = 100.0, slack_bus_indices: Optional[List[int]] = None, device: Optional[torch.device] = None)

Initialize the constraint evaluator.

Parameters:

Name Type Description Default
voltage_limits Optional[Dict[str, Tensor]]

Dict with 'vmin', 'vmax' voltage limits [n_bus]

None
generation_limits Optional[Dict[str, Tensor]]

Dict with 'pmin', 'pmax', 'qmin', 'qmax' [n_gen]

None
line_limits Optional[Tensor]

Line flow limits [n_lines]

None
Y_real Optional[Tensor]

Real part of admittance matrix [n_bus, n_bus]

None
Y_imag Optional[Tensor]

Imaginary part of admittance matrix [n_bus, n_bus]

None
edge_index Optional[Tensor]

Line connectivity [2, n_lines]

None
base_mva float

Base MVA for power scaling

100.0
device Optional[device]

Computation device

None
Source code in lumina/evaluator/opf/evaluator.py
def __init__(
    self,
    voltage_limits: Optional[Dict[str, torch.Tensor]] = None,
    generation_limits: Optional[Dict[str, torch.Tensor]] = None,
    line_limits: Optional[torch.Tensor] = None,
    Y_real: Optional[torch.Tensor] = None,
    Y_imag: Optional[torch.Tensor] = None,
    edge_index: Optional[torch.Tensor] = None,
    base_mva: float = 100.0,
    slack_bus_indices: Optional[List[int]] = None,
    device: Optional[torch.device] = None
):
    """
    Initialize the constraint evaluator.

    Args:
        voltage_limits: Dict with 'vmin', 'vmax' voltage limits [n_bus]
        generation_limits: Dict with 'pmin', 'pmax', 'qmin', 'qmax' [n_gen]
        line_limits: Line flow limits [n_lines]
        Y_real: Real part of admittance matrix [n_bus, n_bus]
        Y_imag: Imaginary part of admittance matrix [n_bus, n_bus]
        edge_index: Line connectivity [2, n_lines]
        base_mva: Base MVA for power scaling
        device: Computation device
    """
    super().__init__()

    self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    self.base_mva = base_mva

    # Store network parameters
    self.voltage_limits = voltage_limits
    self.generation_limits = generation_limits
    self.line_limits = line_limits
    self.Y_real = Y_real
    self.Y_imag = Y_imag
    self.edge_index = edge_index
    self.slack_bus_indices = slack_bus_indices or [0]
    self.total_real_power_demand = 0.0

    # Move tensors to device if provided
    self._move_to_device()

set_network_parameters(voltage_limits: Optional[Dict[str, torch.Tensor]] = None, generation_limits: Optional[Dict[str, torch.Tensor]] = None, line_limits: Optional[torch.Tensor] = None, Y_real: Optional[torch.Tensor] = None, Y_imag: Optional[torch.Tensor] = None, edge_index: Optional[torch.Tensor] = None)

Update network parameters for constraint evaluation.

Any argument that is not None replaces the corresponding stored parameter. All tensors are moved to the evaluator's device after assignment.

Parameters:

Name Type Description Default
voltage_limits dict[str, Tensor]

Updated voltage limits.

None
generation_limits dict[str, Tensor]

Updated generation limits.

None
line_limits Tensor

Updated line flow limits.

None
Y_real Tensor

Updated real admittance matrix.

None
Y_imag Tensor

Updated imaginary admittance matrix.

None
edge_index Tensor

Updated edge connectivity.

None
Source code in lumina/evaluator/opf/evaluator.py
def set_network_parameters(
    self,
    voltage_limits: Optional[Dict[str, torch.Tensor]] = None,
    generation_limits: Optional[Dict[str, torch.Tensor]] = None,
    line_limits: Optional[torch.Tensor] = None,
    Y_real: Optional[torch.Tensor] = None,
    Y_imag: Optional[torch.Tensor] = None,
    edge_index: Optional[torch.Tensor] = None
):
    """Update network parameters for constraint evaluation.

    Any argument that is not ``None`` replaces the corresponding stored
    parameter.  All tensors are moved to the evaluator's device after
    assignment.

    Args:
        voltage_limits (dict[str, torch.Tensor], optional): Updated voltage
            limits.
        generation_limits (dict[str, torch.Tensor], optional): Updated
            generation limits.
        line_limits (torch.Tensor, optional): Updated line flow limits.
        Y_real (torch.Tensor, optional): Updated real admittance matrix.
        Y_imag (torch.Tensor, optional): Updated imaginary admittance matrix.
        edge_index (torch.Tensor, optional): Updated edge connectivity.
    """
    if voltage_limits is not None:
        self.voltage_limits = voltage_limits
    if generation_limits is not None:
        self.generation_limits = generation_limits
    if line_limits is not None:
        self.line_limits = line_limits
    if Y_real is not None:
        self.Y_real = Y_real
    if Y_imag is not None:
        self.Y_imag = Y_imag
    if edge_index is not None:
        self.edge_index = edge_index

    self._move_to_device()

evaluate_bound_constraints(predictions: Dict[str, torch.Tensor], batch_data=None, return_individual: bool = True) -> Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]

Evaluate bound constraint violations for voltage and generation.

Computes ReLU-based violation magnitudes for voltage magnitude bounds, active power generation bounds, and reactive power generation bounds.

Parameters:

Name Type Description Default
predictions dict[str, Tensor]

Model predictions with 'bus' [total_bus, 2] (VA, VM) and 'generator' [total_gen, 2] (PG, QG) keys.

required
batch_data

Batch data (unused, reserved for future extensions).

None
return_individual bool

If True, include per-direction violation components (e.g. vm_low_violations, vm_high_violations).

True

Returns:

Type Description
Dict[str, Union[Tensor, Dict[str, Tensor]]]

dict[str, torch.Tensor]: Violation metrics including 'total_bound_violations' and optionally individual component violations.

Source code in lumina/evaluator/opf/evaluator.py
def evaluate_bound_constraints(
    self,
    predictions: Dict[str, torch.Tensor],
    batch_data=None,
    return_individual: bool = True
) -> Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]:
    """Evaluate bound constraint violations for voltage and generation.

    Computes ReLU-based violation magnitudes for voltage magnitude
    bounds, active power generation bounds, and reactive power generation
    bounds.

    Args:
        predictions (dict[str, torch.Tensor]): Model predictions with
            ``'bus'`` ``[total_bus, 2]`` (VA, VM) and ``'generator'``
            ``[total_gen, 2]`` (PG, QG) keys.
        batch_data: Batch data (unused, reserved for future extensions).
        return_individual (bool): If ``True``, include per-direction
            violation components (e.g. ``vm_low_violations``,
            ``vm_high_violations``).

    Returns:
        dict[str, torch.Tensor]: Violation metrics including
            ``'total_bound_violations'`` and optionally individual
            component violations.
    """
    violations = {}

    # Extract predictions (concatenated over the batch: [total_nodes, feat])
    # Model outputs bus as [VA, VM], generator as [PG, QG]
    bus_pred = predictions['bus']  # [total_bus_nodes, 2] -> [VA, VM]
    gen_pred = predictions['generator']  # [total_gen_nodes, 2] -> [PG, QG]

    # Voltage magnitude and angle violations
    if self.voltage_limits is not None:
        vm = bus_pred[..., 1]  # Voltage magnitude (concatenated)

        if 'vmin' in self.voltage_limits and 'vmax' in self.voltage_limits:
            vmin = self.voltage_limits['vmin']
            vmax = self.voltage_limits['vmax']

            # If vmin/vmax are per-case (n_bus), expand to concatenated per-batch vector
            try:
                n_bus = int(vmin.numel())
                if vm.numel() != n_bus:
                    # If vm length is a multiple of n_bus, repeat per sample
                    if vm.numel() % n_bus == 0:
                        warnings.warn(
                            "Repeated voltage limits across samples because bus vector length does not match base limit length. "
                            "If this batch contains variable-size cases, VM bounds may be assigned to wrong buses."
                        )
                        batch_size = vm.numel() // n_bus
                        vmin_cat = vmin.to(self.device).repeat(batch_size)
                        vmax_cat = vmax.to(self.device).repeat(batch_size)
                    else:
                        # Fallback: broadcast where possible
                        warnings.warn(
                            "Broadcasted voltage limits by repetition/truncation because bus vector length is not a multiple "
                            "of base limit length; this can hide mixed-topology misalignment."
                        )
                        vmin_cat = vmin.to(self.device).reshape(-1).repeat(vm.numel() //
                                                                           vmin.numel() + 1)[:vm.numel()]
                        vmax_cat = vmax.to(self.device).reshape(-1).repeat(vm.numel() //
                                                                           vmax.numel() + 1)[:vm.numel()]
                else:
                    vmin_cat = vmin.to(self.device)
                    vmax_cat = vmax.to(self.device)
            except Exception:
                vmin_cat = vmin.to(self.device)
                vmax_cat = vmax.to(self.device)

            # Voltage magnitude violations
            vm_low_viol = torch.relu(vmin_cat - vm)
            vm_high_viol = torch.relu(vm - vmax_cat)
            vm_violations = vm_low_viol + vm_high_viol

            violations['voltage_magnitude'] = vm_violations.mean()

            if return_individual:
                violations['vm_low_violations'] = vm_low_viol.mean()
                violations['vm_high_violations'] = vm_high_viol.mean()

    # Generation limit violations
    if self.generation_limits is not None:
        pg = gen_pred[..., 0]  # Active power generation (concatenated)
        qg = gen_pred[..., 1]  # Reactive power generation (concatenated)

        # Active power violations
        if 'pmin' in self.generation_limits and 'pmax' in self.generation_limits:
            pmin = self.generation_limits['pmin']
            pmax = self.generation_limits['pmax']

            try:
                n_gen = int(pmin.numel())
                if pg.numel() != n_gen:
                    # Repeat per sample if sizes align
                    if pg.numel() % n_gen == 0:
                        warnings.warn(
                            "Repeated active-power limits across samples because generator vector length does not match base limit length. "
                            "If this batch contains variable-size cases, PG bounds may be assigned to wrong generators."
                        )
                        batch_size_g = pg.numel() // n_gen
                        pmin_cat = pmin.to(self.device).repeat(batch_size_g)
                        pmax_cat = pmax.to(self.device).repeat(batch_size_g)
                    else:
                        warnings.warn(
                            "Broadcasted active-power limits by repetition/truncation because generator vector length is not a multiple "
                            "of base limit length; this can hide mixed-topology misalignment."
                        )
                        pmin_cat = pmin.to(self.device).reshape(-1).repeat(pg.numel() //
                                                                           pmin.numel() + 1)[:pg.numel()]
                        pmax_cat = pmax.to(self.device).reshape(-1).repeat(pg.numel() //
                                                                           pmax.numel() + 1)[:pg.numel()]
                else:
                    pmin_cat = pmin.to(self.device)
                    pmax_cat = pmax.to(self.device)
            except Exception:
                pmin_cat = pmin.to(self.device)
                pmax_cat = pmax.to(self.device)

            pg_low_viol = torch.relu(pmin_cat - pg)
            pg_high_viol = torch.relu(pg - pmax_cat)
            pg_violations = pg_low_viol + pg_high_viol

            violations['active_power_generation'] = pg_violations.mean()

            if return_individual:
                violations['pg_low_violations'] = pg_low_viol.mean()
                violations['pg_high_violations'] = pg_high_viol.mean()

        # Reactive power violations
        if 'qmin' in self.generation_limits and 'qmax' in self.generation_limits:
            qmin = self.generation_limits['qmin']
            qmax = self.generation_limits['qmax']

            try:
                n_genq = int(qmin.numel())
                if qg.numel() != n_genq:
                    if qg.numel() % n_genq == 0:
                        warnings.warn(
                            "Repeated reactive-power limits across samples because generator vector length does not match base limit length. "
                            "If this batch contains variable-size cases, QG bounds may be assigned to wrong generators."
                        )
                        batch_size_q = qg.numel() // n_genq
                        qmin_cat = qmin.to(self.device).repeat(batch_size_q)
                        qmax_cat = qmax.to(self.device).repeat(batch_size_q)
                    else:
                        warnings.warn(
                            "Broadcasted reactive-power limits by repetition/truncation because generator vector length is not a multiple "
                            "of base limit length; this can hide mixed-topology misalignment."
                        )
                        qmin_cat = qmin.to(self.device).reshape(-1).repeat(qg.numel() //
                                                                           qmin.numel() + 1)[:qg.numel()]
                        qmax_cat = qmax.to(self.device).reshape(-1).repeat(qg.numel() //
                                                                           qmax.numel() + 1)[:qg.numel()]
                else:
                    qmin_cat = qmin.to(self.device)
                    qmax_cat = qmax.to(self.device)
            except Exception:
                qmin_cat = qmin.to(self.device)
                qmax_cat = qmax.to(self.device)

            qg_low_viol = torch.relu(qmin_cat - qg)
            qg_high_viol = torch.relu(qg - qmax_cat)
            qg_violations = qg_low_viol + qg_high_viol

            violations['reactive_power_generation'] = qg_violations.mean()

            if return_individual:
                violations['qg_low_violations'] = qg_low_viol.mean()
                violations['qg_high_violations'] = qg_high_viol.mean()

    # Total bound constraint violation
    total_bound_violation = sum(v for k, v in violations.items()
                                if not k.endswith('_violations') and isinstance(v, torch.Tensor))
    violations['total_bound_violations'] = total_bound_violation

    return violations

evaluate_all_constraints(predictions: Dict[str, torch.Tensor], batch_data, normalize: bool = True, return_individual: bool = True) -> Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]

Evaluate all constraint violations comprehensively.

Currently delegates to evaluate_bound_constraints and prefixes results with bound_. Power-flow and line-flow evaluations are planned but not yet active.

Parameters:

Name Type Description Default
predictions dict[str, Tensor]

Model predictions keyed by node type.

required
batch_data

Batch data containing network information.

required
normalize bool

Whether to normalize violations (reserved for future use).

True
return_individual bool

Whether to include individual violation components.

True

Returns:

Type Description
Dict[str, Union[Tensor, Dict[str, Tensor]]]

dict[str, torch.Tensor]: All violation metrics, keyed with a bound_ prefix.

Source code in lumina/evaluator/opf/evaluator.py
def evaluate_all_constraints(
    self,
    predictions: Dict[str, torch.Tensor],
    batch_data,
    normalize: bool = True,
    return_individual: bool = True,
) -> Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]:
    """Evaluate all constraint violations comprehensively.

    Currently delegates to ``evaluate_bound_constraints`` and prefixes
    results with ``bound_``.  Power-flow and line-flow evaluations are
    planned but not yet active.

    Args:
        predictions (dict[str, torch.Tensor]): Model predictions keyed
            by node type.
        batch_data: Batch data containing network information.
        normalize (bool): Whether to normalize violations (reserved for
            future use).
        return_individual (bool): Whether to include individual violation
            components.

    Returns:
        dict[str, torch.Tensor]: All violation metrics, keyed with a
            ``bound_`` prefix.
    """
    all_violations = {}

    # Evaluate bound constraints
    bound_violations = self.evaluate_bound_constraints(predictions, batch_data, return_individual)
    all_violations.update({f"bound_{k}": v for k, v in bound_violations.items()})

    # Compute total constraint violation
    # violation_keys = ['bound_total_bound_violations', 'real_power_flow_violations', 'reactive_power_flow_violations', 'line_flow_violations']
    # total_violation = sum(all_violations.get(key, torch.tensor(0.0, device=self.device))
    #                       for key in violation_keys)
    # all_violations['total_constraint_violations'] = total_violation

    return all_violations

get_violation_summary(violations: Dict[str, torch.Tensor]) -> Dict[str, float]

Convert violation tensors to a plain-float summary dictionary.

Scalar tensors are converted via .item(); multi-element tensors are reduced by mean.

Parameters:

Name Type Description Default
violations dict[str, Tensor]

Dictionary of violation tensors as returned by evaluate_all_constraints.

required

Returns:

Type Description
Dict[str, float]

dict[str, float]: Violation values as Python floats.

Source code in lumina/evaluator/opf/evaluator.py
def get_violation_summary(
    self,
    violations: Dict[str, torch.Tensor]
) -> Dict[str, float]:
    """Convert violation tensors to a plain-float summary dictionary.

    Scalar tensors are converted via ``.item()``; multi-element tensors
    are reduced by mean.

    Args:
        violations (dict[str, torch.Tensor]): Dictionary of violation
            tensors as returned by ``evaluate_all_constraints``.

    Returns:
        dict[str, float]: Violation values as Python floats.
    """
    summary = {}

    for key, value in violations.items():
        if isinstance(value, torch.Tensor):
            summary[key] = value.item() if value.numel() == 1 else value.mean().item()
        else:
            summary[key] = float(value)

    return summary

create_constraint_evaluator(case_data: Dict, device: Optional[torch.device] = None) -> ACOPFConstraintEvaluator

Create a constraint evaluator from a case-data dictionary.

Convenience factory that extracts voltage_limits, generation_limits, line_limits, Y_real, Y_imag, edge_index, base_mva, and slack_bus_indices from case_data and passes them to ACOPFConstraintEvaluator.

Parameters:

Name Type Description Default
case_data dict

Dictionary containing network parameters. Expected keys mirror the ACOPFConstraintEvaluator constructor arguments.

required
device device

Computation device.

None

Returns:

Name Type Description
ACOPFConstraintEvaluator ACOPFConstraintEvaluator

Fully configured evaluator instance.

Source code in lumina/evaluator/opf/evaluator.py
def create_constraint_evaluator(
    case_data: Dict,
    device: Optional[torch.device] = None
) -> ACOPFConstraintEvaluator:
    """Create a constraint evaluator from a case-data dictionary.

    Convenience factory that extracts ``voltage_limits``, ``generation_limits``,
    ``line_limits``, ``Y_real``, ``Y_imag``, ``edge_index``, ``base_mva``, and
    ``slack_bus_indices`` from *case_data* and passes them to
    ``ACOPFConstraintEvaluator``.

    Args:
        case_data (dict): Dictionary containing network parameters. Expected
            keys mirror the ``ACOPFConstraintEvaluator`` constructor arguments.
        device (torch.device, optional): Computation device.

    Returns:
        ACOPFConstraintEvaluator: Fully configured evaluator instance.
    """
    # Extract parameters from case data
    voltage_limits = case_data.get('voltage_limits')
    generation_limits = case_data.get('generation_limits')
    line_limits = case_data.get('line_limits')
    Y_real = case_data.get('Y_real')
    Y_imag = case_data.get('Y_imag')
    edge_index = case_data.get('edge_index')
    base_mva = case_data.get('base_mva', 100.0)
    slack_bus_indices = case_data.get('slack_bus_indices')

    return ACOPFConstraintEvaluator(
        voltage_limits=voltage_limits,
        generation_limits=generation_limits,
        line_limits=line_limits,
        Y_real=Y_real,
        Y_imag=Y_imag,
        edge_index=edge_index,
        base_mva=base_mva,
        device=device,
        slack_bus_indices=slack_bus_indices,
    )

Modeler

Modeler

Modeler wraps model loading, prediction, and evaluation logic for OPF tasks.

This class serves as user-friendly wrapper to encapsulate the usage of lumina trained models including: Configuration setup, loading Weights, run batch predictions, and run evaluation on saved predictions.

Parameters:

Name Type Description Default
device device

Device to run model inference on (e.g., "cpu" or "cuda").

required
fail_on_missing bool

If True, raise when expected model keys are missing during checkpoint load. Defaults to False.

False
verbose bool

If True, print diagnostic messages during checkpoint loading. Defaults to True.

True
base_mva float

Base MVA used by the evaluator. Defaults to 100.0.

100.0
slack_bus_indices str

Comma-separated slack bus indices (default: "0"). Converted to a list of ints and stored on the instance.

'0'

Attributes:

Name Type Description
device device

See Args.

fail_on_missing bool

See Args.

verbose bool

See Args.

base_mva float

See Args.

slack_bus_indices List[int]

Slack bus indices parsed from the slack_bus_indices constructor argument.

model Optional[Module]

Loaded and prepared model (None until load_model is called).

config_data Optional[dict]

Parsed model configuration loaded during load_model.

Example

modeler = Modeler(torch.device("cpu"), slack_bus_indices="0,1") config = json.load(open("config.json")) state_dict = load_file("model.safetensors") modeler.load_model(config, state_dict) loader = DataLoader(OPFDataset(root="./opf_data", case_name="pglib_opf_case14_ieee"), batch_size=1) preds = modeler.run_predictions(loader) stats = modeler.evaluate_from_predictions(preds, cache_key="pglib_opf_case14_ieee")

Source code in lumina/evaluator/opf/utils.py
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
class Modeler:
    """
    Modeler wraps model loading, prediction, and evaluation logic for OPF tasks.

    This class serves as user-friendly wrapper to encapsulate the usage of lumina trained models including:
    Configuration setup, loading Weights, run batch predictions, and run evaluation on saved predictions.

    Args:
        device (torch.device): Device to run model inference on (e.g., "cpu" or "cuda").
        fail_on_missing (bool, optional): If True, raise when expected model keys
            are missing during checkpoint load. Defaults to False.
        verbose (bool, optional): If True, print diagnostic messages during
            checkpoint loading. Defaults to True.
        base_mva (float, optional): Base MVA used by the evaluator. Defaults to 100.0.
        slack_bus_indices (str, optional): Comma-separated slack bus indices
            (default: "0"). Converted to a list of ints and stored on the instance.

    Attributes:
        device (torch.device): See Args.
        fail_on_missing (bool): See Args.
        verbose (bool): See Args.
        base_mva (float): See Args.
        slack_bus_indices (List[int]): Slack bus indices parsed from the
            `slack_bus_indices` constructor argument.
        model (Optional[torch.nn.Module]): Loaded and prepared model (None until `load_model` is called).
        config_data (Optional[dict]): Parsed model configuration loaded during `load_model`.

    Example:
        >>> modeler = Modeler(torch.device("cpu"), slack_bus_indices="0,1")
        >>> config = json.load(open("config.json"))
        >>> state_dict = load_file("model.safetensors")
        >>> modeler.load_model(config, state_dict)
        >>> loader = DataLoader(OPFDataset(root="./opf_data", case_name="pglib_opf_case14_ieee"), batch_size=1)
        >>> preds = modeler.run_predictions(loader)
        >>> stats = modeler.evaluate_from_predictions(preds, cache_key="pglib_opf_case14_ieee")
    """

    def __init__(
        self,
        device: torch.device,
        *,
        fail_on_missing: bool = False,
        verbose: bool = True,
        base_mva: float = 100.0,
        slack_bus_indices: str = "0",
    ):
        self.device = device
        self.fail_on_missing = fail_on_missing
        self.verbose = verbose
        self.base_mva = base_mva
        self.slack_bus_indices = [int(x) for x in slack_bus_indices.split(",") if x.strip() != ""]
        self.model: Optional[torch.nn.Module] = None
        self.config_data = None

    # -- checkpoint key conversion and loading -----------------------------------------------------------------
    @staticmethod
    def convert_checkpoint_key_to_model_key(key: str) -> str:
        """
        Convert checkpoint keys to model keys by transforming
        underscore-delimited items to tuple string representation.

        Args:
            key: Current key with triple underscore delimiters inside angle brackets

        Returns:
            String with angle bracket contents converted to tuple representation

        Example:
            >>> Modeler.convert_checkpoint_key_to_model_key("<bus___ac_line___weight>")
            "('bus', 'ac_line', 'weight')"
        """

        pattern = r"<([^>]+)>"

        def replacer(match):
            parts = match.group(1).split('___')
            return f"('{parts[0]}', '{parts[1]}', '{parts[2]}')"

        return re.sub(pattern, replacer, key)

    def load_checkpoint_into_model(
        self,
        model: torch.nn.Module,
        checkpoint_dict,
        *,
        fail_on_missing: bool = False,
        verbose: bool = True,
    ):
        """
        Load a checkpoint dictionary into a model and report missing/unexpected keys.

        This method remaps checkpoint keys to model keys using
        `convert_checkpoint_key_to_model_key` and then calls `load_state_dict`
        with `strict=False` to allow partial loads.

        Args:
            model (torch.nn.Module): The model to populate.
            checkpoint_dict (dict): Mapping of checkpoint keys to tensors.
            fail_on_missing (bool, optional): If True, raise ValueError when
                missing keys remain after the load. Defaults to False.
            verbose (bool, optional): If True, print missing/unexpected keys.

        Returns:
            dict: A dictionary with keys "missing_keys" and "unexpected_keys",
                each mapping to a list of key names observed.

        Raises:
            ValueError: If `fail_on_missing` is True and missing keys are found.

        Example:
            >>> result = modeler._load_checkpoint_into_model(model, ckpt_dict)
            >>> print(result["missing_keys"])
        """

        model_state = model.state_dict()
        used_keys = set()
        missing_keys = []

        remapped_state = {}
        for model_key in model_state.keys():
            ck = self.convert_checkpoint_key_to_model_key(model_key)
            if ck in checkpoint_dict:
                remapped_state[model_key] = checkpoint_dict[ck]
                used_keys.add(ck)

        unexpected_keys = [k for k in checkpoint_dict.keys() if k not in used_keys]

        load_result = model.load_state_dict(remapped_state, strict=False)
        missing_keys = list(load_result.missing_keys)
        unexpected_keys.extend(list(load_result.unexpected_keys))

        if verbose and (missing_keys or unexpected_keys):
            print(f"[CHECKPOINT LOAD] Missing keys: {missing_keys}, Unexpected keys: {unexpected_keys}")
        if fail_on_missing and missing_keys:
            raise ValueError(f"Missing keys during load: {missing_keys}")

        return {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys}

    # -- model loading ---------------------------------------------------------------------------------------
    def load_model(self, config_data: dict, state_dict: dict):
        """
        Construct a hetero OPF model from provided configuration and state dict.

        Note:
            Downloads and file I/O for the configuration and safetensors are
            expected to be performed outside this method; the parsed `config_data`
            and in-memory `state_dict` should be passed here.

        Args:
            config_data (dict): Parsed JSON configuration describing model metadata
                and architecture.
            state_dict (dict): Raw state dictionary as returned by `safetensors.torch.load_file`.

        Returns:
            Tuple[torch.nn.Module, dict]: The constructed model (in eval mode)
            and the config_data used to build it.

        Raises:
            ValueError: If `fail_on_missing` is True and required keys are missing
                from the checkpoint (raised from `_load_checkpoint_into_model`).

        Example:
            >>> config = json.load(open("config.json"))
            >>> state = load_file("model.safetensors")
            >>> model, cfg = modeler.load_model(config, state)
        """
        # Convert metadata edge keys from strings to tuples if needed
        if 'edges' in config_data.get('metadata', {}):
            edges_dict = {}
            for key, value in config_data['metadata']['edges'].items():
                if isinstance(key, str) and key.startswith('('):
                    key = ast.literal_eval(key)
                edges_dict[key] = value
            config_data['metadata']['edges'] = edges_dict

        model_type = resolve_hetero_model_type(
            model_type=config_data.get("model"),
            model_class_path=config_data.get("model_class"),
            default="HeteroGNN",
        )
        model_class, model_kwargs, _, used_fallback = build_hetero_model_spec(
            model_type=model_type,
            metadata=config_data["metadata"],
            input_channels=config_data["input_channels"],
            models_config=config_data.get("config", {}).get("models", {}),
            out_channels=config_data.get("out_channels", 2),
        )
        if used_fallback and self.verbose:
            print(f"[MODEL LOAD] Config for {model_type} not found; using HeteroGNN config.")

        model = model_class(**model_kwargs).to(self.device)

        # state_dict is the raw output of safetensors.load_file; remap its keys
        checkpoint_dict = {self.convert_checkpoint_key_to_model_key(k): v for k, v in state_dict.items()}

        self.load_checkpoint_into_model(
            model,
            checkpoint_dict,
            fail_on_missing=self.fail_on_missing,
            verbose=self.verbose,
        )

        model.eval()
        self.model = model
        self.config_data = config_data
        return model, config_data


    def load_model_from_training_checkpoint(self,
            ckpt_path: Union[str, "os.PathLike[str]"],
            *,
            strict: bool = True,
    ) -> torch.nn.Module:
        """Load a model from a training checkpoint file.

        Training checkpoint formats differ from HuggingFace safetensor
        serialization. This method expects a checkpoint with at least:
        ``model_class`` (fully-qualified class name), ``model_kwargs``
        (constructor arguments), and ``model_state_dict`` or ``model_state``
        (weight tensors).

        Args:
            ckpt_path (str | os.PathLike): Path to the ``.pt`` checkpoint file.
            strict (bool): Whether to enforce strict key matching in
                ``load_state_dict``.

        Returns:
            torch.nn.Module: Reconstructed model (not DDP-wrapped), in eval mode,
                moved to the instance's device.

        Raises:
            ValueError: If ``model_class`` in the checkpoint is not a valid
                fully-qualified Python class path.
        """
        ckpt: Dict[str, Any] = torch.load(ckpt_path, map_location=self.device)
        class_path = ckpt.get("model_class")

        model_kwargs = ckpt.get("model_kwargs", {})
        state_dict = ckpt.get("model_state")

        if state_dict is None:
            state_dict = ckpt.get("model_state_dict")

        normalized_state_dict = {key.replace('module.', ''): val for key, val in state_dict.items()}

        # N.b. we should switch to using a model registry

        module_name, _, cls_name = class_path.rpartition(".")

        if not module_name:
            raise ValueError(
                f"Invalid model class in checkpoint: '{class_path}'. Expected fully-qualified path like 'pkg.module.ClassName'."
            )

        module = importlib.import_module(module_name)
        cls = getattr(module, cls_name)

        model: torch.nn.Module = cls(**model_kwargs)
        model.load_state_dict(normalized_state_dict, strict=strict)
        model = model.to(torch.device(self.device))

        model.eval()
        self.model = model
        return model


    # -- helpers for tensors --------------------------------------------------------------------------------
    @staticmethod
    def to_float32(batch):
        """
        Convert node features/targets and edge attributes to float32.

        Iterates through all node types and edge types in the batch,
        converting 'x' (features), 'y' (targets), and 'edge_attr' tensors
        to float32 precision.

        Args:
            batch: A batch data object from DataLoader containing node_types
                   and edge_types attributes

        Returns:
            The same batch object with `.x`, `.y` and `.edge_attr` converted to float32
            where present.

        Note:
            Modifies the input batch in-place and returns it.
            Assumes batch has 'node_types' and 'edge_types' properties,
            typical of heterogeneous graph data structures.

        Example:
            >>> batch = modeler.to_float32(batch)
        """
        for node_type in batch.node_types:
            if getattr(batch[node_type], 'x', None) is not None:
                batch[node_type].x = batch[node_type].x.float()
            if getattr(batch[node_type], 'y', None) is not None:
                batch[node_type].y = batch[node_type].y.float()

        for edge_type in batch.edge_types:
            if getattr(batch[edge_type], 'edge_attr', None) is not None:
                batch[edge_type].edge_attr = batch[edge_type].edge_attr.float()

        return batch

    # -- limit/params derivation -------------------------------------------------------------------------------
    @staticmethod
    def derive_voltage_limits(bus_x: torch.Tensor, device: torch.device):
        """
        Derive per-bus voltage limits from bus feature matrix.

        If bus feature tensor contains columns for vmin/vmax (columns 1 and 2),
        those are used; otherwise sensible defaults (0.95/1.05) are returned.

        Args:
            bus_x (torch.Tensor): Bus node feature matrix or None.
            device (torch.device): Device on which returned tensors should be allocated.

        Returns:
            dict: Dictionary with keys 'vmin' and 'vmax' mapping to 1-D tensors of length n_bus.

        Example:
            >>> vlims = Modeler.derive_voltage_limits(bus_x, torch.device("cpu"))
            >>> vlims['vmin'].shape
        """
        if bus_x is not None and bus_x.size(1) >= 3:
            vmin = bus_x[:, 1].to(device)
            vmax = bus_x[:, 2].to(device)
            return {'vmin': vmin, 'vmax': vmax}
        n_bus = bus_x.size(0) if bus_x is not None else 0
        return {
            'vmin': torch.full((n_bus,), 0.95, device=device),
            'vmax': torch.full((n_bus,), 1.05, device=device),
        }

    @staticmethod
    def derive_generation_limits(gen_x: torch.Tensor, device: torch.device):
        """
        Derive generator P/Q limits from generator feature matrix.

        Heuristic mapping based on dataset feature layout:
          - pmin: column 2
          - pmax: column 3
          - qmin: column 5
          - qmax: column 6

        Args:
            gen_x (torch.Tensor): Generator node feature matrix or None.
            device (torch.device): Device for output tensors.

        Returns:
            Optional[dict]: Dictionary with keys 'pmin', 'pmax', 'qmin', 'qmax' mapping
                to 1-D tensors of length n_gen, or None if gen_x is None/empty.

        Example:
            >>> glims = Modeler.derive_generation_limits(gen_x, torch.device("cpu"))
            >>> glims['pmax'].shape
        """
        if gen_x is None or gen_x.numel() == 0:
            return None

        n_gen = gen_x.size(0)

        def col_or_default(idx: int, default: float):
            return gen_x[:, idx].to(device) if gen_x.size(1) > idx else torch.full((n_gen,), default, device=device)

        pmin = col_or_default(2, 0.0)
        pmax = col_or_default(3, 2.0)
        qmin = col_or_default(5, -1.0)
        qmax = col_or_default(6, 1.0)

        return {'pmin': pmin, 'pmax': pmax, 'qmin': qmin, 'qmax': qmax}

    def derive_line_params(self, batch, device: torch.device, cache_key: str = None):
        """
        Build line limits and dense admittance matrices (Y_real, Y_imag) from ac_line edges.

        The method reads the ('bus', 'ac_line', 'bus') edge type attributes and
        computes the per-line thermal limits and the dense Y matrix for the network.
        Results are cached in the module-level `_LINE_CACHE` when a cache_key is provided.

        Args:
            batch: Batched data object containing 'ac_line' edge attributes.
            device (torch.device): Device for intermediate tensors.
            cache_key (str, optional): Key to use for caching results in `_LINE_CACHE`.

        Returns:
            Tuple[torch.Tensor or None, torch.Tensor or None, torch.Tensor or None, torch.Tensor or None]:
                (line_limits, Y_real, Y_imag, edge_index) where any element may be None
                if required data is not present in `batch`.

        Example:
            >>> line_limits, Yr, Yi, idx = modeler.derive_line_params(batch, torch.device("cpu"))
        """
        global _LINE_CACHE
        if cache_key and cache_key in _LINE_CACHE:
            return _LINE_CACHE[cache_key]

        if ('bus', 'ac_line', 'bus') not in batch.edge_types:
            return None, None, None, None

        edge_index = batch[('bus', 'ac_line', 'bus')].edge_index.to(device)
        edge_attr = batch[('bus', 'ac_line', 'bus')].edge_attr.to(device)
        n_bus = batch['bus'].x.size(0)

        # Line limits: use rate_a (first thermal limit)
        line_limits = edge_attr[:, 6] if edge_attr.size(1) > 6 else torch.ones(edge_index.size(1), device=device)

        # Build admittance matrix
        Y_real = torch.zeros((n_bus, n_bus), device=device, dtype=torch.float32)
        Y_imag = torch.zeros((n_bus, n_bus), device=device, dtype=torch.float32)

        for k in range(edge_index.size(1)):
            i = int(edge_index[0, k])
            j = int(edge_index[1, k])

            r = edge_attr[k, 4].item() if edge_attr.size(1) > 4 else 0.0
            x = edge_attr[k, 5].item() if edge_attr.size(1) > 5 else 0.0
            b_shunt = edge_attr[k, 2].item() if edge_attr.size(1) > 2 else 0.0

            if r == 0.0 and x == 0.0:
                continue

            z = complex(r, x)
            y_series = 1.0 / z
            y_shunt = complex(0.0, b_shunt / 2.0)

            g = y_series.real
            b = y_series.imag + y_shunt.imag

            # Off-diagonal
            Y_real[i, j] -= g
            Y_real[j, i] -= g
            Y_imag[i, j] -= b
            Y_imag[j, i] -= b

            # Diagonal contributions
            Y_real[i, i] += g
            Y_imag[i, i] += b
            Y_real[j, j] += g
            Y_imag[j, j] += b

        result = (line_limits, Y_real, Y_imag, edge_index)
        if cache_key:
            _LINE_CACHE[cache_key] = result
        return result

    # -- evaluator construction ------------------------------------------------------------------------------
    def build_constraint_evaluator(self, batch, device: torch.device, cache_key: str = None):
        """
        Build an ACOPFConstraintEvaluator using limits derived from the dataset.

        Args:
            batch: Batched data object containing node/edge features required by the evaluator.
            device (torch.device): Device on which evaluator tensors will be placed.
            cache_key (str, optional): Cache key passed to `derive_line_params` to enable reuse.

        Returns:
            ACOPFConstraintEvaluator: Configured evaluator instance ready to run constraint checks.

        Example:
            >>> evaluator = modeler.build_constraint_evaluator(batch, torch.device("cpu"), cache_key="case14")
        """
        bus_x = batch['bus'].x if hasattr(batch['bus'], 'x') else None
        gen_x = batch['generator'].x if 'generator' in batch.node_types and hasattr(batch['generator'], 'x') else None

        voltage_limits = self.derive_voltage_limits(bus_x, device)
        generation_limits = self.derive_generation_limits(gen_x, device)

        line_limits, Y_real, Y_imag, edge_index = self.derive_line_params(batch, device, cache_key=cache_key)

        return ACOPFConstraintEvaluator(
            voltage_limits=voltage_limits,
            generation_limits=generation_limits,
            line_limits=line_limits,
            Y_real=Y_real,
            Y_imag=Y_imag,
            edge_index=edge_index,
            base_mva=self.base_mva,
            device=device,
        )

    # -- prediction and evaluation stages --------------------------------------------------------------------
    def predict_batch(self, batch, minmax_scaling: bool = True):
        """
        Run a forward pass on a single batch and return predictions (on CPU) with the batch (on CPU).

        Args:
            batch: Batched data object containing inputs for the model.
            minmax_scaling (bool, optional): Whether to apply min-max scaling in the model's forward pass.

        Returns:
            Tuple[dict, object]: A tuple (predictions_cpu, batch_cpu) where `predictions_cpu` maps output
            names to CPU tensors and `batch_cpu` is the input batch moved to CPU.

        Raises:
            RuntimeError: If the model has not been loaded via `load_model()`.

        Example:
            >>> preds, batch_cpu = modeler.predict_batch(batch)
        """
        if self.model is None:
            raise RuntimeError("Model not loaded. Call load_model() first.")

        batch = self.to_float32(batch).to(self.device)

        predictions = self.model(
            batch.x_dict,
            batch.edge_index_dict,
            batch.edge_attr_dict if hasattr(batch, 'edge_attr_dict') else None,
            minmax_scaling=minmax_scaling,
        )

        # Move predictions to CPU and detach to allow storing
        predictions_cpu = {}
        for k, v in predictions.items():
            if isinstance(v, torch.Tensor):
                predictions_cpu[k] = v.detach().cpu()
            else:
                predictions_cpu[k] = v

        # Move batch to CPU for later evaluation/storage. Keep a copy (not on device).
        batch_cpu = batch.to(torch.device('cpu'))
        return predictions_cpu, batch_cpu

    def run_predictions(self, loader: Iterable, max_batches: Optional[int] = None, minmax_scaling: bool = True):
        """
        Run predictions over a data loader and return collected prediction/batch pairs.

        This method separates the forward pass from evaluation so predictions can
        be stored or evaluated later (e.g., on CPU-only machines).

        Args:
            loader (Iterable): Iterable data loader yielding batches.
            max_batches (Optional[int], optional): Limit on number of batches to process. Defaults to None (process all).
            minmax_scaling (bool, optional): Passed to `predict_batch`. Defaults to True.

        Returns:
            List[Tuple[dict, object]]: List of (predictions_cpu, batch_cpu) tuples.

        Raises:
            RuntimeError: If the model has not been loaded via `load_model()`.

        Example:
            >>> pairs = modeler.run_predictions(loader, max_batches=10)
        """
        if self.model is None:
            raise RuntimeError("Model not loaded. Call load_model() first.")

        pred_batch_pairs: List[Tuple[dict, object]] = []
        total_batches = None
        try:
            total_batches = len(loader)
        except TypeError:
            total_batches = None
        if total_batches is not None and max_batches is not None:
            total_batches = min(total_batches, max_batches)

        progress_iter = tqdm(loader, total=total_batches, desc="Predicting samples")
        for batch_idx, batch in enumerate(progress_iter):
            preds, batch_cpu = self.predict_batch(batch, minmax_scaling=minmax_scaling)
            pred_batch_pairs.append((preds, batch_cpu))

            progress_iter.set_postfix(predictions=len(pred_batch_pairs), refresh=False)
            if max_batches is not None and (batch_idx + 1) >= max_batches:
                progress_iter.write(f"Reached max_batches={max_batches}.")
                break
        progress_iter.close()
        return pred_batch_pairs

    def evaluate_from_predictions(
        self,
        pred_batch_pairs: List[Tuple[dict, object]],
        normalize: bool = True,
        cache_key: Optional[str] = None,
    ):
        """
        Evaluate constraints using previously computed predictions and their corresponding batches.

        Args:
            pred_batch_pairs (List[Tuple[dict, object]]): List of (predictions_cpu, batch_cpu) tuples
                produced by `run_predictions`.
            normalize (bool, optional): Whether to normalize violations in the evaluator. Defaults to True.
            cache_key (Optional[str], optional): Cache key to pass to `derive_line_params` for reusing line matrices.

        Returns:
            dict: Aggregated statistics keyed by violation name. Each value is a dict with keys:
                - 'mean' (float): Weighted mean violation
                - 'var' (float): Weighted variance of the violation
                - 'weight' (float): Total sample weight used for aggregation

        Raises:
            ValueError: If `pred_batch_pairs` is empty.

        Example:
            >>> stats = modeler.evaluate_from_predictions(pred_batch_pairs, cache_key="case14")
        """
        if len(pred_batch_pairs) == 0:
            raise ValueError("No predictions provided for evaluation.")

        # Run evaluation on CPU to avoid requiring GPU at evaluation time
        eval_device = torch.device('cpu')

        accum_sum = {}
        accum_sq = {}
        accum_weight = {}
        batches_seen = 0

        progress_iter = tqdm(pred_batch_pairs, desc="Evaluating predictions")
        for predictions, batch in progress_iter:
            # Batch is already on CPU; ensure types
            batch = self.to_float32(batch).to(eval_device)

            evaluator = self.build_constraint_evaluator(batch, device=eval_device, cache_key=cache_key)
            evaluator.slack_bus_indices = self.slack_bus_indices

            # predictions currently CPU tensors; evaluator will operate on same device (cpu)
            violations = evaluator.evaluate_all_constraints(
                predictions=predictions,
                batch_data=batch,
                normalize=normalize,
                return_individual=False,
            )
            summary = evaluator.get_violation_summary(violations)

            sample_weight = batch['bus'].batch.max().item() + 1 if hasattr(batch['bus'], 'batch') else 1
            for key, value in summary.items():
                v = float(value)
                accum_sum[key] = accum_sum.get(key, 0.0) + v * sample_weight
                accum_sq[key] = accum_sq.get(key, 0.0) + v * v * sample_weight
                accum_weight[key] = accum_weight.get(key, 0.0) + sample_weight

            batches_seen += 1
            progress_iter.set_postfix(batches=batches_seen, refresh=False)

        progress_iter.close()

        # compute mean/var
        stats = {}
        if batches_seen > 0:
            for key in sorted(accum_sum.keys()):
                weight = accum_weight.get(key, 0.0)
                if weight == 0:
                    continue
                mean = accum_sum[key] / weight
                mean_sq = accum_sq[key] / weight
                var = mean_sq - mean * mean
                stats[key] = {"mean": mean, "var": var, "weight": weight}
        return stats

convert_checkpoint_key_to_model_key(key: str) -> str staticmethod

Convert checkpoint keys to model keys by transforming underscore-delimited items to tuple string representation.

Parameters:

Name Type Description Default
key str

Current key with triple underscore delimiters inside angle brackets

required

Returns:

Type Description
str

String with angle bracket contents converted to tuple representation

Example

Modeler.convert_checkpoint_key_to_model_key("") "('bus', 'ac_line', 'weight')"

Source code in lumina/evaluator/opf/utils.py
@staticmethod
def convert_checkpoint_key_to_model_key(key: str) -> str:
    """
    Convert checkpoint keys to model keys by transforming
    underscore-delimited items to tuple string representation.

    Args:
        key: Current key with triple underscore delimiters inside angle brackets

    Returns:
        String with angle bracket contents converted to tuple representation

    Example:
        >>> Modeler.convert_checkpoint_key_to_model_key("<bus___ac_line___weight>")
        "('bus', 'ac_line', 'weight')"
    """

    pattern = r"<([^>]+)>"

    def replacer(match):
        parts = match.group(1).split('___')
        return f"('{parts[0]}', '{parts[1]}', '{parts[2]}')"

    return re.sub(pattern, replacer, key)

load_checkpoint_into_model(model: torch.nn.Module, checkpoint_dict, *, fail_on_missing: bool = False, verbose: bool = True)

Load a checkpoint dictionary into a model and report missing/unexpected keys.

This method remaps checkpoint keys to model keys using convert_checkpoint_key_to_model_key and then calls load_state_dict with strict=False to allow partial loads.

Parameters:

Name Type Description Default
model Module

The model to populate.

required
checkpoint_dict dict

Mapping of checkpoint keys to tensors.

required
fail_on_missing bool

If True, raise ValueError when missing keys remain after the load. Defaults to False.

False
verbose bool

If True, print missing/unexpected keys.

True

Returns:

Name Type Description
dict

A dictionary with keys "missing_keys" and "unexpected_keys", each mapping to a list of key names observed.

Raises:

Type Description
ValueError

If fail_on_missing is True and missing keys are found.

Example

result = modeler._load_checkpoint_into_model(model, ckpt_dict) print(result["missing_keys"])

Source code in lumina/evaluator/opf/utils.py
def load_checkpoint_into_model(
    self,
    model: torch.nn.Module,
    checkpoint_dict,
    *,
    fail_on_missing: bool = False,
    verbose: bool = True,
):
    """
    Load a checkpoint dictionary into a model and report missing/unexpected keys.

    This method remaps checkpoint keys to model keys using
    `convert_checkpoint_key_to_model_key` and then calls `load_state_dict`
    with `strict=False` to allow partial loads.

    Args:
        model (torch.nn.Module): The model to populate.
        checkpoint_dict (dict): Mapping of checkpoint keys to tensors.
        fail_on_missing (bool, optional): If True, raise ValueError when
            missing keys remain after the load. Defaults to False.
        verbose (bool, optional): If True, print missing/unexpected keys.

    Returns:
        dict: A dictionary with keys "missing_keys" and "unexpected_keys",
            each mapping to a list of key names observed.

    Raises:
        ValueError: If `fail_on_missing` is True and missing keys are found.

    Example:
        >>> result = modeler._load_checkpoint_into_model(model, ckpt_dict)
        >>> print(result["missing_keys"])
    """

    model_state = model.state_dict()
    used_keys = set()
    missing_keys = []

    remapped_state = {}
    for model_key in model_state.keys():
        ck = self.convert_checkpoint_key_to_model_key(model_key)
        if ck in checkpoint_dict:
            remapped_state[model_key] = checkpoint_dict[ck]
            used_keys.add(ck)

    unexpected_keys = [k for k in checkpoint_dict.keys() if k not in used_keys]

    load_result = model.load_state_dict(remapped_state, strict=False)
    missing_keys = list(load_result.missing_keys)
    unexpected_keys.extend(list(load_result.unexpected_keys))

    if verbose and (missing_keys or unexpected_keys):
        print(f"[CHECKPOINT LOAD] Missing keys: {missing_keys}, Unexpected keys: {unexpected_keys}")
    if fail_on_missing and missing_keys:
        raise ValueError(f"Missing keys during load: {missing_keys}")

    return {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys}

load_model(config_data: dict, state_dict: dict)

Construct a hetero OPF model from provided configuration and state dict.

Note

Downloads and file I/O for the configuration and safetensors are expected to be performed outside this method; the parsed config_data and in-memory state_dict should be passed here.

Parameters:

Name Type Description Default
config_data dict

Parsed JSON configuration describing model metadata and architecture.

required
state_dict dict

Raw state dictionary as returned by safetensors.torch.load_file.

required

Returns:

Type Description

Tuple[torch.nn.Module, dict]: The constructed model (in eval mode)

and the config_data used to build it.

Raises:

Type Description
ValueError

If fail_on_missing is True and required keys are missing from the checkpoint (raised from _load_checkpoint_into_model).

Example

config = json.load(open("config.json")) state = load_file("model.safetensors") model, cfg = modeler.load_model(config, state)

Source code in lumina/evaluator/opf/utils.py
def load_model(self, config_data: dict, state_dict: dict):
    """
    Construct a hetero OPF model from provided configuration and state dict.

    Note:
        Downloads and file I/O for the configuration and safetensors are
        expected to be performed outside this method; the parsed `config_data`
        and in-memory `state_dict` should be passed here.

    Args:
        config_data (dict): Parsed JSON configuration describing model metadata
            and architecture.
        state_dict (dict): Raw state dictionary as returned by `safetensors.torch.load_file`.

    Returns:
        Tuple[torch.nn.Module, dict]: The constructed model (in eval mode)
        and the config_data used to build it.

    Raises:
        ValueError: If `fail_on_missing` is True and required keys are missing
            from the checkpoint (raised from `_load_checkpoint_into_model`).

    Example:
        >>> config = json.load(open("config.json"))
        >>> state = load_file("model.safetensors")
        >>> model, cfg = modeler.load_model(config, state)
    """
    # Convert metadata edge keys from strings to tuples if needed
    if 'edges' in config_data.get('metadata', {}):
        edges_dict = {}
        for key, value in config_data['metadata']['edges'].items():
            if isinstance(key, str) and key.startswith('('):
                key = ast.literal_eval(key)
            edges_dict[key] = value
        config_data['metadata']['edges'] = edges_dict

    model_type = resolve_hetero_model_type(
        model_type=config_data.get("model"),
        model_class_path=config_data.get("model_class"),
        default="HeteroGNN",
    )
    model_class, model_kwargs, _, used_fallback = build_hetero_model_spec(
        model_type=model_type,
        metadata=config_data["metadata"],
        input_channels=config_data["input_channels"],
        models_config=config_data.get("config", {}).get("models", {}),
        out_channels=config_data.get("out_channels", 2),
    )
    if used_fallback and self.verbose:
        print(f"[MODEL LOAD] Config for {model_type} not found; using HeteroGNN config.")

    model = model_class(**model_kwargs).to(self.device)

    # state_dict is the raw output of safetensors.load_file; remap its keys
    checkpoint_dict = {self.convert_checkpoint_key_to_model_key(k): v for k, v in state_dict.items()}

    self.load_checkpoint_into_model(
        model,
        checkpoint_dict,
        fail_on_missing=self.fail_on_missing,
        verbose=self.verbose,
    )

    model.eval()
    self.model = model
    self.config_data = config_data
    return model, config_data

load_model_from_training_checkpoint(ckpt_path: Union[str, os.PathLike[str]], *, strict: bool = True) -> torch.nn.Module

Load a model from a training checkpoint file.

Training checkpoint formats differ from HuggingFace safetensor serialization. This method expects a checkpoint with at least: model_class (fully-qualified class name), model_kwargs (constructor arguments), and model_state_dict or model_state (weight tensors).

Parameters:

Name Type Description Default
ckpt_path str | PathLike

Path to the .pt checkpoint file.

required
strict bool

Whether to enforce strict key matching in load_state_dict.

True

Returns:

Type Description
Module

torch.nn.Module: Reconstructed model (not DDP-wrapped), in eval mode, moved to the instance's device.

Raises:

Type Description
ValueError

If model_class in the checkpoint is not a valid fully-qualified Python class path.

Source code in lumina/evaluator/opf/utils.py
def load_model_from_training_checkpoint(self,
        ckpt_path: Union[str, "os.PathLike[str]"],
        *,
        strict: bool = True,
) -> torch.nn.Module:
    """Load a model from a training checkpoint file.

    Training checkpoint formats differ from HuggingFace safetensor
    serialization. This method expects a checkpoint with at least:
    ``model_class`` (fully-qualified class name), ``model_kwargs``
    (constructor arguments), and ``model_state_dict`` or ``model_state``
    (weight tensors).

    Args:
        ckpt_path (str | os.PathLike): Path to the ``.pt`` checkpoint file.
        strict (bool): Whether to enforce strict key matching in
            ``load_state_dict``.

    Returns:
        torch.nn.Module: Reconstructed model (not DDP-wrapped), in eval mode,
            moved to the instance's device.

    Raises:
        ValueError: If ``model_class`` in the checkpoint is not a valid
            fully-qualified Python class path.
    """
    ckpt: Dict[str, Any] = torch.load(ckpt_path, map_location=self.device)
    class_path = ckpt.get("model_class")

    model_kwargs = ckpt.get("model_kwargs", {})
    state_dict = ckpt.get("model_state")

    if state_dict is None:
        state_dict = ckpt.get("model_state_dict")

    normalized_state_dict = {key.replace('module.', ''): val for key, val in state_dict.items()}

    # N.b. we should switch to using a model registry

    module_name, _, cls_name = class_path.rpartition(".")

    if not module_name:
        raise ValueError(
            f"Invalid model class in checkpoint: '{class_path}'. Expected fully-qualified path like 'pkg.module.ClassName'."
        )

    module = importlib.import_module(module_name)
    cls = getattr(module, cls_name)

    model: torch.nn.Module = cls(**model_kwargs)
    model.load_state_dict(normalized_state_dict, strict=strict)
    model = model.to(torch.device(self.device))

    model.eval()
    self.model = model
    return model

to_float32(batch) staticmethod

Convert node features/targets and edge attributes to float32.

Iterates through all node types and edge types in the batch, converting 'x' (features), 'y' (targets), and 'edge_attr' tensors to float32 precision.

Parameters:

Name Type Description Default
batch

A batch data object from DataLoader containing node_types and edge_types attributes

required

Returns:

Type Description

The same batch object with .x, .y and .edge_attr converted to float32

where present.

Note

Modifies the input batch in-place and returns it. Assumes batch has 'node_types' and 'edge_types' properties, typical of heterogeneous graph data structures.

Example

batch = modeler.to_float32(batch)

Source code in lumina/evaluator/opf/utils.py
@staticmethod
def to_float32(batch):
    """
    Convert node features/targets and edge attributes to float32.

    Iterates through all node types and edge types in the batch,
    converting 'x' (features), 'y' (targets), and 'edge_attr' tensors
    to float32 precision.

    Args:
        batch: A batch data object from DataLoader containing node_types
               and edge_types attributes

    Returns:
        The same batch object with `.x`, `.y` and `.edge_attr` converted to float32
        where present.

    Note:
        Modifies the input batch in-place and returns it.
        Assumes batch has 'node_types' and 'edge_types' properties,
        typical of heterogeneous graph data structures.

    Example:
        >>> batch = modeler.to_float32(batch)
    """
    for node_type in batch.node_types:
        if getattr(batch[node_type], 'x', None) is not None:
            batch[node_type].x = batch[node_type].x.float()
        if getattr(batch[node_type], 'y', None) is not None:
            batch[node_type].y = batch[node_type].y.float()

    for edge_type in batch.edge_types:
        if getattr(batch[edge_type], 'edge_attr', None) is not None:
            batch[edge_type].edge_attr = batch[edge_type].edge_attr.float()

    return batch

derive_voltage_limits(bus_x: torch.Tensor, device: torch.device) staticmethod

Derive per-bus voltage limits from bus feature matrix.

If bus feature tensor contains columns for vmin/vmax (columns 1 and 2), those are used; otherwise sensible defaults (0.95/1.05) are returned.

Parameters:

Name Type Description Default
bus_x Tensor

Bus node feature matrix or None.

required
device device

Device on which returned tensors should be allocated.

required

Returns:

Name Type Description
dict

Dictionary with keys 'vmin' and 'vmax' mapping to 1-D tensors of length n_bus.

Example

vlims = Modeler.derive_voltage_limits(bus_x, torch.device("cpu")) vlims['vmin'].shape

Source code in lumina/evaluator/opf/utils.py
@staticmethod
def derive_voltage_limits(bus_x: torch.Tensor, device: torch.device):
    """
    Derive per-bus voltage limits from bus feature matrix.

    If bus feature tensor contains columns for vmin/vmax (columns 1 and 2),
    those are used; otherwise sensible defaults (0.95/1.05) are returned.

    Args:
        bus_x (torch.Tensor): Bus node feature matrix or None.
        device (torch.device): Device on which returned tensors should be allocated.

    Returns:
        dict: Dictionary with keys 'vmin' and 'vmax' mapping to 1-D tensors of length n_bus.

    Example:
        >>> vlims = Modeler.derive_voltage_limits(bus_x, torch.device("cpu"))
        >>> vlims['vmin'].shape
    """
    if bus_x is not None and bus_x.size(1) >= 3:
        vmin = bus_x[:, 1].to(device)
        vmax = bus_x[:, 2].to(device)
        return {'vmin': vmin, 'vmax': vmax}
    n_bus = bus_x.size(0) if bus_x is not None else 0
    return {
        'vmin': torch.full((n_bus,), 0.95, device=device),
        'vmax': torch.full((n_bus,), 1.05, device=device),
    }

derive_generation_limits(gen_x: torch.Tensor, device: torch.device) staticmethod

Derive generator P/Q limits from generator feature matrix.

Heuristic mapping based on dataset feature layout
  • pmin: column 2
  • pmax: column 3
  • qmin: column 5
  • qmax: column 6

Parameters:

Name Type Description Default
gen_x Tensor

Generator node feature matrix or None.

required
device device

Device for output tensors.

required

Returns:

Type Description

Optional[dict]: Dictionary with keys 'pmin', 'pmax', 'qmin', 'qmax' mapping to 1-D tensors of length n_gen, or None if gen_x is None/empty.

Example

glims = Modeler.derive_generation_limits(gen_x, torch.device("cpu")) glims['pmax'].shape

Source code in lumina/evaluator/opf/utils.py
@staticmethod
def derive_generation_limits(gen_x: torch.Tensor, device: torch.device):
    """
    Derive generator P/Q limits from generator feature matrix.

    Heuristic mapping based on dataset feature layout:
      - pmin: column 2
      - pmax: column 3
      - qmin: column 5
      - qmax: column 6

    Args:
        gen_x (torch.Tensor): Generator node feature matrix or None.
        device (torch.device): Device for output tensors.

    Returns:
        Optional[dict]: Dictionary with keys 'pmin', 'pmax', 'qmin', 'qmax' mapping
            to 1-D tensors of length n_gen, or None if gen_x is None/empty.

    Example:
        >>> glims = Modeler.derive_generation_limits(gen_x, torch.device("cpu"))
        >>> glims['pmax'].shape
    """
    if gen_x is None or gen_x.numel() == 0:
        return None

    n_gen = gen_x.size(0)

    def col_or_default(idx: int, default: float):
        return gen_x[:, idx].to(device) if gen_x.size(1) > idx else torch.full((n_gen,), default, device=device)

    pmin = col_or_default(2, 0.0)
    pmax = col_or_default(3, 2.0)
    qmin = col_or_default(5, -1.0)
    qmax = col_or_default(6, 1.0)

    return {'pmin': pmin, 'pmax': pmax, 'qmin': qmin, 'qmax': qmax}

derive_line_params(batch, device: torch.device, cache_key: str = None)

Build line limits and dense admittance matrices (Y_real, Y_imag) from ac_line edges.

The method reads the ('bus', 'ac_line', 'bus') edge type attributes and computes the per-line thermal limits and the dense Y matrix for the network. Results are cached in the module-level _LINE_CACHE when a cache_key is provided.

Parameters:

Name Type Description Default
batch

Batched data object containing 'ac_line' edge attributes.

required
device device

Device for intermediate tensors.

required
cache_key str

Key to use for caching results in _LINE_CACHE.

None

Returns:

Type Description

Tuple[torch.Tensor or None, torch.Tensor or None, torch.Tensor or None, torch.Tensor or None]: (line_limits, Y_real, Y_imag, edge_index) where any element may be None if required data is not present in batch.

Example

line_limits, Yr, Yi, idx = modeler.derive_line_params(batch, torch.device("cpu"))

Source code in lumina/evaluator/opf/utils.py
def derive_line_params(self, batch, device: torch.device, cache_key: str = None):
    """
    Build line limits and dense admittance matrices (Y_real, Y_imag) from ac_line edges.

    The method reads the ('bus', 'ac_line', 'bus') edge type attributes and
    computes the per-line thermal limits and the dense Y matrix for the network.
    Results are cached in the module-level `_LINE_CACHE` when a cache_key is provided.

    Args:
        batch: Batched data object containing 'ac_line' edge attributes.
        device (torch.device): Device for intermediate tensors.
        cache_key (str, optional): Key to use for caching results in `_LINE_CACHE`.

    Returns:
        Tuple[torch.Tensor or None, torch.Tensor or None, torch.Tensor or None, torch.Tensor or None]:
            (line_limits, Y_real, Y_imag, edge_index) where any element may be None
            if required data is not present in `batch`.

    Example:
        >>> line_limits, Yr, Yi, idx = modeler.derive_line_params(batch, torch.device("cpu"))
    """
    global _LINE_CACHE
    if cache_key and cache_key in _LINE_CACHE:
        return _LINE_CACHE[cache_key]

    if ('bus', 'ac_line', 'bus') not in batch.edge_types:
        return None, None, None, None

    edge_index = batch[('bus', 'ac_line', 'bus')].edge_index.to(device)
    edge_attr = batch[('bus', 'ac_line', 'bus')].edge_attr.to(device)
    n_bus = batch['bus'].x.size(0)

    # Line limits: use rate_a (first thermal limit)
    line_limits = edge_attr[:, 6] if edge_attr.size(1) > 6 else torch.ones(edge_index.size(1), device=device)

    # Build admittance matrix
    Y_real = torch.zeros((n_bus, n_bus), device=device, dtype=torch.float32)
    Y_imag = torch.zeros((n_bus, n_bus), device=device, dtype=torch.float32)

    for k in range(edge_index.size(1)):
        i = int(edge_index[0, k])
        j = int(edge_index[1, k])

        r = edge_attr[k, 4].item() if edge_attr.size(1) > 4 else 0.0
        x = edge_attr[k, 5].item() if edge_attr.size(1) > 5 else 0.0
        b_shunt = edge_attr[k, 2].item() if edge_attr.size(1) > 2 else 0.0

        if r == 0.0 and x == 0.0:
            continue

        z = complex(r, x)
        y_series = 1.0 / z
        y_shunt = complex(0.0, b_shunt / 2.0)

        g = y_series.real
        b = y_series.imag + y_shunt.imag

        # Off-diagonal
        Y_real[i, j] -= g
        Y_real[j, i] -= g
        Y_imag[i, j] -= b
        Y_imag[j, i] -= b

        # Diagonal contributions
        Y_real[i, i] += g
        Y_imag[i, i] += b
        Y_real[j, j] += g
        Y_imag[j, j] += b

    result = (line_limits, Y_real, Y_imag, edge_index)
    if cache_key:
        _LINE_CACHE[cache_key] = result
    return result

build_constraint_evaluator(batch, device: torch.device, cache_key: str = None)

Build an ACOPFConstraintEvaluator using limits derived from the dataset.

Parameters:

Name Type Description Default
batch

Batched data object containing node/edge features required by the evaluator.

required
device device

Device on which evaluator tensors will be placed.

required
cache_key str

Cache key passed to derive_line_params to enable reuse.

None

Returns:

Name Type Description
ACOPFConstraintEvaluator

Configured evaluator instance ready to run constraint checks.

Example

evaluator = modeler.build_constraint_evaluator(batch, torch.device("cpu"), cache_key="case14")

Source code in lumina/evaluator/opf/utils.py
def build_constraint_evaluator(self, batch, device: torch.device, cache_key: str = None):
    """
    Build an ACOPFConstraintEvaluator using limits derived from the dataset.

    Args:
        batch: Batched data object containing node/edge features required by the evaluator.
        device (torch.device): Device on which evaluator tensors will be placed.
        cache_key (str, optional): Cache key passed to `derive_line_params` to enable reuse.

    Returns:
        ACOPFConstraintEvaluator: Configured evaluator instance ready to run constraint checks.

    Example:
        >>> evaluator = modeler.build_constraint_evaluator(batch, torch.device("cpu"), cache_key="case14")
    """
    bus_x = batch['bus'].x if hasattr(batch['bus'], 'x') else None
    gen_x = batch['generator'].x if 'generator' in batch.node_types and hasattr(batch['generator'], 'x') else None

    voltage_limits = self.derive_voltage_limits(bus_x, device)
    generation_limits = self.derive_generation_limits(gen_x, device)

    line_limits, Y_real, Y_imag, edge_index = self.derive_line_params(batch, device, cache_key=cache_key)

    return ACOPFConstraintEvaluator(
        voltage_limits=voltage_limits,
        generation_limits=generation_limits,
        line_limits=line_limits,
        Y_real=Y_real,
        Y_imag=Y_imag,
        edge_index=edge_index,
        base_mva=self.base_mva,
        device=device,
    )

predict_batch(batch, minmax_scaling: bool = True)

Run a forward pass on a single batch and return predictions (on CPU) with the batch (on CPU).

Parameters:

Name Type Description Default
batch

Batched data object containing inputs for the model.

required
minmax_scaling bool

Whether to apply min-max scaling in the model's forward pass.

True

Returns:

Type Description

Tuple[dict, object]: A tuple (predictions_cpu, batch_cpu) where predictions_cpu maps output

names to CPU tensors and batch_cpu is the input batch moved to CPU.

Raises:

Type Description
RuntimeError

If the model has not been loaded via load_model().

Example

preds, batch_cpu = modeler.predict_batch(batch)

Source code in lumina/evaluator/opf/utils.py
def predict_batch(self, batch, minmax_scaling: bool = True):
    """
    Run a forward pass on a single batch and return predictions (on CPU) with the batch (on CPU).

    Args:
        batch: Batched data object containing inputs for the model.
        minmax_scaling (bool, optional): Whether to apply min-max scaling in the model's forward pass.

    Returns:
        Tuple[dict, object]: A tuple (predictions_cpu, batch_cpu) where `predictions_cpu` maps output
        names to CPU tensors and `batch_cpu` is the input batch moved to CPU.

    Raises:
        RuntimeError: If the model has not been loaded via `load_model()`.

    Example:
        >>> preds, batch_cpu = modeler.predict_batch(batch)
    """
    if self.model is None:
        raise RuntimeError("Model not loaded. Call load_model() first.")

    batch = self.to_float32(batch).to(self.device)

    predictions = self.model(
        batch.x_dict,
        batch.edge_index_dict,
        batch.edge_attr_dict if hasattr(batch, 'edge_attr_dict') else None,
        minmax_scaling=minmax_scaling,
    )

    # Move predictions to CPU and detach to allow storing
    predictions_cpu = {}
    for k, v in predictions.items():
        if isinstance(v, torch.Tensor):
            predictions_cpu[k] = v.detach().cpu()
        else:
            predictions_cpu[k] = v

    # Move batch to CPU for later evaluation/storage. Keep a copy (not on device).
    batch_cpu = batch.to(torch.device('cpu'))
    return predictions_cpu, batch_cpu

run_predictions(loader: Iterable, max_batches: Optional[int] = None, minmax_scaling: bool = True)

Run predictions over a data loader and return collected prediction/batch pairs.

This method separates the forward pass from evaluation so predictions can be stored or evaluated later (e.g., on CPU-only machines).

Parameters:

Name Type Description Default
loader Iterable

Iterable data loader yielding batches.

required
max_batches Optional[int]

Limit on number of batches to process. Defaults to None (process all).

None
minmax_scaling bool

Passed to predict_batch. Defaults to True.

True

Returns:

Type Description

List[Tuple[dict, object]]: List of (predictions_cpu, batch_cpu) tuples.

Raises:

Type Description
RuntimeError

If the model has not been loaded via load_model().

Example

pairs = modeler.run_predictions(loader, max_batches=10)

Source code in lumina/evaluator/opf/utils.py
def run_predictions(self, loader: Iterable, max_batches: Optional[int] = None, minmax_scaling: bool = True):
    """
    Run predictions over a data loader and return collected prediction/batch pairs.

    This method separates the forward pass from evaluation so predictions can
    be stored or evaluated later (e.g., on CPU-only machines).

    Args:
        loader (Iterable): Iterable data loader yielding batches.
        max_batches (Optional[int], optional): Limit on number of batches to process. Defaults to None (process all).
        minmax_scaling (bool, optional): Passed to `predict_batch`. Defaults to True.

    Returns:
        List[Tuple[dict, object]]: List of (predictions_cpu, batch_cpu) tuples.

    Raises:
        RuntimeError: If the model has not been loaded via `load_model()`.

    Example:
        >>> pairs = modeler.run_predictions(loader, max_batches=10)
    """
    if self.model is None:
        raise RuntimeError("Model not loaded. Call load_model() first.")

    pred_batch_pairs: List[Tuple[dict, object]] = []
    total_batches = None
    try:
        total_batches = len(loader)
    except TypeError:
        total_batches = None
    if total_batches is not None and max_batches is not None:
        total_batches = min(total_batches, max_batches)

    progress_iter = tqdm(loader, total=total_batches, desc="Predicting samples")
    for batch_idx, batch in enumerate(progress_iter):
        preds, batch_cpu = self.predict_batch(batch, minmax_scaling=minmax_scaling)
        pred_batch_pairs.append((preds, batch_cpu))

        progress_iter.set_postfix(predictions=len(pred_batch_pairs), refresh=False)
        if max_batches is not None and (batch_idx + 1) >= max_batches:
            progress_iter.write(f"Reached max_batches={max_batches}.")
            break
    progress_iter.close()
    return pred_batch_pairs

evaluate_from_predictions(pred_batch_pairs: List[Tuple[dict, object]], normalize: bool = True, cache_key: Optional[str] = None)

Evaluate constraints using previously computed predictions and their corresponding batches.

Parameters:

Name Type Description Default
pred_batch_pairs List[Tuple[dict, object]]

List of (predictions_cpu, batch_cpu) tuples produced by run_predictions.

required
normalize bool

Whether to normalize violations in the evaluator. Defaults to True.

True
cache_key Optional[str]

Cache key to pass to derive_line_params for reusing line matrices.

None

Returns:

Name Type Description
dict

Aggregated statistics keyed by violation name. Each value is a dict with keys: - 'mean' (float): Weighted mean violation - 'var' (float): Weighted variance of the violation - 'weight' (float): Total sample weight used for aggregation

Raises:

Type Description
ValueError

If pred_batch_pairs is empty.

Example

stats = modeler.evaluate_from_predictions(pred_batch_pairs, cache_key="case14")

Source code in lumina/evaluator/opf/utils.py
def evaluate_from_predictions(
    self,
    pred_batch_pairs: List[Tuple[dict, object]],
    normalize: bool = True,
    cache_key: Optional[str] = None,
):
    """
    Evaluate constraints using previously computed predictions and their corresponding batches.

    Args:
        pred_batch_pairs (List[Tuple[dict, object]]): List of (predictions_cpu, batch_cpu) tuples
            produced by `run_predictions`.
        normalize (bool, optional): Whether to normalize violations in the evaluator. Defaults to True.
        cache_key (Optional[str], optional): Cache key to pass to `derive_line_params` for reusing line matrices.

    Returns:
        dict: Aggregated statistics keyed by violation name. Each value is a dict with keys:
            - 'mean' (float): Weighted mean violation
            - 'var' (float): Weighted variance of the violation
            - 'weight' (float): Total sample weight used for aggregation

    Raises:
        ValueError: If `pred_batch_pairs` is empty.

    Example:
        >>> stats = modeler.evaluate_from_predictions(pred_batch_pairs, cache_key="case14")
    """
    if len(pred_batch_pairs) == 0:
        raise ValueError("No predictions provided for evaluation.")

    # Run evaluation on CPU to avoid requiring GPU at evaluation time
    eval_device = torch.device('cpu')

    accum_sum = {}
    accum_sq = {}
    accum_weight = {}
    batches_seen = 0

    progress_iter = tqdm(pred_batch_pairs, desc="Evaluating predictions")
    for predictions, batch in progress_iter:
        # Batch is already on CPU; ensure types
        batch = self.to_float32(batch).to(eval_device)

        evaluator = self.build_constraint_evaluator(batch, device=eval_device, cache_key=cache_key)
        evaluator.slack_bus_indices = self.slack_bus_indices

        # predictions currently CPU tensors; evaluator will operate on same device (cpu)
        violations = evaluator.evaluate_all_constraints(
            predictions=predictions,
            batch_data=batch,
            normalize=normalize,
            return_individual=False,
        )
        summary = evaluator.get_violation_summary(violations)

        sample_weight = batch['bus'].batch.max().item() + 1 if hasattr(batch['bus'], 'batch') else 1
        for key, value in summary.items():
            v = float(value)
            accum_sum[key] = accum_sum.get(key, 0.0) + v * sample_weight
            accum_sq[key] = accum_sq.get(key, 0.0) + v * v * sample_weight
            accum_weight[key] = accum_weight.get(key, 0.0) + sample_weight

        batches_seen += 1
        progress_iter.set_postfix(batches=batches_seen, refresh=False)

    progress_iter.close()

    # compute mean/var
    stats = {}
    if batches_seen > 0:
        for key in sorted(accum_sum.keys()):
            weight = accum_weight.get(key, 0.0)
            if weight == 0:
                continue
            mean = accum_sum[key] / weight
            mean_sq = accum_sq[key] / weight
            var = mean_sq - mean * mean
            stats[key] = {"mean": mean, "var": var, "weight": weight}
    return stats

Utility Functions

extract_network_parameters_from_batch(batch, device: torch.device = None) -> Dict

Extract network parameters from a batch of OPF data.

Parameters:

Name Type Description Default
batch

Batch from OPFDataset containing heterogeneous graph data

required
device device

Target device for tensors

None

Returns:

Type Description
Dict

Dictionary containing extracted network parameters

Source code in lumina/evaluator/opf/utils.py
def extract_network_parameters_from_batch(batch, device: torch.device = None) -> Dict:
    """
    Extract network parameters from a batch of OPF data.

    Args:
        batch: Batch from OPFDataset containing heterogeneous graph data
        device: Target device for tensors

    Returns:
        Dictionary containing extracted network parameters
    """
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    extracted_data = {}

    try:
        # Extract load data (pd, qd)
        if 'load' in batch.x_dict:
            load_data = batch['load'].x  # Shape: [n_loads, features]
            extracted_data['pd'] = load_data[:, 0].to(device)  # Active power demand
            extracted_data['qd'] = load_data[:, 1].to(device)  # Reactive power demand

            # Get load bus indices from edge connections
            if ('load', 'load_link', 'bus') in batch.edge_index_dict:
                # Load to bus connections - get bus indices
                load_bus_edges = batch[('load', 'load_link', 'bus')].edge_index
                extracted_data['load_bus_indices'] = load_bus_edges[1, :].to(device)  # Bus indices
            elif ('bus', 'load_link', 'load') in batch.edge_index_dict:
                # Bus to load connections - get bus indices
                bus_load_edges = batch[('bus', 'load_link', 'load')].edge_index
                extracted_data['load_bus_indices'] = bus_load_edges[0, :].to(device)  # Bus indices

        # Extract generator bus indices
        if ('generator', 'generator_link', 'bus') in batch.edge_index_dict:
            gen_bus_edges = batch[('generator', 'generator_link', 'bus')].edge_index
            extracted_data['gen_bus_indices'] = gen_bus_edges[1, :].to(device)  # Bus indices
        elif ('bus', 'generator_link', 'generator') in batch.edge_index_dict:
            bus_gen_edges = batch[('bus', 'generator_link', 'generator')].edge_index
            extracted_data['gen_bus_indices'] = bus_gen_edges[0, :].to(device)  # Bus indices

        # Extract line edge indices for thermal limits
        if ('bus', 'ac_line', 'bus') in batch.edge_index_dict:
            line_edges = batch[('bus', 'ac_line', 'bus')].edge_index
            extracted_data['line_edge_index'] = line_edges.to(device)

            # Extract line limits from edge attributes if available
            if hasattr(batch[('bus', 'ac_line', 'bus')], 'edge_attr'):
                line_attr = batch[('bus', 'ac_line', 'bus')].edge_attr
                if line_attr.size(1) > 6:  # Assuming thermal limit is 7th column (index 6)
                    extracted_data['line_limits'] = line_attr[:, 6].to(device)

    except Exception as e:
        warnings.warn(f"Error extracting network parameters from batch: {e}")

    return extracted_data

extract_voltage_and_generation_limits_from_batch(batch, device: torch.device = None) -> Tuple[Dict, Dict]

Extract voltage and generation limits from batch data.

Parameters:

Name Type Description Default
batch

Batch from OPFDataset

required
device device

Target device for tensors

None

Returns:

Type Description
Tuple[Dict, Dict]

Tuple of (voltage_limits dict, generation_limits dict)

Source code in lumina/evaluator/opf/utils.py
def extract_voltage_and_generation_limits_from_batch(batch, device: torch.device = None) -> Tuple[Dict, Dict]:
    """
    Extract voltage and generation limits from batch data.

    Args:
        batch: Batch from OPFDataset
        device: Target device for tensors

    Returns:
        Tuple of (voltage_limits dict, generation_limits dict)
    """
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    voltage_limits = {}
    generation_limits = {}

    try:
        # Extract voltage limits from bus data
        if 'bus' in batch.x_dict:
            bus_data = batch['bus'].x
            # Assuming bus features include [base_kv, vmin, vmax, bus_type_onehot...]
            if bus_data.size(1) >= 3:
                voltage_limits['vmin'] = bus_data[:, 1].to(device)
                voltage_limits['vmax'] = bus_data[:, 2].to(device)

        # Extract generation limits from generator data
        if 'generator' in batch.x_dict:
            gen_data = batch['generator'].x
            # Assuming generator features include [mbase, pg, pmin, pmax, qg, qmin, qmax, vg, costs...]
            if gen_data.size(1) >= 7:
                generation_limits['pmin'] = gen_data[:, 2].to(device)
                generation_limits['pmax'] = gen_data[:, 3].to(device)
                generation_limits['qmin'] = gen_data[:, 5].to(device)
                generation_limits['qmax'] = gen_data[:, 6].to(device)

    except Exception as e:
        warnings.warn(f"Error extracting limits from batch: {e}")

    return voltage_limits, generation_limits

extract_generation_costs_from_batch(batch, device: torch.device = None) -> Optional[torch.Tensor]

Extract generation cost coefficients from batch data.

Parameters:

Name Type Description Default
batch

Batch from OPFDataset

required
device device

Target device for tensors

None

Returns:

Type Description
Optional[Tensor]

Tensor of generation cost coefficients or None

Source code in lumina/evaluator/opf/utils.py
def extract_generation_costs_from_batch(batch, device: torch.device = None) -> Optional[torch.Tensor]:
    """
    Extract generation cost coefficients from batch data.

    Args:
        batch: Batch from OPFDataset
        device: Target device for tensors

    Returns:
        Tensor of generation cost coefficients or None
    """
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    try:
        if 'generator' in batch.x_dict:
            gen_data = batch['generator'].x
            # Assuming cost coefficients are the last 3 features
            if gen_data.size(1) >= 3:
                return gen_data[:, -3:].to(device)

    except Exception as e:
        warnings.warn(f"Error extracting generation costs from batch: {e}")

    return None

denormalize_predictions(predictions: Dict[str, torch.Tensor], batch, voltage_range: Tuple[float, float] = (0.95, 1.05), angle_range: Tuple[float, float] = (-180, 180), power_range: Tuple[float, float] = (0, 100)) -> Dict[str, torch.Tensor]

Denormalize model predictions from [0,1] range to physical units.

Parameters:

Name Type Description Default
predictions Dict[str, Tensor]

Normalized predictions from model

required
batch

Batch containing limit information

required
voltage_range Tuple[float, float]

Default voltage magnitude range (per unit)

(0.95, 1.05)
angle_range Tuple[float, float]

Default voltage angle range (degrees)

(-180, 180)
power_range Tuple[float, float]

Default power range (MW/MVAr)

(0, 100)

Returns:

Type Description
Dict[str, Tensor]

Denormalized predictions in physical units

Source code in lumina/evaluator/opf/utils.py
def denormalize_predictions(
    predictions: Dict[str, torch.Tensor],
    batch,
    voltage_range: Tuple[float, float] = (0.95, 1.05),
    angle_range: Tuple[float, float] = (-180, 180),
    power_range: Tuple[float, float] = (0, 100)  # Will be replaced by actual limits
) -> Dict[str, torch.Tensor]:
    """
    Denormalize model predictions from [0,1] range to physical units.

    Args:
        predictions: Normalized predictions from model
        batch: Batch containing limit information
        voltage_range: Default voltage magnitude range (per unit)
        angle_range: Default voltage angle range (degrees)
        power_range: Default power range (MW/MVAr)

    Returns:
        Denormalized predictions in physical units
    """
    denorm_predictions = {}

    # Denormalize bus predictions (voltage magnitude and angle)
    if 'bus' in predictions:
        bus_pred = predictions['bus'].clone()

        # Extract actual limits from batch if available
        if 'bus' in batch.x_dict and batch['bus'].x.size(1) >= 3:
            vmin = batch['bus'].x[:, 1]
            vmax = batch['bus'].x[:, 2]

            # Denormalize voltage magnitude
            vm_denorm = bus_pred[..., 0] * (vmax - vmin) + vmin

        else:
            # Use default range
            vm_denorm = bus_pred[..., 0] * (voltage_range[1] - voltage_range[0]) + voltage_range[0]

        # Denormalize voltage angle
        va_denorm = bus_pred[..., 1] * (angle_range[1] - angle_range[0]) + angle_range[0]

        denorm_predictions['bus'] = torch.stack([vm_denorm, va_denorm], dim=-1)

    # Denormalize generator predictions (active and reactive power)
    if 'generator' in predictions:
        gen_pred = predictions['generator'].clone()

        # Extract actual limits from batch if available
        if 'generator' in batch.x_dict and batch['generator'].x.size(1) >= 7:
            pmin = batch['generator'].x[:, 2]
            pmax = batch['generator'].x[:, 3]
            qmin = batch['generator'].x[:, 5]
            qmax = batch['generator'].x[:, 6]

            # Denormalize active power
            pg_denorm = gen_pred[..., 0] * (pmax - pmin) + pmin

            # Denormalize reactive power
            qg_denorm = gen_pred[..., 1] * (qmax - qmin) + qmin

        else:
            # Use default range
            pg_denorm = gen_pred[..., 0] * (power_range[1] - power_range[0]) + power_range[0]
            qg_denorm = gen_pred[..., 1] * (power_range[1] - power_range[0]) + power_range[0]

        denorm_predictions['generator'] = torch.stack([pg_denorm, qg_denorm], dim=-1)

    return denorm_predictions