from collections import deque

from khronos.statistics.tally import Tally
from khronos.statistics.plotter import get_plotter

class TSeries(Tally):
    def __init__(self, storing=False, numeric=True, time_fnc=None, time_scale=0):
        Tally.__init__(self, storing, numeric)
        self.__collector = None
        self.__last_value = None
        self.__last_time = None
        self.__start = None
        self.__timefnc = time_fnc
        self.__timescale = time_scale  # used to convert time data during collection.
        
    def clear(self, storing=None, numeric=None, time_fnc=None, time_scale=None):
        Tally.clear(self, storing, numeric)
        self.__collector = None
        self.__last_value = None
        self.__last_time = None
        self.__start = None
        if time_fnc is not None:
            self.__timefnc = time_fnc
        if time_scale is not None:
            self.__timescale = time_scale
            
    def collect(self, value, time=None):
        """Register a new value in the time series. This method works for started and unstarted 
        tseries objects, having the advantage of not requiring the user to know whether the 
        timeseries has been started or not."""
        if self.started():
            self.append(value, time)
        else:
            self.start(value, time)
            
    def start(self, value, time=None):
        """Clear the time series and register a new starting value. Be careful when using this 
        method, since any values previously stored in the time series are discarded!"""
        if self.started():
            self.clear()
        if time is None:
            time = self.__timefnc()
        if self.__timescale != 0:
            time /= self.__timescale
        self.__collector = self.lazy_collect(value)
        self.__last_value = value
        self.__last_time = time
        self.__start = time
        
    def append(self, value, time=None):
        """Register a new value into the time series. Note that the timeseries must be started 
        before calling this method. The collect() method can be used without this concern."""
        if time is None:
            time = self.__timefnc()
        if self.__timescale != 0:
            time /= self.__timescale
        if time < self.__last_time:
            raise ValueError("invalid time (less than previous value)")
        self.__collector.send(time - self.__last_time)
        self.__collector = self.lazy_collect(value)
        self.__last_value = value
        self.__last_time = time
        
    # -----------------------------------------------
    def __iter__(self):
        if self.started():
            start = self.__start
            for value, length in Tally.__iter__(self):
                yield value, start
                start += length
            yield self.__last_value, self.__last_time
            
    def iter_values(self):
        for value, _ in self:
            yield value
            
    def iter_times(self):
        for _, time in self:
            yield time
            
    # -----------------------------------------------
    def last_value(self):
        return self.__last_value
        
    def last_time(self):
        return self.__last_time
        
    def started(self):
        return self.__start is not None
        
    def time_fnc(self, fnc=None):
        """Get or set the time function of a timeseries."""
        if fnc is not None:
            self.__timefnc = fnc
        else:
            return self.__timefnc
            
    def time_scale(self, time_scale=None):
        """Get or set the time_scale of a timeseries."""
        if time_scale is not None:
            self.__timescale = time_scale
        else:
            return self.__timescale
            
    def value_at(self, time):
        """Get the timeseries' value at the specified time."""
        if self.__timescale != 0:
            time /= self.__timescale
        if time < self.__start:
            raise ValueError("no data before %s" % (self.__start,))
        if time > self.__last_time:
            raise ValueError("no data after %s" % (self.__last_time,))
        if time == self.__last_time:
            return self.__last_value
            
        start = self.__start
        for value, length in Tally.__iter__(self):
            end = start + length
            if start <= time < end:
                return value
            start = end
            
    # -----------------------------------------------
    def slice(self, start=None, end=None):
        """Create a time series which is a subset of this one, starting and ending at the 
        specified times. Note that the interval limits are divided by the current time_scale if it 
        is defined, so if the time_scale is a TimeDelta, the limits should also be TimeDeltas.
        An additional note is that this method ONLY works on storing timeseries."""
        if not self.started():
            raise ValueError("timeseries not started yet")
        # If the limits have not been provided to this method, the slice will use the starting 
        # and/or ending times of this timeseries as boundaries. Note that if none of the 
        # boundaries is provided, this method returns a copy of this timeseries.
        if start is None:
            start = self.__start
            if self.__timescale != 0:
                start *= self.__timescale
        if end is None:
            end = self.__last_time
            if self.__timescale != 0:
                end *= self.__timescale
        # Scale the limits if a time_scale is defined.
        if self.__timescale != 0:
            start /= self.__timescale
            end   /= self.__timescale
        # Exchange the limits if they are ordered incorrectly.
        if start > end:
            start, end = end, start
        # Check that the specified interval lies within the timeseries' boundaries.
        if start < self.__start:
            raise ValueError("no data before %s" % (self.__start,))
        if end > self.__last_time:
            raise ValueError("no data after %s" % (self.__last_time,))
            
        # Finally, we will create a new timeseries object and populate it with the data regarding 
        # the slice that was asked for. Note that the time_scale is not defined yet to avoid 
        # conversion of the dates we will be inserting in the new tseries. The scale is only set 
        # after the new series is finished, i.e. all the values are inserted.
        tseries = TSeries(storing=self.is_storing(), numeric=self.is_numeric())
        iterator = Tally.__iter__(self)
        closed = False
        t0 = self.__start
        t1 = None
        for value, length in iterator:
            t1 = t0 + length
            if t0 <= start < t1:
                tseries.start(value, start)
                break
            t0 = t1
            
        if end < t1:
            tseries.repeat(end)
        else:
            for value, length in iterator:
                tseries.append(value, t1)
                t1 += length
                if end < t1:
                    break
            if tseries.last_time() < end:
                tseries.repeat(end)
                
        # Only set the time_scale now to avoid conversion of the values inserted previously.
        tseries.__timescale = self.__timescale
        return tseries
        
    def merge(self, *tseries):
        """Merge several time series together, creating a new one."""
        raise NotImplementedError()
        
    def add(self, *tseries):
        """Create a new time series by adding this time series with others."""
        raise NotImplementedError()
        
    def invert(self):
        """Return a new time series equal to this one, but with inverted values. Note that the 
        time series must be numeric for this method to work."""
        raise NotImplementedError()
        
    def average(self, *tseries):
        """Create a new time series by averaging this time series with others."""
        raise NotImplementedError()
        
    # -----------------------------------------------
    def min_series(self):
        """Create a timeseries of the minimum values of this timeseries."""
        return self.__make_series(TSeries.min)
        
    def max_series(self):
        """Create a timeseries of the maximum values of this timeseries."""
        return self.__make_series(TSeries.max)
        
    def sum_series(self):
        """Create a timeseries of the sums of values of this timeseries."""
        return self.__make_series(TSeries.sum)
        
    def mean_series(self):
        """Create a timeseries of the mean of this timeseries."""
        return self.__make_series(TSeries.mean)
        
    def var_series(self):
        """Create a timeseries of the variance of this timeseries."""
        return self.__make_series(TSeries.var)
        
    def stddev_series(self):
        """Create a timeseries of the standard deviation of this timeseries."""
        return self.__make_series(TSeries.stddev)
        
    def wmean_series(self):
        """Create a timeseries of the weighted mean of this timeseries."""
        return self.__make_series(TSeries.wmean)
        
    def wvar_series(self):
        """Create a timeseries of the weighted variance of this timeseries."""
        return self.__make_series(TSeries.wvar)
        
    def wstddev_series(self):
        """Create a timeseries of the weighted standard deviation of this timeseries."""
        return self.__make_series(TSeries.wstddev)
        
    def __make_series(self, indicator):
        if not self.is_numeric() or not self.is_storing():
            raise ValueError("numeric and storing timeseries required")
        tseries = TSeries(storing=True,  numeric=True)
        buffer  = TSeries(storing=False, numeric=True)
        for value, time in self:
            buffer.register(value, time)
            tseries.register(indicator(buffer), time)
        return tseries
        
    # -----------------------------------------------
    def run_chart(self, axes=None, show_mean=True, show_stddev=False, 
                  fill=False, *args, **kwargs):
        mean = None
        stddev = None
        if self.is_numeric() and show_mean:
           mean = self.wmean()
           if show_stddev:
               stddev = self.wstddev()
               
        plotter, axes = get_plotter(axes)
        return plotter.run_chart(self.iter_values(), self.iter_times(), 
                                 numeric=self.is_numeric(), mean=mean, stddev=stddev, 
                                 fill=fill, axes=axes, *args, **kwargs)
                                 
    def simple_plot(self, axes=None, *args, **kwargs):
        plotter, axes = get_plotter(axes)
        return plotter.simple_plot(self.iter_times(), self.iter_values(), axes, *args, **kwargs)
        
