Skip to content

Hooks

Base Classes

DGHook

Bases: Protocol

The behaviours to be executed on a DGraph before materializing.

StatelessHook dataclass

StatelessHook(
    _requires: Set[str] = set(),
    _produces: Set[str] = set(),
    _id: str | None = None,
    has_state: bool = False,
)

Bases: BaseDGHook

Base class for hooks without internal state.

StatefulHook dataclass

StatefulHook(
    _requires: Set[str] = set(),
    _produces: Set[str] = set(),
    _id: str | None = None,
    has_state: bool = False,
)

Bases: BaseDGHook

Base class for hooks that maintain internal state.

Concrete Hooks

PinMemoryHook dataclass

PinMemoryHook(
    _requires: Set[str] = set(),
    _produces: Set[str] = set(),
    _id: str | None = None,
    has_state: bool = False,
)

Bases: StatelessHook

Pin all tensors in the DGBatch to page-locked memory for faster async CPU-GPU transfers.

DeviceTransferHook

DeviceTransferHook(device: str | device)

Bases: StatelessHook

Moves all tensors in the DGBatch to the specified device.

Source code in tgm/hooks/device.py
31
32
33
def __init__(self, device: str | torch.device) -> None:
    super().__init__()
    self.device = torch.device(device)

DeduplicationHook

DeduplicationHook(
    seed_nodes_keys: List[str] | None = None,
    id: str | None = None,
)

Bases: StatelessHook, SeedableHook

Deduplicate node IDs from batch fields and create index mappings to unique node embeddings.

Note: Supports batches with or without negative samples and multi-hop neighbors.

Source code in tgm/hooks/dedup.py
24
25
26
27
28
29
30
def __init__(
    self, seed_nodes_keys: List[str] | None = None, id: str | None = None
) -> None:
    super().__init__()
    self._id = id
    self.seed_keys = seed_nodes_keys
    self.__post_init__()

TGBNegativeEdgeSamplerHook dataclass

TGBNegativeEdgeSamplerHook(
    dataset_name: str,
    split_mode: str,
    id: str | None = None,
)

Bases: TGBNegativeEdgeSamplerBase

Load data from DGraph using pre-generated TGB negative samples. Make sure to perform dataset.load_val_ns() or dataset.load_test_ns() before using this hook.

Parameters:

  • dataset_name (str) –

    The name of the TGB dataset to produce sampler for.

  • split_mode (str) –

    The split mode to use for sampling, either 'val' or 'test'.

  • id (str, default: None ) –

    A unique identifier for the hook. The hook’s name and all attributes it produces will be suffixed with this id.

Raises:

Source code in tgm/hooks/negatives/tgb_sampler.py
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
def __init__(
    self, dataset_name: str, split_mode: str, id: str | None = None
) -> None:
    super().__init__()
    if split_mode not in ['val', 'test']:
        raise ValueError(f'split_mode must be "val" or "test", got: {split_mode}')

    try:
        from tgb.utils.info import DATA_VERSION_DICT, PROJ_DIR
    except ImportError:
        raise ImportError(
            f'TGB required for {self.__class__.__name__}, try `pip install py-tgb`'
        )

    if not dataset_name.startswith(f'{self._dataset_prefix}-'):
        raise ValueError(
            'TGBNegativeEdgeSamplerHook should only be registered for '
            f'"{self._dataset_prefix}-xxx" datasets, but got: {dataset_name}'
        )

    neg_sampler = self._build_sampler(dataset_name)

    # Load evaluation sets
    root = Path(PROJ_DIR + 'datasets') / dataset_name.replace('-', '_')
    if DATA_VERSION_DICT.get(dataset_name, 1) > 1:
        version_suffix = f'_v{DATA_VERSION_DICT[dataset_name]}'
    else:
        version_suffix = ''

    ns_fname = root / f'{dataset_name}_{split_mode}_ns{version_suffix}.pkl'
    logger.debug(
        'Loading %s split (neg_sampler.load_eval_set) for dataset: %s from file: %s',
        split_mode,
        dataset_name,
        ns_fname,
    )
    neg_sampler.load_eval_set(fname=str(ns_fname), split_mode=split_mode)

    self.neg_sampler = neg_sampler
    self.split_mode = split_mode
    self._id = id
    self.__post_init__()

