diff --git a/src/utils.jl b/src/utils.jl index 42a7f6b05..e3578517b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -105,7 +105,7 @@ function _sort_col(matrix::AbstractArray; rev::Bool = true, sortby::Int = 1) return matrix[:, index] end -function _sort_matrix(matrix::AbstractArray, k::Int; rev::Bool = true, sortby = nothing) +function _topk_matrix(matrix::AbstractArray, k::Int; rev::Bool = true, sortby = nothing) if sortby === nothing return sort(matrix, dims = 2; rev)[:, 1:k] else @@ -113,35 +113,26 @@ function _sort_matrix(matrix::AbstractArray, k::Int; rev::Bool = true, sortby = end end -function _sort_batch(matrices, k::Int; rev::Bool = true, sortby = nothing) - return map(x -> _sort_matrix(x, k; rev, sortby), matrices) -end - -function _topk_batch(matrix::AbstractArray, number_graphs::Int, k::Int; rev::Bool = true, +function _topk_batch(matrices::AbstractArray, k::Int; rev::Bool = true, sortby = nothing) - tensor_matrix = reshape(matrix, size(matrix, 1), size(matrix, 2) รท number_graphs, - number_graphs) - sorted_matrix = _sort_batch(eachslice(tensor_matrix, dims = 3), k; rev, sortby) + sorted_matrix = map(x -> _topk_matrix(x, k; rev, sortby), matrices) return reduce(hcat, sorted_matrix) end -function _topk(matrix::AbstractArray, number_graphs::Int, k::Int; rev::Bool = true, - sortby = nothing) - if number_graphs == 1 - return _sort_matrix(matrix, k; rev, sortby) - else - return _topk_batch(matrix, number_graphs, k; rev, sortby) - end -end - """ topk_nodes(g, feat, k; rev = true, sortby = nothing) Graph-wise top-k on node features `feat` according to the `sortby` feature index. """ function topk_nodes(g::GNNGraph, feat::Symbol, k::Int; rev = true, sortby = nothing) - matrix = getproperty(g.ndata, feat) - return _topk(matrix, g.num_graphs, k; rev, sortby) + if g.num_graphs == 1 + matrix = getproperty(g.ndata, feat) + return _topk_matrix(matrix, k; rev, sortby) + else + graphs = [getgraph(g, i) for i in 1:(g.num_graphs)] + matrices = map(graph -> getproperty(graph.ndata, feat), graphs) + return _topk_batch(matrices, k; rev, sortby) + end end """ @@ -150,6 +141,12 @@ end Graph-wise top-k on edge features `feat` according to the `sortby` feature index. """ function topk_edges(g::GNNGraph, feat::Symbol, k::Int; rev = true, sortby = nothing) - matrix = getproperty(g.edata, feat) - return _topk(matrix, g.num_graphs, k; rev, sortby) + if g.num_graphs == 1 + matrix = getproperty(g.edata, feat) + return _topk_matrix(matrix, k; rev, sortby) + else + graphs = [getgraph(g, i) for i in 1:(g.num_graphs)] + matrices = map(graph -> getproperty(graph.edata, feat), graphs) + return _topk_batch(matrices, k; rev, sortby) + end end