diff --git a/GNNlib/src/GNNGraphs/GNNGraphs.jl b/GNNlib/src/GNNGraphs/GNNGraphs.jl index 2e7f05207..0fbf871d0 100644 --- a/GNNlib/src/GNNGraphs/GNNGraphs.jl +++ b/GNNlib/src/GNNGraphs/GNNGraphs.jl @@ -101,6 +101,8 @@ include("operators.jl") include("convert.jl") include("utils.jl") +export sort_edge_index, + color_refinement include("gatherscatter.jl") # _gather, _scatter diff --git a/GNNlib/src/GNNGraphs/utils.jl b/GNNlib/src/GNNGraphs/utils.jl index e2b821369..f6e25dc80 100644 --- a/GNNlib/src/GNNGraphs/utils.jl +++ b/GNNlib/src/GNNGraphs/utils.jl @@ -49,6 +49,15 @@ end sort_edge_index(eindex::Tuple) = sort_edge_index(eindex...) +""" + sort_edge_index(ei::Tuple) -> u', v' + sort_edge_index(u, v) -> u', v' + +Return a sorted version of the tuple of vectors `ei = (u, v)`, +applying a common permutation to `u` and `v`. +The sorting is lexycographic, that is the pairs `(ui, vi)` +are sorted first according to the `ui` and then according to `vi`. +""" function sort_edge_index(u, v) uv = collect(zip(u, v)) p = sortperm(uv) # isless lexicographically defined for tuples @@ -301,4 +310,56 @@ end @non_differentiable normalize_graphdata(::Nothing) iscuarray(x::AbstractArray) = false -@non_differentiable iscuarray(::Any) \ No newline at end of file +@non_differentiable iscuarray(::Any) + + +@doc raw""" + color_refinement(g::GNNGraph, [x0]) -> x, num_colors, niters + +The color refinement algorithm for graph coloring. +Given a graph `g` and an initial coloring `x0`, the algorithm +iteratively refines the coloring until a fixed point is reached. + +At each iteration the algorithm computes a hash of the coloring and the sorted list of colors +of the neighbors of each node. This hash is used to determine if the coloring has changed. + +```math +x_i' = hashmap((x_i, sort([x_j for j \in N(i)]))). +```` + +This algorithm is related to the 1-Weisfeiler-Lehman algorithm for graph isomorphism testing. + +# Arguments +- `g::GNNGraph`: The graph to color. +- `x0::AbstractVector{<:Integer}`: The initial coloring. If not provided, all nodes are colored with 1. + +# Returns +- `x::AbstractVector{<:Integer}`: The final coloring. +- `num_colors::Int`: The number of colors used. +- `niters::Int`: The number of iterations until convergence. +""" +color_refinement(g::GNNGraph) = color_refinement(g, ones(Int, g.num_nodes)) + +function color_refinement(g::GNNGraph, x0::AbstractVector{<:Integer}) + @assert length(x0) == g.num_nodes + s, t = edge_index(g) + t, s = sort_edge_index(t, s) # sort by target + degs = degree(g, dir=:in) + x = x0 + + hashmap = Dict{UInt64, Int}() + x′ = zeros(Int, length(x0)) + niters = 0 + while true + xneigs = chunk(x[s], size=degs) + for (i, (xi, xineigs)) in enumerate(zip(x, xneigs)) + idx = hash((xi, sort(xineigs))) + x′[i] = get!(hashmap, idx, length(hashmap) + 1) + end + niters += 1 + x == x′ && break + x = x′ + end + num_colors = length(union(x)) + return x, num_colors, niters +end \ No newline at end of file diff --git a/docs/src/api/gnngraph.md b/docs/src/api/gnngraph.md index 402e37524..be8c39026 100644 --- a/docs/src/api/gnngraph.md +++ b/docs/src/api/gnngraph.md @@ -52,6 +52,13 @@ Pages = ["transform.jl"] Private = false ``` +## Utils + +```@docs +GNNGraphs.sort_edge_index +GNNGraphs.color_refinement +``` + ## Generate ```@autodocs diff --git a/src/GNNGraphs/GNNGraphs.jl b/src/GNNGraphs/GNNGraphs.jl index 1e28fc9c8..9238c5d06 100644 --- a/src/GNNGraphs/GNNGraphs.jl +++ b/src/GNNGraphs/GNNGraphs.jl @@ -14,7 +14,7 @@ import KrylovKit using ChainRulesCore using LinearAlgebra, Random, Statistics import MLUtils -using MLUtils: getobs, numobs, ones_like, zeros_like +using MLUtils: getobs, numobs, ones_like, zeros_like, chunk import Functors include("chainrules.jl") # hacks for differentiability @@ -104,6 +104,7 @@ include("operators.jl") include("convert.jl") include("utils.jl") +export sort_edge_index, color_refinement include("gatherscatter.jl") # _gather, _scatter diff --git a/src/GNNGraphs/utils.jl b/src/GNNGraphs/utils.jl index e2b821369..4bba304ef 100644 --- a/src/GNNGraphs/utils.jl +++ b/src/GNNGraphs/utils.jl @@ -49,6 +49,15 @@ end sort_edge_index(eindex::Tuple) = sort_edge_index(eindex...) +""" + sort_edge_index(ei::Tuple) -> u', v' + sort_edge_index(u, v) -> u', v' + +Return a sorted version of the tuple of vectors `ei = (u, v)`, +applying a common permutation to `u` and `v`. +The sorting is lexycographic, that is the pairs `(ui, vi)` +are sorted first according to the `ui` and then according to `vi`. +""" function sort_edge_index(u, v) uv = collect(zip(u, v)) p = sortperm(uv) # isless lexicographically defined for tuples @@ -56,7 +65,6 @@ function sort_edge_index(u, v) end - cat_features(x1::Nothing, x2::Nothing) = nothing cat_features(x1::AbstractArray, x2::AbstractArray) = cat(x1, x2, dims = ndims(x1)) function cat_features(x1::Union{Number, AbstractVector}, x2::Union{Number, AbstractVector}) @@ -301,4 +309,56 @@ end @non_differentiable normalize_graphdata(::Nothing) iscuarray(x::AbstractArray) = false -@non_differentiable iscuarray(::Any) \ No newline at end of file +@non_differentiable iscuarray(::Any) + + +@doc raw""" + color_refinement(g::GNNGraph, [x0]) -> x, num_colors, niters + +The color refinement algorithm for graph coloring. +Given a graph `g` and an initial coloring `x0`, the algorithm +iteratively refines the coloring until a fixed point is reached. + +At each iteration the algorithm computes a hash of the coloring and the sorted list of colors +of the neighbors of each node. This hash is used to determine if the coloring has changed. + +```math +x_i' = hashmap((x_i, sort([x_j for j \in N(i)]))). +```` + +This algorithm is related to the 1-Weisfeiler-Lehman algorithm for graph isomorphism testing. + +# Arguments +- `g::GNNGraph`: The graph to color. +- `x0::AbstractVector{<:Integer}`: The initial coloring. If not provided, all nodes are colored with 1. + +# Returns +- `x::AbstractVector{<:Integer}`: The final coloring. +- `num_colors::Int`: The number of colors used. +- `niters::Int`: The number of iterations until convergence. +""" +color_refinement(g::GNNGraph) = color_refinement(g, ones(Int, g.num_nodes)) + +function color_refinement(g::GNNGraph, x0::AbstractVector{<:Integer}) + @assert length(x0) == g.num_nodes + s, t = edge_index(g) + t, s = sort_edge_index(t, s) # sort by target + degs = degree(g, dir=:in) + x = x0 + + hashmap = Dict{UInt64, Int}() + x′ = zeros(Int, length(x0)) + niters = 0 + while true + xneigs = chunk(x[s], size=degs) + for (i, (xi, xineigs)) in enumerate(zip(x, xneigs)) + idx = hash((xi, sort(xineigs))) + x′[i] = get!(hashmap, idx, length(hashmap) + 1) + end + niters += 1 + x == x′ && break + x = x′ + end + num_colors = length(union(x)) + return x, num_colors, niters +end \ No newline at end of file diff --git a/test/GNNGraphs/utils.jl b/test/GNNGraphs/utils.jl index d4575933a..db65b6357 100644 --- a/test/GNNGraphs/utils.jl +++ b/test/GNNGraphs/utils.jl @@ -48,3 +48,15 @@ @test sdec == snew @test tdec == tnew end + +@testset "color_refinment" begin + g = rand_graph(10, 20, seed=17, graph_type = GRAPH_T) + x0 = ones(Int, 10) + x, ncolors, niters = color_refinement(g, x0) + @test ncolors == 8 + @test niters == 2 + @test x == [4, 5, 6, 7, 8, 5, 8, 9, 10, 11] + + x2, _, _ = color_refinement(g) + @test x2 == x +end \ No newline at end of file