/*
 * Decompiled with CFR 0.152.
 */
package jsat.guitool;

import java.awt.BorderLayout;
import java.awt.Component;
import java.awt.FlowLayout;
import java.awt.Frame;
import java.awt.GridLayout;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.util.ArrayList;
import java.util.List;
import javax.swing.JButton;
import javax.swing.JCheckBox;
import javax.swing.JDialog;
import javax.swing.JPanel;
import javax.swing.JScrollPane;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.MultinomialLogisticRegression;
import jsat.classifiers.bayesian.AODE;
import jsat.classifiers.bayesian.BestClassDistribution;
import jsat.classifiers.bayesian.MultivariateNormals;
import jsat.classifiers.bayesian.NaiveBayes;
import jsat.classifiers.bayesian.NaiveBayesUpdateable;
import jsat.classifiers.boosting.AdaBoostM1;
import jsat.classifiers.boosting.Bagging;
import jsat.classifiers.boosting.SAMME;
import jsat.classifiers.knn.DANN;
import jsat.classifiers.knn.LWL;
import jsat.classifiers.knn.NearestNeighbour;
import jsat.classifiers.neuralnetwork.LVQ;
import jsat.classifiers.neuralnetwork.Perceptron;
import jsat.classifiers.neuralnetwork.SOM;
import jsat.classifiers.svm.PlatSMO;
import jsat.classifiers.trees.DecisionStump;
import jsat.classifiers.trees.DecisionTree;
import jsat.classifiers.trees.RandomForest;
import jsat.distributions.kernels.LinearKernel;
import jsat.distributions.kernels.RBFKernel;
import jsat.distributions.multivariate.MetricKDE;
import jsat.guitool.ParameterPanel;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.EuclideanDistance;
import jsat.parameters.DoubleParameter;
import jsat.parameters.GridSearch;
import jsat.parameters.Parameterized;

