#!/usr/bin/python -O
#-t -3 -O

#
# Exactly the same as kdTree.py, but with additional statistics collecting code
#
# Usage: kdTreeWithStatistics.py <start dim> <end dim> <epsilon> <num points> <distribution-type>
# e.g., kdTreeWithStatistics.py 3 40 0.2 10000 uniform
# 
# As you can see, points can be of one of three different distributions:
# distribution-type is in {"uniform", "clustered", "correlated"}.
# "Clustered" means that a number of cluster centers are created (inside the unit cube),
# and then 1000 points are randomly generated around each center,
# so that each cluster center has a "cloud" of points according to a normal distribution.
# With "correlated", the points are generated by an autocorrelation: the first
# point p_0 is generated randomly inside the uit cube;
# all other points are generated according the the autocorrelation
#    p_i+1 = 0.9*p_i + 0.1*w,
# where w is a random vector according to a Gaussian distribution.
#
# After a point set has been generated, all points are scaled such that
#    bbox(points) = unit cube [0,1]^dim .
#
# Author: Gabriel Zachmann, Clausthal University  (zach tu-clausthal de)
# May 2011


import sys
from operator import itemgetter
import copy
from heapq import heappush, heappop
from math import sqrt
import random
import functools

from Vector import *


# global variable for statistics
num_leaves_visited = 0



########################### HyperRectangle ######################

class HyperRectangle(object):

	# Create a hyper-rectangle
	# Usage: HyperRectangle( points )
	def __init__( self, points ):
		if len(points) == 0:
			print >> sys.stderr, "HyperRectangle: no points given!"
			return
			
		self._low  = points[0].copy()
		self._high = self._low.copy()
		dim = len( self._low )
		for p in points:
			for i in xrange(dim):
				self._low[i]  = min( self._low[i],  p[i] )
				self._high[i] = max( self._high[i], p[i] )


	# Return the index (= coord axis) of the longest side
	def longestSide(self):
		diag = self._high - self._low
		i, v = max( enumerate(diag), key = itemgetter(1) )
		return i


	# Split a hyperrectangle in two along axis at value
	# Returns 2 new hyperrectangles
	# Assumption: _low[axis] <= value <= _high[axis]
	def split( self, axis, value ):
		l = copy.deepcopy( self )
		r = copy.deepcopy( self )
		l._high[axis] = value
		r._low[axis]  = value
		return l, r


	# Return true iff point is contained in the hyperrectangle
	def contains( self, point ):
		for i in xrange( len(point) ):
			if point[i] < self._low[i] or point[i] > self._high[i]:
				return False
		return True


	# Return distance between a hyperrectangle and a point
	# Assumption: both must live in the same space, i.e., have same dimension
	def dist( self, p ):
		pp = copy.deepcopy( p )
		for i in xrange(len(pp)):								# could probably be optimized bylist comprehensions
			if pp[i] < self._low[i]:   pp[i] = self._low[i]
			if pp[i] > self._high[i]:  pp[i] = self._high[i]
		return pp.dist( p )


	# Check whether or not a hyperrectangle is overlapped by a ball (center,radius)
	def overlaps_ball( self, center, radius ):
		return ( self.dist(center) < radius )


	def __str__(self):											# for 'print x'
		return '[ ' + ','.join( [str(c) for c in self._low] ) + \
				' : ' + ','.join( [str(c) for c in self._high] ) + ' ]'




########################### kd-Tree ######################

