from csc.divisi.tensor import DictTensor
from csc.divisi.labeled_view import LabeledView
from csc.divisi.ordered_set import OrderedSet
from csc.divisi.blend import Blend
from csc.conceptnet4.models import Assertion, Relation
from math import log, sqrt
import logging

DEFAULT_IDENTITY_WEIGHT = sqrt(5)
DEFAULT_CUTOFF = 5

log_2 = log(2)

###
### General utilities
###

relation_name_cache = None
def get_relation_names():
    global relation_name_cache
    if relation_name_cache is None:
        relation_name_cache = dict((x.id, x.name)
                                   for x in Relation.objects.all())
    return relation_name_cache


###
### Building AnalogySpace
###

class ConceptNet2DTensor(LabeledView):
    def __init__(self):
        super(ConceptNet2DTensor, self).__init__(
            DictTensor(2), [OrderedSet() for _ in '01'])

    def add_assertion(self, relation, left, right, value):
        # FIXME: doesn't actually add. Need to evaluate the impact of that before changing.
        lfeature = ('left',relation,left)
        rfeature = ('right',relation,right)
        self[left, rfeature] = value
        self[right, lfeature] = value

    def add_identity_assertion(self, relation, text, value):
        self.add_assertion(relation, text, text, value)


class ConceptNet3DTensor(LabeledView):
    def __init__(self):
        concepts, relations = OrderedSet(), OrderedSet()
        super(ConceptNet3DTensor, self).__init__(
            DictTensor(3), [concepts, relations, concepts])

    def add_assertion(self, relation, left, right, value):
        self[left, relation, right] += value

    def add_identity_assertion(self, relation, text, value):
        self[text, relation, text] += value


class MirroringCNet3DTensor(ConceptNet3DTensor):
    '''
    Every assertion (c1, r, c2) in this tensor has an inverse,
    (c2, r', c1).

    This is analogous to how the 2D tensor makes left and right features.

    Inverse relations are constructed from ordinary relations by
    prefixing a '-'.
    '''
    def add_assertion(self, relation, left, right, value):
        self[left, relation, right] += value # normal
        self[right, '-'+relation, left] += value # inverse
        
    
    
class AnalogySpaceBuilder(object):
    @classmethod
    def build(cls, **kw):
        return cls(**kw)()

    
    def __init__(self,
                 identity_weight=DEFAULT_IDENTITY_WEIGHT,
                 identity_relation=u'InheritsFrom',
                 cutoff=DEFAULT_CUTOFF,
                 tensor_class=ConceptNet2DTensor):
        self.identity_weight = identity_weight
        self.identity_relation = identity_relation
        self.cutoff = cutoff
        self.tensor_class = tensor_class


    def __call__(self):
        '''Builds a ConceptNet tensor from the database.'''
        tensor = self.queryset_to_tensor(self.queryset())
        self.add_identities(tensor)
        return tensor

    
    def queryset(self):
        return Assertion.objects.filter(
            score__gt=0,
            concept1__num_assertions__gt=self.cutoff,
            concept2__num_assertions__gt=self.cutoff)

    @staticmethod
    def get_value(score, freq):
        return freq * log(max((score+1, 1)))/log_2 / 10.0

    
    def queryset_to_tensor(self, queryset):
        '''Returns the tensor (without identities) built from the
        assertions in the given queryset.'''
        tensor = self.tensor_class()
        addToTensor = tensor.add_assertion
        get_value = self.get_value
        for (relation, concept1, concept2, score, freq) in queryset.values_list(
            'relation__name', 'concept1__text',  'concept2__text',  'score',
            'frequency__value').iterator():
            value = get_value(score, freq)
            addToTensor(relation, concept1, concept2, value)
        return tensor


    def add_identities(self, tensor):
        weight = self.identity_weight
        if weight == 0:
            logging.info('Skipping adding zero-weight identities.')
            return

        rel = self.identity_relation
        add_identity_assertion = tensor.add_identity_assertion
        logging.info('Adding identities, weight=%s', weight)
        for text in list(tensor.label_list(0)):
            add_identity_assertion(rel, text, weight)


class MonolingualAnalogySpaceBuilder(AnalogySpaceBuilder):
    @classmethod
    def build(cls, lang, **kw):
        return cls(lang=lang, **kw)()

    
    def __init__(self, lang, **kw):
        super(MonolingualAnalogySpaceBuilder, self).__init__(**kw)
        self.lang = lang

    
    def queryset(self):
        return super(MonolingualAnalogySpaceBuilder, self).queryset().filter(language=self.lang)
    
            
