/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.meta;

import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.RandomizableIteratedSingleClassifierEnhancer;
import weka.classifiers.Sourcable;
import weka.classifiers.rules.ZeroR;
import weka.classifiers.trees.DecisionStump;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.Randomizable;
import weka.core.RevisionUtils;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
import weka.filters.Filter;
import weka.filters.supervised.instance.SMOTE;

public class SMOTEBoost
extends RandomizableIteratedSingleClassifierEnhancer
implements WeightedInstancesHandler,
Sourcable,
TechnicalInformationHandler {
    private static final long serialVersionUID = -1397262307604906824L;
    private static int MAX_NUM_RESAMPLING_ITERATIONS = 10;
    protected double[] m_Betas;
    protected int m_NumIterationsPerformed;
    protected int m_WeightThreshold = 100;
    protected boolean m_UseResampling;
    protected int m_NumClasses;
    protected Classifier m_ZeroR;
    protected String m_SMOTE_ClassValueIndex = "0";
    protected int m_SMOTE_NearestNeighbors = 5;
    protected double m_SMOTE_Percentage = 100.0;
    protected int m_SMOTE_RandomSeed = 1;

    public SMOTEBoost() {
        this.m_Classifier = new DecisionStump();
    }

    public String globalInfo() {
        return "Class for boosting a binary class classifier using the SMOTEBoost method.\n\nFor more information, see\n\n" + this.getTechnicalInformation().toString();
    }

    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation result = new TechnicalInformation(TechnicalInformation.Type.INPROCEEDINGS);
        result.setValue(TechnicalInformation.Field.AUTHOR, "Nitesh V. Chawla, Aleksandar Lazarevic, Lawrence O.Hall and Kevin W. Bowyer");
        result.setValue(TechnicalInformation.Field.TITLE, "SMOTEBoost: Improving Prediction of the Minority Class in Boosting");
        result.setValue(TechnicalInformation.Field.BOOKTITLE, "7th European Conference on Principles and Practice of Knowledge Discovery in Databases (PKDD)");
        result.setValue(TechnicalInformation.Field.YEAR, "2003");
        result.setValue(TechnicalInformation.Field.PAGES, "107-119");
        result.setValue(TechnicalInformation.Field.ADDRESS, "Dubrovnik, Croatia");
        return result;
    }

    protected String defaultClassifierString() {
        return "weka.classifiers.trees.DecisionStump";
    }

    protected Instances selectWeightQuantile(Instances data, double quantile) {
        int numInstances = data.numInstances();
        Instances trainData = new Instances(data, numInstances);
        double[] weights = new double[numInstances];
        double sumOfWeights = 0.0;
        int i = 0;
        while (i < numInstances) {
            weights[i] = data.instance(i).weight();
            sumOfWeights += weights[i];
            ++i;
        }
        double weightMassToSelect = sumOfWeights * quantile;
        int[] sortedIndices = Utils.sort((double[])weights);
        sumOfWeights = 0.0;
        int i2 = numInstances - 1;
        while (i2 >= 0) {
            Instance instance = (Instance)data.instance(sortedIndices[i2]).copy();
            trainData.add(instance);
            if ((sumOfWeights += weights[sortedIndices[i2]]) > weightMassToSelect && i2 > 0 && weights[sortedIndices[i2]] != weights[sortedIndices[i2 - 1]]) break;
            --i2;
        }
        if (this.m_Debug) {
            System.err.println("Selected " + trainData.numInstances() + " out of " + numInstances);
        }
        return trainData;
    }

    public Enumeration listOptions() {
        Vector<Object> newVector = new Vector<Object>();
        newVector.addElement(new Option("\tPercentage of weight mass to base training on.\n\t(default 100, reduce to around 90 speed up)", "P", 1, "-P <num>"));
        newVector.addElement(new Option("\tUse resampling for boosting.", "Q", 0, "-Q"));
        Enumeration enu = super.listOptions();
        while (enu.hasMoreElements()) {
            newVector.addElement(enu.nextElement());
        }
        return newVector.elements();
    }

    public void setOptions(String[] options) throws Exception {
        String classValueIndexStr;
        String nnStr;
        String percentageStr;
        String thresholdString = Utils.getOption((char)'P', (String[])options);
        if (thresholdString.length() != 0) {
            this.setWeightThreshold(Integer.parseInt(thresholdString));
        } else {
            this.setWeightThreshold(100);
        }
        this.setUseResampling(Utils.getFlag((char)'Q', (String[])options));
        String smoteSeedStr = Utils.getOption((String)"smoteS", (String[])options);
        if (smoteSeedStr.length() != 0) {
            this.setSMOTE_RandomSeed(Integer.parseInt(smoteSeedStr));
        }
        if ((percentageStr = Utils.getOption((String)"smoteP", (String[])options)).length() != 0) {
            this.setSMOTE_Percentage(new Double(percentageStr));
        }
        if ((nnStr = Utils.getOption((char)'K', (String[])options)).length() != 0) {
            this.setSMOTE_NearestNeighbors(Integer.parseInt(nnStr));
        }
        if ((classValueIndexStr = Utils.getOption((char)'C', (String[])options)).length() != 0) {
            this.setSMOTE_ClassValue(classValueIndexStr);
        }
        super.setOptions(options);
    }

    public String[] getOptions() {
        Vector<String> result = new Vector<String>();
        result.add("-C");
        result.add(this.getSMOTE_ClassValue());
        result.add("-K");
        result.add("" + this.getSMOTE_NearestNeighbors());
        result.add("-smoteP");
        result.add("" + this.getSMOTE_Percentage());
        result.add("-smoteS");
        result.add("" + this.getSMOTE_RandomSeed());
        if (this.getUseResampling()) {
            result.add("-Q");
        }
        result.add("-P");
        result.add("" + this.getWeightThreshold());
        String[] options = super.getOptions();
        int i = 0;
        while (i < options.length) {
            result.add(options[i]);
            ++i;
        }
        return result.toArray(new String[result.size()]);
    }

    public String weightThresholdTipText() {
        return "Weight threshold for weight pruning.";
    }

    public void setWeightThreshold(int threshold) {
        this.m_WeightThreshold = threshold;
    }

    public int getWeightThreshold() {
        return this.m_WeightThreshold;
    }

    public String useResamplingTipText() {
        return "Whether resampling is used instead of reweighting.";
    }

    public void setUseResampling(boolean r) {
        this.m_UseResampling = r;
    }

    public boolean getUseResampling() {
        return this.m_UseResampling;
    }

    public String SMOTE_randomSeedTipText() {
        return "The seed used for random sampling.";
    }

    public int getSMOTE_RandomSeed() {
        return this.m_SMOTE_RandomSeed;
    }

    public void setSMOTE_RandomSeed(int value) {
        this.m_SMOTE_RandomSeed = value;
    }

    public String SMOTE_percentageTipText() {
        return "The percentage of SMOTE instances to create.";
    }

    public void setSMOTE_Percentage(double value) {
        if (value >= 0.0) {
            this.m_SMOTE_Percentage = value;
        } else {
            System.err.println("Percentage must be >= 0!");
        }
    }

    public double getSMOTE_Percentage() {
        return this.m_SMOTE_Percentage;
    }

    public String SMOTE_nearestNeighborsTipText() {
        return "The number of nearest neighbors to use.";
    }

    public void setSMOTE_NearestNeighbors(int value) {
        if (value >= 1) {
            this.m_SMOTE_NearestNeighbors = value;
        } else {
            System.err.println("At least 1 neighbor necessary!");
        }
    }

    public int getSMOTE_NearestNeighbors() {
        return this.m_SMOTE_NearestNeighbors;
    }

    public String SMOTE_classValueTipText() {
        return "The index of the class value to which SMOTE should be applied. Use a value of 0 to auto-detect the non-empty minority class.";
    }

    public void setSMOTE_ClassValue(String value) {
        this.m_SMOTE_ClassValueIndex = value;
    }

    public String getSMOTE_ClassValue() {
        return this.m_SMOTE_ClassValueIndex;
    }

    public SMOTE initSMOTE() {
        SMOTE smote = new SMOTE();
        smote.setRandomSeed(this.getSMOTE_RandomSeed());
        smote.setPercentage(this.getSMOTE_Percentage());
        smote.setNearestNeighbors(this.getSMOTE_NearestNeighbors());
        smote.setClassValue(this.getSMOTE_ClassValue());
        return smote;
    }

    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAllClasses();
        result.disableAllClassDependencies();
        if (super.getCapabilities().handles(Capabilities.Capability.NOMINAL_CLASS)) {
            result.enable(Capabilities.Capability.NOMINAL_CLASS);
        }
        if (super.getCapabilities().handles(Capabilities.Capability.BINARY_CLASS)) {
            result.enable(Capabilities.Capability.BINARY_CLASS);
        }
        return result;
    }

    public void buildClassifier(Instances data) throws Exception {
        super.buildClassifier(data);
        this.getCapabilities().testWithFail(data);
        data = new Instances(data);
        data.deleteWithMissingClass();
        if (data.numAttributes() == 1) {
            System.err.println("Cannot build model (only class attribute present in data!), using ZeroR model instead!");
            this.m_ZeroR = new ZeroR();
            this.m_ZeroR.buildClassifier(data);
            return;
        }
        this.m_ZeroR = null;
        this.m_NumClasses = data.numClasses();
        if (!this.m_UseResampling && this.m_Classifier instanceof WeightedInstancesHandler) {
            this.buildClassifierWithWeights(data);
        } else {
            this.buildClassifierUsingResampling(data);
        }
    }

    protected void buildClassifierUsingResampling(Instances data) throws Exception {
        Instances[] training = new Instances[this.m_Classifiers.length];
        int numInstances = data.numInstances();
        Random randomInstance = new Random(this.m_Seed);
        int resamplingIterations = 0;
        this.m_Betas = new double[this.m_Classifiers.length];
        this.m_NumIterationsPerformed = 0;
        training[this.m_NumIterationsPerformed] = new Instances(data, 0, numInstances);
        if (this.m_Debug) {
            System.err.println("iteration: " + (this.m_NumIterationsPerformed + 1) + "\nWeights:");
            int i = 0;
            while (i < training[this.m_NumIterationsPerformed].size()) {
                System.err.print(String.valueOf(Math.rint(training[this.m_NumIterationsPerformed].get(i).weight() * 100.0) / 100.0) + ";");
                ++i;
            }
            System.err.println();
        }
        this.m_NumIterationsPerformed = 0;
        while (this.m_NumIterationsPerformed < this.m_Classifiers.length) {
            Evaluation evaluation;
            double epsilon;
            if (this.m_Debug) {
                System.err.println("Training classifier " + (this.m_NumIterationsPerformed + 1));
            }
            if (this.m_Debug) {
                System.err.println("iteration: " + (this.m_NumIterationsPerformed + 1) + " trainData sumOfWeights before SMOTE: " + training[this.m_NumIterationsPerformed].sumOfWeights() + " N\u00ba instances: " + training[this.m_NumIterationsPerformed].size());
            }
            Instances trainData = this.m_WeightThreshold < 100 ? this.selectWeightQuantile(training[this.m_NumIterationsPerformed], (double)this.m_WeightThreshold / 100.0) : new Instances(training[this.m_NumIterationsPerformed]);
            SMOTE smote = this.initSMOTE();
            smote.setInputFormat(trainData);
            trainData = Filter.useFilter((Instances)trainData, (Filter)smote);
            if (this.m_Debug) {
                System.err.println("iteration: " + (this.m_NumIterationsPerformed + 1) + " trainData sumOfWeights after SMOTE: " + trainData.sumOfWeights() + " N\u00ba instances: " + trainData.size());
            }
            resamplingIterations = 0;
            double[] weights = new double[trainData.numInstances()];
            int i = 0;
            while (i < weights.length) {
                weights[i] = trainData.instance(i).weight();
                ++i;
            }
            if (this.m_Debug) {
                System.err.println("iteration: " + (this.m_NumIterationsPerformed + 1) + "\nWeights:");
                i = 0;
                while (i < weights.length) {
                    System.err.print(String.valueOf(Math.rint(weights[i] * 100.0) / 100.0) + ";");
                    ++i;
                }
                System.err.println();
            }
            do {
                Instances sample = trainData.resampleWithWeights(randomInstance, weights);
                this.m_Classifiers[this.m_NumIterationsPerformed].buildClassifier(sample);
                evaluation = new Evaluation(data);
                evaluation.evaluateModel(this.m_Classifiers[this.m_NumIterationsPerformed], data, new Object[0]);
            } while (Utils.eq((double)(epsilon = evaluation.errorRate()), (double)0.0) && ++resamplingIterations < MAX_NUM_RESAMPLING_ITERATIONS);
            this.m_Betas[this.m_NumIterationsPerformed] = Math.log((1.0 - epsilon) / epsilon);
            double reweight = (1.0 - epsilon) / epsilon;
            if (this.m_Debug) {
                System.err.println("\terror rate = " + epsilon + "  beta = " + this.m_Betas[this.m_NumIterationsPerformed]);
            }
            if (this.m_NumIterationsPerformed + 1 < this.m_Classifiers.length) {
                training[this.m_NumIterationsPerformed + 1] = new Instances(training[this.m_NumIterationsPerformed], 0, numInstances);
                this.setWeights(training[this.m_NumIterationsPerformed + 1], reweight);
            }
            ++this.m_NumIterationsPerformed;
        }
    }

    protected void setWeights(Instances training, double reweight) throws Exception {
        Instance instance;
        double oldSumOfWeights = training.sumOfWeights();
        Enumeration enu = training.enumerateInstances();
        while (enu.hasMoreElements()) {
            instance = (Instance)enu.nextElement();
            if (Utils.eq((double)this.m_Classifiers[this.m_NumIterationsPerformed].classifyInstance(instance), (double)instance.classValue())) continue;
            instance.setWeight(instance.weight() * reweight);
        }
        double newSumOfWeights = training.sumOfWeights();
        enu = training.enumerateInstances();
        while (enu.hasMoreElements()) {
            instance = (Instance)enu.nextElement();
            instance.setWeight(instance.weight() * oldSumOfWeights / newSumOfWeights);
        }
    }

    protected void buildClassifierWithWeights(Instances data) throws Exception {
        Instances[] training = new Instances[this.m_Classifiers.length];
        int numInstances = data.numInstances();
        Random randomInstance = new Random(this.m_Seed);
        this.m_Betas = new double[this.m_Classifiers.length];
        this.m_NumIterationsPerformed = 0;
        training[this.m_NumIterationsPerformed] = new Instances(data, 0, numInstances);
        this.m_NumIterationsPerformed = 0;
        while (this.m_NumIterationsPerformed < this.m_Classifiers.length) {
            if (this.m_Debug) {
                System.err.println("Training classifier, LOOP: " + this.m_NumIterationsPerformed);
            }
            Instances trainData = this.m_WeightThreshold < 100 ? this.selectWeightQuantile(training[this.m_NumIterationsPerformed], (double)this.m_WeightThreshold / 100.0) : new Instances(training[this.m_NumIterationsPerformed], 0, numInstances);
            SMOTE smote = this.initSMOTE();
            smote.setInputFormat(trainData);
            trainData = Filter.useFilter((Instances)trainData, (Filter)smote);
            if (this.m_Classifiers[this.m_NumIterationsPerformed] instanceof Randomizable) {
                ((Randomizable)this.m_Classifiers[this.m_NumIterationsPerformed]).setSeed(randomInstance.nextInt());
            }
            this.m_Classifiers[this.m_NumIterationsPerformed].buildClassifier(trainData);
            Evaluation evaluation = new Evaluation(data);
            evaluation.evaluateModel(this.m_Classifiers[this.m_NumIterationsPerformed], data, new Object[0]);
            double epsilon = evaluation.errorRate();
            this.m_Betas[this.m_NumIterationsPerformed] = Math.log((1.0 - epsilon) / epsilon);
            double reweight = (1.0 - epsilon) / epsilon;
            if (this.m_Debug) {
                System.err.println("\terror rate = " + epsilon + "  beta = " + this.m_Betas[this.m_NumIterationsPerformed]);
            }
            if (this.m_NumIterationsPerformed + 1 < this.m_Classifiers.length) {
                training[this.m_NumIterationsPerformed + 1] = new Instances(training[this.m_NumIterationsPerformed], 0, numInstances);
                this.setWeights(training[this.m_NumIterationsPerformed + 1], reweight);
            }
            ++this.m_NumIterationsPerformed;
        }
    }

    public double[] distributionForInstance(Instance instance) throws Exception {
        if (this.m_ZeroR != null) {
            return this.m_ZeroR.distributionForInstance(instance);
        }
        if (this.m_NumIterationsPerformed == 0) {
            throw new Exception("No model built");
        }
        double[] sums = new double[instance.numClasses()];
        if (this.m_NumIterationsPerformed == 1) {
            return this.m_Classifiers[0].distributionForInstance(instance);
        }
        int i = 0;
        while (i < this.m_NumIterationsPerformed) {
            int n = (int)this.m_Classifiers[i].classifyInstance(instance);
            sums[n] = sums[n] + this.m_Betas[i];
            ++i;
        }
        return Utils.logs2probs((double[])sums);
    }

    public String toSource(String className) throws Exception {
        int i;
        if (this.m_NumIterationsPerformed == 0) {
            throw new Exception("No model built yet");
        }
        if (!(this.m_Classifiers[0] instanceof Sourcable)) {
            throw new Exception("Base learner " + this.m_Classifier.getClass().getName() + " is not Sourcable");
        }
        StringBuffer text = new StringBuffer("class ");
        text.append(className).append(" {\n\n");
        text.append("  public static double classify(Object[] i) {\n");
        if (this.m_NumIterationsPerformed == 1) {
            text.append("    return " + className + "_0.classify(i);\n");
        } else {
            text.append("    double [] sums = new double [" + this.m_NumClasses + "];\n");
            i = 0;
            while (i < this.m_NumIterationsPerformed) {
                text.append("    sums[(int) " + className + '_' + i + ".classify(i)] += " + this.m_Betas[i] + ";\n");
                ++i;
            }
            text.append("    double maxV = sums[0];\n    int maxI = 0;\n    for (int j = 1; j < " + this.m_NumClasses + "; j++) {\n" + "      if (sums[j] > maxV) { maxV = sums[j]; maxI = j; }\n" + "    }\n    return (double) maxI;\n");
        }
        text.append("  }\n}\n");
        i = 0;
        while (i < this.m_Classifiers.length) {
            text.append(((Sourcable)this.m_Classifiers[i]).toSource(String.valueOf(className) + '_' + i));
            ++i;
        }
        return text.toString();
    }

    public String toString() {
        if (this.m_ZeroR != null) {
            StringBuffer buf = new StringBuffer();
            buf.append(String.valueOf(((Object)((Object)this)).getClass().getName().replaceAll(".*\\.", "")) + "\n");
            buf.append(String.valueOf(((Object)((Object)this)).getClass().getName().replaceAll(".*\\.", "").replaceAll(".", "=")) + "\n\n");
            buf.append("Warning: No model could be built, hence ZeroR model is used:\n\n");
            buf.append(this.m_ZeroR.toString());
            return buf.toString();
        }
        StringBuffer text = new StringBuffer();
        if (this.m_NumIterationsPerformed == 0) {
            text.append("SMOTEBoost: No model built yet.\n");
        } else if (this.m_NumIterationsPerformed == 1) {
            text.append("SMOTEBoost: No boosting possible, one classifier used!\n");
            text.append(String.valueOf(this.m_Classifiers[0].toString()) + "\n");
        } else {
            text.append("SMOTEBoost: Base classifiers and their weights: \n\n");
            int i = 0;
            while (i < this.m_NumIterationsPerformed) {
                text.append(String.valueOf(this.m_Classifiers[i].toString()) + "\n\n");
                text.append("Weight: " + Utils.roundDouble((double)this.m_Betas[i], (int)2) + "\n\n");
                ++i;
            }
            text.append("Number of performed Iterations: " + this.m_NumIterationsPerformed + "\n");
        }
        return text.toString();
    }

    public String getRevision() {
        return RevisionUtils.extract((String)"$Revision: 1 $");
    }

    public static void main(String[] argv) {
        SMOTEBoost.runClassifier((Classifier)new SMOTEBoost(), (String[])argv);
    }
}

