Skip to content

Commit

Permalink
Merge pull request #67 from OmkarPathak/dev
Browse files Browse the repository at this point in the history
Completed Quadtree implementation
  • Loading branch information
OmkarPathak authored Sep 4, 2017
2 parents 14e6b0f + ddc9f57 commit ef1f26c
Show file tree
Hide file tree
Showing 2 changed files with 368 additions and 134 deletions.
241 changes: 219 additions & 22 deletions pygorithm/data_structures/quadtree.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
depth and bucket size.
"""
import inspect
import math
from collections import deque

from pygorithm.geometry import (vector2, polygon2, rect2)

Expand All @@ -25,7 +27,7 @@ def __init__(self, aabb):
:param aabb: axis-aligned bounding box
:type aabb: :class:`pygorithm.geometry.rect2.Rect2`
"""
pass
self.aabb = aabb

def __repr__(self):
"""
Expand All @@ -46,7 +48,7 @@ def __repr__(self):
:returns: unambiguous representation of this quad tree entity
:rtype: string
"""
pass
return "quadtreeentity(aabb={})".format(repr(self.aabb))

def __str__(self):
"""
Expand All @@ -67,7 +69,7 @@ def __str__(self):
:returns: human readable representation of this entity
:rtype: string
"""
pass
return "entity(at {})".format(str(self.aabb))

class QuadTree(object):
"""
Expand Down Expand Up @@ -129,7 +131,12 @@ def __init__(self, bucket_size, max_depth, location, depth = 0, entities = None)
:param entities: the entities to initialize this quadtree with
:type entities: list of :class:`.QuadTreeEntity` or None for empty list
"""
pass
self.bucket_size = bucket_size
self.max_depth = max_depth
self.location = location
self.depth = depth
self.entities = entities if entities is not None else []
self.children = None

def think(self, recursive = False):
"""
Expand All @@ -145,7 +152,13 @@ def think(self, recursive = False):
:param recursive: if `think(True)` should be called on :py:attr:`.children` (if there are any)
:type recursive: bool
"""
pass
if not self.children and self.depth < self.max_depth and len(self.entities) > self.bucket_size:
self.split()

if recursive:
if self.children:
for child in self.children:
child.think(True)

def split(self):
"""
Expand All @@ -164,12 +177,43 @@ def split(self):
:raises ValueError: if :py:attr:`.children` is not empty
"""
pass
if self.children:
raise ValueError("cannot split twice")

_cls = type(self)
def _cstr(r):
return _cls(self.bucket_size, self.max_depth, r, self.depth + 1)

_halfwidth = self.location.width / 2
_halfheight = self.location.height / 2
_x = self.location.mincorner.x
_y = self.location.mincorner.y

self.children = [
_cstr(rect2.Rect2(_halfwidth, _halfheight, vector2.Vector2(_x, _y))),
_cstr(rect2.Rect2(_halfwidth, _halfheight, vector2.Vector2(_x + _halfwidth, _y))),
_cstr(rect2.Rect2(_halfwidth, _halfheight, vector2.Vector2(_x + _halfwidth, _y + _halfheight))),
_cstr(rect2.Rect2(_halfwidth, _halfheight, vector2.Vector2(_x, _y + _halfheight))) ]

_newents = []
for ent in self.entities:
quad = self.get_quadrant(ent)

if quad < 0:
_newents.append(ent)
else:
self.children[quad].entities.append(ent)
self.entities = _newents



def get_quadrant(self, entity):
"""
Calculate the quadrant that the specified entity belongs to.
Touching a line is considered overlapping a line. Touching is
determined using :py:meth:`math.isclose`
Quadrants are:
- -1: None (it overlaps 2 or more quadrants)
Expand All @@ -189,7 +233,48 @@ def get_quadrant(self, entity):
:returns: quadrant
:rtype: int
"""
pass

_aabb = entity.aabb
_halfwidth = self.location.width / 2
_halfheight = self.location.height / 2
_x = self.location.mincorner.x
_y = self.location.mincorner.y

if math.isclose(_aabb.mincorner.x, _x + _halfwidth):
return -1
if math.isclose(_aabb.mincorner.x + _aabb.width, _x + _halfwidth):
return -1
if math.isclose(_aabb.mincorner.y, _y + _halfheight):
return -1
if math.isclose(_aabb.mincorner.y + _aabb.height, _y + _halfheight):
return -1

_leftside_isleft = _aabb.mincorner.x < _x + _halfwidth
_rightside_isleft = _aabb.mincorner.x + _aabb.width < _x + _halfwidth

if _leftside_isleft != _rightside_isleft:
return -1

_topside_istop = _aabb.mincorner.y < _y + _halfheight
_botside_istop = _aabb.mincorner.y + _aabb.height < _y + _halfheight

if _topside_istop != _botside_istop:
return -1

_left = _leftside_isleft
_top = _topside_istop

if _left:
if _top:
return 0
else:
return 3
else:
if _top:
return 1
else:
return 2


