Skip to content

Commit d8bc455

Browse files
authored
Merge pull request #254 from FluxML/examples
Fix GCN examples
2 parents 27c2ce4 + 38af46c commit d8bc455

25 files changed

+761
-216
lines changed

Project.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,13 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2727
CUDA = "3"
2828
ChainRulesCore = "1.7"
2929
DataStructures = "0.18"
30-
FillArrays = "0.12"
30+
FillArrays = "0.12 - 0.13"
3131
Flux = "0.12"
3232
GraphMLDatasets = "0.1"
3333
GraphSignals = "0.3"
34-
Graphs = "1.4"
35-
NNlib = "0.7"
36-
NNlibCUDA = "0.1"
34+
Graphs = "1"
35+
NNlib = "0.7 - 0.8"
36+
NNlibCUDA = "0.1 - 0.2"
3737
Reexport = "1.1"
3838
Word2Vec = "0.5"
3939
Zygote = "0.6"

docs/make.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,11 @@ makedocs(
2121
"Building layers" => "basics/layers.md",
2222
"Graph passing" => "basics/passgraph.md"],
2323
"Cooperate with Flux layers" => "cooperate.md",
24+
"Tutorials" =>
25+
[
26+
"Semi-supervised learning with GCN" => "tutorials/semisupervised_gcn.md",
27+
"GCN with Fixed Graph" => "tutorials/gcn_fixed_graph.md",
28+
],
2429
"Abstractions" =>
2530
["Message passing scheme" => "abstractions/msgpass.md",
2631
"Graph network block" => "abstractions/gn.md"],

