Skip to content

Dataset API

In-Memory Dataset

OPFDataset

Bases: InMemoryDataset

The heterogeneous OPF data from the "Large-scale Datasets for AC Optimal Power Flow with Topological Perturbations" <https://arxiv.org/abs/2406.07234>_ paper.

:class:OPFDataset is a large-scale dataset of solved optimal power flow problems, derived from the pglib-opf <https://github.com/power-grid-lib/pglib-opf>_ dataset.

The physical topology of the grid is represented by the :obj:"bus" node type, and the connecting AC lines and transformers. Additionally, :obj:"generator", :obj:"load", and :obj:"shunt" nodes are connected to :obj:"bus" nodes using a dedicated edge type each, e.g., :obj:"generator_link".

Edge direction corresponds to the properties of the line, e.g., :obj:b_fr is the line charging susceptance at the :obj:from (source/sender) bus.

Parameters:

Name Type Description Default
root str

Root directory where the dataset should be saved.

required
case_name str

The name of the original pglib-opf case. (default: :obj:"pglib_opf_case14_ieee")

'pglib_opf_case14_ieee'
group_id int

The specific group to load. Each group contains 15,000 samples. Valid values are [0, 19]. (default: :obj:0)

0
transform callable

A function/transform that takes in a :obj:torch_geometric.data.HeteroData object and returns a transformed version. The data object will be transformed before every access. (default: :obj:None)

None
pre_transform callable

A function/transform that takes in a :obj:torch_geometric.data.HeteroData object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:None)

None
pre_filter callable

A function that takes in a :obj:torch_geometric.data.HeteroData object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:None)

None
force_reload bool

Whether to re-process the dataset. (default: :obj:False)

False
keep_temp bool

Whether to keep the temporary files after processing. (default: :obj:False)

False
n_jobs int

The number of jobs to use for parallel processing. If set to :obj:-1, all available cores will be used. NOTE: for larger dataset, it is recommended to set this to a lower positive value to avoid memory issues. (default: :obj:-1)

-1
local_raw_folder str

Local folder to look for raw files. If :obj:None, files will be downloaded from the internet. (default: :obj:None)

None

Examples:

>>> from lumina.dataset.opf.opf_dataset import OPFDataset
>>> dataset = OPFDataset(root='./data', case_name='pglib_opf_case14_ieee')
>>> # By default, only first group (i.e., group 0) is loaded, if you want to load multiple groups,
>>> #   please use OPFMultiDataset instead:
>>> dataset = OPFMultiDataset.from_case_groups(root='./data', case_name='pglib_opf_case14_ieee', group_ids=[0,1,2])
Source code in lumina/dataset/opf/opf_dataset.py
class OPFDataset(InMemoryDataset):
    r"""The heterogeneous OPF data from the `"Large-scale Datasets for AC
    Optimal Power Flow with Topological Perturbations"
    <https://arxiv.org/abs/2406.07234>`_ paper.

    :class:`OPFDataset` is a large-scale dataset of solved optimal power flow
    problems, derived from the
    `pglib-opf <https://github.com/power-grid-lib/pglib-opf>`_ dataset.

    The physical topology of the grid is represented by the :obj:`"bus"` node
    type, and the connecting AC lines and transformers. Additionally,
    :obj:`"generator"`, :obj:`"load"`, and :obj:`"shunt"` nodes are connected
    to :obj:`"bus"` nodes using a dedicated edge type each, *e.g.*,
    :obj:`"generator_link"`.

    Edge direction corresponds to the properties of the line, *e.g.*,
    :obj:`b_fr` is the line charging susceptance at the :obj:`from`
    (source/sender) bus.

    Args:
        root (str): Root directory where the dataset should be saved.
        case_name (str, optional): The name of the original pglib-opf case.
            (default: :obj:`"pglib_opf_case14_ieee"`)
        group_id (int, optional): The specific group to load. Each group
            contains 15,000 samples. Valid values are [0, 19].
            (default: :obj:`0`)
        transform (callable, optional): A function/transform that takes in
            a :obj:`torch_geometric.data.HeteroData` object and returns a
            transformed version. The data object will be transformed before
            every access. (default: :obj:`None`)
        pre_transform (callable, optional): A function/transform that takes
            in a :obj:`torch_geometric.data.HeteroData` object and returns
            a transformed version. The data object will be transformed before
            being saved to disk. (default: :obj:`None`)
        pre_filter (callable, optional): A function that takes in a
            :obj:`torch_geometric.data.HeteroData` object and returns a boolean
            value, indicating whether the data object should be included in the
            final dataset. (default: :obj:`None`)
        force_reload (bool, optional): Whether to re-process the dataset.
            (default: :obj:`False`)
        keep_temp (bool, optional): Whether to keep the temporary files
            after processing. (default: :obj:`False`)
        n_jobs (int, optional): The number of jobs to use for parallel
            processing. If set to :obj:`-1`, all available cores will be used.
            NOTE: for larger dataset, it is recommended to set this to a lower positive value
            to avoid memory issues. (default: :obj:`-1`)
        local_raw_folder (str, optional): Local folder to look for raw files.
            If :obj:`None`, files will be downloaded from the internet.
            (default: :obj:`None`)

    Examples:
        >>> from lumina.dataset.opf.opf_dataset import OPFDataset
        >>> dataset = OPFDataset(root='./data', case_name='pglib_opf_case14_ieee')
        >>> # By default, only first group (i.e., group 0) is loaded, if you want to load multiple groups,
        >>> #   please use OPFMultiDataset instead:
        >>> dataset = OPFMultiDataset.from_case_groups(root='./data', case_name='pglib_opf_case14_ieee', group_ids=[0,1,2])
    """
    url = "https://storage.googleapis.com/gridopt-dataset"

    def __init__(
        self,
        root: str,
        case_name: Literal[
            'pglib_opf_case14_ieee',
            'pglib_opf_case30_ieee',
            'pglib_opf_case57_ieee',
            'pglib_opf_case118_ieee',
            'pglib_opf_case500_goc',
            'pglib_opf_case2000_goc',
            'pglib_opf_case4661_sdet',
            'pglib_opf_case6470_rte',
            'pglib_opf_case10000_goc',
            'pglib_opf_case13659_pegase',
        ] = 'pglib_opf_case14_ieee',
        group_id: int = 0,
        transform: Optional[Callable] = None,
        pre_transform: Optional[Callable] = None,
        pre_filter: Optional[Callable] = None,
        force_reload: bool = False,
        keep_temp: bool = False,
        n_jobs: int = -1,
        local_raw_folder: str = None,
    ) -> None:

        self.case_name = case_name
        self.group_id = group_id

        self._raw_root = osp.join(root, 'OPFData/raw')
        self._processed_root = osp.join(root, 'OPFData/processed')
        self._release = 'dataset_release_1'
        self.n_jobs = n_jobs
        self.keep_temp = keep_temp

        # TODO: add admittance matrix Y to the dataset - This may be used in multi-case evaluation
        # current_dir = os.path.dirname(os.path.abspath(__file__))
        # matpower_file = osp.join(current_dir, "../data/pglib", f"{case_name}.m")
        # net = pp.converter.from_mpc(matpower_file)
        # ppc = pp.converter.pypower.to_ppc(net, init='flat')
        # self.Y, _, _ = pp.makeYbus_pypower(ppc['baseMVA'], ppc['bus'], ppc['branch'])
        # self.Y_real = torch.tensor(self.Y.real.todense())
        # self.Y_imag = torch.tensor(self.Y.imag.todense())
        # self.load_bus_indices = net.load.bus.values.astype(np.int32)

        self.local_raw_folder = local_raw_folder
        # NOTE: processing steps:
        #   1. check downloaded: if raw files are not ready, download from url
        #   2. check processed: if processed files are not exist, process raw files
        super().__init__(root,
                         transform,
                         pre_transform,
                         pre_filter,
                         force_reload=force_reload)

        # Load only the specified group
        self.data, self.slices = torch.load(self.processed_paths[0], weights_only=False)

        if osp.exists(self._processed_root):
            try:
                current_mode = os.stat(self._processed_root).st_mode
                os.chmod(self._processed_root, current_mode | stat.S_IWGRP)
            except OSError as exc:
                warnings.warn(f'Failed to set group write permission on {self._processed_root}: {exc}')

    @property
    def raw_dir(self) -> str:
        r""" Raw data folder """
        return osp.join(self._raw_root, self._release)

    @property
    def processed_dir(self) -> str:
        r""" Processed data folder. """
        return osp.join(self._processed_root, self._release, self.case_name)

    @property
    def tmp_dir(self) -> str:
        # NOTE: new class attribute
        r""" Temporary data folder. """
        return osp.join(self.raw_dir,
                        "gridopt-dataset-tmp",
                        self._release,
                        self.case_name)

    @property
    def raw_file_names(self) -> List[str]:
        r""" Raw file names, which are stored locally. """
        return [f'{self.case_name}_{self.group_id}.tar.gz']

    @property
    def processed_file_names(self) -> List[str]:
        r""" Processed file names, which are stored locally. """
        return [f'group_{self.group_id}.pt']

    def download(self) -> None:
        r""" Download .tar.gz files """
        print("Download files")

        # NOTE: download a single file for the specified group_id
        self.download_and_extract(self.raw_file_names[0])

        # NOTE: keep the parallel code for future reference
        # results = Parallel(n_jobs=self.n_jobs, backend="multiprocessing")(
        #     delayed(self.download_and_extract)(name)
        #     for name in self.raw_file_names)

        print(f"Downloaded {self.raw_file_names[0]} to {self.raw_dir}")

    def download_and_extract(self, name: str) -> None:
        r""" Download and extract a .tar.gz file. """
        url = f'{self.url}/{self._release}/{name}'
        path = download_url(url, self.raw_dir)
        extract_tar(path, self.raw_dir)

    def process(self) -> None:
        r""" Process the raw files into a single file. """
        h5_files = [f for f in self.raw_paths if f.endswith('.h5')]

        if h5_files:
            print(f"HDF5 files detected: {h5_files}")
            try:
                self.process_hdf5_group(self.group_id)
            except Exception as e:
                print(f"Error processing HDF5 group {self.group_id}: {e}")
                raise e
            return

        if not osp.exists(self.tmp_dir):
            os.makedirs(self.tmp_dir)

        try:
            self.process_json_group(self.group_id)
        except Exception as e:
            print(f"Error processing group {self.group_id}: {e}")
            raise e

        print(f"Processed group {self.group_id}")

        # NOTE: remove tmp_dir content to save local space
        if not self.keep_temp:
            shutil.rmtree(osp.join(self.raw_dir, 'gridopt-dataset-tmp'))

    def _post_process_and_save(self, data_list: List[Optional[Union[HeteroData, List[HeteroData]]]], group_id: int):
        """Filter, transform, collate, and save processed data to disk.

        Args:
            data_list (List[Optional[Union[HeteroData, List[HeteroData]]]]): Raw
                processed samples, which may contain ``None`` entries or nested
                lists. These are flattened, filtered, and transformed before
                collation.
            group_id (int): Group identifier used for the output filename.
        """
        flattened_list = []
        for item in data_list:
            if item is None:
                continue
            if isinstance(item, list):
                flattened_list.extend(item)
            else:
                flattened_list.append(item)

        data_list = flattened_list

        if self.pre_filter is not None or self.pre_transform is not None:
            if self.pre_filter is not None:
                data_list = [data for data in data_list if self.pre_filter(data)]
            if self.pre_transform is not None:
                data_list = [self.pre_transform(data) for data in data_list]

        self.data, self.slices = self.collate(data_list)
        torch.save((self._data, self.slices), osp.join(self.processed_dir, f'group_{group_id}.pt'))

    def process_json_group(self, group_id: int):
        r""" Process a single group of files, save processed data to disk.

        Args:
            group_id (int): Group id.
        """

        group_json_files = glob(osp.join(self.tmp_dir, f'group_{group_id}', '*.json'))
        #
        if len(group_json_files) < 15000:
            extract_tar(osp.join(self.raw_dir, self.raw_file_names[0]), self.raw_dir)
            group_json_files = glob(osp.join(self.tmp_dir, f'group_{group_id}', '*.json'))

        data_list = Parallel(n_jobs=self.n_jobs, backend="threading")(
            delayed(process_json_file)(fn) for fn in tqdm.tqdm(group_json_files, desc=f"Group {group_id}"))

        self._post_process_and_save(data_list, group_id)

    def combine_datasets(self, file_paths: List[str]) -> List[HeteroData]:
        r""" Combine datasets from multiple files.

        Notes:
          - deprecated, kept for reference, please use `OPFMultiDataset` instead.
        """
        warnings.warn(
            "combine_datasets is deprecated and will be removed in a future version. "
            "Please use OPFMultiDataset instead for combining multiple datasets.",
            DeprecationWarning,
            stacklevel=2
        )

        combined_data = []
        for file_path in file_paths:
            with open(file_path, 'rb') as f:
                data = pickle.load(f)
            combined_data.extend(data)
        return combined_data

    def merge_group_files(self) -> None:
        r"""Merge group files into single train/val/test based on groups options.

        Notes:
          - deprecated, kept for reference, please use `OPFMultiDataset` instead.
        """
        warnings.warn(
            "merge_group_files is deprecated and will be removed in a future version. "
            "Please use OPFMultiDataset instead for combining multiple datasets.",
            DeprecationWarning,
            stacklevel=2
        )

        data_files = [osp.join(self.processed_dir, f'group_{self.group_id}.pkl')]
        combined_data = self.combine_datasets(data_files)
        self.data, self.slices = self.collate(combined_data)

    def metadata(self):
        r""" Returns the metadata of the dataset. """
        # return (
        #     ['bus', 'generator', 'load', 'shunt'],
        #     [
        #         ('bus', 'ac_line', 'bus'),
        #         ('bus', 'transformer', 'bus'),
        #         ('generator', 'generator_link', 'bus'),
        #         ('bus', 'generator_link', 'generator'),
        #         ('load', 'load_link', 'bus'),
        #         ('bus', 'load_link', 'load'),
        #         ('shunt', 'shunt_link', 'bus'),
        #         ('bus', 'shunt_link', 'shunt')
        #     ]
        # )

        return {
            "nodes": {"bus": self._data['bus'].x.size(1),
                      "generator": self._data['generator'].x.size(1),
                      "load": self._data['load'].x.size(1),
                      "shunt": self._data['shunt'].x.size(1)},
            "edges": {
                ('bus', 'ac_line', 'bus'): self._data['bus', 'ac_line', 'bus'].edge_attr.size(1),
                ('bus', 'transformer', 'bus'): self._data['bus', 'transformer', 'bus'].edge_attr.size(1),
                ('generator', 'generator_link', 'bus'): 0,
                ('bus', 'generator_link', 'generator'): 0,
                ('load', 'load_link', 'bus'): 0,
                ('bus', 'load_link', 'load'): 0,
                ('shunt', 'shunt_link', 'bus'): 0,
                ('bus', 'shunt_link', 'shunt'): 0
            }
        }

    def process_hdf5_group(self, group_id: int):
        r""" Process a single group of HDF5 files, save processed data to disk.

        Args:
            group_id (int): Group id.
        """
        raw_paths = self.raw_paths
        h5_files = [f for f in raw_paths if f.endswith('.h5')]

        if not h5_files:
            print(f"No HDF5 files found in {self.raw_dir}")
            return

        tasks = []
        for h5_file in h5_files:
            with h5py.File(h5_file, 'r') as f:
                for scenario_key in f.keys():
                    if 'grid' not in scenario_key:
                        tasks.append((h5_file, scenario_key))

        data_list = Parallel(n_jobs=self.n_jobs, backend="threading")(
            delayed(_process_hdf5_scenario_from_path)(fn, key)
            for fn, key in tqdm.tqdm(tasks, desc=f"Group {group_id} HDF5")
        )

        self._post_process_and_save(data_list, group_id)
    def __repr__(self) -> str:
        r""" Returns the string representation of the dataset. """
        return (f'{self.__class__.__name__}({len(self)}, '
                f'case_name={self.case_name})')

