#!/usr/bin/env python3
#
# Maintain a MEDLINE/PubMed repository
#
"""
parse:   Medline XML files into raw table files for DB dumping; ==
insert:  PubMed XML files or a list of PMIDs (contacting EUtils) into the DB
         (slower than using "parse" and a DB dump); ==
update:  existing records or add new records from PubMed XML files or a list
         of PMIDs (slow!); ==
write:   records in various formats for a given list of PMIDs (implicitly sets
         --pmid-lists); ==
delete:  records from the DB for a given list of PMIDs (implicitly sets
         --pmid-lists)
"""
import logging
import os
import sys

from sqlalchemy.exc import OperationalError

__author__ = 'Florian Leitner'
__version__ = '2.2.0'


def Main(command, files_or_pmids, session, unique=True):
    """
    :param command: one of create/read/update/delete
    :param files_or_pmids: the list of files or PMIDs to process
    :param session: the DB session
    :param unique: flag to skip versioned records if VersionID != "1"
    """
    from medic.crud import insert, select, update, delete

    if command == 'insert':
        return insert(session, files_or_pmids, unique)
    elif command == 'write':
        return select(session, [int(i) for i in files_or_pmids])
    elif command == 'update':
        return update(session, files_or_pmids, unique)
    elif command == 'delete':
        return delete(session, [int(i) for i in files_or_pmids])


def WriteTabular(query, output_file: str):
    """Write `query` results as TSV to file or STDOUT if ``output_file == '.'``\ ."""
    def prune(string):
        return string.replace('\n', ' ').replace('\t', ' ')

    logging.debug("writing to TSV file %s", output_file)
    file = open(output_file, 'wt') if output_file != '.' else sys.stdout

    try:
        for rec in query:
            if 'NLM' in rec.abstracts:
                abstract = ' '.join(
                    prune(s.content) for s in rec.abstracts['NLM'].sections
                )
            elif len(rec.abstracts) == 1:
                key = [k for k in rec.abstracts][0]
                abstract = ' '.join(
                    prune(s.content) for s in rec.abstracts[key].sections
                )
            else:
                abstract = ''

            print(rec.pmid, prune(rec.title), abstract, sep='\t', file=file)
    finally:
        if output_file != '.':
            file.close()


def WriteHTML(query, output_file: str):
    """Write `query` results as HTML to a file or STDOUT if ``output_file == '.'``\ ."""
    from medic.web import FormatHTML

    logging.debug("writing to HTML file %s", output_file)
    file = open(output_file, 'wt', encoding='utf-8') if output_file != '.' else sys.stdout

    try:
        print(FormatHTML(query), file=file)
    finally:
        if output_file != '.':
            file.close()


def WriteTIAB(query, output_dir: str):
    """Write `query` results as TIAB plain-text files in `output_dir`."""
    assert os.path.isdir(output_dir), '%s not a directory' % output_dir
    logging.debug("writing TIAB to directory %s", output_dir)

    for rec in query:
        logging.debug("writing PMID %i as TIAB", rec.pmid)
        file = open(os.path.join(output_dir, "{}.txt".format(rec.pmid)), 'wt', encoding='utf-8')

        try:
            print(rec.title, file=file)

            if 'NLM' in rec.abstracts:
                for sec in rec.abstracts['NLM'].sections:
                    WriteSection(sec, file)
            elif len(rec.abstracts) == 1:
                for source in rec.abstracts:
                    for sec in rec.abstracts[source].sections:
                        WriteSection(sec, file)
        finally:
            file.close()


def WriteSection(section, file):
    """Write a MEDLINE record's (plain-text) `section` to a `file` handle."""
    print('', file=file)

    if section.label is not None:
        print(section.label, file=file)

    print(section.content, file=file)


def WriteMedline(query, output_file: str):
    """Write `query` results to a MEDLINE file or STDOUT if ``output_file == '.'``\ ."""
    logging.debug("writing MEDLINE to file %s", output_file)
    file = open(output_file, 'wt', encoding='utf-8') if output_file != '.' else sys.stdout

    try:
        for rec in query:
            WriteMedlineRecord(rec, file)
    finally:
        file.close()


