import numpy as np
import scipy.io
import caffe
import skimage
import os

from PIL import Image
from interface_pb2 import WorkRequest, ResultList
try:
    from cStringIO import StringIO
except:
    from StringIO import StringIO

class SimilarSearcher( object):
    def __init__( self, classifier, imageDim, modelData="", modelURL=""):
        self.classifier = classifier
        self.modelData = modelData
        self.modelURL = modelURL
        self.imageDim = imageDim
        self.image_dims = [imageDim, imageDim]
        
    def on_request( self, request):
        print "have request"
        imgData = request.image
        im = Image.open( StringIO( imgData))

        order, distances = self.do_work( im)

        request.ClearField('image')
        for i in range( request.configuration[0].caffeSearch.resultSize):
            print order[0, i]," ", distances[ 0, i]
            request.result.url.append( self.modelURL[ order[0, i]])
            request.result.score.append( distances[ 0, i])

    def do_work( self, image):

        predictions = self.processImage( image)

        #print predictions.tolist()
        predictions = 1.0 / np.sqrt(predictions.dot( predictions.T))[0] * predictions

        distances = 1-self.modelData.dot( predictions.T).T
        order = np.argsort( distances)
        distances = distances[ 0, order]

        return (order, distances)
    
    def processImage( self, image):
        scale = self.imageDim * 1.0 / max(image.size)
        scaledSize = [ int( round( scale * x)) for x in image.size]
    
        image = image.resize( scaledSize, Image.ANTIALIAS)  #TODO: This resize is really slow for large images

        np_image = np.asarray(image)

        #add borders
        meanVal = np_image.mean();

        left = np.empty( (np_image.shape[0], (self.image_dims[1] - np_image.shape[1])/2, 3));
        left.fill( meanVal);
        np_image = np.concatenate( (left, np_image), axis=1)

        right = np.empty( (np_image.shape[0], self.image_dims[1] - np_image.shape[1], 3));
        right.fill( meanVal);
        np_image = np.concatenate( (np_image, right), axis=1)

        top = np.empty( ((self.image_dims[0] - np_image.shape[0])/2, np_image.shape[1], 3));
        top.fill( meanVal)
        np_image = np.concatenate( (top, np_image), axis=0)

        bottom = np.empty( (self.image_dims[0] - np_image.shape[0], np_image.shape[1], 3));
        bottom.fill( meanVal)
        np_image = np.concatenate( ( np_image, bottom), axis=0)
        
        color = True

        img = skimage.img_as_float( np_image / 255.).astype( np.float32)
        if img.ndim == 2:
            img = img[:, :, np.newaxis]
            if color:
                img = np.tile(img, (1, 1, 3))
            elif img.shape[2] == 4:
                img = img[:, :, :3]

        # Classify.
        return self.classifier.predict([img], oversample=True)
