#!/usr/bin/env python

# -*- coding: utf-8 -*-

# vim: tabstop=4 shiftwidth=4 softtabstop=4

#    Copyright (C) 2013 Yahoo! Inc. All Rights Reserved.
#
#    Licensed under the Apache License, Version 2.0 (the "License"); you may
#    not use this file except in compliance with the License. You may obtain
#    a copy of the License at
#
#         http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
#    License for the specific language governing permissions and limitations
#    under the License.

from itertools import ifilter, islice
import json
import logging
import optparse
import os
import re
import subprocess
import sys

from datetime import datetime

import prettytable
import six

logging.basicConfig(level=logging.ERROR,
                    format='%(asctime)s %(levelname)s: %(message)s',
                    stream=sys.stderr)
LOG = logging.getLogger(__name__)

### DEFAULT SETTINGS/CONSTANTS

GERRIT_HOST = 'review.openstack.org'
GERRIT_PORT = 29418
DEFAULT_SORT = 'createdOn'
GERRIT_CMD = ('gerrit', 'query',
              '--format=JSON',
              "--current-patch-set",
              "--files")
INDENT = ' '

###


def tiny_p(cmd, capture=True, ok_rcs=(0,)):
    LOG.debug("Running cmd %s", cmd)

    if hasattr(subprocess, 'check_output'):
        return subprocess.check_output(cmd)

    # Borrowed from cloud-init...
    #
    # Darn python 2.6 doesn't have check_output (argggg)
    stdout = subprocess.PIPE
    if not capture:
        stdout = None
    sp = subprocess.Popen(cmd, stdout=stdout,
                          stderr=None, stdin=None)
    (out, err) = sp.communicate()
    ret = sp.returncode  # pylint: disable=E1101
    if ret not in ok_rcs:
        raise RuntimeError("Failed running %s [rc=%s] (%s, %s)"
                           % (cmd, ret, out, err))
    return out


def server_cmd(login_name, port=GERRIT_PORT, host=GERRIT_HOST):
    LOG.debug("port = %s", port)
    LOG.debug("host = %s", host)
    pieces = []
    if login_name:
        pieces.extend(['%s@%s' % (login_name, host), '-p', port])
    else:
        pieces.extend([host, '-p', port])
    return [str(p) for p in pieces]


def run_query(login_name, query, key_file, port, hostname, dependencies):
    cmd = ['ssh']
    if key_file and os.path.isfile(key_file):
        cmd.extend(['-i', str(key_file)])
    cmd.extend(server_cmd(login_name, port=port, host=hostname))
    cmd.extend(GERRIT_CMD)

    if dependencies:
        cmd.append('--dependencies')

    cmd.append(query)

    sortkey = None
    while True:
        if sortkey is None:
            stdout = tiny_p(cmd)
        else:
            stdout = tiny_p(cmd + ['resume_sortkey:%s' % sortkey])

        for e in stdout.splitlines():
            try:
                dec = json.loads(e)
                if not isinstance(dec, (dict)):
                    raise TypeError("Expected decoded dict, not %s" %
                                    (type(dec)))

                # The stats row comes last
                # Stop if we returned 0 results
                if dec.get('type') == 'stats':
                    if dec.get('rowCount') == 0:
                        return
                else:
                    sortkey = dec['sortKey']

                    # Always initialise the children of this change so we can
                    # assume it later
                    dec['children'] = []

                    yield dec
            except (ValueError, TypeError):
                LOG.exception("Failure decoding received entry")


def uniq_itr(itr, keyfunc=lambda x: x):
    seen = set()
    for i in itr:
        key = keyfunc(i)
        if key in seen:
            continue
        seen.add(key)
        yield i


def get_key(k, row):
    if k not in row:
        return ""
    v = row.get(k)
    if v is None:
        return ""

    if isinstance(v, unicode):
        v = v.encode('utf8')
    return str(v)


def get_date(k, row):
    v = get_key(k, row)
    try:
        now = datetime.now()
        then = datetime.fromtimestamp(int(v))
        delta = now - then
        hours = delta.seconds / (60 * 60)
        mins = delta.seconds / 60

        if delta.days == 1:
            return "%d day" % delta.days
        elif delta.days > 1:
            return "%d days" % delta.days
        elif hours == 1:
            return "%d hour" % hours
        elif hours > 1:
            return "%d hours" % hours
        elif mins == 1:
            return "%d min" % mins
        elif mins > 1:
            return "%d mins" % mins
        else:
            return "just now"

        return str(delta)
    except (TypeError, ValueError):
        return ''


