#!/usr/bin/python -t -3 -O

#
#            kd-Tree Implementation with Approximate Nearest-Neighbor Search
#
# This is a stand-alone kd-tree class that implements kd-trees,
# the exact NN search and the ANN search algorithm.
# It can handle (naturally) points in arbitrary dimensions.
#
# Usage: kdTree <dim> <num points>
# e.g., kdTree 3 1000
#
# Note that this class implements "longest-side" kd-trees,
# and the ANN algorithm is the standard (iterative) procedure 
# as described in Dickerson, Duncan & Goodrich: "Kd-Trees are
# better when cut on the longest side", ESA 2000.
# The exact NN algorithm is the classic, recursive procedure.
#
# (The python module from SciPy, for instance, uses an older,
# much more complicated method.)
#
# The code is fairly minimalistic, meaning there are only two source files,
# containing only three classes, with only the methods needed for this task.
# (Well, 2-3 methods are there only for testing purposes.)
#
# However, the kd-tree here is really meant for educational purposes only;
# it is not very efficient, because it stores one point per leaf!
# In practice, you'll probably want to store b points per leaf,
# where b is around 10-20.
# In addition, the "bounds overlaps ball" test is a bit inefficient
# (in practice, one would pass the closest p' that is closest to q but still
# within the cell's bounds down to the next level and then move only one of
# p's coordinates, namely the one along the split axis).
# But all these optimizations would complicate the quite a bit,
# which would defeat its didactical purposes.
#
# On the other hand, my code is not entirely inefficient;
# for instance, in order to save space, I don't store the extent of each 
# kd-tree cell in the node (which can be a *lot* in high dimensions).
#
# Author: Gabriel Zachmann, Clausthal University  (zach tu-clausthal de)
# May 2011


import sys
from operator import itemgetter
import copy
from heapq import heappush, heappop
from math import sqrt
import random

from Vector import *

#random.seed( 4 )											# only for debugging


########################### HyperRectangle ######################

class HyperRectangle(object):

	# Create a hyper-rectangle
	# Usage: HyperRectangle( points )
	def __init__( self, points ):
		if len(points) == 0:
			print >> sys.stderr, "HyperRectangle: no points given!"
			return
			
		self._low  = points[0].copy()
		self._high = self._low.copy()
		dim = len( self._low )
		for p in points:
			for i in xrange(dim):
				self._low[i]  = min( self._low[i],  p[i] )
				self._high[i] = max( self._high[i], p[i] )


	# Return the index (= coord axis) of the longest side
	def longestSide(self):
		diag = self._high - self._low
		i, v = max( enumerate(diag), key = itemgetter(1) )
		return i


	# Split a hyperrectangle in two along axis at value
	# Returns 2 new hyperrectangles
	# Assumption: _low[axis] <= value <= _high[axis]
	def split( self, axis, value ):
		l = copy.deepcopy( self )
		r = copy.deepcopy( self )
		l._high[axis] = value
		r._low[axis]  = value
		return l, r


	# Return true iff point is contained in the hyperrectangle
	def contains( self, point ):
		for i in xrange( len(point) ):
			if point[i] < self._low[i] or point[i] > self._high[i]:
				return False
		return True


	# Return distance between a hyperrectangle and a point
	# Assumption: both must live in the same space, i.e., have same dimension
	def dist( self, p ):
		pp = copy.deepcopy( p )
		for i in xrange(len(pp)):								# could probably be optimized bylist comprehensions
			if pp[i] < self._low[i]:   pp[i] = self._low[i]
			if pp[i] > self._high[i]:  pp[i] = self._high[i]
		return pp.dist( p )


	# Check whether or not a hyperrectangle is overlapped by a ball (center,radius)
	def overlaps_ball( self, center, radius ):
		return ( self.dist(center) < radius )


	def __str__(self):											# for 'print x'
		return '[ ' + ','.join( [str(c) for c in self._low] ) + \
				' : ' + ','.join( [str(c) for c in self._high] ) + ' ]'




########################### kd-Tree ######################

