Hive编程开发(2)

自定义UDAF函数和UDTF函数实现

Posted by HC on April 2, 2018

入坑指南(2)

1.UDAF函数

UDAF函数是User Defined Aggregation Function的简称,它用来实现用户自定义的聚合操作,比如sumavgmax等等。实现自定义UDAF函数的方法和UDF函数的操作类似,有两种:

  1. 继承UDAF类
  2. 继承AbstractGenericUDAFResolver类

不过,第一种已经被弃用了,不推荐。但是第一种方法的实现更加简单。相对的,他的功能也受到一些局限。第一种方式效率上不如第二种,也不方便做输入参数的检查等。更多细节可以去hive的官网上查阅。下面以实现avg函数为例进行讲解。

1.1 继承UDAF类

这种方法已经不再推荐,所以不做细讲。总的来说,开发步骤为:

  1. 继承UDAF类
  2. 建立一个public的静态类,并实现UDAFEvaluator接口,实现必要的方法。这个类对应着hive的核心操作。
    1. init方法:这里是用来做一些初始化的操作。
    2. iterate方法:负责接收并处理一条输入数据。
    3. terminatePartial方法:处理map的输出结果
    4. merge方法:负责融合处理中间结果
    5. terminate方法:给出最后的结果
package MyUDAF;

import org.apache.hadoop.hive.ql.exec.UDAF;
import org.apache.hadoop.hive.ql.exec.UDAFEvaluator;

/**
 * multiple in one out
 * */
public class UDAFTest extends UDAF{
    public static class ValuePairs{
        private double value;
        private long count;
    }

    public static class MeansUDAFEvaluator implements UDAFEvaluator{
        private ValuePairs valuePairs;

        public MeansUDAFEvaluator(){
            valuePairs = new ValuePairs();
            init();
        }

        public void init() {
            valuePairs.count = 0;
            valuePairs.value = 0.0;
        }

        /**
         * process a row value
         * */
        public boolean iterate(int value){
            return iterate(Double.valueOf(value));
        }

        private boolean iterate(Double aDouble) {
            if(aDouble == null)
                return false;
            valuePairs.value += aDouble;
            valuePairs.count += 1;
            return true;
        }

        public ValuePairs terminatePartial() {
            return valuePairs;
        }

        public boolean merge(ValuePairs other) {
            if(other == null)
                return false;
            valuePairs.count += other.count;
            valuePairs.value += other.value;
            return true;
        }

        public double terminate() {
            if(valuePairs == null || valuePairs.count == 0){
                return Double.NEGATIVE_INFINITY;
            }
            return valuePairs.value / valuePairs.count;
        }
    }
}

将以上代码打包,在CLI中运行。结果图如下:

1.2 继承AbstractGenericUDAFResolver类

继承AbstractGenericUDAFResolver类是官网比较推荐的方法,但是开发的难度相对增加了。总体的开发流程是:

  1. 继承AbstractGenericUDAFResolver类,覆写getEvaluator方法
  2. 新建一个public的静态类,并且继承GenericUDAFEvaluator类,实现必要的7个方法:
    1. init方法:做一些初始化操作,比如检测输入,指定输出等等
    2. getNewAggregationBuffer方法:在GenericUDAFEvaluator类中有一个AggregationBuffer接口,他用来缓存计算中间结果的。getNewAggregationBuffer方法和reset方法一般一起使用,用来清空中间结果。
    3. reset方法:如上
    4. iterate方法:接收并处理每一条输入数据
    5. terminatePartial方法:返回初步聚合结果
    6. merge方法:融合初步聚合结果
    7. terminate方法:返回最后结果。

以上几个方法,和MapReduce任务调用紧密相关。MapReduce由Map和Reduce两个操作组成。为了标识任务的进度,GenericUDAFEvaluator类中还设计了一个枚举类型

public static enum Mode {
    /**
     * PARTIAL1: from original data to partial aggregation data: iterate() and
     * terminatePartial() will be called.
     */
    PARTIAL1,
        /**
     * PARTIAL2: from partial aggregation data to partial aggregation data:
     * merge() and terminatePartial() will be called.
     */
    PARTIAL2,
        /**
     * FINAL: from partial aggregation to full aggregation: merge() and
     * terminate() will be called.
     */
    FINAL,
        /**
     * COMPLETE: from original data directly to full aggregation: iterate() and
     * terminate() will be called.
     */
    COMPLETE
  };

其中,

  1. PARTIAL1对应着map操作,在这个阶段,程序分别依次调用iterate方法和terminatePartial方法,以处理每一条原始的输入数据,然后做map端的初步融合
  2. PARTIAL2对应着combiner阶段,他依次调用merge方法和terminatePartial方法,对map端的数据做进一步的聚合
  3. FINAL是reduce阶段。他调用merge方法和terminate方法,整合最后结果
  4. COMPLETE指代没有reduce任务的map only操作。他直接调用iterate和terminate方法得到最后结果。

