#!/usr/bin/env python

# BEGIN_COPYRIGHT
# 
# Copyright 2009-2013 CRS4.
# 
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy
# of the License at
# 
#   http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
# 
# END_COPYRIGHT

import logging
logging.basicConfig(level=logging.INFO)

import pydoop.test_support as pts
from psrunner import PydoopScriptRunner


SCRIPT = "base_histogram.py"
LOCAL_INPUT = "example.sam"


def get_expected_res():
  bases = []
  with open(LOCAL_INPUT) as f:
    for line in f:
      bases.append(line.split("\t", 10)[-2])
  bases = "".join(bases)
  # TODO: when we drop support for Python 2.6, use collections.Counter
  return dict((b, bases.count(b)) for b in "ACGT")


def main():
  logger = logging.getLogger("main")
  logger.setLevel(logging.INFO)
  runner = PydoopScriptRunner(logger=logger)
  runner.set_input(LOCAL_INPUT, put=True)
  runner.run(SCRIPT)
  res = runner.collect_output()
  runner.clean()
  logger.info("checking results")
  res = pts.parse_mr_output(res, vtype=int)
  expected_res = get_expected_res()
  if res == expected_res:
    print "OK."
  else:
    print "ERROR"
    for r, n in (res, 'result'), (expected_res, 'expected_result'):
      out_fn = "%s.txt" % n
      with open(out_fn, "w") as fo:
        for b, c in sorted(r.iteritems()):
          fo.write("%s: %s\n" % (b, c))
        print "wrote %r" % (out_fn,)


if __name__ == "__main__":
  main()
