#!/usr/bin/env python

# -*- coding: utf-8 -*-

# vim: tabstop=4 shiftwidth=4 softtabstop=4

#    Copyright (C) 2013 Yahoo! Inc. All Rights Reserved.
#
#    Licensed under the Apache License, Version 2.0 (the "License"); you may
#    not use this file except in compliance with the License. You may obtain
#    a copy of the License at
#
#         http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
#    License for the specific language governing permissions and limitations
#    under the License.

import collections
import functools
import getpass
import logging
import optparse
import os
import re
import sys
import threading
import time

import Queue

from datetime import datetime

from gerritlib import gerrit
import six
import urwid

logging.basicConfig(level=logging.ERROR,
                    format='%(asctime)s %(levelname)s: %(message)s',
                    stream=sys.stderr)
LOG = logging.getLogger(__name__)

### DEFAULT SETTINGS

GERRIT_HOST = 'review.openstack.org'
GERRIT_PORT = 29418
BACKOFF_ATTEMPTS = 5
VISIBLE_LIST_LEN = 50
PREFETCH_LEN = VISIBLE_LIST_LEN
ALARM_FREQ = 1.0
SANITY_QUERY = 'status:open limit:%s'

### GUI CONSTANTS

PALETTE = (
    ('body', urwid.DEFAULT, urwid.DEFAULT),
    ('merged', urwid.LIGHT_CYAN, urwid.DEFAULT, 'bold'),
    ('approved', urwid.LIGHT_GREEN, urwid.DEFAULT),
    ('abandoned', urwid.YELLOW, urwid.DEFAULT),
    ('verified', urwid.LIGHT_GRAY, urwid.DEFAULT),
    ('restored', urwid.LIGHT_BLUE, urwid.DEFAULT),
    ('rejected', urwid.LIGHT_RED, urwid.DEFAULT, 'bold'),
    ('failed', urwid.LIGHT_RED, urwid.DEFAULT, 'bold'),
    ('succeeded', urwid.LIGHT_GREEN, urwid.DEFAULT),
    ('open', urwid.WHITE, urwid.DEFAULT),
)
COLUMNS = (
    'Username',
    "Topic",
    "Url",
    "Project",
    'Subject',
    'Created On',
    'Status',
    'Comment',
)
COLUMN_TRUNCATES = {
    # This determines how the columns will be trucated (at what length will
    # truncation be forced to avoid huge strings).
    'comment': 120,
    'reason': 120,
    'subject': 60,
}
COLUMN_ATTRIBUTES = {
    'Created On': (urwid.WEIGHT, 0.5),
    'Status': (urwid.FIXED, 9),
    'Username': (urwid.FIXED, 13),
    'Project': (urwid.WEIGHT, 0.5),
    'Topic': (urwid.WEIGHT, 0.33),
    'Url': (urwid.FIXED, 35),
    'Subject': (urwid.WEIGHT, 1.0),
    'Comment': (urwid.WEIGHT, 0.7),
}
HIGHLIGHT_WORDS = {
    # These words get special colored highlighting.
    #
    # word -> palette name
    'succeeded': 'succeeded',
    'success': 'succeeded',
    'successful': 'succeeded',
    'failure': 'failed',
    'failed': 'failed',
    'fails': 'failed',
}

### HELPERS


def _format_text(text):
    text_pieces = []
    for t in re.split(r"([\s.\-,!])", text):
        if t.lower() in HIGHLIGHT_WORDS:
            text_pieces.append((HIGHLIGHT_WORDS[t.lower()], t))
        else:
            text_pieces.append(t)
    return _make_text(text_pieces)


def _make_text(text):
    return urwid.Text(text, wrap='any', align='left')


def _get_key_path():
    home_dir = os.path.expanduser("~")
    ssh_dir = os.path.join(home_dir, ".ssh")
    if not os.path.isdir(ssh_dir):
        return None
    for k in ('id_rsa', 'id_dsa'):
        path = os.path.join(ssh_dir, k)
        if os.path.isfile(path):
            return path
    return None


def _format_date(when=None):
    if when is None:
        when = datetime.now()
    return when.strftime('%I:%M %p %m/%d/%Y')


def _get_date(k, row):
    v = _get_text(k, row)
    if not v:
        return None
    try:
        return datetime.fromtimestamp(int(v))
    except (ValueError, TypeError):
        return None