public class ClassifierSelectionDialog
extends JDialog {
    private static final long serialVersionUID = 208983866163325774L;
    private Frame owner;
    private List<ClassifierInfo> listInUse;
    private static final List<ClassifierInfo> weakClassifiers = new ArrayList<ClassifierInfo>(){
        private static final long serialVersionUID = -7137111995902822951L;
        {
            this.add(new ClassifierInfo(){

                @Override
                public Classifier getNewClassifier() {
                    return new DecisionStump();
                }
            });
            this.add(new ClassifierInfo(){

                @Override
                public Classifier getNewClassifier() {
                    return new DecisionTree();
                }
            });
            this.add(new ClassifierInfo(){

                @Override
                public Classifier getNewClassifier() {
                    return new NaiveBayesUpdateable();
                }
            });
        }
    };
    private static final List<ClassifierInfo> possClass = new ArrayList<ClassifierInfo>(){
        private static final long serialVersionUID = -8100532426504012898L;
        {
            this.addAll(weakClassifiers);
            this.add(new ClassifierInfoWeakLearner(){

                @Override
                public Classifier getNewClassifier(Classifier weakLearner) {
                    return new Bagging(weakLearner);
                }
            });
            this.add(new ClassifierInfo(){

                @Override
                public Classifier getNewClassifier() {
                    return new RandomForest(200);
                }
            });
            this.add(new ClassifierInfoWeakLearner(){

                @Override
                public Classifier getNewClassifier(Classifier weakLearner) {
                    return new AdaBoostM1(weakLearner, 200);
                }
            });
            this.add(new ClassifierInfoWeakLearner(){

                @Override
                public Classifier getNewClassifier(Classifier weakLearner) {
                    return new SAMME(weakLearner, 100);
                }
            });
            this.add(new ClassifierInfo(){

                @Override
                public Classifier getNewClassifier() {
                    return new NaiveBayes();
                }
            });
            this.add(new ClassifierInfo(){

                @Override
                public Classifier getNewClassifier() {
                    return new AODE();
                }
            });
            this.add(new ClassifierInfo(){

                @Override
                public Classifier getNewClassifier() {
                    return new MultivariateNormals(true);
                }
            });
            this.add(new ClassifierInfo(){

                @Override
                public boolean canTrain(ClassificationDataSet cds) {
                    return cds.getNumCategoricalVars() == 0;
                }

                @Override
                public Classifier getNewClassifier() {
                    return new NearestNeighbour(1);
                }
            });
            this.add(new ClassifierInfo(){

                @Override
                public Classifier getNewClassifier() {
                    return new DANN();
                }
            });
            this.add(new ClassifierInfoWeakLearner(){

                @Override
                public Classifier getNewClassifier(Classifier weakLearner) {
                    return new LWL(weakLearner, 40, (DistanceMetric)new EuclideanDistance());
                }
            });
            this.add(new ClassifierInfo(){

                @Override
                public String toString() {
                    return "Kernel Density Estimator";
                }

                @Override
                public Classifier getNewClassifier() {
                    return new BestClassDistribution(new MetricKDE());
                }
            });
            this.add(new ClassifierInfo(){

                @Override
                public boolean canTrain(ClassificationDataSet cds) {
                    return cds.getNumCategoricalVars() == 0;
                }

                @Override
                public Classifier getNewClassifier() {
                    return new MultinomialLogisticRegression();
                }
            });
            this.add(new ClassifierInfo(){

                @Override
                public Classifier getNewClassifier() {
                    return new SOM(5, 5);
                }
            });
            this.add(new ClassifierInfo(){

                @Override
                public Classifier getNewClassifier() {
                    return new LVQ(new EuclideanDistance(), 200);
                }
            });
            this.add(new ClassifierInfo(){

                @Override
                public Classifier getNewClassifier() {
                    return new Perceptron();
                }
            });
            this.add(new ClassifierInfo(){

                @Override
                public String toString() {
                    return "SVM-PlatSMO-Linear Kernel w/ Grid Search";
                }

                @Override
                public Classifier getNewClassifier() {
                    GridSearch gSearch = new GridSearch(new PlatSMO(new LinearKernel()), 5){
                        private static final long serialVersionUID = 5389228909478509872L;
                        {
                            DoubleParameter paramC = (DoubleParameter)((Parameterized)((Object)this.getBaseClassifier())).getParameter("C");
                            this.addParameter(paramC, 2.0E-5, 0.002, 0.2, 20.0, 2000.0, 200000.0, 2.0E7, 2.0E9, 2.0E11, 2.0E13, 2.0E15);
                        }
                    };
                    return gSearch;
                }
            });
            this.add(new ClassifierInfo(){

                @Override
                public String toString() {
                    return "SVM-PlatSMO-RBF Kernel w/ Grid Search";
                }

                @Override
                public Classifier getNewClassifier() {
                    GridSearch g = new GridSearch(new PlatSMO(new RBFKernel(2.0)), 5){
                        private static final long serialVersionUID = 5216566364915038112L;
                        {
                            DoubleParameter paramC = (DoubleParameter)((Parameterized)((Object)this.getBaseClassifier())).getParameter("C");
                            DoubleParameter paramRBF = (DoubleParameter)((Parameterized)((Object)this.getBaseClassifier())).getParameter("RBFKernel_sigma");
                            this.addParameter(paramC, 2.0E-5, 0.002, 0.2, 20.0, 2000.0, 200000.0, 2.0E7, 2.0E9, 2.0E11, 2.0E13, 2.0E15);
                            this.addParameter(paramRBF, 2.0E15, 2.0E13, 2.0E11, 2.0E9, 2.0E7, 200000.0, 2000.0, 20.0, 0.2, 0.002);
                        }
                    };
                    return g;
                }
            });
        }
    };
    private ClassificationDataSet dataSet;
    private List<JCheckBox> checkBoxes;
    private boolean hitCancel = false;
    private List<String> names;

    public ClassifierSelectionDialog(ClassificationDataSet dataSet, Frame owner) {
        this(dataSet, owner, "Classifier Selection Dialog", false);
    }

    public ClassifierSelectionDialog(ClassificationDataSet dataSet, Frame owner, String title, boolean weakOnly) {
        super(owner, title, true);
        this.owner = owner;
        this.dataSet = dataSet;
        this.setLayout(new BorderLayout());
        JPanel jPanel = new JPanel(new GridLayout(possClass.size(), 1));
        this.checkBoxes = new ArrayList<JCheckBox>();
        this.listInUse = weakOnly ? weakClassifiers : possClass;
        for (ClassifierInfo cinf : this.listInUse) {
            JCheckBox checkBox = new JCheckBox(cinf.toString());
            checkBox.setEnabled(cinf.canTrain(dataSet));
            jPanel.add(checkBox);
            this.checkBoxes.add(checkBox);
        }
        this.add((Component)new JScrollPane(jPanel), "Center");
        jPanel = new JPanel(new FlowLayout());
        JButton jButton = new JButton("Ok");
        jButton.addActionListener(new ActionListener(){

            @Override
            public void actionPerformed(ActionEvent e) {
                ClassifierSelectionDialog.this.setVisible(false);
            }
        });
        jPanel.add(jButton);
        jButton = new JButton("Cancel");
        jButton.addActionListener(new ActionListener(){

            @Override
            public void actionPerformed(ActionEvent e) {
                ClassifierSelectionDialog.this.hitCancel = true;
                ClassifierSelectionDialog.this.setVisible(false);
            }
        });
        jPanel.add(jButton);
        this.add((Component)jPanel, "South");
    }

    public boolean isCanceled() {
        return this.hitCancel;
    }

    public List<Classifier> getSelectedClassifiers() {
        this.names = new ArrayList<String>();
        ArrayList<Classifier> classifiers = new ArrayList<Classifier>();
        for (int i = 0; i < this.listInUse.size(); ++i) {
            if (!this.checkBoxes.get(i).isSelected()) continue;
            if (possClass.get(i) instanceof ClassifierInfoWeakLearner) {
                ClassifierInfoWeakLearner ciwl = (ClassifierInfoWeakLearner)possClass.get(i);
                ClassifierSelectionDialog weakSelect = new ClassifierSelectionDialog(this.dataSet, this.owner, "Select weak learner for " + ciwl.toString(), true);
                weakSelect.setSize(400, 400);
                weakSelect.setVisible(true);
                if (weakSelect.isCanceled()) continue;
                List<Classifier> selected = weakSelect.getSelectedClassifiers();
                List<String> selectedName = weakSelect.getSelectedNames();
                for (int z = 0; z < selected.size(); ++z) {
                    Classifier weak = selected.get(z);
                    Classifier finalClass = ciwl.getNewClassifier(weak);
                    classifiers.add(finalClass);
                    this.names.add(ciwl.toString() + " using " + selectedName.get(z));
                    if (!(finalClass instanceof Parameterized)) continue;
                    ParameterPanel.showParameterDiag(this.owner, "Select Parameters for " + this.names.get(this.names.size() - 1), (Parameterized)((Object)finalClass));
                }
                continue;
            }
            Classifier finalClass = possClass.get(i).getNewClassifier();
            classifiers.add(finalClass);
            this.names.add(possClass.get(i).toString());
            if (!(finalClass instanceof Parameterized)) continue;
            ParameterPanel.showParameterDiag(this.owner, "Select Parameters for " + this.names.get(this.names.size() - 1), (Parameterized)((Object)finalClass));
        }
        return classifiers;
    }

    public List<String> getSelectedNames() {
        return this.names;
    }

    private static abstract class ClassifierInfoWeakLearner
    extends ClassifierInfo {
        private ClassifierInfoWeakLearner() {
        }

        public abstract Classifier getNewClassifier(Classifier var1);

        @Override
        public Classifier getNewClassifier() {
            return this.getNewClassifier(new DecisionStump());
        }
    }

    private static abstract class ClassifierInfo {
        private ClassifierInfo() {
        }

        public boolean canTrain(ClassificationDataSet cds) {
            return true;
        }

        public String toString() {
            return this.getNewClassifier().getClass().getSimpleName();
        }

        public abstract Classifier getNewClassifier();
    }
}