TGBTHGNegativeEdgeSamplerHook

TGBTHGNegativeEdgeSamplerHook(
    dataset_name: str,
    split_mode: str,
    first_node_id: int,
    last_node_id: int,
    node_type: Tensor,
    id: str | None = None,
)

Bases: TGBNegativeEdgeSamplerBase

Load data from DGraph using pre-generated TGB negative samples for heterogeneous graph. Make sure to perform dataset.load_val_ns() or dataset.load_test_ns() before using this hook.

Parameters:

  • dataset_name (str) –

    The name of the TGB dataset to produce sampler for.

  • split_mode (str) –

    The split mode to use for sampling, either 'val' or 'test'.

  • first_node_id (int) –

    identity of the first node

  • last_node_id (int) –

    identity of the last destination node

  • node_type (Tensor) –

    the node type of each node

  • id (str, default: None ) –

    A unique identifier for the hook. The hook’s name and all attributes it produces will be suffixed with this id.

Raises:

Source code in tgm/hooks/negatives/tgb_sampler.py
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
def __init__(
    self,
    dataset_name: str,
    split_mode: str,
    first_node_id: int,
    last_node_id: int,
    node_type: torch.Tensor,
    id: str | None = None,
) -> None:
    if first_node_id < 0 or last_node_id < 0:
        raise ValueError('First and last ID of node must be positive')

    if node_type is None:
        raise ValueError('Node type must not be None')

    if node_type.shape[0] < last_node_id:
        raise ValueError(f'last_node_id {last_node_id} must be within node_type')

    self._first_node_id = first_node_id
    self._last_node_id = last_node_id
    self._node_type = node_type
    super().__init__(dataset_name, split_mode, id)

TGBTKGNegativeEdgeSamplerHook

TGBTKGNegativeEdgeSamplerHook(
    dataset_name: str,
    split_mode: str,
    first_dst_id: int,
    last_dst_id: int,
    id: str | None = None,
)

Bases: TGBNegativeEdgeSamplerBase

Load data from DGraph using pre-generated TGB negative samples for knowledge graph. Make sure to perform dataset.load_val_ns() or dataset.load_test_ns() before using this hook.

Parameters:

  • dataset_name (str) –

    The name of the TGB dataset to produce sampler for.

  • split_mode (str) –

    The split mode to use for sampling, either 'val' or 'test'.

  • first_dst_id (int) –

    identity of the first destination node

  • last_dst_id (int) –

    identity of the last destination node

  • id (str, default: None ) –

    A unique identifier for the hook. The hook’s name and all attributes it produces will be suffixed with this id.

Raises:

Source code in tgm/hooks/negatives/tgb_sampler.py
261
262
263
264
265
266
267
268
269
270
271
272
273
def __init__(
    self,
    dataset_name: str,
    split_mode: str,
    first_dst_id: int,
    last_dst_id: int,
    id: str | None = None,
) -> None:
    if first_dst_id < 0 or last_dst_id < 0:
        raise ValueError('First and last ID of node must be positive')
    self._first_dst_id = first_dst_id
    self._last_dst_id = last_dst_id
    super().__init__(dataset_name, split_mode, id)

RandomNegativeEdgeSamplerHook

RandomNegativeEdgeSamplerHook(
    low: int,
    high: int,
    neg_ratio: float = 1.0,
    id: str | None = None,
)

Bases: StatelessHook

Random sampling negative edges for dynamic link prediction.

Parameters:

  • low (int) –

    The minimum node id to sample

  • high (int) ) –

    The maximum node id to sample

  • neg_ratio (float, default: 1.0 ) –

    The ratio of sampled negative destination nodes to the number of positive destination nodes (default = 1.0).

  • id (str, default: None ) –

    A unique identifier for the hook. The hook’s name and all attributes it produces will be suffixed with this id.