class kdTree(object):

	# Create kd-tree over a set of points
	# A kdTree object should be instantiated by the user *only* for point sets with more than 1 point!
	#
	# cell_region = the region (hyperrectangle) of the cell (not neccessarily the bbox of the points!)
	#               (should not be set by the user)
	# (this is actually a recursive function)

	def __init__( self, points, cell_region = None ):
		if points is None or len(points) == 0:
			# should not happen
			print >> sys.stderr, "kdTree: init'ed with None!"
			return

		n = len(points)
		if n == 1:
			# create leaf
			self._point = points[0].copy()
			self._left = self._right = None
			self._split_axis = None								# indicates leaf, too
			self._split_value = 0								# don't care value
			return

		# now we have to create an inner node and two child kd-trees
		if cell_region is None:
			# we must be the root
			cell_region = HyperRectangle( points )
			self._cell_region = cell_region						# store the extent of the root (and only the root!)

		# determine splitting plane
		self._split_axis = cell_region.longestSide()

		# here, we do a "stupid" split by repeatedly sorting the points
		points.sort( key = itemgetter(self._split_axis) )
		self._split_value = points[ n/2-1 ][self._split_axis]
		self._point = None										# we store points only at leaves

		# split hyperrectangle
		left_region , right_region = cell_region.split( self._split_axis, self._split_value )

		# create child trees
		self._left = kdTree( points[0:n/2], left_region )
		self._right = kdTree( points[n/2:n], right_region )
	# end def __init__


	# Approximate nearest neighbor
	# Returns the (1+epsilon)-nearest neighbor to query
	# Returns (dist, nearest-neighbor)
	# Assumption: epsilon > 0

	def approx_nn( self, query, epsilon ):

		global num_leaves_visited
		num_leaves_visited = 0

		cell_region = self._cell_region							# root cell region
		queue = []												# p-queue of nodes, sorted by distance to query
		node = self
		current_nn = None										# current candidate for NN, 
		current_d = sys.float_info.max							# start with "infinite" point

		#import pdb; pdb.set_trace()

		while cell_region.dist( query ) < current_d/(1.0 + epsilon):

			# descend into closest leaf
			while node._split_axis is not None:
				left_region, right_region = cell_region.split( node._split_axis, node._split_value )
				dl = left_region.dist( query )
				dr = right_region.dist( query )
				if dl < dr:
					# left child is closer
					heappush( queue, (dr, node._right, right_region) )
					node = node._left
					cell_region = left_region
				else:
					# right child is closer
					heappush( queue, (dl, node._left, left_region) )
					node = node._right
					cell_region = right_region

			# we are now at a leaf
			num_leaves_visited += 1

			d = query.dist( node._point )
			if d < current_d:
				current_nn = node._point
				current_d = d

			if not queue:
				break											# last node was processed

			# process next node, which is the closest of all unprocessed yet
			dn, node, cell_region = heappop( queue )
		# end while

		return current_d, current_nn
	# end def approx_nn


	# Exact nearest neighbor
	# This is the classic, recursive algorithm
	# Usage by the user: kdtree.exact_nn( query )
	# Returns (dist, nearest-neighbor)

	def exact_nn( self, query, cell_region = None, current_d = None, current_nn = None ):

		global num_leaves_visited

		if cell_region is None:
			# we must be at the root
			num_leaves_visited = 0
			cell_region = self._cell_region

		if self._split_axis is None:
			# leaf
			num_leaves_visited += 1

			if current_d is None:
				# first point
				return query.dist(self._point), self._point
			else:
				d = query.dist( self._point )
				if d < current_d:
					return d, self._point
				else:
					return current_d, current_nn

		else:
			query_is_left_of_split_plane = ( query[ self._split_axis ] < self._split_value )

			# split current cell's hyperrectangle for following test and child nodes
			left_region, right_region = cell_region.split( self._split_axis, self._split_value )

			# recurse into closer child first
			if query_is_left_of_split_plane:
				current_d, current_nn = self._left.exact_nn( query, left_region, current_d, current_nn )
			else:
				current_d, current_nn = self._right.exact_nn( query, right_region, current_d, current_nn )

			# recurse into the other (farther) child
			if query_is_left_of_split_plane:
				# "bounds overlap ball" test
				if right_region.overlaps_ball( query, current_d ):
					current_d, current_nn = self._right.exact_nn( query, right_region, current_d, current_nn )
			else:
				if left_region.overlaps_ball( query, current_d ):
					current_d, current_nn = self._left.exact_nn( query, left_region, current_d, current_nn )

			# no "ball within bounds" test here - I think, we don't need that
			return current_d, current_nn

		# can't be reached
		print >> sys.stderr, "brute_force_nn: BUG!"
	# end def brute_force_nn


	# Compute nearest neighbor to query using brute-force method:
	# we just visit every leaf

	def brute_force_nn( self, query, current_d = None, current_nn = None ):

		global num_leaves_visited
		if current_d is None:
			# initial invokation at root
			num_leaves_visited = 0

		if self._split_axis is None:
			# leaf
			num_leaves_visited += 1

			if current_d is None:
				# first point
				return query.dist(self._point), self._point
			else:
				d = query.dist( self._point )
				if d < current_d:
					return d, self._point
				else:
					return current_d, current_nn
		else:
			# inner node
			current_d, current_nn = self._left.brute_force_nn( query, current_d, current_nn )
			current_d, current_nn = self._right.brute_force_nn( query, current_d, current_nn )
			return current_d, current_nn
		# can't be reached
		print >> sys.stderr, "brute_force_nn: BUG!"
	# end def brute_force_nn


	# Print a kd-tree
	# cell_region must be the same as HyperRectangle(points) when the constructor was called!
	# check_only = true -> nothing will be print, only some consistency checks done

	def out( self, cell_region, check_only, level = 0 ):

		if not check_only:
			# print indent
			for i in xrange(level):
				print "  ",

		if self._split_axis is None:
			# leaf
			if not check_only:
				print self._point
			if self._point is None:
				print "BUG: leaf does not contain a point!"
			elif not cell_region.contains( self._point ):
				if not check_only:
					for i in xrange(level):
						print "  ",
				print "BUG: Leaf region does not contain its point!"
				if check_only:
					print "level = ", level
				for i in xrange(level):
					print "  ",
				print cell_region
				if check_only:
					print self._point
			return

		if level == 0:
			# print region of root = bbox(points)
			if not check_only:
				print cell_region

		# print splitting plane
		if not check_only:
			print self._split_axis, " : ", self._split_value

		# split hyperrectangle (do it exactly as the build function did)
		left_region, right_region = cell_region.split( self._split_axis, self._split_value )

		# print left and right sub-trees
		self._left.out( left_region, check_only, level+1 )
		self._right.out( right_region, check_only, level+1 )
	# end def out


