from PIL import Image

import numpy as np

def to_graph(im):
    """
    Convert image to graph. 8-connected, based on intensity absolute difference
    """
    
    edges = {}
    num_rows, num_cols = im.height, im.width
    nodes =  num_rows * num_cols
    
    image = np.empty((num_rows,num_cols, 3))
    
    imiter = iter(im.getdata())
    
    for row in xrange(num_rows):
        for col in xrange(num_cols):
            image[row, col, :] = imiter.next()

    for row in xrange(num_rows):
        for col in xrange(num_cols):
            cur = image[row,col]
            diff = lambda a,b: (a[0]-b[0])**2 + (a[1]-b[1])**2 + (a[2]-b[2])**2
            if row < image.shape[0] -1:
                #   @   *
                #   |
                #   @   *
                edges[((row,col),(row+1,col))] = diff(cur, image[row+1,col])
                #if col < image.shape[1] - 1:
                    #  @   *
                    #     \
                    #  *   @
                    #edges[((row, col), (row+1,col+1))] = diff(cur, image[row+1,col+1])
                    #if row > 0:
                        #edges[((row,col), (row-1,col+1))] = diff(cur, image[row-1,col+1])
            if col < image.shape[1] - 1:
                #   @ - @
                #
                #   *   *
                edges[((row,col), (row, col+1))] = diff(cur, image[row,col+1])
    return edges

from collections import defaultdict

class UF:
    """ Code from http://python-algorithms.readthedocs.org/en/latest/_modules/python_algorithms/basic/union_find.html """

    def __init__(self, N):
        self._id = list(range(N))
        self._count = N
        self._rank = [0] * N
        self._N = N
        self._symbol_to_index = {}
        self._index_to_symbol = {}

    def find(self, p):
        if isinstance(p, int) and p < self._N and            p not in self._index_to_symbol:
            self._symbol_to_index[p] = p
            self._index_to_symbol[p] = p
        else:
            self._symbol_to_index.setdefault(p, len(self._symbol_to_index))
            self._index_to_symbol.setdefault(self._symbol_to_index[p], p)
        i = self._symbol_to_index[p]
        if i >= self._N:
            raise IndexError('You have been exceeding the UF capacity')

        id = self._id
        while i != id[i]:
            id[i] = id[id[i]]
            i = id[i]
        return i


    def union(self, p, q):

        id = self._id
        rank = self._rank

        i = self.find(p)
        j = self.find(q)
        if i == j:
            return

        self._count -= 1
        if rank[i] < rank[j]:
            id[i] = j
        elif rank[i] > rank[j]:
            id[j] = i
        else:
            id[j] = i
            rank[i] += 1


from collections import defaultdict
def tao(size, k=5000.):
    return k/size

def segment(image, k=5000.):
    uf_nodes = UF(image.width*image.height)
    internal = defaultdict(lambda: (0,1))
    count = 0
    graph = to_graph(image)
    edges = sorted(graph.items(), key=lambda x: x[1])
    for edge in edges:
        count += 1
        to_node = edge[0][1]
        from_node = edge[0][0]
        weight = edge[1]
        
        # set_name is a single node
        set_name1 = uf_nodes.find(to_node)
        set_name2 = uf_nodes.find(from_node)

        int1,size1 = internal[set_name1]
            
        int2,size2 = internal[set_name2]
            
        if weight <= min(int1+tao(size1, k=k), int2+tao(size2,k=k)) and set_name1 != set_name2:
            uf_nodes.union(to_node, from_node)
            new_set_name = uf_nodes.find(to_node)
            del internal[set_name1]
            del internal[set_name2]
            internal[new_set_name] = weight, size1+size2+1
    return uf_nodes


def create_segment_image(union_find, im):
    uf = union_find
    image = np.empty((im.height,im.width), dtype=np.uint16)
    for i in xrange(im.height):
        for j in xrange(im.width):
            image[i,j] = uf.find((i,j))
    return image


import ImageFilter


if __name__ == '__main__':
    from matplotlib.pyplot import *
    image = Image.open('/home/coda/Desktop/download.jpg').filter(ImageFilter.MedianFilter)
    image.thumbnail((400,400))
    imshow(create_segment_image(segment(image,10000), image), cmap='Accent')
    show()