Skip to content

Commit

Permalink
Add possibility to pass weights to GCNConv (#447)
Browse files Browse the repository at this point in the history
* Modify GCNConv

* Add tests

* Update src/layers/conv.jl

---------
  • Loading branch information
aurorarossi authored Jul 18, 2024
1 parent df56b7e commit 0f8e13c
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 6 deletions.
23 changes: 17 additions & 6 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,15 @@ and optionally an edge weight vector.
# Forward
(::GCNConv)(g::GNNGraph, x::AbstractMatrix, edge_weight = nothing, norm_fn::Function = d -> 1 ./ sqrt.(d)) -> AbstractMatrix
(::GCNConv)(g::GNNGraph, x::AbstractMatrix, edge_weight = nothing, norm_fn::Function = d -> 1 ./ sqrt.(d), conv_weight::Union{Nothing,AbstractMatrix} = nothing) -> AbstractMatrix
Takes as input a graph `g`,ca node feature matrix `x` of size `[in, num_nodes]`,
Takes as input a graph `g`, a node feature matrix `x` of size `[in, num_nodes]`,
and optionally an edge weight vector. Returns a node feature matrix of size
`[out, num_nodes]`.
The `norm_fn` parameter allows for custom normalization of the graph convolution operation by passing a function as argument.
By default, it computes ``\frac{1}{\sqrt{d}}`` i.e the inverse square root of the degree (`d`) of each node in the graph.
If `conv_weight` is an `AbstractMatrix` of size `[out, in]`, then the convolution is performed using that weight matrix instead of the weights stored in the model.
# Examples
Expand Down Expand Up @@ -102,11 +103,21 @@ check_gcnconv_input(g::AbstractGNNGraph, edge_weight::Nothing) = nothing
function (l::GCNConv)(g::AbstractGNNGraph,
x,
edge_weight::EW = nothing,
norm_fn::Function = d -> 1 ./ sqrt.(d)
norm_fn::Function = d -> 1 ./ sqrt.(d);
conv_weight::Union{Nothing,AbstractMatrix} = nothing
) where {EW <: Union{Nothing, AbstractVector}}

check_gcnconv_input(g, edge_weight)

if conv_weight === nothing
weight = l.weight
else
weight = conv_weight
if size(weight) != size(l.weight)
throw(ArgumentError("The weight matrix has the wrong size. Expected $(size(l.weight)) but got $(size(weight))"))
end
end

if l.add_self_loops
g = add_self_loops(g)
if edge_weight !== nothing
Expand All @@ -116,11 +127,11 @@ function (l::GCNConv)(g::AbstractGNNGraph,
@assert length(edge_weight) == g.num_edges
end
end
Dout, Din = size(l.weight)
Dout, Din = size(weight)
if Dout < Din && !(g isa GNNHeteroGraph)
# multiply before convolution if it is more convenient, otherwise multiply after
# (this works only for homogenous graph)
x = l.weight * x
x = weight * x
end

xj, xi = expand_srcdst(g, x) # expand only after potential multiplication
Expand Down Expand Up @@ -150,7 +161,7 @@ function (l::GCNConv)(g::AbstractGNNGraph,
end
x = x .* cin'
if Dout >= Din || g isa GNNHeteroGraph
x = l.weight * x
x = weight * x
end
return l.σ.(x .+ l.bias)
end
Expand Down
11 changes: 11 additions & 0 deletions test/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,17 @@ test_graphs = [g1, g_single_vertex]
@test gradient(w -> sum(l(g, x, w)), w)[1] isa AbstractVector{T} # redundant test but more explicit
test_layer(l, g, rtol = RTOL_HIGH, outsize = (1, g.num_nodes), test_gpu = false)
end

@testset "conv_weight" begin
l = GraphNeuralNetworks.GCNConv(in_channel => out_channel)
w = zeros(T, out_channel, in_channel)
g1 = GNNGraph(adj1, ndata = ones(T, in_channel, N))
@test l(g1, g1.ndata.x, conv_weight = w) == zeros(T, out_channel, N)
a = rand(T, in_channel, N)
g2 = GNNGraph(adj1, ndata = a)
@test l(g2, g2.ndata.x, conv_weight = w) == w * a

end
end

@testset "ChebConv" begin
Expand Down

0 comments on commit 0f8e13c

Please sign in to comment.