Skip to content

Utilities API

Graph Utilities

OPFHomoWrapper

Wrapper using PyG's to_homogeneous() for OPF hetero-to-homo conversion.

Converts HeteroData objects into Data objects suitable for homogeneous GNN training, preserving node/edge type indicators and optionally storing padded edge attributes.

Parameters:

Name Type Description Default
add_node_type bool

Whether to add node_type tensor to the converted data.

True
add_edge_type bool

Whether to add edge_type tensor to the converted data.

True
dummy_values bool

Whether to fill missing attributes with zeros.

True
attach_full_edge_attr bool

Whether to compute and store edge_attr_full (all edge attributes padded to a uniform width).

False
Source code in lumina/utils/graph_utils.py
class OPFHomoWrapper:
    """Wrapper using PyG's ``to_homogeneous()`` for OPF hetero-to-homo conversion.

    Converts ``HeteroData`` objects into ``Data`` objects suitable for
    homogeneous GNN training, preserving node/edge type indicators and
    optionally storing padded edge attributes.

    Args:
        add_node_type (bool): Whether to add ``node_type`` tensor to the
            converted data.
        add_edge_type (bool): Whether to add ``edge_type`` tensor to the
            converted data.
        dummy_values (bool): Whether to fill missing attributes with zeros.
        attach_full_edge_attr (bool): Whether to compute and store
            ``edge_attr_full`` (all edge attributes padded to a uniform
            width).
    """

    def __init__(self,
                 add_node_type: bool = True,
                 add_edge_type: bool = True,
                 dummy_values: bool = True,
                 attach_full_edge_attr: bool = False):
        self.add_node_type = add_node_type
        self.add_edge_type = add_edge_type
        self.dummy_values = dummy_values
        self.attach_full_edge_attr = attach_full_edge_attr

    def convert(self, hetero_data: HeteroData) -> Data:
        """Convert heterogeneous OPF data to homogeneous format via PyG.

        Args:
            hetero_data (HeteroData): Input heterogeneous graph.

        Returns:
            Data: Homogeneous graph with ``node_type_names`` and
                ``edge_type_names`` attributes attached.
        """
        homo_data = hetero_data.to_homogeneous(
            add_node_type=self.add_node_type,
            add_edge_type=self.add_edge_type,
            dummy_values=self.dummy_values
        )
        homo_data.node_type_names = list(getattr(hetero_data, "node_types", []))
        homo_data.edge_type_names = [
            f"{src}::{rel}::{dst}" for (src, rel, dst) in getattr(hetero_data, "edge_types", [])
        ]
        if self.attach_full_edge_attr:
            self._attach_full_edge_attr(homo_data, hetero_data)

        return homo_data

    def _attach_full_edge_attr(self, homo_data: Data, hetero_data: HeteroData):
        edge_attr_list = []
        max_dim = 0
        edge_types = list(getattr(hetero_data, "edge_types", []))
        if not edge_types:
            return

        for edge_type in edge_types:
            edge_data = hetero_data[edge_type]
            edge_index = getattr(edge_data, "edge_index", None)
            if edge_index is None:
                continue
            num_edges = edge_index.size(1)
            if num_edges == 0:
                continue
            edge_attr = getattr(edge_data, "edge_attr", None)
            if torch.is_tensor(edge_attr):
                dim = int(edge_attr.size(1)) if edge_attr.dim() > 1 else 1
                max_dim = max(max_dim, dim)
            else:
                max_dim = max(max_dim, 0)

        if max_dim == 0:
            return

        for edge_type in edge_types:
            edge_data = hetero_data[edge_type]
            edge_index = getattr(edge_data, "edge_index", None)
            if edge_index is None:
                continue
            num_edges = edge_index.size(1)
            if num_edges == 0:
                continue
            edge_attr = getattr(edge_data, "edge_attr", None)
            if torch.is_tensor(edge_attr):
                if edge_attr.dim() == 1:
                    edge_attr = edge_attr.view(-1, 1)
                if edge_attr.size(1) < max_dim:
                    padding = torch.zeros(
                        edge_attr.size(0),
                        max_dim - edge_attr.size(1),
                        dtype=edge_attr.dtype,
                        device=edge_attr.device,
                    )
                    edge_attr = torch.cat([edge_attr, padding], dim=1)
                elif edge_attr.size(1) > max_dim:
                    edge_attr = edge_attr[:, :max_dim]
            else:
                edge_attr = torch.zeros(
                    num_edges,
                    max_dim,
                    dtype=torch.float32,
                    device=edge_index.device,
                )
            edge_attr_list.append(edge_attr)

        if edge_attr_list:
            homo_data.edge_attr_full = torch.cat(edge_attr_list, dim=0)

convert(hetero_data: HeteroData) -> Data

Convert heterogeneous OPF data to homogeneous format via PyG.

Parameters:

Name Type Description Default
hetero_data HeteroData

Input heterogeneous graph.

required

Returns:

Name Type Description
Data Data

Homogeneous graph with node_type_names and edge_type_names attributes attached.

Source code in lumina/utils/graph_utils.py
def convert(self, hetero_data: HeteroData) -> Data:
    """Convert heterogeneous OPF data to homogeneous format via PyG.

    Args:
        hetero_data (HeteroData): Input heterogeneous graph.

    Returns:
        Data: Homogeneous graph with ``node_type_names`` and
            ``edge_type_names`` attributes attached.
    """
    homo_data = hetero_data.to_homogeneous(
        add_node_type=self.add_node_type,
        add_edge_type=self.add_edge_type,
        dummy_values=self.dummy_values
    )
    homo_data.node_type_names = list(getattr(hetero_data, "node_types", []))
    homo_data.edge_type_names = [
        f"{src}::{rel}::{dst}" for (src, rel, dst) in getattr(hetero_data, "edge_types", [])
    ]
    if self.attach_full_edge_attr:
        self._attach_full_edge_attr(homo_data, hetero_data)

    return homo_data

OPFHeteroWrapper

Wrapper that runs a homogeneous GNN on heterogeneous OPF inputs.

Instead of using to_hetero() (which can cause torch.fx issues), this wrapper converts heterogeneous inputs to a simplified homogeneous format, runs the underlying model, and returns outputs keyed by node type.

Parameters:

Name Type Description Default
model Module

Homogeneous GNN model instance.

required
metadata tuple

Graph metadata (node_types, edge_types).

required
aggr str

Aggregation method for combining embeddings from different relations (currently unused; reserved for future use).

