import cStringIO, os, sys
import PIL.Image

from ctypes import *
from OpenGL.GL import *
from OpenGL.GLU import *
from OpenGL.GLUT import *
from OpenGL.GL.shaders import *
from OpenGL.GL.EXT.framebuffer_object import *
from OpenGL.GL.ARB.multitexture import *

# compileProgram in OpenGL.GL.shaders fails to validate if multiple samplers are used
def compileProgram(*shaders):
    program = glCreateProgram()
    for shader in shaders:
        glAttachShader(program, shader)

    glLinkProgram(program)

    for shader in shaders:
        glDeleteShader(shader)

    return program
    

class HeatMap(object):
    width = None
    height = None
    fbo = None
    texture = None
    palette = None
    
    def __init__(self, left, right, bottom, top):
        self.left = left
        self.right = right
        self.bottom = bottom
        self.top = top
        
        # Check if we need to reinitialize the OpenGL state
        if (self.width != abs(self.right - self.left) or
            self.height != abs(self.top - self.bottom) or
            None in (self.fbo, self.texture, self.palette)):
            self.prepare(abs(right - left), abs(top - bottom))
        
        glMatrixMode(GL_PROJECTION)
        glLoadIdentity()
        gluOrtho2D(left, right, bottom, top)
        
        self._clear_framebuffer()
    
    @classmethod
    def cleanup(cls):
        if cls.fbo: glDeleteFramebuffersEXT(cls.fbo)
        if cls.texture: glDeleteTextures(cls.texture)
        if cls.palette: glDeleteTextures(cls.palette)
        cls.fbo = cls.texture = cls.palette = None
    
    @classmethod
    def prepare(cls, width, height):
        cls.cleanup()
        
        cls.width = width
        cls.height = height
        
        # Glut Init
        glutInit(sys.argv)
        glutInitDisplayMode(GLUT_RGBA)
        glutInitWindowSize(cls.width, cls.height)
        glutCreateWindow("HeatMap")
        
        # Render Flags
        glEnable(GL_BLEND)
        glEnable(GL_TEXTURE_1D)
        glEnable(GL_TEXTURE_2D)
        glEnable(GL_VERTEX_PROGRAM_POINT_SIZE)
    	glEnableClientState(GL_VERTEX_ARRAY)

        cls._compile_programs()
        cls._load_palette()
        cls._create_framebuffer()
        
    @classmethod
    def _load_palette(cls, path='palettes/classic.png'):
        image = PIL.Image.open(path)
        
        cls.palette = glGenTextures(1)
        glActiveTextureARB(GL_TEXTURE0)
        glBindTexture(GL_TEXTURE_1D, cls.palette)
        glTexImage1D(GL_TEXTURE_1D, 0, GL_RGB,
                     image.size[1],
                     0, GL_RGB, GL_UNSIGNED_BYTE,
                     image.tostring('raw', 'RGB', 0, -1))
                     
        glTexParameter(GL_TEXTURE_1D, GL_TEXTURE_MIN_FILTER, GL_NEAREST)
        glTexParameter(GL_TEXTURE_1D, GL_TEXTURE_MAG_FILTER, GL_NEAREST)
        glTexParameter(GL_TEXTURE_1D, GL_TEXTURE_WRAP_S, GL_CLAMP)

    @classmethod
    def _create_framebuffer(cls):
        cls.texture = glGenTextures(1)
        cls.fbo = glGenFramebuffersEXT(1)

        glActiveTextureARB(GL_TEXTURE1)
        glBindTexture(GL_TEXTURE_2D, cls.texture)
        glTexImage2D(GL_TEXTURE_2D, 0, GL_RGBA,
                     cls.width, cls.height, 0, GL_RGBA,
                     GL_UNSIGNED_BYTE, None)
          
        glTexParameter(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST)
        glTexParameter(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST)
        glTexParameter(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_CLAMP)
        glTexParameter(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_CLAMP)

        glBindFramebufferEXT(GL_FRAMEBUFFER_EXT, cls.fbo)
        glFramebufferTexture2DEXT(GL_FRAMEBUFFER_EXT, GL_COLOR_ATTACHMENT0_EXT,
                                  GL_TEXTURE_2D, cls.texture, 0)

        status = glCheckFramebufferStatusEXT(GL_FRAMEBUFFER_EXT)
        assert status == GL_FRAMEBUFFER_COMPLETE_EXT, status
            
    @classmethod            
    def _clear_framebuffer(cls):
        glBindFramebufferEXT(GL_FRAMEBUFFER_EXT, cls.fbo)
        glClear(GL_COLOR_BUFFER_BIT)

    @classmethod
    def _compile_programs(cls):
        # Shader program to transform color into the proper palette based on the alpha channel
        cls.color_transform_program = compileProgram(
            compileShader('''
                void main() {
                    gl_Position = ftransform();
                }
            ''', GL_VERTEX_SHADER),
            compileShader('''
                uniform float alpha;
                uniform sampler1D palette;
                uniform sampler2D framebuffer;
                uniform vec2 windowSize;
    
                void main() {
                    gl_FragColor.rgb = texture1D(palette, texture2D(framebuffer,  gl_FragCoord.xy / windowSize).a).rgb;
                    gl_FragColor.a = alpha;
                }
            ''', GL_FRAGMENT_SHADER))
    
        # Shader program to place heat points
        cls.faded_points_program = compileProgram(
            compileShader('''
                uniform float r;
                uniform vec2 windowSize;
                varying vec2 pos;

                void main() {
                    gl_PointSize = 2.0 * r + 2.0;
                    gl_Position = ftransform();
                    pos = (gl_Position.xy * windowSize + windowSize) * 0.5;
                }
            ''', GL_VERTEX_SHADER),
            compileShader('''
                uniform float r;
                varying vec2 pos;

                void main() {
                    float d = distance(gl_FragCoord.xy, pos);
                    if (d > r) discard;
                
                    gl_FragColor.rgb = vec3(1.0, 1.0, 1.0);
                    gl_FragColor.a = (0.5 + cos(d * 3.14159265 / r) * 0.5) * 0.25;
                
                    // Alternate fading algorithms
                    //gl_FragColor.a = (1.0 - (log(1.1+d) / log(1.1+r)));
                    //gl_FragColor.a = (1.0 - (pow(d, 0.5) / pow(r, 0.5)));
                    //gl_FragColor.a = (1.0 - ((d*d) / (r*r))) / 2.0;
                    //gl_FragColor.a = (1.0 - (d / r)) / 2.0;
                    
                    gl_FragColor.a = clamp(gl_FragColor.a, 0.0, 1.0);
                }
            ''', GL_FRAGMENT_SHADER))
    
    def add_points(self, points, radius):
        # Render all points with the specified radius
        glUseProgram(self.faded_points_program)
        glUniform1f(glGetUniformLocation(self.faded_points_program, 'r'), radius)
        glUniform2f(glGetUniformLocation(self.faded_points_program, 'windowSize'), self.width, self.height)
        
        glBindFramebufferEXT(GL_FRAMEBUFFER_EXT, self.fbo)
        glBlendFunc(GL_ONE, GL_ONE_MINUS_SRC_ALPHA)
        
    	glVertexPointerd(points)
    	glDrawArrays(GL_POINTS, 0, len(points))
    	glFlush()
        
    def transform_color(self, alpha):
        # Transform the color into the proper palette
        glUseProgram(self.color_transform_program)
        glUniform1f(glGetUniformLocation(self.color_transform_program, 'alpha'), alpha)
        glUniform1i(glGetUniformLocation(self.color_transform_program, 'palette'), 0)
        glUniform1i(glGetUniformLocation(self.color_transform_program, 'framebuffer'), 1)
        glUniform2f(glGetUniformLocation(self.color_transform_program, 'windowSize'), self.width, self.height)
                    
        glBindFramebufferEXT(GL_FRAMEBUFFER_EXT, self.fbo)
        glBlendFunc(GL_ONE, GL_ZERO)

        vertices = [(self.left, self.bottom), (self.right, self.bottom),
                    (self.right, self.top), (self.left, self.top)]                    
    	glVertexPointerd(vertices)
    	glDrawArrays(GL_QUADS, 0, len(vertices))
        glFlush()
   
    def get_image(self):
        # Get the data from the heatmap framebuffer and convert it into a PIL image
        glActiveTextureARB(GL_TEXTURE1)
        data = glGetTexImage(GL_TEXTURE_2D, 0, GL_RGBA, GL_UNSIGNED_BYTE)
        im = PIL.Image.frombuffer('RGBA', (self.width, self.height), data, 'raw', 'RGBA', 0, -1)

        # Write the image to a buffer as a PNG
        f = cStringIO.StringIO()
        im.save(f, 'png')
        f.seek(0)
        
        return f

        

def test1():
    import random, time

    hm = HeatMap(0, 256, 0, 256)
    
    points = [(random.gauss(.5, .08) * 200, random.gauss(.5, .075) * 100) for i in xrange(100)]
    points += [(random.gauss(.7, .075) * 250, random.gauss(.2, .042) * 150) for i in xrange(100)]
    points += [(random.gauss(.4, .06) * 175, random.gauss(.3, .04) * 200) for i in xrange(100)]
    
    def run():
        hm.add_points(points, 10.5)
        hm.transform_color(1.0)
        image = hm.get_image()
        open("test.png", "wb").write(image.read())

    def runtimes(ntimes):
        start = time.time()
        for i in xrange(ntimes):
            hm._clear_framebuffers()
            run()
    
        total = time.time() - start
        print total, total / ntimes 
    
    #runtimes(500)
    run()
    
    #glutDisplayFunc(lambda:(hm.render_to_screen(), glutSwapBuffers()))
    #glutMainLoop()


if __name__ == "__main__":
    test1()


    
