diff --git a/src/disjoint_set.jl b/src/disjoint_set.jl index 063522afe..40b640ab7 100644 --- a/src/disjoint_set.jl +++ b/src/disjoint_set.jl @@ -107,17 +107,21 @@ end @compat type DisjointSets{T} intmap::Dict{T,Int} + revmap::Vector{T} internal::IntDisjointSets function (::Type{DisjointSets{T}}){T}(xs) # xs must be iterable imap = Dict{T,Int}() + rmap = Vector{T}() n = length(xs) sizehint!(imap, n) + sizehint!(rmap, n) id = 0 for x in xs imap[x] = (id += 1) + push!(rmap,x) end - new{T}(imap, IntDisjointSets(n)) + new{T}(imap, rmap, IntDisjointSets(n)) end end @@ -129,17 +133,17 @@ num_groups(s::DisjointSets) = num_groups(s.internal) Finds the root element of the subset in `s` which has the element `x` as a member. """ -find_root{T}(s::DisjointSets{T}, x::T) = find_root(s.internal, s.intmap[x]) +find_root{T}(s::DisjointSets{T}, x::T) = s.revmap[find_root(s.internal, s.intmap[x])] in_same_set{T}(s::DisjointSets{T}, x::T, y::T) = in_same_set(s.internal, s.intmap[x], s.intmap[y]) -for f in (:union!, :root_union!) - @eval begin - ($f){T}(s::DisjointSets{T}, x::T, y::T) = ($f)(s.internal, s.intmap[x], s.intmap[y]) - end -end +union!{T}(s::DisjointSets{T}, x::T, y::T) = s.revmap[union!(s.internal, s.intmap[x], s.intmap[y])] + +root_union!{T}(s::DisjointSets{T}, x::T, y::T) = s.revmap[root_union!(s.internal, s.intmap[x], s.intmap[y])] function push!{T}(s::DisjointSets{T}, x::T) id = push!(s.internal) s.intmap[x] = id + push!(s.revmap,x) # Note, this assumes invariant: length(s.revmap) == id + x end diff --git a/test/test_disjoint_set.jl b/test/test_disjoint_set.jl index 895674abe..f78481381 100644 --- a/test/test_disjoint_set.jl +++ b/test/test_disjoint_set.jl @@ -40,53 +40,66 @@ root2 = find_root(s, 2) @test root_union!(s, root1, root2) == 8 @test union!(s, 5, 6) == 8 +# DisjointSets supports arbitrary indices s = DisjointSets{Int}(1:10) @test length(s) == 10 @test num_groups(s) == 10 r = [find_root(s, i) for i in 1 : 10] -@test isa(r, Vector{Int}) @test isequal(r, collect(1:10)) for i = 1 : 5 x = 2 * i - 1 y = 2 * i union!(s, x, y) + @test find_root(s, x) == find_root(s, y) end -@test length(s) == 10 -@test num_groups(s) == 5 - -r0 = [1, 1, 3, 3, 5, 5, 7, 7, 9, 9] -r = [find_root(s, i) for i in 1 : 10] -@test isa(r, Vector{Int}) -@test isequal(r, r0) -@test union!(s, 1, 4) == 1 -@test union!(s, 3, 5) == 1 -@test union!(s, 7, 9) == 7 +@test union!(s, 1, 4) == find_root(s,1) +@test union!(s, 3, 5) == find_root(s,1) +@test union!(s, 7, 9) == find_root(s,7) @test length(s) == 10 @test num_groups(s) == 2 -r0 = [1, 1, 1, 1, 1, 1, 7, 7, 7, 7] -r = [find_root(s, i) for i in 1 : 10] -@test isa(r, Vector{Int}) -@test isequal(r, r0) +r0 = [ find_root(s,i) for i in 1:10 ] +# Since this is a DisjointSet (not IntDisjointSet), the root for 17 will be 17, not 11 push!(s, 17) @test length(s) == 11 @test num_groups(s) == 3 -r0 = [1, 1, 1, 1, 1, 1, 7, 7, 7, 7, 11] +r0 = [ r0 ; 17] r = [find_root(s, i) for i in [1 : 10; 17] ] -@test isa(r, Vector{Int}) @test isequal(r, r0) root1 = find_root(s, 7) root2 = find_root(s, 3) -@test root_union!(s, root1, root2) == 7 -@test find_root(s, 7) == 7 -@test find_root(s, 3) == 7 +@test root1 != root2 +root_union!(s, 7, 3) +@test find_root(s, 7) == find_root(s, 3) + + +## Some tests using non-integer disjoint sets +elems = ["a", "b", "c", "d"] +a = DisjointSets{AbstractString}(elems) +union!(a, "a", "b") +@test in_same_set(a,"a","b") +@test find_root(a,"a") == find_root(a,"b") +@test find_root(a,"a") in elems +@test !in_same_set(a, "c", "d") +# union returns new root +@test find_root(a,"a") == union!(a,"b","c") +union!(a,"c","d") +# Now they should be in same set, and a is transitively connected to d +@test in_same_set(a,"a", "d") +# Root element should thus be same for all: +@test all(find_root(a,first(elems)) .== map(x->find_root(a,x),elems)) + +#@test_throws KeyError find_root(a,"f") + +push!(a, "f") +@test find_root(a,"a") != find_root(a,"f")