raw_dir: str property

Raw data folder

processed_dir: str property

Processed data folder.

tmp_dir: str property

Temporary data folder.

raw_file_names: List[str] property

Raw file names, which are stored locally.

processed_file_names: List[str] property

Processed file names, which are stored locally.

download() -> None

Download .tar.gz files

Source code in lumina/dataset/opf/opf_dataset.py
def download(self) -> None:
    r""" Download .tar.gz files """
    print("Download files")

    # NOTE: download a single file for the specified group_id
    self.download_and_extract(self.raw_file_names[0])

    # NOTE: keep the parallel code for future reference
    # results = Parallel(n_jobs=self.n_jobs, backend="multiprocessing")(
    #     delayed(self.download_and_extract)(name)
    #     for name in self.raw_file_names)

    print(f"Downloaded {self.raw_file_names[0]} to {self.raw_dir}")

download_and_extract(name: str) -> None

Download and extract a .tar.gz file.

Source code in lumina/dataset/opf/opf_dataset.py
def download_and_extract(self, name: str) -> None:
    r""" Download and extract a .tar.gz file. """
    url = f'{self.url}/{self._release}/{name}'
    path = download_url(url, self.raw_dir)
    extract_tar(path, self.raw_dir)

process() -> None

Process the raw files into a single file.

Source code in lumina/dataset/opf/opf_dataset.py
def process(self) -> None:
    r""" Process the raw files into a single file. """
    h5_files = [f for f in self.raw_paths if f.endswith('.h5')]

    if h5_files:
        print(f"HDF5 files detected: {h5_files}")
        try:
            self.process_hdf5_group(self.group_id)
        except Exception as e:
            print(f"Error processing HDF5 group {self.group_id}: {e}")
            raise e
        return

    if not osp.exists(self.tmp_dir):
        os.makedirs(self.tmp_dir)

    try:
        self.process_json_group(self.group_id)
    except Exception as e:
        print(f"Error processing group {self.group_id}: {e}")
        raise e

    print(f"Processed group {self.group_id}")

    # NOTE: remove tmp_dir content to save local space
    if not self.keep_temp:
        shutil.rmtree(osp.join(self.raw_dir, 'gridopt-dataset-tmp'))

process_json_group(group_id: int)

Process a single group of files, save processed data to disk.

Parameters:

Name Type Description Default
group_id int

Group id.

required
Source code in lumina/dataset/opf/opf_dataset.py
def process_json_group(self, group_id: int):
    r""" Process a single group of files, save processed data to disk.

    Args:
        group_id (int): Group id.
    """

    group_json_files = glob(osp.join(self.tmp_dir, f'group_{group_id}', '*.json'))
    #
    if len(group_json_files) < 15000:
        extract_tar(osp.join(self.raw_dir, self.raw_file_names[0]), self.raw_dir)
        group_json_files = glob(osp.join(self.tmp_dir, f'group_{group_id}', '*.json'))

    data_list = Parallel(n_jobs=self.n_jobs, backend="threading")(
        delayed(process_json_file)(fn) for fn in tqdm.tqdm(group_json_files, desc=f"Group {group_id}"))

    self._post_process_and_save(data_list, group_id)

combine_datasets(file_paths: List[str]) -> List[HeteroData]

Combine datasets from multiple files.

Notes
  • deprecated, kept for reference, please use OPFMultiDataset instead.
Source code in lumina/dataset/opf/opf_dataset.py
def combine_datasets(self, file_paths: List[str]) -> List[HeteroData]:
    r""" Combine datasets from multiple files.

    Notes:
      - deprecated, kept for reference, please use `OPFMultiDataset` instead.
    """
    warnings.warn(
        "combine_datasets is deprecated and will be removed in a future version. "
        "Please use OPFMultiDataset instead for combining multiple datasets.",
        DeprecationWarning,
        stacklevel=2
    )

    combined_data = []
    for file_path in file_paths:
        with open(file_path, 'rb') as f:
            data = pickle.load(f)
        combined_data.extend(data)
    return combined_data

merge_group_files() -> None

Merge group files into single train/val/test based on groups options.

Notes
  • deprecated, kept for reference, please use OPFMultiDataset instead.
Source code in lumina/dataset/opf/opf_dataset.py
def merge_group_files(self) -> None:
    r"""Merge group files into single train/val/test based on groups options.

    Notes:
      - deprecated, kept for reference, please use `OPFMultiDataset` instead.
    """
    warnings.warn(
        "merge_group_files is deprecated and will be removed in a future version. "
        "Please use OPFMultiDataset instead for combining multiple datasets.",
        DeprecationWarning,
        stacklevel=2
    )

    data_files = [osp.join(self.processed_dir, f'group_{self.group_id}.pkl')]
    combined_data = self.combine_datasets(data_files)
    self.data, self.slices = self.collate(combined_data)

metadata()

Returns the metadata of the dataset.

Source code in lumina/dataset/opf/opf_dataset.py
def metadata(self):
    r""" Returns the metadata of the dataset. """
    # return (
    #     ['bus', 'generator', 'load', 'shunt'],
    #     [
    #         ('bus', 'ac_line', 'bus'),
    #         ('bus', 'transformer', 'bus'),
    #         ('generator', 'generator_link', 'bus'),
    #         ('bus', 'generator_link', 'generator'),
    #         ('load', 'load_link', 'bus'),
    #         ('bus', 'load_link', 'load'),
    #         ('shunt', 'shunt_link', 'bus'),
    #         ('bus', 'shunt_link', 'shunt')
    #     ]
    # )

    return {
        "nodes": {"bus": self._data['bus'].x.size(1),
                  "generator": self._data['generator'].x.size(1),
                  "load": self._data['load'].x.size(1),
                  "shunt": self._data['shunt'].x.size(1)},
        "edges": {
            ('bus', 'ac_line', 'bus'): self._data['bus', 'ac_line', 'bus'].edge_attr.size(1),
            ('bus', 'transformer', 'bus'): self._data['bus', 'transformer', 'bus'].edge_attr.size(1),
            ('generator', 'generator_link', 'bus'): 0,
            ('bus', 'generator_link', 'generator'): 0,
            ('load', 'load_link', 'bus'): 0,
            ('bus', 'load_link', 'load'): 0,
            ('shunt', 'shunt_link', 'bus'): 0,
            ('bus', 'shunt_link', 'shunt'): 0
        }
    }

process_hdf5_group(group_id: int)

Process a single group of HDF5 files, save processed data to disk.

Parameters:

Name Type Description Default
group_id int

Group id.

required
Source code in lumina/dataset/opf/opf_dataset.py
def process_hdf5_group(self, group_id: int):
    r""" Process a single group of HDF5 files, save processed data to disk.

    Args:
        group_id (int): Group id.
    """
    raw_paths = self.raw_paths
    h5_files = [f for f in raw_paths if f.endswith('.h5')]

    if not h5_files:
        print(f"No HDF5 files found in {self.raw_dir}")
        return

    tasks = []
    for h5_file in h5_files:
        with h5py.File(h5_file, 'r') as f:
            for scenario_key in f.keys():
                if 'grid' not in scenario_key:
                    tasks.append((h5_file, scenario_key))

    data_list = Parallel(n_jobs=self.n_jobs, backend="threading")(
        delayed(_process_hdf5_scenario_from_path)(fn, key)
        for fn, key in tqdm.tqdm(tasks, desc=f"Group {group_id} HDF5")
    )

    self._post_process_and_save(data_list, group_id)

__repr__() -> str

Returns the string representation of the dataset.

Source code in lumina/dataset/opf/opf_dataset.py
def __repr__(self) -> str:
    r""" Returns the string representation of the dataset. """
    return (f'{self.__class__.__name__}({len(self)}, '
            f'case_name={self.case_name})')

OPFHomogeneousDataset

Bases: OPFDataset

OPFDataset variant that converts heterogeneous graphs to homogeneous during preprocessing.

Wraps each HeteroData sample through :class:OPFHomoWrapper before saving, producing homogeneous Data objects with optional one-hot node/edge type indicators. Non-finite target values are sanitized and a y_mask is stored.

Parameters:

Name Type Description Default
*args

Positional arguments forwarded to :class:OPFDataset.

()
add_node_type bool

Append one-hot node type features to x. (default: :obj:True)

True
add_edge_type bool

Append one-hot edge type features to edge_attr. (default: :obj:True)

True
attach_full_edge_attr bool

Attach the full concatenated edge attribute tensor as edge_attr_full. (default: :obj:False)

False
sanitize_targets bool

Replace non-finite target values with zero. (default: :obj:True)

True
log_bad_targets bool

Log a warning when non-finite targets are found. (default: :obj:True)

True
max_bad_target_logs int

Maximum number of bad-target warnings to emit. (default: :obj:1)

1
processed_suffix str

Suffix appended to the release name for the processed directory. (default: :obj:"homo")

'homo'
**kwargs

Additional keyword arguments forwarded to :class:OPFDataset.