docs/src/tutorials/gcn_fixed_graph.md

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# GCN with Fixed Graph
2+
3+
In the tutorial for semi-supervised learning with GCN, variable graphs are provided to GNN from `FeaturedGraph`, which contains a graph and node features. Each `FeaturedGraph` object can contain different graph and different node features, and can be train on the same GNN model. However, variable graph doesn't have the proper form of graph structure with respect to GNN layers and this lead to inefficient training/inference process. Fixed graph strategy can be used to train a GNN model with the same graph structure in GeometricFlux.
4+
5+
## Fixed Graph
6+
7+
A fixed graph is given to a layer by `WithGraph` syntax. `WithGraph` wrap a `FeaturedGraph` object and a GNN layer as first and second arguments, respectively.
8+
9+
```julia
10+
fg = FeaturedGraph(graph)
11+
WithGraph(fg, GCNConv(1024=>256, relu))
12+
```
13+
14+
This way, we can customize by binding different graph to certain layer and the layer will specialize graph to a required form. For example, a `GCNConv` layer requires graph in the form of normalized adjacency matrix. Once the graph is bound to a `GCNConv` layer, it transforms graph into normalized adjacency matrix and stores in `WithGraph` object. It accelerates training or inference by avoiding calculating transformations. The features in `FeaturedGraph` object in `WithGraph` are not used in any layer or model training or inference.
15+
16+
## Array in, Array out
17+
18+
With this approach, a GNN layer accepts features in array. It takes an array as input and outputs array. Thus, a GNN layer wrapped with `WithGraph` should accept a feature array, just like regular deep learning model.
19+
20+
## Batch Learning
21+
22+
Since features are in the form of array, they can be batched up for batched learning. We will demonstrate how to achieve these goals.
23+
24+
## Step 1: Load Dataset
25+
26+
Different from loading datasets in semi-supervised learning example, we use `alldata` for supervised learning here and `padding=true` is added in order to padding features from partial nodes to pseudo-full nodes. A padded features contains zeros in the nodes that are not supposed to be train on.
27+
28+
```julia
29+
train_X, train_y = map(x -> Matrix(x), alldata(Planetoid(), dataset, padding=true))
30+
```
31+
32+
We need graph and node indices for training as well.
33+
34+
```julia
35+
g = graphdata(Planetoid(), dataset)
36+
train_idx = 1:size(train_X, 2)
37+
```
38+
39+
## Step 2: Batch up Features and Labels
40+
41+
In order to make batch learning available, we separate graph and node features. We don't subgraph here. Node features are batched up by repeating node features here for demonstration, since planetoid dataset doesn't have batched settings. Different repeat numbers can be specified by `train_repeats` and `train_repeats`.
42+
43+
```julia
44+
fg = FeaturedGraph(g)
45+
train_data = (repeat(train_X, outer=(1,1,train_repeats)), repeat(train_y, outer=(1,1,train_repeats)))
46+
```
47+
48+
## Step 3: Build a GCN model
49+
50+
Here comes to building a GCN model. We build a model as building a regular Flux model but just wrap `GCNConv` layer with `WithGraph`.
51+
52+
```julia
53+
model = Chain(
54+
WithGraph(fg, GCNConv(args.input_dim=>args.hidden_dim, relu)),
55+
Dropout(0.5),
56+
WithGraph(fg, GCNConv(args.hidden_dim=>args.target_dim)),
57+
)
58+
```
59+
60+
## Step 4: Loss Functions and Accuracy
61+
62+
Almost all codes are the same as in semi-supervised learning example, except that indices for subgraphing are needed to get partial features out for calculating loss.
63+
64+
```julia
65+
l2norm(x) = sum(abs2, x)
66+
67+
function model_loss(model, λ, X, y, idx)
68+
loss = logitcrossentropy(model(X)[:,idx,:], y[:,idx,:])
69+
loss += λ*sum(l2norm, Flux.params(model[1]))
70+
return loss
71+
end
72+
```
73+
74+
And the accuracy measurement also needs indices.
75+
76+
```julia
77+
function accuracy(model, X::AbstractArray, y::AbstractArray, idx)
78+
return mean(onecold(softmax(cpu(model(X))[:,idx,:])) .== onecold(cpu(y)[:,idx,:]))
79+
end
80+
81+
accuracy(model, loader::DataLoader, device, idx) = mean(accuracy(model, X |> device, y |> device, idx) for (X, y) in loader)
82+
```
83+
84+
## Step 5: Training GCN Model
85+
86+
```julia
87+
train_loader, test_loader, fg, train_idx, test_idx = load_data(:cora, args.batch_size)
88+
89+
# optimizer
90+
opt = ADAM(args.η)
91+
92+
# parameters
93+
ps = Flux.params(model)
94+
95+
# training
96+
train_steps = 0
97+
@info "Start Training, total $(args.epochs) epochs"
98+
for epoch = 1:args.epochs
99+
@info "Epoch $(epoch)"
100+
101+
for (X, y) in train_loader
102+
grad = gradient(() -> model_loss(model, args.λ, X |> device, y |> device, train_idx |> device), ps)
103+
Flux.Optimise.update!(opt, ps, grad)
104+
train_steps += 1
105+
end
106+
end
107+
```
108+
109+
Now we could just train the GCN model directly!
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
# Semi-supervised Learning with Graph Convolution Networks (GCN)
2+
3+
Graph convolution networks (GCN) have been considered as the first step to graph neural networks (GNN). This example will go through how to train a vanilla GCN.
4+
5+
## Semi-supervised Learning in Graph Neural Networks
6+
7+
The semi-supervised learning task defines a learning by given features and labels for only partial nodes in a graph. We train features and labels for partial nodes, and test the model for another partial nodes in graph.
8+
9+
## Node Classification task
10+
11+
In this task, we learn a node classification task which learns a model to predict labels for each node in a graph. In GCN network, node features are given and the model outputs node labels.
12+
13+
## Step 1: Load Dataset
14+
15+
GeometricFlux provides planetoid dataset in `GeometricFlux.Datasets`, which is provided by GraphMLDatasets. Planetoid dataset has three sub-datasets: Cora, Citeseer, PubMed. We demonstrate Cora dataset in this example. `traindata` provides the functionality for loading training data from various kinds of datasets. Dataset can be specified by the first argument, and the second for sub-datasets.
16+
17+
```julia
18+
using GeometricFlux.Datasets
19+
20+
train_X, train_y = traindata(Planetoid(), :cora)
21+
```
22+
23+
`traindata` returns a pre-defined training features and labels. These features are node features.
24+
25+
```julia
26+
train_X, train_y = map(x->Matrix(x), traindata(Planetoid(), :cora))
27+
```
28+
29+
We can load graph from `graphdata`, and the graph is preprocessed into `SimpleGraph` type, which is provided by Graphs.
30+
31+
```julia
32+
g = graphdata(Planetoid(), :cora)
33+
train_idx = train_indices(Planetoid(), :cora)
34+
```
35+
36+
We need node indices to index a subgraph from original graph. `train_indices` gives node indices for training.
37+
38+
## Step 2: Wrapping Graph and Features into `FeaturedGraph`
39+
40+
`FeaturedGraph` is a container for holding a graph, node features, edge features and global features. It is provided by GraphSignals. To wrap graph and node features into `FeaturedGraph`, graph `g` should be placed as the first argument and `nf` is to specify node features.
41+
42+
```julia
43+
using GraphSignals
44+
45+
FeaturedGraph(g, nf=train_X)
46+
```
47+
48+
If we want to get a subgraph from a `FeaturedGraph` object, we call `subgraph` and provide node indices `train_idx` as second argument.
49+
50+
```julia
51+
subgraph(FeaturedGraph(g, nf=train_X), train_idx)
52+
```
53+
54+
## Step 3: Build a GCN model
55+
56+
A GCn model is composed of two layers of `GCNConv` and the activation function for first layer is `relu`. In the middle, a `Dropout` layer is placed. We need a `GraphParallel` to integrate with regular Flux layer, and it specifies node features go to `node_layer=Dropout(0.5)`.
57+
58+
```julia
59+
model = Chain(
60+
GCNConv(input_dim=>hidden_dim, relu),
61+
GraphParallel(node_layer=Dropout(0.5)),
62+
GCNConv(hidden_dim=>target_dim),
63+
node_feature,
64+
)
65+
```
66+
67+
Since the model input is a `FeaturedGraph` object, the model output a `FeaturedGraph` object as well. In the end of model, we get node features out from a `FeaturedGraph` object using `node_feature`.
68+
69+
## Step 4: Loss Functions and Accuracy
70+
71+
Then, since it is a node classification task, we define the model loss by `logitcrossentropy`, and a L2 regularization is used. In the vanilla GCN, only first layer is applied to L2 regularization and can be adjusted by hyperparameter `λ`.
72+
73+
```julia
74+
l2norm(x) = sum(abs2, x)
75+
76+
function model_loss(model, λ, batch)
77+
loss = 0.f0
78+
for (x, y) in batch
79+
loss += logitcrossentropy(model(x), y)
80+
loss += λ*sum(l2norm, Flux.params(model[1]))
81+
end
82+
return loss
83+
end
84+
```
85+
86+
Accuracy for a batch and for data loader are provided.
87+
88+
```julia
89+
function accuracy(model, batch::AbstractVector)
90+
return mean(mean(onecold(softmax(cpu(model(x)))) .== onecold(cpu(y))) for (x, y) in batch)
91+
end
92+
93+
accuracy(model, loader::DataLoader, device) = mean(accuracy(model, batch |> device) for batch in loader)
94+
```
95+
96+
## Step 5: Training GCN Model
97+
98+
We train the model with the same process as training a Flux model.
99+
100+
```julia
101+
train_loader, test_loader = load_data(:cora, args.batch_size)
102+
103+
# optimizer
104+
opt = ADAM(args.η)
105+
106+
# parameters
107+
ps = Flux.params(model)
108+
109+
# training
110+
train_steps = 0
111+
@info "Start Training, total $(args.epochs) epochs"
112+
for epoch = 1:args.epochs
113+
@info "Epoch $(epoch)"
114+
115+
for batch in train_loader
116+
grad = gradient(() -> model_loss(model, args.λ, batch |> device), ps)
117+
Flux.Optimise.update!(opt, ps, grad)
118+
train_steps += 1
119+
end
120+
end
121+
```
122+
123+
So far, we complete a basic tutorial for training a GCN model!
124+
125+
For the complete example, please check the script `examples/semisupervised_gcn.jl`.
126+
127+
## Acceleration by Pre-computing Normalized Adjacency Matrix
128+
129+
The training process can be slow in this example. Since we place the graph and features together in `FeaturedGraph` object, `GCNConv` will need to compute a normalized adjacency matrix in the training process. This behavior will lead to long training time. We can accelerate training process by pre-compute normalized adjacency matrix for all `FeaturedGraph` objects. To do so, we can call the following function and it will compute normalized adjacency matrix for `fg` before training. This will reduce the training time.
130+
131+
```julia
132+
GraphSignals.normalized_adjacency_matrix!(fg)
133+
```
134+
135+
Since the normalized adjacency matrix is used in `GCNConv`, we could pre-compute normalized adjacency matrix for it. If a layer doesn't require a normalized adjacency matrix, this step will lead to error.