'sum'
Source code in lumina/utils/graph_utils.py
class OPFHeteroWrapper:
    """Wrapper that runs a homogeneous GNN on heterogeneous OPF inputs.

    Instead of using ``to_hetero()`` (which can cause ``torch.fx`` issues),
    this wrapper converts heterogeneous inputs to a simplified homogeneous
    format, runs the underlying model, and returns outputs keyed by node
    type.

    Args:
        model (torch.nn.Module): Homogeneous GNN model instance.
        metadata (tuple): Graph metadata ``(node_types, edge_types)``.
        aggr (str): Aggregation method for combining embeddings from
            different relations (currently unused; reserved for future use).
    """

    def __init__(self, model, metadata, aggr: str = 'sum'):
        self.homo_model = model
        self.metadata = metadata
        self.aggr = aggr

        # Instead of using to_hetero(), we'll work with homogeneous converted data
        # This is more stable and avoids torch.fx issues

    def __call__(self, x_dict, edge_index_dict, edge_attr_dict=None):
        """Forward pass: converts hetero inputs to homo, runs model, returns per-type dict.

        Args:
            x_dict (dict[str, torch.Tensor]): Per-node-type feature tensors.
            edge_index_dict (dict[tuple, torch.Tensor]): Per-edge-type
                connectivity tensors.
            edge_attr_dict (dict[tuple, torch.Tensor], optional): Per-edge-type
                attribute tensors.

        Returns:
            dict[str, torch.Tensor]: Model output keyed by the primary node type.
        """
        # Convert heterogeneous input to homogeneous format
        # This is a simplified version that works with OPF data structure

        # Get the main node type data (typically 'bus' nodes)
        if 'bus' in x_dict:
            x = x_dict['bus']
        else:
            # Fallback to first available node type
            first_node_type = list(x_dict.keys())[0]
            x = x_dict[first_node_type]

        # Get the main edge type data
        edge_index = None
        edge_attr = None

        for edge_type, ei in edge_index_dict.items():
            if ei.size(1) > 0:  # Use first non-empty edge type
                edge_index = ei
                if edge_attr_dict and edge_type in edge_attr_dict:
                    edge_attr = edge_attr_dict[edge_type]
                break

        # If no valid edges found, create a minimal edge structure
        if edge_index is None or edge_index.size(1) == 0:
            # Create self-loops for all nodes
            num_nodes = x.size(0)
            edge_index = torch.stack([torch.arange(num_nodes), torch.arange(num_nodes)], dim=0)
            edge_attr = torch.zeros(num_nodes, 1)  # Minimal edge attributes

        # Run the homogeneous model
        if hasattr(self.homo_model, 'return_node_embeddings'):
            output = self.homo_model(x, edge_index, edge_attr, return_node_embeddings=True)
        else:
            output = self.homo_model(x, edge_index, edge_attr)

        # Return output in dictionary format for compatibility
        # Return output for the main node type
        if 'bus' in x_dict:
            return {'bus': output}
        else:
            first_node_type = list(x_dict.keys())[0]
            return {first_node_type: output}

    def parameters(self):
        """Get model parameters."""
        return self.homo_model.parameters()

    def train(self):
        """Set model to training mode."""
        self.homo_model.train()
        return self

    def eval(self):
        """Set model to evaluation mode."""
        self.homo_model.eval()
        return self

    def to(self, device):
        """Move model to device."""
        self.homo_model = self.homo_model.to(device)
        return self

__call__(x_dict, edge_index_dict, edge_attr_dict=None)

Forward pass: converts hetero inputs to homo, runs model, returns per-type dict.

Parameters:

Name Type Description Default
x_dict dict[str, Tensor]

Per-node-type feature tensors.

required
edge_index_dict dict[tuple, Tensor]

Per-edge-type connectivity tensors.

required
edge_attr_dict dict[tuple, Tensor]

Per-edge-type attribute tensors.

None

Returns:

Type Description

dict[str, torch.Tensor]: Model output keyed by the primary node type.

Source code in lumina/utils/graph_utils.py
def __call__(self, x_dict, edge_index_dict, edge_attr_dict=None):
    """Forward pass: converts hetero inputs to homo, runs model, returns per-type dict.

    Args:
        x_dict (dict[str, torch.Tensor]): Per-node-type feature tensors.
        edge_index_dict (dict[tuple, torch.Tensor]): Per-edge-type
            connectivity tensors.
        edge_attr_dict (dict[tuple, torch.Tensor], optional): Per-edge-type
            attribute tensors.

    Returns:
        dict[str, torch.Tensor]: Model output keyed by the primary node type.
    """
    # Convert heterogeneous input to homogeneous format
    # This is a simplified version that works with OPF data structure

    # Get the main node type data (typically 'bus' nodes)
    if 'bus' in x_dict:
        x = x_dict['bus']
    else:
        # Fallback to first available node type
        first_node_type = list(x_dict.keys())[0]
        x = x_dict[first_node_type]

    # Get the main edge type data
    edge_index = None
    edge_attr = None

    for edge_type, ei in edge_index_dict.items():
        if ei.size(1) > 0:  # Use first non-empty edge type
            edge_index = ei
            if edge_attr_dict and edge_type in edge_attr_dict:
                edge_attr = edge_attr_dict[edge_type]
            break

    # If no valid edges found, create a minimal edge structure
    if edge_index is None or edge_index.size(1) == 0:
        # Create self-loops for all nodes
        num_nodes = x.size(0)
        edge_index = torch.stack([torch.arange(num_nodes), torch.arange(num_nodes)], dim=0)
        edge_attr = torch.zeros(num_nodes, 1)  # Minimal edge attributes

    # Run the homogeneous model
    if hasattr(self.homo_model, 'return_node_embeddings'):
        output = self.homo_model(x, edge_index, edge_attr, return_node_embeddings=True)
    else:
        output = self.homo_model(x, edge_index, edge_attr)

    # Return output in dictionary format for compatibility
    # Return output for the main node type
    if 'bus' in x_dict:
        return {'bus': output}
    else:
        first_node_type = list(x_dict.keys())[0]
        return {first_node_type: output}

parameters()

Get model parameters.

Source code in lumina/utils/graph_utils.py
def parameters(self):
    """Get model parameters."""
    return self.homo_model.parameters()

train()

Set model to training mode.

Source code in lumina/utils/graph_utils.py
def train(self):
    """Set model to training mode."""
    self.homo_model.train()
    return self

eval()

Set model to evaluation mode.

Source code in lumina/utils/graph_utils.py
def eval(self):
    """Set model to evaluation mode."""
    self.homo_model.eval()
    return self

to(device)

Move model to device.

Source code in lumina/utils/graph_utils.py
def to(self, device):
    """Move model to device."""
    self.homo_model = self.homo_model.to(device)
    return self

HomoOPFDataset

Dataset wrapper that converts OPF hetero graphs to homogeneous on-the-fly.

Wraps an existing OPF dataset and applies OPFHomoWrapper.convert() to each sample at access time. Non-finite targets are optionally sanitized (replaced with zeros) and flagged via a y_mask attribute.

Parameters:

Name Type Description Default
opf_dataset

Original OPFDataset or Subset instance.

required
add_node_type bool

Whether to include node type information.

True
add_edge_type bool

Whether to include edge type information.

True
sanitize_targets bool

If True, replace non-finite target values with zeros and attach a y_mask boolean tensor.

True
log_bad_targets bool

If True, print a warning when non-finite targets are detected (rank-0 only).

True
max_bad_target_logs int

Maximum number of bad-target warnings to emit per dataset instance.

1
Source code in lumina/utils/graph_utils.py
class HomoOPFDataset:
    """Dataset wrapper that converts OPF hetero graphs to homogeneous on-the-fly.

    Wraps an existing OPF dataset and applies ``OPFHomoWrapper.convert()``
    to each sample at access time. Non-finite targets are optionally
    sanitized (replaced with zeros) and flagged via a ``y_mask`` attribute.

    Args:
        opf_dataset: Original OPFDataset or ``Subset`` instance.
        add_node_type (bool): Whether to include node type information.
        add_edge_type (bool): Whether to include edge type information.
        sanitize_targets (bool): If ``True``, replace non-finite target values
            with zeros and attach a ``y_mask`` boolean tensor.
        log_bad_targets (bool): If ``True``, print a warning when non-finite
            targets are detected (rank-0 only).
        max_bad_target_logs (int): Maximum number of bad-target warnings to
            emit per dataset instance.
    """

    def __init__(
        self,
        opf_dataset,
        add_node_type: bool = True,
        add_edge_type: bool = True,
        sanitize_targets: bool = True,
        log_bad_targets: bool = True,
        max_bad_target_logs: int = 1,
    ):
        self.opf_dataset = opf_dataset
        self.converter = OPFHomoWrapper(
            add_node_type=add_node_type,
            add_edge_type=add_edge_type
        )
        self.sanitize_targets = sanitize_targets
        self.log_bad_targets = log_bad_targets
        self.max_bad_target_logs = max_bad_target_logs
        self._bad_target_logs = 0

    def __len__(self):
        return len(self.opf_dataset)

    def __getitem__(self, idx):
        hetero_data = self.opf_dataset[idx]
        homo_data = self.converter.convert(hetero_data)
        self._sanitize_targets(homo_data, idx)
        return homo_data

    def _sanitize_targets(self, homo_data, idx):
        y = getattr(homo_data, "y", None)
        if not torch.is_tensor(y):
            return

        finite_mask = torch.isfinite(y)
        if finite_mask.ndim > 1:
            row_mask = finite_mask.all(dim=-1)
        else:
            row_mask = finite_mask

        if bool(row_mask.all().item()):
            return

        if self.sanitize_targets:
            if y.ndim == 0:
                y = torch.zeros_like(y)
            else:
                y = y.clone()
                y[~row_mask] = 0
            homo_data.y = y

        homo_data.y_mask = row_mask.to(dtype=torch.bool)

        if self._should_log_bad_targets():
            bad_count = int((~row_mask).sum().item())
            total = int(row_mask.numel())
            action = "sanitized" if self.sanitize_targets else "left as-is"
            print(
                f"[HomoOPFDataset] Non-finite targets in sample {idx}: "
                f"{bad_count}/{total} rows {action}; stored y_mask."
            )
            self._bad_target_logs += 1

    def _should_log_bad_targets(self):
        if not self.log_bad_targets:
            return False
        if self._bad_target_logs >= self.max_bad_target_logs:
            return False
        try:
            rank = int(os.environ.get("RANK", "0"))
        except ValueError:
            rank = 0
        return rank == 0

    @property
    def num_node_features(self):
        # Get a sample to determine feature dimensions
        sample = self[0]
        return sample.x.size(1) if hasattr(sample, 'x') else 0

    @property
    def num_edge_features(self):
        # Get a sample to determine edge feature dimensions
        sample = self[0]
        return sample.edge_attr.size(1) if hasattr(sample, 'edge_attr') else 0

    @property
    def num_node_types(self):
        # Get a sample to determine number of node types
        sample = self[0]
        return int(sample.node_type.max().item()) + 1 if hasattr(sample, 'node_type') else 1

    @property
    def num_edge_types(self):
        # Get a sample to determine number of edge types
        sample = self[0]
        return int(sample.edge_type.max().item()) + 1 if hasattr(sample, 'edge_type') else 1

HeteroToHomoConverter

Converts heterogeneous OPF graphs to homogeneous format with feature projection.

