from contextlib import contextmanager

from khronos.utils import Namespace

def mk_collect_fnc(name, aggregator):
    @contextmanager
    def collect_fnc(indicator_fnc, collector_fnc):
        x0 = indicator_fnc()
        yield
        x1 = indicator_fnc()
        x = aggregator(x0, x1)
        collector_fnc(x)
    collect_fnc.__name__ = "collect_" + name
    return collect_fnc
    
@contextmanager
def collect_both(indicator_fnc, collector_fnc):
    collector_fnc(indicator_fnc())
    yield
    collector_fnc(indicator_fnc())
    
# A collection of context managers to automatically collect data into stat collectors
collect = Namespace(diff=mk_collect_fnc("diff", lambda x0, x1: (x1 - x0)),
                    sum=mk_collect_fnc("sum", lambda x0, x1: (x1 + x0)),
                    mean=mk_collect_fnc("mean", lambda x0, x1: (x0 + x1) / 2.0),
                    min=mk_collect_fnc("min", lambda x0, x1: min(x0, x1)),
                    max=mk_collect_fnc("max", lambda x0, x1: max(x0, x1)),
                    both=collect_both)
                    
def mk_sample_fnc(name, condition):
    def sample_fnc(distr, *args, **kwargs):
        tries = kwargs.get("tries", 100)
        value = distr(*args, **kwargs)
        while not condition(value):
            tries -= 1
            if tries <= 0:
                raise ValueError("unable to get valid sample in %d tries" % (tries,))
            value = distr(*args, **kwargs)
        return value
    sample_fnc.__name__ = "sample_" + name
    return sample_fnc
    
# A collection of functions to sample probability distributions until a    
# sample that meets a certain condition (e.g. nonnegative) is observed
sample = Namespace(positive=mk_sample_fnc("pos", lambda x: x > 0.0), 
                   negative=mk_sample_fnc("neg", lambda x: x < 0.0), 
                   nonpositive=mk_sample_fnc("nonpos", lambda x: x <= 0.0), 
                   nonnegative=mk_sample_fnc("nonneg", lambda x: x >= 0.0))