所以,原始数据只会在PARTIAL1和COMPLETE阶段中出现,并且terminatePartial方法只会在PARTIAL1和PARTIAL2阶段调用,这正好说明了map输出和combiner的输出类型一定是一致的,merge方法只在PARTIAL2和FINAL阶段调用,说明combiner和reduce的输入类型是一致的。最后,在以上的四个阶段里,一开始都会调用init方法来指明输入输出。所以在init方法中有个参数Mode,专门用来判断任务的执行阶段。

下面的代码同样是用来计算avg的。这里是用了ValuePairsAgg来暂时存储sum和count

package MyUDAF;

import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFParameterInfo;
import org.apache.hadoop.hive.serde2.objectinspector.*;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.LongObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo;
import org.apache.hadoop.io.DoubleWritable;
import org.apache.hadoop.io.LongWritable;

import java.util.ArrayList;

public class GenericUDAFTest extends AbstractGenericUDAFResolver {
    /**
     * receives information about how the UDAF is being invoked
     * */
    @Override
    public GenericUDAFEvaluator getEvaluator(GenericUDAFParameterInfo info) throws SemanticException {
        ObjectInspector[] objectInspectors = info.getParameterObjectInspectors();
        if(objectInspectors.length != 1){
            throw new UDFArgumentLengthException("It only takes one params");
        }

        if(objectInspectors[0].getCategory() == ObjectInspector.Category.PRIMITIVE){
            PrimitiveTypeInfo primitiveTypeInfo = (PrimitiveTypeInfo) info.getParameters()[0];
            switch (primitiveTypeInfo.getPrimitiveCategory()){
                case INT:
                case BYTE:
                case LONG:
                case FLOAT:
                case SHORT:
                case DOUBLE:
                    break;
                case DATE:
                case BOOLEAN:
                case STRING:
                default:
                    throw new UDFArgumentException("Only numeric type is allowed");
            }
        }
        else
            throw new UDFArgumentException("Only numeric type is allowed");

        return new MyUDAFEvaluator();
    }

    public static class MyUDAFEvaluator extends GenericUDAFEvaluator{
        private PrimitiveObjectInspector inputDataInspector;
        private StructObjectInspector structObjectInspector;
        private StructField valueStruct;
        private StructField countStruct;
        private DoubleObjectInspector doubleObjectInspector;
        private LongObjectInspector longObjectInspector;

        // PARTICAL1 and 2
        private Object[] objects;
        // FINAL and COMPLETE
        private DoubleWritable means;

        /**
         * temp results
         * */
        public static class ValuePairsAgg implements AggregationBuffer{
            private double value;
            private long count;
        }

        /**
         * check input
         * output format
         * */
        @Override
        public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException {
            super.init(m, parameters);
            if(m == Mode.COMPLETE || m == Mode.PARTIAL1){
                //from data, run iterate()
                inputDataInspector = (PrimitiveObjectInspector) parameters[0];
            }
            else{
                //PARTICAL2 and FINAL, run merge()
                structObjectInspector = (StructObjectInspector) parameters[0];
                valueStruct = structObjectInspector.getStructFieldRef("value");
                countStruct = structObjectInspector.getStructFieldRef("count");
                doubleObjectInspector = (DoubleObjectInspector) valueStruct.getFieldObjectInspector();
                longObjectInspector = (LongObjectInspector) countStruct.getFieldObjectInspector();
            }

            if(m == Mode.PARTIAL1 || m == Mode.PARTIAL2){
                // output type of terminatePartial()
                ArrayList<ObjectInspector> objectInspectors = new ArrayList<ObjectInspector>();
                objectInspectors.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); // value
                objectInspectors.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector); // count
                ArrayList<String> objectNames = new ArrayList<String>();
                objectNames.add("value");
                objectNames.add("count");
                objects = new Object[2];
                objects[0] = new DoubleWritable(0.0);
                objects[1] = new LongWritable(0);
                return ObjectInspectorFactory.getStandardStructObjectInspector(objectNames, objectInspectors);

            }
            else{
                means = new DoubleWritable(0.0);
                return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
            }
        }

        public AggregationBuffer getNewAggregationBuffer() throws HiveException {
            ValuePairsAgg valuePairsAgg = new ValuePairsAgg();
            reset(valuePairsAgg);
            return valuePairsAgg;
        }

        public void reset(AggregationBuffer agg) throws HiveException {
            ValuePairsAgg valuePairsAgg = (ValuePairsAgg) agg;
            valuePairsAgg.count = 0;
            valuePairsAgg.value = 0.0;
        }

        public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException {
            if(parameters == null || parameters.length == 0){
                return;
            }
            ValuePairsAgg valuePairsAgg = (ValuePairsAgg) agg;
            valuePairsAgg.count++;
            valuePairsAgg.value += PrimitiveObjectInspectorUtils.getDouble(parameters[0], inputDataInspector);
        }

        public Object terminatePartial(AggregationBuffer agg) throws HiveException {
            ValuePairsAgg valuePairsAgg = (ValuePairsAgg) agg;
            ((DoubleWritable) objects[0]).set(valuePairsAgg.value);
            ((LongWritable) objects[1]).set(valuePairsAgg.count);
            return objects;
        }

        public void merge(AggregationBuffer agg, Object partial) throws HiveException {
            ValuePairsAgg valuePairsAgg = (ValuePairsAgg) agg;
            double value = doubleObjectInspector.get(structObjectInspector.getStructFieldData(partial, valueStruct));
            double count = longObjectInspector.get(structObjectInspector.getStructFieldData(partial, countStruct));
            valuePairsAgg.count += count;
            valuePairsAgg.value += value;
        }

        public Object terminate(AggregationBuffer agg) throws HiveException {
            ValuePairsAgg valuePairsAgg = (ValuePairsAgg) agg;
            if(valuePairsAgg == null || valuePairsAgg.count == 0){
                return null;
            }
            means.set(valuePairsAgg.value / valuePairsAgg.count);
            return means;
        }
    }
}