########################### Point distributions ######################


# Return Bbox (low,high) of a set of points
def bbox( points ):
	low  = functools.reduce( lambda a, b: a.min(b) , points )		# minimum of all points (coord-wise)
	high = functools.reduce( lambda a, b: a.max(b) , points )		# minimum of all points (coord-wise)
	return (low, high)



# Create a number of points which are distributed in clusters.
# Each cluster is a normal distribution, containing 1000 points
# The center of each clustered is picked randomly inside the unit cube,
# with sigma^2 = 0.001

def createClusteredPoints( num_points, dim ):
	num_points_per_cluster = 1000
	num_clusters = num_points / num_points_per_cluster
	points = []
	longest_side = 0.0
	shortest_side = 1000.0

	for i in xrange(num_clusters):
		cluster_center = Vector.random(dim)

		for j in xrange(num_points_per_cluster):
			points.append( Vector.randomGauss( dim, cluster_center, 0.001 ) )

		# for testing
		mini, maxi = bbox( points )
		diag = maxi - mini
		longest_side = max( max(diag), longest_side )
		shortest_side = min( min(diag), shortest_side )

	print "Longest cluster side =", longest_side
	print "Shortest cluster side =", shortest_side

	return points


# Create a number of points which arise from an auto-correlation.
# The first point is a random point inside the unit cube;
# all others obey the following autocorrelation:
#   p[i+1] = 0.9 * p[i] + 0.1 * w,  p[i] and w in R^dim,
# where w is noise from a Gaussian distribution.

def createCorrelatedPoints( num_points, dim ):
	points = [ Vector.random(dim) ]
	c = Vector( dim, 0.0 )

	p = points[0]
	for i in xrange( num_points-1 ):
		p2 = 0.8*p + 0.2 * Vector.randomGauss( dim, c, 0.1 )
		points.append( p2 )
		p = p2

	# for testing
	mini, maxi = bbox( points )
	diag = maxi - mini
	longest_side = max( diag )
	shortest_side = min( diag )

	print "Longest side of correlated points =", longest_side
	print "Shortest side of correlated points =", shortest_side

	return points



# Scale all points such that the bbox(points) = unit cube [0,1]^d
def scalePointsToUnitCube( points ):
	mini, maxi = bbox( points )
	diag = maxi - mini
	scal = Vector( len(diag), [ 1.0/x for x in diag] )
	for p in points:
		for i in xrange( len(p) ):
			p[i] = ( p[i] - mini[i] ) * scal[i]



########################### Main ######################

# retrieve command line arguments
if len( sys.argv ) != 6:
	print >> sys.stderr, "Usage: ", sys.argv[0], " start_dimension end_dimension epsilon #points distribution-type"
	sys.exit( 2 )

