#!/usr/bin/env python
'''
Gets experimental results out of the database and table described in config.txt
and writes the data to a csv file.

Usage: speriment-output filename tags excluded'''

from sqlalchemy import create_engine, MetaData, Table
import json
import pandas as pd
import sys
from psiturk.psiturk_config import PsiturkConfig
import argparse

def parse():
    parser = argparse.ArgumentParser(description='''Retrieve and format the
            data gathered in an experiment and write it to a csv file.''')
    parser.add_argument('filename', type=str, help = '''File
            to write experimental results to in csv format.''')
    parser.add_argument('-t', '--tags', nargs='*',
            default=[], help = '''Names of all page tags and option tags, in that
            order, that were passed into the experiment.''')
    parser.add_argument('-e', '--exclude', nargs='*',
            default=[], help = '''Worker IDs of any participants whose data you don't
            want to write to the output file.''')
    return parser.parse_args()

def get_credentials():
    config = PsiturkConfig()
    config.load_config()
    DBURL = config.get('Database Parameters', 'database_url')
    TABLENAME = config.get('Database Parameters', 'table_name')
    return DBURL, TABLENAME

def retrieve(db_url, table_name, exclude = []):
    # db_url = "mysql://username:password@host.org/database_name"
    # table_name = 'my_experiment_table'
    data_column_name = 'datastring'
    # boilerplace sqlalchemy setup
    engine = create_engine(db_url)
    metadata = MetaData()
    metadata.bind = engine
    table = Table(table_name, metadata, autoload=True)
    # make a query and loop through
    s = table.select()
    rows = s.execute()

    #status codes of subjects who completed experiment
    statuses = [3,4,5,7]

    #column to retrieve
    data_column_name = 'datastring'

    # filter participants
    data = [participant[data_column_name] for participant in rows
            if participant['status'] in statuses
            and participant['uniqueid'].split(':')[0] not in exclude]
    return data

def format_data(data, user_defined_columns):
    # parse each participant's datastring as json object
    participants = [json.loads(participant) for participant in data]
    user_defined_columns

    trialdata_column_names = [
        'PageID',
        'PageText',
        'BlockIDs',
        'StartTime',
        'EndTime',
        'Iteration',
        'Condition',
        'SelectedID',
        'SelectedText',
        'Correct',
        'OptionOrder',
        'SelectedPosition'
    ]
    num_trialdata_names = len(trialdata_column_names)
    trial_column_names = [
        'UniqueID',
        'TrialNumber',
        'Version',
        'Permutation',
        'HIT',
        'WorkerID',
    ]
    column_names = trialdata_column_names + trial_column_names + user_defined_columns

    # push important information into 'trialdata' subobjects
    for participant in participants:
        for trial in participant['data']:
            trial['trialdata'] = trial['trialdata'][:num_trialdata_names] + [
                    trial['uniqueid'],
                    trial['current_trial'],
                    participant['condition'],
                    participant['counterbalance'],
                    participant['hitId'],
                    participant['workerId']
                ] + trial['trialdata'][num_trialdata_names:]

    # extract just trialdata objects
    trials = [trial['trialdata']
            for participant in participants
            for trial in participant['data']]


    data_frame = pd.DataFrame(trials, columns = column_names)
    data_frame['ReactionTime'] = data_frame['EndTime'] - data_frame['StartTime']
    return data_frame

if __name__ == '__main__':
    # usage: speriment-output filename user_defined_columns exclude
    args = parse()
    filename = args.filename
    user_defined_columns = args.tags
    exclude = args.exclude
    (db_url, table_name) = get_credentials()
    data = retrieve(db_url, table_name, exclude)
    formatted = format_data(data, user_defined_columns)
    formatted.to_csv(filename)
