Chapter 22: A Simple Neural Network Module for Relational Reasoning

“We introduce a simple plug-and-play module for relational reasoning that can be added to any neural network architecture.”

Based on: “A Simple Neural Network Module for Relational Reasoning” (Adam Santoro, David Raposo, David G.T. Barrett, et al., 2017)

📄 Original Paper: arXiv:1706.01427 NeurIPS 2017

22.1 The Relational Reasoning Challenge

Many AI tasks require relational reasoning: understanding relationships between objects.

graph TB
    subgraph "Relational Reasoning Tasks"
        Q1["'Is the red ball<br/>to the left of the blue cube?'"]
        Q2["'How many objects<br/>are the same color?'"]
        Q3["'What is the relationship<br/>between object A and B?'"]
    end
    
    R["Requires comparing<br/>and relating entities"]
    
    Q1 --> R
    Q2 --> R
    Q3 --> R
    
    style R fill:#ffe66d,color:#000

Standard CNNs process images globally—they don’t explicitly model pairwise relationships.


22.2 The Relation Network (RN)

Core Idea

Explicitly compute relationships between all pairs of objects:

graph TB
    subgraph "Relation Network"
        O["Objects<br/>{o₁, o₂, ..., o_n}"]
        
        PAIRS["All pairs<br/>(o_i, o_j)"]
        
        REL["Relation function<br/>r_ij = f(o_i, o_j)"]
        
        AGG["Aggregate<br/>Σ r_ij"]
        
        OUT["Output"]
    end
    
    O --> PAIRS --> REL --> AGG --> OUT
    
    K["Explicitly models<br/>pairwise relationships"]
    
    REL --> K
    
    style K fill:#4ecdc4,color:#fff

The Formula

\[RN(O) = f_\phi\left(\sum_{i,j} g_\theta(o_i, o_j)\right)\]

Where:

  • $g_\theta$ = relation function (MLP)
  • $f_\phi$ = aggregation function (MLP)
  • $O = {o_1, …, o_n}$ = set of objects

22.3 Architecture Details

The Relation Function

For each pair $(o_i, o_j)$:

graph LR
    subgraph "Relation Function"
        OI["o_i<br/>(object i)"]
        OJ["o_j<br/>(object j)"]
        CONCAT["Concatenate<br/>[o_i, o_j]"]
        MLP["MLP"]
        RIJ["r_ij<br/>(relationship)"]
    end
    
    OI --> CONCAT
    OJ --> CONCAT
    CONCAT --> MLP --> RIJ
    
    K["Learns to compute<br/>relationship between<br/>any two objects"]
    
    RIJ --> K
    
    style K fill:#ffe66d,color:#000

Mathematical Formulation

\[r_{ij} = g_\theta(o_i, o_j) = \text{MLP}([o_i, o_j])\]

The MLP learns what relationships to extract.


22.4 Application: Visual Question Answering

The CLEVR Dataset

CLEVR (Compositional Language and Elementary Visual Reasoning):

  • Synthetic images with geometric objects
  • Questions requiring relational reasoning
  • Example: “How many red objects are to the left of the blue cube?”
graph TB
    subgraph "Visual Question Answering"
        IMG["Image<br/>(objects in scene)"]
        CNN["CNN<br/>(extract objects)"]
        OBJ["Object features<br/>{o₁, o₂, ..., o_n}"]
        Q["Question<br/>'How many red objects<br/>are left of blue cube?'"]
        Q_ENC["Question encoder"]
        Q_FEAT["Question features q"]
        RN["Relation Network<br/>RN({o_i}, q)"]
        ANS["Answer"]
    end
    
    IMG --> CNN --> OBJ
    Q --> Q_ENC --> Q_FEAT
    OBJ --> RN
    Q_FEAT --> RN
    RN --> ANS
    
    K["RN reasons about<br/>object relationships<br/>conditioned on question"]
    
    RN --> K
    
    style K fill:#4ecdc4,color:#fff

Question-Conditioned Relations

The relation function can be conditioned on the question:

\[r_{ij} = g_\theta(o_i, o_j, q)\]

This allows the network to focus on relevant relationships for the question.


22.5 The Complete Architecture

For Visual Question Answering

