/*
 * Decompiled with CFR 0.152.
 */
package jsat.datatransform.featureselection;

import java.util.Iterator;
import java.util.Random;
import java.util.Set;
import jsat.DataSet;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.datatransform.DataTransform;
import jsat.datatransform.DataTransformFactoryParm;
import jsat.datatransform.RemoveAttributeTransform;
import jsat.datatransform.featureselection.SBS;
import jsat.datatransform.featureselection.SFS;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;
import jsat.utils.IntSet;
import jsat.utils.ListUtils;

public class LRS
implements DataTransform {
    private static final long serialVersionUID = 3065300352046535656L;
    private RemoveAttributeTransform finalTransform;
    private Set<Integer> catSelected;
    private Set<Integer> numSelected;

    private LRS(LRS toClone) {
        if (toClone.catSelected != null) {
            this.finalTransform = toClone.finalTransform.clone();
            this.catSelected = new IntSet(toClone.catSelected);
            this.numSelected = new IntSet(toClone.numSelected);
        }
    }

    public LRS(int L, int R, ClassificationDataSet cds, Classifier evaluater, int folds) {
        this.search(cds, L, R, evaluater, folds);
    }

    public LRS(int L, int R, RegressionDataSet rds, Regressor evaluater, int folds) {
        this.search(rds, L, R, evaluater, folds);
    }

    @Override
    public DataPoint transform(DataPoint dp) {
        return this.finalTransform.transform(dp);
    }

    @Override
    public LRS clone() {
        return new LRS(this);
    }

    public Set<Integer> getSelectedCategorical() {
        return new IntSet(this.catSelected);
    }

    public Set<Integer> getSelectedNumerical() {
        return new IntSet(this.numSelected);
    }

    private void search(DataSet cds, int L, int R, Object evaluater, int folds) {
        int nF = cds.getNumFeatures();
        int nCat = cds.getNumCategoricalVars();
        this.catSelected = new IntSet(nCat);
        this.numSelected = new IntSet(nF - nCat);
        IntSet catToRemove = new IntSet(nCat);
        IntSet numToRemove = new IntSet(nF - nCat);
        IntSet available = new IntSet(nF);
        ListUtils.addRange(available, 0, nF, 1);
        Random rand = new Random();
        double[] pBestScore = new double[]{Double.POSITIVE_INFINITY};
        if (L > R) {
            ListUtils.addRange(catToRemove, 0, nCat, 1);
            ListUtils.addRange(numToRemove, 0, nF - nCat, 1);
            for (int i = 0; i < L; ++i) {
                SFS.SFSSelectFeature(available, cds, catToRemove, numToRemove, this.catSelected, this.numSelected, evaluater, folds, rand, pBestScore, L);
            }
            available.clear();
            available.addAll(this.catSelected);
            for (int i : this.numSelected) {
                available.add(i + nCat);
            }
            for (int i = 0; i < R; ++i) {
                SBS.SBSRemoveFeature(available, cds, catToRemove, numToRemove, this.catSelected, this.numSelected, evaluater, folds, rand, L - R, pBestScore, 0.0);
            }
        } else if (L < R) {
            ListUtils.addRange(this.catSelected, 0, nCat, 1);
            ListUtils.addRange(this.numSelected, 0, nF - nCat, 1);
            for (int i = 0; i < R; ++i) {
                SBS.SBSRemoveFeature(available, cds, catToRemove, numToRemove, this.catSelected, this.numSelected, evaluater, folds, rand, nF - R, pBestScore, 0.0);
            }
            available.clear();
            available.addAll(catToRemove);
            Iterator i$ = numToRemove.iterator();
            while (i$.hasNext()) {
                int i = (Integer)i$.next();
                available.add(i + nCat);
            }
            for (int i = 0; i < L; ++i) {
                SFS.SFSSelectFeature(available, cds, catToRemove, numToRemove, this.catSelected, this.numSelected, evaluater, folds, rand, pBestScore, R - L);
            }
        }
        this.finalTransform = new RemoveAttributeTransform(cds, catToRemove, numToRemove);
    }

    public static class LRSFactory
    extends DataTransformFactoryParm {
        private Classifier classifier;
        private Regressor regressor;
        private int featuresToAdd;
        private int featuresToRemove;

        public LRSFactory(Classifier evaluater, int toAdd, int toRemove) {
            if (toAdd == toRemove) {
                throw new RuntimeException("L and R must be different");
            }
            this.classifier = evaluater;
            if (evaluater instanceof Regressor) {
                this.regressor = (Regressor)((Object)evaluater);
            }
            this.setFeaturesToAdd(toAdd);
            this.setFeaturesToRemove(toRemove);
        }

        public LRSFactory(Regressor evaluater, int toAdd, int toRemove) {
            if (toAdd == toRemove) {
                throw new RuntimeException("L and R must be different");
            }
            this.regressor = evaluater;
            if (evaluater instanceof Classifier) {
                this.classifier = (Classifier)((Object)evaluater);
            }
            this.setFeaturesToAdd(toAdd);
            this.setFeaturesToRemove(toRemove);
        }

        public LRSFactory(LRSFactory toCopy) {
            if (toCopy.classifier == toCopy.regressor) {
                this.classifier = toCopy.classifier.clone();
                this.regressor = (Regressor)((Object)this.classifier);
            } else if (toCopy.classifier != null) {
                this.classifier = toCopy.classifier.clone();
            } else if (toCopy.regressor != null) {
                this.regressor = toCopy.regressor.clone();
            } else {
                throw new RuntimeException("BUG: Please report");
            }
            this.featuresToAdd = toCopy.featuresToAdd;
            this.featuresToRemove = toCopy.featuresToRemove;
        }

        public void setFeaturesToAdd(int featuresToAdd) {
            if (featuresToAdd < 1) {
                throw new IllegalArgumentException("Number of features to add must be positive, not " + featuresToAdd);
            }
            this.featuresToAdd = featuresToAdd;
        }

        public int getFeaturesToAdd() {
            return this.featuresToAdd;
        }

        public void setFeaturesToRemove(int featuresToRemove) {
            if (featuresToRemove < 1) {
                throw new IllegalArgumentException("Number of features to remove must be positive, not " + featuresToRemove);
            }
            this.featuresToRemove = featuresToRemove;
        }

        public int getFeaturesToRemove() {
            return this.featuresToRemove;
        }

        @Override
        public LRS getTransform(DataSet dataset) {
            if (dataset instanceof ClassificationDataSet) {
                return new LRS(this.featuresToAdd, this.featuresToRemove, (ClassificationDataSet)dataset, this.classifier, 5);
            }
            return new LRS(this.featuresToAdd, this.featuresToRemove, (RegressionDataSet)dataset, this.regressor, 5);
        }

        @Override
        public LRSFactory clone() {
            return new LRSFactory(this);
        }
    }
}

