/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.meta;

import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.RandomizableIteratedSingleClassifierEnhancer;
import weka.classifiers.Sourcable;
import weka.classifiers.rules.ZeroR;
import weka.classifiers.trees.DecisionStump;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.Randomizable;
import weka.core.RevisionUtils;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
import weka.filters.Filter;
import weka.filters.supervised.instance.RUS;

public class RUSBoost
extends RandomizableIteratedSingleClassifierEnhancer
implements WeightedInstancesHandler,
Sourcable,
TechnicalInformationHandler {
    private static final long serialVersionUID = -4258566093517813109L;
    private static int MAX_NUM_RESAMPLING_ITERATIONS = 10;
    protected double[] m_Betas;
    protected int m_NumIterationsPerformed;
    protected int m_WeightThreshold = 100;
    protected boolean m_UseResampling;
    protected int m_NumClasses;
    protected Classifier m_ZeroR;
    protected double m_RUS_Percentage = 75.0;

    public RUSBoost() {
        this.m_Classifier = new DecisionStump();
    }

    public String globalInfo() {
        return "Class for boosting a binary class classifier using the RUSboost method.\n\nFor more information, see\n\n" + this.getTechnicalInformation().toString();
    }

    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation result = new TechnicalInformation(TechnicalInformation.Type.ARTICLE);
        result.setValue(TechnicalInformation.Field.AUTHOR, "C. Seiffert and  T. Khoshgoftaar and  J. Van Hulse and  A. Napolitano");
        result.setValue(TechnicalInformation.Field.TITLE, "Rusboost: A hybrid approach to alleviating class imbalance");
        result.setValue(TechnicalInformation.Field.BOOKTITLE, "IEEE Transactions on Systems and  Man and Cybernetics and  Part A");
        result.setValue(TechnicalInformation.Field.VOLUME, "40");
        result.setValue(TechnicalInformation.Field.YEAR, "2010");
        result.setValue(TechnicalInformation.Field.PAGES, "185-197");
        return result;
    }

    protected String defaultClassifierString() {
        return "weka.classifiers.trees.DecisionStump";
    }

    protected Instances selectWeightQuantile(Instances data, double quantile) {
        int numInstances = data.numInstances();
        Instances trainData = new Instances(data, numInstances);
        double[] weights = new double[numInstances];
        double sumOfWeights = 0.0;
        int i = 0;
        while (i < numInstances) {
            weights[i] = data.instance(i).weight();
            sumOfWeights += weights[i];
            ++i;
        }
        double weightMassToSelect = sumOfWeights * quantile;
        int[] sortedIndices = Utils.sort((double[])weights);
        sumOfWeights = 0.0;
        int i2 = numInstances - 1;
        while (i2 >= 0) {
            Instance instance = (Instance)data.instance(sortedIndices[i2]).copy();
            trainData.add(instance);
            if ((sumOfWeights += weights[sortedIndices[i2]]) > weightMassToSelect && i2 > 0 && weights[sortedIndices[i2]] != weights[sortedIndices[i2 - 1]]) break;
            --i2;
        }
        if (this.m_Debug) {
            System.err.println("Selected " + trainData.numInstances() + " out of " + numInstances);
        }
        return trainData;
    }

    public Enumeration listOptions() {
        Vector<Object> newVector = new Vector<Object>();
        newVector.addElement(new Option("\tPercentage of weight mass to base training on.\n\t(default 100, reduce to around 90 speed up)", "P", 1, "-P <num>"));
        newVector.addElement(new Option("\tUse resampling for boosting.", "Q", 0, "-Q"));
        newVector.addElement(new Option("\tSpecifies percentage of the proportion of majority class.\n\t(default 75.0)\n", "rusP", 1, "-rusP <percentage>"));
        Enumeration enu = super.listOptions();
        while (enu.hasMoreElements()) {
            newVector.addElement(enu.nextElement());
        }
        return newVector.elements();
    }

    public void setOptions(String[] options) throws Exception {
        String thresholdString = Utils.getOption((char)'P', (String[])options);
        if (thresholdString.length() != 0) {
            this.setWeightThreshold(Integer.parseInt(thresholdString));
        } else {
            this.setWeightThreshold(100);
        }
        this.setUseResampling(Utils.getFlag((char)'Q', (String[])options));
        String percentageStr = Utils.getOption((String)"rusP", (String[])options);
        if (percentageStr.length() != 0) {
            this.setRUS_Percentage(new Double(percentageStr));
        }
        super.setOptions(options);
    }

    public String[] getOptions() {
        Vector<String> result = new Vector<String>();
        result.add("-rusP");
        result.add("" + this.getRUS_Percentage());
        if (this.getUseResampling()) {
            result.add("-Q");
        }
        result.add("-P");
        result.add("" + this.getWeightThreshold());
        String[] options = super.getOptions();
        int i = 0;
        while (i < options.length) {
            result.add(options[i]);
            ++i;
        }
        return result.toArray(new String[result.size()]);
    }

    public String weightThresholdTipText() {
        return "Weight threshold for weight pruning.";
    }

    public void setWeightThreshold(int threshold) {
        this.m_WeightThreshold = threshold;
    }

    public int getWeightThreshold() {
        return this.m_WeightThreshold;
    }

    public String useResamplingTipText() {
        return "Whether resampling is used instead of reweighting.";
    }

    public void setUseResampling(boolean r) {
        this.m_UseResampling = r;
    }

    public boolean getUseResampling() {
        return this.m_UseResampling;
    }

    public String RUS_percentageTipText() {
        return "Specifies the proportion of majority class desired.";
    }

    public void setRUS_Percentage(double value) {
        if (value < 1.0 || value > 99.0) {
            throw new IllegalArgumentException("Percentage must be between 1 and 99.");
        }
        this.m_RUS_Percentage = value;
    }

    public double getRUS_Percentage() {
        return this.m_RUS_Percentage;
    }

    public RUS initRUS() {
        RUS rus = new RUS();
        rus.setPercentage(this.getRUS_Percentage());
        return rus;
    }

    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAll();
        result.enableAllAttributes();
        result.enable(Capabilities.Capability.BINARY_CLASS);
        return result;
    }

    public void buildClassifier(Instances data) throws Exception {
        super.buildClassifier(data);
        this.getCapabilities().testWithFail(data);
        data = new Instances(data);
        data.deleteWithMissingClass();
        if (data.numAttributes() == 1) {
            System.err.println("Cannot build model (only class attribute present in data!), using ZeroR model instead!");
            this.m_ZeroR = new ZeroR();
            this.m_ZeroR.buildClassifier(data);
            return;
        }
        this.m_ZeroR = null;
        this.m_NumClasses = data.numClasses();
        if (!this.m_UseResampling && this.m_Classifier instanceof WeightedInstancesHandler) {
            this.buildClassifierWithWeights(data);
        } else {
            this.buildClassifierUsingResampling(data);
        }
    }

    protected void buildClassifierUsingResampling(Instances data) throws Exception {
        Instances[] training = new Instances[this.m_Classifiers.length];
        int numInstances = data.numInstances();
        Random randomInstance = new Random(this.m_Seed);
        int resamplingIterations = 0;
        this.m_Betas = new double[this.m_Classifiers.length];
        this.m_NumIterationsPerformed = 0;
        training[this.m_NumIterationsPerformed] = new Instances(data, 0, numInstances);
        this.m_NumIterationsPerformed = 0;
        while (this.m_NumIterationsPerformed < this.m_Classifiers.length) {
            Evaluation evaluation;
            double epsilon;
            if (this.m_Debug) {
                System.err.println("Training classifier " + (this.m_NumIterationsPerformed + 1));
            }
            Instances trainData = this.m_WeightThreshold < 100 ? this.selectWeightQuantile(training[this.m_NumIterationsPerformed], (double)this.m_WeightThreshold / 100.0) : new Instances(training[this.m_NumIterationsPerformed]);
            RUS rus = this.initRUS();
            rus.setInputFormat(trainData);
            trainData = Filter.useFilter((Instances)trainData, (Filter)rus);
            resamplingIterations = 0;
            double[] weights = new double[trainData.numInstances()];
            int i = 0;
            while (i < weights.length) {
                weights[i] = trainData.instance(i).weight();
                ++i;
            }
            do {
                Instances sample = trainData.resampleWithWeights(randomInstance, weights);
                this.m_Classifiers[this.m_NumIterationsPerformed].buildClassifier(sample);
                evaluation = new Evaluation(data);
                evaluation.evaluateModel(this.m_Classifiers[this.m_NumIterationsPerformed], data, new Object[0]);
            } while (Utils.eq((double)(epsilon = evaluation.errorRate()), (double)0.0) && ++resamplingIterations < MAX_NUM_RESAMPLING_ITERATIONS);
            this.m_Betas[this.m_NumIterationsPerformed] = Math.log((1.0 - epsilon) / epsilon);
            double reweight = (1.0 - epsilon) / epsilon;
            if (this.m_Debug) {
                System.err.println("\terror rate = " + epsilon + "  beta = " + this.m_Betas[this.m_NumIterationsPerformed]);
            }
            if (this.m_NumIterationsPerformed + 1 < this.m_Classifiers.length) {
                training[this.m_NumIterationsPerformed + 1] = new Instances(training[this.m_NumIterationsPerformed], 0, numInstances);
                this.setWeights(training[this.m_NumIterationsPerformed + 1], reweight);
            }
            ++this.m_NumIterationsPerformed;
        }
    }

    protected void setWeights(Instances training, double reweight) throws Exception {
        Instance instance;
        double oldSumOfWeights = training.sumOfWeights();
        Enumeration enu = training.enumerateInstances();
        while (enu.hasMoreElements()) {
            instance = (Instance)enu.nextElement();
            if (Utils.eq((double)this.m_Classifiers[this.m_NumIterationsPerformed].classifyInstance(instance), (double)instance.classValue())) continue;
            instance.setWeight(instance.weight() * reweight);
        }
        double newSumOfWeights = training.sumOfWeights();
        enu = training.enumerateInstances();
        while (enu.hasMoreElements()) {
            instance = (Instance)enu.nextElement();
            instance.setWeight(instance.weight() * oldSumOfWeights / newSumOfWeights);
        }
    }

    protected void buildClassifierWithWeights(Instances data) throws Exception {
        Instances[] training = new Instances[this.m_Classifiers.length];
        int numInstances = data.numInstances();
        Random randomInstance = new Random(this.m_Seed);
        this.m_Betas = new double[this.m_Classifiers.length];
        this.m_NumIterationsPerformed = 0;
        training[this.m_NumIterationsPerformed] = new Instances(data, 0, numInstances);
        this.m_NumIterationsPerformed = 0;
        while (this.m_NumIterationsPerformed < this.m_Classifiers.length) {
            if (this.m_Debug) {
                System.err.println("Training classifier, LOOP: " + this.m_NumIterationsPerformed);
            }
            Instances trainData = this.m_WeightThreshold < 100 ? this.selectWeightQuantile(training[this.m_NumIterationsPerformed], (double)this.m_WeightThreshold / 100.0) : new Instances(training[this.m_NumIterationsPerformed], 0, numInstances);
            RUS rus = this.initRUS();
            rus.setInputFormat(trainData);
            trainData = Filter.useFilter((Instances)trainData, (Filter)rus);
            if (this.m_Classifiers[this.m_NumIterationsPerformed] instanceof Randomizable) {
                ((Randomizable)this.m_Classifiers[this.m_NumIterationsPerformed]).setSeed(randomInstance.nextInt());
            }
            this.m_Classifiers[this.m_NumIterationsPerformed].buildClassifier(trainData);
            Evaluation evaluation = new Evaluation(data);
            evaluation.evaluateModel(this.m_Classifiers[this.m_NumIterationsPerformed], data, new Object[0]);
            double epsilon = evaluation.errorRate();
            this.m_Betas[this.m_NumIterationsPerformed] = Math.log((1.0 - epsilon) / epsilon);
            double reweight = (1.0 - epsilon) / epsilon;
            if (this.m_Debug) {
                System.err.println("\terror rate = " + epsilon + "  beta = " + this.m_Betas[this.m_NumIterationsPerformed]);
            }
            if (this.m_NumIterationsPerformed + 1 < this.m_Classifiers.length) {
                training[this.m_NumIterationsPerformed + 1] = new Instances(training[this.m_NumIterationsPerformed], 0, numInstances);
                this.setWeights(training[this.m_NumIterationsPerformed + 1], reweight);
            }
            ++this.m_NumIterationsPerformed;
        }
    }

    public double[] distributionForInstance(Instance instance) throws Exception {
        if (this.m_ZeroR != null) {
            return this.m_ZeroR.distributionForInstance(instance);
        }
        if (this.m_NumIterationsPerformed == 0) {
            throw new Exception("No model built");
        }
        double[] sums = new double[instance.numClasses()];
        if (this.m_NumIterationsPerformed == 1) {
            return this.m_Classifiers[0].distributionForInstance(instance);
        }
        int i = 0;
        while (i < this.m_NumIterationsPerformed) {
            int n = (int)this.m_Classifiers[i].classifyInstance(instance);
            sums[n] = sums[n] + this.m_Betas[i];
            ++i;
        }
        return Utils.logs2probs((double[])sums);
    }

    public String toSource(String className) throws Exception {
        int i;
        if (this.m_NumIterationsPerformed == 0) {
            throw new Exception("No model built yet");
        }
        if (!(this.m_Classifiers[0] instanceof Sourcable)) {
            throw new Exception("Base learner " + this.m_Classifier.getClass().getName() + " is not Sourcable");
        }
        StringBuffer text = new StringBuffer("class ");
        text.append(className).append(" {\n\n");
        text.append("  public static double classify(Object[] i) {\n");
        if (this.m_NumIterationsPerformed == 1) {
            text.append("    return " + className + "_0.classify(i);\n");
        } else {
            text.append("    double [] sums = new double [" + this.m_NumClasses + "];\n");
            i = 0;
            while (i < this.m_NumIterationsPerformed) {
                text.append("    sums[(int) " + className + '_' + i + ".classify(i)] += " + this.m_Betas[i] + ";\n");
                ++i;
            }
            text.append("    double maxV = sums[0];\n    int maxI = 0;\n    for (int j = 1; j < " + this.m_NumClasses + "; j++) {\n" + "      if (sums[j] > maxV) { maxV = sums[j]; maxI = j; }\n" + "    }\n    return (double) maxI;\n");
        }
        text.append("  }\n}\n");
        i = 0;
        while (i < this.m_Classifiers.length) {
            text.append(((Sourcable)this.m_Classifiers[i]).toSource(String.valueOf(className) + '_' + i));
            ++i;
        }
        return text.toString();
    }

    public String toString() {
        if (this.m_ZeroR != null) {
            StringBuffer buf = new StringBuffer();
            buf.append(String.valueOf(((Object)((Object)this)).getClass().getName().replaceAll(".*\\.", "")) + "\n");
            buf.append(String.valueOf(((Object)((Object)this)).getClass().getName().replaceAll(".*\\.", "").replaceAll(".", "=")) + "\n\n");
            buf.append("Warning: No model could be built, hence ZeroR model is used:\n\n");
            buf.append(this.m_ZeroR.toString());
            return buf.toString();
        }
        StringBuffer text = new StringBuffer();
        if (this.m_NumIterationsPerformed == 0) {
            text.append("RUSBoost: No model built yet.\n");
        } else if (this.m_NumIterationsPerformed == 1) {
            text.append("RUSBoost: No boosting possible, one classifier used!\n");
            text.append(String.valueOf(this.m_Classifiers[0].toString()) + "\n");
        } else {
            text.append("RUSBoost: Base classifiers and their weights: \n\n");
            int i = 0;
            while (i < this.m_NumIterationsPerformed) {
                text.append(String.valueOf(this.m_Classifiers[i].toString()) + "\n\n");
                text.append("Weight: " + Utils.roundDouble((double)this.m_Betas[i], (int)2) + "\n\n");
                ++i;
            }
            text.append("Number of performed Iterations: " + this.m_NumIterationsPerformed + "\n");
        }
        return text.toString();
    }

    public String getRevision() {
        return RevisionUtils.extract((String)"$Revision: 1 $");
    }

    public static void main(String[] argv) {
        RUSBoost.runClassifier((Classifier)new RUSBoost(), (String[])argv);
    }
}

