Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Hopcroft-Karp matching algorithm #15

Merged
merged 3 commits into from
Aug 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ authors = ["Robert Parker and contributors"]
version = "0.1.0"

[deps]
BipartiteMatching = "79040ab4-24c8-4c92-950c-d48b5991a0f6"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee"
Expand Down
4 changes: 2 additions & 2 deletions src/dulmage_mendelsohn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,15 +129,15 @@ function dulmage_mendelsohn(graph::Graphs.Graph, set1::Set)
matched_with_reachable1 = [matching[n] for n in reachable1]
matched_with_reachable2 = [matching[n] for n in reachable2]

filter = cat(
filter = Set(cat(
unmatched1,
unmatched2,
reachable1,
reachable2,
matched_with_reachable1,
matched_with_reachable2;
dims=1,
)
))
other1 = [n for n in nodes1 if !(n in filter)]
other2 = [n for n in nodes2 if !(n in filter)]

Expand Down
152 changes: 140 additions & 12 deletions src/maximum_matching.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
# ___________________________________________________________________________

import Graphs
import BipartiteMatching as BM

const UNMATCHED = nothing
MatchedNodeType{T} = Union{T,typeof(UNMATCHED)}

"""
TODO: This should probably be promoted to Graphs.jl
Expand All @@ -44,20 +45,147 @@ function _is_valid_bipartition(graph::Graphs.Graph, set1::Set)
return true
end

# The following three functions are copied from the branch in PR #291
# of Graphs.jl, https://github.com/JuliaGraphs/Graphs.jl/pull/291.
# They will be removed when/if this PR is merged in favor of using the
# Graphs.jl maximum_matching function.

"""
Determine whether an augmenting path exists and mark distances
so we can compute shortest-length augmenting paths in the DFS.
"""
function _hk_augmenting_bfs!(
graph::Graphs.AbstractGraph{T},
set1::Vector{T},
matching::Dict{T,MatchedNodeType{T}},
distance::Dict{MatchedNodeType{T},Float64},
)::Bool where {T<:Integer}
# Initialize queue with the unmatched nodes in set1
queue = Vector{MatchedNodeType{eltype(graph)}}([
n for n in set1 if matching[n] == UNMATCHED
])

distance[UNMATCHED] = Inf
for n in set1
if matching[n] == UNMATCHED
distance[n] = 0.0
else
distance[n] = Inf
end
end

while !isempty(queue)
n1 = popfirst!(queue)

# If n1 is (a) matched or (b) in set1
if distance[n1] < Inf && n1 != UNMATCHED
for n2 in Graphs.neighbors(graph, n1)
# If n2 has not been encountered
if distance[matching[n2]] == Inf
# Give it a distance
distance[matching[n2]] = distance[n1] + 1

# Note that n2 could be unmatched
push!(queue, matching[n2])
end
end
end
end

found_augmenting_path = (distance[UNMATCHED] < Inf)
# The distance to UNMATCHED is the length of the shortest augmenting path
return found_augmenting_path
end

"""
Compute augmenting paths and update the matching
"""
function _hk_augmenting_dfs!(
graph::Graphs.AbstractGraph{T},
root::MatchedNodeType{T},
matching::Dict{T,MatchedNodeType{T}},
distance::Dict{MatchedNodeType{T},Float64},
)::Bool where {T<:Integer}
if root != UNMATCHED
for n in Graphs.neighbors(graph, root)
# Traverse edges of the minimum-length alternating path
if distance[matching[n]] == distance[root] + 1
if _hk_augmenting_dfs!(graph, matching[n], matching, distance)
# If the edge is part of an augmenting path, update the
# matching
matching[root] = n
matching[n] = root
return true
end
end
end
# If we could not find a matched edge that was part of an augmenting
# path, we need to make sure we don't consider this vertex again
distance[root] = Inf
return false
else
# Return true to indicate that we are part of an augmenting path
return true
end
end

"""
hopcroft_karp_matching(graph::AbstractGraph)::Dict

Compute a maximum-cardinality matching of a bipartite graph via the
[Hopcroft-Karp algorithm](https://en.wikipedia.org/wiki/Hopcroft-Karp_algorithm).

The return type is a dict mapping nodes to nodes. All matched nodes are included
as keys. For example, if `i` is matched with `j`, `i => j` and `j => i` are both
included in the returned dict.

### Performance

The algorithms runs in O((m + n)n^0.5), where n is the number of vertices and
m is the number of edges. As it does not assume the number of edges is O(n^2),
this algorithm is particularly effective for sparse bipartite graphs.

### Arguments

* `graph`: The bipartite `Graph` for which a maximum matching is computed

### Exceptions

* `ArgumentError`: The provided graph is not bipartite

"""
function hopcroft_karp_matching(graph::Graphs.AbstractGraph{T})::Dict{T,T} where {T<:Integer}
bmap = Graphs.bipartite_map(graph)
if length(bmap) != Graphs.nv(graph)
throw(ArgumentError("Provided graph is not bipartite"))
end
set1 = [n for n in Graphs.vertices(graph) if bmap[n] == 1]

# Initialize "state" that is modified during the algorithm
matching = Dict{eltype(graph),MatchedNodeType{eltype(graph)}}(
n => UNMATCHED for n in Graphs.vertices(graph)
)
distance = Dict{MatchedNodeType{eltype(graph)},Float64}()

# BFS to determine whether any augmenting paths exist
while _hk_augmenting_bfs!(graph, set1, matching, distance)
for n1 in set1
if matching[n1] == UNMATCHED
# DFS to update the matching along a minimum-length
# augmenting path
_hk_augmenting_dfs!(graph, n1, matching, distance)
end
end
end
matching = Dict(i => j for (i, j) in matching if j != UNMATCHED)
return matching
end

function maximum_matching(graph::Graphs.Graph, set1::Set)
if !_is_valid_bipartition(graph, set1)
throw(Exception)
end
n_nodes = Graphs.nv(graph)
card1 = length(set1)
nodes1 = sort([node for node in set1])
set2 = setdiff(Set(1:n_nodes), set1)
nodes2 = sort([node for node in set2])
edge_set = Set((n1, n2) for n1 in nodes1 for n2 in Graphs.neighbors(graph, n1))
amat = BitArray{2}((r, c) in edge_set for r in nodes1, c in nodes2)
matching, _ = BM.findmaxcardinalitybipartitematching(amat)
# Translate row/column coordinates back into nodes of the graph
graph_matching = Dict(nodes1[r] => nodes2[c] for (r, c) in matching)
return graph_matching
matching = hopcroft_karp_matching(graph)
matching = Dict(i => j for (i, j) in matching if i in set1)
return matching
end