下面是运行结果图。可以看到输出错误的输入格式,会有相应的报错提示。

2. UDTF函数

与UDAF和UDF函数不同,UDTF可以用来生成多行和(或)多列,比如explode方法。UDTF函数的实现方法只有一种,相对比较简单。只需要继承GenericUDTF类,然后覆写initialize方法,process方法和close方法。其中:

  1. initialize方法和GenericUDF中的initialize方法的作用比较类似。用来做一些初始化操作,比如检测输入类型,指定输出类型。不过UDTF方法可以产生多个列,所以需要用StandardStructObjectInspector来指明每一个新产生的列的类型和列名。
  2. process方法:如果说initialize方法指定了要生成多少个列,那么process方法则指定了要生成多少个行。在process中,通过调用一次forward方法,便可以产生一行数据。当然,每次调用forward的时候需要传递一个数组,数组中记录了这一行的值。数据的长度需要等于initialize方法中所指定的新增列的个数。
  3. close方法,可以用作一些扫尾工作

下面的代码,实现了一个阉割版的explode方法。输入限定为数组类型,然后把数组元素展开为新列。

package MyUDTF;

import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.UDFType;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDTF;
import org.apache.hadoop.hive.serde2.objectinspector.*;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;

import java.util.ArrayList;
import java.util.List;

/**
 * one in multiple rows/cols out
 * array<T> to multiple rows and cols
 * e.g. value and index
 * */
@UDFType(deterministic = true)
@Description(name = "flat", value = "_FUN_(Array<T>) gives two cols and many rows")
public class UDTFTest extends GenericUDTF{
    private ListObjectInspector objectInspector;
    private Object[] forwardObjects = new Object[2];

    @Override
    public StructObjectInspector initialize(StructObjectInspector argOIs) throws UDFArgumentException {
        List<? extends StructField> inputFields = argOIs.getAllStructFieldRefs();

        if(inputFields.size() != 1){
            throw new UDFArgumentLengthException("It only takes one param");
        }
        ObjectInspector objectInspector_ = inputFields.get(0).getFieldObjectInspector();
        if(!(objectInspector_ instanceof ListObjectInspector)){
            throw new UDFArgumentTypeException(0, "Array<T> / List<T> type is expected");
        }

        objectInspector = (ListObjectInspector) objectInspector_;

        ArrayList<String> filedNames = new ArrayList<String>();
        filedNames.add("element");
        filedNames.add("index");
        ArrayList<ObjectInspector> objectInspectorArrayList = new ArrayList<ObjectInspector>();
        objectInspectorArrayList.add(objectInspector.getListElementObjectInspector());
        objectInspectorArrayList.add(PrimitiveObjectInspectorFactory.javaIntObjectInspector);
        return ObjectInspectorFactory.getStandardStructObjectInspector(filedNames, objectInspectorArrayList);
    }

    public void process(Object[] args) throws HiveException {
        List list = objectInspector.getList(args[0]);
        if(list == null){
            return;
        }
        for (int i = 0; i < list.size(); i++) {
            forwardObjects[0] = list.get(i);
            forwardObjects[1] = i;
            forward(forwardObjects);
        }
    }

    public void close() throws HiveException {

    }
}

运行结果如下: