Skip to content

Commit 27c2ce4

Browse files
authored
Merge pull request #258 from FluxML/msgpass
Fix message-passing framework
2 parents 0ed1772 + 76f7dfe commit 27c2ce4

File tree

10 files changed

+172
-217
lines changed

10 files changed

+172
-217
lines changed

src/GeometricFlux.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ using Random
88
using Reexport
99

1010
using CUDA, CUDA.CUSPARSE
11+
using ChainRulesCore
1112
using ChainRulesCore: @non_differentiable
1213
using FillArrays: Fill
1314
using Flux
@@ -19,6 +20,8 @@ using Zygote
1920

2021
import Word2Vec: word2vec, wordvectors, get_vector
2122

23+
const ConcreteFeaturedGraph = Union{FeaturedGraph,FeaturedSubgraph}
24+
2225
export
2326
# layers/graphlayers
2427
AbstractGraphLayer,
@@ -50,8 +53,6 @@ export
5053
VGAE,
5154
InnerProductDecoder,
5255
VariationalEncoder,
53-
summarize,
54-
sample,
5556

5657
# layer/selector
5758
bypass_graph,

src/layers/conv.jl

Lines changed: 46 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,14 @@ GCNConv(ch::Pair{Int,Int}, σ = identity; kwargs...) =
3838

3939
Flux.trainable(l::GCNConv) = (l.weight, l.bias)
4040

41-
function (l::GCNConv)(fg::FeaturedGraph, x::AbstractMatrix)
41+
function (l::GCNConv)(fg::ConcreteFeaturedGraph, x::AbstractMatrix)
4242
= Zygote.ignore() do
4343
GraphSignals.normalized_adjacency_matrix(fg, eltype(x); selfloop=true)
4444
end
4545
l.σ.(l.weight * x *.+ l.bias)
4646
end
4747

48-
(l::GCNConv)(fg::FeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg)))
48+
(l::GCNConv)(fg::AbstractFeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg)))
4949

5050
function Base.show(io::IO, l::GCNConv)
5151
out, in = size(l.weight)
@@ -91,7 +91,7 @@ ChebConv(ch::Pair{Int,Int}, k::Int; kwargs...) =
9191

9292
Flux.trainable(l::ChebConv) = (l.weight, l.bias)
9393

94-
function (c::ChebConv)(fg::FeaturedGraph, X::AbstractMatrix{T}) where T
94+
function (c::ChebConv)(fg::ConcreteFeaturedGraph, X::AbstractMatrix{T}) where T
9595
GraphSignals.check_num_nodes(fg, X)
9696
@assert size(X, 1) == size(c.weight, 2) "Input feature size must match input channel size."
9797

@@ -110,7 +110,7 @@ function (c::ChebConv)(fg::FeaturedGraph, X::AbstractMatrix{T}) where T
110110
return Y .+ c.bias
111111
end
112112

113-
(l::ChebConv)(fg::FeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg)))
113+
(l::ChebConv)(fg::AbstractFeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg)))
114114

115115
function Base.show(io::IO, l::ChebConv)
116116
out, in, k = size(l.weight)
@@ -165,14 +165,14 @@ message(gc::GraphConv, x_i, x_j::AbstractVector, e_ij) = gc.weight2 * x_j
165165

166166
update(gc::GraphConv, m::AbstractVector, x::AbstractVector) = gc.σ.(gc.weight1*x .+ m .+ gc.bias)
167167

168-
function (gc::GraphConv)(fg::FeaturedGraph, x::AbstractMatrix)
169-
GraphSignals.check_num_nodes(fg, x)
170-
_, x, _ = propagate(gc, graph(fg), edge_feature(fg), x, global_feature(fg), +)
168+
function (gc::GraphConv)(fg::ConcreteFeaturedGraph, x::AbstractMatrix)
169+
# GraphSignals.check_num_nodes(fg, x)
170+
_, x, _ = propagate(gc, fg, edge_feature(fg), x, global_feature(fg), +)
171171
x
172172
end
173173