examples/gae.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using GeometricFlux
2+
using GraphSignals
23
using Flux
34
using Flux: throttle
45
using Flux.Losses: logitbinarycrossentropy
@@ -9,6 +10,8 @@ using SparseArrays
910
using Graphs.SimpleGraphs
1011
using CUDA
1112

13+
CUDA.allowscalar(false)
14+
1215
@load "data/cora_features.jld2" features
1316
@load "data/cora_graph.jld2" g
1417

@@ -20,14 +23,15 @@ target_catg = 7
2023
epochs = 200
2124

2225
## Preprocessing data
23-
fg = FeaturedGraph(g) |> gpu
26+
fg = FeaturedGraph(g) # pass to gpu together in model layers
2427
train_X = Matrix{Float32}(features) |> gpu # dim: num_features * num_nodes
25-
train_y = fg # dim: num_nodes * num_nodes
28+
train_y = fg |> GraphSignals.adjacency_matrix |> gpu # dim: num_nodes * num_nodes
2629

2730
## Model
2831
encoder = Chain(GCNConv(fg, num_features=>hidden1, relu),
2932
GCNConv(fg, hidden1=>hidden2))
30-
model = Chain(GAE(encoder, σ)) |> gpu
33+
model = Chain(GAE(encoder, σ)) |> gpu;
34+
# do not show model architecture, showing CuSparseMatrix will trigger errors
3135

