/*
 * Decompiled with CFR 0.152.
 */
package com.nexr.rhive.hive.udf;

import java.io.ByteArrayOutputStream;
import java.io.PrintStream;
import java.util.Hashtable;
import java.util.Map;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
import org.rosuda.REngine.REXP;
import org.rosuda.REngine.REXPDouble;
import org.rosuda.REngine.REXPInteger;
import org.rosuda.REngine.REXPString;
import org.rosuda.REngine.Rserve.RConnection;

@Description(name="R", value="_FUNC_(export-name,arg1,arg2,...,return-type) - Returns the result of R scalar function")
public class RUDF
extends GenericUDF {
    private static Map<String, String> funclist = new Hashtable<String, String>();
    private static String NULL = "";
    private static int STRING_TYPE = 1;
    private static int NUMBER_TYPE = 0;
    private static RConnection rconnection;
    private ObjectInspectorConverters.Converter[] converters;
    private int[] types;

    public Object evaluate(GenericUDF.DeferredObject[] arguments) throws HiveException {
        String function_name = this.converters[0].convert(arguments[0].get()).toString();
        this.loadExportedRScript(function_name);
        StringBuffer argument = new StringBuffer();
        for (int i = 1; i < arguments.length - 1; ++i) {
            Object value = this.converters[i].convert(arguments[i].get());
            if (value == null) {
                argument.append("NULL");
            } else if (this.types[i] == STRING_TYPE) {
                argument.append("\"" + this.converters[i].convert(arguments[i].get()) + "\"");
            } else {
                argument.append(this.converters[i].convert(arguments[i].get()));
            }
            if (i >= arguments.length - 2) continue;
            argument.append(",");
        }
        REXP rdata = null;
        try {
            rdata = this.getConnection().eval(function_name + "(" + argument.toString() + ")");
        }
        catch (Exception e) {
            ByteArrayOutputStream output = new ByteArrayOutputStream();
            e.printStackTrace(new PrintStream(output));
            throw new HiveException(new String(output.toByteArray()) + " -- fail to eval : " + function_name + "(" + argument.toString() + ")");
        }
        if (rdata != null) {
            try {
                if (rdata instanceof REXPInteger) {
                    return new IntWritable(rdata.asInteger());
                }
                if (rdata instanceof REXPString) {
                    return new Text(rdata.asString());
                }
                if (rdata instanceof REXPDouble) {
                    return new DoubleWritable(rdata.asDouble());
                }
                throw new HiveException("only support integer, string and double");
            }
            catch (Exception e) {
                ByteArrayOutputStream output = new ByteArrayOutputStream();
                e.printStackTrace(new PrintStream(output));
                throw new HiveException(new String(output.toByteArray()));
            }
        }
        return null;
    }

    public String getDisplayString(String[] children) {
        StringBuilder sb = new StringBuilder();
        sb.append("Rfunction(");
        for (int i = 0; i < children.length; ++i) {
            sb.append(children[i]);
            if (i + 1 == children.length) continue;
            sb.append(",");
        }
        sb.append(")");
        return sb.toString();
    }

    public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException {
        GenericUDFUtils.ReturnObjectInspectorResolver returnOIResolver = new GenericUDFUtils.ReturnObjectInspectorResolver(true);
        for (int i = 0; i < arguments.length; ++i) {
            if (returnOIResolver.update(arguments[i])) continue;
            throw new UDFArgumentTypeException(i, "Argument type \"" + arguments[i].getTypeName() + "\" is different from preceding arguments. " + "Previous type was \"" + arguments[i - 1].getTypeName() + "\"");
        }
        this.converters = new ObjectInspectorConverters.Converter[arguments.length];
        this.types = new int[arguments.length];
        ObjectInspector returnOI = returnOIResolver.get();
        if (returnOI == null) {
            returnOI = PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector((PrimitiveObjectInspector.PrimitiveCategory)PrimitiveObjectInspector.PrimitiveCategory.STRING);
        }
        for (int i = 0; i < arguments.length; ++i) {
            this.converters[i] = ObjectInspectorConverters.getConverter((ObjectInspector)arguments[i], (ObjectInspector)returnOI);
            this.types[i] = arguments[i].getCategory() == ObjectInspector.Category.PRIMITIVE && ((PrimitiveObjectInspector)arguments[i]).getPrimitiveCategory() == PrimitiveObjectInspector.PrimitiveCategory.STRING ? STRING_TYPE : NUMBER_TYPE;
        }
        String typeName = arguments[arguments.length - 1].getTypeName();
        if (typeName.equals("int")) {
            return PrimitiveObjectInspectorFactory.writableIntObjectInspector;
        }
        if (typeName.equals("double")) {
            return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
        }
        if (typeName.equals("string")) {
            return PrimitiveObjectInspectorFactory.writableStringObjectInspector;
        }
        throw new IllegalArgumentException("can't support this type : " + typeName);
    }

    private void loadExportedRScript(String export_name) throws HiveException {
        if (!funclist.containsKey(export_name)) {
            try {
                this.getConnection().eval("load(file=paste(Sys.getenv('RHIVE_DATA'),'/" + export_name + ".Rdata',sep=''))");
            }
            catch (Exception e) {
                ByteArrayOutputStream output = new ByteArrayOutputStream();
                e.printStackTrace(new PrintStream(output));
                throw new HiveException(new String(output.toByteArray()));
            }
            funclist.put(export_name, NULL);
        }
    }

    private RConnection getConnection() throws UDFArgumentException {
        if (rconnection == null || !rconnection.isConnected()) {
            try {
                rconnection = new RConnection("127.0.0.1");
            }
            catch (Exception e) {
                ByteArrayOutputStream output = new ByteArrayOutputStream();
                e.printStackTrace(new PrintStream(output));
                throw new UDFArgumentException(new String(output.toByteArray()));
            }
        }
        return rconnection;
    }
}