174-
(l::GraphConv)(fg::FeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg)))
175-
# (l::GraphConv)(fg::FeaturedGraph) = propagate(l, fg, +) # edge number check break this
174+
(l::GraphConv)(fg::AbstractFeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg)))
175+
# (l::GraphConv)(fg::AbstractFeaturedGraph) = propagate(l, fg, +) # edge number check break this
176176

177177
function Base.show(io::IO, l::GraphConv)
178178
in_channel = size(l.weight1, ndims(l.weight1))
@@ -244,18 +244,22 @@ end
244244

245245
# After some reshaping due to the multihead, we get the α from each message,
246246
# then get the softmax over every α, and eventually multiply the message by α
247-
function apply_batch_message(gat::GATConv, i, js, X::AbstractMatrix)
248-
e_ij = mapreduce(j -> GeometricFlux.message(gat, _view(X, i), _view(X, j)), hcat, js)
249-
n = size(e_ij, 1)
250-
αs = Flux.softmax(reshape(view(e_ij, 1, :), gat.heads, :), dims=2)
251-
msgs = view(e_ij, 2:n, :) .* reshape(αs, 1, :)
252-
reshape(msgs, (n-1)*gat.heads, :)
253-
end
254-
255-
function update_batch_edge(gat::GATConv, sg::SparseGraph, E::AbstractMatrix, X::AbstractMatrix, u)
256-
@assert Zygote.ignore(() -> check_self_loops(sg)) "a vertex must have self loop (receive a message from itself)."
257-
ys = map(i -> apply_batch_message(gat, i, GraphSignals.cpu_neighbors(sg, i), X), 1:nv(sg))
258-
return hcat(ys...)
247+
function graph_attention(gat::GATConv, i, js, X::AbstractMatrix)
248+
e_ij = map(j -> GeometricFlux.message(gat, _view(X, i), _view(X, j)), js)
249+
E = hcat_by_sum(e_ij)
250+
n = size(E, 1)
251+
αs = Flux.softmax(reshape(view(E, 1, :), gat.heads, :), dims=2)
252+
msgs = view(E, 2:n, :) .* reshape(αs, 1, :)
253+
return reshape(msgs, (n-1)*gat.heads, :)
254+
end
255+
256+
function update_batch_edge(gat::GATConv, fg::AbstractFeaturedGraph, E::AbstractMatrix, X::AbstractMatrix, u)
257+
@assert Zygote.ignore(() -> check_self_loops(graph(fg))) "a vertex must have self loop (receive a message from itself)."
258+
nodes = Zygote.ignore(()->vertices(fg))
259+
nbr = i->cpu(GraphSignals.neighbors(graph(fg), i))
260+
ms = map(i -> graph_attention(gat, i, Zygote.ignore(()->nbr(i)), X), nodes)
261+
M = hcat_by_sum(ms)
262+
return M
259263
end
260264

261265
function check_self_loops(sg::SparseGraph)
@@ -267,7 +271,7 @@ function check_self_loops(sg::SparseGraph)
267271
return true
268272
end
269273

270-
function update_batch_vertex(gat::GATConv, M::AbstractMatrix, X::AbstractMatrix, u)
274+
function update_batch_vertex(gat::GATConv, ::AbstractFeaturedGraph, M::AbstractMatrix, X::AbstractMatrix, u)
271275
M = M .+ gat.bias
272276
if !gat.concat
273277
N = size(M, 2)
@@ -276,14 +280,14 @@ function update_batch_vertex(gat::GATConv, M::AbstractMatrix, X::AbstractMatrix,
276280
return M
277281
end
278282

279-
function (gat::GATConv)(fg::FeaturedGraph, X::AbstractMatrix)
283+
function (gat::GATConv)(fg::ConcreteFeaturedGraph, X::AbstractMatrix)
280284
GraphSignals.check_num_nodes(fg, X)
281-
_, X, _ = propagate(gat, graph(fg), edge_feature(fg), X, global_feature(fg), +)
285+
_, X, _ = propagate(gat, fg, edge_feature(fg), X, global_feature(fg), +)
282286
return X
283287
end
284288

285-
(l::GATConv)(fg::FeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg)))
286-
# (l::GATConv)(fg::FeaturedGraph) = propagate(l, fg, +) # edge number check break this
289+
(l::GATConv)(fg::AbstractFeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg)))
290+
# (l::GATConv)(fg::AbstractFeaturedGraph) = propagate(l, fg, +) # edge number check break this
287291

