Skip to content

Neural Layers

Encoders

ctan

Classes:

  • CTAN

    An implementation of CTAN.

  • CTANMemory

    The CTAN Memory model.

CTAN

CTAN(
    edge_dim: int,
    memory_dim: int,
    time_dim: int,
    node_dim: int,
    num_iters: int = 1,
    mean_delta_t: float = 0.0,
    std_delta_t: float = 1.0,
    epsilon: float = 0.1,
    gamma: float = 0.1,
)

Bases: Module

An implementation of CTAN.

Parameters:

  • edge_dim (int) –

    Dimension of edge features.

  • memory_dim (int) –

    Dimension of memory embeddings.

  • time_dim (int) –

    Dimension of time encodings.

  • node_dim (int) –

    Dimension of static/dynamic node features.

  • num_iters (int, default: 1 ) –

    Number of AntiSymmetricConv layers.

  • mean_delta_t (float, default: 0.0 ) –

    Mean delta time between edge events (used to normalize time signal).

  • std_delta_t (float, default: 1.0 ) –

    Std delta time between edge events (used to normalize time signal).

  • epsilon (float, default: 0.1 ) –

    Discretization step size for AntiSymmetricConv.

  • gamma (float, default: 0.1 ) –

    The strength of the diffusion in the AntiSymmetricConv.

Reference: https://arxiv.org/abs/2406.02740

Methods:

Source code in tgm/nn/encoder/ctan.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
def __init__(
    self,
    edge_dim: int,
    memory_dim: int,
    time_dim: int,
    node_dim: int,
    num_iters: int = 1,
    mean_delta_t: float = 0.0,
    std_delta_t: float = 1.0,
    epsilon: float = 0.1,
    gamma: float = 0.1,
) -> None:
    super().__init__()
    self.mean_delta_t = mean_delta_t
    self.std_delta_t = std_delta_t
    self.time_enc = TimeEncoder(time_dim)
    self.enc_x = nn.Linear(memory_dim + node_dim, memory_dim)

    phi = TransformerConv(
        memory_dim, memory_dim, edge_dim=edge_dim + time_dim, root_weight=False
    )
    self.aconv = AntiSymmetricConv(
        memory_dim, phi, num_iters=num_iters, epsilon=epsilon, gamma=gamma
    )
forward
forward(
    x: Tensor,
    last_update: Tensor,
    edge_index: Tensor,
    t: Tensor,
    msg: Tensor,
) -> Tensor

Forward pass.

Parameters:

  • x (PyTorch Float Tensor) –

    Node features.

  • last_update (PyTorch Tensor) –

    Last memory update timestamps.

  • edge_index (PyTorch Tensor) –

    Graph edge indices.

  • t (PyTorch Tensor) –

    Graph edge timestamps.

  • msg (PyTorch Tensor) –

    Memory embeddings.

Returns:

  • PyTorch Float Tensor

    Embeddings for the batch of node ids.

Source code in tgm/nn/encoder/ctan.py
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
def forward(
    self,
    x: torch.Tensor,
    last_update: torch.Tensor,
    edge_index: torch.Tensor,
    t: torch.Tensor,
    msg: torch.Tensor,
) -> torch.Tensor:
    """Forward pass.

    Args:
        x (PyTorch Float Tensor): Node features.
        last_update (PyTorch Tensor): Last memory update timestamps.
        edge_index (PyTorch Tensor): Graph edge indices.
        t (PyTorch Tensor): Graph edge timestamps.
        msg (PyTorch Tensor): Memory embeddings.

    Returns:
        (PyTorch Float Tensor): Embeddings for the batch of node ids.
    """
    rel_t = (last_update[edge_index[0]] - t).abs()
    rel_t = ((rel_t - self.mean_delta_t) / self.std_delta_t).to(x.dtype)
    enc_x = self.enc_x(x)
    edge_attr = torch.cat([msg, self.time_enc(rel_t)], dim=-1)
    z = self.aconv(enc_x, edge_index, edge_attr=edge_attr)
    z = torch.tanh(z)
    return z

CTANMemory

CTANMemory(
    num_nodes: int,
    memory_dim: int,
    aggr_module: Callable,
    init_time: int = 0,
)

Bases: Module

The CTAN Memory model.

Parameters:

  • num_nodes (int) –

    The number of nodes to save memories for.

  • memory_dim (int) –

    The hidden memory dimensionality.

  • aggr_module (Callable) –

    The message aggregator function which aggregates messages to the same destination into a single representation.

  • init_time (int, default: 0 ) –

    Start time of the graph, used during memory reset.

Reference: https://arxiv.org/abs/2406.02740

Source code in tgm/nn/encoder/ctan.py
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
def __init__(
    self, num_nodes: int, memory_dim: int, aggr_module: Callable, init_time: int = 0
) -> None:
    super().__init__()

    self.num_nodes = num_nodes
    self.memory_dim = memory_dim
    self.init_time = init_time
    self.aggr_module = aggr_module

    self.register_buffer('memory', torch.zeros(num_nodes, memory_dim))
    self.register_buffer(
        'last_update', torch.ones(self.num_nodes, dtype=torch.long) * init_time
    )
    self.register_buffer('_assoc', torch.empty(num_nodes, dtype=torch.long))

dygformer

Adapted from https://github.com/yule-BUAA/DyGLib_TGB.

Classes:

DyGFormer

DyGFormer(
    node_feat_dim: int,
    edge_x_dim: int,
    time_feat_dim: int,
    channel_embedding_dim: int,
    output_dim: int = 172,
    patch_size: int = 1,
    num_layers: int = 2,
    num_heads: int = 2,
    dropout: float = 0.1,
    max_input_sequence_length: int = 512,
    num_channels: int = 4,
    time_encoder: Callable[..., Module] = Time2Vec,
    device: str = 'cpu',
)

Bases: Module

An implementation of DyGFormer.

Parameters:

  • node_feat_dim (int) –

    Dimension of static/dynamic node features (d_N).

  • edge_x_dim (int) –

    Dimension of edge features (d_E).

  • time_feat_dim (int) –

    Dimension of time encodings (d_T).

  • channel_embedding_dim (int) –

    Dimension of each channel embedding.

  • output_dim (int, default: 172 ) –

    Dimension of output embedding.

  • patch_size (int, default: 1 ) –

    Path size (\mathbf{P}).

  • num_layers (int, default: 2 ) –

    Number of transformer layers.

  • num_heads (int, default: 2 ) –

    Number of attention heads.

  • dropout (float, default: 0.1 ) –

    Drop out rate.

  • max_input_sequence_length (int, default: 512 ) –

    maximal length of the input sequence for each node.

  • time_encoder (PyTorch Module) , default: Time2Vec ) –

    Time encoder module.

  • device (str) , default: 'cpu' ) –

    cpu or cuda

Reference: https://arxiv.org/abs/2303.13047.

Source code in tgm/nn/encoder/dygformer.py
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
def __init__(
    self,
    node_feat_dim: int,
    edge_x_dim: int,
    time_feat_dim: int,
    channel_embedding_dim: int,
    output_dim: int = 172,
    patch_size: int = 1,
    num_layers: int = 2,
    num_heads: int = 2,
    dropout: float = 0.1,
    max_input_sequence_length: int = 512,
    num_channels: int = 4,
    time_encoder: Callable[..., nn.Module] = Time2Vec,
    device: str = 'cpu',
) -> None:
    super().__init__()
    if max_input_sequence_length % patch_size != 0:
        raise ValueError('Max sequence length must be a multiple of path size')

    self.node_feat_dim = node_feat_dim
    self.edge_x_dim = edge_x_dim
    self.time_feat_dim = time_feat_dim
    self.channel_embedding_dim = channel_embedding_dim
    self.patch_size = patch_size
    self.max_input_sequence_length = max_input_sequence_length
    self.neighbor_co_occurrence_feat_dim = self.channel_embedding_dim
    self.device = device
    self.num_channels = num_channels
    self.num_patches = max_input_sequence_length // patch_size

    self.time_encoder = time_encoder(time_feat_dim)
    self.co_occurrence_encoder = NeighborCooccurrenceEncoder(
        feat_dim=self.neighbor_co_occurrence_feat_dim,
        device=self.device,
    )
    self.projection_layer = nn.ModuleDict(
        {
            'node': nn.Linear(
                in_features=self.patch_size * self.node_feat_dim,
                out_features=self.channel_embedding_dim,
                bias=True,
            ),
            'edge': nn.Linear(
                in_features=self.patch_size * self.edge_x_dim,
                out_features=self.channel_embedding_dim,
                bias=True,
            ),
            'time': nn.Linear(
                in_features=self.patch_size * self.time_feat_dim,
                out_features=self.channel_embedding_dim,
                bias=True,
            ),
            'neighbor_co_occurrence': nn.Linear(
                in_features=self.patch_size * self.neighbor_co_occurrence_feat_dim,
                out_features=self.channel_embedding_dim,
                bias=True,
            ),
        }
    ).to(device)
    self.transformers = nn.ModuleList(
        [
            TransformerEncoder(
                attention_dim=self.num_channels * self.channel_embedding_dim,
                num_heads=num_heads,
                dropout=dropout,
            )
            for _ in range(num_layers)
        ]
    ).to(device)

    self.output_layer = nn.Linear(
        in_features=self.num_channels * self.channel_embedding_dim,
        out_features=output_dim,
        bias=True,
    ).to(device)

NeighborCooccurrenceEncoder