3236
## Loss
3337
loss(x, y) = logitbinarycrossentropy(model(x), y)

examples/gat.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@ using GeometricFlux
22
using Flux
33
using Flux: onehotbatch, onecold, logitcrossentropy, throttle
44
using Flux: @epochs
5-
using Flux.Data: DataLoader
65
using JLD2
76
using Statistics: mean
87
using SparseArrays
8+
using LinearAlgebra
99
using Graphs.SimpleGraphs
1010
using Graphs: adjacency_matrix
1111
using CUDA
@@ -24,29 +24,30 @@ epochs = 10
2424
## Preprocessing data
2525
train_X = Matrix{Float32}(features) |> gpu # dim: num_features * num_nodes
2626
train_y = Matrix{Float32}(labels) |> gpu # dim: target_catg * num_nodes
27-
adj_mat = Matrix{Float32}(adjacency_matrix(g)) |> gpu
27+
A = Matrix{Int}((adjacency_matrix(g) + I) .≥ 1)
28+
fg = FeaturedGraph(A, :adjm)
2829

2930
## Model
30-
model = Chain(GATConv(g, num_features=>hidden, heads=heads),
31+
model = Chain(GATConv(fg, num_features=>hidden, heads=heads),
3132
Dropout(0.6),
32-
GATConv(g, hidden*heads=>target_catg, heads=heads, concat=false)
33+
GATConv(fg, hidden*heads=>target_catg, heads=heads, concat=false)
3334
) |> gpu
3435
# test model
35-
# @show model(train_X)
36+
@show model(train_X)
3637

3738
## Loss
3839
loss(x, y) = logitcrossentropy(model(x), y)
3940
accuracy(x, y) = mean(onecold(cpu(model(x))) .== onecold(cpu(y)))
4041

4142
# test loss
42-
# @show loss(train_X, train_y)
43+
@show loss(train_X, train_y)
4344

4445
# test gradient
45-
# @show gradient(X -> loss(X, train_y), train_X)
46+
@show gradient(()->loss(train_X, train_y), Flux.params(model))
4647

4748
## Training
4849
ps = Flux.params(model)
49-
train_data = DataLoader(train_X, train_y, batchsize=num_nodes)
50+
train_data = Flux.Data.DataLoader((train_X, train_y), batchsize=num_nodes)
5051
opt = ADAM(0.01)
5152
evalcb() = @show(accuracy(train_X, train_y))
5253

0 commit comments

Comments
 (0)