Skip to content

Data Loading

DGDataLoader

loader

Classes:

  • DGDataLoader

    Iterate and materialize batches from a DGraph.

DGDataLoader

DGDataLoader(
    dg: DGraph,
    batch_size: int = 1,
    batch_unit: str = 'r',
    on_empty: Literal['skip', 'raise', None] = 'skip',
    hook_manager: HookManager | None = None,
    **kwargs: Any,
)

Bases: _SkippableDataLoaderMixin, DataLoader

Iterate and materialize batches from a DGraph.

This DataLoader supports both event-ordered and time-ordered temporal graphs. Optional hooks can be applied to each batch, and empty batches can be skipped or raise an exception depending on configuration.

Parameters:

  • dg (DGraph) –

    The dynamic graph to iterate.

  • batch_size (int, default: 1 ) –

    The batch size to yield at each iteration.

  • batch_unit (str, default: 'r' ) –

    The unit corresponding to the batch_size ('r' for event-ordered batches, or a time unit for time-ordered). Defaults to 'r'.

  • on_empty (Literal['skip', 'raise', None], default: 'skip' ) –

    Behavior for empty batches. 'skip' to ignore, 'raise' to throw an error. Defaults to 'skip'.

  • hook_manager (HookManager | None, default: None ) –

    Optional hooks to apply transformations to each batch before returning. Defaults to None.

  • **kwargs (Any, default: {} ) –

    Additional arguments passed to torch.utils.data.DataLoader.

Raises:

  • ValueError

    If batch_size <= 0.

  • EventOrderedConversionError

    If iterating an event-ordered DGraph using a time-ordered batch_unit.

  • InvalidDiscretizationError

    If a time-ordered DGraph has a time unit coarser than the batch_unit.

  • EmptyBatchError

    If an empty batch is encountered with on_empty='raise'.

Note
  • Event-ordered batching ('r') iterates sequentially over event indices. TIme-ordered batching iterates over temporal slices according to batch_unit.
  • For time-ordered batching, batch_unit must not be coarser than the DGraph time delta. Otherwise, a ValueError is raised.
  • The effective batch size may be adjusted internally when using time-ordered batching to match the graph's time granularity.
  • The length returned by len(DGDataLoader) may be inaccurate for time-ordered batches with on_empty='skip', since skipped batches are still counted.
  • Slices and batch materialization return new DGBatch objects; underlying graph storage is not copied but views are used for efficiency.
Source code in tgm/data/loader.py
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
def __init__(
    self,
    dg: DGraph,
    batch_size: int = 1,
    batch_unit: str = 'r',
    on_empty: Literal['skip', 'raise', None] = 'skip',
    hook_manager: HookManager | None = None,
    **kwargs: Any,
) -> None:
    if batch_size <= 0:
        raise ValueError(f'batch_size must be > 0 but got {batch_size}')

    batch_time_delta = TimeDeltaDG(batch_unit)
    logger.info(
        'Initializing DGDataLoader: batch_size=%d, batch_unit=%s',
        batch_size,
        batch_unit,
    )

    if dg.time_delta.is_event_ordered and batch_time_delta.is_time_ordered:
        raise EventOrderedConversionError(
            'Cannot iterate event-ordered dg using time-ordered batch_unit'
        )
    if dg.time_delta.is_time_ordered and batch_time_delta.is_time_ordered:
        # Ensure the graph time unit is more granular than batch time unit.
        batch_time_delta = TimeDeltaDG(batch_unit, value=batch_size)
        if dg.time_delta.is_coarser_than(batch_time_delta):
            raise InvalidDiscretizationError(
                f'Tried to construct a data loader on a DGraph with time delta: {dg.time_delta} '
                f'which is strictly coarser than the batch_unit: {batch_unit}, batch_size: {batch_size}. '
                'Either choose a larger batch size, batch unit or consider iterate using event-ordered batching.'
            )
        batch_size = int(batch_time_delta.convert(dg.time_delta))

    # Warning: Cache miss
    assert dg.start_time is not None and dg.end_time is not None

    self._dg = dg
    self._batch_size = batch_size
    self._hook_manager = hook_manager

    if batch_time_delta.is_event_ordered:
        self._slice_op = dg.slice_events
        start_idx, stop_idx = 0, dg.num_events
    else:
        self._slice_op = dg.slice_time  # type: ignore
        start_idx, stop_idx = dg.start_time, dg.end_time + 1

    if kwargs.get('drop_last', False):
        slice_start = range(start_idx, stop_idx - batch_size, batch_size)
    else:
        slice_start = range(start_idx, stop_idx, batch_size)

    super().__init__(
        slice_start, 1, shuffle=False, collate_fn=self, on_empty=on_empty, **kwargs
    )