Source code in tgm/hooks/negatives/sampler.py
27
28
29
30
31
32
33
34
35
36
37
38
39
def __init__(
    self, low: int, high: int, neg_ratio: float = 1.0, id: str | None = None
) -> None:
    super().__init__()
    if not 0 < neg_ratio <= 1:
        raise ValueError(f'neg_ratio must be in (0, 1], got: {neg_ratio}')
    if not low < high:
        raise ValueError(f'low ({low}) must be strictly less than high ({high})')
    self.low = low
    self.high = high
    self.neg_ratio = neg_ratio
    self._id = id
    self.__post_init__()

HistoricalNegativeEdgeSamplerHook

HistoricalNegativeEdgeSamplerHook(id: str | None = None)

Bases: StatefulHook

Sample negative edges from past interactions for dynamic link prediction.

Parameters:

  • id (str, default: None ) –

    A unique identifier for the hook. The hook’s name and all attributes it produces will be suffixed with this id.

Notes

If a node doesn't have past interactions, we return PADDED_NODE_ID(-1) as the negative destination. valid_neg_mask (BoolTensor): Boolean mask of shape (num_neg,) indicating which entries in neg are real negative samples. True means the corresponding node id is a valid negative; False means the entry is a padding placeholder (PADDED_NODE_ID) and should be excluded from loss computation and evaluation.

Source code in tgm/hooks/negatives/sampler.py
82
83
84
85
86
87
88
89
90
91
92
def __init__(
    self,
    id: str | None = None,
) -> None:
    super().__init__()

    self._id = id

    self._memory: torch.Tensor | None = None
    self._count: int = 0
    self.__post_init__()

NeighborSamplerHook

NeighborSamplerHook(
    num_nbrs: List[int],
    seed_nodes_keys: List[str],
    seed_times_keys: List[str],
    directed: bool = False,
    id: str | None = None,
)

Bases: StatelessHook, SeedableHook

Load neighbors from DGraph using a memory based sampling function.

