import subprocess
from collections import Counter
import math
import re

import nltk
from nltk.corpus.reader import wordnet
from sklearn.feature_extraction.text import CountVectorizer



from media.stoplist import stop_list






SEM_PRED = '/home/michael/dev/Linguistics/linguistic-helper-functions/linghelper/media/SemPred.pl'

class SemanticPredictabilityAnalyzer:
    def __init__(self,ngram=False,use_idf=False):
        self.ngram = ngram
        self.use_idf = use_idf
        try:
            wordnet_path = nltk.data.find('corpora/wordnet')
        except LookupError:
            nltk.download('wordnet')
            wordnet_path = nltk.data.find('corpora/wordnet')
        self.wn = wordnet.WordNetCorpusReader(wordnet_path)
        pentagram_vectorizer = CountVectorizer(ngram_range=(1, 5),
                                     token_pattern=ur'\b[A-Za-z]+\b', min_df=1,stop_words=stop_list)
        self.pent_analyze = pentagram_vectorizer.build_analyzer()
        unigram_vectorizer = CountVectorizer(ngram_range=(1, 1),
                                     token_pattern=r'\b[A-Za-z]+\b', min_df=1,stop_words=stop_list)
        self.uni_analyze = unigram_vectorizer.build_analyzer()
        from media.idf_scores import scores as IDF
        self.IDF = IDF

    def reduce_sense(self,sense):
        if self.ngram:
            bag = set(self.pent_analyze(sense.definition))
            for e in sense.examples:
                bag.update(self.pent_analyze(e))
        else:
            bag = set(self.uni_analyze(sense.definition))
            for e in sense.examples:
                bag.update(self.uni_analyze(e))
        return bag

    def get_idf_scores(self):
        scores = Counter()
        for s in self.wn.all_synsets():
            scores.update(self.reduce_sense(s))
        tot_count = float(len(list(wn.all_synsets())))
        for k in scores:
            scores[k] = -1 * math.log(float(scores[k])/tot_count)
        return scores



    def bag_of_words(self,sense):
        bag = self.reduce_sense(sense)
        for l in [sense.hypernyms(),sense.hyponyms(),
                    sense.member_holonyms(),sense.substance_holonyms(),
                    sense.part_holonyms(),sense.member_meronyms(),
                    sense.substance_meronyms(),sense.part_meronyms(),
                    sense.topic_domains(),sense.region_domains(),
                    sense.usage_domains(),sense.attributes(),
                    sense.entailments(),sense.causes(),
                    sense.also_sees(),sense.verb_groups(),
                    sense.similar_tos()]:
            for s in l:
                bag.update(self.reduce_sense(s))
        return bag

    def relatedness(self,sense_one,sense_two):
        bag_one = self.bag_of_words(sense_one)
        bag_two = self.bag_of_words(sense_two)
        inter = bag_one & bag_two
        if self.ngram:
            score = 0.0
            for i in inter:
                for j in inter:
                    if i != j and i in j.split(' '):
                        break
                else:
                    if self.use_idf:
                        score += reduce(lambda x,y:x*y,[self.IDF[x] for x in i.split(' ')])
                    else:
                        score += float(len(i.split(' '))) * float(len(i.split(' ')))
            return score
        else:
            if self.use_idf:
                return sum([self.IDF[i] for i in inter])
            return len(inter)

    def get_semantic_predictability(self,word_sense,context_senses,debug=False,style='A'):
        if isinstance(word_sense,str):
            word_sense = self.to_wordnet_sense(word_sense)
        score = 0.0
        if word_sense is None:
            return score
        for s in context_senses:
            if isinstance(s,str):
                s = self.to_wordnet_sense(s)
            if s is None:
                continue
            score += self.relatedness(word_sense,s)
        if style == 'A':
            try:
                score = float(score)/float(len(context_senses))
            except ZeroDivisionError:
                pass
        return score

    def to_wordnet_sense(self,sense_string):
        try:
            return self.wn.synset(sense_string)
        except wordnet.WordNetError:
            return None



    def disambiguate_sense(self,word,cat,context,to_string=False):
        synsets = self.wn.synsets(word,pos=cat)
        if len(synsets) == 0:
            return None
        best_sense = synsets[0]
        best_score = 0
        for s in synsets:
            words = self.reduce_sense(s)
            score = sum([ self.IDF[x] for x in context if x in words])
            if score > best_score:
                best_sense = s
        if to_string:
            return best_sense.name
        return best_sense

def perl_get_semantic_relatedness(word,context,debug=False,style='A'):
    com = ["perl",SEM_PRED,word,','.join(context)]
    p = subprocess.Popen(com,stdout=subprocess.PIPE,stderr=subprocess.PIPE,stdin=subprocess.PIPE)
    stdout, stderr = p.communicate()
    if debug:
        print stdout
        print stderr
    if stdout == '':
        return 0.0
    sp = stdout.split(",")
    spsum = sum(map(float,sp))
    if style == 'A':
        if spsum > 0:
            return spsum / float(len(sp))
        return 0.0
    return spsum

def get_scores():
    import os
    import pprint
    scores = get_idf_scores()
    print len(scores)
    pprint.pprint(dict(scores),open(os.path.join(os.path.dirname(os.path.abspath(__file__)),'idf_scores.py'),'w'),indent=4,width=80)

def evaluate_sentences():
    unigram_noidf_analyzer = SemanticPredictabilityAnalyzer()
    unigram_idf_analyzer = SemanticPredictabilityAnalyzer(use_idf=True)
    ngram_noidf_analyzer = SemanticPredictabilityAnalyzer(ngram=True)
    ngram_idf_analyzer = SemanticPredictabilityAnalyzer(ngram=True,use_idf=True)
    import os
    test_dir = '/home/michael/devR/503Project'
    head_scar = None
    sentences = []
    with open(os.path.join(test_dir,'scarborough.txt'),'r') as f:
        for line in f:
            l = line.strip().split('\t')
            if len(l) == 1:
                continue
            if head_scar is None:
                head_scar = l
                continue
            newline = {head_scar[i]:l[i] for i in range(len(l))}
            sentences.append(newline)
    head = None
    with open(os.path.join(test_dir,'kalikow.txt'),'r') as f:
        for line in f:
            l = line.strip().split('\t')
            if len(l) == 1:
                continue
            if head is None:
                head = l
                continue
            newline = {head[i]:l[i] for i in range(len(l)) if head[i] in head_scar}
            sentences.append(newline)
    with open(os.path.join(test_dir,'measures.txt'),'w') as f:
        head = head_scar + ['perl_sim','bag_words_noidf','bag_ngrams_noidf','bag_words_idf','bag_ngrams_idf']
        f.write('\t'.join(head))
        for s in sentences:
            f.write('\n')
            context = map(lambda x: x+'#1',s['Context'].split(','))
            word = s['Final word'] + '#n#1'
            s['perl_sim'] = perl_get_semantic_relatedness(word,context)
            s['bag_words_noidf'] = unigram_noidf_analyzer.get_semantic_predictability(word.replace('#','.'),map(lambda x: x.replace('#','.'),context))
            s['bag_ngrams_noidf'] = ngram_noidf_analyzer.get_semantic_predictability(word.replace('#','.'),map(lambda x: x.replace('#','.'),context))
            s['bag_words_idf'] = unigram_idf_analyzer.get_semantic_predictability(word.replace('#','.'),map(lambda x: x.replace('#','.'),context))
            s['bag_ngrams_idf'] = ngram_idf_analyzer.get_semantic_predictability(word.replace('#','.'),map(lambda x: x.replace('#','.'),context))
            f.write('\t'.join([str(s[x]) for x in head]))

if __name__ == '__main__':
    evaluate_sentences()
    #t = get_semantic_predictability('jaguar.n.1',['leopard.n.2'],ngram=True)
    #print t
    #print pent_analyze('hello hi how')
