#!/usr/bin/env python

# BEGIN_COPYRIGHT
# 
# Copyright 2009-2013 CRS4.
# 
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy
# of the License at
# 
#   http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
# 
# END_COPYRIGHT

import os, tempfile, shutil, operator, logging
logging.basicConfig(level=logging.INFO)

from ipcount import main as ipcount


INPUT = "input/access.log"
EXCLUDES = "excludes.txt"


def local_ipcount(input_fn, excludes_fn=None):
  count = {}
  if excludes_fn:
    with open(excludes_fn) as f:
      excludes = set(l.strip() for l in f if not l.isspace())
  else:
    excludes = set()
  with open(input_fn) as f:
    for line in f:
      if line.isspace():
        continue
      ip = line.split(None, 1)[0]
      if ip in excludes:
        continue
      count[ip] = count.get(ip, 0) + 1
  return sorted(count.iteritems(), key=operator.itemgetter(1), reverse=True)


def run_ipcount(input_fn, excludes_fn=None):
  temp_dir = tempfile.mkdtemp(prefix="pydoop_test_ipcount_")
  outfn = os.path.join(temp_dir, "ipcount.out")
  args = ["-i", input_fn, "-o", outfn, "-n", "0"]
  if excludes_fn:
    args.extend(["-e", EXCLUDES])
  ipcount(args)
  count = []
  with open(outfn) as f:
    for l in f:
      if not l.isspace():
        ip, c = l.strip().split("\t")
        count.append((ip, int(c)))
  shutil.rmtree(temp_dir)
  return count


# we need secondary sorting on IPs for comparisons
def normalize_count(count):
  return sorted(sorted(count), key=operator.itemgetter(1))


def main():
  count = normalize_count(run_ipcount(INPUT, EXCLUDES))
  logging.info("checking results")
  expected_count = normalize_count(local_ipcount(INPUT, EXCLUDES))
  if count == expected_count:
    print "OK."
  else:
    with open("count.dump", "w") as fo:
      fo.write(repr(count)+"\n")
    with open("expected_count.dump", "w") as fo:
      fo.write(repr(expected_count)+"\n")


if __name__ == '__main__':
  main()