{}
Source code in lumina/dataset/opf/opf_dataset.py
class OPFHomogeneousDataset(OPFDataset):
    r"""OPFDataset variant that converts heterogeneous graphs to homogeneous during preprocessing.

    Wraps each HeteroData sample through :class:`OPFHomoWrapper` before saving,
    producing homogeneous ``Data`` objects with optional one-hot node/edge type
    indicators. Non-finite target values are sanitized and a ``y_mask`` is stored.

    Args:
        *args: Positional arguments forwarded to :class:`OPFDataset`.
        add_node_type (bool): Append one-hot node type features to ``x``.
            (default: :obj:`True`)
        add_edge_type (bool): Append one-hot edge type features to ``edge_attr``.
            (default: :obj:`True`)
        attach_full_edge_attr (bool): Attach the full concatenated edge attribute
            tensor as ``edge_attr_full``. (default: :obj:`False`)
        sanitize_targets (bool): Replace non-finite target values with zero.
            (default: :obj:`True`)
        log_bad_targets (bool): Log a warning when non-finite targets are found.
            (default: :obj:`True`)
        max_bad_target_logs (int): Maximum number of bad-target warnings to emit.
            (default: :obj:`1`)
        processed_suffix (str): Suffix appended to the release name for the
            processed directory. (default: :obj:`"homo"`)
        **kwargs: Additional keyword arguments forwarded to :class:`OPFDataset`.
    """

    def __init__(
        self,
        *args,
        add_node_type: bool = True,
        add_edge_type: bool = True,
        attach_full_edge_attr: bool = False,
        sanitize_targets: bool = True,
        log_bad_targets: bool = True,
        max_bad_target_logs: int = 1,
        processed_suffix: str = "homo",
        **kwargs,
    ) -> None:
        self._homo_wrapper = OPFHomoWrapper(
            add_node_type=add_node_type,
            add_edge_type=add_edge_type,
            attach_full_edge_attr=attach_full_edge_attr,
        )
        self._sanitize_targets = bool(sanitize_targets)
        self._log_bad_targets = bool(log_bad_targets)
        self._max_bad_target_logs = int(max_bad_target_logs)
        self._bad_target_logs = 0
        self._processed_suffix = processed_suffix or "homo"

        user_pre_transform = kwargs.pop("pre_transform", None)

        def pre_transform(data):
            homo_data = self._homo_wrapper.convert(data)
            self._sanitize_homo_targets(homo_data)
            if user_pre_transform is not None:
                homo_data = user_pre_transform(homo_data)
            return homo_data

        super().__init__(*args, pre_transform=pre_transform, **kwargs)

    @property
    def processed_dir(self) -> str:
        if self._processed_suffix:
            release = f"{self._release}_{self._processed_suffix}"
        else:
            release = self._release
        return osp.join(self._processed_root, release, self.case_name)

    def _sanitize_homo_targets(self, homo_data):
        y = getattr(homo_data, "y", None)
        if not torch.is_tensor(y):
            return

        finite_mask = torch.isfinite(y)
        if finite_mask.ndim > 1:
            row_mask = finite_mask.all(dim=-1)
        else:
            row_mask = finite_mask

        if bool(row_mask.all().item()):
            return

        if self._sanitize_targets:
            if y.ndim == 0:
                y = torch.zeros_like(y)
            else:
                y = y.clone()
                y[~row_mask] = 0
            homo_data.y = y

        homo_data.y_mask = row_mask.to(dtype=torch.bool)

        if self._should_log_bad_targets():
            bad_count = int((~row_mask).sum().item())
            total = int(row_mask.numel())
            action = "sanitized" if self._sanitize_targets else "left as-is"
            print(
                f"[OPFHomogeneousDataset] Non-finite targets: "
                f"{bad_count}/{total} rows {action}; stored y_mask."
            )
            self._bad_target_logs += 1

    def _should_log_bad_targets(self):
        if not self._log_bad_targets:
            return False
        if self._bad_target_logs >= self._max_bad_target_logs:
            return False
        try:
            rank = int(os.environ.get("RANK", "0"))
        except ValueError:
            rank = 0
        return rank == 0

OPFMultiDataset

Bases: ConcatDataset

Multi-group OPF dataset that combines multiple OPFDataset instances using ConcatDataset.

This class allows combining multiple groups from the same case or different cases to create larger datasets for training.

Parameters:

Name Type Description Default
datasets List[OPFDataset]

List of OPFDataset instances to combine.

required

Examples:

>>> # For different cases
>>> d1 = OPFDataset(root, case_name="pglib_opf_case14_ieee", group_id=0)
>>> d2 = OPFDataset(root, case_name="pglib_opf_case30_ieee", group_id=0)
>>> multi_dataset = OPFMultiDataset([d1, d2])
>>> # For same case, multiple groups
>>> datasets = []
>>> for group_id in range(5):
...     ds = OPFDataset(root, case_name="pglib_opf_case14_ieee", group_id=group_id)
...     datasets.append(ds)
>>> multi_dataset = OPFMultiDataset(datasets)
>>> # Mixed: multiple groups from multiple cases
>>> case_mapping = {
...     "pglib_opf_case14_ieee": [0, 1, 2],
...     "pglib_opf_case30_ieee": [0, 1]
... }
>>> mixed_dataset = OPFMultiDataset.from_mixed_cases(root, case_mapping)
Source code in lumina/dataset/opf/opf_dataset.py
class OPFMultiDataset(ConcatDataset):
    r"""Multi-group OPF dataset that combines multiple OPFDataset instances using ConcatDataset.

    This class allows combining multiple groups from the same case or different cases
    to create larger datasets for training.

    Args:
        datasets (List[OPFDataset]): List of OPFDataset instances to combine.

    Examples:
        >>> # For different cases
        >>> d1 = OPFDataset(root, case_name="pglib_opf_case14_ieee", group_id=0)
        >>> d2 = OPFDataset(root, case_name="pglib_opf_case30_ieee", group_id=0)
        >>> multi_dataset = OPFMultiDataset([d1, d2])

        >>> # For same case, multiple groups
        >>> datasets = []
        >>> for group_id in range(5):
        ...     ds = OPFDataset(root, case_name="pglib_opf_case14_ieee", group_id=group_id)
        ...     datasets.append(ds)
        >>> multi_dataset = OPFMultiDataset(datasets)

        >>> # Mixed: multiple groups from multiple cases
        >>> case_mapping = {
        ...     "pglib_opf_case14_ieee": [0, 1, 2],
        ...     "pglib_opf_case30_ieee": [0, 1]
        ... }
        >>> mixed_dataset = OPFMultiDataset.from_mixed_cases(root, case_mapping)
    """

    def __init__(self, datasets):
        super().__init__(datasets)
        self.datasets = datasets

    @classmethod
    def from_case_groups(
        cls,
        root: str,
        case_name: str,
        group_ids: List[int],
        dataset_cls=OPFDataset,
        **kwargs,
    ):
        r"""Create OPFMultiDataset from multiple groups of the same case.

        Args:
            root (str): Root directory where the dataset should be saved.
            case_name (str): The name of the original pglib-opf case.
            group_ids (List[int]): List of group IDs to load.
            dataset_cls: Dataset class to instantiate for each group.
            **kwargs: Additional arguments passed to OPFDataset constructor.

        Returns:
            OPFMultiDataset: Combined dataset from multiple groups.
        """
        datasets = []
        for group_id in group_ids:
            ds = dataset_cls(root=root, case_name=case_name, group_id=group_id, **kwargs)
            datasets.append(ds)
        return cls(datasets)

    @classmethod
    def from_multiple_cases(
        cls,
        root: str,
        case_configs: List[Dict],
        dataset_cls=OPFDataset,
        **kwargs,
    ):
        r"""Create OPFMultiDataset from multiple cases with their respective group IDs.

        Args:
            root (str): Root directory where the dataset should be saved.
            case_configs (List[Dict]): List of dictionaries, each containing 'case_name'
                and 'group_id' keys.
            dataset_cls: Dataset class to instantiate for each case.
            **kwargs: Additional arguments passed to OPFDataset constructor.

        Returns:
            OPFMultiDataset: Combined dataset from multiple cases.

        Example:
            >>> configs = [
            ...     {"case_name": "pglib_opf_case14_ieee", "group_id": 0},
            ...     {"case_name": "pglib_opf_case30_ieee", "group_id": 1},
            ... ]
            >>> multi_dataset = OPFMultiDataset.from_multiple_cases(root, configs)
        """
        datasets = []
        for config in case_configs:
            case_name = config.pop('case_name')
            group_id = config.pop('group_id')
            # Merge with additional kwargs
            dataset_kwargs = {**kwargs, **config}
            ds = dataset_cls(root=root, case_name=case_name, group_id=group_id, **dataset_kwargs)
            datasets.append(ds)
        return cls(datasets)

    @classmethod
    def from_mixed_cases(
        cls,
        root: str,
        case_group_mapping: Dict[str, List[int]],
        dataset_cls=OPFDataset,
        **kwargs,
    ):
        r"""Create OPFMultiDataset from multiple groups across different cases.

        This method allows loading multiple groups from multiple cases in a single call,
        which is useful when you want to combine several groups from different cases.

        Args:
            root (str): Root directory where the dataset should be saved.
            case_group_mapping (Dict[str, List[int]]): Dictionary mapping case names to
                lists of group IDs to load for each case.
            dataset_cls: Dataset class to instantiate for each group.
            **kwargs: Additional arguments passed to OPFDataset constructor.

        Returns:
            OPFMultiDataset: Combined dataset from multiple groups across multiple cases.

        Example:
            >>> # Load 3 groups from case14 and 2 groups from case30
            >>> case_mapping = {
            ...     "pglib_opf_case14_ieee": [0, 1, 2],
            ...     "pglib_opf_case30_ieee": [0, 1]
            ... }
            >>> multi_dataset = OPFMultiDataset.from_mixed_cases(root, case_mapping)
        """
        datasets = []
        for case_name, group_ids in case_group_mapping.items():
            for group_id in group_ids:
                ds = dataset_cls(root=root, case_name=case_name, group_id=group_id, **kwargs)
                datasets.append(ds)
        return cls(datasets)

    def metadata(self):
        r"""Returns the metadata of the first dataset (assuming all have same structure)."""
        return self.datasets[0].metadata()

    def __repr__(self) -> str:
        case_info = []
        for ds in self.datasets:
            case_info.append(f"{ds.case_name}[{ds.group_id}]")
        cases_str = ", ".join(case_info)
        return f'{self.__class__.__name__}({len(self)}, cases=[{cases_str}])'

from_case_groups(root: str, case_name: str, group_ids: List[int], dataset_cls=OPFDataset, **kwargs) classmethod

Create OPFMultiDataset from multiple groups of the same case.

Parameters:

Name Type Description Default
root str

Root directory where the dataset should be saved.

required
case_name str

The name of the original pglib-opf case.

required
group_ids List[int]

List of group IDs to load.

required
dataset_cls

Dataset class to instantiate for each group.

OPFDataset
**kwargs

Additional arguments passed to OPFDataset constructor.

{}

Returns:

Name Type Description
OPFMultiDataset

Combined dataset from multiple groups.

Source code in lumina/dataset/opf/opf_dataset.py
@classmethod
def from_case_groups(
    cls,
    root: str,
    case_name: str,
    group_ids: List[int],
    dataset_cls=OPFDataset,
    **kwargs,
):
    r"""Create OPFMultiDataset from multiple groups of the same case.

    Args:
        root (str): Root directory where the dataset should be saved.
        case_name (str): The name of the original pglib-opf case.
        group_ids (List[int]): List of group IDs to load.
        dataset_cls: Dataset class to instantiate for each group.
        **kwargs: Additional arguments passed to OPFDataset constructor.

    Returns:
        OPFMultiDataset: Combined dataset from multiple groups.
    """
    datasets = []
    for group_id in group_ids:
        ds = dataset_cls(root=root, case_name=case_name, group_id=group_id, **kwargs)
        datasets.append(ds)
    return cls(datasets)

from_multiple_cases(root: str, case_configs: List[Dict], dataset_cls=OPFDataset, **kwargs) classmethod

Create OPFMultiDataset from multiple cases with their respective group IDs.

Parameters:

Name Type Description Default
root str

Root directory where the dataset should be saved.

required
case_configs List[Dict]

List of dictionaries, each containing 'case_name' and 'group_id' keys.

required
dataset_cls

Dataset class to instantiate for each case.

OPFDataset
**kwargs

Additional arguments passed to OPFDataset constructor.

{}

Returns:

Name Type Description
OPFMultiDataset

Combined dataset from multiple cases.

Example

configs = [ ... {"case_name": "pglib_opf_case14_ieee", "group_id": 0}, ... {"case_name": "pglib_opf_case30_ieee", "group_id": 1}, ... ] multi_dataset = OPFMultiDataset.from_multiple_cases(root, configs)

Source code in lumina/dataset/opf/opf_dataset.py
@classmethod
def from_multiple_cases(
    cls,
    root: str,
    case_configs: List[Dict],
    dataset_cls=OPFDataset,
    **kwargs,
):
    r"""Create OPFMultiDataset from multiple cases with their respective group IDs.

    Args:
        root (str): Root directory where the dataset should be saved.
        case_configs (List[Dict]): List of dictionaries, each containing 'case_name'
            and 'group_id' keys.
        dataset_cls: Dataset class to instantiate for each case.
        **kwargs: Additional arguments passed to OPFDataset constructor.

    Returns:
        OPFMultiDataset: Combined dataset from multiple cases.

    Example:
        >>> configs = [
        ...     {"case_name": "pglib_opf_case14_ieee", "group_id": 0},
        ...     {"case_name": "pglib_opf_case30_ieee", "group_id": 1},
        ... ]
        >>> multi_dataset = OPFMultiDataset.from_multiple_cases(root, configs)
    """
    datasets = []
    for config in case_configs:
        case_name = config.pop('case_name')
        group_id = config.pop('group_id')
        # Merge with additional kwargs
        dataset_kwargs = {**kwargs, **config}
        ds = dataset_cls(root=root, case_name=case_name, group_id=group_id, **dataset_kwargs)
        datasets.append(ds)
    return cls(datasets)