def get_approvals(k, row):
    try:
        approvals = row["currentPatchSet"]["approvals"]
    except (KeyError, TypeError):
        approvals = []
    vals = {}
    for approval in approvals:
        got_type = approval["type"][0:1].lower()
        if got_type not in vals:
            vals[got_type] = []
        vals[got_type].append(approval["value"])
    keys = vals.keys()
    keys.sort(reverse=True)
    return " ".join(map(lambda val: "%s=%s" % (val,
                                               ",".join(vals[val])), keys))


def print_results(results, fields):
    table = prettytable.PrettyTable(map(lambda field: field["label"], fields))
    table.padding_width = 1
    table.align = "l"  # left alignment

    # Python hack because variable is immutable from closure, but its contents
    # are not
    has_results = [False]
    def add_results(results, indent=0):
        for result in results:
            has_results[0] = True

            row = []
            for i in range(0, len(fields)):
                field = fields[i]
                func = field["mapfunc"]
                val = func(field["key"], result)

                # Indent the first value to indicate dependency trees
                if i == 0:
                    val = INDENT * indent + val

                if "truncate" in field:
                    maxlen = field["truncate"]
                    if len(val) > maxlen:
                        val = val[0:maxlen] + "..."
                row.append(val)
            table.add_row(row)

            add_results(result['children'], indent + 1)
    add_results(results)

    # Check we've got results. Note that we don't just check the length of
    # results in case it's an iterator.
    if has_results[0]:
        print(table.get_string())


def matches_file(result, files):
    names = []
    try:
        for file_entry in result["currentPatchSet"]["files"]:
            names.append(file_entry.get("file"))
    except (KeyError, TypeError):
        pass
    for file_re in files:
        for name in names:
            if not isinstance(name, six.string_types):
                continue
            if re.search(file_re, name):
                return True
    return False


#
# approval == {v,c,a}NN  (v == verified, c == code review, a == approved)
#
# NN == -1 -> patches without any -2
# NN == 0 -> patches without any -1 or -2
# NN == =0 -> no values other than 0, useful for querying for patches not
#             approved using "a=0".
# NN == 1 -> patches without any -1 or -2 and at least one +1
# NN == 2 -> patches without any -1 or -2 and at least one +2
#
def matches_approval(result, approval):
    rules = approval.split(",")
    requires = {}
    for rule in rules:
        try:
            requires[rule[0:1]] = {
                'value': int(rule[1:]),
                'type': 'min',
            }
        except ValueError:
            if rule[1:2] == '=':
                requires[rule[0:1]] = {
                    'value': int(rule[2:]),
                    'type': 'exact',
                }
            else:
                raise

    try:
        approvals = result["currentPatchSet"]["approvals"]
    except (KeyError, TypeError):
        approvals = []

    # Record lowest and highest values
    got = {
        'c': [0, 0],
        'v': [0, 0],
        'a': [0, 0],
    }
    for approval in approvals:
        got_type = approval["type"][0:1].lower()
        got_val = int(approval["value"])

        if got_type not in got:
            continue

        # Record lowest + highest flag values
        if got[got_type][0] > got_val:
            got[got_type][0] = got_val
        if got[got_type][1] < got_val:
            got[got_type][1] = got_val

    for rule in requires.keys():
        if requires[rule]['type'] == 'min':
            if rule not in got:
                return False
            if requires[rule]['value'] >= 0:
                if got[rule][0] < 0:
                    return False
            else:
                if got[rule][0] < requires[rule]['value']:
                    return False
            if got[rule][1] < requires[rule]['value']:
                return False
        elif requires[rule]['type'] == 'exact':
            if rule not in got and requires[rule]['value'] == 0:
                continue
            if (got[rule][0] != requires[rule]['value'] or
                    got[rule][1] != requires[rule]['value']):
                return False

    return True


def matches_reviewer(result, reviewers):
    include = []
    exclude = []

    for reviewer in reviewers:
        if reviewer[0] == '^':
            exclude.append(reviewer[1:])
        else:
            include.append(reviewer)

    try:
        approvals = result["currentPatchSet"]["approvals"]
    except (KeyError, TypeError):
        approvals = []

    got = []
    wantPatch = False
    for approval in approvals:
        if "username" not in approval["by"]:
            continue
        who = approval["by"]["username"]
        got.append(who)

    for reviewer in include:
        if not reviewer in got:
            return False

    for reviewer in exclude:
        if reviewer in got:
            return False

    return True


