Skip to content

Commit

Permalink
Merge pull request #295 from cth/master
Browse files Browse the repository at this point in the history
fix broken find_root for non-int disjoint sets
  • Loading branch information
kmsquire authored Jul 13, 2017
2 parents ba69667 + 717bd2f commit 7f5032d
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 27 deletions.
18 changes: 11 additions & 7 deletions src/disjoint_set.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,17 +107,21 @@ end

@compat type DisjointSets{T}
intmap::Dict{T,Int}
revmap::Vector{T}
internal::IntDisjointSets

function (::Type{DisjointSets{T}}){T}(xs) # xs must be iterable
imap = Dict{T,Int}()
rmap = Vector{T}()
n = length(xs)
sizehint!(imap, n)
sizehint!(rmap, n)
id = 0
for x in xs
imap[x] = (id += 1)
push!(rmap,x)
end
new{T}(imap, IntDisjointSets(n))
new{T}(imap, rmap, IntDisjointSets(n))
end
end

Expand All @@ -129,17 +133,17 @@ num_groups(s::DisjointSets) = num_groups(s.internal)
Finds the root element of the subset in `s` which has the element `x` as a member.
"""
find_root{T}(s::DisjointSets{T}, x::T) = find_root(s.internal, s.intmap[x])
find_root{T}(s::DisjointSets{T}, x::T) = s.revmap[find_root(s.internal, s.intmap[x])]

in_same_set{T}(s::DisjointSets{T}, x::T, y::T) = in_same_set(s.internal, s.intmap[x], s.intmap[y])

for f in (:union!, :root_union!)
@eval begin
($f){T}(s::DisjointSets{T}, x::T, y::T) = ($f)(s.internal, s.intmap[x], s.intmap[y])
end
end
union!{T}(s::DisjointSets{T}, x::T, y::T) = s.revmap[union!(s.internal, s.intmap[x], s.intmap[y])]

root_union!{T}(s::DisjointSets{T}, x::T, y::T) = s.revmap[root_union!(s.internal, s.intmap[x], s.intmap[y])]

function push!{T}(s::DisjointSets{T}, x::T)
id = push!(s.internal)
s.intmap[x] = id
push!(s.revmap,x) # Note, this assumes invariant: length(s.revmap) == id
x
end
53 changes: 33 additions & 20 deletions test/test_disjoint_set.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,53 +40,66 @@ root2 = find_root(s, 2)
@test root_union!(s, root1, root2) == 8
@test union!(s, 5, 6) == 8

# DisjointSets supports arbitrary indices
s = DisjointSets{Int}(1:10)

@test length(s) == 10
@test num_groups(s) == 10

r = [find_root(s, i) for i in 1 : 10]
@test isa(r, Vector{Int})
@test isequal(r, collect(1:10))

for i = 1 : 5
x = 2 * i - 1
y = 2 * i
union!(s, x, y)
@test find_root(s, x) == find_root(s, y)
end

@test length(s) == 10
@test num_groups(s) == 5

r0 = [1, 1, 3, 3, 5, 5, 7, 7, 9, 9]
r = [find_root(s, i) for i in 1 : 10]
@test isa(r, Vector{Int})
@test isequal(r, r0)

@test union!(s, 1, 4) == 1
@test union!(s, 3, 5) == 1
@test union!(s, 7, 9) == 7
@test union!(s, 1, 4) == find_root(s,1)
@test union!(s, 3, 5) == find_root(s,1)
@test union!(s, 7, 9) == find_root(s,7)

@test length(s) == 10
@test num_groups(s) == 2

r0 = [1, 1, 1, 1, 1, 1, 7, 7, 7, 7]
r = [find_root(s, i) for i in 1 : 10]
@test isa(r, Vector{Int})
@test isequal(r, r0)

r0 = [ find_root(s,i) for i in 1:10 ]
# Since this is a DisjointSet (not IntDisjointSet), the root for 17 will be 17, not 11
push!(s, 17)

@test length(s) == 11
@test num_groups(s) == 3

r0 = [1, 1, 1, 1, 1, 1, 7, 7, 7, 7, 11]
r0 = [ r0 ; 17]
r = [find_root(s, i) for i in [1 : 10; 17] ]
@test isa(r, Vector{Int})
@test isequal(r, r0)

root1 = find_root(s, 7)
root2 = find_root(s, 3)
@test root_union!(s, root1, root2) == 7
@test find_root(s, 7) == 7
@test find_root(s, 3) == 7
@test root1 != root2
root_union!(s, 7, 3)
@test find_root(s, 7) == find_root(s, 3)


## Some tests using non-integer disjoint sets
elems = ["a", "b", "c", "d"]
a = DisjointSets{AbstractString}(elems)
union!(a, "a", "b")
@test in_same_set(a,"a","b")
@test find_root(a,"a") == find_root(a,"b")
@test find_root(a,"a") in elems
@test !in_same_set(a, "c", "d")
# union returns new root
@test find_root(a,"a") == union!(a,"b","c")
union!(a,"c","d")
# Now they should be in same set, and a is transitively connected to d
@test in_same_set(a,"a", "d")
# Root element should thus be same for all:
@test all(find_root(a,first(elems)) .== map(x->find_root(a,x),elems))

#@test_throws KeyError find_root(a,"f")

push!(a, "f")
@test find_root(a,"a") != find_root(a,"f")

0 comments on commit 7f5032d

Please sign in to comment.