diff --git a/GNNlib/src/msgpass.jl b/GNNlib/src/msgpass.jl index 413a60556..acab02217 100644 --- a/GNNlib/src/msgpass.jl +++ b/GNNlib/src/msgpass.jl @@ -1,15 +1,12 @@ """ - propagate(fmsg, g, aggr [layer]; [xi, xj, e]) - propagate(fmsg, g, aggr, [layer,] xi, xj, e=nothing) + propagate(fmsg, g, aggr; [xi, xj, e]) + propagate(fmsg, g, aggr xi, xj, e=nothing) Performs message passing on graph `g`. Takes care of materializing the node features on each edge, applying the message function `fmsg`, and returning an aggregated message ``\\bar{\\mathbf{m}}`` (depending on the return value of `fmsg`, an array or a named tuple of arrays with last dimension's size `g.num_nodes`). -If also a [`GNNLayer`](@ref) `layer` is provided, it will be passed to `fmsg` -as a first argument. - It can be decomposed in two steps: ```julia @@ -35,10 +32,8 @@ providing as input `f` a closure. with the same batch size. If also `layer` is passed to propagate, the signature of `fmsg` has to be `fmsg(layer, xi, xj, e)` instead of `fmsg(xi, xj, e)`. -- `layer`: A [`GNNLayer`](@ref). If provided it will be passed to `fmsg` as a first argument. - `aggr`: Neighborhood aggregation operator. Use `+`, `mean`, `max`, or `min`. - # Examples ```julia @@ -86,8 +81,8 @@ end ## APPLY EDGES """ - apply_edges(fmsg, g, [layer]; [xi, xj, e]) - apply_edges(fmsg, g, [layer,] xi, xj, e=nothing) + apply_edges(fmsg, g; [xi, xj, e]) + apply_edges(fmsg, g, xi, xj, e=nothing) Returns the message from node `j` to node `i` applying the message function `fmsg` on the edges in graph `g`. @@ -99,9 +94,6 @@ The function `fmsg` operates on batches of edges, therefore `xi`, `xj`, and `e` are tensors whose last dimension is the batch size, or can be named tuples of such tensors. - -If also a [`GNNLayer`](@ref) `layer` is provided, it will be passed to `fmsg` -as a first argument. # Arguments @@ -117,7 +109,6 @@ as a first argument. with the same batch size. If also `layer` is passed to propagate, the signature of `fmsg` has to be `fmsg(layer, xi, xj, e)` instead of `fmsg(xi, xj, e)`. -- `layer`: A [`GNNLayer`](@ref). If provided it will be passed to `fmsg` as a first argument. See also [`propagate`](@ref) and [`aggregate_neighbors`](@ref). """ diff --git a/docs/Project.toml b/docs/Project.toml index 5bee032ef..60f0e00d0 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -4,6 +4,7 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" DocumenterInterLinks = "d12716ef-a0f6-4df4-a9f1-a5a34e75c656" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" GNNGraphs = "aed8fd31-079b-4b5a-b342-a13352159b8c" +GNNlib = "a6a84749-d869-43f8-aacc-be26a1996e48" GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/docs/make.jl b/docs/make.jl index e0a40ee13..869aa94f1 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -26,7 +26,7 @@ prettyurls = get(ENV, "CI", nothing) == "true" mathengine = MathJax3() makedocs(; - modules = [GraphNeuralNetworks, GNNGraphs], + modules = [GraphNeuralNetworks, GNNGraphs, GNNlib], doctest = false, clean = true, plugins = [interlinks], diff --git a/docs/src/api/messagepassing.md b/docs/src/api/messagepassing.md index e7ade6d5b..aba1e0bba 100644 --- a/docs/src/api/messagepassing.md +++ b/docs/src/api/messagepassing.md @@ -14,19 +14,19 @@ Pages = ["messagepassing.md"] ## Interface ```@docs -apply_edges -aggregate_neighbors -propagate +GNNlib.apply_edges +GNNlib.aggregate_neighbors +GNNlib.propagate ``` ## Built-in message functions ```@docs -copy_xi -copy_xj -xi_dot_xj -xi_sub_xj -xj_sub_xi -e_mul_xj -w_mul_xj +GNNlib.copy_xi +GNNlib.copy_xj +GNNlib.xi_dot_xj +GNNlib.xi_sub_xj +GNNlib.xj_sub_xi +GNNlib.e_mul_xj +GNNlib.w_mul_xj ```