/*
 * Decompiled with CFR 0.152.
 */
package dmLab.experiment.classification;

import dmLab.array.Array;
import dmLab.array.FArray;
import dmLab.array.functions.DiscFunctions;
import dmLab.array.loader.File2Array;
import dmLab.array.saver.Array2File;
import dmLab.classifier.Classifier;
import dmLab.classifier.Prediction;
import dmLab.classifier.PredictionResult;
import dmLab.experiment.classification.ClassificationParams;
import dmLab.utils.ArrayUtils;
import dmLab.utils.MathUtils;
import dmLab.utils.cmatrix.ConfusionMatrix;
import java.io.IOException;
import java.util.Arrays;
import java.util.Random;

public final class ClassificationBody {
    public ClassificationParams classParams;
    public Classifier classifier;
    private File2Array file2Container;
    private Array2File array2file;
    private float learningTime;
    private float testingTime;
    private float experimentTime;
    public PredictionResult predResult;
    private double[] predQualityVector;
    private DiscFunctions selectFunctions$49f14b5d;

    /*
     * WARNING - void declaration
     */
    public ClassificationBody(Random random) {
        void var1_1;
        this.selectFunctions$49f14b5d = new DiscFunctions((Random)var1_1);
        this.file2Container = new File2Array();
        this.array2file = new Array2File();
        this.cleanTimeStats();
    }

    private void cleanTimeStats() {
        this.learningTime = 0.0f;
        this.testingTime = 0.0f;
        this.experimentTime = 0.0f;
    }

    public final float run() {
        Float f;
        FArray[] fArrayArray = null;
        fArrayArray = null;
        fArrayArray = this;
        if (this.classifier == null && !fArrayArray.initClassifier()) {
            f = null;
        } else {
            FArray[] fArrayArray2;
            FArray[] fArrayArray3 = fArrayArray;
            FArray fArray = new FArray();
            FArray fArray2 = null;
            if (fArrayArray3.classParams.verbose) {
                System.out.println("Loading input data...");
            }
            if (!fArrayArray3.file2Container.load(fArray, String.valueOf(fArrayArray3.classParams.inputFilesPATH) + fArrayArray3.classParams.inputFileName)) {
                fArrayArray2 = null;
            } else if (!fArray.checkDecisionValues()) {
                fArrayArray2 = null;
            } else {
                if (fArrayArray3.classParams.verbose) {
                    System.out.println("Input data loaded.");
                }
                if (fArrayArray3.classParams.validationType == 3) {
                    if (fArrayArray3.classParams.verbose) {
                        System.out.println("Loading testing data...");
                    }
                    fArray2 = new FArray();
                    new FArray().dictionary = fArray.dictionary.clone();
                    fArray2.setDecValues(fArray.getDecValues());
                    fArray2.setDecAttrIdx(fArray.getDecAttrIdx());
                    fArrayArray3.file2Container.load(fArray2, String.valueOf(fArrayArray3.classParams.inputFilesPATH) + fArrayArray3.classParams.testFileName);
                    if (fArrayArray3.classParams.verbose) {
                        System.out.println("Testing data loaded.");
                    }
                }
                FArray[] fArrayArray4 = new FArray[2];
                fArrayArray3 = fArrayArray4;
                fArrayArray4[0] = fArray;
                fArrayArray3[1] = fArray2;
                fArrayArray2 = fArrayArray3 = fArrayArray3;
            }
            if (fArrayArray2 == null) {
                f = null;
            } else if (fArrayArray.classParams == null || !fArrayArray.classifier.params.check((FArray)((Object)fArrayArray3[0]))) {
                f = null;
            } else {
                if (fArrayArray.classParams.validationType == 2) {
                    fArrayArray.runCV(fArrayArray3[0]);
                } else {
                    fArrayArray.runTrainTest(fArrayArray3[0], fArrayArray3[1]);
                }
                f = Float.valueOf((float)fArrayArray.predResult.getPredQuality());
            }
        }
        return f.floatValue();
    }

    /*
     * WARNING - void declaration
     */
    public final boolean setParameters(ClassificationParams classParams) {
        void var1_1;
        this.classParams = classParams;
        return var1_1.check(null);
    }

