Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

GNN Augmentation and Training

Do We Actually Need Prediction Heads?

Consider a GNN whose last layer produces node embeddings H(L)Rn×dH^{(L)} \in \mathbb{R}^{n \times d}. Let’s assume that you propose to remove all prediction heads and directly train using:

L=vD(hv(L),yv)\mathcal{L} = \sum_{v \in \mathcal{D}} \ell(h_v^{(L)}, y_v)

where \ell is a classification/regression loss and DV\mathcal{D} \subseteq V is the set of labeled nodes used in training.

Answer briefly:

  1. How would this affect the generalization capability of the model? (Hint: Think of what the embeddings will encode.)

  2. In this scenario, describe what happens when classifying a previously unseen node (inference). What does the embedding of this node represent, and how is the class assigned?

  3. Explain why adding a prediction head (e.g., a linear layer + softmax) solves the issues identified in (1) and (2). How does it allow embeddings to remain generalizable while still supporting accurate classification?

Expressivity of Prediction Heads

You are given a GNN that computes node embeddings huh_u, hvh_v. You think of the following edge prediction heads:

  • Dot product: s(u,v)=huhvs(u,v)=h_u^\top h_v

  • Single MLP on concatenation: MLP([huhv])\text{MLP}([h_u || h_v])

  • Bilinear form: huWhvh_u^\top W h_v

Tasks:

  1. Compare the expressive power of these edge prediction heads.

  2. Show that the dot product is a special (degenerate) case of the bilinear form.

  3. Give an example edge-labeling function where dot-product performs strictly worse.

Is there a Leak?

You train a node classifier on a citation graph, and someone claims:

“If the graph is fixed, transductive evaluation is just cheating, you always leak test information.”

Tasks:

  1. Provide a rigorous argument for when this is not true (give a scenario where transductive inference is legitimate).

  2. Give a leakage scenario that this claim can be true.

  3. Explain why graph classification cannot be transductive even in principle.

In this exercise, you’ll implement a basic transductive link prediction setup using PyG tools on the Cora dataset.

We start by loading the dataset from the Planetoid class (documentation).

import torch
from torch_geometric.datasets import Planetoid

# load dataset
dataset = Planetoid(root='data/Cora', name='Cora')
data = dataset[0]

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Next, we’ll use the RandomLinkSplit (documentation) to split our dataset.

For our simple exercise, we won’t need a validation set, and we’ll split the data so that 90%90\% of the edges are in the training set and the remaining 10%10\% is in the test set. Further, we’ll split the training edges into message (80%80\%) and supervision (20%20\%) edges.

  1. Given this, complete the following code snippet. Check the RandomLinkSplit documentation to find which parameters to use.

from torch_geometric.transforms import RandomLinkSplit

# edge split for link prediction
split = RandomLinkSplit(
    ########## Your code here ##########
    
    #################################### 
    is_undirected=True,
    add_negative_train_samples=True
)
train_data, _, test_data = split(data)

train_data, test_data = train_data.to(device), test_data.to(device)

Next, let’s be sure that the message and supervision sets do not overlap. We want to have zero overlapping edges.

train_message_edges = train_data.edge_index
train_supervision_edges = train_data.edge_label_index
test_edges = test_data.edge_index

# message edges as set of (u, v) tuples
message_set = set(map(tuple, train_message_edges.t().cpu().numpy()))

# supervision edges as set of (u, v) tuples
supervision_set = set(map(tuple, train_supervision_edges.t().cpu().numpy()))

overlap = message_set & supervision_set  # intersection
print(f"Number of overlapping edges: {len(overlap)}")

# number of edges in each set
print(f"Number of message edges: {train_message_edges.size(1)}")
print(f"Number of supervision edges: {train_supervision_edges.size(1)}")

Next, we’ll use the GCNConv layer implemented in PyG. We’ll use a simple 2-layer GCN.

import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class GCN(torch.nn.Module):
    def __init__(self, in_dim, hidden_dim=64, out_dim=32):
        super().__init__()
        self.conv1 = GCNConv(in_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, out_dim)

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)
        return x
    
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Now, we’ll need a prediction head. The simplest one would be the dot product.

  1. Complete the following function that computes the dot products for given edges.

def predict_dot(z, edge_index):
    """
    z: node embeddings, shape [num_nodes, embedding_dim]
    edge_index: edge indices, shape [2, num_edges]
    returns: scores for each edge, shape [num_edges]
    """

    ########## Your code here ##########
    
    #################################### 

At this point, we have everything we need, so let’s train our GCN!

  1. Complete the following training loop.

model = GCN(dataset.num_features).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

for epoch in range(100):
    model.train()
    optimizer.zero_grad()

    # get node embeddings
    ########## Your code here ##########
    z = ...
    #################################### 

    # get predictions for supervision edges
    ########## Your code here ##########
    preds = ...
    ####################################
    
    # labels of supervision edges 
    labels = train_data.edge_label.float().to(device)

    # compute BCE loss
    ########## Your code here ##########
    loss = F.binary_cross_entropy_with_logits(..., ...)
    ####################################
    
    loss.backward()
    optimizer.step()

    if epoch % 10 == 0:
        print(f'Epoch {epoch}, Loss: {loss.item():.4f}')

Now it’s time to test our model. We’ll use the roc_auc_score from sklearn.metrics.

from sklearn.metrics import roc_auc_score

model.eval()
with torch.no_grad():
    
    # node embeddings computed using message-passing edges
    z = model(data.x.to(device), data.edge_index.to(device))
    
    # supervision edges for test
    test_edge_index = test_data.edge_label_index
    test_labels = test_data.edge_label.float().to(device)
    
    # dot product predictions
    test_preds = predict_dot(z, test_edge_index)
    
    # ROC-AUC
    test_auc = roc_auc_score(test_labels.cpu(), test_preds.sigmoid().cpu())
    print(f"Test ROC-AUC: {test_auc:.4f}")
  1. Instead of dot product, implement the CONCAT + MLP approach and test your GCN. Compare your result with the dot product head, and interpret your findings.