#!/usr/bin/env python

import argparse
import contextlib
import sys
import os
import shutil

import mammoth


def main():
    args = _parse_args()
    
    if args.style_map is None:
        style_map = None
    else:
        with open(args.style_map) as style_map_fileobj:
            style_map = style_map_fileobj.read()
    
    with open(args.path, "rb") as docx_fileobj:
        if args.output_dir is None:
            convert_image = None
            output_path = args.output
        else:
            convert_image = mammoth.images.inline(ImageWriter(args.output_dir))
            output_filename = "{0}.html".format(os.path.basename(args.path).rpartition(".")[0])
            output_path = os.path.join(args.output_dir, output_filename)
        
        result = mammoth.convert_to_html(
            docx_fileobj,
            style_map=style_map,
            convert_image=convert_image,
        )
        for message in result.messages:
            sys.stderr.write(message.message)
            sys.stderr.write("\n")
        
        with _open_output(output_path) as output:
            output.write(result.value)


class ImageWriter(object):
    def __init__(self, output_dir):
        self._output_dir = output_dir
        self._image_number = 1
        
    def __call__(self, element):
        extension = element.content_type.partition("/")[2]
        image_filename = "{0}.{1}".format(self._image_number, extension)
        with open(os.path.join(self._output_dir, image_filename), "wb") as image_dest:
            with element.open() as image_source:
                shutil.copyfileobj(image_source, image_dest)
        
        self._image_number += 1
        
        return {"src": image_filename}


@contextlib.contextmanager
def _open_output(name):
    if name is None:
        try:
            yield Output(sys.stdout)
        finally:
            sys.stdout.flush()
    else:
        with open(name, "w") as output:
            yield Output(output)


class Output(object):
    def __init__(self, fileobj):
        self._fileobj = fileobj
    
    def write(self, value):
        if sys.version_info[0] <= 2: 
            self._fileobj.write(value.encode("utf8"))
        else:
            self._fileobj.write(value)
            


def _parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("path")
    
    output_group = parser.add_mutually_exclusive_group()
    output_group.add_argument("output", nargs="?")
    output_group.add_argument("--output-dir", nargs="?")
    
    parser.add_argument("--style-map", required=False)
    return parser.parse_args()


if __name__ == "__main__":
    main()