NeighborCooccurrenceEncoder(feat_dim: int, device: str)

Bases: Module

An implementation of Neighbor Co-occurrence Encoding Scheme.

Parameters:

  • feat_dim (int) –

    dimension of neighbor co-occurrence features (encodings).

  • device (str) –

    Device (cpu or gpu)

Reference: https://arxiv.org/abs/2303.13047.

Methods:

  • forward

    Forward pass. Encode neighbor co-occurrence (Section 4.1).

Source code in tgm/nn/encoder/dygformer.py
23
24
25
26
27
28
29
30
31
32
def __init__(self, feat_dim: int, device: str) -> None:
    super().__init__()
    self.feat_dim = feat_dim
    self.device = device

    self.neighbor_co_occurrence_encoder = nn.Sequential(
        nn.Linear(in_features=1, out_features=self.feat_dim),
        nn.ReLU(),
        nn.Linear(in_features=self.feat_dim, out_features=self.feat_dim),
    ).to(device)
forward
forward(
    src_neighbour_nodes_ids: Tensor,
    dst_neighbour_nodes_ids: Tensor,
) -> Tuple[Tensor, Tensor]

Forward pass. Encode neighbor co-occurrence (Section 4.1).

Parameters:

  • src_neighbour_nodes_ids (Tensor) –

    Padded list of source node's neighbour.

  • dst_neighbour_nodes_ids (Tensor) –

    Padded list of destination node's neighbour.

Returns:

  • X ( PyTorch Float Tensor ) –

    Neighbor co-occurrence features (X^{t}_{*,C}).

Source code in tgm/nn/encoder/dygformer.py
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def forward(
    self,
    src_neighbour_nodes_ids: torch.Tensor,
    dst_neighbour_nodes_ids: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    r"""Forward pass. Encode neighbor co-occurrence (Section 4.1).

    Args:
        src_neighbour_nodes_ids (Tensor): Padded list of source node's neighbour.
        dst_neighbour_nodes_ids (Tensor): Padded list of destination node's neighbour.

    Returns:
        X (PyTorch Float Tensor): Neighbor co-occurrence features (`X^{t}_{*,C}`).
    """
    source_freq, dst_freq = self._count_nodes_freq(
        src_neighbour_nodes_ids, dst_neighbour_nodes_ids
    )
    src_neighbors_co_occurrence_feat = self.neighbor_co_occurrence_encoder(
        source_freq.unsqueeze(dim=-1)
    ).sum(dim=2)
    dst_neighbors_co_occurrence_feat = self.neighbor_co_occurrence_encoder(
        dst_freq.unsqueeze(dim=-1)
    ).sum(dim=2)
    return src_neighbors_co_occurrence_feat, dst_neighbors_co_occurrence_feat

TransformerEncoder

TransformerEncoder(
    attention_dim: int, num_heads: int, dropout: float = 0.1
)

Bases: Module

An implementation of Transformer Encoder.

Parameters:

  • attention_dim (int) –

    dimension of the attention vector.

  • num_heads (int) –

    number of attention heads.

  • dropout (float, default: 0.1 ) –

    dropout rate.

Reference: https://arxiv.org/abs/2303.13047.

Methods:

  • forward

    Forward pass. Encode the inputs by Transformer encoder (Section 4.1).

Source code in tgm/nn/encoder/dygformer.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
def __init__(
    self, attention_dim: int, num_heads: int, dropout: float = 0.1
) -> None:
    super().__init__()
    self.attention_dim = attention_dim
    self.num_heads = num_heads
    self.dropout_rate = dropout

    self.multi_head_attention = nn.MultiheadAttention(
        embed_dim=attention_dim, num_heads=num_heads, dropout=dropout
    )

    self.dropout = nn.Dropout(self.dropout_rate)

    self.linear_layers = nn.ModuleList(
        [
            nn.Linear(in_features=attention_dim, out_features=4 * attention_dim),
            nn.Linear(in_features=4 * attention_dim, out_features=attention_dim),
        ]
    )
    self.norm_layers = nn.ModuleList(
        [nn.LayerNorm(attention_dim), nn.LayerNorm(attention_dim)]
    )
forward
forward(inputs: Tensor) -> Tensor

Forward pass. Encode the inputs by Transformer encoder (Section 4.1).

Parameters:

  • inputs (PyTorch Float Tensor) –

    Z^{t} = [Z^{t}_u, Z^{t}_v].

Returns:

  • H ( PyTorch Float Tensor ) –

    Representations of all nodes.

Source code in tgm/nn/encoder/dygformer.py
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
    r"""Forward pass. Encode the inputs by Transformer encoder (Section 4.1).

    Args:
        inputs (PyTorch Float Tensor): `Z^{t} = [Z^{t}_u, Z^{t}_v]`.

    Returns:
        H (PyTorch Float Tensor): Representations of all nodes.
    """
    transposed_inputs = inputs.transpose(0, 1)
    transposed_inputs = self.norm_layers[0](transposed_inputs)

    # E.q 5 - Section 4.1
    hidden_states = self.multi_head_attention(
        query=transposed_inputs, key=transposed_inputs, value=transposed_inputs
    )[0].transpose(0, 1)

    # E.q 6 - Section 4.1
    outputs = inputs + self.dropout(hidden_states)

    # E.q 7 - Section 4.1
    hidden_states = self.linear_layers[1](
        self.dropout(F.gelu(self.linear_layers[0](self.norm_layers[1](outputs))))
    )

    # E.q 7 - Section 4.1
    outputs = outputs + self.dropout(hidden_states)

    return outputs

tpnet

Classes:

  • RandomProjectionModule

    This model maintains a series of temporal walk matrices $A_^(0)(t),A_^(1)(t),...,A^(k)(t)$ through

  • TPNet

    An implementation of TPNet.

RandomProjectionModule

RandomProjectionModule(
    num_nodes: int,
    num_layer: int,
    time_decay_weight: float,
    beginning_time: float,
    use_matrix: bool = True,
    scale_random_projection: bool = True,
    enforce_dim: int | None = None,
    num_edges: int | None = None,
    dim_factor: int | None = None,
    concat_src_dst: bool = True,
    device: str = 'cpu',
)

Bases: Module

This model maintains a series of temporal walk matrices $A_^(0)(t),A_^(1)(t),...,A^(k)(t)$ through random feature propagation, and extract the pairwise features from the obtained random projections.

Parameters:

  • num_nodes (int) –

    the number of nodes

  • num_layer (int) –

    the max hop of the maintained temporal walk matrices

  • time_decay_weight (float) –

    the time decay weight (lambda of the original paper)

  • beginning_time (float) –

    the earliest time in the given temporal graph

  • use_matrix (bool, default: True ) –

    if True, explicitly maintain the temporal walk matrices

  • scale_random_projection (bool, default: True ) –

    if True, the inner product of nodes' random projections will be scaled

  • enforce_dim (int, default: None ) –

    if not None, explicitly set the dimension of random projections to enforce_dim

  • num_edges (int, default: None ) –

    the number of edges

  • dim_factor (int, default: None ) –

    the parameter to control the dimension of random projections. Specifically, the dimension of the random projections is set to be dim_factor * log(2*edge_num)

  • concat_src_dst (bool, default: True ) –

    If true, random_feature will be computed from the concatenation between src and dst by model design, this is True by default. To make the model scalable, this can be set to False; however, there will be performance trade-off.

  • device (str, default: 'cpu' ) –

    torch device

Note: For large-scale dataset, the authors suggested to set use_matrix=False and use number of edge and dim_factor=10 to make it scalable.

Source code in tgm/nn/encoder/tpnet.py
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
def __init__(
    self,
    num_nodes: int,
    num_layer: int,
    time_decay_weight: float,
    beginning_time: float,
    use_matrix: bool = True,
    scale_random_projection: bool = True,
    enforce_dim: int | None = None,
    num_edges: int | None = None,
    dim_factor: int | None = None,
    concat_src_dst: bool = True,
    device: str = 'cpu',
) -> None:
    super().__init__()
    if not use_matrix:
        if enforce_dim is not None:
            self.dim = enforce_dim
        elif num_edges is not None and dim_factor is not None:
            self.dim = min(int(math.log(num_edges * 2)) * dim_factor, num_nodes)
        else:
            raise ValueError(
                'When `use_matrix` is False, either providing enforce_dim or both num_edges and dim_factor'
            )
    else:
        self.dim = num_nodes
    self.num_layer = num_layer
    self.time_decay_weight = time_decay_weight
    self.use_matrix = use_matrix
    self.device = device
    self.scale = scale_random_projection
    self.concat_src_dst = concat_src_dst

    self.beginning_time = nn.Parameter(
        torch.tensor(beginning_time), requires_grad=False
    ).to(device)
    self.now_time = nn.Parameter(
        torch.tensor(beginning_time), requires_grad=False
    ).to(device)
    self.random_projections = nn.ParameterList()

    if use_matrix:
        for i in range(self.num_layer + 1):
            if i == 0:
                self.random_projections.append(
                    nn.Parameter(torch.eye(self.dim), requires_grad=False)
                )
            else:
                self.random_projections.append(
                    nn.Parameter(
                        torch.zeros_like(self.random_projections[i - 1]),
                        requires_grad=False,
                    )
                )
    else:
        for i in range(self.num_layer + 1):
            if i == 0:
                self.random_projections.append(
                    nn.Parameter(
                        torch.normal(
                            0, 1 / math.sqrt(self.dim), (num_nodes, self.dim)
                        ),
                        requires_grad=False,
                    )
                )
            else:
                self.random_projections.append(
                    nn.Parameter(
                        torch.zeros_like(self.random_projections[i - 1]),
                        requires_grad=False,
                    )
                )
    if concat_src_dst:
        self.out_dim = (2 * self.num_layer + 2) ** 2
    else:
        self.out_dim = (self.num_layer + 1) ** 2
    self.mlp = nn.Sequential(
        nn.Linear(self.out_dim, self.out_dim * 4),
        nn.ReLU(),
        nn.Linear(self.out_dim * 4, self.out_dim),
    )

TPNet

TPNet(
    node_feat_dim: int,
    edge_x_dim: int,
    time_feat_dim: int,
    output_dim: int,
    num_neighbors: int,
    num_layers: int = 2,
    dropout: float = 0.1,
    random_projections: RandomProjectionModule
    | None = None,
    device: str = 'cpu',
    time_encoder: Callable[..., Module] = Time2Vec,
)

Bases: Module

An implementation of TPNet.

Parameters:

  • node_feat_dim (int) –

    Dimension of static/dynamic node features (d_N).

  • edge_x_dim (int) –

    Dimension of edge features (d_E).

  • time_feat_dim (int) –

    Dimension of time encodings (d_T).

  • output_dim (int) –

    Dimension of output embedding.

  • num_neighbors (int) –

    Number of recent temporal neighbors consider

  • num_layers (int, default: 2 ) –

    Number of transformer layers.

  • dropout (float, default: 0.1 ) –

    Drop out rate.

  • random_projections (Module, default: None ) –

    Random projection module that maintains a series temporal walk matrices

  • device (str) , default: 'cpu' ) –

    cpu or cuda

  • time_encoder (PyTorch Module) , default: Time2Vec ) –

    Time encoder module.

