/*
 * Decompiled with CFR 0.152.
 */
package jsat.datatransform.featureselection;

import java.util.Random;
import java.util.Set;
import jsat.DataSet;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.datatransform.DataTransform;
import jsat.datatransform.DataTransformFactoryParm;
import jsat.datatransform.RemoveAttributeTransform;
import jsat.datatransform.featureselection.SBS;
import jsat.datatransform.featureselection.SFS;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;
import jsat.utils.IntSet;
import jsat.utils.ListUtils;

public class BDS
implements DataTransform {
    private static final long serialVersionUID = 8633823674617843754L;
    private RemoveAttributeTransform finalTransform;
    private Set<Integer> catSelected;
    private Set<Integer> numSelected;

    public BDS(BDS toClone) {
        if (toClone.finalTransform != null) {
            this.finalTransform = toClone.finalTransform.clone();
            this.catSelected = new IntSet(toClone.catSelected);
            this.numSelected = new IntSet(toClone.numSelected);
        }
    }

    public BDS(int featureCount, ClassificationDataSet dataSet, Classifier evaluator, int folds) {
        this.search(dataSet, featureCount, folds, evaluator);
    }

    public BDS(int featureCount, RegressionDataSet dataSet, Regressor evaluator, int folds) {
        this.search(dataSet, featureCount, folds, evaluator);
    }

    @Override
    public DataPoint transform(DataPoint dp) {
        return this.finalTransform.transform(dp);
    }

    @Override
    public BDS clone() {
        return new BDS(this);
    }

    public Set<Integer> getSelectedCategorical() {
        return new IntSet(this.catSelected);
    }

    public Set<Integer> getSelectedNumerical() {
        return new IntSet(this.numSelected);
    }

    private void search(DataSet dataSet, int maxFeatures, int folds, Object evaluator) {
        Random rand = new Random();
        int nF = dataSet.getNumFeatures();
        int nCat = dataSet.getNumCategoricalVars();
        this.catSelected = new IntSet(dataSet.getNumCategoricalVars());
        this.numSelected = new IntSet(dataSet.getNumNumericalVars());
        IntSet availableSFS = new IntSet();
        ListUtils.addRange(availableSFS, 0, nF, 1);
        IntSet catToRemoveSFS = new IntSet(dataSet.getNumCategoricalVars());
        IntSet numToRemoveSFS = new IntSet(dataSet.getNumNumericalVars());
        ListUtils.addRange(catToRemoveSFS, 0, nCat, 1);
        ListUtils.addRange(numToRemoveSFS, 0, nF - nCat, 1);
        IntSet availableSBS = new IntSet();
        ListUtils.addRange(availableSBS, 0, nF, 1);
        IntSet catSelecteedSBS = new IntSet(dataSet.getNumCategoricalVars());
        IntSet numSelectedSBS = new IntSet(dataSet.getNumNumericalVars());
        IntSet catToRemoveSBS = new IntSet(dataSet.getNumCategoricalVars());
        IntSet numToRemoveSBS = new IntSet(dataSet.getNumNumericalVars());
        ListUtils.addRange(catSelecteedSBS, 0, nCat, 1);
        ListUtils.addRange(numSelectedSBS, 0, nF - nCat, 1);
        double[] pBestScore0 = new double[]{Double.POSITIVE_INFINITY};
        double[] pBestScore1 = new double[]{Double.POSITIVE_INFINITY};
        int max = Math.min(maxFeatures, nF / 2);
        for (int i = 0; i < max; ++i) {
            int mustKeep = SFS.SFSSelectFeature(availableSFS, dataSet, catToRemoveSFS, numToRemoveSFS, this.catSelected, this.numSelected, evaluator, folds, rand, pBestScore0, max);
            availableSBS.remove(mustKeep);
            SFS.removeFeature(mustKeep, nCat, catToRemoveSBS, numToRemoveSBS);
            int mustRemove = SBS.SBSRemoveFeature(availableSBS, dataSet, catToRemoveSBS, numToRemoveSBS, catSelecteedSBS, numSelectedSBS, evaluator, folds, rand, max, pBestScore1, 0.0);
            availableSFS.remove(mustRemove);
            SFS.addFeature(mustRemove, nCat, catToRemoveSFS, numToRemoveSFS);
        }
        catSelecteedSBS.clear();
        numToRemoveSBS.clear();
        ListUtils.addRange(catSelecteedSBS, 0, nCat, 1);
        ListUtils.addRange(numSelectedSBS, 0, nF - nCat, 1);
        catSelecteedSBS.removeAll(this.catSelected);
        numSelectedSBS.removeAll(this.numSelected);
        this.finalTransform = new RemoveAttributeTransform(dataSet, catSelecteedSBS, numSelectedSBS);
    }

    public static class BDSFactory
    extends DataTransformFactoryParm {
        private Classifier classifier;
        private Regressor regressor;
        private int featureCount;

        public BDSFactory(Classifier evaluater, int featureCount) {
            this.classifier = evaluater;
            if (evaluater instanceof Regressor) {
                this.regressor = (Regressor)((Object)evaluater);
            }
            this.setFeatureCount(featureCount);
        }

        public BDSFactory(Regressor evaluater, int featureCount) {
            this.regressor = evaluater;
            if (evaluater instanceof Classifier) {
                this.classifier = (Classifier)((Object)evaluater);
            }
            this.setFeatureCount(featureCount);
        }

        public BDSFactory(BDSFactory toCopy) {
            if (toCopy.classifier == toCopy.regressor) {
                this.classifier = toCopy.classifier.clone();
                this.regressor = (Regressor)((Object)this.classifier);
            } else if (toCopy.classifier != null) {
                this.classifier = toCopy.classifier.clone();
            } else if (toCopy.regressor != null) {
                this.regressor = toCopy.regressor.clone();
            } else {
                throw new RuntimeException("BUG: Please report");
            }
            this.featureCount = toCopy.featureCount;
        }

        public void setFeatureCount(int featureCount) {
            if (featureCount < 1) {
                throw new IllegalArgumentException("Number of features to select must be positive, not " + featureCount);
            }
            this.featureCount = featureCount;
        }

        public int getFeatureCount() {
            return this.featureCount;
        }

        @Override
        public BDS getTransform(DataSet dataset) {
            if (dataset instanceof ClassificationDataSet) {
                return new BDS(this.featureCount, (ClassificationDataSet)dataset, this.classifier, 5);
            }
            return new BDS(this.featureCount, (RegressionDataSet)dataset, this.regressor, this.featureCount);
        }

        @Override
        public BDSFactory clone() {
            return new BDSFactory(this);
        }
    }
}

