@@ -38,14 +38,14 @@ GCNConv(ch::Pair{Int,Int}, σ = identity; kwargs...) =
38
38
39
39
Flux. trainable (l:: GCNConv ) = (l. weight, l. bias)
40
40
41
- function (l:: GCNConv )(fg:: FeaturedGraph , x:: AbstractMatrix )
41
+ function (l:: GCNConv )(fg:: ConcreteFeaturedGraph , x:: AbstractMatrix )
42
42
à = Zygote. ignore () do
43
43
GraphSignals. normalized_adjacency_matrix (fg, eltype (x); selfloop= true )
44
44
end
45
45
l. σ .(l. weight * x * Ã .+ l. bias)
46
46
end
47
47
48
- (l:: GCNConv )(fg:: FeaturedGraph ) = FeaturedGraph (fg, nf = l (fg, node_feature (fg)))
48
+ (l:: GCNConv )(fg:: AbstractFeaturedGraph ) = FeaturedGraph (fg, nf = l (fg, node_feature (fg)))
49
49
50
50
function Base. show (io:: IO , l:: GCNConv )
51
51
out, in = size (l. weight)
@@ -91,7 +91,7 @@ ChebConv(ch::Pair{Int,Int}, k::Int; kwargs...) =
91
91
92
92
Flux. trainable (l:: ChebConv ) = (l. weight, l. bias)
93
93
94
- function (c:: ChebConv )(fg:: FeaturedGraph , X:: AbstractMatrix{T} ) where T
94
+ function (c:: ChebConv )(fg:: ConcreteFeaturedGraph , X:: AbstractMatrix{T} ) where T
95
95
GraphSignals. check_num_nodes (fg, X)
96
96
@assert size (X, 1 ) == size (c. weight, 2 ) " Input feature size must match input channel size."
97
97
@@ -110,7 +110,7 @@ function (c::ChebConv)(fg::FeaturedGraph, X::AbstractMatrix{T}) where T
110
110
return Y .+ c. bias
111
111
end
112
112
113
- (l:: ChebConv )(fg:: FeaturedGraph ) = FeaturedGraph (fg, nf = l (fg, node_feature (fg)))
113
+ (l:: ChebConv )(fg:: AbstractFeaturedGraph ) = FeaturedGraph (fg, nf = l (fg, node_feature (fg)))
114
114
115
115
function Base. show (io:: IO , l:: ChebConv )
116
116
out, in, k = size (l. weight)
@@ -165,14 +165,14 @@ message(gc::GraphConv, x_i, x_j::AbstractVector, e_ij) = gc.weight2 * x_j
165
165
166
166
update (gc:: GraphConv , m:: AbstractVector , x:: AbstractVector ) = gc. σ .(gc. weight1* x .+ m .+ gc. bias)
167
167
168
- function (gc:: GraphConv )(fg:: FeaturedGraph , x:: AbstractMatrix )
169
- GraphSignals. check_num_nodes (fg, x)
170
- _, x, _ = propagate (gc, graph (fg) , edge_feature (fg), x, global_feature (fg), + )
168
+ function (gc:: GraphConv )(fg:: ConcreteFeaturedGraph , x:: AbstractMatrix )
169
+ # GraphSignals.check_num_nodes(fg, x)
170
+ _, x, _ = propagate (gc, fg , edge_feature (fg), x, global_feature (fg), + )
171
171
x
172
172
end
173
173
174
- (l:: GraphConv )(fg:: FeaturedGraph ) = FeaturedGraph (fg, nf = l (fg, node_feature (fg)))
175
- # (l::GraphConv)(fg::FeaturedGraph ) = propagate(l, fg, +) # edge number check break this
174
+ (l:: GraphConv )(fg:: AbstractFeaturedGraph ) = FeaturedGraph (fg, nf = l (fg, node_feature (fg)))
175
+ # (l::GraphConv)(fg::AbstractFeaturedGraph ) = propagate(l, fg, +) # edge number check break this
176
176
177
177
function Base. show (io:: IO , l:: GraphConv )
178
178
in_channel = size (l. weight1, ndims (l. weight1))
@@ -244,18 +244,22 @@ end
244
244
245
245
# After some reshaping due to the multihead, we get the α from each message,
246
246
# then get the softmax over every α, and eventually multiply the message by α
247
- function apply_batch_message (gat:: GATConv , i, js, X:: AbstractMatrix )
248
- e_ij = mapreduce (j -> GeometricFlux. message (gat, _view (X, i), _view (X, j)), hcat, js)
249
- n = size (e_ij, 1 )
250
- αs = Flux. softmax (reshape (view (e_ij, 1 , :), gat. heads, :), dims= 2 )
251
- msgs = view (e_ij, 2 : n, :) .* reshape (αs, 1 , :)
252
- reshape (msgs, (n- 1 )* gat. heads, :)
253
- end
254
-
255
- function update_batch_edge (gat:: GATConv , sg:: SparseGraph , E:: AbstractMatrix , X:: AbstractMatrix , u)
256
- @assert Zygote. ignore (() -> check_self_loops (sg)) " a vertex must have self loop (receive a message from itself)."
257
- ys = map (i -> apply_batch_message (gat, i, GraphSignals. cpu_neighbors (sg, i), X), 1 : nv (sg))
258
- return hcat (ys... )
247
+ function graph_attention (gat:: GATConv , i, js, X:: AbstractMatrix )
248
+ e_ij = map (j -> GeometricFlux. message (gat, _view (X, i), _view (X, j)), js)
249
+ E = hcat_by_sum (e_ij)
250
+ n = size (E, 1 )
251
+ αs = Flux. softmax (reshape (view (E, 1 , :), gat. heads, :), dims= 2 )
252
+ msgs = view (E, 2 : n, :) .* reshape (αs, 1 , :)
253
+ return reshape (msgs, (n- 1 )* gat. heads, :)
254
+ end
255
+
256
+ function update_batch_edge (gat:: GATConv , fg:: AbstractFeaturedGraph , E:: AbstractMatrix , X:: AbstractMatrix , u)
257
+ @assert Zygote. ignore (() -> check_self_loops (graph (fg))) " a vertex must have self loop (receive a message from itself)."
258
+ nodes = Zygote. ignore (()-> vertices (fg))
259
+ nbr = i-> cpu (GraphSignals. neighbors (graph (fg), i))
260
+ ms = map (i -> graph_attention (gat, i, Zygote. ignore (()-> nbr (i)), X), nodes)
261
+ M = hcat_by_sum (ms)
262
+ return M
259
263
end
260
264
261
265
function check_self_loops (sg:: SparseGraph )
@@ -267,7 +271,7 @@ function check_self_loops(sg::SparseGraph)
267
271
return true
268
272
end
269
273
270
- function update_batch_vertex (gat:: GATConv , M:: AbstractMatrix , X:: AbstractMatrix , u)
274
+ function update_batch_vertex (gat:: GATConv , :: AbstractFeaturedGraph , M:: AbstractMatrix , X:: AbstractMatrix , u)
271
275
M = M .+ gat. bias
272
276
if ! gat. concat
273
277
N = size (M, 2 )
@@ -276,14 +280,14 @@ function update_batch_vertex(gat::GATConv, M::AbstractMatrix, X::AbstractMatrix,
276
280
return M
277
281
end
278
282
279
- function (gat:: GATConv )(fg:: FeaturedGraph , X:: AbstractMatrix )
283
+ function (gat:: GATConv )(fg:: ConcreteFeaturedGraph , X:: AbstractMatrix )
280
284
GraphSignals. check_num_nodes (fg, X)
281
- _, X, _ = propagate (gat, graph (fg) , edge_feature (fg), X, global_feature (fg), + )
285
+ _, X, _ = propagate (gat, fg , edge_feature (fg), X, global_feature (fg), + )
282
286
return X
283
287
end
284
288
285
- (l:: GATConv )(fg:: FeaturedGraph ) = FeaturedGraph (fg, nf = l (fg, node_feature (fg)))
286
- # (l::GATConv)(fg::FeaturedGraph ) = propagate(l, fg, +) # edge number check break this
289
+ (l:: GATConv )(fg:: AbstractFeaturedGraph ) = FeaturedGraph (fg, nf = l (fg, node_feature (fg)))
290
+ # (l::GATConv)(fg::AbstractFeaturedGraph ) = propagate(l, fg, +) # edge number check break this
287
291
288
292
function Base. show (io:: IO , l:: GATConv )
289
293
in_channel = size (l. weight, ndims (l. weight))
@@ -335,7 +339,7 @@ message(ggc::GatedGraphConv, x_i, x_j::AbstractVector, e_ij) = x_j
335
339
update (ggc:: GatedGraphConv , m:: AbstractVector , x) = m
336
340
337
341
338
- function (ggc:: GatedGraphConv )(fg:: FeaturedGraph , H:: AbstractMatrix{S} ) where {T<: AbstractVector ,S<: Real }
342
+ function (ggc:: GatedGraphConv )(fg:: ConcreteFeaturedGraph , H:: AbstractMatrix{S} ) where {T<: AbstractVector ,S<: Real }
339
343
GraphSignals. check_num_nodes (fg, H)
340
344
m, n = size (H)
341
345
@assert (m <= ggc. out_ch) " number of input features must less or equals to output features."
@@ -347,14 +351,14 @@ function (ggc::GatedGraphConv)(fg::FeaturedGraph, H::AbstractMatrix{S}) where {T
347
351
end
348
352
for i = 1 : ggc. num_layers
349
353
M = view (ggc. weight, :, :, i) * H
350
- _, M = propagate (ggc, graph (fg) , edge_feature (fg), M, global_feature (fg), + )
354
+ _, M = propagate (ggc, fg , edge_feature (fg), M, global_feature (fg), + )
351
355
H, _ = ggc. gru (H, M)
352
356
end
353
357
H
354
358
end
355
359
356
- (l:: GatedGraphConv )(fg:: FeaturedGraph ) = FeaturedGraph (fg, nf = l (fg, node_feature (fg)))
357
- # (l::GatedGraphConv)(fg::FeaturedGraph ) = propagate(l, fg, +) # edge number check break this
360
+ (l:: GatedGraphConv )(fg:: AbstractFeaturedGraph ) = FeaturedGraph (fg, nf = l (fg, node_feature (fg)))
361
+ # (l::GatedGraphConv)(fg::AbstractFeaturedGraph ) = propagate(l, fg, +) # edge number check break this
358
362
359
363
360
364
function Base. show (io:: IO , l:: GatedGraphConv )
@@ -392,14 +396,14 @@ Flux.trainable(l::EdgeConv) = (l.nn,)
392
396
message (ec:: EdgeConv , x_i:: AbstractVector , x_j:: AbstractVector , e_ij) = ec. nn (vcat (x_i, x_j .- x_i))
393
397
update (ec:: EdgeConv , m:: AbstractVector , x) = m
394
398
395
- function (ec:: EdgeConv )(fg:: FeaturedGraph , X:: AbstractMatrix )
399
+ function (ec:: EdgeConv )(fg:: ConcreteFeaturedGraph , X:: AbstractMatrix )
396
400
GraphSignals. check_num_nodes (fg, X)
397
- _, X, _ = propagate (ec, graph (fg) , edge_feature (fg), X, global_feature (fg), ec. aggr)
401
+ _, X, _ = propagate (ec, fg , edge_feature (fg), X, global_feature (fg), ec. aggr)
398
402
X
399
403
end
400
404
401
- (l:: EdgeConv )(fg:: FeaturedGraph ) = FeaturedGraph (fg, nf = l (fg, node_feature (fg)))
402
- # (l::EdgeConv)(fg::FeaturedGraph ) = propagate(l, fg, l.aggr) # edge number check break this
405
+ (l:: EdgeConv )(fg:: AbstractFeaturedGraph ) = FeaturedGraph (fg, nf = l (fg, node_feature (fg)))
406
+ # (l::EdgeConv)(fg::AbstractFeaturedGraph ) = propagate(l, fg, l.aggr) # edge number check break this
403
407
404
408
function Base. show (io:: IO , l:: EdgeConv )
405
409
print (io, " EdgeConv(" , l. nn)
@@ -443,15 +447,15 @@ Flux.trainable(g::GINConv) = (fg=g.fg, nn=g.nn)
443
447
message (g:: GINConv , x_i:: AbstractVector , x_j:: AbstractVector ) = x_j
444
448
update (g:: GINConv , m:: AbstractVector , x) = g. nn ((1 + g. eps) * x + m)
445
449
446
- function (g:: GINConv )(fg:: FeaturedGraph , X:: AbstractMatrix )
450
+ function (g:: GINConv )(fg:: ConcreteFeaturedGraph , X:: AbstractMatrix )
447
451
gf = graph (fg)
448
452
GraphSignals. check_num_nodes (gf, X)
449
- _, X, _ = propagate (g, graph (fg) , edge_feature (fg), X, global_feature (fg), + )
453
+ _, X, _ = propagate (g, fg , edge_feature (fg), X, global_feature (fg), + )
450
454
X
451
455
end
452
456
453
- (l:: GINConv )(fg:: FeaturedGraph ) = FeaturedGraph (fg, nf = l (fg, node_feature (fg)))
454
- # (l::GINConv)(fg::FeaturedGraph ) = propagate(l, fg, +) # edge number check break this
457
+ (l:: GINConv )(fg:: AbstractFeaturedGraph ) = FeaturedGraph (fg, nf = l (fg, node_feature (fg)))
458
+ # (l::GINConv)(fg::AbstractFeaturedGraph ) = propagate(l, fg, +) # edge number check break this
455
459
456
460
457
461
"""
@@ -512,17 +516,17 @@ message(c::CGConv,
512
516
end
513
517
update (c:: CGConv , m:: AbstractVector , x) = x + m
514
518
515
- function (c:: CGConv )(fg:: FeaturedGraph , X:: AbstractMatrix , E:: AbstractMatrix )
519
+ function (c:: CGConv )(fg:: ConcreteFeaturedGraph , X:: AbstractMatrix , E:: AbstractMatrix )
516
520
GraphSignals. check_num_nodes (fg, X)
517
521
GraphSignals. check_num_edges (fg, E)
518
- _, Y, _ = propagate (c, graph (fg) , E, X, global_feature (fg), + )
522
+ _, Y, _ = propagate (c, fg , E, X, global_feature (fg), + )
519
523
Y
520
524
end
521
525
522
- (l:: CGConv )(fg:: FeaturedGraph ) = FeaturedGraph (fg,
526
+ (l:: CGConv )(fg:: AbstractFeaturedGraph ) = FeaturedGraph (fg,
523
527
nf= l (fg, node_feature (fg), edge_feature (fg)),
524
528
ef= edge_feature (fg))
525
- # (l::CGConv)(fg::FeaturedGraph ) = propagate(l, fg, +) # edge number check break this
529
+ # (l::CGConv)(fg::AbstractFeaturedGraph ) = propagate(l, fg, +) # edge number check break this
526
530
527
531
(l:: CGConv )(X:: AbstractMatrix , E:: AbstractMatrix ) = l (l. fg, X, E)
528
532
0 commit comments