Skip to content

Commit

Permalink
Coloring refinement algorithm (#444)
Browse files Browse the repository at this point in the history
* add coloring refinment algorithm

* also in GNNlib

* docs
  • Loading branch information
CarloLucibello authored Jul 9, 2024
1 parent 3bcafbe commit acf4b6a
Show file tree
Hide file tree
Showing 6 changed files with 147 additions and 4 deletions.
2 changes: 2 additions & 0 deletions GNNlib/src/GNNGraphs/GNNGraphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ include("operators.jl")

include("convert.jl")
include("utils.jl")
export sort_edge_index,
color_refinement

include("gatherscatter.jl")
# _gather, _scatter
Expand Down
63 changes: 62 additions & 1 deletion GNNlib/src/GNNGraphs/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,15 @@ end

sort_edge_index(eindex::Tuple) = sort_edge_index(eindex...)

"""
sort_edge_index(ei::Tuple) -> u', v'
sort_edge_index(u, v) -> u', v'
Return a sorted version of the tuple of vectors `ei = (u, v)`,
applying a common permutation to `u` and `v`.
The sorting is lexycographic, that is the pairs `(ui, vi)`
are sorted first according to the `ui` and then according to `vi`.
"""
function sort_edge_index(u, v)
uv = collect(zip(u, v))
p = sortperm(uv) # isless lexicographically defined for tuples
Expand Down Expand Up @@ -301,4 +310,56 @@ end
@non_differentiable normalize_graphdata(::Nothing)

iscuarray(x::AbstractArray) = false
@non_differentiable iscuarray(::Any)
@non_differentiable iscuarray(::Any)


@doc raw"""
color_refinement(g::GNNGraph, [x0]) -> x, num_colors, niters
The color refinement algorithm for graph coloring.
Given a graph `g` and an initial coloring `x0`, the algorithm
iteratively refines the coloring until a fixed point is reached.
At each iteration the algorithm computes a hash of the coloring and the sorted list of colors
of the neighbors of each node. This hash is used to determine if the coloring has changed.
```math
x_i' = hashmap((x_i, sort([x_j for j \in N(i)]))).
````
This algorithm is related to the 1-Weisfeiler-Lehman algorithm for graph isomorphism testing.
# Arguments
- `g::GNNGraph`: The graph to color.
- `x0::AbstractVector{<:Integer}`: The initial coloring. If not provided, all nodes are colored with 1.
# Returns
- `x::AbstractVector{<:Integer}`: The final coloring.
- `num_colors::Int`: The number of colors used.
- `niters::Int`: The number of iterations until convergence.
"""
color_refinement(g::GNNGraph) = color_refinement(g, ones(Int, g.num_nodes))

function color_refinement(g::GNNGraph, x0::AbstractVector{<:Integer})
@assert length(x0) == g.num_nodes
s, t = edge_index(g)
t, s = sort_edge_index(t, s) # sort by target
degs = degree(g, dir=:in)
x = x0

hashmap = Dict{UInt64, Int}()
x′ = zeros(Int, length(x0))
niters = 0
while true
xneigs = chunk(x[s], size=degs)
for (i, (xi, xineigs)) in enumerate(zip(x, xneigs))
idx = hash((xi, sort(xineigs)))
x′[i] = get!(hashmap, idx, length(hashmap) + 1)
end
niters += 1
x == x′ && break
x = x′
end
num_colors = length(union(x))
return x, num_colors, niters
end
7 changes: 7 additions & 0 deletions docs/src/api/gnngraph.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@ Pages = ["transform.jl"]
Private = false
```

## Utils

```@docs
GNNGraphs.sort_edge_index
GNNGraphs.color_refinement
```

## Generate

```@autodocs
Expand Down
3 changes: 2 additions & 1 deletion src/GNNGraphs/GNNGraphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import KrylovKit
using ChainRulesCore
using LinearAlgebra, Random, Statistics
import MLUtils
using MLUtils: getobs, numobs, ones_like, zeros_like
using MLUtils: getobs, numobs, ones_like, zeros_like, chunk
import Functors

include("chainrules.jl") # hacks for differentiability
Expand Down Expand Up @@ -104,6 +104,7 @@ include("operators.jl")

include("convert.jl")
include("utils.jl")
export sort_edge_index, color_refinement

include("gatherscatter.jl")
# _gather, _scatter
Expand Down
64 changes: 62 additions & 2 deletions src/GNNGraphs/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,22 @@ end

sort_edge_index(eindex::Tuple) = sort_edge_index(eindex...)

"""
sort_edge_index(ei::Tuple) -> u', v'
sort_edge_index(u, v) -> u', v'
Return a sorted version of the tuple of vectors `ei = (u, v)`,
applying a common permutation to `u` and `v`.
The sorting is lexycographic, that is the pairs `(ui, vi)`
are sorted first according to the `ui` and then according to `vi`.
"""
function sort_edge_index(u, v)
uv = collect(zip(u, v))
p = sortperm(uv) # isless lexicographically defined for tuples
return u[p], v[p]
end



cat_features(x1::Nothing, x2::Nothing) = nothing
cat_features(x1::AbstractArray, x2::AbstractArray) = cat(x1, x2, dims = ndims(x1))
function cat_features(x1::Union{Number, AbstractVector}, x2::Union{Number, AbstractVector})
Expand Down Expand Up @@ -301,4 +309,56 @@ end
@non_differentiable normalize_graphdata(::Nothing)

iscuarray(x::AbstractArray) = false
@non_differentiable iscuarray(::Any)
@non_differentiable iscuarray(::Any)


@doc raw"""
color_refinement(g::GNNGraph, [x0]) -> x, num_colors, niters
The color refinement algorithm for graph coloring.
Given a graph `g` and an initial coloring `x0`, the algorithm
iteratively refines the coloring until a fixed point is reached.
At each iteration the algorithm computes a hash of the coloring and the sorted list of colors
of the neighbors of each node. This hash is used to determine if the coloring has changed.
```math
x_i' = hashmap((x_i, sort([x_j for j \in N(i)]))).
````
This algorithm is related to the 1-Weisfeiler-Lehman algorithm for graph isomorphism testing.
# Arguments
- `g::GNNGraph`: The graph to color.
- `x0::AbstractVector{<:Integer}`: The initial coloring. If not provided, all nodes are colored with 1.
# Returns
- `x::AbstractVector{<:Integer}`: The final coloring.
- `num_colors::Int`: The number of colors used.
- `niters::Int`: The number of iterations until convergence.
"""
color_refinement(g::GNNGraph) = color_refinement(g, ones(Int, g.num_nodes))

function color_refinement(g::GNNGraph, x0::AbstractVector{<:Integer})
@assert length(x0) == g.num_nodes
s, t = edge_index(g)
t, s = sort_edge_index(t, s) # sort by target
degs = degree(g, dir=:in)
x = x0

hashmap = Dict{UInt64, Int}()
x′ = zeros(Int, length(x0))
niters = 0
while true
xneigs = chunk(x[s], size=degs)
for (i, (xi, xineigs)) in enumerate(zip(x, xneigs))
idx = hash((xi, sort(xineigs)))
x′[i] = get!(hashmap, idx, length(hashmap) + 1)
end
niters += 1
x == x′ && break
x = x′
end
num_colors = length(union(x))
return x, num_colors, niters
end
12 changes: 12 additions & 0 deletions test/GNNGraphs/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,15 @@
@test sdec == snew
@test tdec == tnew
end

@testset "color_refinment" begin
g = rand_graph(10, 20, seed=17, graph_type = GRAPH_T)
x0 = ones(Int, 10)
x, ncolors, niters = color_refinement(g, x0)
@test ncolors == 8
@test niters == 2
@test x == [4, 5, 6, 7, 8, 5, 8, 9, 10, 11]

x2, _, _ = color_refinement(g)
@test x2 == x
end

0 comments on commit acf4b6a

Please sign in to comment.