Skip to content

Commit

Permalink
Merge branch 'intersection_independent'
Browse files Browse the repository at this point in the history
  • Loading branch information
itsdfish committed Jul 8, 2022
2 parents 31f8ea8 + 96ff3e0 commit 9dfca19
Show file tree
Hide file tree
Showing 14 changed files with 675 additions and 359 deletions.
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@ ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
QHull = "a8468747-bd6f-53ef-9e5c-744dbc5c59e7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd"
ThreadedIterables = "11d239b0-c0b9-11e8-1935-d5cfa53abb03"

Expand All @@ -23,6 +25,7 @@ ConcreteStructs = "v0.2.3"
DataFrames = "1.3.0"
Distributions = "0.23.0,0.24.0,v0.25.37"
QHull = "0.2"
ProgressMeter = "1.7.2"
Requires = "1.3.0"
SafeTestsets = "0.0.1"
SpecialFunctions = "2"
Expand Down
4 changes: 2 additions & 2 deletions src/ParameterSpacePartitions.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module ParameterSpacePartitions
using Requires, Distributions, ConcreteStructs, LinearAlgebra
using ThreadedIterables, SpecialFunctions, ComponentArrays
using Requires, ProgressMeter, Distributions, ConcreteStructs, LinearAlgebra
using ThreadedIterables, SpecialFunctions, ComponentArrays

export find_partitions,
adapt!,
Expand Down
2 changes: 0 additions & 2 deletions src/TestModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ module TestModels
- 'location`: location of chain in the hypercube
- `p_bounds`: boundaries of partitions for each dimension
"""
function p_fun(location, hypercube::HyperCube, args...; kwargs...)
p_bounds = hypercube.p_bounds
Expand Down Expand Up @@ -87,7 +86,6 @@ module TestModels
- 'location`: location of chain in the hypercube
- `points`: a vector of polytopes containing locations
"""
function p_fun(location, points::Vector{Polytope}, args...; kwargs...)
distances = map(p -> norm(location .- p.location), points)
Expand Down
216 changes: 212 additions & 4 deletions src/intersection_test.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
# https://www.youtube.com/watch?v=OPSCKXXvWiM&ab_channel=TheOrganicChemistryTutor
# https://math.stackexchange.com/questions/1447730/drawing-ellipse-from-eigenvalue-eigenvector
# https://stats.stackexchange.com/questions/164741/how-to-find-the-maximum-axis-of-ellipsoid-given-the-covariance-matrix
"""
intersects(μ1, μ2, cov1, cov2)
intersects(μ1, μ2, cov1, cov2, c=2)
Tests whether two hyperellipsoids intersect. The test returns true if
the hyperellipsoids intersection and false otherwise.
Expand All @@ -13,6 +10,7 @@ the hyperellipsoids intersection and false otherwise.
- `μ2`: centroid of ellipsoid 2
- `cov1`: covariance matrix of ellipsoid 1
- `cov2`: covariance matrix of ellipsoid 2
- `c=2`: ellipse sclar
"""
function intersects(μ1, μ2, cov1, cov2, c=2)
cov1 .*= c^2
Expand All @@ -22,6 +20,16 @@ function intersects(μ1, μ2, cov1, cov2, c=2)

_Q2b = inv1' * cov2 * inv1
Q2b = Symmetric(_Q2b)

if !isposdef(Q2b)
println("cov1")
cov1 ./= c^2
println(cov1)
println("cov2")
cov2 ./= c^2
println(cov2)
end