def _get_text(k, container):
    if k not in container:
        return ""
    text = container[k]
    if not isinstance(text, six.string_types):
        text = str(text)
    max_len = COLUMN_TRUNCATES.get(k.lower())
    if max_len is not None and len(text) > max_len:
        text = text[0:max_len] + "..."
    return text


class GerritWatcher(threading.Thread):
    def __init__(self, queue, server, port, username, keyfile, prefetch):
        super(GerritWatcher, self).__init__()
        self.queue = queue
        self.keyfile = keyfile
        self.port = port
        self.server = server
        self.username = username
        self.prefetch = prefetch
        self.daemon = True
        self.gerrit = None
        self.has_prefetched = False

    def _sanity_check(self):

        def event_sort(ev1, ev2):
            p1 = ev1['patchSet']
            p2 = ev2['patchSet']
            return cmp(p1['createdOn'], p2['createdOn'])

        fetch_am = 1
        if not self.has_prefetched:
            fetch_am = self.prefetch
        q = SANITY_QUERY % (fetch_am)
        LOG.info("Using '%s' for sanity query.", q)
        results = self.gerrit.bulk_query(q)
        if not self.has_prefetched:
            self.has_prefetched = True

        translated = []
        for r in results:
            if not isinstance(r, (dict)):
                continue
            if r.get('type') == 'stats':
                continue
            # Translate it into what looks like a patch-set created
            # event and then send this via the queue to showup on the gui
            ev = {
                'type': 'patchset-created',
                'uploader': r.pop('owner'),
                'patchSet': {
                    'createdOn': r.pop('createdOn'),
                    'lastUpdated': r.pop('lastUpdated', None),
                },
            }
            ev['change'] = dict(r)
            translated.append(ev)

        # For some reason we can get more than requested, even though
        # we send a limit, huh??
        LOG.info("Received %s sanity check results.", len(translated))
        translated = translated[0:fetch_am]
        for e in sorted(translated, cmp=event_sort):
            self.queue.put(e)

    def _connect(self):
        try:
            if self.gerrit is None:
                self.gerrit = gerrit.Gerrit(self.server, self.username,
                                            self.port, self.keyfile)
            # NOTE(harlowja): only after the sanity query passes do we have
            # some level of confidence that the watcher thread will actually
            # correctly connect.
            self._sanity_check()
            self.gerrit.startWatching()
            LOG.info('Start watching gerrit event stream.')
        except Exception:
            LOG.exception('Exception while connecting to gerrit')

    @property
    def connected(self):
        if self.gerrit is None:
            return False
        if self.gerrit.watcher_thread is None:
            return False
        if not self.gerrit.watcher_thread.is_alive():
            return False
        return True

    def _ensure_connected(self):
        if self.connected:
            return
        for i in xrange(0, BACKOFF_ATTEMPTS):
            self._connect()
            if not self.connected:
                sleep_time = 2**i
                if i + 1 < BACKOFF_ATTEMPTS:
                    LOG.warn("Trying connection again in %s seconds",
                             sleep_time)
                time.sleep(sleep_time)
            else:
                break
        if not self.connected:
            LOG.fatal("Could not connect to %s:%s", self.server, self.port)

    def _handle_event(self, event):
        LOG.debug('Placing event on producer queue: %s', event)
        self.queue.put(event)

    def _consume(self):
        try:
            event = self.gerrit.getEvent()
            self._handle_event(event)
        except Exception:
            LOG.exception('Exception encountered in event loop')
            if self.gerrit.watcher_thread is not None \
               and not self.gerrit.watcher_thread.is_alive():
                self.gerrit = None

    def run(self):
        while True:
            self._ensure_connected()
            self._consume()


def _consume_queue(queue):
    # TODO(harlowja): consume many at once instead of just one?
    events = []
    ev = None
    try:
        ev = queue.get(block=False)
    except Queue.Empty:
        pass
    if ev is not None:
        events.append(ev)
    return events


def _get_change_status(event):
    change_type = None
    for approval in event.get('approvals', []):
        if not isinstance(approval, (dict)):
            continue
        try:
            approval_value = int(approval['value'])
        except (ValueError, TypeError, KeyError):
            approval_value = None
        if approval.get('type') == 'VRIF':
            if approval_value == -2:
                change_type = 'Failed'
            if approval_value == -1:
                change_type = 'Verified'
            if approval_value == 2:
                change_type = 'Succeeded'
        if approval.get('type') == 'CRVW':
            if approval_value == -2:
                change_type = 'Rejected'
            if approval_value == 2:
                change_type = 'Approved'
    return change_type