288292
function Base.show(io::IO, l::GATConv)
289293
in_channel = size(l.weight, ndims(l.weight))
@@ -335,7 +339,7 @@ message(ggc::GatedGraphConv, x_i, x_j::AbstractVector, e_ij) = x_j
335339
update(ggc::GatedGraphConv, m::AbstractVector, x) = m
336340

337341

338-
function (ggc::GatedGraphConv)(fg::FeaturedGraph, H::AbstractMatrix{S}) where {T<:AbstractVector,S<:Real}
342+
function (ggc::GatedGraphConv)(fg::ConcreteFeaturedGraph, H::AbstractMatrix{S}) where {T<:AbstractVector,S<:Real}
339343
GraphSignals.check_num_nodes(fg, H)
340344
m, n = size(H)
341345
@assert (m <= ggc.out_ch) "number of input features must less or equals to output features."
@@ -347,14 +351,14 @@ function (ggc::GatedGraphConv)(fg::FeaturedGraph, H::AbstractMatrix{S}) where {T
347351
end
348352
for i = 1:ggc.num_layers
349353
M = view(ggc.weight, :, :, i) * H
350-
_, M = propagate(ggc, graph(fg), edge_feature(fg), M, global_feature(fg), +)
354+
_, M = propagate(ggc, fg, edge_feature(fg), M, global_feature(fg), +)
351355
H, _ = ggc.gru(H, M)
352356
end
353357
H
354358
end
355359

356-
(l::GatedGraphConv)(fg::FeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg)))
357-
# (l::GatedGraphConv)(fg::FeaturedGraph) = propagate(l, fg, +) # edge number check break this
360+
(l::GatedGraphConv)(fg::AbstractFeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg)))
361+
# (l::GatedGraphConv)(fg::AbstractFeaturedGraph) = propagate(l, fg, +) # edge number check break this
358362

359363

360364
function Base.show(io::IO, l::GatedGraphConv)
@@ -392,14 +396,14 @@ Flux.trainable(l::EdgeConv) = (l.nn,)
392396
message(ec::EdgeConv, x_i::AbstractVector, x_j::AbstractVector, e_ij) = ec.nn(vcat(x_i, x_j .- x_i))
393397
update(ec::EdgeConv, m::AbstractVector, x) = m
394398

395-
function (ec::EdgeConv)(fg::FeaturedGraph, X::AbstractMatrix)
399+
function (ec::EdgeConv)(fg::ConcreteFeaturedGraph, X::AbstractMatrix)
396400
GraphSignals.check_num_nodes(fg, X)
397-
_, X, _ = propagate(ec, graph(fg), edge_feature(fg), X, global_feature(fg), ec.aggr)
401+
_, X, _ = propagate(ec, fg, edge_feature(fg), X, global_feature(fg), ec.aggr)
398402
X
399403
end
400404

401-
(l::EdgeConv)(fg::FeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg)))
402-
# (l::EdgeConv)(fg::FeaturedGraph) = propagate(l, fg, l.aggr) # edge number check break this
405+
(l::EdgeConv)(fg::AbstractFeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg)))
406+
# (l::EdgeConv)(fg::AbstractFeaturedGraph) = propagate(l, fg, l.aggr) # edge number check break this
403407

