Skip to content

Commit

Permalink
pairlist and itemlist
Browse files Browse the repository at this point in the history
  • Loading branch information
guo-yong-zhi committed Dec 31, 2021
1 parent 0328888 commit 95ad19a
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 35 deletions.
11 changes: 7 additions & 4 deletions src/fit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ end
"element-wise trainer"
trainepoch_E!(;inputs) = Dict(:colist => Vector{QTrees.CoItem}(),
:queue => QTrees.thread_queue(),
:pairlist => Vector{QTrees.CoItem}(),
:itemlist => Vector{QTrees.CoItem}(),
:pairlist => Vector{Tuple{Int, Int}}(),
:updated => Set{Int}(),
:spqtree => QTrees.hash_spacial_qtree(inputs))
trainepoch_E!(s::Symbol) = get(Dict(:patient => 20, :nepoch => 2000), s, nothing)
Expand All @@ -101,7 +102,8 @@ end
trainepoch_EM!(;inputs) = Dict(:colist => Vector{QTrees.CoItem}(),
:queue => QTrees.thread_queue(),
:memory => intlru(length(inputs)),
:pairlist => Vector{QTrees.CoItem}(),
:itemlist => Vector{QTrees.CoItem}(),
:pairlist => Vector{Tuple{Int, Int}}(),
:updated => Set{Int}(),
:spqtree => QTrees.hash_spacial_qtree(inputs))
trainepoch_EM!(s::Symbol) = get(Dict(:patient => 10, :nepoch => 1000), s, nothing)
Expand Down Expand Up @@ -212,11 +214,12 @@ end
"dynamic trainer"
trainepoch_D!(;inputs) = Dict(:colist => Vector{QTrees.CoItem}(),
:queue => QTrees.thread_queue(),
:pairlist => Vector{QTrees.CoItem}(),
:itemlist => Vector{QTrees.CoItem}(),
:pairlist => Vector{Tuple{Int, Int}}(),
:updated => QTrees.UpdatedSet(1:length(inputs)),
:loops => 10,
:spqtree => QTrees.linked_spacial_qtree(inputs), #fllowing 4 tiems: pre-allocating for dynamiccollisions
:sptree2 => QTrees.hash_spacial_qtree(inputs),
:sptqree2 => QTrees.hash_spacial_qtree(inputs),
:lbcollector => Vector{Int}(),
:treenodestack => Vector{QTrees.SpacialQTreeNode}())
trainepoch_D!(s::Symbol) = get(Dict(:patient => 10, :nepoch => 2000), s, nothing)
Expand Down
66 changes: 35 additions & 31 deletions src/qtree_functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,12 @@ function _totalcollisions_native(qtrees::AbstractVector, coitems::Vector{CoItem}
colist
end
function _totalcollisions_native(qtrees::AbstractVector,
labels::AbstractVector{<:Integer}=1:length(qtrees); kargs...)
labels::AbstractVector{<:Integer}=1:length(qtrees);
pairlist::AbstractVector{Tuple{Int, Int}}=Vector{Tuple{Int, Int}}(), kargs...)
l = length(labels)
_totalcollisions_native(qtrees, [@inbounds (labels[i], labels[j]) for i in 1:l for j in l:-1:i + 1]; kargs...)
empty!(pairlist)
append!(pairlist, (@inbounds (labels[i], labels[j]) for i in 1:l for j in l:-1:i + 1))
_totalcollisions_native(qtrees, pairlist; kargs...)
end
function _totalcollisions_native(qtrees::AbstractVector, labels::AbstractSet{<:Integer}; kargs...)
_totalcollisions_native(qtrees, labels |> collect; kargs...)
Expand Down Expand Up @@ -155,15 +158,15 @@ function locate!(qts::AbstractVector, labels::Union{AbstractVector{Int},Abstract
spqtree
end

function collisions_boundsfilter(qtrees, spindex, lowlabels, higlabels, pairlist, colist)
function collisions_boundsfilter(qtrees, spindex, lowlabels, higlabels, itemlist, colist)
for hlb in higlabels
# check here because there are no bounds checking in _collision_randbfs
collisions_boundsfilter(qtrees, spindex, lowlabels, hlb, pairlist, colist)
collisions_boundsfilter(qtrees, spindex, lowlabels, hlb, itemlist, colist)
end
end
function collisions_boundsfilter(qtrees, spindex, lowlabels, hlb::Int, pairlist, colist)
function collisions_boundsfilter(qtrees, spindex, lowlabels, hlb::Int, itemlist, colist)
if inkernelbounds(@inbounds(qtrees[hlb][spindex[1]]), spindex[2], spindex[3])
append!(pairlist, ((llb, hlb)=>spindex for llb in lowlabels))
append!(itemlist, ((llb, hlb)=>spindex for llb in lowlabels))
elseif getdefault(@inbounds(qtrees[hlb][1])) == QTrees.FULL
for llb in lowlabels
if @inbounds(qtrees[llb][spindex]) != QTrees.EMPTY
Expand All @@ -173,22 +176,22 @@ function collisions_boundsfilter(qtrees, spindex, lowlabels, hlb::Int, pairlist,
end
end
end
function collisions_boundsfilter(qtrees, spindex, llb::Int, higlabels, pairlist, colist)
collisions_boundsfilter(qtrees, spindex, (llb,), higlabels, pairlist, colist)
function collisions_boundsfilter(qtrees, spindex, llb::Int, higlabels, itemlist, colist)
collisions_boundsfilter(qtrees, spindex, (llb,), higlabels, itemlist, colist)
end
@assert collect(Iterators.product(1:2, 4:6))[1] == (1, 4)
function totalcollisions_spacial(qtrees::AbstractVector, spqtree::HashSpacialQTree;
colist=Vector{CoItem}(), pairlist::AbstractVector{CoItem}=Vector{CoItem}(), unique=true, kargs...)
colist=Vector{CoItem}(), itemlist::AbstractVector{CoItem}=Vector{CoItem}(), unique=true, kargs...)
length(qtrees) > 1 || return colist
nlevel = length(@inbounds qtrees[1])
empty!(pairlist)
empty!(itemlist)
for spindex in keys(spqtree)
labels = spqtree[spindex]
labelslen = length(labels)
if labelslen > 1
for i in 1:labelslen
for j in labelslen:-1:i+1
push!(pairlist, (@inbounds labels[i], @inbounds labels[j]) => spindex)
push!(itemlist, (@inbounds labels[i], @inbounds labels[j]) => spindex)
end
end
end
Expand All @@ -198,12 +201,12 @@ function totalcollisions_spacial(qtrees::AbstractVector, spqtree::HashSpacialQTr
(@inbounds pspindex[1] > nlevel) && break
if haskey(spqtree, pspindex)
plbs = spqtree[pspindex]
collisions_boundsfilter(qtrees, spindex, labels, plbs, pairlist, colist)
collisions_boundsfilter(qtrees, spindex, labels, plbs, itemlist, colist)
end
end
end
# @show length(pairlist), length(colist)
r = _totalcollisions_native(qtrees, pairlist; colist=colist, kargs...)
# @show length(itemlist), length(colist)
r = _totalcollisions_native(qtrees, itemlist; colist=colist, kargs...)
unique ? unique!(first, sort!(r)) : r
end
function totalcollisions_spacial(qtrees::AbstractVector{U8SQTree};
Expand All @@ -218,24 +221,25 @@ function totalcollisions_spacial(qtrees::AbstractVector{U8SQTree}, labels::Union
end

const SPACIAL_ENABLE_THRESHOLD = round(Int, 10+10log2(Threads.nthreads()))
function totalcollisions_native_kw(qtrees, args...; pairlist=nothing, unique=true, spqtree=nothing, kargs...)
totalcollisions_native(qtrees, args...; kargs...)
function totalcollisions_native_kw(args...; itemlist=nothing, unique=true, spqtree=nothing, kargs...)
totalcollisions_native(args...; kargs...)
end
function totalcollisions(qtrees::AbstractVector{U8SQTree}, args...; kargs...)
if length(qtrees) > SPACIAL_ENABLE_THRESHOLD
return totalcollisions_spacial(qtrees, args...; kargs...)
totalcollisions_spacial_kw(args...; pairlist=nothing, kargs...) = totalcollisions_spacial(args...; kargs...)
function totalcollisions(args...; kargs...)
if length(args[end]) > SPACIAL_ENABLE_THRESHOLD
return totalcollisions_spacial_kw(args...; kargs...)
else
return totalcollisions_native_kw(qtrees, args...; kargs...)
return totalcollisions_native_kw(args...; kargs...)
end
end
function partialcollisions(qtrees::AbstractVector,
spqtree::LinkedSpacialQTree=linked_spacial_qtree(qtrees),
labels::AbstractSet{Int}=Set(1:length(qtrees));
colist=Vector{CoItem}(), pairlist::AbstractVector{CoItem}=Vector{CoItem}(),
colist=Vector{CoItem}(), itemlist::AbstractVector{CoItem}=Vector{CoItem}(),
lbcollector = Vector{Int}(),
treenodestack = Vector{SpacialQTreeNode}(),
unique=true, kargs...)
empty!(pairlist)
empty!(itemlist)
locate!(qtrees, labels, spqtree) #需要将labels中的label移动到链表首
for label in labels
# @show label
Expand All @@ -250,13 +254,13 @@ function partialcollisions(qtrees::AbstractVector,
# 但要保证更prev的node在`labels`中
treenode = seek_treenode(ln)
spindex = spacial_index(treenode)
append!(pairlist, (((label, lb) => spindex) for lb in lbcollector))
append!(itemlist, (((label, lb) => spindex) for lb in lbcollector))
tn = treenode
while !isroot(tn)
tn = tn.parent #root不是哨兵,值需要遍历
if !isemptylabels(tn)
plbs = Iterators.filter(!in(labels), labelsof(tn)) #moved了的plb不加入,等候其向下遍历时加,避免重复
collisions_boundsfilter(qtrees, spindex, label, plbs, pairlist, colist)
collisions_boundsfilter(qtrees, spindex, label, plbs, itemlist, colist)
end
end
empty!(treenodestack)
Expand All @@ -271,22 +275,22 @@ function partialcollisions(qtrees::AbstractVector,
cspindex = spacial_index(tn)
clbs = labelsof(tn)
# @show cspindex clbs
collisions_boundsfilter(qtrees, cspindex, clbs, label, pairlist, colist)
collisions_boundsfilter(qtrees, cspindex, clbs, label, itemlist, colist)
end
for c in tn.children
if !isemptychild(tn, c) #如果isemptychild则该child无意义
emptyflag = false
push!(treenodestack, c)
# @show pairlist
# @show itemlist
end
end
emptyflag && remove_tree_node(spqtree, tn)
end
end
end
empty!(labels)
# @show length(pairlist), length(colist)
r = _totalcollisions_native(qtrees, pairlist; colist=colist, kargs...)
# @show length(itemlist), length(colist)
r = _totalcollisions_native(qtrees, itemlist; colist=colist, kargs...)
unique ? unique!(first, sort!(r)) : r
end
mutable struct UpdatedSet{T} <: AbstractSet{T}
Expand All @@ -305,11 +309,11 @@ Base.length(s::UpdatedSet) = length(s.set)
Base.iterate(s::UpdatedSet, args...) = iterate(s.set, args...)
Base.in(item, s::UpdatedSet) = in(item, s.set)
Base.in(s::UpdatedSet) = in(s.set)
function totalcollisions_kw(qtrees; sptree2=hash_spacial_qtree(qtrees),
function totalcollisions_kw(qtrees; sptqree2=hash_spacial_qtree(qtrees),
spqtree=nothing, lbcollector=nothing, treenodestack=nothing, kargs...)
totalcollisions(qtrees; spqtree=sptree2, kargs...)
totalcollisions(qtrees; spqtree=sptqree2, kargs...)
end
partialcollisions_kw(qtrees, spqtree, updated; sptree2=nothing, kargs...) = partialcollisions(qtrees, spqtree, updated; kargs...)
partialcollisions_kw(qtrees, spqtree, updated; sptqree2=nothing, pairlist=nothing, kargs...) = partialcollisions(qtrees, spqtree, updated; kargs...)
function dynamiccollisions(qtrees::AbstractVector,
spqtree::LinkedSpacialQTree=linked_spacial_qtree(qtrees),
updated::UpdatedSet{Int}=UpdatedSet(1:length(qtrees));
Expand Down

2 comments on commit 95ad19a

@guo-yong-zhi
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/51478

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.8.2 -m "<description of version>" 95ad19aa68a3e232249573990016c739bba8a8c5
git push origin v0.8.2

Please sign in to comment.