from_mixed_cases(root: str, case_group_mapping: Dict[str, List[int]], dataset_cls=OPFDataset, **kwargs) classmethod

Create OPFMultiDataset from multiple groups across different cases.

This method allows loading multiple groups from multiple cases in a single call, which is useful when you want to combine several groups from different cases.

Parameters:

Name Type Description Default
root str

Root directory where the dataset should be saved.

required
case_group_mapping Dict[str, List[int]]

Dictionary mapping case names to lists of group IDs to load for each case.

required
dataset_cls

Dataset class to instantiate for each group.

OPFDataset
**kwargs

Additional arguments passed to OPFDataset constructor.

{}

Returns:

Name Type Description
OPFMultiDataset

Combined dataset from multiple groups across multiple cases.

Example

Load 3 groups from case14 and 2 groups from case30

case_mapping = { ... "pglib_opf_case14_ieee": [0, 1, 2], ... "pglib_opf_case30_ieee": [0, 1] ... } multi_dataset = OPFMultiDataset.from_mixed_cases(root, case_mapping)

Source code in lumina/dataset/opf/opf_dataset.py
@classmethod
def from_mixed_cases(
    cls,
    root: str,
    case_group_mapping: Dict[str, List[int]],
    dataset_cls=OPFDataset,
    **kwargs,
):
    r"""Create OPFMultiDataset from multiple groups across different cases.

    This method allows loading multiple groups from multiple cases in a single call,
    which is useful when you want to combine several groups from different cases.

    Args:
        root (str): Root directory where the dataset should be saved.
        case_group_mapping (Dict[str, List[int]]): Dictionary mapping case names to
            lists of group IDs to load for each case.
        dataset_cls: Dataset class to instantiate for each group.
        **kwargs: Additional arguments passed to OPFDataset constructor.

    Returns:
        OPFMultiDataset: Combined dataset from multiple groups across multiple cases.

    Example:
        >>> # Load 3 groups from case14 and 2 groups from case30
        >>> case_mapping = {
        ...     "pglib_opf_case14_ieee": [0, 1, 2],
        ...     "pglib_opf_case30_ieee": [0, 1]
        ... }
        >>> multi_dataset = OPFMultiDataset.from_mixed_cases(root, case_mapping)
    """
    datasets = []
    for case_name, group_ids in case_group_mapping.items():
        for group_id in group_ids:
            ds = dataset_cls(root=root, case_name=case_name, group_id=group_id, **kwargs)
            datasets.append(ds)
    return cls(datasets)

metadata()

Returns the metadata of the first dataset (assuming all have same structure).

Source code in lumina/dataset/opf/opf_dataset.py
def metadata(self):
    r"""Returns the metadata of the first dataset (assuming all have same structure)."""
    return self.datasets[0].metadata()

On-Disk Dataset

OPFOnDiskDataset

Bases: OnDiskDataset

On-disk OPF dataset backed by SQLite or RocksDB.

Stores individual HeteroData samples in a database so that the full dataset need not reside in CPU memory. Supports the same cases and group structure as :class:OPFDataset.

Parameters:

Name Type Description Default
root str

Root directory where the dataset should be saved.

required
case_name str

Name of the pglib-opf case. (default: :obj:"pglib_opf_case14_ieee")

'pglib_opf_case14_ieee'
group_id int

Group to load (each group has 15,000 samples). (default: :obj:0)

0
transform callable

Per-access transform applied to each sample. (default: :obj:None)

None
pre_transform callable

Transform applied once before writing to the database. (default: :obj:None)

None
pre_filter callable

Predicate; samples that return False are skipped. (default: :obj:None)

None
force_reload bool

Re-process even if the database already exists. (default: :obj:False)

False
keep_temp bool

Keep extracted JSON temp files after processing. (default: :obj:False)

False
n_jobs int

Number of parallel workers for processing. (default: :obj:-1)

-1
local_raw_folder str

Local folder containing raw archives; skips download when set. (default: :obj:None)

None
backend str

Database backend, "sqlite" or "rocksdb". (default: :obj:"sqlite")

'sqlite'
schema object

PyG schema for the stored data type.

object
log bool

Enable logging. (default: :obj:True)

True
write_batch_size int

Number of samples to batch per DB write. (default: :obj:128)

128
sqlite_timeout_sec float

SQLite connection timeout in seconds. (default: :obj:600.0)

600.0
sqlite_busy_timeout_ms Optional[int]

SQLite busy timeout in ms.

None
sqlite_journal_mode Optional[str]

SQLite journal mode PRAGMA. (default: :obj:"WAL")

'WAL'
sqlite_synchronous Optional[str]

SQLite synchronous PRAGMA. (default: :obj:"NORMAL")

'NORMAL'
Source code in lumina/dataset/opf/opf_on_disk_dataset.py
class OPFOnDiskDataset(OnDiskDataset):
    r"""On-disk OPF dataset backed by SQLite or RocksDB.

    Stores individual HeteroData samples in a database so that the full
    dataset need not reside in CPU memory. Supports the same cases and
    group structure as :class:`OPFDataset`.

    Args:
        root (str): Root directory where the dataset should be saved.
        case_name (str): Name of the pglib-opf case.
            (default: :obj:`"pglib_opf_case14_ieee"`)
        group_id (int): Group to load (each group has 15,000 samples).
            (default: :obj:`0`)
        transform (callable, optional): Per-access transform applied to each
            sample. (default: :obj:`None`)
        pre_transform (callable, optional): Transform applied once before
            writing to the database. (default: :obj:`None`)
        pre_filter (callable, optional): Predicate; samples that return
            ``False`` are skipped. (default: :obj:`None`)
        force_reload (bool): Re-process even if the database already exists.
            (default: :obj:`False`)
        keep_temp (bool): Keep extracted JSON temp files after processing.
            (default: :obj:`False`)
        n_jobs (int): Number of parallel workers for processing.
            (default: :obj:`-1`)
        local_raw_folder (str, optional): Local folder containing raw archives;
            skips download when set. (default: :obj:`None`)
        backend (str): Database backend, ``"sqlite"`` or ``"rocksdb"``.
            (default: :obj:`"sqlite"`)
        schema (object): PyG schema for the stored data type.
        log (bool): Enable logging. (default: :obj:`True`)
        write_batch_size (int): Number of samples to batch per DB write.
            (default: :obj:`128`)
        sqlite_timeout_sec (float): SQLite connection timeout in seconds.
            (default: :obj:`600.0`)
        sqlite_busy_timeout_ms (Optional[int]): SQLite busy timeout in ms.
        sqlite_journal_mode (Optional[str]): SQLite journal mode PRAGMA.
            (default: :obj:`"WAL"`)
        sqlite_synchronous (Optional[str]): SQLite synchronous PRAGMA.
            (default: :obj:`"NORMAL"`)
    """

    url = "https://storage.googleapis.com/gridopt-dataset"

    def __init__(
        self,
        root: str,
        case_name: Literal[
            "pglib_opf_case14_ieee",
            "pglib_opf_case30_ieee",
            "pglib_opf_case57_ieee",
            "pglib_opf_case118_ieee",
            "pglib_opf_case500_goc",
            "pglib_opf_case2000_goc",
            "pglib_opf_case4661_sdet",
            "pglib_opf_case6470_rte",
            "pglib_opf_case10000_goc",
            "pglib_opf_case13659_pegase",
        ] = "pglib_opf_case14_ieee",
        group_id: int = 0,
        transform: Optional[Callable] = None,
        pre_transform: Optional[Callable] = None,
        pre_filter: Optional[Callable] = None,
        force_reload: bool = False,
        keep_temp: bool = False,
        n_jobs: int = -1,
        local_raw_folder: str = None,
        backend: str = "sqlite",
        schema: object = object,
        log: bool = True,
        write_batch_size: int = 128,
        sqlite_timeout_sec: float = 600.0,
        sqlite_busy_timeout_ms: Optional[int] = None,
        sqlite_journal_mode: Optional[str] = "WAL",
        sqlite_synchronous: Optional[str] = "NORMAL",
    ) -> None:
        if backend not in OnDiskDataset.BACKENDS:
            raise ValueError(
                f"Database backend must be one of {set(OnDiskDataset.BACKENDS.keys())} "
                f"(got '{backend}')"
            )

        self.backend = backend
        self.schema = schema
        self._db = None
        self._numel = None

        self.case_name = case_name
        self.group_id = int(group_id)

        self._raw_root = osp.join(root, "OPFData/raw")
        self._processed_root = osp.join(root, "OPFData/on_disk")
        self._release = "dataset_release_1"
        self.n_jobs = n_jobs
        self.keep_temp = keep_temp
        self.local_raw_folder = local_raw_folder
        self.write_batch_size = max(1, int(write_batch_size))
        self.sqlite_timeout_sec = float(sqlite_timeout_sec)
        self.sqlite_busy_timeout_ms = (
            int(sqlite_busy_timeout_ms) if sqlite_busy_timeout_ms is not None else None
        )
        self.sqlite_journal_mode = sqlite_journal_mode
        self.sqlite_synchronous = sqlite_synchronous

        Dataset.__init__(
            self,
            root=root,
            transform=transform,
            pre_transform=pre_transform,
            pre_filter=pre_filter,
            log=log,
            force_reload=force_reload,
        )

        if osp.exists(self.processed_dir):
            try:
                current_mode = os.stat(self.processed_dir).st_mode
                os.chmod(self.processed_dir, current_mode | stat.S_IWGRP)
            except OSError as exc:
                warnings.warn(
                    f"Failed to set group write permission on {self.processed_dir}: {exc}"
                )

    @property
    def raw_dir(self) -> str:
        return osp.join(self._raw_root, self._release)

    @property
    def processed_dir(self) -> str:
        return osp.join(self._processed_root, self._release, self.case_name)

    @property
    def tmp_dir(self) -> str:
        return osp.join(self.raw_dir, "gridopt-dataset-tmp", self._release, self.case_name)

    @property
    def raw_file_names(self) -> List[str]:
        return [f"{self.case_name}_{self.group_id}.tar.gz"]

    @property
    def processed_file_names(self) -> List[str]:
        return [self._db_filename()]

    def _db_filename(self) -> str:
        if self.backend == "rocksdb":
            return f"group_{self.group_id}.rocksdb"
        return f"group_{self.group_id}.{self.backend}.db"

    @property
    def db(self):
        if self._db is not None:
            return self._db

        os.makedirs(self.processed_dir, exist_ok=True)
        path = self.processed_paths[0]

        if self.backend == "sqlite":
            self._db = OPFSQLiteDatabase(
                path=path,
                name=self.__class__.__name__,
                schema=self.schema,
                timeout_sec=self.sqlite_timeout_sec,
                busy_timeout_ms=self.sqlite_busy_timeout_ms,
                journal_mode=self.sqlite_journal_mode,
                synchronous=self.sqlite_synchronous,
            )
        else:
            self._db = RocksDatabase(path=path, schema=self.schema)

        self._numel = len(self._db)
        return self._db

    def download(self) -> None:
        """Download and extract the raw tar.gz archive for this group."""
        self.download_and_extract(self.raw_file_names[0])

    def download_and_extract(self, name: str) -> None:
        """Download a tar.gz archive by name and extract it to the raw directory.

        Args:
            name (str): Filename of the archive to download.
        """
        url = f"{self.url}/{self._release}/{name}"
        path = download_url(url, self.raw_dir)
        extract_tar(path, self.raw_dir)

    def _clear_processed(self) -> None:
        path = self.processed_paths[0]
        if osp.isdir(path):
            shutil.rmtree(path)
            return
        if osp.exists(path):
            os.remove(path)
            for suffix in ("-wal", "-shm", "-journal"):
                sidecar = f"{path}{suffix}"
                if osp.exists(sidecar):
                    os.remove(sidecar)

    def process(self) -> None:
        """Process raw JSON files and write samples into the on-disk database."""
        if osp.exists(self.processed_paths[0]):
            self._clear_processed()

        if not osp.exists(self.tmp_dir):
            os.makedirs(self.tmp_dir)

        try:
            self.process_json_group(self.group_id)
        except Exception as exc:
            print(f"Error processing group {self.group_id}: {exc}")
            raise exc

        if not self.keep_temp:
            shutil.rmtree(osp.join(self.raw_dir, "gridopt-dataset-tmp"))

    def process_json_group(self, group_id: int) -> None:
        """Parse all JSON files for a group and insert them into the database.

        Args:
            group_id (int): Group identifier to process.
        """
        group_json_files = glob(osp.join(self.tmp_dir, f"group_{group_id}", "*.json"))
        if len(group_json_files) < 15000:
            extract_tar(osp.join(self.raw_dir, self.raw_file_names[0]), self.raw_dir)
            group_json_files = glob(osp.join(self.tmp_dir, f"group_{group_id}", "*.json"))

        batch = []
        for json_file in tqdm.tqdm(group_json_files, desc=f"Group {group_id}"):
            data = process_json_file(json_file)
            if data is None:
                continue
            if self.pre_filter is not None and not self.pre_filter(data):
                continue
            if self.pre_transform is not None:
                data = self.pre_transform(data)
            batch.append(data)
            if len(batch) >= self.write_batch_size:
                self.extend(batch)
                batch = []

        if batch:
            self.extend(batch)

    def metadata(self):
        """Return node and edge feature dimensionality metadata from the first sample.

        Returns:
            dict: Dictionary with ``"nodes"`` and ``"edges"`` keys mapping node
                types and edge types to their respective feature dimensions.
        """
        sample = self.get(0)

        def node_dim(node_type: str) -> int:
            if node_type not in getattr(sample, "node_types", []):
                return 0
            store = sample[node_type]
            x = getattr(store, "x", None)
            if torch.is_tensor(x):
                return int(x.size(1)) if x.ndim > 1 else 1
            return 0

        def edge_dim(edge_type) -> int:
            if edge_type not in getattr(sample, "edge_types", []):
                return 0
            store = sample[edge_type]
            edge_attr = getattr(store, "edge_attr", None)
            if torch.is_tensor(edge_attr):
                return int(edge_attr.size(1)) if edge_attr.ndim > 1 else 1
            return 0

        return {
            "nodes": {
                "bus": node_dim("bus"),
                "generator": node_dim("generator"),
                "load": node_dim("load"),
                "shunt": node_dim("shunt"),
            },
            "edges": {
                ("bus", "ac_line", "bus"): edge_dim(("bus", "ac_line", "bus")),
                ("bus", "transformer", "bus"): edge_dim(("bus", "transformer", "bus")),
                ("generator", "generator_link", "bus"): 0,
                ("bus", "generator_link", "generator"): 0,
                ("load", "load_link", "bus"): 0,
                ("bus", "load_link", "load"): 0,
                ("shunt", "shunt_link", "bus"): 0,
                ("bus", "shunt_link", "shunt"): 0,
            },
        }

    def __repr__(self) -> str:
        return (
            f"{self.__class__.__name__}({len(self)}, "
            f"case_name={self.case_name})"
        )