class ReviewDate(urwid.Text):
    def __init__(self, when=None):
        super(ReviewDate, self).__init__('')
        self.when = when
        if when is not None:
            self.set_text(_format_date(when))


class ReviewTable(urwid.ListBox):
    def __init__(self, max_size=1):
        super(ReviewTable, self).__init__(urwid.SimpleListWalker([]))
        assert int(max_size) > 0, "Max size must be > 0"
        self._columns = tuple(COLUMNS)
        self._column_attributes = dict(COLUMN_ATTRIBUTES)
        self._max_size = int(max_size)
        self._header = None
        self._footer_pieces = [None, None, None]
        self._column_2_idx = dict((k, i) for (i, k) in enumerate(COLUMNS))
        self._sort_by = [
            (None, None),  # no sorting
            ('Created On (Desc)', self._sort_date("Created On", False)),
            ('Created On (Asc)', self._sort_date("Created On", True)),
            ('Subject (Desc)', self._sort_text("Subject", False)),
            ('Subject (Asc)', self._sort_text("Subject", True)),
            ('Username (Desc)', self._sort_text("Username", False)),
            ('Username (Asc)', self._sort_text("Username", True)),
            ('Project (Desc)', self._sort_text("Project", False)),
            ('Project (Asc)', self._sort_text("Project", True)),
            ('Topic (Desc)', self._sort_text("Topic", False)),
            ('Topic (Asc)', self._sort_text("Topic", True)),
        ]
        self._sort_idx = 0
        self._rows = []

    def _sort_text(self, col_name, asc):
        col_idx = self._column_2_idx[col_name]
        flip_map = {
            0: 0,
            -1: 1,
            1: -1,
        }

        def sorter(i1, i2):
            t1 = i1.contents[col_idx][0]
            t2 = i2.contents[col_idx][0]
            r = cmp(t1.text, t2.text)
            if not asc:
                r = flip_map[r]
            return r

        return sorter

    def _sort_date(self, col_name, asc):
        col_idx = self._column_2_idx[col_name]
        flip_map = {
            0: 0,
            -1: 1,
            1: -1,
        }

        def sorter(i1, i2):
            d1 = i1.contents[col_idx][0]
            d2 = i2.contents[col_idx][0]
            if d1.when is None and d2.when is None:
                r = 0
            if d1.when is None and d2.when is not None:
                r = -1
            if d1.when is not None and d2.when is None:
                r = 1
            if d1.when is not None and d2.when is not None:
                r = cmp(d1.when, d2.when)
            if not asc:
                r = flip_map[r]
            return r

        return sorter

    @property
    def columns(self):
        return self._columns

    @property
    def max_size(self):
        return self._max_size

    @property
    def header(self):
        if self._header is None:
            table_header = []
            for col_name in self.columns:
                try:
                    col_attrs = list(self._column_attributes[col_name])
                except (KeyError, TypeError):
                    col_attrs = []
                col_attrs.append(_make_text(col_name))
                table_header.append(tuple(col_attrs))
            self._header = table_header
        cols = urwid.Columns(self._header, dividechars=1)
        sep = urwid.AttrWrap(urwid.Divider('-'), 'body')
        return urwid.Pile([urwid.AttrWrap(cols, 'body'), sep])

    @property
    def right_footer(self):
        if self._footer_pieces[2] is None:
            self._footer_pieces[2] = urwid.Text('', align='right')
        return self._footer_pieces[2]

    @property
    def left_footer(self):
        if self._footer_pieces[0] is None:
            self._footer_pieces[0] = urwid.Text('', align='left')
        return self._footer_pieces[0]

    @property
    def center_footer(self):
        if self._footer_pieces[1] is None:
            self._footer_pieces[1] = urwid.Text('', align='center')
        return self._footer_pieces[1]

    @property
    def footer(self):
        sep = urwid.AttrWrap(urwid.Divider('-'), 'body')
        footer_pieces = [
            self.left_footer,
            self.center_footer,
            self.right_footer,
        ]
        cols = urwid.Columns(footer_pieces)
        return urwid.Pile([sep, urwid.AttrWrap(cols, 'body')])

    def _add_row(self, row):
        if len(row.contents) != len(self.columns):
            raise RuntimeError("Attempt to add a row with differing"
                               " column count")
        if len(self._rows) >= self.max_size:
            self._rows.pop()
        self._rows.insert(0, row)
        (_sort_title, sort_functor) = self._sort_by[self._sort_idx]
        if sort_functor:
            self._refill(sorted(self._rows, cmp=sort_functor))
        else:
            if len(self.body) >= self.max_size:
                self.body.pop()
            self.body.insert(0, row)

    def _find_change(self, change):
        url_i = self._column_2_idx['Url']
        m_c = None
        for c in self.body:
            url = c.contents[url_i]
            if url[0].text == change.get('url'):
                m_c = c
                break
        return m_c

    def _set_status(self, match, text):
        if not text or match is None:
            return None
        status_i = self._column_2_idx['Status']
        new_contents = list(match.contents[status_i])
        new_contents[0] = urwid.AttrWrap(_make_text(text), text.lower())
        match.contents[status_i] = tuple(new_contents)
        return match

    def on_change_merged(self, event):
        change = event['change']
        match = self._find_change(change)
        if match is not None:
            self._set_status(match, 'Merged')

    def on_change_restored(self, event):
        change = event['change']
        match = self._find_change(change)
        if match is not None:
            reason = _get_text('reason', event)
            if len(reason):
                comment_i = self._column_2_idx['Comment']
                new_column = list(match.contents[comment_i])
                new_column[0] = _format_text(reason)
                match.contents[comment_i] = tuple(new_column)
            self._set_status(match, 'Restored')

    def on_comment_added(self, event):
        change = event['change']
        match = self._find_change(change)
        if match is not None:
            comment = _get_text('comment', event)
            if len(comment):
                comment_i = self._column_2_idx['Comment']
                new_column = list(match.contents[comment_i])
                new_column[0] = _format_text(comment)
                match.contents[comment_i] = tuple(new_column)
            self._set_status(match, _get_change_status(event))

    def on_change_abandoned(self, event):
        change = event['change']
        match = self._find_change(change)
        if match is not None:
            reason = _get_text('reason', event)
            if len(reason):
                comment_i = self._column_2_idx['Comment']
                new_column = list(match.contents[comment_i])
                new_column[0] = _format_text(reason)
                match.contents[comment_i] = tuple(new_column)
            self._set_status(match, 'Abandoned')

    def on_patchset_created(self, event):
        change = event['change']
        match = self._find_change(change)
        if match is not None:
            # NOTE(harlowja): already being actively displayed
            return
        patch_set = event['patchSet']
        uploader = event['uploader']
        row = [
            _get_text('username', uploader),
            _get_text('topic', change),
            _get_text('url', change),
            _get_text('project', change),
            _get_text('subject', change),
            ReviewDate(_get_date('createdOn', patch_set)),
            "",  # status
            "",  # comment
        ]
        attr_row = []
        for (i, v) in enumerate(row):
            col_name = self.columns[i]
            try:
                col_attrs = list(self._column_attributes[col_name])
            except (KeyError, TypeError):
                col_attrs = []
            if not isinstance(v, urwid.Text):
                col_attrs.append(_format_text(v))
            else:
                col_attrs.append(v)
            attr_row.append(tuple(col_attrs))
        cols = urwid.Columns(attr_row, dividechars=1)
        self._set_status(cols, 'Open')
        self._add_row(cols)

    def _refill(self, new_body):
        while len(self.body):
            self.body.pop()
        self.body.extend(new_body)

    def keypress(self, size, key):
        handled = super(ReviewTable, self).keypress(size, key)
        if handled is None:
            return None
        if key in ('s', 'S'):
            self._sort_idx += 1
            self._sort_idx = self._sort_idx % len(self._sort_by)
            (sort_title, sort_functor) = self._sort_by[self._sort_idx]
            if not all([sort_title, sort_functor]):
                self.center_footer.set_text("")
                self._refill(self._rows)
            else:
                self.center_footer.set_text("Sort: %s" % (sort_title))
                self._refill(sorted(self._rows, cmp=sort_functor))
            return None
        return key

