import itertools
from kessel import *

value = forward()

wspace = optional(word(" \n\r\t"))

hex_ = one_of(*"0123456789abcdefABCDEF")


@mapf(count(4, hex_))
def unicode_ord(xs):
    return chr(int("".join(xs), 16))

escape = single("\\") >> ( single("\"")
                         | single("\\")
                         | single("/")
                         | single("b") >> unit(lambda: "\b")
                         | single("f") >> unit(lambda: "\f")
                         | single("n") >> unit(lambda: "\n")
                         | single("r") >> unit(lambda: "\r")
                         | single("t") >> unit(lambda: "\t")
                         | single("u") >> unicode_ord
                         )


@mapf(between(single("\""), single("\""), many(escape | none_of(*"\\\""))))
def string(xs):
    return "".join(xs)


digit = one_of(*"0123456789")


@gen_parser
def number():
    sign = yield optional(single("-") >> unit(lambda: -1),
                          lambda: 1)

    is_float = False

    try:
        number = yield single("0")
    except Unexpected:
        number = yield one_of(*"123456789")
        number += "".join((yield many(digit)))

    try:
        number += (yield single("."))
    except Unexpected:
        pass
    else:
        is_float = True
        number += "".join((yield many(digit)))

    try:
        number += (yield one_of(*"eE"))
    except Unexpected:
        pass
    else:
        is_float = True
        number += (yield ( one_of(*"+-")
                         | unit(lambda: "")
                         ))
        number += "".join((yield many(digit)))

    return sign * (float if is_float else int)(number)


array_items = sep_by1(value, wspace >> single(",") >> wspace)

array = between(single("[") >> wspace, wspace << single("]"),
                optional(array_items, lambda: []))


@mapf_star((value << wspace << single(":")) + (wspace >> value))
def object_item(k, v):
    return (k, v)


@mapf(sep_by1(object_item, wspace >> single(",") >> wspace))
def object_items(xs):
    return dict(xs)

object_ = between(single("{") >> wspace, wspace << single("}"),
                  optional(object_items, lambda: {}))

value.set( string
         | number
         | object_
         | array
         | literal("true") >> unit(lambda: True)
         | literal("false") >> unit(lambda: False)
         | literal("null") >> unit(lambda: None)
         )

json = wspace >> value << wspace << eof


def load(f):
    return json.parse(itertools.chain.from_iterable(f))

if __name__ == "__main__":  # pragma: no cover
    import pprint

    with open("test.json") as f:
        pprint.pprint(load(f))
