"""
"""
import graphlab as _graphlab
import graphlab.connect as _mt
from graphlab.toolkits.model import Model as _Model
from graphlab.data_structures.sframe import SFrame as _SFrame
from graphlab.data_structures.sarray import SArray as _SArray
import array as _array
from itertools import izip


class TopicModel(_Model):
    """
    TopicModel is a class of objects returned by `topic_model.create()`.
    """

    def __init__(self, model_proxy):
        self.__proxy__ = model_proxy

    def _get_wrapper(self):
        def model_wrapper(model_proxy):
            return TopicModel(model_proxy)
        return model_wrapper

    def list_fields(self):
        _mt._get_metric_tracker().track('toolkit.text.topic_model.list_fields')
        opts = {'model': self.__proxy__}
        response = _graphlab.toolkits.main.run("text_topicmodel_list_fields", opts)
        return response.keys()

    def get(self, field):
        _mt._get_metric_tracker().track('toolkit.text.topic_model.get')
        opts = {'model': self.__proxy__, 'field': field}
        response = _graphlab.toolkits.main.run("text_topicmodel_get_value", opts)
        if field == 'vocabulary':
            return _SArray(None, _proxy=response['value'])
        elif field == 'topics':
            return _SFrame(None, _proxy=response['value'])
        return response['value']

    def __str__(self):
        return self.__repr__()

    def __repr__(self):
        s = "Topic Model\n"
        s += "  Data:\n"
        s += "      Vocabulary size:  %i\n" % len(self['vocabulary'])
        s += "  Settings:\n"
        s += "      Number of topics: {0}\n".format(self['num_topics'])
        s += "      alpha:            {0}\n".format(self['alpha'], 3)
        s += "      beta:             {0}\n".format(self['beta'], 3)
        s += "      Iterations:       {0}\n".format(self['num_iterations'])
        s += "      Verbose:          {0}\n".format(self['verbose']==1)
        s += "  Accessible attributes:\n"
        s += "      m['topics']         An SFrame containing the topics.\n"
        s += "      m['vocabulary']     An SArray containing the topics.\n"
        s += "  Useful methods:\n"
        s += "      m.get_topics()       Get the most probable words per topic.\n"
        s += "      m.predict(new_docs)  Make predictions for new documents.\n"

        return s

    def summary(self):
        """
        Print a summary of the model.
        """

        _mt._get_metric_tracker().track('toolkit.text.topic_model.summary')
        print self.__repr__()

    def get_topics(self, topic_ids=None, num_words=5, cdf_cutoff=1.0):
        """
        Get the words associated with a given topic. The score column is the
        probability of choosing that word given that you have chosen a
        particular topic.

        Parameters
        ----------
        topic_ids : list of int or str, optional
            The topics to retrieve words. Topic ids are zero-based.
            This list can also contain the names of the desired topics, which
            by default are `topic_0`, `topic_1`, etc.
            Throws an error if greater than or equal to m['num_topics'], or
            if the requested topic name is not present.

        num_words : int, optional
            The number of words to show.

        cdf_cutoff : float
            Allows one to only show the most probable words whose cumulative
            probability is below this cutoff. For example if there exist
            three words where
               p(word_1 | topic_k) = .1
               p(word_2 | topic_k) = .2
               p(word_3 | topic_k) = .05
            then setting cdf_cutoff=.3 would return only word_1 and word_2
            since p(word_1 | topic_k) + p(word_2 | topic_k) <= cdf_cutff.

        Returns
        -------
        out : SFrame
            An SFrame with a column of words ranked by a column of scores for
            each topic.

        Examples
        --------

        Get the highest ranked words for all topics.

        >>> m.get_topics()

        Get the highest ranked words for topics 0 and 1. Show 15 words.

        >>> m.get_topics(['topic_0', 'topic_1'], num_words=15)

        Alternatively one can achieve this via

        >>> m.get_topics([0, 1], num_words=15)

        If one wants to rearrange this into dictionaries, one can use
        :py:func:`~graphlab.SFrame.unstack`:

        >>> topics = m.get_topics()
        >>> topics.unstack(['word', 'score'], 'word_score')['word_score']

        """
        _mt._get_metric_tracker().track('toolkit.text.topic_model.get_topics')


        if topic_ids is None:
            topic_ids = range(self.get('num_topics'))

        assert isinstance(topic_ids, list), \
            "The provided topic_ids is not a list."

        if any([type(x) == str for x in topic_ids]):
            raise ValueError, \
                "Only integer topic_ids can be used at this point in time."
        if not all([x >= 0 and x < self['num_topics']]):
            raise ValueError, \
                "Topic id values must be non-negative and less than the \
number of topics used to fit the model."

        opts = {'model': self.__proxy__,
                'topic_ids': topic_ids,
                'num_words': num_words,
                'cdf_cutoff': cdf_cutoff}
        response = _graphlab.toolkits.main.run('text_topicmodel_get_topic',
                                               opts)
        sf =_SFrame(None, _proxy=response['top_words'])
        return sf

    def predict(self, dataset, output_type='assignment'):
        """
        Use the model to predict topics for each document. The provided
        `dataset` should be an SArray object where each element is a dict
        representing a single document in bag-of-words format, where keys
        are words and values are their corresponding counts.

        The current implementation will make inferences about each document
        given its estimates of the topics learned when creating the model.
        This is done via Gibbs sampling.

        Parameters
        ----------
        dataset : SArray
            A set of documents to predict topics for.

        output_type : str, optional
            The type of output desired. This can either be

            - assignment: the returned values are integers in [0, num_topics)
            - probability: each returned prediction is a vector with length
              num_topics, where element k prepresents the probability that
              document belongs to topic k.

        Returns
        -------
        out : SArray

        Examples
        --------

        Make predictions about which topic each document belongs to.

        >>> pred = m.predict(docs)

        If one is interested in the probability of each topic

        >>> pred = m.predict(docs, output_type='probability')

        Notes
        -----
        For each unique word w in a document d, we sample an assignment to
        topic k with probability proportional to

        .. math::
            p(z_{dw} = k) \propto (n_{d,k} + alpha) * Phi_{w,k}

        where

        - W is the size of the vocabulary,
        - n_{d,k} is the number of other times we have assigned a word in
          document to d to topic k,
        - Phi_{w,k} is the probability under the model of choosing word w
          given the word is of topic k. This is the matrix returned by calling
          `m['topics']`.

        This represents a collapsed Gibbs sampler for the document assignments
        while we keep the topics learned during training fixed.
        This process is done in parallel across all documents, five times per
        document.

        """
        _mt._get_metric_tracker().track('toolkit.text.topic_model.predict')

        assert isinstance(dataset, _SArray), \
            "Provided test documents argument must be an SArray."

        opts = {'model': self.__proxy__,
                'data': dataset}
        response = _graphlab.toolkits.main.run("text_topicmodel_predict", opts)
        preds = _SArray(None, _proxy=response['predictions'])

        # Get most likely topic if probabilities are not requested
        if output_type != 'probability':
            # equivalent to numpy.argmax(x)
            preds = preds.apply(lambda x: max(izip(x, xrange(len(x))))[1])

        return preds

    def evaluate(self, train_data, test_data=None):
        """
        Estimate the model's ability to predict new data.
        Imagine you have a corpus of books. One common approach to evaluating
        topic models is to train on the first half of all of the books
        and see how well the model predicts the second half of each book.

        This method returns a metric called perplexity, which  is related
        to the likelihood of observing these words under the given model.
        See :py:func:`~graphlab.text.topic_model.perplexity` for more details.

        The provided `train_data` and `test_data` must have the same length,
        i.e., both data sets must have the same number of documents;
        the model will use train_data to estimate which topic the
        document belongs to, and this is used to estimate the model's
        performance at predicting the unseen words in the test data.

        See :py:func:`~graphlab.text.topic_model.TopicModel.predict` for
        details on how these predictions are made, and see
        :py:func:`~graphlab.text.util.random_split` for a helper function
        that can be used for making train/test splits.

        Parameters
        ----------
        train_data : SArray
            A set of documents to predict topics for.

        test_data : SArray, optional
            A set of documents to evaluate performance on.
            By default this will set to be the same as train_data.

        Examples
        --------

        >>> train_data, test_data = graphlab.text.util.random_split(docs)
        >>> m = topic_model.create(train_data)
        >>> m.evaluate(train_data, test_data)

        """
        _mt._get_metric_tracker().track('toolkit.text.topic_model.evaluate')


        if test_data is None:
            test_data = train_data

        predictions = self.predict(train_data, output_type='probability')
        topics = self.get('topics')
        vocab = self.get('vocabulary')

        perp = perplexity(test_data,
                          predictions,
                          topics['topic_probabilities'],
                          topics['vocabulary'])
        return perp

