Skip to content

Commit

Permalink
perturb_nodes -> remove_nodes (#454)
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello authored Jul 24, 2024
1 parent 1823e0d commit 7e7e202
Show file tree
Hide file tree
Showing 11 changed files with 56 additions and 61 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests_GNNGraphs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
matrix:
version:
- '1.10' # Replace this with the minimum Julia version that your package supports.
- '1' # '1' will automatically expand to the latest stable 1.x release of Julia.
# - '1' # '1' will automatically expand to the latest stable 1.x release of Julia.
# - 'pre'
os:
- ubuntu-latest
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/tests_GraphNeuralNetworks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
matrix:
version:
- '1.10' # Replace this with the minimum Julia version that your package supports.
- '1' # '1' will automatically expand to the latest stable 1.x release of Julia.
# - '1' # '1' will automatically expand to the latest stable 1.x release of Julia.
# - 'pre'
os:
- ubuntu-latest
Expand Down
4 changes: 2 additions & 2 deletions GNNGraphs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "GNNGraphs"
uuid = "aed8fd31-079b-4b5a-b342-a13352159b8c"
authors = ["Carlo Lucibello and contributors"]
version = "0.1.0"
version = "1.0.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down Expand Up @@ -46,7 +46,7 @@ SparseArrays = "1"
Statistics = "1"
StatsBase = "0.34"
cuDNN = "1"
julia = "1.9"
julia = "1.10"

[extras]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
33 changes: 33 additions & 0 deletions GNNGraphs/ext/GNNGraphsCUDAExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
module GNNGraphsCUDAExt

using CUDA
using Random, Statistics, LinearAlgebra
using GNNGraphs
using GNNGraphs: COO_T, ADJMAT_T, SPARSE_T

const CUMAT_T = Union{CUDA.AnyCuMatrix, CUDA.CUSPARSE.CuSparseMatrix}

# Query

GNNGraphs._rand_dense_vector(A::CUMAT_T) = CUDA.randn(size(A, 1))

# Transform

GNNGraphs.dense_zeros_like(a::CUMAT_T, T::Type, sz = size(a)) = CUDA.zeros(T, sz)


# Utils

GNNGraphs.iscuarray(x::AnyCuArray) = true


function sort_edge_index(u::AnyCuArray, v::AnyCuArray)
dev = get_device(u)
cdev = cpu_device()
u, v = u |> cdev, v |> cdev
#TODO proper cuda friendly implementation
sort_edge_index(u, v) |> dev
end


end #module
14 changes: 0 additions & 14 deletions GNNGraphs/ext/GNNGraphsCUDAExt/GNNGraphsCUDAExt.jl

This file was deleted.

2 changes: 0 additions & 2 deletions GNNGraphs/ext/GNNGraphsCUDAExt/query.jl

This file was deleted.

2 changes: 0 additions & 2 deletions GNNGraphs/ext/GNNGraphsCUDAExt/transform.jl

This file was deleted.

11 changes: 0 additions & 11 deletions GNNGraphs/ext/GNNGraphsCUDAExt/utils.jl

This file was deleted.

1 change: 0 additions & 1 deletion GNNGraphs/src/GNNGraphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ export add_nodes,
perturb_edges,
remove_nodes,
ppr_diffusion,
drop_nodes,
# from MLUtils
batch,
unbatch,
Expand Down
38 changes: 15 additions & 23 deletions GNNGraphs/src/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -307,35 +307,27 @@ function remove_nodes(g::GNNGraph{<:COO_T}, nodes_to_remove::AbstractVector)
end

"""
drop_nodes(g::GNNGraph{<:COO_T}, p)
remove_nodes(g::GNNGraph, p)
Randomly drop nodes (and their associated edges) from a GNNGraph based on a given probability.
Dropping nodes is a technique that can be used for graph data augmentation, refering paper [DropNode](https://arxiv.org/pdf/2008.12578.pdf).
Returns a new graph obtained by dropping nodes from `g` with independent probabilities `p`.
# Arguments
- `g`: The input graph from which nodes (and their associated edges) will be dropped.
- `p`: The probability of dropping each node. Default value is `0.5`.
# Returns
A modified GNNGraph with nodes (and their associated edges) dropped based on the given probability.
# Examples
# Example
```julia
using GraphNeuralNetworks
# Construct a GNNGraph
g = GNNGraph([1, 1, 2, 2, 3], [2, 3, 1, 3, 1], num_nodes=3)
# Drop nodes with a probability of 0.5
g_new = drop_node(g, 0.5)
println(g_new)
julia> g = GNNGraph([1, 1, 2, 2, 3, 4], [1, 2, 3, 1, 3, 1])
GNNGraph:
num_nodes: 4
num_edges: 6
julia> g_new = remove_nodes(g, 0.5)
GNNGraph:
num_nodes: 2
num_edges: 2
```
"""
function drop_nodes(g::GNNGraph{<:COO_T}, p = 0.5)
num_nodes = g.num_nodes
nodes_to_remove = filter(_ -> rand() < p, 1:num_nodes)

new_g = remove_nodes(g, nodes_to_remove)

return new_g
function remove_nodes(g::GNNGraph, p::AbstractFloat)
nodes_to_remove = filter(_ -> rand() < p, 1:g.num_nodes)
return remove_nodes(g, nodes_to_remove)
end

"""
Expand Down
8 changes: 4 additions & 4 deletions GNNGraphs/test/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -247,20 +247,20 @@ end end
@test edata_new == edatatest
end end

@testset "drop_nodes" begin
@testset "remove_nodes(g, p)" begin
if GRAPH_T == :coo
Random.seed!(42)
s = [1, 1, 2, 3]
t = [2, 3, 4, 5]
g = GNNGraph(s, t, graph_type = GRAPH_T)

gnew = drop_nodes(g, Float32(0.5))
gnew = remove_nodes(g, 0.5)
@test gnew.num_nodes == 3

gnew = drop_nodes(g, Float32(1.0))
gnew = remove_nodes(g, 1.0)
@test gnew.num_nodes == 0

gnew = drop_nodes(g, Float32(0.0))
gnew = remove_nodes(g, 0.0)
@test gnew.num_nodes == 5
end
end
Expand Down

4 comments on commit 7e7e202

@CarloLucibello
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register subdir=GNNGraphs

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An unexpected error occurred during registration.

@CarloLucibello
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register subdir=GNNGraphs

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Error while trying to register: Version 1.0.0 already exists

Please sign in to comment.