Do We Actually Need Prediction Heads?¶
Consider a GNN whose last layer produces node embeddings . Let’s assume that you propose to remove all prediction heads and directly train using:
where is a classification/regression loss and is the set of labeled nodes used in training.
Answer briefly:
How would this affect the generalization capability of the model? (Hint: Think of what the embeddings will encode.)
In this scenario, describe what happens when classifying a previously unseen node (inference). What does the embedding of this node represent, and how is the class assigned?
Explain why adding a prediction head (e.g., a linear layer + softmax) solves the issues identified in (1) and (2). How does it allow embeddings to remain generalizable while still supporting accurate classification?
Expressivity of Prediction Heads¶
You are given a GNN that computes node embeddings , . You think of the following edge prediction heads:
Dot product:
Single MLP on concatenation:
Bilinear form:
Tasks:
Compare the expressive power of these edge prediction heads.
Show that the dot product is a special (degenerate) case of the bilinear form.
Give an example edge-labeling function where dot-product performs strictly worse.
Is there a Leak?¶
You train a node classifier on a citation graph, and someone claims:
“If the graph is fixed, transductive evaluation is just cheating, you always leak test information.”
Tasks:
Provide a rigorous argument for when this is not true (give a scenario where transductive inference is legitimate).
Give a leakage scenario that this claim can be true.
Explain why graph classification cannot be transductive even in principle.
Programming: Transductive Link Prediction with Multiple Heads¶
In this exercise, you’ll implement a basic transductive link prediction setup using PyG tools on the Cora dataset.
We start by loading the dataset from the Planetoid class (documentation).
import torch
from torch_geometric.datasets import Planetoid
# load dataset
dataset = Planetoid(root='data/Cora', name='Cora')
data = dataset[0]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')Next, we’ll use the RandomLinkSplit (documentation) to split our dataset.
For our simple exercise, we won’t need a validation set, and we’ll split the data so that of the edges are in the training set and the remaining is in the test set. Further, we’ll split the training edges into message () and supervision () edges.
Given this, complete the following code snippet. Check the
RandomLinkSplitdocumentation to find which parameters to use.
from torch_geometric.transforms import RandomLinkSplit
# edge split for link prediction
split = RandomLinkSplit(
########## Your code here ##########
####################################
is_undirected=True,
add_negative_train_samples=True
)
train_data, _, test_data = split(data)
train_data, test_data = train_data.to(device), test_data.to(device)Next, let’s be sure that the message and supervision sets do not overlap. We want to have zero overlapping edges.
train_message_edges = train_data.edge_index
train_supervision_edges = train_data.edge_label_index
test_edges = test_data.edge_index
# message edges as set of (u, v) tuples
message_set = set(map(tuple, train_message_edges.t().cpu().numpy()))
# supervision edges as set of (u, v) tuples
supervision_set = set(map(tuple, train_supervision_edges.t().cpu().numpy()))
overlap = message_set & supervision_set # intersection
print(f"Number of overlapping edges: {len(overlap)}")
# number of edges in each set
print(f"Number of message edges: {train_message_edges.size(1)}")
print(f"Number of supervision edges: {train_supervision_edges.size(1)}")Next, we’ll use the GCNConv layer implemented in PyG. We’ll use a simple 2-layer GCN.
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
class GCN(torch.nn.Module):
def __init__(self, in_dim, hidden_dim=64, out_dim=32):
super().__init__()
self.conv1 = GCNConv(in_dim, hidden_dim)
self.conv2 = GCNConv(hidden_dim, out_dim)
def forward(self, x, edge_index):
x = F.relu(self.conv1(x, edge_index))
x = self.conv2(x, edge_index)
return x
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')Now, we’ll need a prediction head. The simplest one would be the dot product.
Complete the following function that computes the dot products for given edges.
def predict_dot(z, edge_index):
"""
z: node embeddings, shape [num_nodes, embedding_dim]
edge_index: edge indices, shape [2, num_edges]
returns: scores for each edge, shape [num_edges]
"""
########## Your code here ##########
#################################### At this point, we have everything we need, so let’s train our GCN!
Complete the following training loop.
model = GCN(dataset.num_features).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
for epoch in range(100):
model.train()
optimizer.zero_grad()
# get node embeddings
########## Your code here ##########
z = ...
####################################
# get predictions for supervision edges
########## Your code here ##########
preds = ...
####################################
# labels of supervision edges
labels = train_data.edge_label.float().to(device)
# compute BCE loss
########## Your code here ##########
loss = F.binary_cross_entropy_with_logits(..., ...)
####################################
loss.backward()
optimizer.step()
if epoch % 10 == 0:
print(f'Epoch {epoch}, Loss: {loss.item():.4f}')Now it’s time to test our model. We’ll use the roc_auc_score from sklearn.metrics.
from sklearn.metrics import roc_auc_score
model.eval()
with torch.no_grad():
# node embeddings computed using message-passing edges
z = model(data.x.to(device), data.edge_index.to(device))
# supervision edges for test
test_edge_index = test_data.edge_label_index
test_labels = test_data.edge_label.float().to(device)
# dot product predictions
test_preds = predict_dot(z, test_edge_index)
# ROC-AUC
test_auc = roc_auc_score(test_labels.cpu(), test_preds.sigmoid().cpu())
print(f"Test ROC-AUC: {test_auc:.4f}")Instead of dot product, implement the CONCAT + MLP approach and test your GCN. Compare your result with the dot product head, and interpret your findings.