Inductive Representation Learning on Large Graphs
February 3, 2026
Paper: Hamilton, W., Ying, Z., & Leskovec, J. (2017). Inductive Representation Learning on Large Graphs. NeurIPS 2017. (arXiv:1706.02216)
The Problem
Most early graph embedding methods — DeepWalk, node2vec, LINE — are transductive. They learn a fixed embedding vector for every node seen during training. At inference time, if you encounter a node that wasn’t in the training graph, you’re stuck. There’s no embedding for it. You have to re-run the entire optimization from scratch.
This is a serious limitation in practice. Graphs grow. New users join a social network. New proteins get discovered. New papers are published and cite old ones. Any real deployment needs a model that can generalize to nodes it has never seen.
GraphSAGE (SAmple and aggreGatE) solves this by shifting the learning objective: instead of learning what the embedding of node is, it learns how to compute an embedding from a node’s local neighborhood. The aggregation function is what gets trained, and it can be applied to any node — seen or unseen.
The Core Idea
The key insight is simple: a node’s identity is largely determined by its neighborhood.
Instead of storing a lookup table of embeddings, GraphSAGE parameterizes an aggregation function that takes a node’s feature vector and the features of its sampled neighbors, and produces an embedding. This function is learned during training and can generalize inductively.
The forward pass for a single layer looks like this:
Where:
- is the representation of node at layer
- is a sampled subset of ‘s neighbors
- is a trainable aggregation function
- is a weight matrix learned at layer
- is a nonlinearity (e.g., ReLU)
After each layer the representations are -normalized:
This normalization is a small but important detail — it keeps embeddings on a unit sphere and prevents gradient instability during training.
Neighbor Sampling
A subtle but crucial design choice is that GraphSAGE does not use the full neighborhood of each node. For large graphs, aggregating over hundreds or thousands of neighbors at every layer would be computationally intractable.
Instead, at each layer , a fixed-size set of neighbors is sampled uniformly at random:
The paper uses and for a two-layer model. This caps the receptive field at nodes per training example — manageable regardless of the true degree of any node.
This is what makes mini-batch training on large graphs feasible. You expand outward from a batch of target nodes, sampling at each hop, to construct a fixed-size computation graph.
The Three Aggregators
The authors propose and evaluate three choices for AGG:
1. Mean Aggregator
Simple element-wise mean over neighbor representations. The paper notes that if you skip the CONCAT and just replace with the mean of and its neighbors, you recover something close to a graph convolutional network (GCN). So the mean aggregator with CONCAT is a strict generalization of GCN — inductive and with the self-representation preserved separately.
2. LSTM Aggregator
Uses an LSTM applied to a random permutation of the neighborhood:
def lstm_aggregate(neighbor_embeddings):
# randomly permute neighbors (no natural ordering on a graph)
perm = random.shuffle(neighbor_embeddings)
_, (h, _) = lstm(perm)
return h
LSTMs have more expressive capacity than simple means. The trade-off is that they assume a sequential order, but neighborhoods have no natural ordering — so the permutation is random. Empirically it still outperforms the mean on some tasks, suggesting the LSTM learns useful patterns even without consistent ordering.
3. Pooling Aggregator
Each neighbor is passed through a fully-connected layer, then max-pooling is applied:
The linear transform before pooling allows each neighbor to “highlight” which of its features are most relevant. Max-pooling then captures the most activated signal across all neighbors. This is often the best-performing aggregator in practice.
Training
GraphSAGE can be trained in two settings:
Unsupervised: Uses a graph-based loss that encourages nearby nodes to have similar representations and distant nodes to have dissimilar ones:
Where is a node co-occurring near in a random walk, and is a negative sample drawn from a noise distribution .
Supervised / Semi-supervised: Standard cross-entropy loss on labeled nodes. In practice, supervised GraphSAGE trained end-to-end with task labels outperforms the unsupervised variant on most benchmarks.
Results
The paper evaluates on three tasks:
| Dataset | Task | Method | F1 |
|---|---|---|---|
| Citation (Cora/Citeseer) | Node classification | GraphSAGE-pool | 93.0 |
| Reddit posts | Community detection | GraphSAGE-LSTM | 95.4 |
| Protein interaction (PPI) | Multi-label node classification | GraphSAGE-pool | 61.2 |
The PPI task is the critical inductive benchmark — the test graphs contain entirely unseen proteins from different biological contexts than the training graphs. GraphSAGE-pool achieves an F1 of 0.612, compared to 0.5 for the transductive baseline which had no way to generalize. Feature-only baselines (ignoring graph structure) score around 0.4.
Key Takeaways
-
Inductive vs. transductive is an important distinction. For any application where the graph changes over time (almost all real graphs), transductive methods are not viable at scale.
-
Sampling is necessary and sufficient. You don’t need all neighbors. Fixed-size sampling enables mini-batch training on graphs with millions of nodes.
-
The aggregator choice matters but not enormously. Pooling is generally best, mean is often competitive and simplest to implement. LSTM adds cost without reliable gains.
-
Concatenating self with aggregated neighbors (rather than averaging them together) consistently helps. It preserves the distinction between the node’s own identity and its context.
-
GraphSAGE was one of the early works that demonstrated GNNs can scale. It laid the groundwork for frameworks like PinSage (deployed at Pinterest) and much of the subsequent work on scalable graph learning.
Implementation
PyTorch Geometric has a clean implementation:
from torch_geometric.nn import SAGEConv
import torch.nn.functional as F
class GraphSAGE(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
self.conv1 = SAGEConv(in_channels, hidden_channels)
self.conv2 = SAGEConv(hidden_channels, out_channels)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, p=0.5, training=self.training)
x = self.conv2(x, edge_index)
return x
For large-scale inductive training with neighbor sampling, use NeighborLoader:
from torch_geometric.loader import NeighborLoader
loader = NeighborLoader(
data,
num_neighbors=[25, 10], # S_1=25, S_2=10 as in the paper
batch_size=512,
input_nodes=train_mask,
)
GraphSAGE remains one of the most practically useful GNN architectures. When someone asks “which GNN should I start with?”, GraphSAGE is usually the right answer — it’s well-understood, scales well, and often beats more complex models on real-world tasks.