module MapTree(
    Map
  , empty
  , lookup
  , put
  , toList
  ) where

-- Weighted Trees in Haskell
--  Very much inspired by Data.Map module in the standard library.

import Prelude hiding(lookup)


-- The map is implemented as a tree decorated with the keys and values.
{-- datamap --}
data Map alpha beta = Empty 
              | Node alpha beta Int (Map alpha beta) (Map alpha beta)
{-- end --}




-- I. The Tree functions                            

-- Returns the size of the tree
size :: Map alpha beta-> Int
size Empty              = 0
size (Node _ _ h _ _) = h

-- The weight parameter says when a tree has become unbalanced.
-- The ratio parameter says when an inner subtree has become too large.  
weight, ratio :: Int
weight = 3
ratio  = 2

-- Main function: construct a balanced tree.
-- Four main cases: 
-- a. Tree small enough 
-- b. Imbalanced to the right 
-- c. Imbalanced to the left
-- d. Balanced
-- Cases b and c require rotations (double/single, depending on bias)
-- Preconditions: lt and rt are balanced, and weight difference not too 
--  large (ie. at most one node too large/small)
{-- balance --}
balance :: alpha-> beta-> Map alpha beta-> Map alpha beta-> Map alpha beta
balance k x lt rt 
  | ls+ rs < 2  = node k x lt rt
  | weight* ls < rs = balanceL k x lt rt
  | ls > weight* rs = balanceR k x lt rt 
  | otherwise   = node k x lt rt where 
     ls = size lt; rs= size rt
{-- end --}

-- Shift balance to the left (two subcases)
{-- balanceL --}
balanceL :: alpha-> beta-> Map alpha beta-> Map alpha beta-> Map alpha beta
balanceL k x lt rt@(Node kr xr _ rlt rrt) 
  | size rlt < ratio* size rrt = rotl k x lt rt
  | otherwise                  = rotl k x lt (rotr kr xr rlt rrt)
{-- end --}

-- Shift balance to the right (two subcases)
{-- balanceR --}
balanceR :: alpha-> beta-> Map alpha beta-> Map alpha beta-> Map alpha beta
balanceR k x lt@(Node lk lx _ llt lrt) rt
  | size lrt < ratio* size llt = rotr k x lt rt
  | otherwise                  = rotr k x (rotl lk lx llt lrt) rt
{-- end --}

-- One single rotation to the left/right.
{-- rot --}
rotl, rotr :: alpha-> beta-> Map alpha beta-> Map alpha beta-> Map alpha beta
rotl k x lt (Node rk rx _ rlt rrt) = node rk rx (node k x lt rlt) rrt
rotr k x (Node lx lk _ llt lrt) rt = node lx lk llt (node k x lrt rt)
{-- end --}

-- Constructs a node, maintains size.
-- Pre: arguments are balanced, and size is set correctly
node :: alpha-> beta-> Map alpha beta-> Map alpha beta-> Map alpha beta
node k x lt rt = Node k x (1+ size lt+ size rt) lt rt 

-- II. The map functions.

empty :: Map alpha beta
empty = Empty 

lookup :: Ord alpha=> alpha-> Map alpha beta-> Maybe beta
lookup k Empty  = Nothing
lookup k (Node kv x _ lt rt) 
  | k == kv  = Just x
  | k <  kv  = lookup k lt
  | k >  kv  = lookup k rt

put :: Ord alpha=> alpha-> Maybe beta-> Map alpha beta-> Map alpha beta
put k (Just vv) Empty = node k vv Empty  Empty 
put k Nothing Empty   = Empty
put k v (Node nk nx sz lt rt) 
  | k < nk  = balance nk nx (put k v lt) rt
  | k == nk = case v of Just vv -> Node k vv sz lt rt
                        Nothing -> join lt rt
  | k > nk  = balance nk nx lt (put k v rt)

-- Joins two trees
-- Preconditions: 
--   |size l- size r| <= 1
--   for all node values n in l und m in r, n < m.
-- We split off the leftmost node from the right tree and
-- make that the new top node (function splitMap below-- we could
-- equally do the rightmost node from the left tree.)
join :: Map alpha beta-> Map alpha beta-> Map alpha beta
join xt Empty  = xt 
join lt rt   = balance k x lt nu where
  (k, x, nu) = splitMap rt
  splitMap :: Map alpha beta-> (alpha, beta, Map alpha beta)
  splitMap (Node k x _ Empty  rt) = (k, x, rt)
  splitMap (Node k x _ lt rt) = (nk, nx, balance k x nu rt) where
       (nk, nx, nu) = splitMap lt

fold :: (alpha-> beta-> gamma-> gamma-> gamma)-> gamma-> Map alpha beta-> gamma
fold f e Empty  = e
fold f e (Node k x _ l r) = f k x (fold f e l) (fold f e r)

-- Slightly fancy Show instance
instance (Show alpha, Show beta)=> Show (Map alpha beta) where
  show t = shw 0 t where
    shw _ Empty  = []
    shw n (Node k x _ l r) = (shw (n+1) l)
                              ++ spc n ++ "("++ show k ++", " ++ show x ++")\n" 
                              ++ (shw (n+1) r)
    spc n = concat (replicate n "   ")

{-- tolist --}
toList :: Map alpha beta -> [(alpha, beta)]
toList = fold (\k x l r -> l++[(k,x)]++r) []
{-- end --}

-- Equality on weighted trees as maps is extensional
-- (same nodes, but structure irrelevant)
{-- eqinstance --}
instance (Eq alpha, Eq beta)=> Eq (Map alpha beta) where
  t1 == t2 = toList t1 == toList t2
{-- end --}