Reference: https://arxiv.org/abs/2410.04013.

Source code in tgm/nn/encoder/tpnet.py
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
def __init__(
    self,
    node_feat_dim: int,
    edge_x_dim: int,
    time_feat_dim: int,
    output_dim: int,
    num_neighbors: int,
    num_layers: int = 2,
    dropout: float = 0.1,
    random_projections: RandomProjectionModule | None = None,
    device: str = 'cpu',
    time_encoder: Callable[..., nn.Module] = Time2Vec,
) -> None:
    super().__init__()
    self.device = device
    self.time_encoder = time_encoder(time_feat_dim).to(device)
    self.random_projections = random_projections
    self.num_neighbors = num_neighbors
    if self.random_projections is None:
        self.random_feature_dim = 0
    else:
        self.random_feature_dim = self.random_projections.out_dim * 2

    self.projection_layer = nn.Sequential(
        nn.Linear(
            node_feat_dim + edge_x_dim + time_feat_dim + self.random_feature_dim,
            output_dim * 2,
        ),
        nn.ReLU(),
        nn.Linear(output_dim * 2, output_dim),
    ).to(device)
    self.mlp_mixers = nn.ModuleList(
        [
            MLPMixer(
                num_tokens=num_neighbors,
                num_channels=output_dim,
                token_dim_expansion_factor=0.5,
                channel_dim_expansion_factor=4.0,
                dropout=dropout,
            ).to(device)
            for _ in range(num_layers)
        ]
    )

gclstm

Adapted from https://github.com/benedekrozemberczki/pytorch_geometric_temporal/blob/master/torch_geometric_temporal/nn/recurrent/gc_lstm.py.

Classes:

  • GCLSTM

    An implementation of Integrated Graph Convolutional Long Short Term Memory Cell.

GCLSTM

GCLSTM(
    in_channels: int,
    out_channels: int,
    K: int,
    normalization: str = 'sym',
    bias: bool = True,
)

Bases: Module

An implementation of Integrated Graph Convolutional Long Short Term Memory Cell.

Parameters:

  • in_channels (int) –

    Number of input features.

  • out_channels (int) –

    Number of output features.

  • K (int) –

    Chebyshev filter size :math:K.

  • normalization (str, default: 'sym' ) –

    The normalization scheme for the graph Laplacian (default: :obj:"sym"):

    1. :obj:None: No normalization :math:\mathbf{L} = \mathbf{D} - \mathbf{A}

    2. :obj:"sym": Symmetric normalization :math:\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{-1/2}

    3. :obj:"rw": Random-walk normalization :math:\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1} \mathbf{A}

    You need to pass :obj:lambda_max to the :meth:forward method of this operator in case the normalization is non-symmetric. :obj:\lambda_max should be a :class:torch.Tensor of size :obj:[num_graphs] in a mini-batch scenario and a scalar/zero-dimensional tensor when operating on single graphs. You can pre-compute :obj:lambda_max via the :class:torch_geometric.transforms.LaplacianLambdaMax transform.

  • bias (bool, default: True ) –

    If set to :obj:False, the layer will not learn an additive bias. (default: :obj:True)

Reference: https://arxiv.org/abs/1812.04206

Source code in tgm/nn/encoder/gclstm.py
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def __init__(
    self,
    in_channels: int,
    out_channels: int,
    K: int,
    normalization: str = 'sym',
    bias: bool = True,
) -> None:
    super().__init__()
    self.in_channels = in_channels
    self.out_channels = out_channels
    self.K = K
    self.normalization = normalization
    self.bias = bias

    self._create_input_gate_parameters_and_layers()
    self._create_forget_gate_parameters_and_layers()
    self._create_cell_state_parameters_and_layers()
    self._create_output_gate_parameters_and_layers()
    glorot(self.W_i)
    glorot(self.W_f)
    glorot(self.W_c)
    glorot(self.W_o)
    zeros(self.b_i)
    zeros(self.b_f)
    zeros(self.b_c)
    zeros(self.b_o)

tgcn

Adapted from https://github.com/benedekrozemberczki/pytorch_geometric_temporal/blob/master/torch_geometric_temporal/nn/recurrent/temporalgcn.py.

Classes:

  • TGCN

    An implementation of Temporal Graph Convolutional Gated Recurrent Cell.

TGCN

TGCN(
    in_channels: int,
    out_channels: int,
    improved: bool = False,
    cached: bool = False,
    add_self_loops: bool = True,
)

Bases: Module

An implementation of Temporal Graph Convolutional Gated Recurrent Cell.

Parameters:

  • in_channels (int) –

    Number of input features.

  • out_channels (int) –

    Number of output features.

  • improved (bool, default: False ) –

    Stronger self loops. Default is False. If improved = True, the self-loops are added A+2I instead of A+I giving each node’s own features more influence during aggregation

  • cached (bool, default: False ) –

    Caching the message weights. Default is False. The layer computes the normalized adjacency matrix only once. Speed up training but limit to transductive learning scenario (graph structure is assumed to be static)

  • add_self_loops (bool, default: True ) –

    Adding self-loops for smoothing. Default is True.

Reference: https://arxiv.org/abs/1811.05320

Source code in tgm/nn/encoder/tgcn.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def __init__(
    self,
    in_channels: int,
    out_channels: int,
    improved: bool = False,
    cached: bool = False,
    add_self_loops: bool = True,
) -> None:
    super().__init__()
    self.in_channels = in_channels
    self.out_channels = out_channels
    self.improved = improved
    self.cached = cached
    self.add_self_loops = add_self_loops

    self._create_candidate_state_parameters_and_layers()
    self._create_reset_gate_parameters_and_layers()
    self._create_update_gate_parameters_and_layers()

roland

Classes:

  • ROLAND

    An implementation of ROLAND.

ROLAND

ROLAND(
    input_channel: int,
    out_channel: int,
    num_nodes: int,
    dropout: float = 0.0,
    update: str | None = 'learnable',
    tau: float = 0.5,
)

Bases: Module

An implementation of ROLAND. https://arxiv.org/abs/2208.07239 .

Parameters:

  • input_channel (int) –

    Dimension of input.

  • out_channel (int) –

    Dimension of output.

  • num_nodes (int) –

    Maximum number of nodes.

  • dropout (float, default: 0.0 ) –

    dropout rate

  • update (str, default: 'learnable' ) –

    update mechanism. Choose from ['moving','learnable','gru','mlp',None] If update is set to None, the embedding will be update with tau

Reference: https://github.com/manuel-dileo/dynamic-gnn .

Source code in tgm/nn/encoder/roland.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def __init__(
    self,
    input_channel: int,
    out_channel: int,
    num_nodes: int,
    dropout: float = 0.0,
    update: str | None = 'learnable',
    tau: float = 0.5,
) -> None:
    assert update in ('moving', 'learnable', 'gru', 'mlp', None)

    super(ROLAND, self).__init__()

    self.conv1 = GCNConv(input_channel, out_channel)
    self.conv2 = GCNConv(out_channel, out_channel)

    self.dropout = dropout
    self.update = update
    if update == 'moving':
        self.tau = torch.Tensor([0])
    elif update == 'learnable':
        self.tau = torch.nn.Parameter(torch.Tensor([0]))
    elif update == 'gru':
        self.gru1 = GRUCell(out_channel, out_channel)
        self.gru2 = GRUCell(out_channel, out_channel)
    elif update == 'mlp':
        self.mlp1 = Linear(out_channel * 2, out_channel)
        self.mlp2 = Linear(out_channel * 2, out_channel)
    else:
        assert tau >= 0 and tau <= 1
        self.tau = torch.Tensor([tau])
    self.previous_embeddings = [
        torch.Tensor([[0 for i in range(out_channel)] for j in range(num_nodes)]),
        torch.Tensor([[0 for i in range(out_channel)] for j in range(num_nodes)]),
    ]

