#!/usr/bin/env python
"""
This script will generate a csv of course content statistics.
"""

import csv
import StringIO
import sys

# Setup environment here, before importing project specific stuff
from xsiftx.tools import enter_lms
enter_lms(sys.argv[1], sys.argv[2])

from collections import OrderedDict
from django.core.cache import get_cache
from django.dispatch import Signal
from request_cache.middleware import RequestCache

from courseware.views import *
from opaque_keys.edx.locator import CourseLocator
from student.roles import CourseStaffRole
from xmodule.modulestore.django import modulestore


cache = get_cache('mongo_metadata_inheritance')
store = modulestore()
store.metadata_inheritance_cache_subsystem = cache
store.request_cache = RequestCache.get_request_cache()
modulestore_update_signal = Signal(providing_args=[
    'modulestore', 'course_id', 'location',
])
store.modulestore_update_signal = modulestore_update_signal


class CourseAxis(object):
    def __init__(self, course_id):
        '''
        course_id should be eg: MITx/3.091r/2013_Fall
        '''
        self.course_id = course_id

        # Get course locator
        self.course = CourseLocator(*tuple(course_id.split('/')))  # org, course, run

        # Get flat XModuleDescriptor list for course
        self.ms = modulestore()
        self.modules = self.ms.get_items(self.course)

        self.compute_statistics()


    def compute_statistics(self):
        '''
        Find number of StudentModule instances for each module - and
        differentiate those with empty state
        '''
        self.stats = OrderedDict()
        staff_role = CourseStaffRole(self.course)
        for node in self.modules:
            smq = StudentModule.objects.filter(
                module_state_key=node.location, student__is_staff=False
            )
            smq_staff = smq.filter(student__in=staff_role.users_with_role())
            smq_student = smq.exclude(student__in=staff_role.users_with_role())

            naccess = smq_student.count()
            nengaged = smq_student.exclude(state='{}').count()

            naccess_staff = smq_staff.count()
            nengaged_staff = smq_staff.exclude(state='{}').count()

            loc = str(node.location).replace('i4x://%s/' % (
                self.course_id.rsplit('/', 1)[0]), '')
            self.stats[node.location] = {
                'naccess': naccess, 'nengaged': nengaged,
                'display_name': node.display_name,
                'location': loc, 'category': node.category,
                'naccess_staff': naccess_staff,
                'nengaged_staff': nengaged_staff,
            }

    def dump_csv(self, fp):
        '''
        Dump statistics as a CSV file
        fp = file pointer (open and available for writing)
        '''
        fieldnames = [
            'category',
            'display_name',
            'naccess',
            'nengaged',
            'location',
            'naccess_staff',
            'nengaged_staff'
        ]
        writer = csv.DictWriter(
            fp, fieldnames,
            dialect='excel', quotechar='"',
            quoting=csv.QUOTE_ALL
        )
        writer.writeheader()
        for row in self.stats.values():
            try:
                writer.writerow(row)
            except Exception as err:
                sys.stderr.write("Oops, failed to write, "
                                 "error={0!r}\n".format(err))

if __name__ == "__main__":

    help_txt = ("Generate CSV file with course content usage statistics; "
                "CSV file columns are category, display_name, naccess, "
                "nengaged, location for all modules in the course.  \n"
                "Arguments: filename.csv")

    if not len(sys.argv) == 5:
        sys.stderr.write('Usage:\n{0}\n'.format(help_txt))
        sys.stderr.write('A file name must be specified for output.\n')
        sys.exit(-1)
    filename = sys.argv[4]
    if not filename.endswith('.csv'):
        sys.stderr.write('Usage:\n{0}\n'.format(help_txt))
        sys.stderr.write('File name must end in .csv\n')
        sys.exit(-1)
    course_id = sys.argv[3]
    ca = CourseAxis(course_id)
    output_buf = StringIO.StringIO()
    ca.dump_csv(output_buf)
    print(filename)
    print(output_buf.getvalue())
