/*
 * Decompiled with CFR 0.152.
 */
package stallone.stat;

import java.util.Arrays;
import stallone.api.API;
import stallone.api.datasequence.IDataSequence;
import stallone.api.doubles.IDoubleArray;
import stallone.api.function.IParametricFunction;
import stallone.api.stat.IDiscreteDistribution;
import stallone.api.stat.IParameterEstimator;
import stallone.doubles.PrimitiveDoubleTools;

public class DiscreteDistribution
implements IParametricFunction,
IParameterEstimator,
IDiscreteDistribution {
    private double[] priorCount;
    private double[] count;
    private double[] p;
    private double[] pinc;

    public DiscreteDistribution(double[] _p) {
        this.p = _p;
        this.pinc = new double[_p.length];
        this.priorCount = new double[_p.length];
        this.count = new double[_p.length];
        this.updateInc();
    }

    public DiscreteDistribution(IDoubleArray arr) {
        this(arr.getArray());
    }

    public void setPrior(double[] _prior) {
        this.priorCount = _prior;
    }

    private final void updateInc() {
        this.pinc[0] = this.p[0];
        int j = 1;
        while (j < this.pinc.length) {
            this.pinc[j] = this.pinc[j - 1] + this.p[j];
            ++j;
        }
    }

    @Override
    public int sample() {
        double r = Math.random();
        int to = 0;
        while (to < this.pinc.length && this.pinc[to] <= r) {
            ++to;
        }
        return to;
    }

    @Override
    public IDoubleArray getParameters() {
        return API.doublesNew.arrayFrom(this.p);
    }

    @Override
    public void setParameters(IDoubleArray par) {
        int i = 0;
        while (i < this.p.length) {
            this.p[i] = par.get(i);
            ++i;
        }
        this.updateInc();
    }

    @Override
    public int getNumberOfVariables() {
        return this.p.length;
    }

    @Override
    public double f(double ... x) {
        if (x.length == this.p.length) {
            double logP = 0.0;
            int i = 0;
            while (i < x.length) {
                if (this.p[i] == 0.0) {
                    if (x[i] != 0.0) {
                        return 0.0;
                    }
                } else {
                    logP += x[i] * Math.log(this.p[i]);
                }
                ++i;
            }
            return Math.exp(logP);
        }
        if (x.length == 1) {
            return this.p[(int)x[0]];
        }
        throw new IllegalArgumentException("incompatible input vector");
    }

    @Override
    public double f(IDoubleArray x) {
        return this.f(x.getArray());
    }

    @Override
    public DiscreteDistribution copy() {
        DiscreteDistribution dd = new DiscreteDistribution(PrimitiveDoubleTools.copy(this.p));
        dd.setPrior(PrimitiveDoubleTools.copy(this.priorCount));
        return dd;
    }

    @Override
    public IDoubleArray estimate(IDataSequence data) {
        this.initialize();
        this.addToEstimate(data);
        return this.getEstimate();
    }

    @Override
    public IDoubleArray estimate(IDataSequence data, IDoubleArray weights) {
        this.initialize();
        this.addToEstimate(data, weights);
        return this.getEstimate();
    }

    @Override
    public void initialize() {
        Arrays.fill(this.count, 0.0);
        this.updateInc();
    }

    @Override
    public void initialize(IDoubleArray initPar) {
        this.p = initPar.getArray();
        this.updateInc();
    }

    private void count2p() {
        double[] totalcounts = PrimitiveDoubleTools.add(this.priorCount, this.count);
        this.p = PrimitiveDoubleTools.multiply(1.0 / PrimitiveDoubleTools.sum(totalcounts), totalcounts);
        this.updateInc();
    }

    @Override
    public void addToEstimate(IDataSequence data) {
        if (data.dimension() == this.p.length) {
            for (IDoubleArray arr : data) {
                PrimitiveDoubleTools.increment(this.count, arr.getArray());
            }
        } else if (data.dimension() == 1) {
            for (IDoubleArray arr : data) {
                int n = (int)arr.get(0);
                this.count[n] = this.count[n] + 1.0;
            }
        } else {
            throw new IllegalArgumentException("incompatible dimension of observation");
        }
        this.count2p();
    }

    @Override
    public void addToEstimate(IDataSequence data, IDoubleArray weights) {
        if (data.dimension() == this.count.length) {
            int i = 0;
            while (i < data.size()) {
                IDoubleArray arr = data.get(i);
                double w = weights.get(i);
                int j = 0;
                while (j < this.count.length) {
                    int n = j;
                    this.count[n] = this.count[n] + w * arr.get(j);
                    ++j;
                }
                ++i;
            }
        } else if (data.dimension() == 1) {
            int i = 0;
            while (i < data.size()) {
                IDoubleArray arr = data.get(i);
                int n = (int)arr.get(0);
                this.count[n] = this.count[n] + weights.get(i);
                ++i;
            }
        } else {
            throw new IllegalArgumentException("incompatible dimension of observation");
        }
        this.count2p();
    }

    @Override
    public IDoubleArray getEstimate() {
        return API.doublesNew.arrayFrom(this.p);
    }
}

