/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.pmml.consumer;

import java.io.Serializable;
import java.util.ArrayList;
import org.w3c.dom.Element;
import org.w3c.dom.Node;
import org.w3c.dom.NodeList;
import weka.classifiers.pmml.consumer.PMMLClassifier;
import weka.core.Attribute;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.RevisionUtils;
import weka.core.Utils;
import weka.core.matrix.Maths;
import weka.core.pmml.MiningSchema;
import weka.core.pmml.TargetMetaInfo;

public class Regression
extends PMMLClassifier
implements Serializable {
    private static final long serialVersionUID = -5551125528409488634L;
    protected String m_algorithmName;
    protected RegressionTable[] m_regressionTables;
    protected Normalization m_normalizationMethod = Normalization.NONE;

    public Regression(Element element, Instances instances, MiningSchema miningSchema) throws Exception {
        super(instances, miningSchema);
        int n = 0;
        String string = element.getAttribute("functionName");
        if (string.equals("regression")) {
            n = 0;
        } else if (string.equals("classification")) {
            n = 1;
        } else {
            throw new Exception("[PMML Regression] Function name not defined in pmml!");
        }
        String string2 = element.getAttribute("algorithmName");
        if (string2 != null && string2.length() > 0) {
            this.m_algorithmName = string2;
        }
        this.m_normalizationMethod = Regression.determineNormalization(element);
        this.setUpRegressionTables(element, n);
    }

    private void setUpRegressionTables(Element element, int n) throws Exception {
        NodeList nodeList = element.getElementsByTagName("RegressionTable");
        if (nodeList.getLength() == 0) {
            throw new Exception("[Regression] no regression tables defined!");
        }
        this.m_regressionTables = new RegressionTable[nodeList.getLength()];
        for (int i = 0; i < nodeList.getLength(); ++i) {
            RegressionTable regressionTable;
            Node node = nodeList.item(i);
            if (node.getNodeType() != 1) continue;
            this.m_regressionTables[i] = regressionTable = new RegressionTable((Element)node, n, this.m_miningSchema);
        }
    }

    private static Normalization determineNormalization(Element element) {
        Normalization normalization = Normalization.NONE;
        String string = element.getAttribute("normalizationMethod");
        if (string.equals("simplemax")) {
            normalization = Normalization.SIMPLEMAX;
        } else if (string.equals("softmax")) {
            normalization = Normalization.SOFTMAX;
        } else if (string.equals("logit")) {
            normalization = Normalization.LOGIT;
        } else if (string.equals("probit")) {
            normalization = Normalization.PROBIT;
        } else if (string.equals("cloglog")) {
            normalization = Normalization.CLOGLOG;
        } else if (string.equals("exp")) {
            normalization = Normalization.EXP;
        } else if (string.equals("loglog")) {
            normalization = Normalization.LOGLOG;
        } else if (string.equals("cauchit")) {
            normalization = Normalization.CAUCHIT;
        }
        return normalization;
    }

    public String toString() {
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append("PMML version " + this.getPMMLVersion());
        if (!this.getCreatorApplication().equals("?")) {
            stringBuffer.append("\nApplication: " + this.getCreatorApplication());
        }
        if (this.m_algorithmName != null) {
            stringBuffer.append("\nPMML Model: " + this.m_algorithmName);
        }
        stringBuffer.append("\n\n");
        stringBuffer.append(this.m_miningSchema);
        for (RegressionTable regressionTable : this.m_regressionTables) {
            stringBuffer.append(regressionTable);
        }
        if (this.m_normalizationMethod != Normalization.NONE) {
            stringBuffer.append("Normalization: " + (Object)((Object)this.m_normalizationMethod));
        }
        stringBuffer.append("\n");
        return stringBuffer.toString();
    }

    public double[] distributionForInstance(Instance instance) throws Exception {
        int n;
        if (!this.m_initialized) {
            this.mapToMiningSchema(instance.dataset());
        }
        double[] dArray = null;
        dArray = this.m_miningSchema.getFieldsAsInstances().classAttribute().isNumeric() ? new double[1] : new double[this.m_miningSchema.getFieldsAsInstances().classAttribute().numValues()];
        double[] dArray2 = this.m_fieldsMap.instanceToSchema(instance, this.m_miningSchema);
        boolean bl = false;
        for (n = 0; n < dArray2.length; ++n) {
            if (n == this.m_miningSchema.getFieldsAsInstances().classIndex() || !Instance.isMissingValue(dArray2[n])) continue;
            bl = true;
            break;
        }
        if (bl) {
            if (!this.m_miningSchema.hasTargetMetaData()) {
                String string = "[Regression] WARNING: Instance to predict has missing value(s) but there is no missing value handling meta data and no prior probabilities/default value to fall back to. No prediction will be made (" + (this.m_miningSchema.getFieldsAsInstances().classAttribute().isNominal() || this.m_miningSchema.getFieldsAsInstances().classAttribute().isString() ? "zero probabilities output)." : "NaN output).");
                if (this.m_log == null) {
                    System.err.println(string);
                } else {
                    this.m_log.logMessage(string);
                }
                if (this.m_miningSchema.getFieldsAsInstances().classAttribute().isNumeric()) {
                    dArray[0] = Instance.missingValue();
                }
                return dArray;
            }
            TargetMetaInfo targetMetaInfo = this.m_miningSchema.getTargetMetaData();
            if (this.m_miningSchema.getFieldsAsInstances().classAttribute().isNumeric()) {
                dArray[0] = targetMetaInfo.getDefaultValue();
            } else {
                Instances instances = this.m_miningSchema.getFieldsAsInstances();
                for (int i = 0; i < instances.classAttribute().numValues(); ++i) {
                    dArray[i] = targetMetaInfo.getPriorProbability(instances.classAttribute().value(i));
                }
            }
            return dArray;
        }
        for (n = 0; n < this.m_regressionTables.length; ++n) {
            this.m_regressionTables[n].predict(dArray, dArray2);
        }
        switch (this.m_normalizationMethod) {
            case NONE: {
                break;
            }
            case SIMPLEMAX: {
                Utils.normalize(dArray);
                break;
            }
            case SOFTMAX: {
                for (n = 0; n < dArray.length; ++n) {
                    dArray[n] = Math.exp(dArray[n]);
                }
                if (dArray.length == 1) {
                    dArray[0] = dArray[0] / (dArray[0] + 1.0);
                    break;
                }
                Utils.normalize(dArray);
                break;
            }
            case LOGIT: {
                for (n = 0; n < dArray.length; ++n) {
                    dArray[n] = 1.0 / (1.0 + Math.exp(-dArray[n]));
                }
                Utils.normalize(dArray);
                break;
            }
            case PROBIT: {
                for (n = 0; n < dArray.length; ++n) {
                    dArray[n] = Maths.pnorm(dArray[n]);
                }
                Utils.normalize(dArray);
                break;
            }
            case CLOGLOG: {
                for (n = 0; n < dArray.length; ++n) {
                    dArray[n] = 1.0 - Math.exp(-Math.exp(-dArray[n]));
                }
                Utils.normalize(dArray);
                break;
            }
            case EXP: {
                for (n = 0; n < dArray.length; ++n) {
                    dArray[n] = Math.exp(dArray[n]);
                }
                Utils.normalize(dArray);
                break;
            }
            case LOGLOG: {
                for (n = 0; n < dArray.length; ++n) {
                    dArray[n] = Math.exp(-Math.exp(-dArray[n]));
                }
                Utils.normalize(dArray);
                break;
            }
            case CAUCHIT: {
                for (n = 0; n < dArray.length; ++n) {
                    dArray[n] = 0.5 + 0.3183098861837907 * Math.atan(dArray[n]);
                }
                Utils.normalize(dArray);
                break;
            }
            default: {
                throw new Exception("[Regression] unknown normalization method");
            }
        }
        if (this.m_miningSchema.getFieldsAsInstances().classAttribute().isNumeric() && this.m_miningSchema.hasTargetMetaData()) {
            TargetMetaInfo targetMetaInfo = this.m_miningSchema.getTargetMetaData();
            dArray[0] = targetMetaInfo.applyMinMaxRescaleCast(dArray[0]);
        }
        return dArray;
    }

    public String getRevision() {
        return RevisionUtils.extract("$Revision: 4739 $");
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    static enum Normalization {
        NONE,
        SIMPLEMAX,
        SOFTMAX,
        LOGIT,
        PROBIT,
        CLOGLOG,
        EXP,
        LOGLOG,
        CAUCHIT;

    }

    static class RegressionTable
    implements Serializable {
        private static final long serialVersionUID = -5259866093996338995L;
        public static final int REGRESSION = 0;
        public static final int CLASSIFICATION = 1;
        protected int m_functionType = 0;
        protected MiningSchema m_miningSchema;
        protected double m_intercept = 0.0;
        protected int m_targetCategory = -1;
        protected ArrayList<Predictor> m_predictors = new ArrayList();
        protected ArrayList<PredictorTerm> m_predictorTerms = new ArrayList();

        public String toString() {
            int n;
            Instances instances = this.m_miningSchema.getFieldsAsInstances();
            StringBuffer stringBuffer = new StringBuffer();
            stringBuffer.append("Regression table:\n");
            stringBuffer.append(instances.classAttribute().name());
            if (this.m_functionType == 1) {
                stringBuffer.append("=" + instances.classAttribute().value(this.m_targetCategory));
            }
            stringBuffer.append(" =\n\n");
            for (n = 0; n < this.m_predictors.size(); ++n) {
                stringBuffer.append(this.m_predictors.get(n).toString() + " +\n");
            }
            for (n = 0; n < this.m_predictorTerms.size(); ++n) {
                stringBuffer.append(this.m_predictorTerms.get(n).toString() + " +\n");
            }
            stringBuffer.append(Utils.doubleToString(this.m_intercept, 12, 4));
            stringBuffer.append("\n\n");
            return stringBuffer.toString();
        }

        protected RegressionTable(Element element, int n, MiningSchema miningSchema) throws Exception {
            Object object;
            Object object2;
            int n2;
            Object object3;
            this.m_miningSchema = miningSchema;
            this.m_functionType = n;
            Instances instances = this.m_miningSchema.getFieldsAsInstances();
            String string = element.getAttribute("intercept");
            if (string.length() > 0) {
                this.m_intercept = Double.parseDouble(string);
            }
            if (this.m_functionType == 1) {
                object3 = element.getAttribute("targetCategory");
                if (((String)object3).length() > 0) {
                    Attribute attribute = instances.classAttribute();
                    for (n2 = 0; n2 < attribute.numValues(); ++n2) {
                        if (!attribute.value(n2).equals(object3)) continue;
                        this.m_targetCategory = n2;
                    }
                }
                if (this.m_targetCategory == -1) {
                    throw new Exception("[RegressionTable] No target categories defined for classification");
                }
            }
            object3 = element.getElementsByTagName("NumericPredictor");
            for (int i = 0; i < object3.getLength(); ++i) {
                Node node = object3.item(i);
                if (node.getNodeType() != 1) continue;
                object2 = new NumericPredictor((Element)node, instances);
                this.m_predictors.add((Predictor)object2);
            }
            NodeList nodeList = element.getElementsByTagName("CategoricalPredictor");
            for (n2 = 0; n2 < nodeList.getLength(); ++n2) {
                object2 = nodeList.item(n2);
                if (object2.getNodeType() != 1) continue;
                object = new CategoricalPredictor((Element)object2, instances);
                this.m_predictors.add((Predictor)object);
            }
            NodeList nodeList2 = element.getElementsByTagName("PredictorTerm");
            for (int i = 0; i < nodeList2.getLength(); ++i) {
                object = nodeList2.item(i);
                PredictorTerm predictorTerm = new PredictorTerm((Element)object, instances);
                this.m_predictorTerms.add(predictorTerm);
            }
        }

        public void predict(double[] dArray, double[] dArray2) {
            Serializable serializable;
            int n;
            if (this.m_targetCategory == -1) {
                dArray[0] = this.m_intercept;
            }
            for (n = 0; n < this.m_predictors.size(); ++n) {
                serializable = this.m_predictors.get(n);
                ((Predictor)serializable).add(dArray, dArray2);
            }
            for (n = 0; n < this.m_predictorTerms.size(); ++n) {
                serializable = this.m_predictorTerms.get(n);
                ((PredictorTerm)serializable).add(dArray, dArray2);
            }
        }

        protected class PredictorTerm
        implements Serializable {
            private static final long serialVersionUID = 5493100145890252757L;
            protected double m_coefficient = 1.0;
            protected int[] m_indexes;
            protected String[] m_fieldNames;

            protected PredictorTerm(Element element, Instances instances) throws Exception {
                NodeList nodeList;
                String string = element.getAttribute("coefficient");
                if (string != null && string.length() > 0) {
                    try {
                        this.m_coefficient = Double.parseDouble(string);
                    }
                    catch (IllegalArgumentException illegalArgumentException) {
                        throw new Exception("[PredictorTerm] unable to parse coefficient");
                    }
                }
                if ((nodeList = element.getElementsByTagName("FieldRef")).getLength() > 0) {
                    this.m_indexes = new int[nodeList.getLength()];
                    this.m_fieldNames = new String[nodeList.getLength()];
                    for (int i = 0; i < nodeList.getLength(); ++i) {
                        String string2;
                        Node node = nodeList.item(i);
                        if (node.getNodeType() != 1 || (string2 = ((Element)node).getAttribute("field")) == null || string2.length() <= 0) continue;
                        boolean bl = false;
                        for (int j = 0; j < instances.numAttributes(); ++j) {
                            if (!instances.attribute(j).name().equals(string2)) continue;
                            if (!instances.attribute(j).isNumeric()) {
                                throw new Exception("[PredictorTerm] field is not continuous: " + string2);
                            }
                            bl = true;
                            this.m_indexes[i] = j;
                            this.m_fieldNames[i] = string2;
                            break;
                        }
                        if (bl) continue;
                        throw new Exception("[PredictorTerm] Unable to find field " + string2 + " in mining schema!");
                    }
                }
            }

            public String toString() {
                StringBuffer stringBuffer = new StringBuffer();
                stringBuffer.append("(" + Utils.doubleToString(this.m_coefficient, 12, 4));
                for (int i = 0; i < this.m_fieldNames.length; ++i) {
                    stringBuffer.append(" * " + this.m_fieldNames[i]);
                }
                stringBuffer.append(")");
                return stringBuffer.toString();
            }

            public void add(double[] dArray, double[] dArray2) {
                int n = 0;
                if (RegressionTable.this.m_targetCategory != -1) {
                    n = RegressionTable.this.m_targetCategory;
                }
                double d = this.m_coefficient;
                for (int i = 0; i < this.m_indexes.length; ++i) {
                    d *= dArray2[this.m_indexes[i]];
                }
                int n2 = n;
                dArray[n2] = dArray[n2] + d;
            }
        }

        protected class CategoricalPredictor
        extends Predictor {
            private static final long serialVersionUID = 3077920125549906819L;
            protected String m_valueName;
            protected int m_valueIndex;

            protected CategoricalPredictor(Element element, Instances instances) throws Exception {
                super(element, instances);
                this.m_valueIndex = -1;
                String string = element.getAttribute("value");
                if (string.length() == 0) {
                    throw new Exception("[CategoricalPredictor] attribute value not specified!");
                }
                this.m_valueName = string;
                Attribute attribute = instances.attribute(this.m_miningSchemaAttIndex);
                if (attribute.isString()) {
                    attribute.addStringValue(this.m_valueName);
                }
                this.m_valueIndex = attribute.indexOfValue(this.m_valueName);
                if (this.m_valueIndex == -1) {
                    throw new Exception("[CategoricalPredictor] unable to find value " + this.m_valueName + " in mining schema attribute " + attribute.name());
                }
            }

            public String toString() {
                String string = super.toString();
                string = string + this.m_name + "=" + this.m_valueName;
                return string;
            }

            public void add(double[] dArray, double[] dArray2) {
                if (this.m_valueIndex == (int)dArray2[this.m_miningSchemaAttIndex]) {
                    if (RegressionTable.this.m_targetCategory == -1) {
                        dArray[0] = dArray[0] + this.m_coefficient;
                    } else {
                        int n = RegressionTable.this.m_targetCategory;
                        dArray[n] = dArray[n] + this.m_coefficient;
                    }
                }
            }
        }

        protected class NumericPredictor
        extends Predictor {
            private static final long serialVersionUID = -4335075205696648273L;
            protected double m_exponent;

            protected NumericPredictor(Element element, Instances instances) throws Exception {
                super(element, instances);
                this.m_exponent = 1.0;
                String string = element.getAttribute("exponent");
                if (string.length() > 0) {
                    this.m_exponent = Double.parseDouble(string);
                }
            }

            public String toString() {
                String string = super.toString();
                string = string + this.m_name;
                if (this.m_exponent > 1.0 || this.m_exponent < 1.0) {
                    string = string + "^" + Utils.doubleToString(this.m_exponent, 4);
                }
                return string;
            }

            public void add(double[] dArray, double[] dArray2) {
                if (RegressionTable.this.m_targetCategory == -1) {
                    dArray[0] = dArray[0] + this.m_coefficient * Math.pow(dArray2[this.m_miningSchemaAttIndex], this.m_exponent);
                } else {
                    int n = RegressionTable.this.m_targetCategory;
                    dArray[n] = dArray[n] + this.m_coefficient * Math.pow(dArray2[this.m_miningSchemaAttIndex], this.m_exponent);
                }
            }
        }

        static abstract class Predictor
        implements Serializable {
            private static final long serialVersionUID = 7043831847273383618L;
            protected String m_name;
            protected int m_miningSchemaAttIndex = -1;
            protected double m_coefficient = 1.0;

            protected Predictor(Element element, Instances instances) throws Exception {
                this.m_name = element.getAttribute("name");
                for (int i = 0; i < instances.numAttributes(); ++i) {
                    Attribute attribute = instances.attribute(i);
                    if (!attribute.name().equals(this.m_name)) continue;
                    this.m_miningSchemaAttIndex = i;
                }
                if (this.m_miningSchemaAttIndex == -1) {
                    throw new Exception("[Predictor] unable to find matching attribute for predictor " + this.m_name);
                }
                String string = element.getAttribute("coefficient");
                if (string.length() > 0) {
                    this.m_coefficient = Double.parseDouble(string);
                }
            }

            public String toString() {
                return Utils.doubleToString(this.m_coefficient, 12, 4) + " * ";
            }

            public abstract void add(double[] var1, double[] var2);
        }
    }
}