    /*
     * WARNING - void declaration
     */
    public final boolean loadParameters(String paramsFileName) {
        this.classParams = new ClassificationParams();
        if (!this.classParams.load("", paramsFileName)) {
            void var1_1;
            System.err.println("Error loading configuration file. File: " + (String)var1_1);
            return false;
        }
        if (this.classParams.verbose) {
            System.out.println(this.classParams.toString());
        }
        return this.classParams.check(null);
    }

    public final boolean initClassifier() {
        this.classifier = Classifier.getClassifier(this.classParams.model);
        if (!this.classifier.params.load(this.classParams.classifierCfgPATH, this.classifier.label)) {
            return false;
        }
        this.classifier.init();
        this.classifier.setTempPath(this.classParams.resFilesPATH);
        this.predResult = new PredictionResult(this.classifier.modelType);
        if (this.classParams.verbose) {
            System.out.println(this.classifier.params.toString());
        }
        this.cleanTimeStats();
        return true;
    }

    /*
     * WARNING - void declaration
     */
    private Array[] split(FArray array, int[] splitMask) {
        void var2_2;
        Array[] arrayArray;
        if (splitMask == null) {
            if (this.classParams.splitType == 1) {
                splitMask = this.selectFunctions$49f14b5d.getSplitMaskRandom(array, this.classParams.splitRatio);
            } else if (this.classParams.splitType == 2) {
                splitMask = this.selectFunctions$49f14b5d.getSplitMaskUniform(array, this.classParams.splitRatio);
            } else {
                System.err.println("classParams.splitType does not equal to SPLIT_RANDOM or SPLIT_UNIFORM.");
                return null;
            }
        }
        arrayArray = DiscFunctions.split((Array)arrayArray, (int[])var2_2);
        return arrayArray;
    }

    /*
     * WARNING - void declaration
     */
    private boolean savePredictionArray(FArray array, String fileName) {
        void var2_2;
        void var3_3;
        FArray predictionArray = array.clone();
        String[] decValues = array.getDecValuesStr();
        int[] scoreIndex = new int[decValues.length];
        boolean saveScores = true;
        if (this.predResult.predictions[0].getScores() == null) {
            saveScores = false;
        }
        if (saveScores) {
            int i = 0;
            while (i < decValues.length) {
                String scoreAttrName = "score_" + decValues[i];
                DiscFunctions.addAttribute(predictionArray, scoreAttrName);
                scoreIndex[i] = predictionArray.getColIndex(scoreAttrName);
                predictionArray.attributes[scoreIndex[i]].type = (short)2;
                ++i;
            }
        }
        DiscFunctions.addAttribute(predictionArray, "prediction");
        int predictionIndex = predictionArray.getColIndex("prediction");
        int rows = predictionArray.rowsNumber();
        int j = 0;
        while (j < rows) {
            if (saveScores) {
                int i = 0;
                while (i < decValues.length) {
                    predictionArray.writeValue(scoreIndex[i], j, (float)this.predResult.predictions[j].getScores()[i]);
                    ++i;
                }
            }
            predictionArray.writeValueStr(predictionIndex, j, this.predResult.predictions[j].getPredicted());
            ++j;
        }
        this.array2file.setFormat(0);
        this.array2file.saveFile(predictionArray, String.valueOf(fileName) + "_pred");
        this.array2file.setFormat(3);
        this.array2file.saveFile((Array)var3_3, String.valueOf(var2_2) + "_pred");
        return true;
    }