download() -> None

Download and extract the raw tar.gz archive for this group.

Source code in lumina/dataset/opf/opf_on_disk_dataset.py
def download(self) -> None:
    """Download and extract the raw tar.gz archive for this group."""
    self.download_and_extract(self.raw_file_names[0])

download_and_extract(name: str) -> None

Download a tar.gz archive by name and extract it to the raw directory.

Parameters:

Name Type Description Default
name str

Filename of the archive to download.

required
Source code in lumina/dataset/opf/opf_on_disk_dataset.py
def download_and_extract(self, name: str) -> None:
    """Download a tar.gz archive by name and extract it to the raw directory.

    Args:
        name (str): Filename of the archive to download.
    """
    url = f"{self.url}/{self._release}/{name}"
    path = download_url(url, self.raw_dir)
    extract_tar(path, self.raw_dir)

process() -> None

Process raw JSON files and write samples into the on-disk database.

Source code in lumina/dataset/opf/opf_on_disk_dataset.py
def process(self) -> None:
    """Process raw JSON files and write samples into the on-disk database."""
    if osp.exists(self.processed_paths[0]):
        self._clear_processed()

    if not osp.exists(self.tmp_dir):
        os.makedirs(self.tmp_dir)

    try:
        self.process_json_group(self.group_id)
    except Exception as exc:
        print(f"Error processing group {self.group_id}: {exc}")
        raise exc

    if not self.keep_temp:
        shutil.rmtree(osp.join(self.raw_dir, "gridopt-dataset-tmp"))

process_json_group(group_id: int) -> None

Parse all JSON files for a group and insert them into the database.

Parameters:

Name Type Description Default
group_id int

Group identifier to process.

required
Source code in lumina/dataset/opf/opf_on_disk_dataset.py
def process_json_group(self, group_id: int) -> None:
    """Parse all JSON files for a group and insert them into the database.

    Args:
        group_id (int): Group identifier to process.
    """
    group_json_files = glob(osp.join(self.tmp_dir, f"group_{group_id}", "*.json"))
    if len(group_json_files) < 15000:
        extract_tar(osp.join(self.raw_dir, self.raw_file_names[0]), self.raw_dir)
        group_json_files = glob(osp.join(self.tmp_dir, f"group_{group_id}", "*.json"))

    batch = []
    for json_file in tqdm.tqdm(group_json_files, desc=f"Group {group_id}"):
        data = process_json_file(json_file)
        if data is None:
            continue
        if self.pre_filter is not None and not self.pre_filter(data):
            continue
        if self.pre_transform is not None:
            data = self.pre_transform(data)
        batch.append(data)
        if len(batch) >= self.write_batch_size:
            self.extend(batch)
            batch = []

    if batch:
        self.extend(batch)

metadata()

Return node and edge feature dimensionality metadata from the first sample.

Returns:

Name Type Description
dict

Dictionary with "nodes" and "edges" keys mapping node types and edge types to their respective feature dimensions.

Source code in lumina/dataset/opf/opf_on_disk_dataset.py
def metadata(self):
    """Return node and edge feature dimensionality metadata from the first sample.

    Returns:
        dict: Dictionary with ``"nodes"`` and ``"edges"`` keys mapping node
            types and edge types to their respective feature dimensions.
    """
    sample = self.get(0)

    def node_dim(node_type: str) -> int:
        if node_type not in getattr(sample, "node_types", []):
            return 0
        store = sample[node_type]
        x = getattr(store, "x", None)
        if torch.is_tensor(x):
            return int(x.size(1)) if x.ndim > 1 else 1
        return 0

    def edge_dim(edge_type) -> int:
        if edge_type not in getattr(sample, "edge_types", []):
            return 0
        store = sample[edge_type]
        edge_attr = getattr(store, "edge_attr", None)
        if torch.is_tensor(edge_attr):
            return int(edge_attr.size(1)) if edge_attr.ndim > 1 else 1
        return 0

    return {
        "nodes": {
            "bus": node_dim("bus"),
            "generator": node_dim("generator"),
            "load": node_dim("load"),
            "shunt": node_dim("shunt"),
        },
        "edges": {
            ("bus", "ac_line", "bus"): edge_dim(("bus", "ac_line", "bus")),
            ("bus", "transformer", "bus"): edge_dim(("bus", "transformer", "bus")),
            ("generator", "generator_link", "bus"): 0,
            ("bus", "generator_link", "generator"): 0,
            ("load", "load_link", "bus"): 0,
            ("bus", "load_link", "load"): 0,
            ("shunt", "shunt_link", "bus"): 0,
            ("bus", "shunt_link", "shunt"): 0,
        },
    }

OPFOnDiskHomogeneousDataset

Bases: OPFOnDiskDataset

On-disk OPF dataset that converts to homogeneous graphs before storing.

Each HeteroData sample is converted to a homogeneous Data object via :class:OPFHomoWrapper, optionally pruned and cast to a compact dtype for storage. On retrieval the data is restored to float32 if needed.

Parameters:

Name Type Description Default
*args

Positional arguments forwarded to :class:OPFOnDiskDataset.

()
add_node_type bool

Append one-hot node type features. (default: :obj:True)

True
add_edge_type bool

Append one-hot edge type features. (default: :obj:True)

True
sanitize_targets bool

Replace non-finite targets with zero. (default: :obj:True)

True
log_bad_targets bool

Log warnings for non-finite targets. (default: :obj:True)

True
max_bad_target_logs int

Maximum bad-target warnings to emit. (default: :obj:1)

1
processed_suffix str

Suffix for the processed directory name. (default: :obj:"homo")

'homo'
attach_full_edge_attr bool

Store full concatenated edge attributes as edge_attr_full. (default: :obj:False)

False
prune_homo bool

Remove extraneous attributes from the homogeneous data before writing. (default: :obj:True)

True
storage_dtype Optional[str]

Dtype for on-disk storage of floating point tensors, e.g. "float16". None keeps the original dtype. (default: :obj:"float16")

'float16'
restore_fp32 bool

Cast tensors back to float32 on retrieval. (default: :obj:True)

True
**kwargs

Additional keyword arguments forwarded to :class:OPFOnDiskDataset.

{}
Source code in lumina/dataset/opf/opf_on_disk_dataset.py
class OPFOnDiskHomogeneousDataset(OPFOnDiskDataset):
    r"""On-disk OPF dataset that converts to homogeneous graphs before storing.

    Each HeteroData sample is converted to a homogeneous ``Data`` object via
    :class:`OPFHomoWrapper`, optionally pruned and cast to a compact dtype for
    storage. On retrieval the data is restored to float32 if needed.

    Args:
        *args: Positional arguments forwarded to :class:`OPFOnDiskDataset`.
        add_node_type (bool): Append one-hot node type features.
            (default: :obj:`True`)
        add_edge_type (bool): Append one-hot edge type features.
            (default: :obj:`True`)
        sanitize_targets (bool): Replace non-finite targets with zero.
            (default: :obj:`True`)
        log_bad_targets (bool): Log warnings for non-finite targets.
            (default: :obj:`True`)
        max_bad_target_logs (int): Maximum bad-target warnings to emit.
            (default: :obj:`1`)
        processed_suffix (str): Suffix for the processed directory name.
            (default: :obj:`"homo"`)
        attach_full_edge_attr (bool): Store full concatenated edge attributes
            as ``edge_attr_full``. (default: :obj:`False`)
        prune_homo (bool): Remove extraneous attributes from the homogeneous
            data before writing. (default: :obj:`True`)
        storage_dtype (Optional[str]): Dtype for on-disk storage of floating
            point tensors, e.g. ``"float16"``. ``None`` keeps the original
            dtype. (default: :obj:`"float16"`)
        restore_fp32 (bool): Cast tensors back to float32 on retrieval.
            (default: :obj:`True`)
        **kwargs: Additional keyword arguments forwarded to
            :class:`OPFOnDiskDataset`.
    """

    def __init__(
        self,
        *args,
        add_node_type: bool = True,
        add_edge_type: bool = True,
        sanitize_targets: bool = True,
        log_bad_targets: bool = True,
        max_bad_target_logs: int = 1,
        processed_suffix: str = "homo",
        attach_full_edge_attr: bool = False,
        prune_homo: bool = True,
        storage_dtype: Optional[str] = "float16",
        restore_fp32: bool = True,
        **kwargs,
    ) -> None:
        self._processed_suffix = processed_suffix or "homo"
        self._homo_wrapper = OPFHomoWrapper(
            add_node_type=add_node_type,
            add_edge_type=add_edge_type,
            attach_full_edge_attr=attach_full_edge_attr,
        )
        self._sanitize_targets = bool(sanitize_targets)
        self._log_bad_targets = bool(log_bad_targets)
        self._max_bad_target_logs = int(max_bad_target_logs)
        self._bad_target_logs = 0
        self._prune_homo = bool(prune_homo)
        self._storage_dtype = self._resolve_dtype(storage_dtype)
        self._restore_fp32 = bool(restore_fp32)

        user_pre_transform = kwargs.pop("pre_transform", None)
        user_transform = kwargs.pop("transform", None)

        def pre_transform(data):
            homo_data = self._homo_wrapper.convert(data)
            self._copy_graph_attrs(data, homo_data)
            self._sanitize_homo_targets(homo_data)
            if user_pre_transform is not None:
                homo_data = user_pre_transform(homo_data)
            if self._prune_homo:
                homo_data = self._prune_homo_data(homo_data)
            if self._storage_dtype is not None:
                homo_data = self._cast_homo_data(homo_data, self._storage_dtype)
            return homo_data

        def transform(data):
            if (
                self._restore_fp32
                and self._storage_dtype is not None
                and self._storage_dtype != torch.float32
            ):
                data = self._cast_homo_data(data, torch.float32)
            if user_transform is not None:
                data = user_transform(data)
            return data

        super().__init__(*args, pre_transform=pre_transform, transform=transform, **kwargs)

    @property
    def processed_dir(self) -> str:
        release = self._release
        if self._processed_suffix:
            release = f"{release}_{self._processed_suffix}"
        return osp.join(self._processed_root, release, self.case_name)

    def _resolve_dtype(self, dtype):
        if dtype is None:
            return None
        if isinstance(dtype, torch.dtype):
            return dtype
        if isinstance(dtype, str):
            key = dtype.strip().lower()
            if key in {"none", "null", ""}:
                return None
            if key in {"fp16", "float16"}:
                return torch.float16
            if key in {"bf16", "bfloat16"}:
                return torch.bfloat16
            if key in {"fp32", "float32"}:
                return torch.float32
        raise ValueError(f"Unsupported storage_dtype: {dtype}")

    def _cast_homo_data(self, data: Data, dtype: torch.dtype) -> Data:
        for key in ("x", "edge_attr", "edge_attr_full", "y"):
            value = getattr(data, key, None)
            if torch.is_tensor(value) and value.is_floating_point():
                setattr(data, key, value.to(dtype))
        return data

    def _prune_homo_data(self, data: Data) -> Data:
        keep = {}
        for key in ("x", "edge_index", "edge_attr", "edge_attr_full", "y", "y_mask", "node_type", "edge_type"):
            value = getattr(data, key, None)
            if value is not None:
                keep[key] = value
        pruned = Data(**keep)
        for key in ("node_type_names", "edge_type_names", "baseMVA", "base_mva"):
            if hasattr(data, key):
                setattr(pruned, key, getattr(data, key))
        return pruned

    def _copy_graph_attrs(self, hetero_data, homo_data):
        if hasattr(hetero_data, "baseMVA"):
            homo_data.baseMVA = hetero_data.baseMVA
        elif hasattr(hetero_data, "base_mva"):
            homo_data.baseMVA = hetero_data.base_mva

    def _sanitize_homo_targets(self, homo_data):
        y = getattr(homo_data, "y", None)
        if not torch.is_tensor(y):
            return

        finite_mask = torch.isfinite(y)
        if finite_mask.ndim > 1:
            row_mask = finite_mask.all(dim=-1)
        else:
            row_mask = finite_mask

        if bool(row_mask.all().item()):
            return

        if self._sanitize_targets:
            if y.ndim == 0:
                y = torch.zeros_like(y)
            else:
                y = y.clone()
                y[~row_mask] = 0
            homo_data.y = y

        homo_data.y_mask = row_mask.to(dtype=torch.bool)

        if self._should_log_bad_targets():
            bad_count = int((~row_mask).sum().item())
            total = int(row_mask.numel())
            action = "sanitized" if self._sanitize_targets else "left as-is"
            print(
                f"[OPFOnDiskHomogeneousDataset] Non-finite targets: "
                f"{bad_count}/{total} rows {action}; stored y_mask."
            )
            self._bad_target_logs += 1

    def _should_log_bad_targets(self):
        if not self._log_bad_targets:
            return False
        if self._bad_target_logs >= self._max_bad_target_logs:
            return False
        try:
            rank = int(os.environ.get("RANK", "0"))
        except ValueError:
            rank = 0
        return rank == 0