Decoders

graphproppred

Classes:

  • GraphPredictor

    Perform pooling over provided node features and perform graph level task.

Functions:

GraphPredictor

GraphPredictor(
    in_dim: int,
    out_dim: int = 1,
    nlayers: int = 2,
    hidden_dim: int = 64,
    graph_pooling: str = 'mean',
)

Bases: Module

Perform pooling over provided node features and perform graph level task.

Parameters:

  • in_dim (int) –

    Dimension of input

  • out_dim (int, default: 1 ) –

    Dimension of output

  • nlayers (int, default: 2 ) –

    Number of layers

  • hidden_dim (int, default: 64 ) –

    Size of each hidden embeddings

  • graph_pooling (str, default: 'mean' ) –

    graph pooling operation (mean, sum.)

Methods:

Source code in tgm/nn/decoder/graphproppred.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def __init__(
    self,
    in_dim: int,
    out_dim: int = 1,
    nlayers: int = 2,
    hidden_dim: int = 64,
    graph_pooling: str = 'mean',
) -> None:
    super().__init__()
    if graph_pooling not in POOLING_OP:
        raise ValueError(
            f'{graph_pooling} pooling operations is not supported. Please choose from {list(POOLING_OP.keys())}'
        )

    self.model = Sequential()
    self.model.append(nn.Linear(in_dim, hidden_dim))
    self.model.append(nn.ReLU())

    for i in range(1, nlayers - 1):
        self.model.append(nn.Linear(hidden_dim, hidden_dim))
        self.model.append(nn.ReLU())

    self.model.append(nn.Linear(hidden_dim, out_dim))
    self.graph_pooling = POOLING_OP[graph_pooling]
forward
forward(z_nodes: Tensor) -> Tensor

Forward pass.

Parameters:

  • z_nodes (Tensor) –

    embedding of nodes

Source code in tgm/nn/decoder/graphproppred.py
57
58
59
60
61
62
63
64
def forward(self, z_nodes: torch.Tensor) -> torch.Tensor:
    r"""Forward pass.

    Args:
        z_nodes (torch.Tensor): embedding of nodes
    """
    z_graph = self.graph_pooling(z_nodes)
    return self.model(z_graph)

mean_pooling

mean_pooling(z: Tensor) -> Tensor

Default graph pooling: Mean pooling.

Source code in tgm/nn/decoder/graphproppred.py
6
7
8
9
def mean_pooling(z: torch.Tensor) -> torch.Tensor:
    r"""Default graph pooling: Mean pooling."""
    # @TODO: we can define this in different module and have a base class for this
    return torch.mean(z, dim=0).squeeze()

sum_pooling

sum_pooling(z: Tensor) -> Tensor

Default graph pooling: Sunm pooling.

Source code in tgm/nn/decoder/graphproppred.py
12
13
14
15
def sum_pooling(z: torch.Tensor) -> torch.Tensor:
    r"""Default graph pooling: Sunm pooling."""
    # @TODO: we can define this in different module and have a base class for this
    return torch.sum(z, dim=0).squeeze()

linkproppred

Classes:

  • LearnableSumMerge

    Sum node-level embeddings after a linear projection.

  • LinkPredictor

    Compute edge embedding given src and dst node embeddings.

Functions:

  • cat_merge

    Default merging operation: Concat.

LearnableSumMerge

LearnableSumMerge(node_dim: int)

Bases: Module

Sum node-level embeddings after a linear projection.

Source code in tgm/nn/decoder/linkproppred.py
17
18
19
20
def __init__(self, node_dim: int) -> None:
    super().__init__()
    self.lin_src = nn.Linear(node_dim, node_dim)
    self.lin_dst = nn.Linear(node_dim, node_dim)

LinkPredictor

LinkPredictor(
    node_dim: int,
    out_dim: int = 1,
    nlayers: int = 2,
    hidden_dim: int = 64,
    merge_op: str = 'concat',
)

Bases: Module

Compute edge embedding given src and dst node embeddings.

Parameters:

  • node_dim (int) –

    Dimension of node embedding

  • out_dim (int, default: 1 ) –

    Dimension of output

  • nlayers (int, default: 2 ) –

    Number of layers

  • hidden_dim (int, default: 64 ) –

    Size of each hidden embedding

  • merge_op (str, default: 'concat' ) –

    Operation to merge 2 node embeddings (concat)

Methods:

Source code in tgm/nn/decoder/linkproppred.py
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def __init__(
    self,
    node_dim: int,
    out_dim: int = 1,
    nlayers: int = 2,
    hidden_dim: int = 64,
    merge_op: str = 'concat',
) -> None:
    super().__init__()

    if merge_op not in MERGE_OP:
        raise ValueError(
            f'{merge_op} merge operations is not support. Please choose from {list(MERGE_OP.keys())}'
        )

    if merge_op == 'concat':
        self.merge = MERGE_OP[merge_op]
        in_dim = node_dim * 2
    else:
        self.merge = MERGE_OP[merge_op](node_dim)
        in_dim = node_dim

    self.model = Sequential()
    self.model.append(nn.Linear(in_dim, hidden_dim))
    self.model.append(nn.ReLU())

    for _ in range(1, nlayers - 1):
        self.model.append(nn.Linear(hidden_dim, hidden_dim))
        self.model.append(nn.ReLU())

    self.model.append(nn.Linear(hidden_dim, out_dim))
forward
forward(z_src: Tensor, z_dst: Tensor) -> Tensor

Forward pass.

Parameters:

  • z_src (Tensor) –

    embedding of src node

  • z_dst (Tensor) –

    embedding of dst node

Source code in tgm/nn/decoder/linkproppred.py
75
76
77
78
79
80
81
82
83
def forward(self, z_src: torch.Tensor, z_dst: torch.Tensor) -> torch.Tensor:
    r"""Forward pass.

    Args:
        z_src (torch.Tensor): embedding of src node
        z_dst (torch.Tensor): embedding of dst node
    """
    h = self.merge(z_src, z_dst)
    return self.model(h).view(-1)

cat_merge

cat_merge(z_src: Tensor, z_dst: Tensor) -> Tensor

Default merging operation: Concat.

Source code in tgm/nn/decoder/linkproppred.py
 8
 9
10
11
def cat_merge(z_src: torch.Tensor, z_dst: torch.Tensor) -> torch.Tensor:
    r"""Default merging operation: Concat."""
    # @TODO: we can define this in different module and have a base class for this
    return torch.cat([z_src, z_dst], dim=1)

nodeproppred

Classes:

NodePredictor

NodePredictor(
    in_dim: int,
    out_dim: int = 1,
    nlayers: int = 2,
    hidden_dim: int = 64,
)

Bases: Module

Encoder for node property prediction.

Parameters:

  • in_dim (int) –

    Dimension of input

  • out_dim (int, default: 1 ) –

    Dimension of output

  • hidden_dim (int, default: 64 ) –

    Size of hidden embedding

Methods:

Source code in tgm/nn/decoder/nodeproppred.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
def __init__(
    self,
    in_dim: int,
    out_dim: int = 1,
    nlayers: int = 2,
    hidden_dim: int = 64,
) -> None:
    super().__init__()

    self.model = Sequential()
    self.model.append(nn.Linear(in_dim, hidden_dim))
    self.model.append(nn.ReLU())

    for i in range(1, nlayers - 1):
        self.model.append(nn.Linear(hidden_dim, hidden_dim))
        self.model.append(nn.ReLU())

    self.model.append(nn.Linear(hidden_dim, out_dim))
forward
forward(z_node: Tensor) -> Tensor

Forward pass.

Parameters:

  • z_node (Tensor) –

    embedding of a node

Source code in tgm/nn/decoder/nodeproppred.py
34
35
36
37
38
39
40
def forward(self, z_node: torch.Tensor) -> torch.Tensor:
    r"""Forward pass.

    Args:
        z_node (torch.Tensor): embedding of a node
    """
    return self.model(z_node)

ncnpred

Classes:

  • NCNPredictor

    An implementation of Temporal Neural Common Neighbor (TNCN).

NCNPredictor

NCNPredictor(
    in_channels: int,
    hidden_dim: int,
    out_channels: int,
    k: int = 2,
    cn_time_decay: bool = False,
)

Bases: Module

An implementation of Temporal Neural Common Neighbor (TNCN).

Parameters:

  • in_channels (int) –

    Number of input features.

  • out_channels (int) –

    Number of output features.

  • hidden_dim (int) –

    Size of each hidden embedding.

  • k (int, default: 2 ) –

    define k-th hop common neighbour (CN) embedding extraction (select from 2/4/8)

  • cn_time_decay (bool, default: False ) –

    indicate whether applying decay on time

Reference: https://arxiv.org/abs/2406.07926.

Methods:

  • forward

    Forward pass.

  • get_cn_emb

    Obtain the CNs embeddings for each node pair in a batch (Torch version).

Source code in tgm/nn/decoder/ncnpred.py
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
def __init__(
    self,
    in_channels: int,
    hidden_dim: int,
    out_channels: int,
    k: int = 2,
    cn_time_decay: bool = False,
) -> None:
    super().__init__()
    if k not in [2, 4, 8]:
        raise ValueError('Please choose k from [2,4,8]')

    self.k = k
    self.xslin = torch.nn.Linear(k * in_channels, out_channels)
    self.xsmlp = torch.nn.Sequential(
        torch.nn.Linear(k * in_channels, hidden_dim),
        torch.nn.ReLU(),
        torch.nn.Linear(hidden_dim, out_channels),
    )
    self.cn_time_decay = cn_time_decay