    /*
     * WARNING - void declaration
     */
    public final PredictionResult runCV(FArray trainArray) {
        void var2_2;
        if (this.classParams.verbose) {
            System.out.println("Running mult CV...");
        }
        ConfusionMatrix cMatrix = null;
        if (trainArray.isTargetNominal()) {
            cMatrix = new ConfusionMatrix(trainArray.getColNames(true)[trainArray.getDecAttrIdx()], trainArray.getDecValuesStr());
        }
        this.predQualityVector = new double[this.classParams.repetitions];
        int i = 0;
        while (i < this.classParams.repetitions) {
            double start = System.currentTimeMillis();
            String label = "_rep" + Integer.toString(i + 1);
            PredictionResult singlePredResult = this.singleCV(trainArray, label);
            this.experimentTime = (float)((double)this.experimentTime + ((double)System.currentTimeMillis() - start) / 1000.0);
            this.predQualityVector[i] = singlePredResult.getPredQuality();
            this.predResult.predictions = singlePredResult.predictions;
            if (cMatrix != null) {
                ConfusionMatrix singleMatrix = singlePredResult.confusionMatrix;
                cMatrix.add(singleMatrix);
            }
            String experimentLabel = String.valueOf(this.classParams.label) + Integer.toString(i + 1);
            if (this.classParams.savePredictionResult) {
                this.savePredictionArray(trainArray, String.valueOf(this.classParams.resFilesPATH) + "//" + experimentLabel);
            }
            if (this.classParams.verbose) {
                System.out.println("\n##### CV " + Integer.toString(i + 1) + " RESULT #####");
                System.out.println(singlePredResult.toString());
            }
            System.gc();
            ++i;
        }
        this.predResult.confusionMatrix = var2_2;
        this.finalizeTimeStats();
        return this.predResult;
    }

    /*
     * WARNING - void declaration
     */
    private PredictionResult singleCV(FArray array, String label) {
        void var3_3;
        ConfusionMatrix cMatrix = null;
        if (array.isTargetNominal()) {
            cMatrix = new ConfusionMatrix(array.getColNames(true)[array.getDecAttrIdx()], array.getDecValuesStr());
        }
        int cvFolds = this.classParams.folds;
        int rows = array.rowsNumber();
        int[] cvTable = new int[rows];
        int[] splitMask = new int[rows];
        int n = cvFolds;
        int[] nArray = cvTable;
        ArrayUtils arrayUtils = this.selectFunctions$49f14b5d.arrayUtils;
        int n2 = 0;
        while (n2 < nArray.length) {
            nArray[n2] = (int)(arrayUtils.random.nextFloat() * (float)n);
            ++n2;
        }
        Prediction[] predictions = new Prediction[rows];
        int i = 0;
        while (i < cvFolds) {
            int j = 0;
            while (j < rows) {
                splitMask[j] = cvTable[j] == i ? 0 : 1;
                ++j;
            }
            Array[] cvArrays = this.split(array, splitMask);
            PredictionResult singlePredResult = this.singleTrainTest((FArray)cvArrays[0], (FArray)cvArrays[1]);
            this.learningTime = (float)((double)this.learningTime + this.classifier.getLearningTime());
            this.testingTime = (float)((double)this.testingTime + this.classifier.getTestingTime());
            if (cMatrix != null) {
                PredictionResult predictionResult = singlePredResult;
                cMatrix.add(predictionResult.confusionMatrix);
            }
            PredictionResult predictionResult = singlePredResult;
            Prediction[] singlePrediction = predictionResult.predictions;
            int k = 0;
            int j2 = 0;
            while (j2 < rows) {
                if (cvTable[j2] == i) {
                    predictions[j2] = singlePrediction[k++];
                }
                ++j2;
            }
            String experimentLabel = String.valueOf(this.classParams.label) + label + "_fold" + Integer.toString(i + 1);
            if (this.classParams.saveClassifier) {
                try {
                    this.classifier.saveDefinition(this.classParams.resFilesPATH, experimentLabel);
                }
                catch (IOException e) {
                    System.err.println("Error saving classifier.");
                    e.printStackTrace();
                }
            }
            ++i;
        }
        PredictionResult myPredResult = new PredictionResult(this.classifier.modelType);
        new PredictionResult(this.classifier.modelType).predictions = predictions;
        myPredResult.confusionMatrix = var3_3;
        return myPredResult;
    }

