import re
import roman
from linfir.mdtex.utils import *

__all__ = ['parse_inline', 'parse', 'Parser', 'ParseError']

R_newline = re.compile(r"\n(?=[^\n])( *)")
R_space = re.compile(r" +")
R_space0 = re.compile(r" *")
R_nl = re.compile(r"\n+")
R_nl0 = re.compile(r"\n*")
R_ws = re.compile(r"\s+")
R_ws_nl = re.compile(r"\s*\n")
R_text = re.compile(r"[^\n\*\$\{\}\\<>]+")
R_comment = re.compile(r"(?s)<!--(.*?)-->")
R_abbrev = re.compile(r"\\([a-zA-Z]+)")

R_meta1 = re.compile(r"%\s+(\S.*)\n+")
R_meta = re.compile(r"---\n")
R_meta_line = re.compile(r"(\w+)\s*:(.*)\n")
R_meta_preamble = re.compile(r"(?s)\n(.*?)\n---\n")

R_header = re.compile(r"(#+)\s*(.*)\n+")
R_header_thm = re.compile(r"!!(?:[\s\d!]*)(.*)\n+")
R_maybe_hdr = re.compile(r"#|!!")

R_bullet_roman = re.compile(r"([ivx]+)[.)] +")
R_bullet_enum = re.compile(r"(?:\d+|[a-z])[.)] +")
R_bullet_item = re.compile(r"[-*+] +")

R_newcommand = re.compile(r"\\(?:re)?newcommand(?![a-zA-Z]).*\n+")


class ParseError(Exception):
    pass


class Parser:
    def __init__(self, text):
        self.text = text
        self.pos = 0
        self.len = len(text)
        self.line_delta = 0
        self.in_emph = False
        self.in_math = False
        self.indent = 0
        self.meta = {}

    def __repr__(self):  # pragma: no cover
        return "Parser(rest={!r})".format(self.text[self.pos:])

    def startswith(self, prefix):
        if self.text.startswith(prefix, self.pos):
            self.pos += len(prefix)
            return True

    def re(self, regex):
        m = regex.match(self.text, self.pos)
        if m:
            self.pos = m.end()
            return m

    def at_eof(self):
        return self.pos == self.len

    def expect_eof(self):
        if not self.at_eof():
            self.fail("Expecting EOF")

    def line(self):
        return self.text.count('\n', 0, self.pos) + 1 + self.line_delta

    def fail(self, msg=None):
        m = "ERROR "
        m += "line {} ".format(self.line())
        if msg:
            m += ": " + msg.strip()
        raise ParseError(m)

# -----------------------------------------------------------------------------

    def bullet(self):
        pos = self.pos
        if self.re(R_bullet_item):
            return 'item'
        if self.re(R_bullet_enum):
            return 'enum'
        m = self.re(R_bullet_roman)
        if m:
            try:
                roman.fromRoman(m.group(1).upper())
            except roman.InvalidRomanNumeralError:
                self.pos = pos
                return
            return 'enum'

    def is_bullet(self):
        pos = self.pos
        x = self.bullet()
        self.pos = pos
        return x

# -----------------------------------------------------------------------------

    def i_parse(self):
        if R_ws.match(self.text, self.pos):
            return
        x = self.i_many()
        assert not (self.in_emph or self.in_math)
        return x

    def i_newline(self):
        pos = self.pos
        m = self.re(R_newline)
        if (m
                and self.indent == len(m.group(1))
                and not self.is_bullet()):
            return True
        self.pos = pos

    def i_text(self):
        x = ""
        if self.i_newline():
            x += "\n"
        while True:
            m = self.re(R_text)
            if not m:
                break
            x += R_ws.sub(' ', m.group(0))
            if not self.i_newline():
                break
            x = x.rstrip() + "\n"
        if x:
            return ('text', x)

    def i_comment(self):
        m = self.re(R_comment)
        if m:
            return ('comment', m.group(1))

    def i_emph(self):
        if self.in_emph or not self.startswith('*'):
            return
        self.in_emph = True
        x = self.i_many()
        if not (x and self.startswith('*')):
            self.fail('Unfinished *_*')
        self.in_emph = False
        return ('emph', x)

    def i_many(self):
        return many(self.i_any)

    def i_any(self):
        return (
            self.i_text() or
            self.i_comment() or
            self.i_emph() or
            self.i_abbrev() or
            self.i_math())

    def i_abbrev(self):
        m = self.re(R_abbrev)
        if m:
            return ('abbrev', m.group(1))

    def i_math(self):
        if not self.startswith("$"):
            return
        full = self.startswith("$")
        depth = 0
        i = self.pos
        while i < self.len:
            c = self.text[i]
            if c == '\\':
                i += 1
                if i == self.len:
                    self.fail("Trailing backslash")
            elif c == '{':
                depth += 1
            elif c == '}':
                if depth == 0:
                    self.fail("Unbalanced braces")
                depth -= 1
            elif c == '$' and depth == 0:
                tex = self.text[self.pos:i]
                self.pos = i + 1
                if full and not self.startswith("$"):
                    self.fail("Invalid $")
                tex = symbols_to_unicode(tex)
                return ('mathfull' if full else 'math', tex.strip())
            i += 1