start_dim = int( sys.argv[1] )
if start_dim < 2:  start_dim = 2
end_dim = int( sys.argv[2] )
end_dim += 1 		# add 1, otherwise invokations like 'kdTree.py 10 10' would produce an empty range below
epsilon = float( sys.argv[3] )
if epsilon < 0.0:  epsilon = 0.1
num_points = int( sys.argv[4] )
if num_points < 10:  num_points = 100
distrib = str( sys.argv[5] )
distrib = distrib[:3]
if distrib not in ("uni", "clu", "cor"):
	print >> sys.stderr, "Distribution-type: must be in {uniform, clustered, correlated}!"
	distrib = "uni"

# init statistics
avg_leaves_exact = [0] * end_dim
avg_leaves_approx = [0] * end_dim

# make experiments for a range of dimensions
dim_step =  (end_dim - start_dim) / 10
if dim_step < 1:
	dim_step = 1
for dim in xrange(start_dim, end_dim, dim_step):

	print "\nDimension = ", dim

	# create random points
	print "Creating random points ..."
	if distrib == "uni":
		points = [ Vector.random(dim) for i in xrange(num_points) ]
	elif distrib == "clu":
		points = createClusteredPoints( num_points, dim )
	else: # disitrib == "cor"
		points = createCorrelatedPoints( num_points, dim )
	scalePointsToUnitCube( points )

	# create kd-tree
	print "Building kd-tree ..."
	kdtree = kdTree( points )

	# check consistency
	print "Checking ..."
	kdtree.out( HyperRectangle(points), True )

	# Use known query point
	query = Vector( dim, 0.5 )

	print "Starting queries ..."

	# do a number of tests
	num_tests = 100
	for j in xrange(num_tests):

		print "\n", j, "."

		# compute nearest neighbor by brute force, but only once, in order to save time
		if j == 0:
			d, nn = kdtree.brute_force_nn( query )
			print "Exact nn  = ", nn
			print " distance = ", d
			print " #leaves  = ", num_leaves_visited						# must always = num_points

			# check that d is really the closest distance to query
			def checkNN( points, q, d ):
				for p in points:
					if q.dist( p ) < d:
						print "BUG: ", p
						print "   is nearer (%f  vs  %f)!" % (q.dist(p), d)

			checkNN( points, query, d )

		# compute exact nearest neighbor
		d, nn = kdtree.exact_nn( query )
		print "Exact nn  = ", nn
		print " distance = ", d
		print " #leaves  = ", num_leaves_visited

		avg_leaves_exact[dim] += num_leaves_visited

		checkNN( points, query, d )

		# compute approximate nearest neighbor
		d, ann = kdtree.approx_nn( query, epsilon )
		print "Approx nn = ", nn
		print " distance = ", d
		print " |ann-q| / |nn-q| = ", query.dist(ann) / query.dist(nn)
		print " #leaves  = ", num_leaves_visited

		avg_leaves_approx[dim] += num_leaves_visited

		# check
		if query.dist(ann) > (1.0 + epsilon) * d:
			print "BUG: ANN too far away!"
			print "ANN     = ", ann
			print "|q-ann| = ", query.dist( ann )
			print "d       = ", d


		# check the statement of the proof (the raison d'etre)
		if query.dist(ann) > (1.0 + epsilon) * query.dist(nn):
			print "BUG: ann is farther away than (1+eps)*dist(nn,q)!"
			print "ann     = ", ann
			print "|q-ann| = ", query.dist( ann )
			print "|q-nn|  = ", query.dist( nn )

		# new query point for next round of tests
		query = Vector.random(dim)

# print statistics
print "\n#Start dimension =", start_dim, "  last dimension =", end_dim
print "#Epsilon =", epsilon, "   # points =", num_points,
if distrib == "uni":     print "   distribution = uniform"
elif distrib == "clu":   print "   distribution = clustered"
else:                    print "   distribution = correlated"
print "\n#Avg # leaves visited"
print "#Dim.    Exact-NN   Approx-NN"
for dim in xrange(start_dim, end_dim, dim_step):
	print "%4d    %8d   %8d" % \
			( dim, (avg_leaves_exact[dim]/num_tests), (avg_leaves_approx[dim]/num_tests) )