###


def main():
    parser = optparse.OptionParser()
    parser.add_option("-u", "--user", dest="username", action='store',
                      help="gerrit user [default: %default]", metavar="USER",
                      default=getpass.getuser())
    parser.add_option("-s", "--server", dest="server", action='store',
                      help="gerrit server [default: %default]",
                      metavar="SERVER", default=GERRIT_HOST)
    parser.add_option("-p", "--port", dest="port", action='store',
                      type="int", help="gerrit port [default: %default]",
                      metavar="PORT", default=GERRIT_PORT)
    parser.add_option("--prefetch", dest="prefetch", action='store',
                      type="int", help="prefetch amount [default: %default]",
                      metavar="COUNT", default=PREFETCH_LEN)
    parser.add_option("-k", "--keyfile", dest="keyfile", action='store',
                      help="gerrit ssh keyfile [default: %default]",
                      metavar="FILE", default=_get_key_path())
    parser.add_option("--project", dest="projects", action='append',
                      help="only show given projects reviews",
                      metavar="PROJECT", default=[])
    parser.add_option("-i", "--items", dest="back", action='store',
                      type="int",
                      help="how many items to keep visible"
                           " [default: %default]",
                      metavar="COUNT", default=VISIBLE_LIST_LEN)
    (options, args) = parser.parse_args()
    if options.back <= 0:
        parser.error("Item count must be greater or equal to one.")

    gerrit_config = {
        'keyfile': options.keyfile,
        'port': int(options.port),
        'server': options.server,
        'username': options.username,
        'prefetch': max(0, options.prefetch),
    }
    event_queue = Queue.Queue()
    gerrit_reader = GerritWatcher(event_queue, **gerrit_config)
    gerrit_details = collections.defaultdict(int)

    review_table = ReviewTable(max_size=options.back)
    review_table.left_footer.set_text("Initializing...")
    frame = urwid.Frame(urwid.AttrWrap(review_table, 'body'),
                        footer=review_table.footer,
                        header=review_table.header)

    def filter_event(event):
        if len(options.projects) == 0:
            return False
        project = None
        try:
            project = event['change']['project']
        except (KeyError, TypeError, ValueError):
            pass
        if project in options.projects:
            return False
        return True

    def on_unhandled_input(key):
        if key in ('q', 'Q', 'esc'):
            raise urwid.ExitMainLoop()

    def process_event(event):
        if not isinstance(event, (dict)) or not 'type' in event:
            return
        if filter_event(event):
            return
        event_type = str(event['type'])
        gerrit_details[event_type] += 1
        if event_type == 'patchset-created':
            review_table.on_patchset_created(event)
        elif event_type == 'comment-added':
            review_table.on_comment_added(event)
        elif event_type == 'change-merged':
            review_table.on_change_merged(event)
        elif event_type == 'change-restored':
            review_table.on_change_restored(event)
        elif event_type == 'change-abandoned':
            review_table.on_change_abandoned(event)
        else:
            raise RuntimeError("Unknown event type: '%s'" % (event_type))

    def process_gerrit(loop, user_data):
        evs = _consume_queue(event_queue)
        for e in evs:
            try:
                process_event(e)
            except Exception:
                LOG.exception("Failed handling event: %s", e)
        detail_text = "%s, %s events received (%sp, %sc, %sm, %sr, %sa)"
        detail_text = detail_text % (_format_date(),
                                     sum(gerrit_details.values()),
                                     gerrit_details.get('patchset-created', 0),
                                     gerrit_details.get('comment-added', 0),
                                     gerrit_details.get('change-merged', 0),
                                     gerrit_details.get('change-restored', 0),
                                     gerrit_details.get('change-abandoned', 0))
        review_table.right_footer.set_text(detail_text)
        if gerrit_reader.is_alive():
            if not gerrit_reader.connected:
                if not gerrit_reader.has_prefetched:
                    review_table.left_footer.set_text("Connecting & "
                                                      "prefetching...")
                else:
                    review_table.left_footer.set_text("Connecting...")
            else:
                if len(evs) == 0:
                    review_table.left_footer.set_text("Waiting for events...")
                else:
                    review_table.left_footer.set_text("Processing events...")
        else:
            review_table.left_footer.set_text("Initializing...")

    def on_idle(loop):
        loop.set_alarm_in(ALARM_FREQ, process_gerrit)

    loop = urwid.MainLoop(urwid.LineBox(frame), PALETTE,
                          handle_mouse=False,
                          unhandled_input=on_unhandled_input)
    gerrit_reader.start()
    loop.event_loop.enter_idle(functools.partial(on_idle, loop))
    loop.run()


if __name__ == "__main__":
    main()