# -----------------------------------------------------------------------------

    def p_meta(self):
        m = self.re(R_meta1)
        if m:
            self.meta['title'] = m.group(1).strip()
            self.re(R_ws)

        m = self.re(R_meta)
        if not m:
            return

        while not self.re(R_meta):
            m = self.re(R_meta_preamble)
            if m:
                self.meta['preamble'] = m.group(1).strip()
                return
            m = self.re(R_meta_line)
            if not m:
                self.fail("Invalid meta line")
            self.meta[m.group(1)] = m.group(2).strip()

    def p_block(self):
        assert self.indent == 0
        self.p_newcommand()
        return (
            self.p_header() or
            self.p_header_thm() or
            self.p_list() or
            self.p_paragraph())

# -----------------------------------------------------------------------------

    def p_newcommand(self):
        while True:
            m = self.re(R_newcommand)
            if not m:
                break
            x = self.meta.get('preamble', "")
            self.meta['preamble'] = x + m.group(0)

# -----------------------------------------------------------------------------

    def p_indent(self):
        pos = self.pos
        m = self.re(R_space0)
        assert m
        if self.indent != len(m.group(0)):
            self.pos = pos
            return False
        return True

    def many_indented(self, f):
        x = f()
        if not x:
            return
        L = [x]
        while True:
            pos = self.pos
            if not self.p_indent():
                break
            x = f()
            if not x:
                self.pos = pos
                break
            L.append(x)
        return L

# -----------------------------------------------------------------------------

    def p_paragraph(self):
        assert not R_ws.match(self.text, self.pos)
        if self.is_bullet():
            return
        x = self.i_parse()
        if x:
            assert self.re(R_nl)
            return ('Paragraph', x)

    def p_list(self):
        list_type = self.is_bullet()
        if not list_type:
            return

        def f():
            x = self.p_item()
            if x:
                t, i = x
                if t != list_type:
                    self.fail("Heterogeneous list")
                return ('Item', i)

        h = {'item': 'Itemize', 'enum': 'Enumerate'}
        return (h[list_type], self.many_indented(f))

    def p_item(self):
        pos0 = self.pos
        list_type = self.bullet()
        if not list_type:
            return
        ind = self.pos - pos0
        assert ind > 0

        def f():
            return self.p_paragraph() or self.p_list()

        self.indent += ind
        L = self.many_indented(f)
        self.indent -= ind

        return (list_type, L)

    def p_many_blocks(self):
        return many(self.p_block)

    def p_header(self):
        m = self.re(R_header)
        if m:
            n = len(m.group(1))
            x = parse_inline(m.group(2).strip(), self.line())
            return ('Header', n, x)

    def p_header_thm(self):
        m = self.re(R_header_thm)
        if m:
            x = parse_inline(m.group(1).strip(), self.line())
            return ('Theorem', x)


def parse_inline(txt, line=0):
    p = Parser(txt)
    p.line_delta = line
    x = p.i_parse()
    p.expect_eof()
    return x


def parse(txt):
    p = Parser(preprocess(txt))
    p.p_meta()
    p.re(R_nl0)
    blocks = p.p_many_blocks()
    p.expect_eof()
    return {'meta': p.meta, 'doc': blocks}


def many(f):
    L = []
    x = f()
    while x:
        L.append(x)
        x = f()
    return L

from linfir.mdtex.abbrev import symbols_to_unicode
