#!/usr/bin/env python3

"""\
PicoSAT solver
"""

import argparse
import sys

from pyeda.boolalg.expr import ast2expr
from pyeda.boolalg import picosat
from pyeda.parsing import dimacs

# Return values (standardized by http://www.satcompetition.org):
UNKNOWN, SATISFIABLE, UNSATISFIABLE = 0, 10, 20

class PicosatArgumentParser(argparse.ArgumentParser):
    """Customized PicoSAT argument parser"""
    def exit(self, status=0, message=None):
        """Format command-line errors like PicoSAT."""
        if message:
            print("*** picosat:", message)
        sys.exit(status)

    def error(self, message):
        """All parsing errors return UNKNOWN."""
        self.exit(UNKNOWN, message)

PARSER = PicosatArgumentParser(description=__doc__)

def _nonzero(s):
    """Verify a non-zero input integer."""
    try:
        value = int(s)
    except ValueError:
        fstr = "invalid assumption: {}"
        raise argparse.ArgumentTypeError(fstr.format(s))
    if value == 0:
        raise argparse.ArgumentTypeError("zero assumption")
    return value

PARSER.add_argument(
    '--version', action='version', version=picosat.VERSION,
    help="print version and exit"
)
PARSER.add_argument(
    '-n', action='store_false', dest='printsoln',
    help="do not print satisfying assignment"
)
PARSER.add_argument(
    '-a', metavar='LIT', action='append', type=_nonzero, dest='assumptions',
    help="start with one or more literal assumptions"
)
PARSER.add_argument(
    '-v', action='store_true', dest='verbose',
    help="enable verbose output"
)
PARSER.add_argument(
    '-i', choices=[0, 1, 2, 3], default=2, dest='default_phase',
    help=("initial phase {0: FALSE, 1: TRUE, 2: JWH, 3: RAND} "
          "(default: %(default)s)")
)
PARSER.add_argument(
    '-P', metavar='LIMIT', type=int, default=-1, dest='propagation_limit',
    help='set propagation limit; negative means "none" (default: %(default)s)'
)
PARSER.add_argument(
    '-l', metavar='LIMIT', type=int, default=-1, dest='decision_limit',
    help='set decision limit; negative means "none" (default: %(default)s)'
)
PARSER.add_argument(
    '--all', action='store_true',
    help="enumerate all solutions"
)
PARSER.add_argument(
    'file', nargs='?', default=sys.stdin, type=argparse.FileType('r'),
    help="CNF input file (default: stdin)"
)

def main():
    opts = PARSER.parse_args()

    try:
        ast = dimacs.parse_cnf(opts.file.read())
    except dimacs.Error as exc:
        print("error parsing file:", opts.file.name)
        print(exc)
        return UNKNOWN

    expr = ast2expr(ast).simplify()

    # Zero doesn't have a CNF encoding
    if expr.is_zero():
        print('s', 'UNSATISFIABLE')
        return UNSATISFIABLE

    _, nvars, clauses = expr.encode_cnf()

    config = {
        'verbosity' : opts.verbose,
        'default_phase' : opts.default_phase,
        'propagation_limit' : opts.propagation_limit,
        'decision_limit' : opts.decision_limit
    }

    try:
        if opts.all:
            gen = picosat.satisfy_all(nvars, clauses, **config)
            cnt = 0
            for soln in gen:
                if opts.printsoln:
                    print('s', 'SATISFIABLE')
                    print('v', *_soln2values(soln))
                cnt += 1
            print('s', 'SOLUTIONS', cnt)
            # This unexpected result matches PicoSAT
            return UNSATISFIABLE
        else:
            soln = picosat.satisfy_one(nvars, clauses,
                                       assumptions=opts.assumptions,
                                       **config)
            if soln is None:
                print('s', 'UNSATISFIABLE')
                return UNSATISFIABLE
            else:
                print('s', 'SATISFIABLE')
                if opts.printsoln:
                    print('v', *_soln2values(soln))
                return SATISFIABLE
    except picosat.Error as exc:
        print(exc)
        return UNKNOWN

def _soln2values(soln):
    """Convert a PicoSAT solution into an array of lit values."""
    values = [i if val > 0 else -i
              for i, val in enumerate(soln, start=1)]
    values.append(0)
    return values


if __name__ == '__main__':
    sys.exit(main())

