diff --git a/src/GraphNeuralNetworks.jl b/src/GraphNeuralNetworks.jl index ae4440f22..5fde68fea 100644 --- a/src/GraphNeuralNetworks.jl +++ b/src/GraphNeuralNetworks.jl @@ -28,6 +28,7 @@ export broadcast_nodes, broadcast_edges, softmax_edge_neighbors, + topk_feature, # msgpass apply_edges, diff --git a/src/utils.jl b/src/utils.jl index 8434c63c8..dd2a4a4b7 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -120,6 +120,81 @@ function broadcast_edges(g::GNNGraph, x) return gather(x, gi) end +function _sort_col(matrix::AbstractArray; rev::Bool = true, sortby::Int = 1) + index = sortperm(view(matrix, sortby, :); rev) + return matrix[:, index], index +end + +function _topk_matrix(matrix::AbstractArray, k::Int; rev::Bool = true, sortby::Union{Nothing, Int} = nothing) + if sortby === nothing + sorted_matrix = sort(matrix, dims = 2; rev)[:, 1:k] + vector_indices = map(x -> sortperm(x; rev), eachrow(matrix)) + indices = reduce(vcat, vector_indices')[:, 1:k] + return sorted_matrix, indices + else + sorted_matrix, indices = _sort_col(matrix; rev, sortby) + return sorted_matrix[:, 1:k], indices[1:k] + end +end + +function _topk_batch(matrices::AbstractArray, k::Int; rev::Bool = true, + sortby::Union{Nothing, Int} = nothing) + num_graphs = length(matrices) + num_feat = size(matrices[1], 1) + sorted_matrix = map(x -> _topk_matrix(x, k; rev, sortby)[1], matrices) + output_matrix = reshape(reduce(hcat, sorted_matrix), num_feat, k, num_graphs) + indices = map(x -> _topk_matrix(x, k; rev, sortby)[2], matrices) + if sortby === nothing + output_indices = reshape(reduce(hcat, indices), num_feat, k, num_graphs) + else + output_indices = reshape(reduce(hcat, indices), k, 1, num_graphs) + end + return output_matrix, output_indices +end + +""" + topk_feature(g, feat, k; rev = true, sortby = nothing) + +Graph-wise top-`k` on feature array `x` according to the `sortby` index. +Returns a tuple of the top-`k` features and their indices. + +# Arguments + +- `g`: a `GNNGraph``. +- `feat`: a feature array of size `(number_features, g.num_nodes)` or `(number_features, g.num_edges)` of the graph `g`. +- `k`: the number of top features to return. +- `rev`: if `true`, sort in descending order otherwise returns the `k` smallest elements. +- `sortby`: the index of the feature to sort by. If `nothing`, every row independently. + +# Examples + +```julia +julia> g = rand_graph(5, 4, ndata = rand(3,5)); + +julia> g.ndata.x +3×5 Matrix{Float64}: + 0.333661 0.683551 0.315145 0.794089 0.840085 + 0.263023 0.726028 0.626617 0.412247 0.0914052 + 0.296433 0.186584 0.960758 0.0999844 0.813808 + +julia> topk_feature(g, g.ndata.x, 2) +([0.8400845757074524 0.7940891040468462; 0.7260276789396128 0.6266174187625888; 0.9607582005024967 0.8138081223752274], [5 4; 2 3; 3 5]) + +julia> topk_feature(g, g.ndata.x, 2; sortby=3) +([0.3151452763177829 0.8400845757074524; 0.6266174187625888 0.09140519108918477; 0.9607582005024967 0.8138081223752274], [3, 5]) + +``` + +""" +function topk_feature(g::GNNGraph, feat::AbstractArray, k::Int; rev::Bool = true, + sortby::Union{Nothing, Int} = nothing) + if g.num_graphs == 1 + return _topk_matrix(feat, k; rev, sortby) + else + matrices = [feat[:, g.graph_indicator .== i] for i in 1:(g.num_graphs)] + return _topk_batch(matrices, k; rev, sortby) + end +end expand_srcdst(g::AbstractGNNGraph, x) = throw(ArgumentError("Invalid input type, expected matrix or tuple of matrices.")) expand_srcdst(g::AbstractGNNGraph, x::AbstractMatrix) = (x, x) diff --git a/test/utils.jl b/test/utils.jl index 7fb423c18..6b1d567e6 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -1,8 +1,8 @@ De, Dx = 3, 2 g = Flux.batch([GNNGraph(erdos_renyi(10, 30), - ndata = rand(Dx, 10), - edata = rand(De, 30), - graph_type = GRAPH_T) for i in 1:5]) + ndata = rand(Dx, 10), + edata = rand(De, 30), + graph_type = GRAPH_T) for i in 1:5]) x = g.ndata.x e = g.edata.e @@ -62,3 +62,46 @@ end @test z[:, 3:4] ≈ NNlib.softmax(e2[:, 3:4], dims = 2) end +@testset "topk_feature" begin + A = [0.0297 0.5901 0.088 0.5171; + 0.8307 0.303 0.6515 0.6379; + 0.914 0.928 0.4451 0.2695; + 0.6702 0.6893 0.7507 0.8954; + 0.3346 0.7997 0.5297 0.5197] + B = [0.3168 0.1323 0.1752 0.1931 0.5065; + 0.3174 0.2766 0.9105 0.4954 0.5182; + 0.5303 0.4318 0.5692 0.3455 0.5418; + 0.0804 0.6114 0.8489 0.3934 0.152; + 0.3808 0.1458 0.0539 0.0857 0.3872] + g1 = rand_graph(4, 2, ndata = (x = A,)) + g2 = rand_graph(5, 4, ndata = B) + g = Flux.batch([g1, g2]) + output1 = topk_feature(g, g.ndata.x, 3) + @test output1[1][:, :, 1] == [0.5901 0.5171 0.088; + 0.8307 0.6515 0.6379; + 0.928 0.914 0.4451; + 0.8954 0.7507 0.6893; + 0.7997 0.5297 0.5197] + @test output1[1][:, :, 2] == [0.5065 0.3168 0.1931; + 0.9105 0.5182 0.4954; + 0.5692 0.5418 0.5303; + 0.8489 0.6114 0.3934; + 0.3872 0.3808 0.1458] + @test output1[2][:, :, 1] == [2 4 3; + 1 3 4; + 2 1 3; + 4 3 2; + 2 3 4] + @test output1[2][:, :, 2] == [5 1 4; + 3 5 4; + 3 5 1; + 3 2 4; + 5 1 2] + output2 = topk_feature(g, g.ndata.x, 2; sortby = 5) + @test output2[1][:, :, 1] == [0.5901 0.088 + 0.303 0.6515; + 0.928 0.4451; + 0.6893 0.7507; + 0.7997 0.5297] + @test output2[2][:, :, 1] == [2; 3;;] +end