Sharded Dataset

OPFShardedIterableDataset

Bases: IterableDataset

Iterable dataset that streams OPF samples from sharded .pt files.

Shards are distributed across DDP ranks and DataLoader workers so that each global worker processes a disjoint subset. Shard order can be shuffled per epoch.

Parameters:

Name Type Description Default
shards Iterable[ShardInfo]

Shard descriptors to iterate over.

required
shuffle_shards bool

Shuffle shard order each epoch. (default: :obj:False)

False
seed int

Base random seed for shard shuffling. (default: :obj:0)

0
transform Optional[Callable]

Per-sample transform applied on iteration. (default: :obj:None)

None
Source code in lumina/dataset/opf/opf_sharded_dataset.py
class OPFShardedIterableDataset(IterableDataset):
    """Iterable dataset that streams OPF samples from sharded ``.pt`` files.

    Shards are distributed across DDP ranks and DataLoader workers so that
    each global worker processes a disjoint subset. Shard order can be
    shuffled per epoch.

    Args:
        shards (Iterable[ShardInfo]): Shard descriptors to iterate over.
        shuffle_shards (bool): Shuffle shard order each epoch.
            (default: :obj:`False`)
        seed (int): Base random seed for shard shuffling.
            (default: :obj:`0`)
        transform (Optional[Callable]): Per-sample transform applied on
            iteration. (default: :obj:`None`)
    """

    def __init__(
        self,
        shards: Iterable[ShardInfo],
        shuffle_shards: bool = False,
        seed: int = 0,
        transform: Optional[Callable] = None,
    ) -> None:
        super().__init__()
        self.shards = list(shards)
        self.shuffle_shards = bool(shuffle_shards)
        self.seed = int(seed)
        self.epoch = 0
        self.transform = transform
        self._num_samples = self._compute_num_samples()

    def _compute_num_samples(self) -> Optional[int]:
        if not self.shards:
            return 0
        return sum(int(shard.num_samples) for shard in self.shards)

    def __len__(self) -> int:
        if self._num_samples is None:
            raise TypeError("Sharded dataset length is unknown.")
        return int(self._num_samples)

    def set_epoch(self, epoch: int) -> None:
        """Set the current epoch for deterministic shard shuffling.

        Args:
            epoch (int): Current training epoch number.
        """
        self.epoch = int(epoch)

    def _dist_info(self):
        if dist is not None and dist.is_available() and dist.is_initialized():
            return dist.get_rank(), dist.get_world_size()
        rank = int(os.environ.get("RANK", "0"))
        world_size = int(os.environ.get("WORLD_SIZE", "1"))
        return rank, world_size

    def _iter_inmemory(self, data, slices, num_samples: int):
        accessor = _InMemoryShardAccessor(data, slices)
        for idx in range(num_samples):
            sample = accessor.get(idx)
            if self.transform is not None:
                sample = self.transform(sample)
            yield sample

    def _iter_shard(self, shard: ShardInfo):
        obj = torch.load(shard.path, map_location="cpu", weights_only=False)
        try:
            if isinstance(obj, tuple) and len(obj) == 2:
                data, slices = obj
                yield from self._iter_inmemory(data, slices, shard.num_samples)
                return
            if isinstance(obj, list):
                for sample in obj:
                    if self.transform is not None:
                        sample = self.transform(sample)
                    yield sample
                return
            if isinstance(obj, dict) and "data" in obj and "slices" in obj:
                yield from self._iter_inmemory(obj["data"], obj["slices"], shard.num_samples)
                return
            raise ValueError(f"Unsupported shard payload type in {shard.path}.")
        finally:
            del obj

    def __iter__(self):
        worker_info = get_worker_info()
        worker_id = worker_info.id if worker_info else 0
        num_workers = worker_info.num_workers if worker_info else 1
        rank, world_size = self._dist_info()
        total_workers = max(1, num_workers * world_size)
        global_worker_id = rank * num_workers + worker_id

        shard_indices = list(range(len(self.shards)))
        if self.shuffle_shards:
            rng = random.Random(self.seed + self.epoch)
            rng.shuffle(shard_indices)

        for local_idx, shard_idx in enumerate(shard_indices):
            if (local_idx % total_workers) != global_worker_id:
                continue
            shard = self.shards[shard_idx]
            yield from self._iter_shard(shard)

    def peek(self):
        """Return the first sample from the first shard without iterating all data.

        Returns:
            The first data sample from the dataset.

        Raises:
            RuntimeError: If no shards are available.
        """
        if not self.shards:
            raise RuntimeError("No shards available to sample from.")
        shard = self.shards[0]
        return next(self._iter_shard(shard))

    def metadata(self):
        """Return node and edge feature dimensionality metadata from the first sample.

        Returns:
            dict: Dictionary with ``"nodes"`` and ``"edges"`` keys mapping
                node types and edge types to their feature dimensions.
        """
        sample = self.peek()

        def node_dim(node_type: str) -> int:
            if node_type not in getattr(sample, "node_types", []):
                return 0
            store = sample[node_type]
            x = getattr(store, "x", None)
            if torch.is_tensor(x):
                return int(x.size(1)) if x.ndim > 1 else 1
            return 0

        def edge_dim(edge_type) -> int:
            if edge_type not in getattr(sample, "edge_types", []):
                return 0
            store = sample[edge_type]
            edge_attr = getattr(store, "edge_attr", None)
            if torch.is_tensor(edge_attr):
                return int(edge_attr.size(1)) if edge_attr.ndim > 1 else 1
            return 0

        return {
            "nodes": {
                "bus": node_dim("bus"),
                "generator": node_dim("generator"),
                "load": node_dim("load"),
                "shunt": node_dim("shunt"),
            },
            "edges": {
                ("bus", "ac_line", "bus"): edge_dim(("bus", "ac_line", "bus")),
                ("bus", "transformer", "bus"): edge_dim(("bus", "transformer", "bus")),
                ("generator", "generator_link", "bus"): 0,
                ("bus", "generator_link", "generator"): 0,
                ("load", "load_link", "bus"): 0,
                ("bus", "load_link", "load"): 0,
                ("shunt", "shunt_link", "bus"): 0,
                ("bus", "shunt_link", "shunt"): 0,
            },
        }

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(shards={len(self.shards)}, samples={len(self)})"

set_epoch(epoch: int) -> None

Set the current epoch for deterministic shard shuffling.

Parameters:

Name Type Description Default
epoch int

Current training epoch number.

required
Source code in lumina/dataset/opf/opf_sharded_dataset.py
def set_epoch(self, epoch: int) -> None:
    """Set the current epoch for deterministic shard shuffling.

    Args:
        epoch (int): Current training epoch number.
    """
    self.epoch = int(epoch)

peek()

Return the first sample from the first shard without iterating all data.

Returns:

Type Description

The first data sample from the dataset.

Raises:

Type Description
RuntimeError

If no shards are available.

Source code in lumina/dataset/opf/opf_sharded_dataset.py
def peek(self):
    """Return the first sample from the first shard without iterating all data.

    Returns:
        The first data sample from the dataset.

    Raises:
        RuntimeError: If no shards are available.
    """
    if not self.shards:
        raise RuntimeError("No shards available to sample from.")
    shard = self.shards[0]
    return next(self._iter_shard(shard))

metadata()

Return node and edge feature dimensionality metadata from the first sample.

Returns:

Name Type Description
dict

Dictionary with "nodes" and "edges" keys mapping node types and edge types to their feature dimensions.

Source code in lumina/dataset/opf/opf_sharded_dataset.py
def metadata(self):
    """Return node and edge feature dimensionality metadata from the first sample.

    Returns:
        dict: Dictionary with ``"nodes"`` and ``"edges"`` keys mapping
            node types and edge types to their feature dimensions.
    """
    sample = self.peek()

    def node_dim(node_type: str) -> int:
        if node_type not in getattr(sample, "node_types", []):
            return 0
        store = sample[node_type]
        x = getattr(store, "x", None)
        if torch.is_tensor(x):
            return int(x.size(1)) if x.ndim > 1 else 1
        return 0

    def edge_dim(edge_type) -> int:
        if edge_type not in getattr(sample, "edge_types", []):
            return 0
        store = sample[edge_type]
        edge_attr = getattr(store, "edge_attr", None)
        if torch.is_tensor(edge_attr):
            return int(edge_attr.size(1)) if edge_attr.ndim > 1 else 1
        return 0

    return {
        "nodes": {
            "bus": node_dim("bus"),
            "generator": node_dim("generator"),
            "load": node_dim("load"),
            "shunt": node_dim("shunt"),
        },
        "edges": {
            ("bus", "ac_line", "bus"): edge_dim(("bus", "ac_line", "bus")),
            ("bus", "transformer", "bus"): edge_dim(("bus", "transformer", "bus")),
            ("generator", "generator_link", "bus"): 0,
            ("bus", "generator_link", "generator"): 0,
            ("load", "load_link", "bus"): 0,
            ("bus", "load_link", "load"): 0,
            ("shunt", "shunt_link", "bus"): 0,
            ("bus", "shunt_link", "shunt"): 0,
        },
    }

load_shard_manifest(path: str) -> Dict

Load a shard manifest JSON file and annotate it with its source path.

Parameters:

Name Type Description Default
path str

Path to the manifest JSON file.

required

Returns:

Name Type Description
Dict Dict

Parsed manifest dictionary with an added _manifest_path key.

Source code in lumina/dataset/opf/opf_sharded_dataset.py
def load_shard_manifest(path: str) -> Dict:
    """Load a shard manifest JSON file and annotate it with its source path.

    Args:
        path (str): Path to the manifest JSON file.

    Returns:
        Dict: Parsed manifest dictionary with an added ``_manifest_path`` key.
    """
    with open(path, "r") as f:
        manifest = json.load(f)
    manifest["_manifest_path"] = path
    return manifest

build_shard_infos(manifest: Dict) -> List[ShardInfo]

Build a list of ShardInfo objects from a parsed manifest dictionary.

Resolves relative shard paths against the manifest's base directory and infers num_samples from the shard file when not specified.

Parameters:

Name Type Description Default
manifest Dict

Parsed shard manifest containing a "shards" list.

required

Returns:

Type Description
List[ShardInfo]

List[ShardInfo]: Ordered list of shard descriptors.

Raises:

Type Description
KeyError

If the manifest is missing the "shards" key or a shard entry is missing the "path" key.

