Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Graph Transformers

Self-Attention as Message Passing

Let’s assume that we have three tokens with scalar features:

x1=1,x2=2,x3=0x_1 = 1, \quad x_2 = 2, \quad x_3 = 0

Note: This is a toy model. In practice, token features are high-dimensional vectors.

For simplicity, we define WQ=WK=WV=IW^Q=W^K=W^V=I, therefore:

qi=xi,ki=xi,vi=xiq_i = x_i,\quad k_i = x_i,\quad v_i = x_i

And the self-attention update for token 1:

z1=j=13α1jvj,α1j=eq1kjeq1kz_1 = \sum_{j=1}^3 \alpha_{1j} v_j, \quad \alpha_{1j} = \frac{e^{q_1 k_j}}{\sum_{\ell} e^{q_1 k_\ell}}

Before doing any math, answer:

  1. Which token do you expect token 1 to pay most attention to?

  2. Which token should receive the least attention from others?

Now, compute and answer:

  1. All α1j\alpha_{1j} and then z1z_1. Describe what z1z_1 represents in this example.

  2. Are α12\alpha_{12} and α21\alpha_{21} equal? What does this say about attention as an “edge weight”?

  3. What happens if all xix_i are equal? What kind of GNN is this equivalent to?

  4. What changes in the self-attention mechanism when moving from a Transformer to a GAT?

Graph Laplacian Magic

In previous lectures, we used the Laplacian matrix LL, but what does it actually mean?

Given an undirected graph with adjacency matrix AA and degree matrix DD, LL is defined as:

L=DAL = D - A

Consider the following path graph of 3 nodes:

alt text
  1. What is the Laplacian matrix for this graph?

  1. Verify that the following are eigenvectors of LL, and find the corresponding eigenvalues:

    • (1,0,1)(-1,0,1)^\top

    • (1,1,1)(1,1,1)^\top

    • (1,2,1)(1,-2,1)^\top

    Refresher: Lv=λvLv=\lambda v, where vv is the eigenvector and λ\lambda is the eigenvalue.

  1. Order the eigenvectors from slowly varying to rapidly varying (i.e., smooth vs oscillatory) across the graph. Then, look at the corresponding eigenvalues. What do you observe about how the eigenvalues are ordered relative to the variation speed?

  1. Based on your finding in (3), which eigenvectors capture the global structure and which eigenvectors capture the local variation?

  1. Now let’s verify our findings in a relatively larger graph. We will revisit the Karate Club graph and plot what the eigenvectors highlight within the graph.

    The following code snippet computes the eigenvectors and sorts them based on the corresponding eigenvalues. Plot different eigenvectors across the graph to see which information they encode (local vs global). Does it match with your finding in (4)?

import networkx as nx
import numpy as np
import matplotlib.pyplot as plt

# load the Karate Club graph
G = nx.karate_club_graph()

# compute the Laplacian matrix
L = nx.laplacian_matrix(G).toarray()

# compute eigenvalues and eigenvectors
eigvals, eigvecs = np.linalg.eigh(L)

# sort eigenvalues and eigenvectors
idx = np.argsort(eigvals)
eigvals = eigvals[idx]
eigvecs = eigvecs[:, idx]

# plot the eigenvectors
pos = nx.spring_layout(G, seed=42)
for k in [1]:
	plt.figure()
	nx.draw(G, pos, node_color=eigvecs[:, k], cmap='coolwarm', with_labels=False)
	plt.title(f"Laplacian eigenvector {k}")
	plt.show()

Eigenvector Sign Ambiguity

Recall that Lv=λvLv=\lambda v, where vv is the eigenvector and λ\lambda is the eigenvalue. This also means:

L(v)=λ(v)L(-v)=\lambda (-v)
  1. Why is this a problem when we want to use the eigenvectors as positional encodings?

  2. Would an attention-based model automatically be invariant to this sign ambiguity? Why or why not?

  3. If we need sign invariance, why don’t we just take v|v| before using it as a positional encoding?

  4. How does SignNet (paper link) solve this problem? You can play with the following code snippet which implements a simplifed version of SignNet.

import torch
import torch.nn as nn

class SimpleSignNet(nn.Module):
    def __init__(self, hidden_dim=16):
        super().__init__()
        self.phi = nn.Sequential(
            nn.Linear(1, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

    def forward(self, v):
        # v: (n,) eigenvector
        v = v.unsqueeze(-1)        # (n, 1)

        out_pos = self.phi(v)      # φ(v)
        out_neg = self.phi(-v)     # φ(-v)

        return out_pos + out_neg   # sign-invariant
    
# toy eigenvector
v = torch.tensor([1.0, -2.0, 0.5])

model = SimpleSignNet()

z1 = model(v)
z2 = model(-v)

print(torch.allclose(z1, z2))