Unlike OPFHomoWrapper (which delegates to PyG's to_homogeneous), this class manually projects node and edge features to fixed-width unified feature spaces, builds global node/edge indices, and preserves type indicators.

Conversion strategy
  1. Project or pad/truncate all node features to node_dim.
  2. Add integer node_type indicators.
  3. Convert all edge types to a single type with edge_type indicators.
  4. Preserve targets for appropriate node/edge types.

Parameters:

Name Type Description Default
node_dim int

Target dimension for unified node features.

64
edge_dim int

Target dimension for unified edge features.

32
use_node_type_embedding bool

Whether to add learnable node type embeddings (reserved for future use).

True
use_edge_type_embedding bool

Whether to add learnable edge type embeddings (reserved for future use).

True
Source code in lumina/utils/graph_utils.py
class HeteroToHomoConverter:
    """Converts heterogeneous OPF graphs to homogeneous format with feature projection.

    Unlike ``OPFHomoWrapper`` (which delegates to PyG's ``to_homogeneous``),
    this class manually projects node and edge features to fixed-width
    unified feature spaces, builds global node/edge indices, and preserves
    type indicators.

    Conversion strategy:
        1. Project or pad/truncate all node features to ``node_dim``.
        2. Add integer ``node_type`` indicators.
        3. Convert all edge types to a single type with ``edge_type`` indicators.
        4. Preserve targets for appropriate node/edge types.

    Args:
        node_dim (int): Target dimension for unified node features.
        edge_dim (int): Target dimension for unified edge features.
        use_node_type_embedding (bool): Whether to add learnable node type
            embeddings (reserved for future use).
        use_edge_type_embedding (bool): Whether to add learnable edge type
            embeddings (reserved for future use).
    """

    def __init__(self,
                 node_dim: int = 64,
                 edge_dim: int = 32,
                 use_node_type_embedding: bool = True,
                 use_edge_type_embedding: bool = True):
        self.node_dim = node_dim
        self.edge_dim = edge_dim
        self.use_node_type_embedding = use_node_type_embedding
        self.use_edge_type_embedding = use_edge_type_embedding

        # Node type mapping
        self.node_types = ['bus', 'generator', 'load', 'shunt']
        self.node_type_to_id = {nt: i for i, nt in enumerate(self.node_types)}

        # Edge type mapping
        self.edge_types = ['ac_line', 'transformer', 'generator_link', 'load_link', 'shunt_link']
        self.edge_type_to_id = {et: i for i, et in enumerate(self.edge_types)}

        self.num_node_types = len(self.node_types)
        self.num_edge_types = len(self.edge_types)

    def convert(self, hetero_data: HeteroData) -> Data:
        """Convert a heterogeneous graph to a homogeneous graph.

        Args:
            hetero_data (HeteroData): Input heterogeneous graph with
                per-type node/edge features and targets.

        Returns:
            Data: Homogeneous graph with unified ``x``, ``edge_index``,
                ``edge_attr``, ``node_type``, ``edge_type``, and optional
                ``y`` / ``edge_label`` attributes.
        """
        # Step 1: Build unified node features and mapping
        node_features, node_types, node_targets, node_mapping = self._unify_nodes(hetero_data)

        # Step 2: Build unified edge features and indices
        edge_index, edge_attr, edge_types, edge_targets = self._unify_edges(hetero_data, node_mapping)

        # Step 3: Create homogeneous data object
        homo_data = Data(
            x=node_features,
            edge_index=edge_index,
            edge_attr=edge_attr,
            node_type=node_types,
            edge_type=edge_types,
        )

        # Add targets if they exist
        if node_targets is not None:
            homo_data.y = node_targets
        if edge_targets is not None:
            homo_data.edge_label = edge_targets

        # Add graph-level properties
        if hasattr(hetero_data, 'baseMVA'):
            homo_data.baseMVA = hetero_data.baseMVA
        if hasattr(hetero_data, 'objective'):
            homo_data.objective = hetero_data.objective

        return homo_data

    def _unify_nodes(self, hetero_data: HeteroData) -> Tuple[torch.Tensor,
                                                             torch.Tensor, Optional[torch.Tensor], Dict[str, torch.Tensor]]:
        """
        Unify all node types into a single node feature matrix.

        Returns:
            node_features: Unified node feature matrix [total_nodes, node_dim]
            node_types: Node type indicators [total_nodes]
            node_targets: Unified node targets [total_nodes, target_dim] or None
            node_mapping: Mapping from hetero node indices to homo node indices
        """
        all_node_features = []
        all_node_types = []
        all_node_targets = []
        node_mapping = {}
        current_node_idx = 0

        for node_type in self.node_types:
            if node_type not in hetero_data.node_types:
                continue

            node_data = hetero_data[node_type]
            num_nodes = node_data.x.size(0)

            # Get original features
            node_x = node_data.x  # [num_nodes, original_dim]

            # Project to target dimension or pad/truncate
            if node_x.size(1) < self.node_dim:
                # Pad with zeros
                padding = torch.zeros(num_nodes, self.node_dim - node_x.size(1),
                                      dtype=node_x.dtype, device=node_x.device)
                unified_features = torch.cat([node_x, padding], dim=1)
            elif node_x.size(1) > self.node_dim:
                # Use linear projection (this would need to be learned in practice)
                unified_features = node_x[:, :self.node_dim]
            else:
                unified_features = node_x

            all_node_features.append(unified_features)

            # Node type indicators
            node_type_id = self.node_type_to_id[node_type]
            node_type_tensor = torch.full((num_nodes,), node_type_id,
                                          dtype=torch.long, device=node_x.device)
            all_node_types.append(node_type_tensor)

            # Node targets (if they exist)
            if hasattr(node_data, 'y') and node_data.y is not None:
                all_node_targets.append(node_data.y)
            else:
                # Add dummy targets for consistency
                dummy_targets = torch.zeros(num_nodes, 2, dtype=torch.float32, device=node_x.device)
                all_node_targets.append(dummy_targets)

            # Build mapping
            node_indices = torch.arange(current_node_idx, current_node_idx + num_nodes,
                                        dtype=torch.long, device=node_x.device)
            node_mapping[node_type] = node_indices
            current_node_idx += num_nodes

        # Concatenate all features
        node_features = torch.cat(all_node_features, dim=0)
        node_types = torch.cat(all_node_types, dim=0)

        # Handle targets
        if all_node_targets and any(target.numel() > 0 for target in all_node_targets):
            # Find max target dimension
            max_target_dim = max(target.size(1) for target in all_node_targets if target.numel() > 0)

            # Pad all targets to the same dimension
            padded_targets = []
            for target in all_node_targets:
                if target.size(1) < max_target_dim:
                    padding = torch.zeros(target.size(0), max_target_dim - target.size(1),
                                          dtype=target.dtype, device=target.device)
                    padded_target = torch.cat([target, padding], dim=1)
                else:
                    padded_target = target[:, :max_target_dim]
                padded_targets.append(padded_target)

            node_targets = torch.cat(padded_targets, dim=0)
        else:
            node_targets = None

        return node_features, node_types, node_targets, node_mapping

    def _unify_edges(self,
                     hetero_data: HeteroData,
                     node_mapping: Dict[str,
                                        torch.Tensor]) -> Tuple[torch.Tensor,
                                                                torch.Tensor,
                                                                torch.Tensor,
                                                                Optional[torch.Tensor]]:
        """
        Unify all edge types into homogeneous edge representation.

        Returns:
            edge_index: Unified edge indices [2, total_edges]
            edge_attr: Unified edge features [total_edges, edge_dim]
            edge_types: Edge type indicators [total_edges]
            edge_targets: Unified edge targets [total_edges, target_dim] or None
        """
        all_edge_indices = []
        all_edge_attrs = []
        all_edge_types = []
        all_edge_targets = []

        # Process edge types that correspond to actual connections
        for edge_type_tuple in hetero_data.edge_types:
            src_type, edge_type, dst_type = edge_type_tuple

            if edge_type not in self.edge_types:
                continue

            edge_data = hetero_data[edge_type_tuple]
            if not hasattr(edge_data, 'edge_index'):
                continue

            edge_index = edge_data.edge_index  # [2, num_edges]
            num_edges = edge_index.size(1)

            if num_edges == 0:
                continue

            # Map to homogeneous node indices
            src_mapping = node_mapping.get(src_type, None)
            dst_mapping = node_mapping.get(dst_type, None)

            if src_mapping is None or dst_mapping is None:
                continue

            # Convert edge indices
            homo_edge_index = torch.zeros_like(edge_index)
            homo_edge_index[0] = src_mapping[edge_index[0]]
            homo_edge_index[1] = dst_mapping[edge_index[1]]
            all_edge_indices.append(homo_edge_index)

            # Process edge attributes
            if hasattr(edge_data, 'edge_attr') and edge_data.edge_attr is not None:
                edge_attr = edge_data.edge_attr

                # Project to target dimension or pad/truncate
                if edge_attr.size(1) < self.edge_dim:
                    # Pad with zeros
                    padding = torch.zeros(num_edges, self.edge_dim - edge_attr.size(1),
                                          dtype=edge_attr.dtype, device=edge_attr.device)
                    unified_edge_attr = torch.cat([edge_attr, padding], dim=1)
                elif edge_attr.size(1) > self.edge_dim:
                    # Truncate (in practice, use learned projection)
                    unified_edge_attr = edge_attr[:, :self.edge_dim]
                else:
                    unified_edge_attr = edge_attr
            else:
                # Create dummy edge features
                unified_edge_attr = torch.zeros(num_edges, self.edge_dim,
                                                dtype=torch.float32, device=edge_index.device)

            all_edge_attrs.append(unified_edge_attr)

            # Edge type indicators
            edge_type_id = self.edge_type_to_id[edge_type]
            edge_type_tensor = torch.full((num_edges,), edge_type_id,
                                          dtype=torch.long, device=edge_index.device)
            all_edge_types.append(edge_type_tensor)

            # Edge targets (if they exist)
            if hasattr(edge_data, 'edge_label') and edge_data.edge_label is not None:
                all_edge_targets.append(edge_data.edge_label)
            else:
                # Add dummy targets for consistency
                dummy_targets = torch.zeros(num_edges, 4, dtype=torch.float32, device=edge_index.device)
                all_edge_targets.append(dummy_targets)

        # Concatenate all edges
        if all_edge_indices:
            edge_index = torch.cat(all_edge_indices, dim=1)
            edge_attr = torch.cat(all_edge_attrs, dim=0)
            edge_types = torch.cat(all_edge_types, dim=0)

            # Handle edge targets
            if all_edge_targets and any(target.numel() > 0 for target in all_edge_targets):
                # Find max target dimension
                max_target_dim = max(target.size(1) for target in all_edge_targets if target.numel() > 0)

                # Pad all targets to the same dimension
                padded_targets = []
                for target in all_edge_targets:
                    if target.size(1) < max_target_dim:
                        padding = torch.zeros(target.size(0), max_target_dim - target.size(1),
                                              dtype=target.dtype, device=target.device)
                        padded_target = torch.cat([target, padding], dim=1)
                    else:
                        padded_target = target[:, :max_target_dim]
                    padded_targets.append(padded_target)

                edge_targets = torch.cat(padded_targets, dim=0)
            else:
                edge_targets = None
        else:
            # No edges found
            device = next(iter(node_mapping.values())).device
            edge_index = torch.empty((2, 0), dtype=torch.long, device=device)
            edge_attr = torch.empty((0, self.edge_dim), dtype=torch.float32, device=device)
            edge_types = torch.empty((0,), dtype=torch.long, device=device)
            edge_targets = None

        return edge_index, edge_attr, edge_types, edge_targets

convert(hetero_data: HeteroData) -> Data

Convert a heterogeneous graph to a homogeneous graph.

Parameters:

Name Type Description Default
hetero_data HeteroData

Input heterogeneous graph with per-type node/edge features and targets.

required

Returns:

Name Type Description
Data Data

Homogeneous graph with unified x, edge_index, edge_attr, node_type, edge_type, and optional y / edge_label attributes.

Source code in lumina/utils/graph_utils.py
def convert(self, hetero_data: HeteroData) -> Data:
    """Convert a heterogeneous graph to a homogeneous graph.

    Args:
        hetero_data (HeteroData): Input heterogeneous graph with
            per-type node/edge features and targets.

    Returns:
        Data: Homogeneous graph with unified ``x``, ``edge_index``,
            ``edge_attr``, ``node_type``, ``edge_type``, and optional
            ``y`` / ``edge_label`` attributes.
    """
    # Step 1: Build unified node features and mapping
    node_features, node_types, node_targets, node_mapping = self._unify_nodes(hetero_data)

    # Step 2: Build unified edge features and indices
    edge_index, edge_attr, edge_types, edge_targets = self._unify_edges(hetero_data, node_mapping)

    # Step 3: Create homogeneous data object
    homo_data = Data(
        x=node_features,
        edge_index=edge_index,
        edge_attr=edge_attr,
        node_type=node_types,
        edge_type=edge_types,
    )

    # Add targets if they exist
    if node_targets is not None:
        homo_data.y = node_targets
    if edge_targets is not None:
        homo_data.edge_label = edge_targets

    # Add graph-level properties
    if hasattr(hetero_data, 'baseMVA'):
        homo_data.baseMVA = hetero_data.baseMVA
    if hasattr(hetero_data, 'objective'):
        homo_data.objective = hetero_data.objective

    return homo_data

prepare_opf_training_data(dataset, train_ratio: float = 0.8, val_ratio: float = 0.1, test_ratio: float = 0.1, use_homogeneous: bool = True)

Split an OPF dataset into train/val/test subsets.

Creates random index-based subsets and optionally wraps each in HomoOPFDataset for homogeneous model training.

Parameters:

Name Type Description Default
dataset

OPFDataset instance with __len__ support.

required
train_ratio float

Fraction of data for training.

0.8
val_ratio float

Fraction of data for validation.

0.1
test_ratio float

Fraction of data for testing.

0.1
use_homogeneous bool

If True, wrap each subset in HomoOPFDataset for on-the-fly hetero-to-homo conversion.

True

Returns:

Name Type Description
tuple

(train_dataset, val_dataset, test_dataset) subsets.

Raises:

Type Description
AssertionError

If ratios do not sum to 1.0.

Source code in lumina/utils/graph_utils.py
def prepare_opf_training_data(dataset,
                              train_ratio: float = 0.8,
                              val_ratio: float = 0.1,
                              test_ratio: float = 0.1,
                              use_homogeneous: bool = True):
    """Split an OPF dataset into train/val/test subsets.

    Creates random index-based subsets and optionally wraps each in
    ``HomoOPFDataset`` for homogeneous model training.

    Args:
        dataset: OPFDataset instance with ``__len__`` support.
        train_ratio (float): Fraction of data for training.
        val_ratio (float): Fraction of data for validation.
        test_ratio (float): Fraction of data for testing.
        use_homogeneous (bool): If ``True``, wrap each subset in
            ``HomoOPFDataset`` for on-the-fly hetero-to-homo conversion.

    Returns:
        tuple: ``(train_dataset, val_dataset, test_dataset)`` subsets.

    Raises:
        AssertionError: If ratios do not sum to 1.0.
    """
    assert abs(train_ratio + val_ratio + test_ratio - 1.0) < 1e-6, "Ratios must sum to 1.0"

    total_size = len(dataset)
    train_size = int(train_ratio * total_size)
    val_size = int(val_ratio * total_size)
    test_size = total_size - train_size - val_size

    # Create random indices for splitting
    indices = torch.randperm(total_size)
    train_indices = indices[:train_size]
    val_indices = indices[train_size:train_size + val_size]
    test_indices = indices[train_size + val_size:]

    # Create subset datasets
    train_subset = torch.utils.data.Subset(dataset, train_indices)
    val_subset = torch.utils.data.Subset(dataset, val_indices)
    test_subset = torch.utils.data.Subset(dataset, test_indices)

    if use_homogeneous:
        # Wrap with homogeneous conversion
        train_dataset = HomoOPFDataset(train_subset)
        val_dataset = HomoOPFDataset(val_subset)
        test_dataset = HomoOPFDataset(test_subset)
    else:
        train_dataset = train_subset
        val_dataset = val_subset
        test_dataset = test_subset

    return train_dataset, val_dataset, test_dataset

get_opf_metadata(dataset)

Extract graph metadata from an OPF dataset.

Retrieves the (node_types, edge_types) tuple needed by to_hetero() and other heterogeneous model utilities.

Parameters:

Name Type Description Default
dataset

OPFDataset instance supporting indexing.

required

Returns:

Name Type Description
tuple

(node_types, edge_types) where each element is a list of strings or tuples respectively.

Source code in lumina/utils/graph_utils.py
def get_opf_metadata(dataset):
    """Extract graph metadata from an OPF dataset.

    Retrieves the ``(node_types, edge_types)`` tuple needed by
    ``to_hetero()`` and other heterogeneous model utilities.

    Args:
        dataset: OPFDataset instance supporting indexing.

    Returns:
        tuple: ``(node_types, edge_types)`` where each element is a list
            of strings or tuples respectively.
    """
    # Get a sample to extract metadata
    sample = dataset[0]
    if hasattr(sample, 'metadata'):
        return sample.metadata()
    else:
        # Manually extract from the heterogeneous data structure
        node_types = list(sample.node_types)
        edge_types = list(sample.edge_types)
        return (node_types, edge_types)

convert_opf_to_homo(hetero_data: HeteroData, node_dim: int = 64, edge_dim: int = 32) -> Data

Convert OPF heterogeneous data to homogeneous format.

Convenience function that creates a HeteroToHomoConverter and calls its convert method.

Parameters:

Name Type Description Default
hetero_data HeteroData

Input heterogeneous OPF graph.

required
node_dim int

Target unified node feature dimension.

64
edge_dim int

Target unified edge feature dimension.

32

Returns:

Name Type Description
Data Data

Homogeneous graph data object.

Source code in lumina/utils/graph_utils.py
def convert_opf_to_homo(hetero_data: HeteroData,
                        node_dim: int = 64,
                        edge_dim: int = 32) -> Data:
    """Convert OPF heterogeneous data to homogeneous format.

    Convenience function that creates a ``HeteroToHomoConverter`` and calls
    its ``convert`` method.

    Args:
        hetero_data (HeteroData): Input heterogeneous OPF graph.
        node_dim (int): Target unified node feature dimension.
        edge_dim (int): Target unified edge feature dimension.

    Returns:
        Data: Homogeneous graph data object.
    """
    converter = HeteroToHomoConverter(node_dim=node_dim, edge_dim=edge_dim)
    return converter.convert(hetero_data)

Throughput Tracking

ThroughputTracker

Measures training throughput (samples/sec) over a configurable window.

After a configurable warmup period, records per-step timing for a fixed number of measurement steps, then computes and logs the mean throughput across all DDP ranks. Results are optionally logged to W&B and written as JSON metadata.

Parameters:

Name Type Description Default
config dict

Full training configuration. Throughput settings are read from config["training"]: throughput_enabled, throughput_warmup_steps, throughput_measure_steps.

required
world_size int

Total number of DDP processes.

required
global_rank int

Rank of the current process.

required
get_global_step callable

Zero-argument callable returning the current global step count (used as W&B x-axis).

required
wandb_enabled bool

Whether to log metrics to W&B.

False
Source code in lumina/utils/throughput.py
class ThroughputTracker:
    """Measures training throughput (samples/sec) over a configurable window.

    After a configurable warmup period, records per-step timing for a fixed
    number of measurement steps, then computes and logs the mean throughput
    across all DDP ranks.  Results are optionally logged to W&B and written
    as JSON metadata.

    Args:
        config (dict): Full training configuration. Throughput settings are
            read from ``config["training"]``: ``throughput_enabled``,
            ``throughput_warmup_steps``, ``throughput_measure_steps``.
        world_size (int): Total number of DDP processes.
        global_rank (int): Rank of the current process.
        get_global_step (callable): Zero-argument callable returning the
            current global step count (used as W&B x-axis).
        wandb_enabled (bool): Whether to log metrics to W&B.
    """

    def __init__(self, config, world_size, global_rank, get_global_step, wandb_enabled=False):
        self.config = config or {}
        self.world_size = world_size
        self.global_rank = global_rank
        self._get_global_step = get_global_step or (lambda: 0)
        self.wandb_enabled = bool(wandb_enabled) and WANDB_AVAILABLE

        training_config = self.config.get("training", {})
        self.enabled = training_config.get("throughput_enabled", False)
        self.warmup_steps = max(0, int(training_config.get("throughput_warmup_steps", 100)))
        self.measure_steps = max(0, int(training_config.get("throughput_measure_steps", 200)))
        if self.measure_steps == 0:
            self.enabled = False

        self.has_run = False
        self.step_index = 0
        self.measure_started = False
        self.measure_count = 0
        self.samples = []
        self.metadata_written = False
        self.loader_config = self.config.get("loader", {})

    def set_wandb_enabled(self, enabled):
        """Enable or disable W&B logging for throughput metrics.

        Args:
            enabled (bool): ``True`` to enable W&B logging (only effective
                if ``wandb`` is installed).
        """
        self.wandb_enabled = bool(enabled) and WANDB_AVAILABLE

    def _global_step(self):
        try:
            return int(self._get_global_step())
        except Exception:
            return 0

    def _get_git_hash(self):
        repo_root = Path(__file__).resolve().parents[2]
        try:
            output = subprocess.check_output(
                ["git", "rev-parse", "HEAD"],
                cwd=repo_root,
                stderr=subprocess.DEVNULL,
                text=True,
            )
            return output.strip()
        except Exception:
            return None

    def write_metadata(self):
        """Write environment and configuration metadata to a JSON file (rank-0 only)."""
        if not self.enabled:
            return
        if self.metadata_written or self.global_rank != 0:
            return
        logging_dir = self.config.get("logging_dir", ".")
        os.makedirs(logging_dir, exist_ok=True)
        env_keys = [
            "RANK",
            "LOCAL_RANK",
            "WORLD_SIZE",
            "MASTER_ADDR",
            "MASTER_PORT",
            "OMP_NUM_THREADS",
        ]
        metadata = {
            "git_hash": self._get_git_hash(),
            "config_yaml": yaml.safe_dump(self.config, sort_keys=False),
            "hostname": socket.gethostname(),
            "world_size": self.world_size,
            "env": {key: os.environ.get(key) for key in env_keys if key in os.environ},
            "torch_version": torch.__version__,
            "dist_backend": dist.get_backend() if dist.is_initialized() else None,
        }
        metadata_path = os.path.join(logging_dir, "throughput_metadata.json")
        with open(metadata_path, "w") as handle:
            json.dump(metadata, handle, indent=2)
        self.metadata_written = True
        if self.wandb_enabled:
            wandb.log({"throughput/metadata_path": metadata_path}, step=self._global_step())

    def maybe_start_measurement(self):
        """Begin measurement if warmup is complete and measurement has not yet run.

        Returns:
            bool: ``True`` if measurement is now active, ``False`` otherwise.
        """
        if not self.enabled or self.has_run:
            return False
        if self.measure_started:
            return True
        if self.step_index < self.warmup_steps:
            return False
        if dist.is_available() and dist.is_initialized() and self.world_size > 1:
            dist.barrier()
        self.measure_started = True
        self.measure_count = 0
        if self.global_rank == 0:
            print(
                f"Starting throughput measurement: warmup={self.warmup_steps}, "
                f"measure={self.measure_steps}"
            )
        self.write_metadata()
        return True

    def measure_active(self):
        """Return whether throughput measurement is currently in progress.

        Returns:
            bool: ``True`` when actively measuring.
        """
        return self.measure_started and not self.has_run

    def accelerator_synchronize(self):
        """Synchronize the current accelerator device for accurate timing."""
        if hasattr(torch, "accelerator") and hasattr(torch.accelerator, "synchronize"):
            torch.accelerator.synchronize()

    def get_batch_samples(self, batch):
        """Determine the number of samples in a batch.

        Args:
            batch: PyG batch object or plain tensor.

        Returns:
            int: Number of samples in the batch.
        """
        if hasattr(batch, "num_graphs"):
            return int(batch.num_graphs)
        if torch.is_tensor(batch):
            return int(batch.size(0))
        return int(self.loader_config.get("batch_size", 1))

    def record_step(self, step_metrics):
        """Record a single step's throughput metrics.

        Args:
            step_metrics (dict): Metrics for the step, must include
                ``'throughput/samples_per_sec'``.
        """
        self.samples.append(step_metrics)
        if self.wandb_enabled:
            wandb.log(step_metrics, step=self._global_step())

    def on_step_end(self, step_metrics=None):
        """Called at the end of each training step to record metrics and check completion.

        Automatically calls ``finalize`` once the measurement window is filled.

        Args:
            step_metrics (dict, optional): Throughput metrics for this step.
        """
        if not self.enabled or self.has_run:
            return
        self.step_index += 1
        if not self.measure_active():
            return
        if step_metrics is None:
            return
        self.record_step(step_metrics)
        self.measure_count += 1
        if self.measure_count >= self.measure_steps:
            self.finalize()

    def finalize(self, partial=False):
        """Compute and log the final throughput summary across all ranks.

        Gathers per-step samples/sec from all DDP ranks, computes the global
        mean, and logs to W&B if enabled.

        Args:
            partial (bool): If ``True``, indicates the measurement window
                was not fully completed.
        """
        if not self.measure_started:
            return
        if self.has_run:
            return
        if dist.is_available() and dist.is_initialized() and self.world_size > 1:
            dist.barrier()
        if not self.samples:
            if self.global_rank == 0:
                print("Throughput measurement skipped: no samples collected.")
            self.has_run = True
            return

        samples_per_sec = [sample["throughput/samples_per_sec"] for sample in self.samples]
        global_samples_per_sec = samples_per_sec
        if dist.is_available() and dist.is_initialized() and self.world_size > 1:
            gathered = [None for _ in range(self.world_size)]
            dist.all_gather_object(gathered, samples_per_sec)
            global_samples_per_sec = [value for sublist in gathered for value in sublist]

        sample_array = np.array(global_samples_per_sec)
        global_mean = float(sample_array.mean())

        summary = {
            "throughput/summary/mean_samples_per_sec": global_mean,
            "throughput/summary/partial": float(partial),
        }

        if self.global_rank == 0:
            status = "partial" if partial else "complete"
            print(f"Throughput measurement {status}: mean_samples_per_sec={global_mean:.3f}")
        if self.wandb_enabled:
            wandb.log(summary, step=self._global_step())

        self.has_run = True

set_wandb_enabled(enabled)

Enable or disable W&B logging for throughput metrics.

Parameters:

Name Type Description Default
enabled bool

True to enable W&B logging (only effective if wandb is installed).

required
Source code in lumina/utils/throughput.py
def set_wandb_enabled(self, enabled):
    """Enable or disable W&B logging for throughput metrics.

    Args:
        enabled (bool): ``True`` to enable W&B logging (only effective
            if ``wandb`` is installed).
    """
    self.wandb_enabled = bool(enabled) and WANDB_AVAILABLE

write_metadata()

Write environment and configuration metadata to a JSON file (rank-0 only).

Source code in lumina/utils/throughput.py
def write_metadata(self):
    """Write environment and configuration metadata to a JSON file (rank-0 only)."""
    if not self.enabled:
        return
    if self.metadata_written or self.global_rank != 0:
        return
    logging_dir = self.config.get("logging_dir", ".")
    os.makedirs(logging_dir, exist_ok=True)
    env_keys = [
        "RANK",
        "LOCAL_RANK",
        "WORLD_SIZE",
        "MASTER_ADDR",
        "MASTER_PORT",
        "OMP_NUM_THREADS",
    ]
    metadata = {
        "git_hash": self._get_git_hash(),
        "config_yaml": yaml.safe_dump(self.config, sort_keys=False),
        "hostname": socket.gethostname(),
        "world_size": self.world_size,
        "env": {key: os.environ.get(key) for key in env_keys if key in os.environ},
        "torch_version": torch.__version__,
        "dist_backend": dist.get_backend() if dist.is_initialized() else None,
    }
    metadata_path = os.path.join(logging_dir, "throughput_metadata.json")
    with open(metadata_path, "w") as handle:
        json.dump(metadata, handle, indent=2)
    self.metadata_written = True
    if self.wandb_enabled:
        wandb.log({"throughput/metadata_path": metadata_path}, step=self._global_step())

maybe_start_measurement()

Begin measurement if warmup is complete and measurement has not yet run.

Returns:

Name Type Description
bool

True if measurement is now active, False otherwise.

Source code in lumina/utils/throughput.py
def maybe_start_measurement(self):
    """Begin measurement if warmup is complete and measurement has not yet run.

    Returns:
        bool: ``True`` if measurement is now active, ``False`` otherwise.
    """
    if not self.enabled or self.has_run:
        return False
    if self.measure_started:
        return True
    if self.step_index < self.warmup_steps:
        return False
    if dist.is_available() and dist.is_initialized() and self.world_size > 1:
        dist.barrier()
    self.measure_started = True
    self.measure_count = 0
    if self.global_rank == 0:
        print(
            f"Starting throughput measurement: warmup={self.warmup_steps}, "
            f"measure={self.measure_steps}"
        )
    self.write_metadata()
    return True

measure_active()

Return whether throughput measurement is currently in progress.

Returns:

Name Type Description
bool

True when actively measuring.

Source code in lumina/utils/throughput.py
def measure_active(self):
    """Return whether throughput measurement is currently in progress.

    Returns:
        bool: ``True`` when actively measuring.
    """
    return self.measure_started and not self.has_run

accelerator_synchronize()

Synchronize the current accelerator device for accurate timing.

Source code in lumina/utils/throughput.py
def accelerator_synchronize(self):
    """Synchronize the current accelerator device for accurate timing."""
    if hasattr(torch, "accelerator") and hasattr(torch.accelerator, "synchronize"):
        torch.accelerator.synchronize()

get_batch_samples(batch)

Determine the number of samples in a batch.

Parameters:

Name Type Description Default
batch

PyG batch object or plain tensor.

required

Returns:

Name Type Description
int

Number of samples in the batch.

Source code in lumina/utils/throughput.py
def get_batch_samples(self, batch):
    """Determine the number of samples in a batch.

    Args:
        batch: PyG batch object or plain tensor.

    Returns:
        int: Number of samples in the batch.
    """
    if hasattr(batch, "num_graphs"):
        return int(batch.num_graphs)
    if torch.is_tensor(batch):
        return int(batch.size(0))
    return int(self.loader_config.get("batch_size", 1))

record_step(step_metrics)

Record a single step's throughput metrics.

Parameters:

Name Type Description Default
step_metrics dict

Metrics for the step, must include 'throughput/samples_per_sec'.

required
Source code in lumina/utils/throughput.py
def record_step(self, step_metrics):
    """Record a single step's throughput metrics.

    Args:
        step_metrics (dict): Metrics for the step, must include
            ``'throughput/samples_per_sec'``.
    """
    self.samples.append(step_metrics)
    if self.wandb_enabled:
        wandb.log(step_metrics, step=self._global_step())

on_step_end(step_metrics=None)

Called at the end of each training step to record metrics and check completion.

Automatically calls finalize once the measurement window is filled.

Parameters:

Name Type Description Default
step_metrics dict

Throughput metrics for this step.

None
Source code in lumina/utils/throughput.py
def on_step_end(self, step_metrics=None):
    """Called at the end of each training step to record metrics and check completion.

    Automatically calls ``finalize`` once the measurement window is filled.

    Args:
        step_metrics (dict, optional): Throughput metrics for this step.
    """
    if not self.enabled or self.has_run:
        return
    self.step_index += 1
    if not self.measure_active():
        return
    if step_metrics is None:
        return
    self.record_step(step_metrics)
    self.measure_count += 1
    if self.measure_count >= self.measure_steps:
        self.finalize()

finalize(partial=False)

Compute and log the final throughput summary across all ranks.

Gathers per-step samples/sec from all DDP ranks, computes the global mean, and logs to W&B if enabled.

Parameters:

Name Type Description Default
partial bool

If True, indicates the measurement window was not fully completed.

False
Source code in lumina/utils/throughput.py
def finalize(self, partial=False):
    """Compute and log the final throughput summary across all ranks.

    Gathers per-step samples/sec from all DDP ranks, computes the global
    mean, and logs to W&B if enabled.

    Args:
        partial (bool): If ``True``, indicates the measurement window
            was not fully completed.
    """
    if not self.measure_started:
        return
    if self.has_run:
        return
    if dist.is_available() and dist.is_initialized() and self.world_size > 1:
        dist.barrier()
    if not self.samples:
        if self.global_rank == 0:
            print("Throughput measurement skipped: no samples collected.")
        self.has_run = True
        return

    samples_per_sec = [sample["throughput/samples_per_sec"] for sample in self.samples]
    global_samples_per_sec = samples_per_sec
    if dist.is_available() and dist.is_initialized() and self.world_size > 1:
        gathered = [None for _ in range(self.world_size)]
        dist.all_gather_object(gathered, samples_per_sec)
        global_samples_per_sec = [value for sublist in gathered for value in sublist]

    sample_array = np.array(global_samples_per_sec)
    global_mean = float(sample_array.mean())

    summary = {
        "throughput/summary/mean_samples_per_sec": global_mean,
        "throughput/summary/partial": float(partial),
    }

    if self.global_rank == 0:
        status = "partial" if partial else "complete"
        print(f"Throughput measurement {status}: mean_samples_per_sec={global_mean:.3f}")
    if self.wandb_enabled:
        wandb.log(summary, step=self._global_step())

    self.has_run = True

Model Utilities

set_seed(seed)

Set random seeds for reproducibility across all backends.

Configures Python random, NumPy, PyTorch CPU, and (if available) PyTorch CUDA random number generators. Also sets cuDNN to deterministic mode.

Parameters:

Name Type Description Default
seed int

Random seed value.

required
Source code in lumina/utils/model.py
def set_seed(seed):
    """Set random seeds for reproducibility across all backends.

    Configures Python ``random``, NumPy, PyTorch CPU, and (if available)
    PyTorch CUDA random number generators. Also sets cuDNN to deterministic
    mode.

    Args:
        seed (int): Random seed value.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

dict_agg(stats, key, value, op='concat')

Aggregate a value into a dictionary entry by summing or concatenating.

If key already exists in stats, the value is combined with the existing entry using the specified operation. Otherwise, the value is stored directly.

Parameters:

Name Type Description Default
stats dict

Dictionary to update in place.

required
key str

Key in the dictionary.

required
value ndarray

Value to add or concatenate.

required
op str

Operation type -- 'sum' for element-wise addition, 'concat' for numpy.concatenate along axis 0.

'concat'

Raises:

Type Description
NotImplementedError

If op is not 'sum' or 'concat'.

Source code in lumina/utils/model.py
def dict_agg(stats, key, value, op='concat'):
    """Aggregate a value into a dictionary entry by summing or concatenating.

    If *key* already exists in *stats*, the value is combined with the
    existing entry using the specified operation.  Otherwise, the value is
    stored directly.

    Args:
        stats (dict): Dictionary to update in place.
        key (str): Key in the dictionary.
        value (numpy.ndarray): Value to add or concatenate.
        op (str): Operation type -- ``'sum'`` for element-wise addition,
            ``'concat'`` for ``numpy.concatenate`` along axis 0.

    Raises:
        NotImplementedError: If *op* is not ``'sum'`` or ``'concat'``.
    """
    # Modifies stats in place
    if key in stats.keys():
        if op == 'sum':
            stats[key] += value
        elif op == 'concat':
            stats[key] = np.concatenate((stats[key], value), axis=0)
        else:
            raise NotImplementedError
    else:
        stats[key] = value

I/O

check_dir(dir_path)

Ensure a directory exists, creating it (and parents) if necessary.

Parameters:

Name Type Description Default
dir_path str

Path to the directory to verify or create.

required
Source code in lumina/utils/io.py
def check_dir(dir_path):
    """Ensure a directory exists, creating it (and parents) if necessary.

    Args:
        dir_path (str): Path to the directory to verify or create.
    """
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)