# Experiment: an AnalogySpace from frames
class FramedTensorBuilder(MonolingualAnalogySpaceBuilder):
    def queryset_to_tensor(self, queryset):
        tensor = self.tensor_class()
        add_assertion = tensor.add_assertion
        get_value = self.get_value

        for (rel, concept1, concept2, text1, text2, frame_id, score, freq) in queryset.values_list(
            'relation__name', 'concept1__text',  'concept2__text', 'text1', 'text2', 'frame_id', 'score', 'frequency__value'
            ).iterator():
            value = get_value(score, freq)
            # Raw
            add_assertion(frame_id, text1, text2, value)
            # Assertion
            add_assertion(rel, concept1, concept2, value)
            # NormalizesTo
            add_assertion('NormalizesTo', concept1, text1, 1)
            add_assertion('NormalizesTo', concept2, text2, 1)
            add_assertion('NormalizesTo', concept1, concept1, 1)
            add_assertion('NormalizesTo', concept2, concept2, 1)
        return tensor


# Experiment: Multilingual AnalogySpace
class MultilingualAnalogySpaceBuilder(AnalogySpaceBuilder):
    def queryset_to_tensor(self, queryset):
        tensor = self.tensor_class()
        add_assertion = tensor.add_assertion
        get_value = self.get_value

        relation_name = get_relation_names()
        for (rel_id, name1, name2, score, freq, lang) in queryset.values_list(
            'relation_id', 'concept1__text',  'concept2__text',  'score', 'frequency__value', 'language_id').iterator():

            value = get_value(score, freq)
            relation = relation_name[rel_id]
            lconcept = (name1, lang)
            rconcept = (name2, lang)
            add_assertion(relation, lconcept, rconcept, value)
        return tensor


# Experiment: One relation type at a time
class ByRelationBuilder(MonolingualAnalogySpaceBuilder):
    def queryset_for_relation(self, relation_id):
        '''Returns a QuerySet of just the assertions having the
        specified relation.'''
        return self.queryset().filter(relation__id=relation_id)

    def tensor_for_relation(self, relation_id):
        '''Returns a tensor built only from assertions of the given relation.'''
        return self.queryset_to_tensor(self.queryset_for_relation(relation_id))

    def get_tensors_by_relation(self):
        '''Returns a dictionary mapping names of relations to tensors of that kind of data.'''
        relation_name = get_relation_names()
        by_rel = [(relation_name[rel_id], self.tensor_for_relation(rel_id))
                  for rel_id in relation_name.keys()
                  if relation_name[rel_id] != 'InheritsFrom']
        return dict((name, tensor) for (name, tensor) in by_rel
                    if len(tensor) > 0)

    def identities_for_all_relations(self, byrel):
        '''Returns a tensor containing identity relations for the
        concepts in all tensors. We handle this separately so the
        blends don't include identities.'''
        # Build a tensor with all the concept labels.
        tensor = self.tensor_class()
        for other in byrel.itervalues():
            tensor._labels[0].extend(other._labels[0])
        # Add identities to that as normal.
        self.add_identities(tensor)
        return tensor

    def __call__(self):
        byrel = self.get_tensors_by_relation()
        identity_tensor = self.identities_for_all_relations(byrel)
        return Blend(byrel.values() + [identity_tensor])

# A backwards-compatibility method.
def load_one_type(lang, relation, identities, cutoff):
    return ByRelationBuilder(lang, identity_weight=identities, cutoff=cutoff).tensor_for_relation(relation)

# Experiment: Flatten (just by weighted concepts)
class FlatConceptNetTensor(LabeledView):
    def __init__(self):
        concepts = OrderedSet()
        super(ConceptNet2DTensor, self).__init__(DictTensor(2), [concepts, concepts])

