Quick Start Guide
Minimal Example
Here’s a basic example demonstrating how to train TGCN for dynamic node property prediction on tgbn-trade:
import torch
import torch.nn as nn
import torch.nn.functional as F
from tgm import DGraph, DGBatch
from tgm.data import DGData, DGDataLoader
from tgm.nn import TGCN, NodePredictor
# Load TGB data splits
train_data, val_data, test_data = DGData.from_tgb("tgbn-trade").split()
# Construct a DGraph and setup iteration by yearly ('Y') snapshots
train_dg = DGraph(train_data)
train_loader = DGDataLoader(train_dg, batch_unit="Y")
# tgbn-trade has no static node features, so we create Gaussian ones (dim=64)
static_node_x = torch.randn((train_dg.num_nodes, 64))
class RecurrentGCN(torch.nn.Module):
def __init__(self, node_dim: int, embed_dim: int) -> None:
super().__init__()
self.recurrent = TGCN(in_channels=node_dim, out_channels=embed_dim)
self.linear = nn.Linear(embed_dim, embed_dim)
def forward(
self, batch: DGBatch, node_feat: torch.tensor, h: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
edge_index = torch.stack([batch.edge_src, batch.edge_dst], dim=0)
h_0 = self.recurrent(node_feat, edge_index, H=h)
z = F.relu(h_0)
z = self.linear(z)
return z, h_0
# Initialize our model and optimizer
encoder = RecurrentGCN(node_dim=static_node_x.shape[1], embed_dim=128)
decoder = NodePredictor(in_dim=128, out_dim=train_dg.node_y_dim)
opt = torch.optim.Adam(set(encoder.parameters()) | set(decoder.parameters()), lr=0.001)
# Training loop
h_0 = None
for batch in train_loader:
opt.zero_grad()
y_true = batch.node_y
if y_true is None:
continue
z, h_0 = encoder(batch, static_node_x, h_0)
z_node = z[batch.node_y_nids]
y_pred = decoder(z_node)
loss = F.cross_entropy(y_pred, y_true)
loss.backward()
opt.step()
h_0 = h_0.detach()
Running Pre-packaged Examples
TGM includes pre-packaged example scripts to help you get started quickly. The examples require extra dependencies beyond the core library.
pip install -e .[examples]
After installing the dependencies, you can run any of our examples. For instance, TGAT dynamic link prediction on tgbl-wiki:
python examples/linkproppred/tgat.py --dataset tgbl-wiki --device cuda
[!NOTE] By default, our link prediction examples default to
tgbl-wiki, and node prediction usetgbn-trade. Examples run on CPU by default; use the--deviceflag to override this as shown above.