Skip to content

Commit

Permalink
Fix batch case and reorder
Browse files Browse the repository at this point in the history
  • Loading branch information
aurorarossi committed Mar 10, 2023
1 parent e10c4a9 commit eec3a46
Showing 1 changed file with 19 additions and 22 deletions.
41 changes: 19 additions & 22 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,43 +105,34 @@ function _sort_col(matrix::AbstractArray; rev::Bool = true, sortby::Int = 1)
return matrix[:, index]
end

function _sort_matrix(matrix::AbstractArray, k::Int; rev::Bool = true, sortby = nothing)
function _topk_matrix(matrix::AbstractArray, k::Int; rev::Bool = true, sortby = nothing)
if sortby === nothing
return sort(matrix, dims = 2; rev)[:, 1:k]
else
return _sort_col(matrix; rev, sortby)[:, 1:k]
end
end

function _sort_batch(matrices, k::Int; rev::Bool = true, sortby = nothing)
return map(x -> _sort_matrix(x, k; rev, sortby), matrices)
end

function _topk_batch(matrix::AbstractArray, number_graphs::Int, k::Int; rev::Bool = true,
function _topk_batch(matrices::AbstractArray, k::Int; rev::Bool = true,
sortby = nothing)
tensor_matrix = reshape(matrix, size(matrix, 1), size(matrix, 2) ÷ number_graphs,
number_graphs)
sorted_matrix = _sort_batch(eachslice(tensor_matrix, dims = 3), k; rev, sortby)
sorted_matrix = map(x -> _topk_matrix(x, k; rev, sortby), matrices)
return reduce(hcat, sorted_matrix)
end

function _topk(matrix::AbstractArray, number_graphs::Int, k::Int; rev::Bool = true,
sortby = nothing)
if number_graphs == 1
return _sort_matrix(matrix, k; rev, sortby)
else
return _topk_batch(matrix, number_graphs, k; rev, sortby)
end
end

"""
topk_nodes(g, feat, k; rev = true, sortby = nothing)
Graph-wise top-k on node features `feat` according to the `sortby` feature index.
"""
function topk_nodes(g::GNNGraph, feat::Symbol, k::Int; rev = true, sortby = nothing)
matrix = getproperty(g.ndata, feat)
return _topk(matrix, g.num_graphs, k; rev, sortby)
if g.num_graphs == 1
matrix = getproperty(g.ndata, feat)
return _topk_matrix(matrix, k; rev, sortby)
else
graphs = [getgraph(g, i) for i in 1:(g.num_graphs)]
matrices = map(graph -> getproperty(graph.ndata, feat), graphs)
return _topk_batch(matrices, k; rev, sortby)
end
end

"""
Expand All @@ -150,6 +141,12 @@ end
Graph-wise top-k on edge features `feat` according to the `sortby` feature index.
"""
function topk_edges(g::GNNGraph, feat::Symbol, k::Int; rev = true, sortby = nothing)
matrix = getproperty(g.edata, feat)
return _topk(matrix, g.num_graphs, k; rev, sortby)
if g.num_graphs == 1
matrix = getproperty(g.edata, feat)
return _topk_matrix(matrix, k; rev, sortby)
else
graphs = [getgraph(g, i) for i in 1:(g.num_graphs)]
matrices = map(graph -> getproperty(graph.edata, feat), graphs)
return _topk_batch(matrices, k; rev, sortby)
end
end

0 comments on commit eec3a46

Please sign in to comment.