#!/usr/bin/env python

# Copyright (c) 2010 Leif Johnson <leif@leifjohnson.net>
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

'''A simple OpenGL viewer for C3D files.'''

import argparse
import collections
import itertools
import c3d
import math
import OpenGL.GL as gl
import OpenGL.GLU as glu
import OpenGL.GLUT as glut
import sys
import time

FLAGS = argparse.ArgumentParser()

FLAGS.add_argument('inputs', nargs=argparse.REMAINDER, metavar='FILE',
                   help='show these c3d files')

BLACK = (0, 0, 0)
WHITE = (1, 1, 1)
RED = (1, 0.2, 0.2)
YELLOW = (1, 1, 0.2)
ORANGE = (1, 0.7, 0.2)
GREEN = (0.2, 0.9, 0.2)
BLUE = (0.2, 0.3, 0.9)
COLORS = (WHITE, RED, YELLOW, GREEN, BLUE, ORANGE)


class Viewer(object):
    '''Render data from a C3D file using OpenGL.'''

    def __init__(self, c3d_reader):
        '''Set up this viewer with some initial parameters.'''
        # we get frames from the c3d file.
        self._frames = c3d_reader.read_frames()
        self._frame_rate = c3d_reader.frame_rate()

        self.maxlen = 1
        self.visible = [True for _ in xrange(c3d_reader.num_points())]
        self._trails = [[] for _ in xrange(c3d_reader.num_points())]
        self._reset_trails()

        # rendering state.
        self.theta = 350.0
        self.phi = 300.0
        self.rho = 0.0001
        self.paused = False

        self._last_time = 0
        self._mouse_button = None
        self._width = 800
        self._height = 600
        self._d_theta = 0
        self._d_phi = 0
        self._d_rho = 1.0

        # gl initialization.
        glut.glutInit([])
        glut.glutInitDisplayMode(glut.GLUT_RGBA | glut.GLUT_DEPTH | glut.GLUT_DOUBLE)
        glut.glutInitWindowPosition(0, 0)
        glut.glutInitWindowSize(800, 600)
        glut.glutCreateWindow('C3D Viewer')
        glut.glutKeyboardFunc(self.handle_keypress)
        glut.glutSpecialFunc(self.handle_special_keypress)
        glut.glutDisplayFunc(self.handle_draw)
        glut.glutReshapeFunc(self.handle_reshape)
        glut.glutMotionFunc(self.handle_mouse_movement)
        glut.glutMouseFunc(self.handle_mouse_button)
        glut.glutIdleFunc(self.handle_idle)

        gl.glEnable(gl.GL_COLOR_MATERIAL)
        gl.glEnable(gl.GL_LINE_SMOOTH)
        gl.glEnable(gl.GL_NORMALIZE)
        gl.glEnable(gl.GL_DEPTH_TEST)
        gl.glEnable(gl.GL_BLEND)

        gl.glEnable(gl.GL_LIGHTING)
        gl.glEnable(gl.GL_LIGHT0)
        gl.glEnable(gl.GL_LIGHT1)
        gl.glEnable(gl.GL_LIGHT2)

        gl.glShadeModel(gl.GL_SMOOTH)
        gl.glDepthFunc(gl.GL_LEQUAL)
        gl.glBlendFunc(gl.GL_SRC_ALPHA, gl.GL_ONE_MINUS_SRC_ALPHA)

        self._model_list = gl.glGenLists(1)

    def _reset_trails(self):
        self._trails = [collections.deque(t, self.maxlen) for t in self._trails]

    def mainloop(self):
        '''Run the GLUT main loop, blocking until it returns.'''
        glut.glutMainLoop()

    def handle_reshape(self, width, height):
        '''GLUT window reshape callback.'''
        self._width = width
        self._height = height

    def handle_mouse_button(self, button, state, x, y):
        '''GLUT mouse button callback.'''
        if state == glut.GLUT_UP:
            self._mouse_button = None
            self._d_theta = 0
            self._d_phi = 0
            self._d_rho = 1.0
            return
        self._mouse_button = (button, glut.glutGetModifiers())
        self.handle_mouse_movement(x, y)

    def handle_mouse_movement(self, x, y):
        '''GLUT mouse motion callback.'''
        button, modifier = self._mouse_button
        if button == glut.GLUT_LEFT_BUTTON:
            if modifier & glut.GLUT_ACTIVE_CTRL:
                button = glut.GLUT_MIDDLE_BUTTON
            if modifier & glut.GLUT_ACTIVE_ALT:
                button = glut.GLUT_RIGHT_BUTTON

        cx = self._width / 2
        cy = self._height / 2
        if button == glut.GLUT_LEFT_BUTTON:
            self._d_phi = 5 * float(y - cy) / self._height
            self._d_theta = 5 * float(x - cx) / self._width
        if button == glut.GLUT_MIDDLE_BUTTON:
            pass
        if button == glut.GLUT_RIGHT_BUTTON:
            self._d_rho = math.exp(float(y - cy) / self._height / 1.3)

    def handle_keypress(self, char, x, y):
        '''GLUT keyboard callback.'''
        if char == 'q':
            sys.exit(0)
        elif char in '0123456789':
            self.visible[int(char)] ^= True
        elif char == 'p':
            self.paused ^= True
        elif char in '+=':
            self.maxlen *= 2
            self._reset_trails()
        elif char in '_-':
            self.maxlen = max(1, self.maxlen / 2)
            self._reset_trails()

    def handle_special_keypress(self, key, x, y):
        '''GLUT special key callback.'''
        if key == glut.GLUT_KEY_PAGE_UP:
            self.rho /= 1.5
        elif key == glut.GLUT_KEY_PAGE_DOWN:
            self.rho *= 1.5
        elif key == glut.GLUT_KEY_UP:
            self.phi += 5
        elif key == glut.GLUT_KEY_DOWN:
            self.phi -= 5
        elif key == glut.GLUT_KEY_LEFT:
            self.theta -= 5
        elif key == glut.GLUT_KEY_RIGHT:
            self.theta += 5

    def handle_draw(self):
        '''GLUT render callback.'''
        gl.glClearColor(0.9, 0.9, 0.9, 1)
        gl.glClearDepth(1)
        gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_DEPTH_BUFFER_BIT)

        self.render_model()

        w = int(2 * self._width / 3.0)
        h = int(self._height / 3.0)

        # show a perspective rendering of the markers.
        gl.glViewport(0, 0, w, self._height)
        gl.glMatrixMode(gl.GL_PROJECTION)
        gl.glLoadIdentity()
        glu.gluPerspective(45, float(w) / self._height, 0.01, 10)

        gl.glMatrixMode(gl.GL_MODELVIEW)
        gl.glPushMatrix()
        gl.glTranslate(0, 0, -1)
        gl.glRotate(self.phi, 1, 0, 0)
        gl.glRotate(self.theta, 0, 0, 1)
        gl.glScalef(self.rho, self.rho, self.rho)
        gl.glCallList(self._model_list)
        gl.glPopMatrix()

        # render orthographic projections of the data.
        z = 1.0 / self.rho / 2
        gl.glMatrixMode(gl.GL_PROJECTION)
        gl.glLoadIdentity()
        gl.glOrtho(-z, z, -z, z, -z, z)

        # show the x-y plane.
        gl.glViewport(w, 0, w / 2, h)

        gl.glMatrixMode(gl.GL_MODELVIEW)
        gl.glPushMatrix()
        gl.glTranslate(0, 0, -1)
        gl.glScalef(2, 2, 2)
        gl.glCallList(self._model_list)
        gl.glPopMatrix()

        # show the y-z plane.
        gl.glViewport(w, h, w / 2, h)

        gl.glMatrixMode(gl.GL_MODELVIEW)
        gl.glPushMatrix()
        gl.glRotate(-90, 1, 0, 0)
        gl.glTranslate(0, 0, -1)
        gl.glScalef(2, 2, 2)
        gl.glCallList(self._model_list)
        gl.glPopMatrix()

        # show the x-z plane.
        gl.glViewport(w, 2 * h, w / 2, h)

        gl.glMatrixMode(gl.GL_MODELVIEW)
        gl.glPushMatrix()
        gl.glRotate(-90, 1, 0, 0)
        gl.glRotate(-90, 0, 0, 1)
        gl.glTranslate(0, 0, -1)
        gl.glScalef(2, 2, 2)
        gl.glCallList(self._model_list)
        gl.glPopMatrix()

        glut.glutSwapBuffers()

    def handle_idle(self):
        '''Redraw the scene.'''
        self.theta += self._d_theta
        self.phi += self._d_phi
        self.rho /= self._d_rho

        if self.paused:
            return

        elapsed = time.time() - self._last_time
        if elapsed > 1.0 / self._frame_rate:
            try:
                points, analog = self._frames.next()
            except StopIteration:
                points = []
            for trail, point in itertools.izip(self._trails, points):
                trail.append(point[:3])
            self._last_time = time.time()
        else:
            time.sleep(1.0 / self._frame_rate - elapsed)

        glut.glutPostRedisplay()

    def render_model(self):
        '''Render the objects in the world.'''
        gl.glNewList(self._model_list, gl.GL_COMPILE)

        gl.glLightfv(gl.GL_LIGHT0, gl.GL_POSITION, [1, 1, 1, 0])
        gl.glLightfv(gl.GL_LIGHT0, gl.GL_DIFFUSE, [1, 1, 1, 0.5])
        gl.glLightfv(gl.GL_LIGHT0, gl.GL_SPECULAR, [1, 1, 1, 0.5])
        gl.glLightfv(gl.GL_LIGHT1, gl.GL_POSITION, [1, 0, 0, 1])
        gl.glLightfv(gl.GL_LIGHT1, gl.GL_DIFFUSE, [1, 0, 0, 0.8])
        gl.glLightfv(gl.GL_LIGHT1, gl.GL_SPECULAR, [1, 0, 0, 0.8])
        gl.glLightfv(gl.GL_LIGHT2, gl.GL_POSITION, [0, 1, 0, 1])
        gl.glLightfv(gl.GL_LIGHT2, gl.GL_DIFFUSE, [0, 1, 0, 0.8])
        gl.glLightfv(gl.GL_LIGHT2, gl.GL_SPECULAR, [0, 1, 0, 0.8])
        gl.glLightfv(gl.GL_LIGHT3, gl.GL_POSITION, [0, 0, 1, 1])
        gl.glLightfv(gl.GL_LIGHT3, gl.GL_DIFFUSE, [0, 0, 1, 0.8])
        gl.glLightfv(gl.GL_LIGHT3, gl.GL_SPECULAR, [0, 0, 1, 0.8])

        # render a simple axis system
        gl.glLineWidth(1.0)
        gl.glBegin(gl.GL_LINES)
        gl.glColor4f(1, 0, 0, 1)
        gl.glVertex3f(0, 0, 0)
        gl.glVertex3f(100.0, 0, 0)
        gl.glColor4f(0, 1, 0, 1)
        gl.glVertex3f(0, 0, 0)
        gl.glVertex3f(0, 100.0, 0)
        gl.glColor4f(0, 0, 1, 1)
        gl.glVertex3f(0, 0, 0)
        gl.glVertex3f(0, 0, 100.0)
        gl.glEnd()

        # draw markers for the currently maintained data
        for i, points in enumerate(self._trails):
            if not self.visible[i]:
                continue
            self.render_marker_points(str(i), COLORS[i % len(COLORS)], points)

        gl.glEndList()

    def render_marker_points(self, label, color, points):
        gl.glColor4f(*(color + (0.7, )))
        for point in points:
            gl.glPushMatrix()
            gl.glTranslated(*point)
            glut.glutSolidSphere(1.0 / self.rho / 200.0, 13, 13)
            gl.glPopMatrix()

    def render_marker_trails(self, color, points):
        gl.glColor4f(*(color + (0.7, )))
        gl.glBegin(gl.GL_LINES)
        for point in points:
            gl.glVertex3f(*point)
        gl.glEnd()


if __name__ == '__main__':
    args = FLAGS.parse_args()
    if not args:
        FLAGS.error('no input file provided!')

    for filename in args.inputs:
        try:
            Viewer(c3d.Reader(open(filename, 'rb'))).mainloop()
        except StopIteration:
            pass