class kdTree(object):

	# Create kd-tree over a set of points
	# A kdTree object should be instantiated by the user *only* for point sets with more than 1 point!
	#
	# cell_region = the region (hyperrectangle) of the cell (not neccessarily the bbox of the points!)
	#               (should not be set by the user)
	# (this is actually a recursive function)

	def __init__( self, points, cell_region = None ):
		if points is None or len(points) == 0:
			# should not happen
			print >> sys.stderr, "kdTree: init'ed with None!"
			return

		n = len(points)
		if n == 1:
			# create leaf
			self._point = points[0].copy()
			self._left = self._right = None
			self._split_axis = None								# indicates leaf, too
			self._split_value = 0								# don't care value
			return

		# now we have to create an inner node and two child kd-trees
		if cell_region is None:
			# we must be the root
			cell_region = HyperRectangle( points )
			self._cell_region = cell_region						# store the extent of the root (and only the root!)

		# determine splitting plane
		self._split_axis = cell_region.longestSide()

		# here, we do a "stupid" split by repeatedly sorting the points
		points.sort( key = itemgetter(self._split_axis) )
		self._split_value = points[ n/2-1 ][self._split_axis]
		self._point = None										# we store points only at leaves

		# split hyperrectangle
		left_region , right_region = cell_region.split( self._split_axis, self._split_value )

		# create child trees
		self._left = kdTree( points[0:n/2], left_region )
		self._right = kdTree( points[n/2:n], right_region )
	# end def __init__


	# Approximate nearest neighbor
	# Returns the (1+epsilon)-nearest neighbor to query
	# Returns (dist, nearest-neighbor)
	# Assumption: epsilon > 0

	def approx_nn( self, query, epsilon ):

		cell_region = self._cell_region							# root cell region
		queue = []												# p-queue of nodes, sorted by distance to query
		node = self
		current_nn = None										# current candidate for NN, 
		current_d = sys.float_info.max							# start with "infinite" point

		#import pdb; pdb.set_trace()

		while cell_region.dist( query ) < current_d/(1.0 + epsilon):

			# descend into closest leaf
			while node._split_axis is not None:
				left_region, right_region = cell_region.split( node._split_axis, node._split_value )
				dl = left_region.dist( query )
				dr = right_region.dist( query )
				if dl < dr:
					# left child is closer
					heappush( queue, (dr, node._right, right_region) )
					node = node._left
					cell_region = left_region
				else:
					# right child is closer
					heappush( queue, (dl, node._left, left_region) )
					node = node._right
					cell_region = right_region

			# we are now at a leaf
			d = query.dist( node._point )
			if d < current_d:
				current_nn = node._point
				current_d = d

			if not queue:
				break											# last node was processed

			# process next node, which is the closest of all unprocessed yet
			dn, node, cell_region = heappop( queue )
		# end while

		return current_d, current_nn
	# end def approx_nn


	# Exact nearest neighbor
	# This is the classic, recursive algorithm
	# Usage by the user: kdtree.exact_nn( query )
	# Returns (dist, nearest-neighbor)

	def exact_nn( self, query, cell_region = None, current_d = None, current_nn = None ):

		if cell_region is None:
			# we must be at the root
			cell_region = self._cell_region

		if self._split_axis is None:
			# leaf
			if current_d is None:
				# first point
				return query.dist(self._point), self._point
			else:
				d = query.dist( self._point )
				if d < current_d:
					return d, self._point
				else:
					return current_d, current_nn

		else:
			query_is_left_of_split_plane = ( query[ self._split_axis ] < self._split_value )

			# split current cell's hyperrectangle for following test and child nodes
			left_region, right_region = cell_region.split( self._split_axis, self._split_value )

			# recurse into closer child first
			if query_is_left_of_split_plane:
				current_d, current_nn = self._left.exact_nn( query, left_region, current_d, current_nn )
			else:
				current_d, current_nn = self._right.exact_nn( query, right_region, current_d, current_nn )

			# recurse into the other (farther) child
			if query_is_left_of_split_plane:
				# "bounds overlap ball" test
				if right_region.overlaps_ball( query, current_d ):
					current_d, current_nn = self._right.exact_nn( query, right_region, current_d, current_nn )
			else:
				if left_region.overlaps_ball( query, current_d ):
					current_d, current_nn = self._left.exact_nn( query, left_region, current_d, current_nn )

			# no "ball within bounds" test here - I think, we don't need that
			return current_d, current_nn

		# can't be reached
		print >> sys.stderr, "brute_force_nn: BUG!"
	# end def brute_force_nn


	# Compute nearest neighbor to query using brute-force method:
	# we just visit every leaf

	def brute_force_nn( self, query, current_d = None, current_nn = None ):

		if self._split_axis is None:
			# leaf
			if current_d is None:
				# first point
				return query.dist(self._point), self._point
			else:
				d = query.dist( self._point )
				if d < current_d:
					return d, self._point
				else:
					return current_d, current_nn
		else:
			# inner node
			current_d, current_nn = self._left.brute_force_nn( query, current_d, current_nn )
			current_d, current_nn = self._right.brute_force_nn( query, current_d, current_nn )
			return current_d, current_nn
		# can't be reached
		print >> sys.stderr, "brute_force_nn: BUG!"
	# end def brute_force_nn


	# Print a kd-tree
	# cell_region must be the same as HyperRectangle(points) when the constructor was called!
	# check_only = true -> nothing will be print, only some consistency checks done

	def out( self, cell_region, check_only, level = 0 ):

		if not check_only:
			# print indent
			for i in xrange(level):
				print "  ",

		if self._split_axis is None:
			# leaf
			if not check_only:
				print self._point
			if self._point is None:
				print "BUG: leaf does not contain a point!"
			elif not cell_region.contains( self._point ):
				if not check_only:
					for i in xrange(level):
						print "  ",
				print "BUG: Leaf region does not contain its point!"
				if check_only:
					print "level = ", level
				for i in xrange(level):
					print "  ",
				print cell_region
				if check_only:
					print self._point
			return

		if level == 0:
			# print region of root = boox(points)
			if not check_only:
				print cell_region

		# print splitting plane
		if not check_only:
			print self._split_axis, " : ", self._split_value

		# split hyperrectangle (do it exactly as the build function did)
		left_region, right_region = cell_region.split( self._split_axis, self._split_value )

		# print left and right sub-trees
		self._left.out( left_region, check_only, level+1 )
		self._right.out( right_region, check_only, level+1 )
	# end def out


########################### Main ######################

# retrieve command line arguments
if len( sys.argv ) < 2:
	print >> sys.stderr, "Usage: ", sys.argv[0], " dimension [num points]"
	sys.exit( 2 )

dimension = int( sys.argv[1] )
if dimension < 2:
	print >> sys.stderr, "Dimension (%d) < 2!" % dimension
	sys.exit( 2 )

if len( sys.argv ) > 2:
	num_points = int( sys.argv[2] )
else:
	num_points = 100


# create random points
points = [ Vector.random(dimension) for i in xrange(num_points) ]

# create kd-tree
print "Building kd-tree ..."
kdtree = kdTree( points )

# check consistency
print "Checking ..."
kdtree.out( HyperRectangle(points), True )

# Use known query point
query = Vector( dimension, 0.5 )

print "Starting queries ..."

# do a number of tests
for j in xrange(10):

	print "\n", j, "."

	# compute nearest neighbor by brute force
	d, nn = kdtree.brute_force_nn( query )
	print "Exact nn  = ", nn
	print " distance = ", d

	# check that d is really the closest distance to query
	def checkNN( points, q, d ):
		for p in points:
			if q.dist( p ) < d:
				print "BUG: ", p
				print "   is nearer (%f  vs  %f)!" % (q.dist(p), d)

	checkNN( points, query, d )

	# compute exact nearest neighbor
	d, nn = kdtree.exact_nn( query )
	print "Exact nn  = ", nn
	print " distance = ", d

	checkNN( points, query, d )

	# compute approximate nearest neighbor
	eps = 0.1 + 0.4*random.random()										# eps in [0.1,0.5]
	d, ann = kdtree.approx_nn( query, eps )
	print "Approx nn = ", nn
	print " distance = ", d
	print " epsilon  = ", eps
	print " |ann-q| / |nn-q| = ", query.dist(ann) / query.dist(nn)

	# check
	if query.dist(ann) > (1.0 + eps) * d:
		print "BUG: ANN too far away!"
		print "ANN     = ", ann
		print "|q-ann| = ", query.dist( ann )
		print "d       = ", d


	# check the statement of the proof (the raison d'etre)
	if query.dist(ann) > (1.0 + eps) * query.dist(nn):
		print "BUG: ann is farther away than (1+eps)*dist(nn,q)!"
		print "ann     = ", ann
		print "|q-ann| = ", query.dist( ann )
		print "|q-nn|  = ", query.dist( nn )

	# new query point for next round of tests
	query = Vector.random(dimension)