def get_info(loginname, keyfile, terms, approval, reviewers, port, hostname,
             dependencies, files):
    clauses = []
    for term in terms:
        if len(terms[term]) == 0:
            continue
        clause = " OR ".join(map(lambda value: "%s:%s" % (term, value),
                             terms[term]))
        if clause != "":
            clauses.append(clause)
    query = " AND ".join(map(lambda clause: "(%s)" % clause, clauses))
    results = run_query(loginname, query, keyfile, port, hostname, dependencies)
    if files is not None and len(files) > 0:
        results = ifilter(lambda result: matches_file(result, files), results)
    if approval is not None:
        results = ifilter(lambda result: matches_approval(result, approval),
                          results)
    if len(reviewers) > 0:
        results = ifilter(lambda result: matches_reviewer(result, reviewers),
                          results)
    return results


def get_key_path():
    home_dir = os.path.expanduser("~")
    ssh_dir = os.path.join(home_dir, ".ssh")
    if not os.path.isdir(ssh_dir):
        return None
    for k in ('id_rsa', 'id_dsa'):
        path = os.path.join(ssh_dir, k)
        if os.path.isfile(path):
            return path
    return None


def valid_field(name, fields):
    for field in fields:
        if field["key"] == name:
            return True
    return False


def format_fields(fields):
    fields = ["'%s'" % (f) for f in sorted(fields)]
    return ", ".join(fields)


def get_owner(k, row):
    owner = None
    try:
        owner = row.get(k).get('username')
    except (TypeError, KeyError):
        pass
    if not owner:
        return ''
    return str(owner)


def sort_results(results, key, reverse):
    for result in sorted(results, key=lambda result: result[key],
                         reverse=reverse):
        result['children'] = sort_results(result['children'], key, reverse)
        yield result


# Late creation of this to make sure we can capture the needed functions.
ALL_FIELDS = [
    {'key': "status", 'label': "Status", 'mapfunc': get_key},
    {'key': "topic", 'label': "Topic", 'mapfunc': get_key, "truncate": 20},
    {'key': "url", 'label': "URL", 'mapfunc': get_key},
    {'key': "owner", 'label': "Owner", 'mapfunc': get_owner},
    {'key': "project", 'label': "Project", 'mapfunc': get_key},
    {'key': "branch", 'label': "Branch", 'mapfunc': get_key},
    {'key': "subject", 'label': "Subject", 'mapfunc': get_key, "truncate": 50},
    {'key': "createdOn", 'label': "Created", 'mapfunc': get_date},
    {'key': "lastUpdated", 'label': "Updated", 'mapfunc': get_date},
    {'key': "approvals", 'label': "Approvals", 'mapfunc': get_approvals},
]
ALL_FIELD_NAMES = sorted([f['key'] for f in ALL_FIELDS])

# See: https://git.eclipse.org/r/Documentation/user-search.html#status
ALL_STATUS = sorted([
    'abandoned',
    'submitted',
    'reviewed',
    'closed',
    'open',
    'merged',
])