if !issymmetric(_Q2b)
ϵ = sum(abs.(_Q2b .- _Q2b')) / length(_Q2b)
if ϵ > 1e-10
Expand All @@ -40,4 +48,204 @@ function intersects(μ1, μ2, cov1, cov2, c=2)
else
return false
end
end

"""
intersects(chain1, chain2, c=2)
Tests whether two hyperellipsoids intersect. The test returns true if
the hyperellipsoids intersection and false otherwise.
# Arguments
- `chain1`: chain object
- `chain2`: chain object
- `c=2`: ellipse sclar
"""
function intersects(chain1, chain2, c=2)
mat1 = to_matrix(chain1)
mat2 = to_matrix(chain2)
μ1 = mean(mat1, dims=1)[:]
μ2 = mean(mat2, dims=1)[:]
cov1 = cov(mat1)
add_variance!(cov1)
cov2 = cov(mat2)
add_variance!(cov2)
if !isposdef(cov1)
println("cov1")
println(cov1)
end
if !isposdef(cov2)
println("cov2")
println(cov2)
end
return intersects(μ1, μ2, cov1, cov2)
end

function add_variance!(x)
if any(x -> isapprox(x, 0; atol = 1e-10), diag(x))
x[diagind(x)] .= eps()
end
return nothing
end

to_matrix(x) = reduce(vcat, transpose.(x.all_parms))

"""
get_group_indices(chains, chain_indices)
Sorts chains of the same pattern into non-overlapping groups. The vector
[[1,2],[3,4]] indices chains 1 and 2 are located in region R₁ and chains
3 and 4 are located in region R₂.
# Arguments
- `chains`: a vector of chains
- `group`: a vector of indices corresponding to chains with the same pattern
"""
function get_group_indices(chains, chain_indices)
# group chain indices according to region
indices = Vector{Vector{Int}}()
# first index group will have c = 1
push!(indices, [chain_indices[1]])
n_groups = length(indices)
n_chains = length(chain_indices)
# group index
g = 1
# loop through each chain
for i 2:n_chains
# loop through each index group
c = chain_indices[i]
while g n_groups
# if chain c matches region index in group g,
# push c into index group g
if intersects(chains[indices[g][1]], chains[c])
push!(indices[g], c)
merge_chains!(chains[indices[g][1]], chains[c])
break
end
g += 1
end
# add new index group
if g > n_groups
# add new group
push!(indices, [c])
# increment number of groups
n_groups += 1
end
# reset group counter
g = 1
end
return indices
end

"""
remove_redundant_chains!(chains, indices)
Removes chains that have the same pattern and location in the parameter space.
For example, in the vector `indices` [[1,3],[34,50]], chains indexed in positions 1 and 3 have the same
pattern and location, as do chains indexed at positions 34 and 50. Only the first element of each sub-vector
is retained (i.e., [1,50])
# Arguments
- `chains`: a vector containing all chain objects
- `indices`: a nested vector that maps chains of the same pattern and location to `chains`
"""
function remove_redundant_chains!(chains, indices)
k_indices = Int[]
for i in indices
push!(k_indices, i[1])
end
r_indices = setdiff(vcat(indices...), k_indices)
sort!(r_indices)
deleteat!(chains, r_indices)
return nothing
end

"""
group_by_pattern(chains)
Groups chains according to pattern and returns a nested vector of chain indices. The vector
[[1,2],[3,4]] indices chains 1 and 2 are located in region R₁ and chains
3 and 4 are located in region R₂.
# Arguments
- `chains`: a vector of all chains
"""
function group_by_pattern(chains)
patterns = map(c -> c.pattern, chains)
u_patterns = unique(patterns)
n_patterns = length(u_patterns)
chain_indices = [Vector{Int}() for _ in 1:n_patterns]
for c in 1:length(chains)
g = findfirst(p -> chains[c].pattern == p, u_patterns)
push!(chain_indices[g], c)
end
return chain_indices
end

function remove_nonposdef!(chains)
n_chains = length(chains)
test_vals = fill(false, n_chains)
for c in 1:n_chains
mat = to_matrix(chains[c])
covar = cov(mat)
test_vals[c] = !isposdef(covar)
end
deleteat!(chains, test_vals)
return nothing
end

"""
make_unique!(chains, options; timer=nothing, show_time=false)
This function sorts chains by pattern and merges chains that are in the same region
# Arguments
- `chains`: a vector of all chains
- `options`: an Options configuration object for the PSP algorithm
# Keywords
- `timer`: monitors time
- `show_time=false`: shows time if true
"""
function make_unique!(chains, options; timer=nothing, show_timer=false)
for iter in 1:2
_make_unique(chains, options; iter, timer, show_timer)
end
return nothing
end

function _make_unique(chains, options; iter=0, timer=nothing, show_timer = false)
remove_nonposdef!(chains)
chain_indices = group_by_pattern(chains)
all_indices = Vector{Vector{Int}}()
for c in chain_indices
temp = get_group_indices(chains, c)
push!(all_indices, temp...)
show_timer ? next!(timer; showvalues=show_values(iter)) : nothing
end
remove_redundant_chains!(chains, all_indices)
return nothing
end

"""
merge_chains!(chain1, chain2)
Merges chain2 into chain1 on fields `all_parms`, `acceptance`, and `radii`.
# Arguments
- `chain1`: a chain object
- `chain2`: a chain object
"""
function merge_chains!(chain1, chain2)
push!(chain1.all_parms, chain2.all_parms...)
push!(chain1.acceptance, chain2.acceptance...)
push!(chain1.radii, chain2.radii...)
return nothing
end
Loading

0 comments on commit 9dfca19

Please sign in to comment.