class FlatASpaceBuilder(MonolingualAnalogySpaceBuilder):
    def __init__(self, forward_weight_by_relation, forward_default_weight,
                 inverse_weight_by_relation, inverse_default_weight,
                 min_identity_weight=0.0, tensor_class=FlatConceptNetTensor, **kw):
        super(FlatASpaceBuilder, self).__init__(tensor_class=tensor_class, **kw)
        self.get_forward_weight = lambda relation: forward_weight_by_relation.get(relation, forward_default_weight)
        self.get_inverse_weight = lambda relation: inverse_weight_by_relation.get(relation, inverse_default_weight)
        self.min_identity_weight = min_identity_weight

    def queryset_to_tensor(self, queryset):
        tensor = self.tensor_class()
        ## Micro-optimization:
        #tensor._labels[1] = tensor._labels[0]
        get_forward_weight = self.get_forward_weight
        get_inverse_weight = self.get_inverse_weight
        get_value = self.get_value

        for (relation, concept1, concept2, score, freq) in queryset.values_list(
            'relation__name', 'concept1__text',  'concept2__text',  'score',
            'frequency__value').iterator():

            value = get_value(score, freq)
            # Add the forward link
            fwd_val = get_forward_weight(relation)
            if fwd_val: tensor[concept1, concept2] += fwd_val * value
            # Add the reverse link
            rev_val = get_inverse_weight(relation)
            if rev_val: tensor[concept2, concept1] += rev_val * value
        return tensor

    def add_identities(self, tensor):
        min_identity_weight = self.min_identity_weight
        if min_identity_weight:
            # Ensure that the minimum weight of any concept with itself is min_identity_weight
            for concept in tensor.label_list(0):
                if tensor[concept, concept] < min_identity_weight:
                    tensor[concept, concept] = min_identity_weight
    


# Compatibility API
def rename_elt(dct, old, new):
    if old in dct:
        dct[new] = dct[old]
        del dct[old]
        
def conceptnet_2d_from_db(lang, builder=MonolingualAnalogySpaceBuilder, **kw):
    '''Build a ConceptNet tensor in the given language.'''
    # Handle an old parameter
    rename_elt(kw, 'identities', 'identity_weight')
    return builder(lang=lang, **kw)()


def conceptnet_selfblend(lang, **kw):
    return conceptnet_2d_from_db(lang, builder=ByRelationBuilder, **kw)


### Experiment: Add cooccurrences
class ConceptNetTensorWithCooccurrences(ConceptNet2DTensor):
    @classmethod
    def get_constructor(cls, cooccurrence_weight):
        def constructor():
            return cls(cooccurrence_weight)
        return constructor


    def __init__(self, cooccurrence_weight, *a, **kw):
        super(ConceptNet2DTensor, self).__init__(*a, **kw)
        self.cooccurrence_weight = cooccurrence_weight
        
    def add_assertion(self, relation, left, right, value):
        # Add the normal assertion.
        super_add_assertion = super(ConceptNetTensorWithCooccurrences, self)
        super_add_assertion.add_assertion(relation, left, right, value)
    
        # Split apart right-side concepts.
        rel = 'CooccursWith'
        for right_side_word in right.split():
            super_add_assertion(rel, left, right_side_word, value)

def conceptnet_2d_with_cooccurrences(lang, cooccurrence_weight=1.0, **kw):
    kw['tensor_class'] = ConceptNetTensorWithCooccurrences.get_constructor(cooccurrence_weight)
    return conceptnet_2d_from_db(lang, **kw)

###
### Analysis helpers
###

def concept_similarity(svd, concept):
    return svd.u_dotproducts_with(svd.weighted_u_vec(concept))

def predict_features(svd, concept):
    return svd.v_dotproducts_with(svd.weighted_u_vec(concept))

def feature_similarity(svd, feature):
    return svd.v_dotproducts_with(svd.weighted_v_vec(feature))

def predict_concepts(svd, feature):
    return svd.u_dotproducts_with(svd.weighted_v_vec(feature))

def make_category(svd, concepts=[], features=[]):
    from operator import add
    components = (
        [svd.weighted_u_vec(concept) for concept in concepts] +
        [svd.weighted_v_vec(feature) for feature in features])
    return reduce(add, components)


def category_similarity(svd, cat):
    '''Return all the features and concepts that are close to the given
    category, as (concepts, features), both labeled dense tensors.

    Example usage:
    concepts, features = category_similarity(svd, cat)
    concepts.top_items(10)
    features.top_items(10)'''
    return svd.u_dotproducts_with(cat), svd.v_dotproducts_with(cat)

def eval_assertion(svd, relation, left, right):
    lfeature = ('left',relation,left)
    rfeature = ('right',relation,right)

    # Evaluate right feature
    try:
        rfeature_val = svd.get_ahat((left, rfeature))
    except KeyError:
        rfeature_val = 0

    # Evaluate left feature
    try:
        lfeature_val = svd.get_ahat((right, lfeature))
    except KeyError:
        lfeature_val = 0

    return lfeature_val, rfeature_val


