Chapter 21: Neural Message Passing for Quantum Chemistry
“We introduce a unified framework for learning on graphs, generalizing convolutional neural networks to graph-structured data.”
Based on: “Neural Message Passing for Quantum Chemistry” (Justin Gilmer, Samuel S. Schoenholz, Patrick F. Riley, Oriol Vinyals, George E. Dahl, 2017)
| 📄 Original Paper: arXiv:1704.01212 | ICML 2017 |
21.1 The Graph Problem
Many real-world problems involve graph-structured data:
- Molecules: Atoms (nodes) connected by bonds (edges)
- Social networks: People (nodes) with friendships (edges)
- Knowledge graphs: Entities (nodes) with relations (edges)
- Code: Functions (nodes) with calls (edges)
graph TB
subgraph "Graph Data"
N1["Node 1<br/>(atom, person, etc.)"]
N2["Node 2"]
N3["Node 3"]
E1["Edge<br/>(bond, connection)"]
E2["Edge"]
end
N1 --- E1 --- N2
N2 --- E2 --- N3
Q["How to apply neural networks<br/>to graph-structured data?"]
E1 --> Q
style Q fill:#ffe66d,color:#000
Standard CNNs and RNNs assume grid/sequence structure. Graphs are irregular.
21.2 Why Standard Architectures Fail
The Irregularity Problem
graph TB
subgraph "Regular Structures"
IMG["Image<br/>Regular grid"]
SEQ["Sequence<br/>Linear order"]
end
subgraph "Irregular Structures"
GRAPH["Graph<br/>Variable neighbors<br/>No fixed order"]
end
IMG -->|"CNNs work"| CNN
SEQ -->|"RNNs work"| RNN
GRAPH -->|"Need new approach"| GNN
K["Graphs have:<br/>• Variable node degrees<br/>• No spatial locality<br/>• Permutation invariance"]
GRAPH --> K
style K fill:#ff6b6b,color:#fff
The Challenge
- Variable structure: Each graph has different connectivity
- Permutation invariance: Node ordering shouldn’t matter
- Variable size: Graphs can have any number of nodes
21.3 The Message Passing Framework
Core Idea
Nodes exchange messages with their neighbors:
graph TB
subgraph "Message Passing"
N1["Node 1"]
N2["Node 2"]
N3["Node 3"]
M1["Message from N2"]
M2["Message from N1"]
M3["Message from N2"]
AGG["Aggregate messages"]
UPDATE["Update node state"]
end
N1 -->|"sends"| M2
N2 -->|"sends"| M1
N2 -->|"sends"| M3
M1 --> AGG
M2 --> AGG
M3 --> AGG
AGG --> UPDATE
K["Each node updates based on<br/>messages from neighbors"]
UPDATE --> K
style K fill:#4ecdc4,color:#fff
The General Framework
For each node $v$:
- Collect messages from neighbors
- Aggregate messages
- Update node state
21.4 The Message Passing Neural Network (MPNN)
Formal Definition
graph TB
subgraph "MPNN Framework"
H0["Initial node features<br/>h_v^0"]
subgraph "Message Passing (T steps)"
M["Message function M_t<br/>m_v^t = Σ M_t(h_v^{t-1}, h_w^{t-1}, e_{vw})"]
U["Update function U_t<br/>h_v^t = U_t(h_v^{t-1}, m_v^t)"]
end
READ["Readout function R<br/>ŷ = R({h_v^T})"]
OUTPUT["Graph-level prediction"]
end
H0 --> M --> U --> M
U --> READ --> OUTPUT
K["T steps of message passing<br/>→ Final node representations<br/>→ Graph-level prediction"]
READ --> K
style K fill:#ffe66d,color:#000
The Equations
Message: \(m_v^t = \sum_{w \in \mathcal{N}(v)} M_t(h_v^{t-1}, h_w^{t-1}, e_{vw})\)
Update: \(h_v^t = U_t(h_v^{t-1}, m_v^t)\)
Readout: \(\hat{y} = R(\{h_v^T : v \in G\})\)
Where:
- $\mathcal{N}(v)$ = neighbors of node $v$
- $e_{vw}$ = edge features
- $T$ = number of message passing steps
21.5 Specific Instantiations
Variant 1: Graph Convolutional Network (GCN)
graph TB
subgraph "GCN Variant"
H["Node features h"]
A["Adjacency matrix A"]
NORM["Normalize: D^(-1/2) A D^(-1/2)"]
CONV["Convolution: H' = σ(ÃHW)"]
end
H --> CONV
A --> NORM --> CONV
CONV --> H
K["Message = normalized sum<br/>of neighbor features"]
CONV --> K
Message: $m_v = \sum_{w \in \mathcal{N}(v)} \frac{1}{\sqrt{d_v d_w}} h_w$
Update: $h_v’ = \sigma(W m_v)$
Variant 2: Gated Graph Neural Network (GGNN)
Uses GRU for updating:
\[h_v^t = \text{GRU}(h_v^{t-1}, m_v^t)\]Variant 3: Interaction Networks
Includes edge updates:
graph TB
subgraph "Interaction Network"
N["Node features"]
E["Edge features"]
M["Message: m = f(n_i, n_j, e_ij)"]
N_UPD["Node update"]
E_UPD["Edge update"]
end
N --> M
E --> M
M --> N_UPD
M --> E_UPD
K["Both nodes and edges<br/>get updated"]
N_UPD --> K
E_UPD --> K
21.6 Application: Molecular Property Prediction
The Task
Predict properties of molecules (e.g., energy, solubility) from their structure.
graph TB
subgraph "Molecular Property Prediction"
MOL["Molecule<br/>(graph of atoms)"]
MPNN["Message Passing<br/>over molecular graph"]
REP["Node representations<br/>(after T steps)"]
READOUT["Graph-level readout"]
PROP["Property prediction<br/>(energy, solubility, etc.)"]
end
MOL --> MPNN --> REP --> READOUT --> PROP
K["Learns to aggregate<br/>molecular structure<br/>into property prediction"]
READOUT --> K
style K fill:#4ecdc4,color:#fff
Molecular Graph
graph LR
subgraph "Molecule as Graph"
C1["C (carbon)"]
C2["C"]
O["O (oxygen)"]
H1["H (hydrogen)"]
H2["H"]
C1 ---|"single bond"| C2
C1 ---|"single bond"| H1
C2 ---|"double bond"| O
C2 ---|"single bond"| H2
end
K["Nodes = atoms<br/>Edges = bonds<br/>Features = atom/bond types"]
C1 --> K
21.7 Message Passing Steps
One Step of Message Passing
graph TB
subgraph "Step t"
H_PREV["h_v^{t-1}<br/>(previous state)"]
NEIGHBORS["Neighbors<br/>{h_w^{t-1}}"]
EDGES["Edge features<br/>{e_{vw}}"]
MSG["Message function<br/>m_v^t = Σ M(h_v, h_w, e_{vw})"]
AGG["Aggregate<br/>(sum, mean, max)"]
UPD["Update function<br/>h_v^t = U(h_v^{t-1}, m_v^t)"]
end
H_PREV --> MSG
NEIGHBORS --> MSG
EDGES --> MSG
MSG --> AGG --> UPD
K["Node receives information<br/>from its neighborhood"]
UPD --> K
style K fill:#ffe66d,color:#000
Multiple Steps
graph LR
subgraph "Multi-Step Message Passing"
T0["t=0: Local<br/>(1-hop neighbors)"]
T1["t=1: 2-hop<br/>(neighbors of neighbors)"]
T2["t=2: 3-hop<br/>(wider context)"]
T3["t=T: Global<br/>(entire graph)"]
end
T0 --> T1 --> T2 --> T3
K["More steps = larger<br/>receptive field"]
T3 --> K
style K fill:#ffe66d,color:#000
After $T$ steps, each node has information from nodes up to $T$ hops away!
21.8 Aggregation Functions
Common Aggregators
graph TB
subgraph "Aggregation Options"
SUM["Sum: Σ m_w"]
MEAN["Mean: (1/|N|) Σ m_w"]
MAX["Max: max(m_w)"]
ATT["Attention: Σ α_w m_w"]
end
I["Different aggregators<br/>capture different patterns"]
SUM --> I
MEAN --> I
MAX --> I
ATT --> I
style I fill:#ffe66d,color:#000
Which to Use?
| Aggregator | Use Case |
|---|---|
| Sum | When quantity matters (e.g., counting) |
| Mean | When average is meaningful |
| Max | When presence matters (e.g., “has feature X”) |
| Attention | When importance varies |
21.9 Readout Functions
Graph-Level Prediction
After message passing, aggregate all node features:
graph TB
subgraph "Readout"
H_ALL["{h_v^T : v ∈ G}<br/>(all node features)"]
subgraph "Options"
SUM_R["Sum: Σ h_v"]
MEAN_R["Mean: (1/n) Σ h_v"]
MAX_R["Max: max(h_v)"]
SET2VEC["Set2Vec<br/>(learned aggregation)"]
end
GRAPH_REP["Graph representation"]
PRED["Prediction"]
end
H_ALL --> SUM_R --> GRAPH_REP
H_ALL --> MEAN_R --> GRAPH_REP
H_ALL --> MAX_R --> GRAPH_REP
H_ALL --> SET2VEC --> GRAPH_REP
GRAPH_REP --> PRED
K["Must be permutation-invariant<br/>(order of nodes doesn't matter)"]
GRAPH_REP --> K
style K fill:#ffe66d,color:#000
Set2Vec
A learned readout that’s more expressive:
\[r = \text{LSTM}([\text{sum}(h_v), \text{max}(h_v)])\]21.10 Experimental Results
Quantum Chemistry Datasets
The paper evaluates on molecular property prediction:
xychart-beta
title "MAE on QM9 Dataset (lower is better)"
x-axis ["Baseline", "MPNN (sum)", "MPNN (set2vec)"]
y-axis "MAE" 0 --> 5
bar [4.2, 2.1, 1.5]
MPNN achieves state-of-the-art on molecular property prediction!
Learned Representations
The model learns meaningful molecular features:
- Atom types: Carbon, oxygen, nitrogen, etc.
- Bond patterns: Single, double, triple bonds
- Molecular structure: Rings, chains, branches
21.11 Connection to Other GNN Variants
Unified View
graph TB
subgraph "GNN Variants as MPNNs"
GCN["Graph Convolutional Network<br/>Message = normalized sum"]
GAT["Graph Attention Network<br/>Message = attention-weighted"]
GIN["Graph Isomorphism Network<br/>Message = MLP(sum)"]
GGNN["Gated Graph NN<br/>Update = GRU"]
end
MPNN["MPNN Framework<br/>(unifies all)"]
GCN --> MPNN
GAT --> MPNN
GIN --> MPNN
GGNN --> MPNN
K["All are special cases<br/>of message passing!"]
MPNN --> K
style K fill:#4ecdc4,color:#fff
The Power of Unification
This paper showed that many GNN architectures are instantiations of the same framework:
- Different message functions
- Different update functions
- Different aggregation strategies
21.12 Modern Graph Neural Networks
Evolution
timeline
title GNN Evolution
2017 : MPNN paper
: Unified framework
2018 : Graph Attention Networks
: Attention in graphs
2019 : Graph Isomorphism Networks
: Expressive power analysis
2020 : Graph Transformers
: Self-attention on graphs
2020s : Modern GNNs
: Scalable, efficient
: Applications everywhere
Current Applications
- Drug discovery: Molecular property prediction
- Social networks: Node classification, link prediction
- Recommendation: User-item graphs
- Knowledge graphs: Entity linking, reasoning
- Code analysis: Program graphs
21.13 Implementation
Simple MPNN Layer
class MPNNLayer(nn.Module):
def __init__(self, node_dim, edge_dim, hidden_dim):
super().__init__()
self.message_net = nn.Sequential(
nn.Linear(node_dim * 2 + edge_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim)
)
self.update_net = nn.GRUCell(hidden_dim, node_dim)
def forward(self, node_features, edge_index, edge_features):
# node_features: [N, node_dim]
# edge_index: [2, E] (source, target)
# edge_features: [E, edge_dim]
messages = []
for i in range(len(edge_index[0])):
src, tgt = edge_index[0][i], edge_index[1][i]
# Concatenate source, target, edge features
msg_input = torch.cat([
node_features[src],
node_features[tgt],
edge_features[i]
])
msg = self.message_net(msg_input)
messages.append((tgt, msg))
# Aggregate messages per node
aggregated = {}
for tgt, msg in messages:
if tgt not in aggregated:
aggregated[tgt] = []
aggregated[tgt].append(msg)
# Update nodes
new_features = []
for i in range(len(node_features)):
if i in aggregated:
msg = torch.stack(aggregated[i]).mean(dim=0)
else:
msg = torch.zeros(hidden_dim)
new_feat = self.update_net(msg, node_features[i])
new_features.append(new_feat)
return torch.stack(new_features)
21.14 Connection to Other Chapters
graph TB
CH21["Chapter 21<br/>Message Passing"]
CH21 --> CH19["Chapter 19: Seq2Seq for Sets<br/><i>Set aggregation</i>"]
CH21 --> CH22["Chapter 22: Relational Reasoning<br/><i>Pairwise relations</i>"]
CH21 --> CH16["Chapter 16: Transformers<br/><i>Attention mechanisms</i>"]
CH21 --> CH14["Chapter 14: Relational RNNs<br/><i>Relational processing</i>"]
style CH21 fill:#ff6b6b,color:#fff
21.15 Key Equations Summary
Message Passing
\[m_v^t = \sum_{w \in \mathcal{N}(v)} M_t(h_v^{t-1}, h_w^{t-1}, e_{vw})\]Update
\[h_v^t = U_t(h_v^{t-1}, m_v^t)\]Readout
\[\hat{y} = R(\{h_v^T : v \in G\})\]GCN Variant
\[h_v' = \sigma\left(W \sum_{w \in \mathcal{N}(v)} \frac{1}{\sqrt{d_v d_w}} h_w\right)\]21.16 Chapter Summary
graph TB
subgraph "Key Takeaways"
T1["Message passing unifies<br/>graph neural networks"]
T2["Nodes exchange messages<br/>with neighbors"]
T3["Multiple steps create<br/>larger receptive fields"]
T4["Readout aggregates node<br/>features for graph prediction"]
T5["Works for molecules,<br/>social networks, etc."]
end
T1 --> C["The Message Passing framework<br/>provides a unified view of graph<br/>neural networks, where nodes<br/>exchange information with neighbors<br/>over multiple steps to learn<br/>graph-structured representations."]
T2 --> C
T3 --> C
T4 --> C
T5 --> C
style C fill:#ffe66d,color:#000,stroke:#000,stroke-width:2px
In One Sentence
The Message Passing framework unifies graph neural networks by having nodes exchange messages with neighbors over multiple steps, enabling learning on graph-structured data like molecules, social networks, and knowledge graphs.
Exercises
-
Conceptual: Explain why message passing is permutation-invariant. Why is this important for graph learning?
-
Implementation: Implement a simple MPNN for node classification on a small graph dataset (e.g., Cora or CiteSeer).
-
Analysis: Compare the receptive field of a 3-layer MPNN vs a 3-layer CNN. How do they differ?
-
Extension: How would you modify message passing to handle directed graphs? What about graphs with multiple edge types?
References & Further Reading
| Resource | Link |
|---|---|
| Original Paper (Gilmer et al., 2017) | arXiv:1704.01212 |
| Graph Convolutional Networks | arXiv:1609.02907 |
| Graph Attention Networks | arXiv:1710.10903 |
| Graph Isomorphism Networks | arXiv:1810.00826 |
| PyTorch Geometric | pytorch-geometric.readthedocs.io |
| Deep Learning on Graphs | Book |
Next Chapter: Chapter 22: A Simple Neural Network Module for Relational Reasoning — We explore Relation Networks, which explicitly model pairwise relationships between objects for tasks like visual question answering.