Source code in lumina/dataset/opf/opf_sharded_dataset.py
def build_shard_infos(manifest: Dict) -> List[ShardInfo]:
    """Build a list of ShardInfo objects from a parsed manifest dictionary.

    Resolves relative shard paths against the manifest's base directory and
    infers ``num_samples`` from the shard file when not specified.

    Args:
        manifest (Dict): Parsed shard manifest containing a ``"shards"`` list.

    Returns:
        List[ShardInfo]: Ordered list of shard descriptors.

    Raises:
        KeyError: If the manifest is missing the ``"shards"`` key or a shard
            entry is missing the ``"path"`` key.
    """
    if "shards" not in manifest:
        raise KeyError("Shard manifest missing 'shards' list.")
    base_dir = _manifest_base_dir(manifest)
    shard_infos: List[ShardInfo] = []
    for entry in manifest["shards"]:
        if "path" not in entry:
            raise KeyError("Shard entry missing 'path'.")
        path = entry["path"]
        if not osp.isabs(path):
            path = osp.join(base_dir, path)
        name = entry.get("name") or osp.basename(path)
        num_samples = entry.get("num_samples")
        if num_samples is None:
            num_samples = _infer_num_samples(path)
        group_id = entry.get("group_id")
        group_id = int(group_id) if group_id is not None else None
        shard_infos.append(
            ShardInfo(
                path=path,
                num_samples=int(num_samples),
                group_id=group_id,
                name=name,
            )
        )
    return shard_infos

Processing Functions

process_json_file(json_file)

Process a single json file.

Parameters:

Name Type Description Default
json_file str

Path to the json file.

required

Returns:

Name Type Description
data HeteroData

Processed single data object.

Source code in lumina/dataset/opf/opf_dataset.py
def process_json_file(json_file):
    r"""Process a single json file.

    Args:
        json_file (str): Path to the json file.

    Returns:
        data (HeteroData): Processed single data object.
    """
    with open(json_file) as f:
        try:
            obj = json.load(f)
        except json.JSONDecodeError:
            print(f"Error decoding JSON from file: {json_file}")
            return None

    return build_heterodata_from_grid(obj['grid'], obj['metadata'], obj['solution'])

build_heterodata_from_grid(grid: Dict, metadata: Dict, solution: Optional[Dict] = None)

Build a single HeteroData graph from OPFData grid and metadata.

Parameters:

Name Type Description Default
grid Dict

Grid payload from an OPFData object.

required
metadata Dict

Metadata payload from an OPFData object.

required
solution Dict

Solution payload from an OPFData object.

None

Returns:

Name Type Description
data HeteroData

Processed single data object.

Source code in lumina/dataset/opf/opf_dataset.py
def build_heterodata_from_grid(grid: Dict, metadata: Dict, solution: Optional[Dict] = None):
    r"""Build a single HeteroData graph from OPFData grid and metadata.

    Args:
        grid (Dict): Grid payload from an OPFData object.
        metadata (Dict): Metadata payload from an OPFData object.
        solution (Dict, optional): Solution payload from an OPFData object.

    Returns:
        data (HeteroData): Processed single data object.
    """
    obj = {'grid': grid}

    # Graph-level properties:
    hdata = HeteroData()
    hdata.baseMVA = torch.tensor(grid['context']).view(-1).item()
    hdata.objective = torch.tensor(metadata['objective'])

    # ! bus (only some have a target):
    # x: `base_kv, bus_type, vmin, vmax`
    bus_x = np.array(grid['nodes']['bus'])
    # bus_type (index 1)
    bus_type = bus_x[:, 1].astype(int)
    # One-hot encode bus_type (4 types: 1,2,3,4)
    # 1: pg, 2: pv, 3: ref, 4: isolated
    bus_type_onehot = np.eye(4)[bus_type - 1]  # bus_type assumed to be 1-based
    # Remove the original bus_type column and concatenate one-hot
    bus_x_wo_type = np.delete(bus_x, 1, axis=1)
    # x: `base_kv, vmin, vmax, pg, pv, ref, isolated`
    bus_x_final = np.concatenate([bus_x_wo_type, bus_type_onehot], axis=1)
    hdata['bus'].x = torch.tensor(bus_x_final)
    if solution is not None:
        # y: `va, vm`
        hdata['bus'].y = torch.tensor(solution['nodes']['bus'])

    # ! generator (only some have a target):
    # x: `mbase, pg, pmin, pmax, qg, qmin, qmax, vg, cost_squared, cost_linear, cost_offset`
    hdata['generator'].x = torch.tensor(grid['nodes']['generator'])
    if solution is not None:
        # y: `pg, qg`
        hdata['generator'].y = torch.tensor(solution['nodes']['generator'])

    # ! load (only some have a target):
    # x: `pd, qd`
    hdata['load'].x = torch.tensor(grid['nodes']['load'])

    # ! shunt (only some have a target):
    # x: `bs, gs`
    hdata['shunt'].x = torch.tensor(grid['nodes']['shunt'])

    # ! ac_line (only ac lines and transformers have features):
    hdata['bus', 'ac_line', 'bus'].edge_index = extract_edge_index(obj, 'ac_line')
    # edge_attr: `angmin, angmax, b_fr, b_to, br_r, br_x, rate_a, rate_b, rate_c`
    hdata['bus', 'ac_line', 'bus'].edge_attr = torch.tensor(grid['edges']['ac_line']['features'])
    if solution is not None:
        # edge_label: `pt, qt, pf, qf`
        hdata['bus', 'ac_line', 'bus'].edge_label = torch.tensor(solution['edges']['ac_line']['features'])

    # ! transformer (only ac lines and transformers have features):
    hdata['bus', 'transformer', 'bus'].edge_index = extract_edge_index(obj, 'transformer')
    # edge_attr: `angmin, angmax, br_r, br_x, rate_a, rate_b, rate_c, tap, shift, b_fr, b_to`
    hdata['bus', 'transformer', 'bus'].edge_attr = torch.tensor(grid['edges']['transformer']['features'])
    if solution is not None:
        # edge_label: `pt, qt, pf, qf`
        hdata['bus', 'transformer', 'bus'].edge_label = torch.tensor(solution['edges']['transformer']['features'])

    # ! virtual links:
    # bus-generator
    hdata['generator', 'generator_link', 'bus'].edge_index = extract_edge_index(obj, 'generator_link')
    hdata['bus', 'generator_link', 'generator'].edge_index = extract_edge_index_rev(obj, 'generator_link')
    # bus-load
    hdata['load', 'load_link', 'bus'].edge_index = extract_edge_index(obj, 'load_link')
    hdata['bus', 'load_link', 'load'].edge_index = extract_edge_index_rev(obj, 'load_link')
    # bus-shunt
    hdata['shunt', 'shunt_link', 'bus'].edge_index = extract_edge_index(obj, 'shunt_link')
    hdata['bus', 'shunt_link', 'shunt'].edge_index = extract_edge_index_rev(obj, 'shunt_link')

    return hdata

process_hdf5_scenario(scenario, scenario_key: str) -> Union[Optional[HeteroData], List[HeteroData]]

Process a single HDF5 scenario group into a HeteroData graph.

Parameters:

Name Type Description Default
scenario

An open HDF5 group containing grid, solution, and metadata subgroups for one OPF scenario.

required
scenario_key str

Key identifying the scenario within the HDF5 file.

required

Returns:

Type Description
Union[Optional[HeteroData], List[HeteroData]]

Union[Optional[HeteroData], List[HeteroData]]: The constructed HeteroData object, or None if processing fails.

Source code in lumina/dataset/opf/opf_dataset.py
def process_hdf5_scenario(scenario, scenario_key: str) -> Union[Optional[HeteroData], List[HeteroData]]:
    """Process a single HDF5 scenario group into a HeteroData graph.

    Args:
        scenario: An open HDF5 group containing ``grid``, ``solution``, and
            ``metadata`` subgroups for one OPF scenario.
        scenario_key (str): Key identifying the scenario within the HDF5 file.

    Returns:
        Union[Optional[HeteroData], List[HeteroData]]: The constructed
            HeteroData object, or ``None`` if processing fails.
    """
    try:
        grid = scenario['grid'] if 'grid' in scenario else scenario.file['grid']
        solution = scenario['solution']
        metadata = scenario['metadata']

        hdata = HeteroData()
        _process_nodes_hdf5(hdata, grid, solution)
        _process_edges_hdf5(hdata, grid, solution)

        hdata.baseMVA = torch.tensor(grid['context']['baseMVA'][()], dtype=torch.float32).view(-1)
        hdata.objective = torch.tensor(metadata.attrs['objective'], dtype=torch.float32)
        hdata.scenario_id = scenario_key

        return hdata
    except Exception as e:
        print(f"Error in process_hdf5_scenario for {scenario_key}: {e}")
        return None

Schema

OPFSchemaModel

Bases: BaseModel

Base Pydantic model defining a column-ordered feature schema for OPF data.

Subclasses declare fields whose order defines the canonical column layout of the corresponding numpy feature array. Provides utilities for schema introspection, cross-schema alignment, and numpy conversion.

Source code in lumina/dataset/opf/schema.py
class OPFSchemaModel(BaseModel):
    """Base Pydantic model defining a column-ordered feature schema for OPF data.

    Subclasses declare fields whose order defines the canonical column layout
    of the corresponding numpy feature array. Provides utilities for schema
    introspection, cross-schema alignment, and numpy conversion.
    """

    model_config = ConfigDict(arbitrary_types_allowed=True)

    @classmethod
    def get_feature_names(cls) -> List[str]:
        """Return the ordered list of feature field names.

        Returns:
            List[str]: Field names in declaration order.
        """
        return list(cls.model_fields.keys())

    @classmethod
    def get_field_indices(cls) -> Dict[str, int]:
        """Return a mapping from field names to their column indices.

        Returns:
            Dict[str, int]: ``{field_name: column_index}`` for every field.
        """
        return {name: i for i, name in enumerate(cls.get_feature_names())}

    @classmethod
    def get_alignment_map(cls, other: type["OPFSchemaModel"]) -> Dict[int, int]:
        """Compute a column index mapping from this schema to another.

        Only fields present in both schemas are included.

        Args:
            other (type[OPFSchemaModel]): Target schema class.

        Returns:
            Dict[int, int]: ``{source_column: target_column}`` for shared fields.
        """
        cls_indices = cls.get_field_indices()
        other_indices = other.get_field_indices()

        mapping = {}
        for name, idx in cls_indices.items():
            if name in other_indices:
                mapping[idx] = other_indices[name]
        return mapping

    @classmethod
    def from_numpy(cls, data: np.ndarray) -> "OPFSchemaModel":
        """Create a model instance from a 1-D numpy array of feature values.

        Args:
            data (np.ndarray): Array whose elements correspond to schema fields
                in declaration order.

        Returns:
            OPFSchemaModel: Populated model instance.
        """
        return cls(**dict(zip(cls.get_feature_names(), data.tolist())))

    def to_numpy(self) -> np.ndarray:
        """Convert this model instance to a 1-D numpy array of feature values.

        Returns:
            np.ndarray: Feature values in declaration order.
        """
        return np.array([getattr(self, field) for field in self.get_feature_names()])

get_feature_names() -> List[str] classmethod

Return the ordered list of feature field names.

Returns:

Type Description
List[str]

List[str]: Field names in declaration order.

Source code in lumina/dataset/opf/schema.py
@classmethod
def get_feature_names(cls) -> List[str]:
    """Return the ordered list of feature field names.

    Returns:
        List[str]: Field names in declaration order.
    """
    return list(cls.model_fields.keys())

get_field_indices() -> Dict[str, int] classmethod

Return a mapping from field names to their column indices.

Returns:

Type Description
Dict[str, int]

Dict[str, int]: {field_name: column_index} for every field.

Source code in lumina/dataset/opf/schema.py
@classmethod
def get_field_indices(cls) -> Dict[str, int]:
    """Return a mapping from field names to their column indices.

    Returns:
        Dict[str, int]: ``{field_name: column_index}`` for every field.
    """
    return {name: i for i, name in enumerate(cls.get_feature_names())}

get_alignment_map(other: type[OPFSchemaModel]) -> Dict[int, int] classmethod

Compute a column index mapping from this schema to another.

Only fields present in both schemas are included.

Parameters:

Name Type Description Default
other type[OPFSchemaModel]

Target schema class.

required

Returns:

Type Description
Dict[int, int]

Dict[int, int]: {source_column: target_column} for shared fields.

Source code in lumina/dataset/opf/schema.py
@classmethod
def get_alignment_map(cls, other: type["OPFSchemaModel"]) -> Dict[int, int]:
    """Compute a column index mapping from this schema to another.

    Only fields present in both schemas are included.

    Args:
        other (type[OPFSchemaModel]): Target schema class.

    Returns:
        Dict[int, int]: ``{source_column: target_column}`` for shared fields.
    """
    cls_indices = cls.get_field_indices()
    other_indices = other.get_field_indices()

    mapping = {}
    for name, idx in cls_indices.items():
        if name in other_indices:
            mapping[idx] = other_indices[name]
    return mapping

from_numpy(data: np.ndarray) -> OPFSchemaModel classmethod

Create a model instance from a 1-D numpy array of feature values.

Parameters:

Name Type Description Default
data ndarray

Array whose elements correspond to schema fields in declaration order.

required

Returns:

Name Type Description
OPFSchemaModel OPFSchemaModel

Populated model instance.