def perplexity(test_data, predictions, topics, vocabulary):
    """
    Compute the perplexity of a set of test documents given a set
    of predicted topics.

    Let theta be the matrix of document-topic probabilities, where
    theta_ik = p(topic k | document i). Let Phi be the matrix of term-topic
    probabilities, where phi_jk = p(word j | topic k).

    Then for each word in each document, we compute for a given word w
    and document d

    .. math::
        p(word | \theta[doc_id,:], \phi[word_id,:]) =
       \sum_k \theta[doc_id, k] * \phi[word_id, k]

    We compute loglikelihood to be:

    .. math::
        l(D) = \sum_{i \in D} \sum_{j in D_i} count_{i,j} * log Pr(word_{i,j} | \theta, \phi)

    and perplexity to be

    .. math::
        \exp \{ - l(D) / \sum_i \sum_j count_{i,j} \}

    For more information, see http://en.wikipedia.org/wiki/Perplexity.

    Parameters
    ----------
    test_data : SArray
        An SArray of documents in bag-of-words format.

    predictions : SArray
        An SArray of vector type, where each vector contains estimates of the
        probability that this document belongs to each of the topics.
        This must have the same size as test_data; otherwise an exception
        occurs. This can be the output of
        :py:func:`~graphlab.toolkit.text.TopicModel.predict`, for example.

    vocabulary: SArray
        An SArray of words to use. All words in test_data that are not in this
        vocabulary will be ignored.

    Examples
    --------

    >>> from graphlab.text import topic_model
    >>> m = topic_model.create(train_data)
    >>> pred = m.predict(train_data)
    >>> topics = m['topics']
    >>> p = topic_model.perplexity(test_data, pred,
                                   topics['topic_probabilities'],
                                   topics['vocabulary'])
    >>> p
    1720.7  # lower values are better

    Notes
    -----
    For more details, see equations 13-16 of the following paper:

    Patterson and Teh. NIPS, 2013.
    http://www.stats.ox.ac.uk/~teh/research/compstats/PatTeh2013a.pdf
    """
    _mt._get_metric_tracker().track('toolkit.text.perplexity')


    assert isinstance(test_data, _SArray), \
        "Test data must be an SArray."
    assert isinstance(predictions, _SArray), \
        "Predictions must be an SArray of vector type."
    assert predictions.dtype() == _array.array, \
        "Predictions must be probabilities. Try using m.predict() with \
         output_type='probability'."

    opts = {'test_data': test_data,
            'predictions': predictions,
            'topics': topics,
            'vocabulary': vocabulary}
    response = _graphlab.toolkits.main.run("text_topicmodel_get_perplexity", opts)
    return response['perplexity']


def create(dataset,
           num_topics=10,
           initial_topics=None,
           alpha=None, beta=.1,
           num_iterations=10,
           associations=None,
           verbose=True,
           print_interval=10,
           method='auto'):
    """
    Create a topic model from the given data set.

    A topic model assumes each document is a mixture of a set of topics,
    where for  each topic some words are more likely than others. One
    statistical approach to do this is called a "topic model". This method
    learns a topic model for the given document collection.

    Parameters
    ----------
    dataset : SArray of type dict
        A bag of words representation of a document corpus.
        Each element is a dictionary representing a single document, where
        the keys are words and the values are the number of times that word
        occurs in that document.

    num_topics : int
        The number of topics to learn.

    initial_topics : SFrame, optional
        An SFrame with a column of unique words representing the vocabulary
        and a column of dense vectors representing
        probability of that word given each topic. When provided,
        these values are used to initialize the algorithm.

    num_iterations : int, optional
        The number of iterations to perform.

    alpha : float, optional
        Hyperparameter that controls the diversity of topics in a document.
        Smaller values encourage fewer topics per document.
        Provided value must be positive. Default value is 50/num_topics.

    beta : float, optional
        Hyperparameter that controls the diversity of words in a topic.
        Smaller values encourage fewer words per topic. Provided value
        must be positive.

    verbose: bool, optional
        Print progress when True.

    print_interval : int, optional
        The number of iterations to wait between progress reports.

    associations : SFrame optional
        An SFrame with two columns named "word" and "topic" containing words
        and the topic id that the word should be associated with. These words
        are not considered during learning.

    method : str {'cgs'}
        The algorithm used for learning the model.

    Returns
    -------
    out : TopicModel
        A fitted topic model. This can be used with get_topics() and predict().

    Examples
    --------

    The following example includes an SArray of documents, where
    each element represents a document in "bag of words" representation
    -- a dictionary with word keys and whose values are the number of times
    that word occurred in the document:

    >>> docs = graphlab.SArray('s3://GraphLab-Datasets/nytimes')

    Once in this form, it is straightforward to learn a topic model.

    >>> from graphlab.toolkits.text import topic_model
    >>> m = topic_model.create(docs)

    The returned object is a :py:class:`~graphlab.toolkits.text.TopicModel`
    object, which exposes several useful methods. For example,
    :py:func:`graphlab.toolkits.text.TopicModel.get_topics` returns an
    :class:`graphlab.SFrame` containing the most probable words for each
    topic and a score related to how high that word ranks for that topic.

    >>> m.get_topics()
        Columns:
                topic_name    str
                word          str
                score         float
        Rows: 10
        Data:
          topic        word   score
        0     0     percent  .32815
        1     0     million  .17627
        2     0         new  .16467
        3     0       stock  .05729
        4     0    exchange  .02868
        5     1   hurricane  .18396
        6     1       storm  .16498
        7     1        rain  .13967
        8     1        wind  .11335
        9     1     beatles  .19824
        10    2        hair  .17446

    You may get details on a subset of topics by supplying a list of topic
    names or topic indices, as well as restrict the number of words returned
    per topic.

    >>> m.get_topics([0, 1, 3], num_words=5)

    To predict the topic of a given document, one can get an SArray of
    integers containing the most probable topic ids:

    >>> topic_ids = m.predict(documents.head(5))
    >>> topic_ids.size()  # 5

    Combining the above method with standard SFrame capabilities, one can use
    predict to find documents related to a particular topic

    >>> documents[m.predict(documents) == topic_id]

    or join with other data in order to analyze an author's typical topics or
    how topics change over time. For example,

    >>> doc_data.column_names()
    ['timestamp', 'author', 'text']
    >>> m = topic_model.create(doc_data['text'])
    >>> doc_data['topic'] = m.predict(doc_data['text'])
    >>> doc_data['author'][doc_data['topic'] == 1] # authors of docs in topic 1

    Sometimes you want to know how certain the model's predictions are. One
    can optionally also get the probability of each topic for a set of
    documents. Each element of the returned SArray is a vector containing the
    probability of each document.

    >>> topic_probs = m.predict(documents.head(5), output_type='probability')

    The model object keeps track of various useful statistics about how the
    model was trained and its current status.

    >>> print m
    Topic Model
       Data:
            vocab_size:10473
            num_words:435838
       Settings:
            num_topics:20
            beta:0.1
            alpha:0.1
            num_iterations:100
            verbose:1
       Accessible attributes:
            m['topics']: An SFrame containing the topics.

    The value for each metadata field is accessible via m[field]. As with
    other models in GraphLab Create, it's also easy to save and load models.

    >>> m.save('my_model')
    >>> m2 = graphlab.load_model('my_model')

    It is also easy to create a new topic model from an old one  -- whether
    it was created using GraphLab Create or another package.

    >>> m3 = topic_model.create(documents,
                                initial_topics=m['topics'])

    To manually fix several words to always be assigned to a topic, use
    the `associations` argument. The following will ensure that topic 0
    has the most probability for each of the provided words:

    >>> associations = SFrame({'word':['hurricane', 'wind', 'storm'],
                               'topic': [0, 0, 0]})
    >>> m = topic_model.create(docs,
                               associations=associations)

    More advanced usage allows you  to control aspects of the model and the
    learning method.

    >>> m = topic_model.create(docs,
                               num_topics=20,       # number of topics
                               num_iterations=10,   # algorithm parameters
                               alpha=.01, beta=.1)  # hyperparameters

    For evaluating the performance of a learned TopicModel object, see
    :py:func:`~graphlab.toolkits.text.TopicModel.evaluate` for more
    information.
    """
    _mt._get_metric_tracker().track('toolkit.text.topic_model.create')


    if not isinstance(dataset, _SArray):
        raise TypeError('dataset input must be an SArray')

    # If associations are provided, check they are in the proper format
    if associations is None:
        associations = _graphlab.SFrame({'word': [], 'topic': []})
    if isinstance(associations, _graphlab.SFrame) and \
        associations.num_rows() > 0:
        assert set(associations.column_names()) == set(['word', 'topic']), \
            "Provided associations must be an SFrame containing a word column\
             and a topic column."
        assert associations['word'].dtype() == str, \
            "Words must be strings."
        assert associations['topic'].dtype() == int, \
            "Topic ids must be of int type."
    if alpha is None:
        alpha = 50 / num_topics

    opts = {'data': dataset,
            'verbose': verbose,
            'num_topics': num_topics,
            'num_iterations': num_iterations,
            'alpha': alpha,
            'beta': beta,
            'associations': associations}
    # opts.update(method_options)

    # Initialize the model with basic parameters
    response = _graphlab.toolkits.main.run("text_topicmodel_init", opts)
    m = TopicModel(response['model'])

    # If initial_topics provided, load it into the model
    if isinstance(initial_topics, _graphlab.SFrame):
        assert set(['vocabulary', 'topic_probabilities']) ==              \
               set(initial_topics.column_names()),                        \
            "The provided initial_topics does not have the proper format, \
             e.g. wrong column names."
        observed_topics = initial_topics['topic_probabilities'].apply(lambda x: len(x))
        assert all(observed_topics == num_topics),                        \
            "Provided num_topics value does not match the number of provided initial_topics."

        # Rough estimate of total number of words
        weight = dataset.size() * 1000

        opts = {'model': m.__proxy__,
                'topics': initial_topics['topic_probabilities'],
                'vocabulary': initial_topics['vocabulary'],
                'weight': weight}
        response = _graphlab.toolkits.main.run("text_topicmodel_set_topics", opts)
        m = TopicModel(response['model'])

    # Train the model on the given data set and retrieve predictions
    opts = {'model': m.__proxy__,
            'data': dataset,
            'method': method,
            'verbose': verbose}
    # opts.update(method_options)
    response = _graphlab.toolkits.main.run("text_topicmodel_train", opts)
    m = TopicModel(response['model'])

    return m