def WriteMedlineRecord(record, file):
    """Write a MEDLINE `record` to a `file` handle."""
    def WriteMedlineItem(key: str, value):
        print("{:<4}- {}".format(
            key[:4], str(value).replace('\\', '\\\\').replace('\n', '\\n')
        ), file=file)

    addIfExists = lambda val: ' {}'.format(val) if val else ''
    formatMajor = lambda el: '*{}'.format(el.name) if el.major else el.name
    medlineDate = lambda dt: str(dt).replace('-', '')
    owner = None
    source = "{}. {}".format(record.journal, record.pub_date)

    logging.debug("writing PMID %i as MEDLINE", record.pmid)
    WriteMedlineItem('PMID', record.pmid)
    WriteMedlineItem('STAT', record.status)
    WriteMedlineItem('DA', medlineDate(record.created))

    if record.revised:
        WriteMedlineItem('LR', medlineDate(record.revised))

    if record.completed:
        WriteMedlineItem('DCOM', medlineDate(record.completed))

    WriteMedlineItem('TA', record.journal)
    WriteMedlineItem('DP', record.pub_date)

    if record.issue:
        source += ';' + record.issue
        WriteMedlineItem('VI', record.issue)

    if record.pagination:
        source += ':' + record.pagination
        WriteMedlineItem('PG', record.pagination)

    WriteMedlineItem('SO', source)

    for pub_type in record.publication_types:
        WriteMedlineItem('PT', pub_type.value)

    WriteMedlineItem('TI', record.title)

    for author in record.authors:
        if author.initials == '':  # corporate authors
            WriteMedlineItem('CN', author.name)
        else:
            WriteMedlineItem('AU', '{}{}{}'.format(
                author.name, addIfExists(author.initials),
                addIfExists(author.suffix)
            ))

            if author.forename:
                WriteMedlineItem('FAU', '{}, {}{}'.format(
                    author.name, author.forename,
                    addIfExists(author.suffix)
                ))

    for source in record.abstracts:
        abstract = record.abstracts[source]
        text = ' '.join(sec.content for sec in abstract.sections)

        if source == 'NLM':
            WriteMedlineItem('AB', text)
            WriteMedlineItem('CI', abstract.copyright)
        else:
            WriteMedlineItem('OAB', "{}: {}".format(source, text))
            WriteMedlineItem('OCI', "{} {}".format(source, abstract.copyright))

    for desc in record.descriptors:
        heading = [formatMajor(desc)]

        for qual in desc.qualifiers:
            heading.append(formatMajor(qual))

        WriteMedlineItem('MH', '/'.join(heading))

    for kwd in record.keywords:
        if kwd.owner != owner:
            WriteMedlineItem('OTO', kwd.owner)
            owner = kwd.owner

        WriteMedlineItem('OT', formatMajor(kwd))

    for chem in record.chemicals:
        if chem.uid:
            WriteMedlineItem('RN', '{} ({})'.format(chem.uid, chem.name))
        else:
            WriteMedlineItem('RN', chem.name)

    for xref in record.databases:
        WriteMedlineItem('SI', '{}/{}'.format(xref.name, xref.accession))

    for ns, uid in record.identifiers.items():
        if ns == 'pmc':
            WriteMedlineItem('PMC', uid.value)
        else:
            WriteMedlineItem('AID', "{} [{}]".format(uid.value, ns))

    print('', file=file)  # empty line

if __name__ == '__main__':
    from argparse import ArgumentParser
    from medic.orm import InitDb, Session

    epilog = 'system (default) encoding: {}'.format(sys.getdefaultencoding())

    parser = ArgumentParser(
        usage='%(prog)s [options] CMD FILE/PMID ...',
        description=__doc__, epilog=epilog,
        prog=os.path.basename(sys.argv[0])
    )

    parser.set_defaults(loglevel=logging.WARNING)

    parser.add_argument(
        'command', metavar='CMD', choices=[
            'parse', 'insert', 'write', 'update', 'delete'
        ],
        help='one of {parse,insert,write,update,delete}; see above'
    )
    parser.add_argument(
        'files', metavar='FILE/PMID', nargs='+',
        help='MEDLINE XML file, PMID (integer), or PMID list file'
    )
    parser.add_argument('--version', action='version', version=__version__)
    parser.add_argument(
        '--url', metavar='URL',
        help='a database URL string, e.g., sqlite:///tmp.db '
             '[postgresql://localhost/medline]',
        default='postgresql://localhost/medline'
    )
    parser.add_argument(
        '--format', choices=['full', 'html', 'tiab', 'tsv'], default='medline',
        help='medline: [default] write all content to one MEDLINE file; '
             'html: write all content to one HTML file; '
             'tsv: write one TSV line per PMID, title, and abstract; '
             'tiab: write title and abstract only plain-text to files; '
    )
    parser.add_argument(
        '--all', action='store_true',
        help='parse records with VersionID != "1"'
    )
    parser.add_argument(
        '--update', action='store_true',
        help='parsing MEDLINE XML files: '
             'delete all parsed records prior to inserting them'
    )
    parser.add_argument(
        '--pmid-lists', action='store_true',
        help='assume input files are lists of PMIDs, not XML files ' +
             '(implicitly set for delete and write operations)'
    )
    parser.add_argument(
        '--output', metavar='DIR', default=os.path.curdir,
        help='dump/write to a specific directory or file '
             '(for formats "medline", "tsv" and "html")'
    )
    parser.add_argument(
        '--error', action='store_const', const=logging.ERROR,
        dest='loglevel', help='error log level only [warn]'
    )
    parser.add_argument(
        '--info', action='store_const', const=logging.INFO,
        dest='loglevel', help='info log level [warn]'
    )
    parser.add_argument(
        '--debug', action='store_const', const=logging.DEBUG,
        dest='loglevel', help='debug log level [warn]'
    )
    parser.add_argument('--logfile', metavar='FILE',
                        help='log to file, not STDERR')

    args = parser.parse_args()
    logging.basicConfig(
        filename=args.logfile, level=args.loglevel,
        format='%(asctime)s %(name)s %(levelname)s: %(message)s'
    )

    if args.command not in ('parse', 'write', 'insert', 'update', 'delete'):
        parser.error('illegal command "{}"'.format(args.command))

    if args.command in ('write', 'delete'):
        args.pmid_lists = True

    def ParseListOrYield(file):
        """Helper to read PMID lists if `file` indeed is a file."""
        if os.path.isfile(file):
            for line in open(file):
                yield line.strip()
        else:
            yield file

    if args.pmid_lists:
        args.files = [
            pmid for f in args.files for pmid in ParseListOrYield(f)
        ]

    if args.command == 'parse':
        from medic.crud import dump

        result = dump(args.files, args.output, not args.all, args.update)
    else:
        try:
            InitDb(args.url)
        except OperationalError as e:
            parser.error(str(e))

        result = Main(args.command, args.files, Session(), not args.all)

        if args.command == 'write':
            if args.format == 'tsv':
                WriteTabular(result, args.output)
            elif args.format == 'html':
                WriteHTML(result, args.output)
            elif args.format == 'tiab':
                WriteTIAB(result, args.output)
            else:
                WriteMedline(result, args.output)
            result = True

    sys.exit(0 if result else 1)