def insert_and_think(self, entity):
"""
Expand All @@ -204,7 +289,14 @@ def insert_and_think(self, entity):
:param entity: the entity to insert
:type entity: :class:`.QuadTreeEntity`
"""
pass
if not self.children and len(self.entities) == self.bucket_size and self.depth < self.max_depth:
self.split()

quad = self.get_quadrant(entity) if self.children else -1
if quad < 0:
self.entities.append(entity)
else:
self.children[quad].insert_and_think(entity)

def retrieve_collidables(self, entity, predicate = None):
"""
Expand All @@ -227,19 +319,71 @@ def retrieve_collidables(self, entity, predicate = None):
:returns: potential collidables (never `None)
:rtype: list of :class:`.QuadTreeEntity`
"""
pass
result = list(filter(predicate, self.entities))
quadrant = self.get_quadrant(entity) if self.children else -1

if quadrant >= 0:
result.extend(self.children[quadrant].retrieve_collidables(entity, predicate))
elif self.children:
for child in self.children:
touching, overlapping, alwaysNone = rect2.Rect2.find_intersection(entity.aabb, child.location, find_mtv=False)
if touching or overlapping:
result.extend(child.retrieve_collidables(entity, predicate))

return result

def _iter_helper(self, pred):
"""
Calls pred on each child and childs child, iteratively.
pred takes one positional argument (the child).
:param pred: function to call
:type pred: `types.FunctionType`
"""

_stack = deque()
_stack.append(self)

while _stack:
curr = _stack.pop()
if curr.children:
for child in curr.children:
_stack.append(child)

pred(curr)

def find_entities_per_depth(self):
"""
Calculate the number of nodes and entities at each depth level in this
quad tree. Only returns for depth levels at or equal to this node.
This is implemented iteratively. See :py:meth:`.__str__` for usage example.
:returns: dict of depth level to (number of nodes, number of entities)
:rtype: dict int: (int, int)
:returns: dict of depth level to number of entities
:rtype: dict int: int
"""

container = { 'result': {} }
def handler(curr, container=container):
container['result'][curr.depth] = container['result'].get(curr.depth, 0) + len(curr.entities)
self._iter_helper(handler)

return container['result']

def find_nodes_per_depth(self):
"""
Calculate the number of nodes at each depth level.
This is implemented iteratively. See :py:meth:`.__str__` for usage example.
:returns: dict of depth level to number of nodes
:rtype: dict int: int
"""
pass

nodes_per_depth = {}
self._iter_helper(lambda curr, d=nodes_per_depth: d.update({ (curr.depth, d.get(curr.depth, 0) + 1) }))
return nodes_per_depth

def sum_entities(self, entities_per_depth=None):
"""
Expand All @@ -254,7 +398,15 @@ def sum_entities(self, entities_per_depth=None):
:returns: number of entities in this and child nodes
:rtype: int
"""
pass
if entities_per_depth is not None:
return sum(entities_per_depth.values())

container = { 'result': 0 }
def handler(curr, container=container):
container['result'] += len(curr.entities)
self._iter_helper(handler)

return container['result']

def calculate_avg_ents_per_leaf(self):
"""
Expand All @@ -270,7 +422,13 @@ def calculate_avg_ents_per_leaf(self):
:returns: average number of entities at each leaf node
:rtype: :class:`numbers.Number`
"""
pass
container = { 'leafs': 0, 'total': 0 }
def handler(curr, container=container):
if not curr.children:
container['leafs'] += 1
container['total'] += len(curr.entities)
self._iter_helper(handler)
return container['total'] / container['leafs']

def calculate_weight_misplaced_ents(self, sum_entities=None):
"""
Expand All @@ -293,11 +451,40 @@ def calculate_weight_misplaced_ents(self, sum_entities=None):
:returns: weight of misplaced entities
:rtype: :class:`numbers.Number`
"""
pass

# this iteration requires more context than _iter_helper provides.
# we must keep track of parents as well in order to correctly update
# weights

nonleaf_to_max_child_depth_dict = {}

# stack will be (quadtree, list (of parents) or None)
_stack = deque()
_stack.append((self, None))
while _stack:
curr, parents = _stack.pop()
if parents:
for p in parents:
nonleaf_to_max_child_depth_dict[p] = max(nonleaf_to_max_child_depth_dict.get(p, 0), curr.depth)

if curr.children:
new_parents = list(parents) if parents else []
new_parents.append(curr)
for child in curr.children:
_stack.append((child, new_parents))

_weight = 0
for nonleaf, maxchilddepth in nonleaf_to_max_child_depth_dict.items():
_weight += len(nonleaf.entities) * 4 * (maxchilddepth - nonleaf.depth)

_sum = self.sum_entities() if sum_entities is None else sum_entities
return _weight / _sum

