/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.boosting;

import java.util.List;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import jsat.DataSet;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.distributions.ContinuousDistribution;
import jsat.exceptions.FailedToFitException;
import jsat.exceptions.UntrainedModelException;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;
import jsat.utils.FakeExecutor;
import jsat.utils.SystemInfo;

public class Wagging
implements Classifier,
Regressor,
Parameterized {
    private static final long serialVersionUID = 4999034730848794619L;
    private ContinuousDistribution dist;
    private int iterations;
    private Classifier weakL;
    private Regressor weakR;
    private CategoricalData predicting;
    private Classifier[] hypotsL;
    private Regressor[] hypotsR;

    public Wagging(ContinuousDistribution dist, Classifier weakL, int iterations) {
        this.setDistribution(dist);
        this.setIterations(iterations);
        this.setWeakLearner(weakL);
    }

    public Wagging(ContinuousDistribution dist, Regressor weakR, int iterations) {
        this.setDistribution(dist);
        this.setIterations(iterations);
        this.setWeakLearner(weakR);
    }

    protected Wagging(Wagging clone) {
        int i;
        this.dist = clone.dist.clone();
        this.iterations = clone.iterations;
        if (clone.weakL != null) {
            this.setWeakLearner(clone.weakL.clone());
        }
        if (clone.weakR != null) {
            this.setWeakLearner(clone.weakR.clone());
        }
        if (clone.predicting != null) {
            this.predicting = clone.predicting.clone();
        }
        if (clone.hypotsL != null) {
            this.hypotsL = new Classifier[clone.hypotsL.length];
            for (i = 0; i < this.hypotsL.length; ++i) {
                this.hypotsL[i] = clone.hypotsL[i].clone();
            }
        }
        if (clone.hypotsR != null) {
            this.hypotsR = new Regressor[clone.hypotsR.length];
            for (i = 0; i < this.hypotsR.length; ++i) {
                this.hypotsR[i] = clone.hypotsR[i].clone();
            }
        }
    }

    public void setWeakLearner(Classifier weakL) {
        if (weakL == null) {
            throw new NullPointerException();
        }
        this.weakL = weakL;
        if (weakL instanceof Regressor) {
            this.weakR = (Regressor)((Object)weakL);
        }
    }

    public Classifier getWeakClassifier() {
        return this.weakL;
    }

    public void setWeakLearner(Regressor weakR) {
        if (weakR == null) {
            throw new NullPointerException();
        }
        this.weakR = weakR;
        if (weakR instanceof Classifier) {
            this.weakL = (Classifier)((Object)weakR);
        }
    }

    public Regressor getWeakRegressor() {
        return this.weakR;
    }

    public void setIterations(int iterations) {
        if (iterations < 1) {
            throw new ArithmeticException("The number of iterations must be positive");
        }
        this.iterations = iterations;
    }

    public int getIterations() {
        return this.iterations;
    }

    public void setDistribution(ContinuousDistribution dist) {
        if (dist == null) {
            throw new NullPointerException();
        }
        this.dist = dist;
    }

    public ContinuousDistribution getDistribution() {
        return this.dist;
    }

    private void performTraining(ExecutorService threadPool, DataSet dataSet) {
        int chunkSize = this.iterations / SystemInfo.LogicalCores;
        int extra = this.iterations % SystemInfo.LogicalCores;
        int used = 0;
        Random rand = new Random();
        CountDownLatch latch = new CountDownLatch(chunkSize > 0 ? SystemInfo.LogicalCores : extra);
        while (used < this.iterations) {
            int start = used;
            int end = start + chunkSize;
            if (extra-- > 0) {
                // empty if block
            }
            used = ++end;
            threadPool.submit(new WagFill(start, end, dataSet, new Random(rand.nextInt()), latch));
        }
        try {
            latch.await();
        }
        catch (InterruptedException ex) {
            throw new FailedToFitException(ex);
        }
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        if (this.hypotsL == null) {
            throw new UntrainedModelException("Model has not been trained for classification");
        }
        CategoricalResults results = new CategoricalResults(this.predicting.getNumOfCategories());
        for (Classifier hypot : this.hypotsL) {
            results.incProb(hypot.classify(data).mostLikely(), 1.0);
        }
        results.normalize();
        return results;
    }

    @Override
    public void trainC(ClassificationDataSet dataSet, ExecutorService threadPool) {
        if (this.weakL == null) {
            throw new FailedToFitException("No classification weak learner was provided");
        }
        this.predicting = dataSet.getPredicting();
        this.hypotsL = new Classifier[this.iterations];
        this.hypotsR = null;
        this.performTraining(threadPool, dataSet);
    }

    @Override
    public void trainC(ClassificationDataSet dataSet) {
        this.trainC(dataSet, new FakeExecutor());
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    @Override
    public double regress(DataPoint data) {
        if (this.hypotsR == null) {
            throw new UntrainedModelException("Model has not been trained for regression");
        }
        double avg = 0.0;
        for (Regressor hypot : this.hypotsR) {
            avg += hypot.regress(data);
        }
        return avg /= (double)this.hypotsR.length;
    }

    @Override
    public void train(RegressionDataSet dataSet, ExecutorService threadPool) {
        if (this.weakR == null) {
            throw new FailedToFitException("No regression weak learner was provided");
        }
        this.hypotsL = null;
        this.hypotsR = new Regressor[this.iterations];
        this.performTraining(threadPool, dataSet);
    }

    @Override
    public void train(RegressionDataSet dataSet) {
        this.train(dataSet, new FakeExecutor());
    }

    @Override
    public Wagging clone() {
        return new Wagging(this);
    }

    @Override
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override
    public Parameter getParameter(String paramName) {
        return Parameter.toParameterMap(this.getParameters()).get(paramName);
    }

    private class WagFill
    implements Runnable {
        int start;
        int end;
        DataSet ds;
        Random rand;
        CountDownLatch latch;

        public WagFill(int start, int end, DataSet ds, Random rand, CountDownLatch latch) {
            this.start = start;
            this.end = end;
            this.ds = ds.shallowClone();
            this.rand = rand;
            this.latch = latch;
            for (int i = 0; i < this.ds.getSampleSize(); ++i) {
                DataPoint dp = this.ds.getDataPoint(i);
                this.ds.setDataPoint(i, new DataPoint(dp.getNumericalValues(), dp.getCategoricalValues(), dp.getCategoricalData()));
            }
        }

        @Override
        public void run() {
            if (this.ds instanceof ClassificationDataSet) {
                ClassificationDataSet cds = (ClassificationDataSet)this.ds;
                for (int i = this.start; i < this.end; ++i) {
                    for (int j = 0; j < this.ds.getSampleSize(); ++j) {
                        double newWeight = Math.max(1.0E-6, Wagging.this.dist.invCdf(this.rand.nextDouble()));
                        this.ds.getDataPoint(j).setWeight(newWeight);
                    }
                    Classifier hypot = Wagging.this.weakL.clone();
                    hypot.trainC(cds);
                    ((Wagging)Wagging.this).hypotsL[i] = hypot;
                }
            } else if (this.ds instanceof RegressionDataSet) {
                RegressionDataSet rds = (RegressionDataSet)this.ds;
                for (int i = this.start; i < this.end; ++i) {
                    for (int j = 0; j < this.ds.getSampleSize(); ++j) {
                        this.ds.getDataPoint(i).setWeight(Math.max(1.0E-6, Wagging.this.dist.invCdf(this.rand.nextDouble())));
                    }
                    Regressor hypot = Wagging.this.weakR.clone();
                    hypot.train(rds);
                    ((Wagging)Wagging.this).hypotsR[i] = hypot;
                }
            } else {
                throw new RuntimeException("BUG: please report");
            }
            this.latch.countDown();
        }
    }
}