404408
function Base.show(io::IO, l::EdgeConv)
405409
print(io, "EdgeConv(", l.nn)
@@ -443,15 +447,15 @@ Flux.trainable(g::GINConv) = (fg=g.fg, nn=g.nn)
443447
message(g::GINConv, x_i::AbstractVector, x_j::AbstractVector) = x_j
444448
update(g::GINConv, m::AbstractVector, x) = g.nn((1 + g.eps) * x + m)
445449

446-
function (g::GINConv)(fg::FeaturedGraph, X::AbstractMatrix)
450+
function (g::GINConv)(fg::ConcreteFeaturedGraph, X::AbstractMatrix)
447451
gf = graph(fg)
448452
GraphSignals.check_num_nodes(gf, X)
449-
_, X, _ = propagate(g, graph(fg), edge_feature(fg), X, global_feature(fg), +)
453+
_, X, _ = propagate(g, fg, edge_feature(fg), X, global_feature(fg), +)
450454
X
451455
end
452456

453-
(l::GINConv)(fg::FeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg)))
454-
# (l::GINConv)(fg::FeaturedGraph) = propagate(l, fg, +) # edge number check break this
457+
(l::GINConv)(fg::AbstractFeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg)))
458+
# (l::GINConv)(fg::AbstractFeaturedGraph) = propagate(l, fg, +) # edge number check break this
455459

456460

