Skip to content

Commit

Permalink
Added remove_edges function (#414)
Browse files Browse the repository at this point in the history
* added remove edge function

* tests

* added remove edge function

* fix

* fix

* fix

* fix

* fix

* Update src/GNNGraphs/transform.jl

Co-authored-by: Carlo Lucibello <[email protected]>

* Update src/GNNGraphs/transform.jl

Co-authored-by: Carlo Lucibello <[email protected]>

* Update transform.jl

* tests final

* Update Project.toml

* Update src/GNNGraphs/transform.jl

Co-authored-by: Carlo Lucibello <[email protected]>

* done

* fixes

* more tests

---------

Co-authored-by: Carlo Lucibello <[email protected]>
  • Loading branch information
rbSparky and CarloLucibello authored Mar 21, 2024
1 parent 374d8fb commit 95a90fc
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 2 deletions.
3 changes: 1 addition & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,4 @@ Manifest.toml
/docs/build/
.vscode
LocalPreferences.toml
.DS_Store
/test.jl
.DS_Store
1 change: 1 addition & 0 deletions src/GNNGraphs/GNNGraphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ export add_nodes,
negative_sample,
rand_edge_split,
remove_self_loops,
remove_edges,
remove_multi_edges,
set_edge_weight,
to_bidirected,
Expand Down
51 changes: 51 additions & 0 deletions src/GNNGraphs/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,57 @@ function remove_self_loops(g::GNNGraph{<:ADJMAT_T})
g.ndata, g.edata, g.gdata)
end

"""
remove_edges(g::GNNGraph, edges_to_remove::AbstractVector{<:Integer})
Remove specified edges from a GNNGraph.
# Arguments
- `g`: The input graph from which edges will be removed.
- `edges_to_remove`: Vector of edge indices to be removed.
# Returns
A new GNNGraph with the specified edges removed.
# Example
```julia
julia> using GraphNeuralNetworks
# Construct a GNNGraph
julia> g = GNNGraph([1, 1, 2, 2, 3], [2, 3, 1, 3, 1])
GNNGraph:
num_nodes: 3
num_edges: 5
# Remove the second edge
julia> g_new = remove_edges(g, [2]);
julia> g_new
GNNGraph:
num_nodes: 3
num_edges: 4
```
"""
function remove_edges(g::GNNGraph{<:COO_T}, edges_to_remove::AbstractVector{<:Integer})
s, t = edge_index(g)
w = get_edge_weight(g)
edata = g.edata

mask_to_keep = trues(length(s))

mask_to_keep[edges_to_remove] .= false

s = s[mask_to_keep]
t = t[mask_to_keep]
edata = getobs(edata, mask_to_keep)
w = isnothing(w) ? nothing : getobs(w, mask_to_keep)

return GNNGraph((s, t, w),
g.num_nodes, length(s), g.num_graphs,
g.graph_indicator,
g.ndata, edata, g.gdata)
end

"""
remove_multi_edges(g::GNNGraph; aggr=+)
Expand Down
28 changes: 28 additions & 0 deletions test/GNNGraphs/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,34 @@ end
@test nodemap == 1:(g1.num_nodes)
end

@testset "remove_edges" begin
if GRAPH_T == :coo
s = [1, 1, 2, 3]
t = [2, 3, 4, 5]
w = [0.1, 0.2, 0.3, 0.4]
edata = ['a', 'b', 'c', 'd']
g = GNNGraph(s, t, w, edata = edata, graph_type = GRAPH_T)

# single edge removal
gnew = remove_edges(g, [1])
new_s, new_t = edge_index(gnew)
@test gnew.num_edges == 3
@test new_s == s[2:end]
@test new_t == t[2:end]

# multiple edge removal
gnew = remove_edges(g, [1,2,4])
new_s, new_t = edge_index(gnew)
new_w = get_edge_weight(gnew)
new_edata = gnew.edata.e
@test gnew.num_edges == 1
@test new_s == [2]
@test new_t == [4]
@test new_w == [0.3]
@test new_edata == ['c']
end
end

@testset "add_edges" begin
if GRAPH_T == :coo
s = [1, 1, 2, 3]
Expand Down

0 comments on commit 95a90fc

Please sign in to comment.