def __repr__(self):
"""
Create an unambiguous, recursive representation of this quad tree.
Create an unambiguous representation of this quad tree.
This is implemented iteratively.
Example:
Expand All @@ -308,19 +495,18 @@ def __repr__(self):
# create a tree with a up to 2 entities in a bucket that
# can have a depth of up to 5.
_tree = quadtree.QuadTree(2, 5, rect2.Rect2(100, 100))
_tree = quadtree.QuadTree(1, 5, rect2.Rect2(100, 100))
# add a few entities to the tree
_tree.insert_and_think(quadtree.QuadTreeEntity(rect2.Rect2(2, 2, vector2.Vector2(5, 5))))
_tree.insert_and_think(quadtree.QuadTreeEntity(rect2.Rect2(2, 2, vector2.Vector2(95, 5))))
# prints quadtree(bucket_size=2, max_depth=5, location=rect2(width=100, height=100, mincorner=vector2(x=0, y=0)), depth=0, entities=[], children=[ quadtree(bucket_size=2, max_depth=5, location=rect2(width=50, height=50, mincorner=vector2(x=0, y=0)), depth=1, entities=[ quadtreeentity(aabb=rect2(width=2, height=2, mincorner=vector2(x=5, y=5))) ], children=[]), quadtree(bucket_size=2, max_depth=5, location=rect2(width=50, height=50, mincorner=vector2(x=50, y=0)), depth=1, entities=[ quadtreeentity(aabb=rect2(width=2, height=2, mincorner=vector2(x=95, y=5))) ], children=[]), quadtree(bucket_size=2, max_depth=5, location=rect2(width=50, height=50, mincorner=vector2(x=50, y=50)), depth=1, entities=[], children=[]), quadtree(bucket_size=2, max_depth=5, location=rect2(width=50, height=50, mincorner=vector2(x=0, y=50)), depth=1, entities=[], children=[]) ])
print(repr(_tree))
# prints quadtree(bucket_size=1, max_depth=5, location=rect2(width=100, height=100, mincorner=vector2(x=0, y=0)), depth=0, entities=[], children=[quadtree(bucket_size=1, max_depth=5, location=rect2(width=50.0, height=50.0, mincorner=vector2(x=0, y=0)), depth=1, entities=[quadtreeentity(aabb=rect2(width=2, height=2, mincorner=vector2(x=5, y=5)))], children=None), quadtree(bucket_size=1, max_depth=5, location=rect2(width=50.0, height=50.0, mincorner=vector2(x=50.0, y=0)), depth=1, entities=[quadtreeentity(aabb=rect2(width=2, height=2, mincorner=vector2(x=95, y=5)))], children=None), quadtree(bucket_size=1, max_depth=5, location=rect2(width=50.0, height=50.0, mincorner=vector2(x=50.0, y=50.0)), depth=1, entities=[], children=None), quadtree(bucket_size=1, max_depth=5, location=rect2(width=50.0, height=50.0, mincorner=vector2(x=0, y=50.0)), depth=1, entities=[], children=None)])
:returns: unambiguous, recursive representation of this quad tree
:rtype: string
"""
pass
return "quadtree(bucket_size={}, max_depth={}, location={}, depth={}, entities={}, children={})".format(self.bucket_size, self.max_depth, repr(self.location), self.depth, self.entities, self.children)

def __str__(self):
"""
Expand All @@ -347,12 +533,23 @@ def __str__(self):
_tree.insert_and_think(quadtree.QuadTreeEntity(rect2.Rect2(2, 2, vector2.Vector2(5, 5))))
_tree.insert_and_think(quadtree.QuadTreeEntity(rect2.Rect2(2, 2, vector2.Vector2(95, 5))))
# prints quadtree(at rect(100x100 at <0, 0>) with 0 entities here (2 in total); (nodes, entities) per depth: [ 0: (1, 0), 1: (4, 2) ] (max depth: 5), avg ent/leaf: 0.5 (target 2), misplaced weight = 0 (0 best, >1 bad))
# prints quadtree(at rect(100x100 at <0, 0>) with 0 entities here (2 in total); (nodes, entities) per depth: [ 0: (1, 0), 1: (4, 2) ] (allowed max depth: 5, actual: 1), avg ent/leaf: 0.5 (target 1), misplaced weight 0.0 (0 best, >1 bad)
print(_tree)
:returns: human-readable representation of this quad tree
:rtype: string
"""
pass

nodes_per_depth = self.find_nodes_per_depth()
_ents_per_depth = self.find_entities_per_depth()

_nodes_ents_per_depth_str = "[ {} ]".format(', '.join("{}: ({}, {})".format(dep, nodes_per_depth[dep], _ents_per_depth[dep]) for dep in nodes_per_depth.keys()))

_sum = self.sum_entities(entities_per_depth=_ents_per_depth)
_max_depth = max(_ents_per_depth.keys())
_avg_ent_leaf = self.calculate_avg_ents_per_leaf()
_mispl_weight = self.calculate_weight_misplaced_ents(sum_entities=_sum)
return "quadtree(at {} with {} entities here ({} in total); (nodes, entities) per depth: {} (allowed max depth: {}, actual: {}), avg ent/leaf: {} (target {}), misplaced weight {} (0 best, >1 bad)".format(self.location, len(self.entities), _sum, _nodes_ents_per_depth_str, self.max_depth, _max_depth, _avg_ent_leaf, self.bucket_size, _mispl_weight)

@staticmethod
def get_code():
Expand Down
Loading

0 comments on commit ef1f26c

Please sign in to comment.