/*
 * Decompiled with CFR 0.152.
 */
package jsat.datatransform.featureselection;

import java.util.Random;
import java.util.Set;
import jsat.DataSet;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.ClassificationModelEvaluation;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.datatransform.DataTransform;
import jsat.datatransform.DataTransformFactoryParm;
import jsat.datatransform.RemoveAttributeTransform;
import jsat.regression.RegressionDataSet;
import jsat.regression.RegressionModelEvaluation;
import jsat.regression.Regressor;
import jsat.utils.IntSet;
import jsat.utils.ListUtils;

public class SFS
implements DataTransform {
    private static final long serialVersionUID = 140187978708131002L;
    private RemoveAttributeTransform finalTransform;
    private Set<Integer> catSelected;
    private Set<Integer> numSelected;
    private double maxIncrease;
    private Classifier classifier;
    private Regressor regressor;

    private SFS(SFS toClone) {
        if (toClone.catSelected != null) {
            this.finalTransform = toClone.finalTransform.clone();
            this.catSelected = new IntSet(toClone.catSelected);
            this.numSelected = new IntSet(toClone.numSelected);
        }
        this.maxIncrease = toClone.maxIncrease;
        if (toClone.classifier != null) {
            this.classifier = toClone.classifier.clone();
        }
        if (toClone.regressor != null) {
            this.regressor = toClone.regressor.clone();
        }
    }

    public SFS(int minFeatures, int maxFeatures, ClassificationDataSet dataSet, Classifier evaluater, int folds, double maxIncrease) {
        this.classifier = evaluater.clone();
        this.maxIncrease = maxIncrease;
        this.search(minFeatures, maxFeatures, dataSet, folds);
    }

    public SFS(int minFeatures, int maxFeatures, RegressionDataSet dataSet, Regressor regressor, int folds, double maxIncrease) {
        this.regressor = regressor.clone();
        this.maxIncrease = maxIncrease;
        this.search(minFeatures, maxFeatures, dataSet, folds);
    }

    private void search(int minFeatures, int maxFeatures, DataSet dataSet, int folds) {
        Random rand = new Random();
        int nF = dataSet.getNumFeatures();
        int nCat = dataSet.getNumCategoricalVars();
        IntSet available = new IntSet();
        ListUtils.addRange(available, 0, nF, 1);
        this.catSelected = new IntSet(dataSet.getNumCategoricalVars());
        this.numSelected = new IntSet(dataSet.getNumNumericalVars());
        IntSet catToRemove = new IntSet(dataSet.getNumCategoricalVars());
        IntSet numToRemove = new IntSet(dataSet.getNumNumericalVars());
        ListUtils.addRange(catToRemove, 0, nCat, 1);
        ListUtils.addRange(numToRemove, 0, nF - nCat, 1);
        double[] bestScore = new double[]{Double.POSITIVE_INFINITY};
        Cloneable learner = this.regressor;
        if (dataSet instanceof ClassificationDataSet) {
            learner = this.classifier;
        }
        while (this.catSelected.size() + this.numSelected.size() < maxFeatures && SFS.SFSSelectFeature(available, dataSet, catToRemove, numToRemove, this.catSelected, this.numSelected, learner, folds, rand, bestScore, minFeatures) >= 0) {
        }
        this.finalTransform = new RemoveAttributeTransform(dataSet, catToRemove, numToRemove);
    }

    protected static void addFeature(int curBest, int nCat, Set<Integer> catF, Set<Integer> numF) {
        if (curBest >= nCat) {
            numF.add(curBest - nCat);
        } else {
            catF.add(curBest);
        }
    }

    protected static void removeFeature(int feature, int nCat, Set<Integer> catF, Set<Integer> numF) {
        if (feature >= nCat) {
            numF.remove(feature - nCat);
        } else {
            catF.remove(feature);
        }
    }

    @Override
    public DataPoint transform(DataPoint dp) {
        return this.finalTransform.transform(dp);
    }

    @Override
    public SFS clone() {
        return new SFS(this);
    }

    public Set<Integer> getSelectedCategorical() {
        return new IntSet(this.catSelected);
    }

    public Set<Integer> getSelectedNumerical() {
        return new IntSet(this.numSelected);
    }

    protected static int SFSSelectFeature(Set<Integer> available, DataSet dataSet, Set<Integer> catToRemove, Set<Integer> numToRemove, Set<Integer> catSelecteed, Set<Integer> numSelected, Object evaluater, int folds, Random rand, double[] PbestScore, int minFeatures) {
        int nCat = dataSet.getNumCategoricalVars();
        int curBest = -1;
        double curBestScore = Double.POSITIVE_INFINITY;
        for (int feature : available) {
            SFS.removeFeature(feature, nCat, catToRemove, numToRemove);
            DataSet workOn = dataSet.shallowClone();
            RemoveAttributeTransform remove = new RemoveAttributeTransform(workOn, catToRemove, numToRemove);
            workOn.applyTransform(remove);
            double score = SFS.getScore(workOn, evaluater, folds, rand);
            if (score < curBestScore) {
                curBestScore = score;
                curBest = feature;
            }
            SFS.addFeature(feature, nCat, catToRemove, numToRemove);
        }
        if (curBestScore <= 1.0E-14 && PbestScore[0] <= 1.0E-14 && catSelecteed.size() + numSelected.size() >= minFeatures) {
            return -1;
        }
        if (curBestScore < PbestScore[0] || catSelecteed.size() + numSelected.size() < minFeatures || Math.abs(PbestScore[0] - curBestScore) < 0.001) {
            PbestScore[0] = curBestScore;
            SFS.addFeature(curBest, nCat, catSelecteed, numSelected);
            SFS.removeFeature(curBest, nCat, catToRemove, numToRemove);
            available.remove(curBest);
            return curBest;
        }
        return -1;
    }

    protected static double getScore(DataSet workOn, Object evaluater, int folds, Random rand) {
        if (workOn instanceof ClassificationDataSet) {
            ClassificationModelEvaluation cme = new ClassificationModelEvaluation((Classifier)evaluater, (ClassificationDataSet)workOn);
            cme.evaluateCrossValidation(folds, rand);
            return cme.getErrorRate();
        }
        if (workOn instanceof RegressionDataSet) {
            RegressionModelEvaluation rme = new RegressionModelEvaluation((Regressor)evaluater, (RegressionDataSet)workOn);
            rme.evaluateCrossValidation(folds, rand);
            return rme.getMeanError();
        }
        return Double.POSITIVE_INFINITY;
    }

    public static class SFSFactory
    extends DataTransformFactoryParm {
        private double maxDecrease;
        private Classifier classifier;
        private Regressor regressor;
        private int minFeatures;
        private int maxFeatures;

        public SFSFactory(double maxDecrease, Classifier evaluater, int minFeatures, int maxFeatures) {
            this.setMaxDecrease(maxDecrease);
            this.classifier = evaluater;
            if (evaluater instanceof Regressor) {
                this.regressor = (Regressor)((Object)evaluater);
            }
            this.setMinFeatures(minFeatures);
            this.setMaxFeatures(maxFeatures);
        }

        public SFSFactory(double maxDecrease, Regressor evaluater, int minFeatures, int maxFeatures) {
            this.setMaxDecrease(maxDecrease);
            this.regressor = evaluater;
            if (evaluater instanceof Classifier) {
                this.classifier = (Classifier)((Object)evaluater);
            }
            this.setMinFeatures(minFeatures);
            this.setMaxFeatures(maxFeatures);
        }

        public SFSFactory(SFSFactory toCopy) {
            if (toCopy.classifier == toCopy.regressor) {
                this.classifier = toCopy.classifier.clone();
                this.regressor = (Regressor)((Object)this.classifier);
            } else if (toCopy.classifier != null) {
                this.classifier = toCopy.classifier.clone();
            } else if (toCopy.regressor != null) {
                this.regressor = toCopy.regressor.clone();
            } else {
                throw new RuntimeException("BUG: Please report");
            }
            this.maxDecrease = toCopy.maxDecrease;
            this.minFeatures = toCopy.minFeatures;
            this.maxFeatures = toCopy.maxFeatures;
        }

        public void setMaxDecrease(double maxDecrease) {
            if (maxDecrease < 0.0) {
                throw new IllegalArgumentException("Decarese must be a positive value, not " + maxDecrease);
            }
            this.maxDecrease = maxDecrease;
        }

        public double getMaxDecrease() {
            return this.maxDecrease;
        }

        public void setMinFeatures(int minFeatures) {
            this.minFeatures = minFeatures;
        }

        public int getMinFeatures() {
            return this.minFeatures;
        }

        public void setMaxFeatures(int maxFeatures) {
            this.maxFeatures = maxFeatures;
        }

        public int getMaxFeatures() {
            return this.maxFeatures;
        }

        @Override
        public SFS getTransform(DataSet dataset) {
            if (dataset instanceof ClassificationDataSet) {
                return new SFS(this.minFeatures, this.maxFeatures, (ClassificationDataSet)dataset, this.classifier, 5, this.maxDecrease);
            }
            return new SFS(this.minFeatures, this.maxFeatures, (RegressionDataSet)dataset, this.regressor, 5, this.maxDecrease);
        }

        @Override
        public SFSFactory clone() {
            return new SFSFactory(this);
        }
    }
}