Source code in lumina/dataset/opf/schema.py
@classmethod
def from_numpy(cls, data: np.ndarray) -> "OPFSchemaModel":
    """Create a model instance from a 1-D numpy array of feature values.

    Args:
        data (np.ndarray): Array whose elements correspond to schema fields
            in declaration order.

    Returns:
        OPFSchemaModel: Populated model instance.
    """
    return cls(**dict(zip(cls.get_feature_names(), data.tolist())))

to_numpy() -> np.ndarray

Convert this model instance to a 1-D numpy array of feature values.

Returns:

Type Description
ndarray

np.ndarray: Feature values in declaration order.

Source code in lumina/dataset/opf/schema.py
def to_numpy(self) -> np.ndarray:
    """Convert this model instance to a 1-D numpy array of feature values.

    Returns:
        np.ndarray: Feature values in declaration order.
    """
    return np.array([getattr(self, field) for field in self.get_feature_names()])

JSONBus

Bases: OPFSchemaModel

Bus features in JSON format.

Source code in lumina/dataset/opf/schema.py
class JSONBus(OPFSchemaModel):
    """Bus features in JSON format."""
    base_kv: float = Field(..., description="Base voltage (kV)")
    bus_type: int = Field(..., description="Bus type (1=PQ, 2=PV, 3=Ref, 4=Isolated)")
    vmin: float = Field(..., description="Minimum voltage magnitude (p.u.)")
    vmax: float = Field(..., description="Maximum voltage magnitude (p.u.)")

JSONGenerator

Bases: OPFSchemaModel

Generator features in JSON format.

Source code in lumina/dataset/opf/schema.py
class JSONGenerator(OPFSchemaModel):
    """Generator features in JSON format."""
    mbase: float = Field(..., description="Machine base power (MVA)")
    pg: float = Field(..., description="Active power generation (p.u.)")
    pmin: float = Field(..., description="Minimum active power output (p.u.)")
    pmax: float = Field(..., description="Maximum active power output (p.u.)")
    qg: float = Field(..., description="Reactive power generation (p.u.)")
    qmin: float = Field(..., description="Minimum reactive power output (p.u.)")
    qmax: float = Field(..., description="Maximum reactive power output (p.u.)")
    vg: float = Field(..., description="Voltage setpoint (p.u.)")
    cost_c2: float = Field(..., description="Quadratic cost coefficient ($/MW²/h)")
    cost_c1: float = Field(..., description="Linear cost coefficient ($/MW/h)")
    cost_c0: float = Field(..., description="Constant cost coefficient ($/h)")

H5Bus

Bases: OPFSchemaModel

Bus features in HDF5 format.

Source code in lumina/dataset/opf/schema.py
class H5Bus(OPFSchemaModel):
    """Bus features in HDF5 format."""
    vmin: float = Field(..., description="Minimum voltage magnitude (p.u.)")
    vmax: float = Field(..., description="Maximum voltage magnitude (p.u.)")
    zone: Optional[float] = Field(None, description="Zone identifier")
    area: Optional[float] = Field(None, description="Area identifier")
    bus_type: float = Field(..., description="Bus type (1=PQ, 2=PV, 3=Ref)")

H5Generator

Bases: OPFSchemaModel

Generator features in HDF5 format.

Source code in lumina/dataset/opf/schema.py
class H5Generator(OPFSchemaModel):
    """Generator features in HDF5 format."""
    pmax: float = Field(..., description="Maximum active power output (p.u.)")
    pmin: float = Field(..., description="Minimum active power output (p.u.)")
    qmax: float = Field(..., description="Maximum reactive power output (p.u.)")
    qmin: float = Field(..., description="Minimum reactive power output (p.u.)")
    cost_c2: float = Field(..., description="Quadratic cost coefficient ($/MW²/h)")
    cost_c1: float = Field(..., description="Linear cost coefficient ($/MW/h)")
    cost_c0: float = Field(..., description="Constant cost coefficient ($/h)")
    vg: float = Field(..., description="Voltage setpoint (p.u.)")
    mbase: float = Field(..., description="Machine base power (MVA)")
    gen_status: float = Field(..., description="Generator status (1=on, 0=off)")

Case Tagging

CaseTaggedDataset

Bases: Dataset

Map-style dataset wrapper that attaches a case_id to every sample.

Delegates all attribute access to the wrapped dataset, adding only the case_id tensor to each retrieved sample.

Parameters:

Name Type Description Default
dataset

Underlying map-style dataset.

required
case_id

Integer case identifier attached to every sample.

required
Source code in lumina/dataset/opf/case_id.py
class CaseTaggedDataset(Dataset):
    """Map-style dataset wrapper that attaches a ``case_id`` to every sample.

    Delegates all attribute access to the wrapped dataset, adding only the
    ``case_id`` tensor to each retrieved sample.

    Args:
        dataset: Underlying map-style dataset.
        case_id: Integer case identifier attached to every sample.
    """

    def __init__(self, dataset, case_id):
        self.dataset = dataset
        self.case_id = int(case_id)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        sample = self.dataset[idx]
        return attach_case_id(sample, self.case_id)

    def __getattr__(self, name):
        if name in {"dataset", "case_id"}:
            raise AttributeError(name)
        return getattr(self.dataset, name)

attach_case_id(sample, case_id)

Attach a case_id tensor to a PyG Data or HeteroData sample.

Parameters:

Name Type Description Default
sample

A PyG data object (or None).

required
case_id

Integer case identifier to store as a scalar torch.long tensor on the sample.

required

Returns:

Type Description

The input sample with case_id attribute set, or the unmodified

sample if it is None or case_id cannot be converted to int.

Source code in lumina/dataset/opf/case_id.py
def attach_case_id(sample, case_id):
    """Attach a ``case_id`` tensor to a PyG Data or HeteroData sample.

    Args:
        sample: A PyG data object (or ``None``).
        case_id: Integer case identifier to store as a scalar ``torch.long``
            tensor on the sample.

    Returns:
        The input sample with ``case_id`` attribute set, or the unmodified
        sample if it is ``None`` or *case_id* cannot be converted to int.
    """
    if sample is None:
        return sample
    try:
        sample.case_id = torch.tensor(int(case_id), dtype=torch.long)
    except (TypeError, ValueError):
        return sample
    return sample

Staging

opf_release(processed_suffix: Optional[str] = None) -> str

Build the release directory name, optionally appending a suffix.

Parameters:

Name Type Description Default
processed_suffix Optional[str]

Suffix to append (e.g. "homo").

None

Returns:

Name Type Description
str str

Release string such as "dataset_release_1" or "dataset_release_1_homo".

Source code in lumina/dataset/opf/staging.py
def opf_release(processed_suffix: Optional[str] = None) -> str:
    """Build the release directory name, optionally appending a suffix.

    Args:
        processed_suffix (Optional[str]): Suffix to append (e.g. ``"homo"``).

    Returns:
        str: Release string such as ``"dataset_release_1"`` or
            ``"dataset_release_1_homo"``.
    """
    release = "dataset_release_1"
    if processed_suffix:
        release += f"_{processed_suffix}"
    return release

stage_on_disk_group(source_root: str, stage_root: str, case_name: str, group_id: int, backend: str, processed_suffix: Optional[str] = None, log: bool = True) -> str

Copy an on-disk database from the source root to a local staging root.

Skips the copy if source and destination are the same path or if the destination already exists with the same file size. Also copies SQLite sidecar files (WAL, SHM, journal) when present.

Parameters:

Name Type Description Default
source_root str

Root directory containing the source database.

required
stage_root str

Local staging root to copy into.

required
case_name str

Name of the pglib-opf case.

required
group_id int

Group identifier.

required
backend str

Database backend ("sqlite" or "rocksdb").

required
processed_suffix Optional[str]

Release directory suffix.

None
log bool

Enable copy logging. (default: :obj:True)

True

Returns:

Name Type Description
str str

The effective root directory to use (either stage_root or source_root if staging was skipped).

Raises:

Type Description
FileNotFoundError

If the source database does not exist.

Source code in lumina/dataset/opf/staging.py
def stage_on_disk_group(
    source_root: str,
    stage_root: str,
    case_name: str,
    group_id: int,
    backend: str,
    processed_suffix: Optional[str] = None,
    log: bool = True,
) -> str:
    """Copy an on-disk database from the source root to a local staging root.

    Skips the copy if source and destination are the same path or if the
    destination already exists with the same file size. Also copies SQLite
    sidecar files (WAL, SHM, journal) when present.

    Args:
        source_root (str): Root directory containing the source database.
        stage_root (str): Local staging root to copy into.
        case_name (str): Name of the pglib-opf case.
        group_id (int): Group identifier.
        backend (str): Database backend (``"sqlite"`` or ``"rocksdb"``).
        processed_suffix (Optional[str]): Release directory suffix.
        log (bool): Enable copy logging. (default: :obj:`True`)

    Returns:
        str: The effective root directory to use (either *stage_root* or
            *source_root* if staging was skipped).

    Raises:
        FileNotFoundError: If the source database does not exist.
    """
    source_root = _expand_path(source_root) or source_root
    stage_root = _expand_path(stage_root) or stage_root

    if not stage_root or not source_root:
        return source_root

    if osp.abspath(stage_root) == osp.abspath(source_root):
        return source_root

    src_path = get_on_disk_db_path(
        source_root,
        case_name,
        group_id,
        backend,
        processed_suffix,
    )
    dst_path = get_on_disk_db_path(
        stage_root,
        case_name,
        group_id,
        backend,
        processed_suffix,
    )

    if not fs.exists(src_path):
        raise FileNotFoundError(f"On-disk DB missing at {src_path}")

    if fs.exists(dst_path):
        if not fs.isdir(src_path) and _same_size(src_path, dst_path):
            return stage_root

    dst_dir = osp.dirname(dst_path)
    fs.makedirs(dst_dir, exist_ok=True)

    if fs.isdir(src_path):
        if fs.exists(dst_path):
            fs.rm(dst_path, recursive=True)
        fs.cp(src_path, dst_dir, log=log)
    else:
        fs.cp(src_path, dst_path, log=log)
        _copy_sidecar_files(src_path, dst_path, log=log)

    return stage_root

stage_sharded_case(source_root: str, stage_root: str, case_name: str, processed_suffix: Optional[str] = None, manifest_name: str = 'manifest.json', log: bool = True) -> str

Copy an entire sharded case directory from source to a staging root.

Skips the copy if source and destination are the same path or if the destination manifest already exists with the same file size.

Parameters:

Name Type Description Default
source_root str

Root directory containing the source sharded data.

required
stage_root str

Local staging root to copy into.

required
case_name str

Name of the pglib-opf case.

required
processed_suffix Optional[str]

Release directory suffix.

None
manifest_name str

Manifest filename. (default: :obj:"manifest.json")

'manifest.json'
log bool

Enable copy logging. (default: :obj:True)

True

Returns:

Name Type Description
str str

The effective root directory to use (either stage_root or source_root if staging was skipped).

Raises:

Type Description
FileNotFoundError

If the source sharded directory does not exist.

Source code in lumina/dataset/opf/staging.py
def stage_sharded_case(
    source_root: str,
    stage_root: str,
    case_name: str,
    processed_suffix: Optional[str] = None,
    manifest_name: str = "manifest.json",
    log: bool = True,
) -> str:
    """Copy an entire sharded case directory from source to a staging root.

    Skips the copy if source and destination are the same path or if the
    destination manifest already exists with the same file size.

    Args:
        source_root (str): Root directory containing the source sharded data.
        stage_root (str): Local staging root to copy into.
        case_name (str): Name of the pglib-opf case.
        processed_suffix (Optional[str]): Release directory suffix.
        manifest_name (str): Manifest filename.
            (default: :obj:`"manifest.json"`)
        log (bool): Enable copy logging. (default: :obj:`True`)

    Returns:
        str: The effective root directory to use (either *stage_root* or
            *source_root* if staging was skipped).

    Raises:
        FileNotFoundError: If the source sharded directory does not exist.
    """
    source_root = _expand_path(source_root) or source_root
    stage_root = _expand_path(stage_root) or stage_root

    if not stage_root or not source_root:
        return source_root

    if osp.abspath(stage_root) == osp.abspath(source_root):
        return source_root

    src_dir = sharded_processed_dir(source_root, case_name, processed_suffix)
    dst_dir = sharded_processed_dir(stage_root, case_name, processed_suffix)
    src_manifest = get_sharded_manifest_path(
        source_root,
        case_name,
        processed_suffix,
        manifest_name,
    )
    dst_manifest = get_sharded_manifest_path(
        stage_root,
        case_name,
        processed_suffix,
        manifest_name,
    )

    if not fs.exists(src_dir):
        raise FileNotFoundError(f"Sharded dataset missing at {src_dir}")

    if fs.exists(dst_manifest) and _same_size(src_manifest, dst_manifest):
        return stage_root

    dst_parent = osp.dirname(dst_dir)
    fs.makedirs(dst_parent, exist_ok=True)
    if fs.exists(dst_dir):
        fs.rm(dst_dir, recursive=True)
    fs.cp(src_dir, dst_parent, log=log)
    return stage_root