def main():
    help_formatter = optparse.IndentedHelpFormatter(width=79)
    parser = optparse.OptionParser(formatter=help_formatter)
    parser.add_option("-l", "--login", dest="login", action='store',
                      help="connect to gerrit with USER", metavar="USER")
    parser.add_option("-u", "--user", dest="users", action='append',
                      help="gather information on given USER", metavar="USER",
                      default=[])
    parser.add_option("-w", "--reviewer", dest="reviewers", action='append',
                      help="filter to patches with reviewer USER", metavar="USER",
                      default=[])
    parser.add_option("-s", "--status", dest="statuses", action='append',
                      help="gather information on given status",
                      metavar="STATUS", default=[])
    parser.add_option("-m", "--message", dest="messages", action='append',
                      help="filter on message", metavar="MESSAGE",
                      default=[])
    parser.add_option("-p", "--project", dest="projects", action='append',
                      help="gather information on given project",
                      metavar="PROJECT",
                      default=[])
    parser.add_option("-b", "--branch", dest="branches", action='append',
                      help="filter on branch", metavar="BRANCH",
                      default=[])
    parser.add_option("-a", "--approval", dest="approval", action="store",
                      help="filter on approval value min %n"
                           " [default: no filter]",
                      metavar="APPROVAL", default=None)
    parser.add_option("-g", "--gerrit", dest="instance", action='store',
                      help="server address of the gerrit instance",
                      metavar="INSTANCE", default=GERRIT_HOST)
    parser.add_option("-r", "--port", dest="port", action='store',
                      help="port of the gerrit instance",
                      metavar="PORT", default=GERRIT_PORT)
    parser.add_option("-k", "--keyfile", dest="keyfile", action='store',
                      help="gerrit ssh keyfile [default: %default]",
                      metavar="FILE", default=get_key_path())
    parser.add_option("-t", "--sort", dest="sort", action='store',
                      help="sort order for results [default: %default]",
                      metavar="SORT", default=DEFAULT_SORT)
    parser.add_option("-n", "--limit", dest="limit", action="store",
                      type="int", default=0,
                      help="Limit the number of returned results. Note that "
                           "this limit is applied before sorting")
    parser.add_option("-d", "--deps", dest="dependencies",
                      action="store_const", const=True, default=False,
                      help="Display results as a dependency tree")
    parser.add_option("-f", "--field", dest="fields", action='append',
                      help="display field in results"
                           " [default: %s]" % format_fields(ALL_FIELD_NAMES),
                      metavar="FIELD", default=[])
    (options, args) = parser.parse_args()

    if len(options.statuses) == 0:
        options.statuses = ["open"]

    for s in options.statuses:
        if s not in ALL_STATUS:
            parser.error("Invalid status '%s', valid options: %s"
                         % (s, format_fields(ALL_STATUS)))

    if len(options.fields) == 0:
        fields = ALL_FIELDS
    else:
        trunc = {}
        field_names = []
        for name in options.fields:
            offset = name.find(":")
            if offset != -1:
                val = int(name[offset + 1:])
                name = name[0:offset]
                trunc[name] = val
            field_names.append(name)

        for n in field_names:
            if not valid_field(n, ALL_FIELDS):
                parser.error("Invalid valid name '%s', valid options: %s"
                             % (n, format_fields(ALL_FIELD_NAMES)))

        fields = []
        for field in ALL_FIELDS:
            if field["key"] not in field_names:
                continue
            if field["key"] in trunc:
                field["truncate"] = trunc[field["key"]]
            fields.append(field)

    sort_by = DEFAULT_SORT
    reverse = False
    offset = options.sort.find(":")
    if offset != -1:
        direction = options.sort[offset + 1:]
        if direction == "rev":
            reverse = True
        name = options.sort[0:offset]
    else:
        name = options.sort
    if not valid_field(name, fields):
        parser.error("Invalid sort key '%s', valid options: %s"
                     % (name, format_fields([f['key'] for f in fields])))
    sort_by = name

    entries = get_info(options.login,
                       options.keyfile,
                       {
                           "owner": list(uniq_itr(options.users)),
                           "status": list(uniq_itr(options.statuses)),
                           'message': list(uniq_itr(options.messages)),
                           'project': list(uniq_itr(options.projects)),
                           'branch': list(uniq_itr(options.branches)),
                       },
                       options.approval,
                       options.reviewers,
                       options.port,
                       options.instance,
                       options.dependencies,
                       args)

    # There shouldn't be any duplicates, but apparently there occasionally are
    # anyway. Filter them out.
    entries = uniq_itr(entries, lambda x: x['number'])

    if options.limit != 0:
        entries = islice(entries, options.limit)

    if options.dependencies:
        # Index all entries by change id
        all_entries = {}
        for entry in entries:
            all_entries[entry['id']] = entry

        # Turn the list into a dependency tree
        entries = []
        for entry in all_entries.values():
            depends = None
            if 'dependsOn' in entry:
                # N.B. We only consider the first dependency. We don't allow
                # merge commits anyway, so this should never be an issue.
                depends = entry['dependsOn'][0]['id']

            if depends is not None and depends in all_entries:
                parent = all_entries[depends]
                parent['children'].append(entry)
            else:
                entries.append(entry)

    entries = sort_results(entries, sort_by, reverse)
    print_results(entries, fields)


if __name__ == '__main__':
    main()
