Skip to content

Loader API

DataLoader

Bases: DataLoader

A data loader which merges data objects from a :class:torch_geometric.data.Dataset to a mini-batch. Data objects can be either of type :class:~torch_geometric.data.Data or :class:~torch_geometric.data.HeteroData.

Parameters:

Name Type Description Default
dataset Dataset

The dataset from which to load the data.

required
batch_size int

How many samples per batch to load. (default: :obj:1)

1
shuffle bool

If set to :obj:True, the data will be reshuffled at every epoch. (default: :obj:False)

False
follow_batch List[str]

Creates assignment batch vectors for each key in the list. (default: :obj:None)

None
exclude_keys List[str]

Will exclude each key in the list. (default: :obj:None)

None
**kwargs optional

Additional arguments of :class:torch.utils.data.DataLoader.

{}
Source code in lumina/loader/opf/opf_loader.py
class DataLoader(torch.utils.data.DataLoader):
    r"""A data loader which merges data objects from a
    :class:`torch_geometric.data.Dataset` to a mini-batch.
    Data objects can be either of type :class:`~torch_geometric.data.Data` or
    :class:`~torch_geometric.data.HeteroData`.

    Args:
        dataset (Dataset): The dataset from which to load the data.
        batch_size (int, optional): How many samples per batch to load.
            (default: :obj:`1`)
        shuffle (bool, optional): If set to :obj:`True`, the data will be
            reshuffled at every epoch. (default: :obj:`False`)
        follow_batch (List[str], optional): Creates assignment batch
            vectors for each key in the list. (default: :obj:`None`)
        exclude_keys (List[str], optional): Will exclude each key in the
            list. (default: :obj:`None`)
        **kwargs (optional): Additional arguments of
            :class:`torch.utils.data.DataLoader`.
    """

    def __init__(
        self,
        dataset: Union[Dataset, Sequence[BaseData], DatasetAdapter],
        batch_size: int = 1,
        shuffle: bool = False,
        follow_batch: Optional[List[str]] = None,
        exclude_keys: Optional[List[str]] = None,
        **kwargs,
    ):
        # Remove for PyTorch Lightning:
        kwargs.pop('collate_fn', None)

        # Save for PyTorch Lightning < 1.6:
        self.follow_batch = follow_batch
        self.exclude_keys = exclude_keys

        super().__init__(
            dataset,
            batch_size,
            shuffle,
            collate_fn=Collater(dataset, follow_batch, exclude_keys),
            **kwargs,
        )

Collater

Custom collation function for PyG data objects.

Handles batching of BaseData objects via Batch.from_data_list, as well as tensors, TensorFrame instances, and nested Python collections (dicts, named tuples, lists).

Parameters:

Name Type Description Default
dataset Dataset | Sequence[BaseData] | DatasetAdapter

Source dataset (used for type inference; not indexed during collation).

required
follow_batch list[str]

Keys for which to create assignment batch vectors.

None
exclude_keys list[str]

Keys to exclude from batching.

None
Source code in lumina/loader/opf/opf_loader.py
class Collater:
    """Custom collation function for PyG data objects.

    Handles batching of ``BaseData`` objects via ``Batch.from_data_list``,
    as well as tensors, ``TensorFrame`` instances, and nested Python
    collections (dicts, named tuples, lists).

    Args:
        dataset (Dataset | Sequence[BaseData] | DatasetAdapter): Source
            dataset (used for type inference; not indexed during collation).
        follow_batch (list[str], optional): Keys for which to create
            assignment batch vectors.
        exclude_keys (list[str], optional): Keys to exclude from batching.
    """

    def __init__(
        self,
        dataset: Union[Dataset, Sequence[BaseData], DatasetAdapter],
        follow_batch: Optional[List[str]] = None,
        exclude_keys: Optional[List[str]] = None,
    ):
        self.dataset = dataset
        self.follow_batch = follow_batch
        self.exclude_keys = exclude_keys

    def __call__(self, batch: List[Any]) -> Any:
        """Collate a list of data elements into a batch.

        Args:
            batch (list[Any]): List of individual data elements.

        Returns:
            Any: Batched data, whose type depends on the element type.

        Raises:
            TypeError: If the element type is not supported.
        """
        elem = batch[0]
        if isinstance(elem, BaseData):
            return Batch.from_data_list(
                batch,
                follow_batch=self.follow_batch,
                exclude_keys=self.exclude_keys,
            )
        elif isinstance(elem, torch.Tensor):
            return default_collate(batch)
        elif isinstance(elem, TensorFrame):
            return torch_frame.cat(batch, dim=0)
        elif isinstance(elem, float):
            return torch.tensor(batch, dtype=torch.float)
        elif isinstance(elem, int):
            return torch.tensor(batch)
        elif isinstance(elem, str):
            return batch
        elif isinstance(elem, Mapping):
            return {key: self([data[key] for data in batch]) for key in elem}
        elif isinstance(elem, tuple) and hasattr(elem, '_fields'):
            return type(elem)(*(self(s) for s in zip(*batch)))
        elif isinstance(elem, Sequence) and not isinstance(elem, str):
            return [self(s) for s in zip(*batch)]

        raise TypeError(f"DataLoader found invalid type: '{type(elem)}'")

__call__(batch: List[Any]) -> Any

Collate a list of data elements into a batch.

Parameters:

Name Type Description Default
batch list[Any]

List of individual data elements.

required

Returns:

Name Type Description
Any Any

Batched data, whose type depends on the element type.

Raises:

Type Description
TypeError

If the element type is not supported.

Source code in lumina/loader/opf/opf_loader.py
def __call__(self, batch: List[Any]) -> Any:
    """Collate a list of data elements into a batch.

    Args:
        batch (list[Any]): List of individual data elements.

    Returns:
        Any: Batched data, whose type depends on the element type.

    Raises:
        TypeError: If the element type is not supported.
    """
    elem = batch[0]
    if isinstance(elem, BaseData):
        return Batch.from_data_list(
            batch,
            follow_batch=self.follow_batch,
            exclude_keys=self.exclude_keys,
        )
    elif isinstance(elem, torch.Tensor):
        return default_collate(batch)
    elif isinstance(elem, TensorFrame):
        return torch_frame.cat(batch, dim=0)
    elif isinstance(elem, float):
        return torch.tensor(batch, dtype=torch.float)
    elif isinstance(elem, int):
        return torch.tensor(batch)
    elif isinstance(elem, str):
        return batch
    elif isinstance(elem, Mapping):
        return {key: self([data[key] for data in batch]) for key in elem}
    elif isinstance(elem, tuple) and hasattr(elem, '_fields'):
        return type(elem)(*(self(s) for s in zip(*batch)))
    elif isinstance(elem, Sequence) and not isinstance(elem, str):
        return [self(s) for s in zip(*batch)]

    raise TypeError(f"DataLoader found invalid type: '{type(elem)}'")