/*
 * Decompiled with CFR 0.152.
 */
package jsat.graphing;

import java.awt.Color;
import java.awt.Graphics;
import java.awt.Graphics2D;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.ClassificationModelEvaluation;
import jsat.classifiers.Classifier;
import jsat.graphing.Graph2D;
import jsat.graphing.ProgressPanel;
import jsat.guitool.GUIUtils;
import jsat.utils.IndexTable;

public class ROCPlot
extends Graph2D {
    private static final long serialVersionUID = -5956609892534834431L;
    private double[][][] curves;
    private List<Color> categoryColors;
    private List<String> names;
    private volatile boolean readyToDraw = false;

    private ROCPlot(int numCurves) {
        super(0.0, 1.0, 0.0, 1.0);
        this.setYAxisTtile("True Positive Rate");
        this.setXAxisTtile("False Positive Rate");
        this.curves = new double[numCurves][2][];
        this.names = new ArrayList<String>(numCurves);
        this.categoryColors = GUIUtils.getDistinctColors(numCurves);
    }

    public ROCPlot(List<String> names, ClassificationModelEvaluation ... evaluations) {
        this(names, Arrays.asList(evaluations));
    }

    public ROCPlot(List<String> names, List<ClassificationModelEvaluation> evaluations) {
        this(names.size());
        this.names.addAll(names);
        for (int i = 0; i < evaluations.size(); ++i) {
            this.curves[i][0] = new double[evaluations.get(i).getTruths().length];
            this.curves[i][1] = new double[this.curves[i][0].length];
        }
        this.computeCurves(evaluations);
    }

    public ROCPlot(final ClassificationDataSet dataSet, final int folds, final Classifier ... classifiers) {
        this(classifiers.length);
        if (dataSet.getPredicting().getNumOfCategories() != 2) {
            throw new RuntimeException("ROC curves can only be done for binarry classification problems");
        }
        final ArrayList cmes = new ArrayList(classifiers.length);
        Thread thread = new Thread(new Runnable(){

            @Override
            public void run() {
                for (int i = 0; i < classifiers.length; ++i) {
                    ROCPlot.this.names.add(classifiers[i].getClass().getSimpleName());
                    ((ROCPlot)ROCPlot.this).curves[i][0] = new double[dataSet.getSampleSize()];
                    ((ROCPlot)ROCPlot.this).curves[i][1] = new double[dataSet.getSampleSize()];
                    ClassificationModelEvaluation cme = new ClassificationModelEvaluation(classifiers[i].clone(), dataSet);
                    cme.keepPredictions(true);
                    cme.evaluateCrossValidation(folds);
                    cmes.add(cme);
                }
                ROCPlot.this.computeCurves(cmes);
            }
        });
        thread.start();
    }

    @Override
    protected void paintWork(Graphics g, int imageWidth, int imageHeight, ProgressPanel pp) {
        super.paintWork(g, imageWidth, imageHeight, pp);
        if (!this.readyToDraw) {
            String message = "Computing ROC Curves...";
            int strWidth = g.getFontMetrics().stringWidth(message);
            g.drawString(message, imageWidth / 2 - strWidth / 2, imageHeight / 2);
            return;
        }
        g.setColor(Color.DARK_GRAY);
        g.drawLine(this.toXCord(0.0), this.toYCord(0.0), this.toXCord(1.0), this.toYCord(1.0));
        for (int ci = 0; ci < this.curves.length; ++ci) {
            g.setColor(this.categoryColors.get(ci));
            double prevTP = 0.0;
            double prevFP = 0.0;
            for (int i = 0; i < this.curves[ci][0].length - 1; ++i) {
                double nextTP = this.curves[ci][0][i];
                double nextFP = this.curves[ci][1][i];
                g.drawLine(this.toXCord(prevFP), this.toYCord(prevTP), this.toXCord(nextFP), this.toYCord(nextTP));
                prevFP = nextFP;
                prevTP = nextTP;
            }
        }
        this.drawKey((Graphics2D)g, 3, this.names, this.categoryColors, null);
    }

    private void computeCurves(List<ClassificationModelEvaluation> cmes) {
        for (int ci = 0; ci < cmes.size(); ++ci) {
            ClassificationModelEvaluation cme = cmes.get(ci);
            CategoricalResults[] results = cme.getPredictions();
            IndexTable it = new IndexTable(Arrays.asList(results), new Comparator<CategoricalResults>(){

                @Override
                public int compare(CategoricalResults t, CategoricalResults t1) {
                    return -Double.compare(t.getProb(0), t1.getProb(0));
                }
            });
            int[] truth = cme.getTruths();
            double[] weights = cme.getPointWeights();
            for (int i = 0; i < results.length; ++i) {
                double TP = 0.0;
                double TN = 0.0;
                double FN = 0.0;
                double FP = 0.0;
                int origIndx = it.index(i);
                double thresh = results[origIndx].getProb(0);
                for (int j = 0; j < results.length; ++j) {
                    boolean predClass;
                    double weight = weights[j];
                    int trueClass = truth[j];
                    boolean bl = predClass = !(results[j].getProb(0) >= thresh);
                    if (trueClass == 0 && !predClass) {
                        TP += weight;
                        continue;
                    }
                    if (trueClass == 1 && predClass) {
                        TN += weight;
                        continue;
                    }
                    if (trueClass == 1 && !predClass) {
                        FP += weight;
                        continue;
                    }
                    if (trueClass != 0 || !predClass) continue;
                    FN += weight;
                }
                this.curves[ci][0][i] = TP / (TP + FN);
                this.curves[ci][1][i] = FP + TN > 0.0 ? FP / (FP + TN) : 1.0;
            }
        }
        this.readyToDraw = true;
        this.forceRedraw();
        this.repaint();
    }
}

