from wdmmg import model

class ValueCache(object):
    '''
    A utility for creating and retrieving EnumerationValues and their `id`s.

    To avoid running out of RAM, the caller will probably call
    `model.Session.remove()` one or more times during the lifetime of this 
    ValueCache. Therefore, this class cannot retain domain objects between 
    method calls. Instead, it records the database `id`s of the objects it 
    needs.
    '''
    
    def __init__(self, key, create=True):
        '''
        Constructs a ValueCache for `key`. (You need a separate ValueCache 
        instance for each Key).
    
        key - the Key to which the EnumerationValues belong.
        
        create (optional, default=True) - if `True`, then calling 
            `get_value_id()` with an unknown code will cause a new 
            EnumerationValue to be created. If `False`, then an exception will 
            be raised instead.
        '''
        self.key_id = key.id
        self.key_name = key.name
        self.create = create
        self.index = {}
        for ev in (model.Session.query(model.EnumerationValue)
            .filter_by(key=key)
            ).all():
            self.index[ev.code] = ev.id
    
    def get_value_id(self, code, name=None, notes=u''):
        '''
        Returns the `id` of the EnumerationValue with the specified `code`,
        creating it if necessary.
        
        The whole point of this class is that it does not read from the
        database, except once at construction time. Therefore, do not create
        EnumerationValue records that this ValueCache is not aware of. If you
        do, this method could create a duplicate, or raise an exception.
        
        :param code: the `EnumerationValue.code` to look up. If `None`, returns
            `None`.
        
        :param name: the `name` to use if the EnumerationValue needs to be
            created. If None, then the `code` will be used as the `name`.
        
        :param notes: the `notes` to use if the EnumerationValue needs to be 
            created.
        
        :return: the `EnumerationValue.id`, or `None`.
        '''
        #print 'get_value_id with code %s' % code
        if not code:
            return None
        assert isinstance(code, unicode)
        if not name:
            name = code
        else:
            assert isinstance(name, unicode)
        if notes:
            assert isinstance(notes, unicode)
        if code not in self.index:
            if not self.create:
                raise ValueError, 'Unknown EnumerationValue (key_name=%r, code=%r)' % (self.key_name, code)
            #print 'Creating EnumerationValue (key_name=%r, code=%r, key_id=%r)' % (self.key_name, code, self.key_id)
            key = model.Session.query(model.Key).filter_by(id=self.key_id).one() # Avoids "orphan" error.
            ev = model.EnumerationValue(key=key, code=code, name=name, notes=notes)
            model.Session.add(ev)
            model.Session.commit()
            self.index[code] = int(ev.id)
        #print 'Returning %s of type %s' % (self.index[code], type(self.index[code]))
        return self.index[code]

# TODO: Rethink Loader API.
# Currently, caller has to do three jobs for each Key:
# - Retrieve or create it.
# - Wrap it in a ValueCache.
# - Pass it to Loader.
# The flexibility to do something different is not really useful.

class Loader(object):
    '''
    Represents the RAM-resident state of a process that loads an OLAP-like
    data set into the store. The data set is assumed to be too big to fit in
    RAM, so it is streamed in and simultaneously written out to the database.
    
    To avoid running out of RAM, the caller will probably call
    `model.Session.remove()` one or more times during the lifetime of this 
    Loader. Therefore, this class cannot retain domain objects between method 
    calls. Instead, it records the database `id`s of the objects it needs.
    
    The intended usage is something like this:
    
        my_loader = Loader('my_dataset', [key_1, ..., key_n], commit_every=1000)
        for row in fetch_my_data():
            entry = my_loader.create_entry(row.amount, [
                row.enumeration_value_1_id,
                ...,
                row.enumeration_value_n_id,
            ])
            # Optionally do stuff to `entry` here.
        my_loader.compute_aggregates()
    
    The caller is responsible for setting up the Keys and EnumerationValues.
    It is recommended that instances of ValueCache be used to retrieve the
    `EnumerationValue.id`s passed to `create_entry()`, as this will
    avoid database traffic.
    '''
    
    def __init__(self, dataset_name, axes, notes=u'', metadata=None, \
                  commit_every=None, dataset_long_name=None, currency=u'gbp'):
        '''
        Constructs a Loader for a new Dataset `dataset_name`. (Raises an exception
        if a Dataset already exists with that name). Calling the constructor 
        creates the Dataset object.
        
        dataset_name - the unique name of the Dataset.
        
        axes - a list of Keys which will be used to classify spending.
        
        notes (optional) - the `notes` to use when creating the Dataset.
        
        commit_every (optional) - if not None, the frequency with which the
            Loader will commit data to the store, expressed as a number of 
            calls of `create_entry()`. The Loader will call
            `model.Session.commit()` and also `model.Session.remove()`.
        '''
        assert isinstance(dataset_name, unicode)
        for key in axes:
            assert isinstance(key, model.Key)
        assert not (model.Session.query(model.Dataset)
            .filter_by(name=dataset_name)
            ).first(), "Dataset '%s' already loaded" % dataset_name
        # Create dataset.
        dataset_ = model.Dataset(name=dataset_name, currency=currency, metadata=metadata, notes=notes)
        model.Session.add(dataset_)
        model.Session.commit()
        # Initialise fields.
        self.dataset_name = dataset_name
        self.dataset_id = dataset_.id
        self.key_ids = [key.id for key in axes]
        self.commit_every = commit_every
        self.commit_count = 0

    def create_entry(self, amount, value_ids, currency=None):
        '''
        Creates a Entry record and associated ClassificationItems.
        
        amount - the amount spent.

        currency - hull by default.
        
        value_ids - a list of `EnumerationValue.id`s, one for each Key passed
            to the constructor. Can pass `None` for missing data.
        '''
        assert isinstance(amount, float)
        assert len(value_ids) == len(self.key_ids)
        for value_id in value_ids:
            if value_id:
                assert isinstance(value_id, int)
        # Update the counter and commit if necessary.
        if self.commit_every and self.commit_count % self.commit_every == 0:
            print "Committing before row %d" % self.commit_count
            model.Session.commit()
            model.Session.remove()
        self.commit_count += 1
        # Create Entry and ClassificationItems.
        txn = model.Entry(dataset_id=self.dataset_id, amount=amount, currency=currency)
        model.Session.add(txn)  
        model.Session.commit()
        for value_id in value_ids:
            if value_id:
                model.Session.add(model.ClassificationItem(
                    entry_id=txn.id, value_id=value_id))
        model.Session.commit()
        return txn

    def compute_aggregates(self):
        # TODO.
        pass