Parameters:

  • num_nbrs (List[int]) –

    Number of neighbors to sample at each hop (-1 to keep all)

  • directed (bool, default: False ) –

    If true, aggregates interactions in edge_src->edge_dst direction only (default=False).

  • seed_nodes_keys ([List[str]) –

    List of batch attribute keys to identify the initial seed nodes to sample for.

  • seed_times_keys ([List[str]) –

    List of batch attribute keys to identify the initial seed times to sample for.

  • id (str, default: None ) –

    A unique identifier for the hook. The hook’s name and all attributes it produces will be suffixed with this id.

Note

The order of the output tensors respect the order of seed_nodes_keys. For instance, for seed node keys ['edge_src', 'edge_dst', 'neg'] will have the first output index (hop 0) contain the concatenation of batch.edge_src, batch.edge_dst, batch.neg (in that order). The next index (hop 1) will contain first-hop neighbors of batch.edge_src followed by first-hop neighbors of batch.edge_dst, and then those of batch.neg. This pattern repeats for deeper hops.

Raises:

  • ValueError

    If the num_nbrs list is empty or has non-positive entries.

  • ValueError

    If len(seed_nodes_keys) != len(seed_times_keys).

Source code in tgm/hooks/neighbors/uniform.py
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def __init__(
    self,
    num_nbrs: List[int],
    seed_nodes_keys: List[str],
    seed_times_keys: List[str],
    directed: bool = False,
    id: str | None = None,
) -> None:
    super().__init__()
    if not len(num_nbrs):
        raise ValueError('num_nbrs must be non-empty')
    if not all([isinstance(x, int) and (x > 0) for x in num_nbrs]):
        raise ValueError('Each value in num_nbrs must be a positive integer')
    self._num_nbrs = num_nbrs
    self._directed = directed

    if len(seed_nodes_keys) != len(seed_times_keys):
        raise ValueError(
            f'len(seed_nodes_keys) ({len(seed_nodes_keys)}) '
            f'!= len(seed_times_keys) ({len(seed_times_keys)})\n'
            f'seed_nodes_keys={seed_nodes_keys}, '
            f'seed_times_keys={seed_times_keys}'
        )
    self._seed_nodes_keys = seed_nodes_keys
    self._seed_times_keys = seed_times_keys
    logger.debug(
        'Seed nodes keys: %s, Seed times keys: %s',
        self._seed_nodes_keys,
        self._seed_times_keys,
    )
    self._warned_seed_None = False
    self._id = id
    self.seed_keys = seed_nodes_keys
    self.__post_init__()

RecencyNeighborHook

RecencyNeighborHook(
    num_nodes: int,
    num_nbrs: List[int],
    seed_nodes_keys: List[str],
    seed_times_keys: List[str],
    directed: bool = False,
    id: str | None = None,
)

Bases: StatefulHook, SeedableHook

Load neighbors from DGraph using a recency sampling. Each node maintains a fixed number of recent neighbors.

Parameters:

  • num_nodes (int) –

    Total number of nodes to track.

  • num_nbrs (List[int]) –

    Number of neighbors to sample at each hop (max neighbors to keep).

  • directed (bool, default: False ) –

    If true, aggregates interactions in edge_src->edge_dst direction only (default=False). If not specified, defaults to batch edges: ['edge_src', 'edge_dst'] If not specified, defaults to batch seed_times: ['time', 'time']

  • seed_nodes_keys ([List[str]) –

    List of batch attribute keys to identify the initial seed nodes to sample for.

  • seed_times_keys ([List[str]) –

    List of batch attribute keys to identify the initial seed times to sample for.

  • id (str, default: None ) –

    A unique identifier for the hook. The hook’s name and all attributes it produces will be suffixed with this id.

Note
  • RecencyNeighborSamplerHook assumes queries occur in chronological order. If query timestamps lag behind the most recently pushed events, the computed neighbours may be incorrect.
  • The order of the output tensors respect the order of seed_nodes_keys. For instance, for seed node keys ['edge_src', 'edge_dst', 'neg'] will have the first output index (hop 0) contain the concatenation of batch.edge_src, batch.edge_dst, batch.neg (in that order). The next index (hop 1) will contain first-hop neighbors of batch.edge_src followed by first-hop neighbors of batch.edge_dst, and then those of batch.neg. This pattern repeats for deeper hops.

Raises:

  • ValueError

    If the num_nbrs list is empty or has non-positive entries.

  • ValueError

    If len(seed_nodes_keys) != len(seed_times_keys).

Source code in tgm/hooks/neighbors/recency.py
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
def __init__(
    self,
    num_nodes: int,
    num_nbrs: List[int],
    seed_nodes_keys: List[str],
    seed_times_keys: List[str],
    directed: bool = False,
    id: str | None = None,
) -> None:
    super().__init__()
    if not len(num_nbrs):
        raise ValueError('num_nbrs must be non-empty')
    if not all([isinstance(x, int) and (x > 0) for x in num_nbrs]):
        raise ValueError('Each value in num_nbrs must be a positive integer')

    self._num_nodes = num_nodes
    self._num_nbrs = num_nbrs
    self._max_nbrs = max(num_nbrs)
    self._directed = directed
    self._device = torch.device('cpu')

    if len(seed_nodes_keys) != len(seed_times_keys):
        raise ValueError(
            f'len(seed_nodes_keys) ({len(seed_nodes_keys)}) '
            f'!= len(seed_times_keys) ({len(seed_times_keys)})\n'
            f'seed_nodes_keys={seed_nodes_keys}, '
            f'seed_times_keys={seed_times_keys}'
        )
    self._seed_nodes_keys = seed_nodes_keys
    self._seed_times_keys = seed_times_keys
    logger.debug(
        'Seed nodes keys: %s, Seed times keys: %s',
        self._seed_nodes_keys,
        self._seed_times_keys,
    )
    self._warned_seed_None = False

    self._nbr_ids = torch.full(
        (num_nodes, self._max_nbrs), PADDED_NODE_ID, dtype=torch.int32
    )
    self._nbr_times = torch.zeros((num_nodes, self._max_nbrs), dtype=torch.int64)
    self._write_pos = torch.zeros(num_nodes, dtype=torch.int32)

    # Wait until first __call__ to infer the edge_x_dim on the underlying graph
    self._need_to_initialize_nbr_feats = True
    self._edge_x_dim = None
    self._nbr_feats = None
    self._id = id
    self.seed_keys = seed_nodes_keys
    self.__post_init__()

BatchAnalyticsHook

BatchAnalyticsHook(id: str | None = None)

Bases: StatelessHook

Compute simple batch-level statistics.

Source code in tgm/hooks/analytics/batch_analytics.py
31
32
33
34
def __init__(self, id: str | None = None) -> None:
    super().__init__()
    self._id = id
    self.__post_init__()

NodeAnalyticsHook

NodeAnalyticsHook(
    tracked_nodes: Tensor,
    num_nodes: int,
    id: str | None = None,
)

Bases: StatefulHook

Compute node-centric statistics for a specific set of tracked nodes.

This hook maintains state across batches to compute temporal statistics for a specified set of nodes.

Parameters:

  • tracked_nodes (Tensor) –

    1D tensor of node IDs to track statistics for.

  • num_nodes (int) –

    Total number of nodes in the graph.

  • id (str, default: None ) –

    A unique identifier for the hook. The hook’s name and all attributes it produces will be suffixed with this id.

Produces

node_stats (Dict[int, Dict[str, float]]): Dictionary mapping node_id to statistics: - degree: Number of edges connected to the node in the current batch. - activity: Fraction of unique timesteps in which the node has appeared. - new_neighbors: Number of new neighbors in which the node encountered in the current batch. - lifetime: Time since the node was first seen. - time_since_last_seen: Time since the node was last seen. - appearances: Total number of unique timesteps the node has appeared in. node_macro_stats (Dict): Batch-level node statistics: - node_novelty: Fraction of tracked nodes in the batch that are appearing for the first time. - new_node_count: Number of tracked nodes in the batch that are appearing for the first time. edge_stats (Dict): Batch-level edge statistics: - edge_novelty: Fraction of new edges in the batch, that is not seen in previous batches. - edge_density: Edges this batch / possible edges based on unique nodes - new_edge_count: Number of new edges in the batch, that is not seen in previous batches.

Methods:

Source code in tgm/hooks/analytics/node_analytics.py
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
def __init__(
    self, tracked_nodes: Tensor, num_nodes: int, id: str | None = None
) -> None:
    super().__init__()
    if num_nodes <= 0:
        raise ValueError('num_nodes must be positive')
    self.tracked_nodes = tracked_nodes.unique()
    self.num_nodes = num_nodes

    # Create a mask for fast lookup of tracked nodes
    self._tracked_mask = torch.zeros(num_nodes, dtype=torch.bool)
    self._tracked_mask[self.tracked_nodes] = True

    # State dictionaries for each tracked node
    self._first_seen: Dict[int, float] = {}
    self._last_seen: Dict[int, float] = {}
    self._appearances: Dict[int, int] = {}  # Count of unique timesteps per node
    self._total_timesteps: Set[float] = set()  # Track all unique timesteps seen
    self._node_timesteps: Dict[int, Set[float]] = {  # Track timesteps per node
        int(node): set() for node in self.tracked_nodes.tolist()
    }

    # Neighbor tracking
    self._all_neighbors: Dict[int, Set[int]] = {
        int(node): set() for node in self.tracked_nodes.tolist()
    }
    self._engagement_sum: Dict[int, float] = {}

    # Edge tracking
    self._seen_edges: Set[tuple] = set()

    self._id = id
    self.__post_init__()

reset_state

reset_state() -> None

Reset internal state.

Source code in tgm/hooks/analytics/node_analytics.py
221
222
223
224
225
226
227
228
229
230
231
232
def reset_state(self) -> None:
    """Reset internal state."""
    self._first_seen.clear()
    self._last_seen.clear()
    self._appearances.clear()
    self._total_timesteps.clear()
    self._node_timesteps = {
        int(node): set() for node in self.tracked_nodes.tolist()
    }
    self._all_neighbors = {int(node): set() for node in self.tracked_nodes.tolist()}
    self._engagement_sum.clear()
    self._seen_edges.clear()