457461
"""
@@ -512,17 +516,17 @@ message(c::CGConv,
512516
end
513517
update(c::CGConv, m::AbstractVector, x) = x + m
514518

515-
function (c::CGConv)(fg::FeaturedGraph, X::AbstractMatrix, E::AbstractMatrix)
519+
function (c::CGConv)(fg::ConcreteFeaturedGraph, X::AbstractMatrix, E::AbstractMatrix)
516520
GraphSignals.check_num_nodes(fg, X)
517521
GraphSignals.check_num_edges(fg, E)
518-
_, Y, _ = propagate(c, graph(fg), E, X, global_feature(fg), +)
522+
_, Y, _ = propagate(c, fg, E, X, global_feature(fg), +)
519523
Y
520524
end
521525

522-
(l::CGConv)(fg::FeaturedGraph) = FeaturedGraph(fg,
526+
(l::CGConv)(fg::AbstractFeaturedGraph) = FeaturedGraph(fg,
523527
nf=l(fg, node_feature(fg), edge_feature(fg)),
524528
ef=edge_feature(fg))
525-
# (l::CGConv)(fg::FeaturedGraph) = propagate(l, fg, +) # edge number check break this
529+
# (l::CGConv)(fg::AbstractFeaturedGraph) = propagate(l, fg, +) # edge number check break this
526530

527531
(l::CGConv)(X::AbstractMatrix, E::AbstractMatrix) = l(l.fg, X, E)
528532

src/layers/gn.jl

Lines changed: 44 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,23 @@
1-
_view(::Nothing, i) = nothing
2-
_view(A::Fill{T,2,Axes}, i) where {T,Axes} = view(A, :, 1)
1+
_view(::Nothing, idx) = nothing
2+
_view(A::Fill{T,2,Axes}, idx) where {T,Axes} = fill(A.value, A.axes[1], length(idx))
3+
4+
function _view(A::SubArray{T,2,S}, idx) where {T,S<:Fill}
5+
p = parent(A)
6+
return Fill(p.value, p.axes[1].stop, length(idx))
7+
end
8+
39
_view(A::AbstractMatrix, idx) = view(A, :, idx)
410

11+
function _view(A::SubArray{T,2,S}, idxs) where {T,S<:AbstractMatrix}
12+
view_idx = A.indices[2]
13+
if view_idx == idxs
14+
return A
15+
else
16+
idxs = findall(x -> x in idxs, view_idx)
17+
return view(A, :, idxs)
18+
end
19+
end
20+
521
aggregate(aggr::typeof(+), X) = vec(sum(X, dims=2))
622
aggregate(aggr::typeof(-), X) = -vec(sum(X, dims=2))
723
aggregate(aggr::typeof(*), X) = vec(prod(X, dims=2))
@@ -16,43 +32,48 @@ abstract type GraphNet <: AbstractGraphLayer end
1632
@inline update_vertex(gn::GraphNet, ē, vi, u) = vi
1733
@inline update_global(gn::GraphNet, ē, v̄, u) = u
1834

19-
@inline function update_batch_edge(gn::GraphNet, sg::SparseGraph, E, V, u)
20-
ys = map(i -> apply_batch_message(gn, sg, i, GraphSignals.cpu_neighbors(sg, i), E, V, u), vertices(sg))
21-
return hcat(ys...)
35+
@inline function update_batch_edge(gn::GraphNet, fg::AbstractFeaturedGraph, E, V, u)
36+
es = Zygote.ignore(()->cpu(GraphSignals.incident_edges(fg)))
37+
xs = Zygote.ignore(()->cpu(GraphSignals.repeat_nodes(fg)))
38+
nbrs = Zygote.ignore(()->cpu(GraphSignals.neighbors(fg)))
39+
ms = map((e,i,j)->update_edge(gn, _view(E, e), _view(V, i), _view(V, j), u), es, xs, nbrs)
40+
M = hcat_by_sum(ms)
41+
return M
2242
end
2343

24-
@inline function apply_batch_message(gn::GraphNet, sg::SparseGraph, i, js, E, V, u)
25-
# js still CuArray
26-
es = Zygote.ignore(() -> GraphSignals.cpu_incident_edges(sg, i))
27-
ys = map(k -> update_edge(gn, _view(E, es[k]), _view(V, i), _view(V, js[k]), u), 1:length(js))
28-
return hcat(ys...)
44+
@inline function update_batch_vertex(gn::GraphNet, fg::AbstractFeaturedGraph, Ē, V, u)
45+
nodes = Zygote.ignore(()->vertices(fg))
46+
vs = map(n->update_vertex(gn, _view(Ē, n), _view(V, n), u), nodes)
47+
V_ = hcat_by_sum(vs)
48+
return V_
2949
end
3050

31-
@inline function update_batch_vertex(gn::GraphNet, Ē, V, u)
32-
ys = map(i -> update_vertex(gn, _view(Ē, i), _view(V, i), u), 1:size(V,2))
33-
return hcat(ys...)
51+
@inline function aggregate_neighbors(gn::GraphNet, fg::AbstractFeaturedGraph, aggr, E)
52+
N = nv(parent(fg))
53+
xs = Zygote.ignore(()->cpu(GraphSignals.repeat_nodes(fg)))
54+
= NNlib.scatter(aggr, E, xs; dstsize=(size(E, 1), N))
55+
return
3456
end
35-
36-
@inline aggregate_neighbors(gn::GraphNet, sg::SparseGraph, aggr, E) = neighbor_scatter(aggr, E, sg)
37-
@inline aggregate_neighbors(gn::GraphNet, sg::SparseGraph, aggr::Nothing, @nospecialize E) = nothing
57+
@inline aggregate_neighbors(gn::GraphNet, fg::AbstractFeaturedGraph, aggr::Nothing, @nospecialize E) = nothing
3858

3959
@inline aggregate_edges(gn::GraphNet, aggr, E) = aggregate(aggr, E)
4060
@inline aggregate_edges(gn::GraphNet, aggr::Nothing, @nospecialize E) = nothing
4161

4262
@inline aggregate_vertices(gn::GraphNet, aggr, V) = aggregate(aggr, V)
4363
@inline aggregate_vertices(gn::GraphNet, aggr::Nothing, @nospecialize V) = nothing
4464

45-
function propagate(gn::GraphNet, fg::FeaturedGraph, naggr=nothing, eaggr=nothing, vaggr=nothing)
46-
E, V, u = propagate(gn, graph(fg), edge_feature(fg), node_feature(fg), global_feature(fg), naggr, eaggr, vaggr)
65+
function propagate(gn::GraphNet, fg::AbstractFeaturedGraph, naggr=nothing, eaggr=nothing, vaggr=nothing)
66+
E, V, u = propagate(gn, fg, edge_feature(fg), node_feature(fg), global_feature(fg), naggr, eaggr, vaggr)
4767
FeaturedGraph(fg, nf=V, ef=E, gf=u)
4868
end
4969

50-
function propagate(gn::GraphNet, sg::SparseGraph, E, V, u, naggr=nothing, eaggr=nothing, vaggr=nothing)
51-
E = update_batch_edge(gn, sg, E, V, u)
52-
= aggregate_neighbors(gn, sg, naggr, E)
53-
V = update_batch_vertex(gn, Ē, V, u)
70+
function propagate(gn::GraphNet, fg::AbstractFeaturedGraph, E::AbstractArray, V::AbstractArray, u::AbstractArray,
71+
naggr=nothing, eaggr=nothing, vaggr=nothing)
72+
E = update_batch_edge(gn, fg, E, V, u)
73+
= aggregate_neighbors(gn, fg, naggr, E)
74+
V = update_batch_vertex(gn, fg, Ē, V, u)
5475
= aggregate_edges(gn, eaggr, E)
5576
= aggregate_vertices(gn, vaggr, V)
5677
u = update_global(gn, ē, v̄, u)
57-
return E, V, u
78+
return parent(E), parent(V), u
5879
end

src/layers/graphlayers.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
abstract type AbstractGraphLayer end
22

33
(l::AbstractGraphLayer)(x::AbstractMatrix) = l(l.fg, x)
4+
(l::AbstractGraphLayer)(::NullGraph, x::AbstractMatrix) = throw(ArgumentError("concrete FeaturedGraph is not provided."))

src/utils.jl

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,31 @@
1-
"""
2-
accumulated_edges(adj)
1+
function hcat_by_sum(xs::AbstractVector)
2+
T = eltype(xs[1])
3+
dim = size(xs[1], 1)
4+
N = length(xs)
35