forward
forward(
    x: Tensor,
    edge_index: Tensor,
    tar_ei: Tensor,
    time_info: Optional[Tuple[Tensor, Tensor]] = None,
) -> Tensor

Forward pass.

Parameters:

  • x (Tensor) –

    node features,

  • edge_index (Tensor) –

    edges list of subgraph,

  • tar_ei (Tensor) –

    edges list for prediction ,

  • time_info (Optional[Tuple[Tensor, Tensor]], default: None ) –

    A tuple of last update and current time of each edge

Source code in tgm/nn/decoder/ncnpred.py
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
def forward(
    self,
    x: torch.Tensor,
    edge_index: torch.Tensor,
    tar_ei: torch.Tensor,
    time_info: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
    r"""Forward pass.

    Args:
        x (torch.Tensor): node features,
        edge_index (torch.Tensor): edges list of subgraph,
        tar_ei (torch.Tensor): edges list for prediction ,
        time_info (Optional[Tuple[torch.Tensor, torch.Tensor]]): A tuple of last update and current time of each edge
    """
    xi = x[tar_ei[0]]
    xj = x[tar_ei[1]]

    xij = torch.mul(xi, xj).reshape(-1, x.size(1))
    cn_emb = self.get_cn_emb(x, edge_index, tar_ei, time_info)
    xs = torch.cat([xij, cn_emb], dim=-1)

    xs.relu()
    xs = self.xsmlp(xs)

    return xs.view(-1)
get_cn_emb
get_cn_emb(
    x: Tensor,
    edge_index: Tensor,
    tar_ei: Tensor,
    time_info: Optional[Tuple[Tensor, Tensor]] = None,
) -> Tensor

Obtain the CNs embeddings for each node pair in a batch (Torch version).

Parameters:

  • x (Tensor) –

    node features,

  • edge_index (Tensor) –

    edges list of subgraph,

  • tar_ei (Tensor) –

    edges list for prediction,

  • time_info (Optional[Tuple[Tensor, Tensor]], default: None ) –

    A tuple of last update and current time of each edge

Source code in tgm/nn/decoder/ncnpred.py
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
def get_cn_emb(
    self,
    x: torch.Tensor,
    edge_index: torch.Tensor,
    tar_ei: torch.Tensor,
    time_info: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
    r"""Obtain the CNs embeddings for each node pair in a batch (Torch version).

    Args:
        x (torch.Tensor): node features,
        edge_index (torch.Tensor): edges list of subgraph,
        tar_ei (torch.Tensor): edges list for prediction,
        time_info (Optional[Tuple[torch.Tensor, torch.Tensor]]): A tuple of last update and current time of each edge
    """
    tar_i, tar_j = tar_ei[0], tar_ei[1]
    if self.cn_time_decay:
        if time_info is None:
            raise RuntimeError(
                'Please provide time_information to perform time decay'
            )
        last_update, pos_t = time_info
        last_update = last_update.unsqueeze(0)  # 1*N
        pos_t = pos_t.unsqueeze(1)  # B*1
        time_decay_matrix = (pos_t - last_update) / 10000  # time scale
        time_decay_matrix = torch.exp(-time_decay_matrix)

    id_num = x.size(0)
    adj1 = (
        torch.sparse_coo_tensor(
            torch.cat(
                (edge_index, torch.stack([edge_index[1], edge_index[0]])),
                dim=-1,
            ),
            torch.ones(edge_index.shape[1] * 2, device=x.device),
            size=(id_num, id_num),
        )
        .coalesce()
        .to(x.device)
    )

    if self.k == 4:
        indices = torch.arange(id_num, device=x.device)
        adj0 = torch.sparse_coo_tensor(
            torch.stack([indices, indices], dim=0),
            torch.ones(id_num, device=x.device),
            size=(id_num, id_num),
            device=x.device,
        )

        i_0_v, i_1_v, j_0_v, j_1_v = (
            _sparse_sliding(adj0, tar_i),
            _sparse_sliding(adj1, tar_i),
            _sparse_sliding(adj0, tar_j),
            _sparse_sliding(adj1, tar_j),
        )

        i_0_e, i_1_e, j_0_e, j_1_e = (
            _fill(i_0_v, 1.0),
            _fill(i_1_v, 1.0),
            _fill(j_0_v, 1.0),
            _fill(j_1_v, 1.0),
        )

        cn_0_1, cn_1_0 = (i_0_v * j_1_v), (i_1_v * j_0_v)
        cn_1_1 = i_1_v * j_1_v

        if self.cn_time_decay:
            cn_0_1, cn_1_0, cn_1_1 = (
                cn_0_1 * time_decay_matrix,
                cn_1_0 * time_decay_matrix,
                cn_1_1 * time_decay_matrix,
            )
        xcn_0_1, xcn_1_0, xcn_1_1 = (
            torch.sparse.mm(cn_0_1, x),
            torch.sparse.mm(cn_1_0, x),
            torch.sparse.mm(cn_1_1, x),
        )
        cn_emb = torch.cat([xcn_0_1, xcn_1_0, xcn_1_1], dim=-1)

    elif self.k == 2:
        i_1_v, j_1_v = (
            _sparse_sliding(adj1, tar_i),
            _sparse_sliding(adj1, tar_j),
        )
        i_1_e, j_1_e = _fill(i_1_v, 1.0), _fill(j_1_v, 1.0)
        cn_1_1 = i_1_v * j_1_v
        if self.cn_time_decay:
            cn_1_1 = cn_1_1 * time_decay_matrix
        xcn_1_1 = torch.sparse.mm(cn_1_1, x)
        cn_emb = torch.cat([xcn_1_1], dim=-1)

    elif self.k == 8:
        indices = torch.arange(id_num, device=x.device)
        adj0 = torch.sparse_coo_tensor(
            torch.stack([indices, indices], dim=0),
            torch.ones(id_num, device=x.device),
            size=(id_num, id_num),
            device=x.device,
        )

        adj2 = torch.sparse.mm(adj1, adj1)  # self: fake 2 hop
        k3cycle = torch.sparse.mm(adj2, adj1)
        i_0_v, i_1_v, i_2_v, j_0_v, j_1_v, j_2_v = (
            _sparse_sliding(adj0, tar_i),
            _sparse_sliding(adj1, tar_i),
            _sparse_sliding(adj2, tar_i),
            _sparse_sliding(adj0, tar_j),
            _sparse_sliding(adj1, tar_j),
            _sparse_sliding(adj2, tar_j),
        )

        i_0_e, i_1_e, i_2_e, j_0_e, j_1_e, j_2_e = (
            _fill(i_0_v, 1.0),
            _fill(i_1_v, 1.0),
            _fill(i_2_v, 1.0),
            _fill(j_0_v, 1.0),
            _fill(j_1_v, 1.0),
            _fill(j_2_v, 1.0),
        )

        cn_0_1, cn_1_0 = (i_0_v * j_1_v), (i_1_v * j_0_v)
        cn_1_1 = i_1_v * j_1_v
        cn_1_2, cn_2_1, cn_2_2 = (
            (i_1_v * j_2_v),
            (i_2_v * j_1_v),
            (i_2_v * j_2_v),
        )

        u_v_value = _sparse_sliding(adj1, tar_i, tar_j).to_dense().diag().reshape(
            -1, 1
        ) * (-1)
        delta_1_2 = i_1_v * i_1_v * u_v_value
        delta_2_1 = j_1_v * j_1_v * u_v_value
        neg_cn_1_1 = torch.sparse_coo_tensor(
            cn_1_1.indices(),
            cn_1_1.values() * -1,
            cn_1_1.size(),
            device=x.device,
        )
        delta_2_2 = (
            i_1_e
            * _sparse_sliding(k3cycle, tar_i, tar_i)
            .to_dense()
            .diag()
            .reshape(-1, 1)
            + j_1_e
            * _sparse_sliding(k3cycle, tar_j, tar_j)
            .to_dense()
            .diag()
            .reshape(-1, 1)
            + neg_cn_1_1
        ) * u_v_value
        special_2_2 = torch.sparse.mm(cn_1_1, adj1)
        delta_2_2 = delta_2_2 + special_2_2

        cn_1_2, cn_2_1 = cn_1_2 + delta_1_2, cn_2_1 + delta_2_1
        cn_2_2 = cn_2_2 + delta_2_2
        idx = torch.arange(0, len(tar_i), device=x.device).repeat(2)
        u_v_mask = torch.cat([tar_i, tar_j], dim=0)

        cn_1_2, cn_2_1, cn_2_2 = (
            cn_1_2.to_dense(),
            cn_2_1.to_dense(),
            cn_2_2.to_dense(),
        )
        cn_1_2[idx, u_v_mask] = 0
        cn_2_1[idx, u_v_mask] = 0
        cn_2_2[idx, u_v_mask] = 0
        cn_2_2[cn_2_2 < 0] = 0

        if self.cn_time_decay:
            cn_0_1, cn_1_0, cn_1_1 = (
                cn_0_1.to_dense(),
                cn_1_0.to_dense(),
                cn_1_1.to_dense(),
            )
            cn_0_1, cn_1_0, cn_1_1, cn_1_2, cn_2_1, cn_2_2 = (
                cn_0_1 * time_decay_matrix,
                cn_1_0 * time_decay_matrix,
                cn_1_1 * time_decay_matrix,
                cn_1_2 * time_decay_matrix,
                cn_2_1 * time_decay_matrix,
                cn_2_2 * time_decay_matrix,
            )
            cn_0_1, cn_1_0, cn_1_1 = (
                cn_0_1.to_sparse_coo(),
                cn_1_0.to_sparse_coo(),
                cn_1_1.to_sparse_coo(),
            )
        cn_1_2, cn_2_1, cn_2_2 = (
            cn_1_2.to_sparse_coo(),
            cn_2_1.to_sparse_coo(),
            cn_2_2.to_sparse_coo(),
        )
        xcn_0_1, xcn_1_0, xcn_1_1, xcn_1_2, xcn_2_1, xcn_2_2 = (
            torch.sparse.mm(cn_0_1, x),
            torch.sparse.mm(cn_1_0, x),
            torch.sparse.mm(cn_1_1, x),
            torch.sparse.mm(cn_1_2, x),
            torch.sparse.mm(cn_2_1, x),
            torch.sparse.mm(cn_2_2, x),
        )
        special_xcn_2_2 = torch.sparse.mm(special_2_2, x)
        cn_emb = torch.cat(
            [
                xcn_0_1,
                xcn_1_0,
                xcn_1_1,
                xcn_1_2,
                xcn_2_1,
                xcn_2_2,
                special_xcn_2_2,
            ],
            dim=-1,
        )

    return cn_emb

Modules

edgebank

Classes:

EdgeBankPredictor

EdgeBankPredictor(
    src: Tensor,
    dst: Tensor,
    ts: Tensor,
    memory_mode: Literal[
        'unlimited', 'fixed'
    ] = 'unlimited',
    window_ratio: float = 0.15,
    pos_prob: float = 1.0,
)

Reference: https://arxiv.org/abs/2207.10128.

This predictor implements the EdgeBank baseline for dynamic link prediction, introduced in https://arxiv.org/abs/2207.10128. It stores a memory of past edges and predicts the probability of a link reoccurring based on whether the queried edge is present in memory. The memory can be either unlimited (retains all edges) or fixed (retains only edges within a sliding window).

Parameters:

  • src (Tensor) –

    Source node IDs of edges used for initialization.

  • dst (Tensor) –

    Destination node IDs of edges used for initialization.

  • ts (Tensor) –

    Timestamps of edges used for initialization.

  • memory_mode (Literal['unlimited', 'fixed'], default: 'unlimited' ) –
    • 'unlimited': Keeps all observed edges in memory.
    • 'fixed': Keeps only edges within a sliding window of time. Defaults to 'unlimited'.
  • window_ratio (float, default: 0.15 ) –

    Ratio of the sliding window length to the total observed time span (only used if memory_mode='fixed'). Must be in (0, 1]. Defaults to 0.15.

  • pos_prob (float, default: 1.0 ) –

    The probability assigned to edges present in memory. Defaults to 1.0.

Raises:

  • ValueError

    If memory_mode is not one of 'unlimited' or 'fixed'.

  • ValueError

    If window_ratio is not in the range (0, 1].

  • TypeError

    If src, dst, or ts are not all torch.Tensor.

  • ValueError

    If src, dst, and ts do not have the same length, or if they are empty.

Note
  • In unlimited mode, memory grows with the number of observed edges.
  • In fixed mode, only edges within the most recent time window are retained. The window size is proportional to window_ratio.

Methods:

  • update

    Update EdgeBank memory with a batch of edges.

Attributes:

Source code in tgm/nn/modules/edgebank.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
def __init__(
    self,
    src: torch.Tensor,
    dst: torch.Tensor,
    ts: torch.Tensor,
    memory_mode: Literal['unlimited', 'fixed'] = 'unlimited',
    window_ratio: float = 0.15,
    pos_prob: float = 1.0,
) -> None:
    """Edgebank link predictor with fixed or unlimited memory.

    Reference: https://arxiv.org/abs/2207.10128.

    This predictor implements the EdgeBank baseline for dynamic link prediction,
    introduced in `https://arxiv.org/abs/2207.10128`. It stores a memory of past
    edges and predicts the probability of a link reoccurring based on whether
    the queried edge is present in memory. The memory can be either unlimited
    (retains all edges) or fixed (retains only edges within a sliding window).

    Args:
        src (torch.Tensor): Source node IDs of edges used for initialization.
        dst (torch.Tensor): Destination node IDs of edges used for initialization.
        ts (torch.Tensor): Timestamps of edges used for initialization.
        memory_mode (Literal['unlimited', 'fixed'], optional):
            - ``'unlimited'``: Keeps all observed edges in memory.
            - ``'fixed'``: Keeps only edges within a sliding window of time.
            Defaults to ``'unlimited'``.
        window_ratio (float, optional): Ratio of the sliding window length to
            the total observed time span (only used if ``memory_mode='fixed'``).
            Must be in ``(0, 1]``. Defaults to ``0.15``.
        pos_prob (float, optional): The probability assigned to edges present
            in memory. Defaults to ``1.0``.

    Raises:
        ValueError: If ``memory_mode`` is not one of ``'unlimited'`` or ``'fixed'``.
        ValueError: If ``window_ratio`` is not in the range ``(0, 1]``.
        TypeError: If ``src``, ``dst``, or ``ts`` are not all ``torch.Tensor``.
        ValueError: If ``src``, ``dst``, and ``ts`` do not have the same length,
            or if they are empty.

    Note:
        - In ``unlimited`` mode, memory grows with the number of observed edges.
        - In ``fixed`` mode, only edges within the most recent time window are
          retained. The window size is proportional to ``window_ratio``.
    """
    if memory_mode not in ['unlimited', 'fixed']:
        raise ValueError(f'memory_mode must be "unlimited" or "fixed"')
    if not 0 < window_ratio <= 1.0:
        raise ValueError(f'Window ratio must be in (0, 1]')
    self._check_input_data(src, dst, ts)

    self.pos_prob = pos_prob
    self._window_ratio = window_ratio
    self._fixed_memory = memory_mode == 'fixed'

    self._window_start, self._window_end = ts.min(), ts.max()
    if self._fixed_memory:
        self._window_start = ts.max() - window_ratio * (ts.max() - ts.min())
    self._window_size = self._window_end - self._window_start

    self.memory: Dict[Tuple[int, int], int] = {}
    # maintain bidirectional linked list with 2 pointers
    self._head: Optional[_Event] = None
    self._tail: Optional[_Event] = None

    logger.warning(
        'EdgeBank will be slow if events are added/updated out of order.'
    )

    self.update(src, dst, ts)
window_end property
window_end: int | float

Return the end timestamp of the current memory window.

window_ratio property
window_ratio: float

Return the ratio of the memory window size to the full time span.

window_start property
window_start: int | float

Return the start timestamp of the current memory window.

update
update(src: Tensor, dst: Tensor, ts: Tensor) -> None

Update EdgeBank memory with a batch of edges.

Parameters:

  • src (Tensor) –

    Source node IDs of the edges.

  • dst (Tensor) –

    Destination node IDs of the edges.

  • ts (Tensor) –

    Timestamps of the edges.

Raises:

  • TypeError

    If inputs are not torch.Tensor.

  • ValueError

    If input tensors do not have the same length, or are empty.

Source code in tgm/nn/modules/edgebank.py
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
def update(self, src: torch.Tensor, dst: torch.Tensor, ts: torch.Tensor) -> None:
    """Update EdgeBank memory with a batch of edges.

    Args:
        src (torch.Tensor): Source node IDs of the edges.
        dst (torch.Tensor): Destination node IDs of the edges.
        ts (torch.Tensor): Timestamps of the edges.

    Raises:
        TypeError: If inputs are not ``torch.Tensor``.
        ValueError: If input tensors do not have the same length, or are empty.
    """
    self._check_input_data(src, dst, ts)
    self._window_end = torch.max(self._window_end, ts.max())
    self._window_start = self._window_end - self._window_size

    if (
        self._fixed_memory
        and self._head is not None
        and self._tail is not None
        and self._head.ts < self._window_start
    ):
        self._clean_up()

    for src_, dst_, ts_ in zip(src, dst, ts):
        src_, dst_, ts_ = src_.item(), dst_.item(), ts_.item()
        if ts_ >= self._window_start:
            self.memory[(src_, dst_)] = ts_
            if self._head == self._tail == None:
                self._head = self._tail = _Event((src_, dst_), ts_, None, None)
            elif self._head is not None and self._tail is not None:
                new_event = _Event((src_, dst_), ts_, left=None, right=None)
                curr: _Event | None = self._tail

                # This while loop should never run assuming events added in time-ascending order.
                # When events added out of order, time complexity would be O(n)
                while curr is not None and ts_ < curr.ts:
                    curr = curr.left

                if curr == None:
                    new_event.right = self._head
                    if self._head is not None:
                        self._head.left = new_event
                    self._head = new_event
                else:
                    new_event.left = curr
                    new_event.right = curr.right  # type: ignore[union-attr]
                    if curr.right is not None:  # type: ignore[union-attr]
                        curr.right.left = new_event  # type: ignore[union-attr]
                    curr.right = new_event  # type: ignore[union-attr]
                    if curr == self._tail:
                        self._tail = new_event

time_encoding

Classes:

Time2Vec

Time2Vec(time_dim: int)

Bases: Module

Parameters:

  • time_dim (int) –

    The dimension of time encodings.

Source code in tgm/nn/modules/time_encoding.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
def __init__(self, time_dim: int) -> None:
    """Time encoder representation.

    Args:
        time_dim (int): The dimension of time encodings.
    """
    super().__init__()
    self.time_dim = time_dim
    self.w = torch.nn.Linear(1, time_dim)

    # Initialization from: https://github.com/yule-BUAA/DyGLib/blob/master/models/modules.py
    w = (1 / 10 ** np.linspace(0, 9, time_dim)).reshape(time_dim, 1)
    self.w.weight = torch.nn.Parameter(torch.from_numpy(w).float())
    self.w.bias = torch.nn.Parameter(torch.zeros(time_dim))

attention

Classes:

  • TemporalAttention

    Multi-head Temporal Attention Module for dynamic/temporal graphs.

TemporalAttention

TemporalAttention(
    n_heads: int,
    node_dim: int,
    edge_dim: int,
    time_dim: int,
    dropout: float = 0.1,
)

Bases: Module

Multi-head Temporal Attention Module for dynamic/temporal graphs.

This module computes attention over a node's neighbors considering node features, edge features, and time features. It supports multiple attention heads and applies residual connection, dropout, and layer normalization to the output.

Parameters:

  • n_heads (int) –

    Number of attention heads.

  • node_dim (int) –

    Dimensionality of node features.

  • edge_dim (int) –

    Dimensionality of edge features.

  • time_dim (int) –

    Dimensionality of temporal features.

  • dropout (float, default: 0.1 ) –

    Dropout probability applied to attention and output layers. Default is 0.1.

Raises:

  • ValueError

    If n_heads, node_dim, edge_dim, or time_dim are <= 0.

Note

The output dimension is node_dim + time_dim, padded to be divisible by n_heads if necessary.

Methods:

  • forward

    Forward pass of the Temporal Attention module.

Source code in tgm/nn/modules/attention.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def __init__(
    self,
    n_heads: int,
    node_dim: int,
    edge_dim: int,
    time_dim: int,
    dropout: float = 0.1,
) -> None:
    super().__init__()
    if any((x <= 0 for x in [n_heads, node_dim, edge_dim, time_dim])):
        raise ValueError('n_heads,node_dim,edge_dim,time_dim,out_dim must be > 0')

    out_dim = node_dim + time_dim
    self.pad_dim = 0
    if out_dim % n_heads != 0:
        self.pad_dim = n_heads - out_dim % n_heads
        out_dim += self.pad_dim

    self.n_heads = n_heads
    self.head_dim = out_dim // n_heads
    self.out_dim = out_dim

    key_dim = node_dim + edge_dim + time_dim
    self.W_Q = torch.nn.Linear(out_dim, out_dim, bias=False)
    self.W_KV = torch.nn.Linear(key_dim, out_dim * 2, bias=False)
    self.W_O = torch.nn.Linear(out_dim, out_dim)

    self.dropout = torch.nn.Dropout(dropout)
    self.layer_norm = torch.nn.LayerNorm(out_dim)
forward
forward(
    node_feat: Tensor,
    time_feat: Tensor,
    edge_feat: Tensor,
    nbr_node_feat: Tensor,
    nbr_time_feat: Tensor,
    valid_nbr_mask: Tensor,
) -> Tensor

Forward pass of the Temporal Attention module.

Computes multi-head attention over neighbors, using node, edge, and time features, followed by a residual connection, dropout, and layer normalization.

Parameters:

  • node_feat (Tensor) –

    Node features of shape (B, node_dim).

  • time_feat (Tensor) –

    Node time features of shape (B, time_dim).

  • edge_feat (Tensor) –

    Edge features for each neighbor of shape (B, num_nbrs, edge_dim).

  • nbr_node_feat (Tensor) –

    Neighbor node features of shape (B, num_nbrs, node_dim).

  • nbr_time_feat (Tensor) –

    Neighbor time features of shape (B, num_nbrs, time_dim).

  • valid_nbr_mask (Tensor) –

    Boolean mask indicating valid neighbors of shape (B, num_nbrs). True indicates valid neighbors.

Returns:

  • Tensor

    torch.Tensor: Updated node features of shape (B, out_dim).

Notes
  • If a node has no neighbors, masked attention values are set to -1e10 to avoid NaNs in softmax.
  • The output dimension is padded if necessary to be divisible by the number of heads.
Source code in tgm/nn/modules/attention.py
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
def forward(
    self,
    node_feat: torch.Tensor,
    time_feat: torch.Tensor,
    edge_feat: torch.Tensor,
    nbr_node_feat: torch.Tensor,
    nbr_time_feat: torch.Tensor,
    valid_nbr_mask: torch.Tensor,
) -> torch.Tensor:
    """Forward pass of the Temporal Attention module.

    Computes multi-head attention over neighbors, using node, edge, and time
    features, followed by a residual connection, dropout, and layer normalization.

    Args:
        node_feat (torch.Tensor): Node features of shape (B, node_dim).
        time_feat (torch.Tensor): Node time features of shape (B, time_dim).
        edge_feat (torch.Tensor): Edge features for each neighbor of shape
            (B, num_nbrs, edge_dim).
        nbr_node_feat (torch.Tensor): Neighbor node features of shape
            (B, num_nbrs, node_dim).
        nbr_time_feat (torch.Tensor): Neighbor time features of shape
            (B, num_nbrs, time_dim).
        valid_nbr_mask (torch.Tensor): Boolean mask indicating valid neighbors
            of shape (B, num_nbrs). True indicates valid neighbors.

    Returns:
        torch.Tensor: Updated node features of shape (B, out_dim).

    Notes:
        - If a node has no neighbors, masked attention values are set to -1e10
          to avoid NaNs in softmax.
        - The output dimension is padded if necessary to be divisible by the number
          of heads.
    """
    node_feat = F.pad(node_feat, (0, self.pad_dim)) if self.pad_dim else node_feat

    Q = R = torch.cat([node_feat, time_feat], dim=1).unsqueeze(1)  # (B, 1, out_dim)
    Q = self.W_Q(Q)  # (B, out_dim)

    Z = torch.cat([nbr_node_feat, edge_feat, nbr_time_feat], dim=-1)
    Z = self.W_KV(Z)
    K = Z[:, :, : self.out_dim]  # (B, num_nbrs, out_dim)
    V = Z[:, :, self.out_dim :]  # (B, num_nbrs, out_dim)

    Q = Q.reshape(Q.shape[0], -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
    K = K.reshape(K.shape[0], -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
    V = V.reshape(V.shape[0], -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
    del Z

    A = torch.einsum('bhld,bhnd->bhln', Q, K)  # (B, n_heads, 1, num_nbrs)
    A *= self.head_dim**-0.5
    del Q, K

    valid_nbr_mask = valid_nbr_mask.reshape(valid_nbr_mask.shape[0], 1, 1, -1)
    valid_nbr_mask = valid_nbr_mask.repeat(1, self.n_heads, 1, 1)

    # If a node has no neighbors (valid_nbr_mask all False), setting masks to -np.inf will cause softmax nans
    # Choose a very large negative number (-1e10 following TGAT) instead
    A = A.masked_fill(~valid_nbr_mask, -1e10)
    A = torch.softmax(A, dim=-1)
    A = self.dropout(A)

    O = torch.einsum('bhln,bhnd->bhld', A, V)  # (B, n_heads, 1, head_dim)
    O = O.flatten(start_dim=1)  # (B, out_dim)
    del A

    out = self.W_O(O)  # (B, out_dim)
    out = self.dropout(out)
    out = self.layer_norm(out + R.squeeze(1))
    return out

poptrack

Classes:

PopTrackPredictor

PopTrackPredictor(
    src: Tensor,
    dst: Tensor,
    ts: Tensor,
    num_nodes: int,
    k: int = 50,
    decay: float = 0.9,
)

Reference: https://openreview.net/pdf?id=9kLDrE5rsW

This predictor implements the PopTrack baseline for dynamic link prediction, introduced in https://openreview.net/pdf?id=9kLDrE5rsW. It predicts the probability of a link reoccurring based on the popularity score of the queried edge's destination.

Parameters:

  • src (Tensor) –

    Source node IDs of edges used for initialization.

  • dst (Tensor) –

    Destination node IDs of edges used for initialization.

  • ts (Tensor) –

    Timestamps of edges used for initialization.

  • num_nodes (int) –

    The total number of nodes.

  • k (int, default: 50 ) –

    Number of popular nodes to retrieve from.

  • decay (float, default: 0.9 ) –

    temporal decay parameter. Must be in (0, 1]. Defaults to 0.9.

Raises:

  • ValueError

    If k is nonpositive.

  • TypeError

    If src, dst, or ts are not all torch.Tensor.

  • ValueError

    If src, dst, and ts do not have the same length, or if they are empty.

Note
  • The predictions are not conditional on the source.

Methods:

  • update

    Update PopTrack cache with a batch of edges.

Source code in tgm/nn/modules/poptrack.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def __init__(
    self,
    src: torch.Tensor,
    dst: torch.Tensor,
    ts: torch.Tensor,
    num_nodes: int,
    k: int = 50,
    decay: float = 0.9,
) -> None:
    """PopTrack Predictor.

    Reference: https://openreview.net/pdf?id=9kLDrE5rsW

    This predictor implements the PopTrack baseline for dynamic link prediction,
    introduced in `https://openreview.net/pdf?id=9kLDrE5rsW`.
    It predicts the probability of a link reoccurring based on
    the popularity score of the queried edge's destination.

    Args:
        src (torch.Tensor): Source node IDs of edges used for initialization.
        dst (torch.Tensor): Destination node IDs of edges used for initialization.
        ts (torch.Tensor): Timestamps of edges used for initialization.
        num_nodes (int): The total number of nodes.
        k (int, optional): Number of popular nodes to retrieve from.
        decay (float, optional): temporal decay parameter.
            Must be in ``(0, 1]``. Defaults to ``0.9``.

    Raises:
        ValueError: If ``k`` is nonpositive.
        TypeError: If ``src``, ``dst``, or ``ts`` are not all ``torch.Tensor``.
        ValueError: If ``src``, ``dst``, and ``ts`` do not have the same length,
            or if they are empty.

    Note:
        - The predictions are not conditional on the source.
    """
    if 0 >= k:
        raise ValueError('K must be positive')

    if decay <= 0 or decay > 1:
        raise ValueError('Decay must be in (0,1]')

    if num_nodes <= 0:
        raise ValueError('``num_nodes`` must be set to the total number of nodes.')

    if k > num_nodes:
        raise ValueError('``k`` must be smaller than ``num_nodes``.')

    self._check_input_data(src, dst, ts)
    self.popularity = torch.zeros(num_nodes)
    self.k = k
    self.decay = decay
    self.update(src, dst, ts)
update
update(src: Tensor, dst: Tensor, ts: Tensor) -> None

Update PopTrack cache with a batch of edges.

Parameters:

  • src (Tensor) –

    Source node IDs of the edges.

  • dst (Tensor) –

    Destination node IDs of the edges.

  • ts (Tensor) –

    Timestamps of the edges.

Raises:

  • TypeError

    If inputs are not torch.Tensor.

  • ValueError

    If input tensors do not have the same length, or are empty.

Source code in tgm/nn/modules/poptrack.py
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
def update(self, src: torch.Tensor, dst: torch.Tensor, ts: torch.Tensor) -> None:
    """Update PopTrack cache with a batch of edges.

    Args:
        src (torch.Tensor): Source node IDs of the edges.
        dst (torch.Tensor): Destination node IDs of the edges.
        ts (torch.Tensor): Timestamps of the edges.

    Raises:
        TypeError: If inputs are not ``torch.Tensor``.
        ValueError: If input tensors do not have the same length, or are empty.
    """
    self._check_input_data(src, dst, ts)
    self.popularity.index_add_(
        0, dst, torch.ones_like(dst, dtype=self.popularity.dtype)
    )
    self.popularity *= self.decay

t_comem

Classes:

tCoMemPredictor

tCoMemPredictor(
    src: Tensor,
    dst: Tensor,
    ts: Tensor,
    num_nodes: int,
    k: int = 50,
    window_ratio: float = 0.15,
    co_occurrence_weight: float = 0.8,
)

Reference: https://www.arxiv.org/abs/2506.12764

This predictor implements the t-CoMem module for dynamic link prediction, introduced in https://www.arxiv.org/abs/2506.12764. It is a memory-based module that mixes popularity with co-occurence.

Parameters:

  • src (Tensor) –

    Source node IDs of edges used for initialization.

  • dst (Tensor) –

    Destination node IDs of edges used for initialization.

  • ts (Tensor) –

    Timestamps of edges used for initialization.

  • num_nodes (int) –

    Total number of nodes in the dataset.

  • k (int, default: 50 ) –

    threshold for popularity effect. Defaults to 50, must be positive and smaller than num_nodes. In general, larger k leads to better performance but higher memory usage, though this usually stops being true after a certain point.

  • window_ratio (float, default: 0.15 ) –

    Ratio of the sliding window length to the total observed time span (only used if memory_mode='fixed'). Must be in (0, 1]. Defaults to 0.15.

  • co_occurrence_weight (float, default: 0.8 ) –

    Weighting parameter for co-occurrence. Must be in (0, 1]. Defaults to 0.8.

Raises:

  • TypeError

    If src, dst, or ts are not all torch.Tensor.

  • ValueError

    If src, dst, and ts do not have the same length, or if they are empty.

Methods:

  • update

    Update EdgeBank memory with a batch of edges.

Attributes:

Source code in tgm/nn/modules/t_comem.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
def __init__(
    self,
    src: torch.Tensor,
    dst: torch.Tensor,
    ts: torch.Tensor,
    num_nodes: int,
    k: int = 50,
    window_ratio: float = 0.15,
    co_occurrence_weight: float = 0.8,
) -> None:
    """t-CoMem link predictor with fixed or unlimited memory.

    Reference: https://www.arxiv.org/abs/2506.12764

    This predictor implements the t-CoMem module for dynamic link prediction,
    introduced in `https://www.arxiv.org/abs/2506.12764`.
    It is a memory-based module that mixes popularity with co-occurence.

    Args:
        src (torch.Tensor): Source node IDs of edges used for initialization.
        dst (torch.Tensor): Destination node IDs of edges used for initialization.
        ts (torch.Tensor): Timestamps of edges used for initialization.
        num_nodes (int): Total number of nodes in the dataset.
        k (int, optional): threshold for popularity effect.
            Defaults to 50, must be positive and smaller than ``num_nodes``.
            In general, larger ``k`` leads to better performance but higher memory usage,
            though this usually stops being true after a certain point.
        window_ratio (float, optional): Ratio of the sliding window length to
            the total observed time span (only used if ``memory_mode='fixed'``).
            Must be in ``(0, 1]``. Defaults to ``0.15``.
        co_occurrence_weight (float, optional): Weighting parameter for co-occurrence.
            Must be in ``(0, 1]``. Defaults to ``0.8``.

    Raises:
        TypeError: If ``src``, ``dst``, or ``ts`` are not all ``torch.Tensor``.
        ValueError: If ``src``, ``dst``, and ``ts`` do not have the same length,
            or if they are empty.

    """
    if not 0 < window_ratio <= 1.0:
        raise ValueError(f'Window ratio must be in (0, 1]')

    if not 0 < co_occurrence_weight <= 1.0:
        raise ValueError(f'Co-occurrence weight must be in (0, 1]')

    if 0 >= k:
        raise ValueError(f'K must be positive')

    if num_nodes <= 0:
        raise ValueError('``num_nodes`` must be set to the total number of nodes.')

    if k > num_nodes:
        raise ValueError('``k`` must be smaller than ``num_nodes``.')

    self._check_input_data(src, dst, ts)

    self._window_ratio = window_ratio
    self._window_start, self._window_end = ts.min(), ts.max()
    self._window_size = torch.clamp(self._window_end - self._window_start, min=1.0)

    self.device = src.device
    self.num_nodes = num_nodes
    self.k = k

    self.recent_ts = torch.full(
        (self.num_nodes, self.k), fill_value=-float('inf'), device=self.device
    )

    self.recent_dst = torch.full(
        (self.num_nodes, self.k), fill_value=-1, device=self.device
    )

    self.recent_len = torch.zeros(self.num_nodes)
    self.recent_pos = torch.zeros(self.num_nodes)

    self.node_to_co_occurrence: DefaultDict[int, Dict[int, int]] = defaultdict(dict)
    self.popularity = torch.zeros(num_nodes)
    self.co_occurrence_weight = co_occurrence_weight

    self.update(src, dst, ts)
window_end property
window_end: int | float

Return the end timestamp of the current memory window.

window_ratio property
window_ratio: float

Return the ratio of the memory window size to the full time span.

window_size property
window_size: int

Return the absolute size of the memory window.

window_start property
window_start: int | float

Return the start timestamp of the current memory window.

update
update(src: Tensor, dst: Tensor, ts: Tensor) -> None

Update EdgeBank memory with a batch of edges.

Parameters:

  • src (Tensor) –

    Source node IDs of the edges.

  • dst (Tensor) –

    Destination node IDs of the edges.

  • ts (Tensor) –

    Timestamps of the edges.

Raises:

  • TypeError

    If inputs are not torch.Tensor.

  • ValueError

    If input tensors do not have the same length, or are empty.

Source code in tgm/nn/modules/t_comem.py
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
def update(
    self,
    src: torch.Tensor,
    dst: torch.Tensor,
    ts: torch.Tensor,
) -> None:
    """Update EdgeBank memory with a batch of edges.

    Args:
        src (torch.Tensor): Source node IDs of the edges.
        dst (torch.Tensor): Destination node IDs of the edges.
        ts (torch.Tensor): Timestamps of the edges.

    Raises:
        TypeError: If inputs are not ``torch.Tensor``.
        ValueError: If input tensors do not have the same length, or are empty.
    """
    self._check_input_data(src, dst, ts)

    self._window_end = torch.max(self._window_end, ts.max())
    self._window_start = self._window_end - self._window_size

    for s, d, t in zip(src.long().tolist(), dst.long().tolist(), ts.tolist()):
        pos = int(self.recent_pos[s])

        self.recent_ts[s, pos] = t
        self.recent_dst[s, pos] = d
        self.recent_pos[s] = (pos + 1) % self.k
        if self.recent_len[s] < self.k:
            self.recent_len[s] += 1
        self.node_to_co_occurrence[s][d] = (
            self.node_to_co_occurrence[s].get(d, 0) + 1
        )
        self.node_to_co_occurrence[d][s] = (
            self.node_to_co_occurrence[d].get(s, 0) + 1
        )

    self.popularity.index_add_(
        0, dst.long(), torch.ones_like(dst, dtype=self.popularity.dtype)
    )