graph TB
    subgraph "RN for VQA"
        IMG["Image"]
        CNN["CNN<br/>(ResNet)"]
        OBJ["Object features<br/>14×14×1024"]
        SPAT["Spatial coordinates<br/>(x, y positions)"]
        CONCAT1["Concatenate<br/>object + position"]
        
        Q["Question"]
        LSTM["LSTM"]
        Q_FEAT["Question features"]
        
        PAIRS["All pairs<br/>(o_i, o_j, q)"]
        REL["Relation function<br/>g(o_i, o_j, q)"]
        SUM["Sum all relations"]
        AGG["Aggregation<br/>f(Σ r_ij)"]
        ANS["Answer"]
    end
    
    IMG --> CNN --> OBJ
    OBJ --> CONCAT1
    SPAT --> CONCAT1
    CONCAT1 --> PAIRS
    
    Q --> LSTM --> Q_FEAT --> PAIRS
    
    PAIRS --> REL --> SUM --> AGG --> ANS
    
    K["14×14 = 196 objects<br/>→ 196×196 = 38,416 pairs!"]
    
    PAIRS --> K
    
    style K fill:#ff6b6b,color:#fff

Computational Complexity

For $n$ objects:

  • Pairs: $O(n^2)$
  • Relation computation: $O(n^2)$
  • Total: $O(n^2)$

This can be expensive for large $n$!


22.6 Results on CLEVR

Performance

xychart-beta
    title "CLEVR Accuracy"
    x-axis ["Human", "CNN+LSTM", "CNN+RN", "CNN+RN (ours)"]
    y-axis "Accuracy %" 0 --> 100
    bar [92.6, 52.3, 68.5, 95.5]

Relation Networks achieve near-human performance on CLEVR!

What the Network Learned

The RN learns to answer questions like:

  • Counting: “How many objects?”
  • Spatial: “What is left of X?”
  • Attribute: “What color is the cube?”
  • Comparison: “Are there more red than blue objects?”
  • Compositional: “What is the shape of the object that is the same size as the red sphere?”

22.7 Why Relation Networks Work

Explicit Relationship Modeling

graph TB
    subgraph "Standard CNN"
        IMG["Image"]
        CNN["CNN"]
        GLOBAL["Global features"]
        PRED["Prediction"]
    end
    
    subgraph "Relation Network"
        IMG2["Image"]
        CNN2["CNN"]
        OBJ["Object features"]
        PAIRS["All pairs"]
        REL["Relationships"]
        PRED2["Prediction"]
    end
    
    IMG --> CNN --> GLOBAL --> PRED
    IMG2 --> CNN2 --> OBJ --> PAIRS --> REL --> PRED2
    
    K["RN explicitly reasons<br/>about relationships<br/>→ Better for relational tasks"]
    
    REL --> K
    
    style K fill:#4ecdc4,color:#fff

Compositionality

Relations can be composed:

graph LR
    subgraph "Compositional Reasoning"
        R1["r(A, B)<br/>'A is left of B'"]
        R2["r(B, C)<br/>'B is left of C'"]
        COMP["Compose<br/>→ 'A is left of C'"]
    end
    
    R1 --> COMP
    R2 --> COMP
    
    K["Learns transitive<br/>relationships"]
    
    COMP --> K
    
    style K fill:#ffe66d,color:#000

22.8 Comparison with Other Approaches

Standard VQA Models

Approach CLEVR Accuracy
CNN + LSTM ~52%
Attention-based ~68%
Relation Network ~96%

Why RN Wins

graph TB
    subgraph "Advantages"
        A1["Explicit relationship<br/>modeling"]
        A2["Compositional<br/>reasoning"]
        A3["Question-conditioned<br/>relations"]
        A4["Simple architecture<br/>(easy to add)"]
    end
    
    S["Superior performance<br/>on relational tasks"]
    
    A1 --> S
    A2 --> S
    A3 --> S
    A4 --> S
    
    style S fill:#4ecdc4,color:#fff

22.9 Efficiency Considerations

The O(n²) Problem

For 196 objects (14×14 grid):

  • 38,416 pairs to process
  • Computationally expensive

Solutions

graph TB
    subgraph "Efficiency Solutions"
        S1["Object detection<br/>(fewer objects)"]
        S2["Sampling pairs<br/>(not all pairs)"]
        S3["Hierarchical relations<br/>(coarse to fine)"]
        S4["Attention-based<br/>(focus on relevant)"]
    end
    
    E["Reduce computation<br/>while maintaining<br/>performance"]
    
    S1 --> E
    S2 --> E
    S3 --> E
    S4 --> E

22.10 Connection to Attention

Relation Networks as Attention

Relation Networks can be viewed as a form of attention:

graph TB
    subgraph "Attention View"
        Q["Query (question)"]
        K["Keys (objects)"]
        V["Values (objects)"]
        ATT["Attention<br/>α_ij = f(o_i, o_j, q)"]
        OUT["Weighted combination"]
    end
    
    Q --> ATT
    K --> ATT
    V --> ATT
    ATT --> OUT
    
    K2["RN computes attention<br/>over all object pairs"]
    
    ATT --> K2
    
    style K2 fill:#ffe66d,color:#000

Difference from Standard Attention

  • Standard attention: Attends to individual objects
  • Relation Networks: Attend to pairs of objects

22.11 Modern Applications

Where Relation Networks Appear

graph TB
    subgraph "Applications"
        VQA["Visual Question Answering<br/>(CLEVR, VQA dataset)"]
        REASON["Scene understanding<br/>(spatial reasoning)"]
        PHYSICS["Physics simulation<br/>(object interactions)"]
        GAMES["Game playing<br/>(strategic reasoning)"]
    end
    
    RN["Relation Networks"]
    
    RN --> VQA
    RN --> REASON
    RN --> PHYSICS
    RN --> GAMES
    
    style RN fill:#4ecdc4,color:#fff

In Modern Architectures

  • Transformer attention: Can be viewed as relation computation
  • Graph neural networks: Explicitly model relationships
  • Object-centric models: Use relational reasoning

22.12 Connection to Other Chapters

graph TB
    CH22["Chapter 22<br/>Relational Reasoning"]
    
    CH22 --> CH14["Chapter 14: Relational RNNs<br/><i>Relational processing</i>"]
    CH22 --> CH21["Chapter 21: Message Passing<br/><i>Graph relationships</i>"]
    CH22 --> CH19["Chapter 19: Seq2Seq for Sets<br/><i>Set processing</i>"]
    CH22 --> CH16["Chapter 16: Transformers<br/><i>Self-attention as relations</i>"]
    
    style CH22 fill:#ff6b6b,color:#fff

22.13 Key Equations Summary

Basic Relation Network

\[RN(O) = f_\phi\left(\sum_{i,j} g_\theta(o_i, o_j)\right)\]

Question-Conditioned

\[RN(O, q) = f_\phi\left(\sum_{i,j} g_\theta(o_i, o_j, q)\right)\]

Relation Function

\[r_{ij} = g_\theta(o_i, o_j) = \text{MLP}([o_i, o_j])\]

With Question

\[r_{ij} = g_\theta(o_i, o_j, q) = \text{MLP}([o_i, o_j, q])\]

22.14 Chapter Summary

graph TB
    subgraph "Key Takeaways"
        T1["Relation Networks explicitly<br/>model pairwise relationships"]
        T2["Compute relations for<br/>all object pairs"]
        T3["Question-conditioned<br/>relations focus on relevance"]
        T4["Achieves near-human<br/>performance on CLEVR"]
        T5["Simple plug-and-play<br/>module"]
    end
    
    T1 --> C["Relation Networks provide a simple<br/>yet powerful way to add relational<br/>reasoning to neural networks by<br/>explicitly computing pairwise<br/>relationships between objects,<br/>enabling superior performance on<br/>tasks requiring compositional reasoning."]
    T2 --> C
    T3 --> C
    T4 --> C
    T5 --> C
    
    style C fill:#ffe66d,color:#000,stroke:#000,stroke-width:2px

In One Sentence

Relation Networks add explicit relational reasoning to neural networks by computing pairwise relationships between all objects, achieving near-human performance on visual question answering tasks like CLEVR.


Exercises

  1. Conceptual: Explain why computing all pairwise relationships is important for relational reasoning tasks. What are the computational trade-offs?

  2. Implementation: Implement a simple Relation Network for a small visual question answering task. Start with 5-10 objects.

  3. Analysis: Compare the computational complexity of Relation Networks vs standard attention mechanisms. When does each have advantages?

  4. Extension: How would you modify Relation Networks to handle higher-order relationships (triplets, quadruplets) efficiently?


References & Further Reading

Resource Link
Original Paper (Santoro et al., 2017) arXiv:1706.01427
CLEVR Dataset GitHub
Visual Question Answering Survey arXiv:1610.01465
Object-Centric Learning arXiv:1806.08572
Relational Deep Reinforcement Learning arXiv:1806.01830

Next Chapter: Chapter 23: Variational Lossy Autoencoder — We explore how variational autoencoders can be improved using lossy compression principles, connecting back to the MDL foundations from Chapter 1.



Back to top

Educational content based on public research papers. All original papers are cited with links to their sources.