4-
Return a vector which acts as a mapping table. The index is the vertex index,
5-
value is accumulated numbers of edge (current vertex not included).
6-
"""
7-
accumulated_edges(adj::AbstractVector{<:AbstractVector{<:Integer}}) = [0, cumsum(map(length, adj))...]
6+
ns = map(x->size(x,2), xs)
7+
pushfirst!(ns, 1)
8+
cumsum!(ns, ns)
89

9-
function generate_cluster(M::AbstractArray{T,N}, accu_edge) where {T,N}
10-
num_V = length(accu_edge) - 1
11-
num_E = accu_edge[end]
12-
cluster = similar(M, Int, num_E)
13-
@inbounds for i = 1:num_V
14-
j = accu_edge[i]
15-
k = accu_edge[i+1]
16-
cluster[j+1:k] .= i
10+
A = similar(xs[1], T, dim, ns[end]-1)
11+
for i in 1:N
12+
A[:, ns[i]:(ns[i+1]-1)] .= xs[i]
1713
end
18-
cluster
14+
return A
1915
end
2016

21-
@non_differentiable accumulated_edges(x...)
22-
@non_differentiable generate_cluster(x...)
17+
function ChainRulesCore.rrule(::typeof(hcat_by_sum), xs::AbstractVector)
18+
N = length(xs)
19+
20+
ns = map(x->size(x,2), xs)
21+
pushfirst!(ns, 1)
22+
cumsum!(ns, ns)
23+
24+
hcat_by_sum_pullback(Δ) = (NoTangent(), ntuple(i->view(Δ,:,ns[i]:(ns[i+1]-1)), N))
25+
hcat_by_sum(xs), hcat_by_sum_pullback
26+
end
27+
28+
function ChainRulesCore.rrule(::typeof(parent), A::Base.SubArray)
29+
parent_pullback(Δ) = (NoTangent(), view(Δ, A.indices...))
30+
parent(A), parent_pullback
31+
end

0 commit comments

Comments
 (0)