    /*
     * WARNING - void declaration
     */
    public final PredictionResult runTrainTest(FArray trainArray, FArray testArray) {
        void var3_3;
        if (testArray == null) {
            if (this.classParams.verbose) {
                System.out.println("MultTrainTest - split training set...");
            }
        } else if (this.classParams.verbose) {
            System.out.println("MultTrainTest - training set & testing set...");
        }
        ConfusionMatrix cMatrix = null;
        if (trainArray.isTargetNominal()) {
            cMatrix = new ConfusionMatrix(trainArray.getColNames(true)[trainArray.getDecAttrIdx()], trainArray.getDecValuesStr());
        }
        this.predQualityVector = new double[this.classParams.repetitions];
        int i = 0;
        while (i < this.classParams.repetitions) {
            Array[] ttArrays;
            if (testArray == null) {
                ttArrays = this.split(trainArray, null);
            } else {
                Array[] arrayArray = new Array[2];
                ttArrays = arrayArray;
                arrayArray[0] = trainArray;
                ttArrays[1] = testArray;
            }
            double start2 = System.currentTimeMillis();
            PredictionResult singlePredResult = this.singleTrainTest((FArray)ttArrays[0], (FArray)ttArrays[1]);
            this.experimentTime = (float)((double)this.experimentTime + ((double)System.currentTimeMillis() - start2) / 1000.0);
            this.learningTime = (float)((double)this.learningTime + this.classifier.getLearningTime());
            this.testingTime = (float)((double)this.testingTime + this.classifier.getTestingTime());
            this.predQualityVector[i] = singlePredResult.getPredQuality();
            this.predResult.predictions = singlePredResult.predictions;
            if (cMatrix != null) {
                PredictionResult start2 = singlePredResult;
                ConfusionMatrix singleMatrix = start2.confusionMatrix;
                cMatrix.add(singleMatrix);
            }
            String experimentLabel = String.valueOf(this.classParams.label) + Integer.toString(i + 1);
            if (this.classParams.saveClassifier) {
                try {
                    this.classifier.saveDefinition(this.classParams.resFilesPATH, experimentLabel);
                }
                catch (IOException e) {
                    System.err.println("Error saving classifier.");
                    e.printStackTrace();
                }
            }
            if (this.classParams.savePredictionResult) {
                this.savePredictionArray(testArray, String.valueOf(this.classParams.resFilesPATH) + "//" + experimentLabel);
            }
            if (this.classParams.verbose) {
                System.out.println("\n##### SPLIT " + Integer.toString(i + 1) + " RESULT #####");
                System.out.println(singlePredResult.toString());
            }
            ++i;
        }
        this.predResult.confusionMatrix = var3_3;
        this.finalizeTimeStats();
        return this.predResult;
    }

    /*
     * WARNING - void declaration
     */
    private PredictionResult singleTrainTest(FArray trainArray, FArray testArray) {
        void var2_2;
        PredictionResult predictionResult;
        this.classifier.train((FArray)((Object)predictionResult));
        this.classifier.test((FArray)var2_2);
        predictionResult = this.classifier.getPredResult();
        return predictionResult;
    }

    /*
     * WARNING - void declaration
     */
    private void finalizeTimeStats() {
        void var1_1;
        float repetitions = this.classParams.repetitions;
        if (this.classParams.validationType == 2) {
            repetitions *= (float)this.classParams.folds;
        }
        this.experimentTime /= repetitions;
        this.learningTime /= repetitions;
        this.testingTime /= var1_1;
    }

    /*
     * WARNING - void declaration
     */
    public final String toStringResults() {
        void var1_1;
        void var2_2;
        StringBuffer tmp = new StringBuffer();
        tmp.append("\n#######  FINAL RESULT  #######").append("\n");
        tmp.append("repetitions: " + this.classParams.repetitions).append("\n");
        tmp.append(this.predResult.toString()).append("\n");
        int folds = 1;
        if (this.classParams.validationType == 2) {
            folds = this.classParams.folds;
        }
        tmp.append("Average Repetition time: " + DiscFunctions.formatFloat(this.experimentTime * (float)var2_2, 2) + " s.").append("\n");
        tmp.append("Average Experiment time: " + DiscFunctions.formatFloat(this.experimentTime, 2) + " s.").append("\n");
        tmp.append("Average Learning time: " + DiscFunctions.formatFloat(this.learningTime, 2) + " s.").append("\n");
        tmp.append("Average Testing time: " + DiscFunctions.formatFloat(this.testingTime, 2) + " s.").append("\n");
        if (this.predQualityVector != null) {
            tmp.append("\nwPredQuality: " + Arrays.toString(this.predQualityVector)).append("\n");
            tmp.append("\nwPredQuality variance: " + DiscFunctions.formatFloat(MathUtils.variance(this.predQualityVector), 5)).append("\n");
        }
        return var1_1.toString();
    }
}

