#!/usr/bin/env python

# BEGIN_COPYRIGHT
# 
# Copyright 2009-2013 CRS4.
# 
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy
# of the License at
# 
#   http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
# 
# END_COPYRIGHT

import logging, csv
logging.basicConfig(level=logging.INFO)

from psrunner import PydoopScriptRunner


SCRIPT = "transpose.py"
LOCAL_INPUT = "matrix.txt"


# We don't want to depend on NumPy just for this example
def transpose(a):
  t_a = [[] for _ in xrange(len(a[0]))]
  for r in a:
    for j, x in enumerate(r):
      t_a[j].append(x)
  return t_a


def get_expected_res():
  with open(LOCAL_INPUT) as f:
    reader = csv.reader(f, delimiter="\t")
    a = [r for r in reader]
    return transpose(a)


def parse_mr_output(res):
  res = [l.split() for l in res.splitlines(False)]
  for l in res:
    l[0] = int(l[0])
  res.sort()
  return [l[1:] for l in res]


def main():
  logger = logging.getLogger("main")
  logger.setLevel(logging.INFO)
  opts = [
    '--num-reducers', '4',
    '-D', 'mapred.map.tasks=2'
    ]
  runner = PydoopScriptRunner(logger=logger)
  runner.set_input(LOCAL_INPUT, put=True)
  runner.run(SCRIPT, more_args=opts)
  res = runner.collect_output()
  runner.clean()
  logger.info("checking results")
  res = parse_mr_output(res)
  expected_res = get_expected_res()
  if res == expected_res:
    print "OK."
  else:
    print "ERROR"
    for r, n in (res, 'result'), (expected_res, 'expected_result'):
      out_fn = "%s.txt" % n
      with open(out_fn, "w") as fo:
        for row in r:
          fo.write("\t".join(map(str, row))+"\n")
        print "wrote %r" % (out_fn,)


if __name__ == "__main__":
  main()
