入坑指南(2)
1.UDAF函数
UDAF函数是User Defined Aggregation Function的简称,它用来实现用户自定义的聚合操作,比如sum
,avg
,max
等等。实现自定义UDAF函数的方法和UDF函数的操作类似,有两种:
- 继承UDAF类
- 继承AbstractGenericUDAFResolver类
不过,第一种已经被弃用了,不推荐。但是第一种方法的实现更加简单。相对的,他的功能也受到一些局限。第一种方式效率上不如第二种,也不方便做输入参数的检查等。更多细节可以去hive的官网上查阅。下面以实现avg函数为例进行讲解。
1.1 继承UDAF类
这种方法已经不再推荐,所以不做细讲。总的来说,开发步骤为:
- 继承UDAF类
- 建立一个public的静态类,并实现UDAFEvaluator接口,实现必要的方法。这个类对应着hive的核心操作。
- init方法:这里是用来做一些初始化的操作。
- iterate方法:负责接收并处理一条输入数据。
- terminatePartial方法:处理map的输出结果
- merge方法:负责融合处理中间结果
- terminate方法:给出最后的结果
package MyUDAF;
import org.apache.hadoop.hive.ql.exec.UDAF;
import org.apache.hadoop.hive.ql.exec.UDAFEvaluator;
/**
* multiple in one out
* */
public class UDAFTest extends UDAF{
public static class ValuePairs{
private double value;
private long count;
}
public static class MeansUDAFEvaluator implements UDAFEvaluator{
private ValuePairs valuePairs;
public MeansUDAFEvaluator(){
valuePairs = new ValuePairs();
init();
}
public void init() {
valuePairs.count = 0;
valuePairs.value = 0.0;
}
/**
* process a row value
* */
public boolean iterate(int value){
return iterate(Double.valueOf(value));
}
private boolean iterate(Double aDouble) {
if(aDouble == null)
return false;
valuePairs.value += aDouble;
valuePairs.count += 1;
return true;
}
public ValuePairs terminatePartial() {
return valuePairs;
}
public boolean merge(ValuePairs other) {
if(other == null)
return false;
valuePairs.count += other.count;
valuePairs.value += other.value;
return true;
}
public double terminate() {
if(valuePairs == null || valuePairs.count == 0){
return Double.NEGATIVE_INFINITY;
}
return valuePairs.value / valuePairs.count;
}
}
}
将以上代码打包,在CLI中运行。结果图如下:
1.2 继承AbstractGenericUDAFResolver类
继承AbstractGenericUDAFResolver类是官网比较推荐的方法,但是开发的难度相对增加了。总体的开发流程是:
- 继承AbstractGenericUDAFResolver类,覆写getEvaluator方法
- 新建一个public的静态类,并且继承GenericUDAFEvaluator类,实现必要的7个方法:
- init方法:做一些初始化操作,比如检测输入,指定输出等等
- getNewAggregationBuffer方法:在GenericUDAFEvaluator类中有一个AggregationBuffer接口,他用来缓存计算中间结果的。getNewAggregationBuffer方法和reset方法一般一起使用,用来清空中间结果。
- reset方法:如上
- iterate方法:接收并处理每一条输入数据
- terminatePartial方法:返回初步聚合结果
- merge方法:融合初步聚合结果
- terminate方法:返回最后结果。
以上几个方法,和MapReduce任务调用紧密相关。MapReduce由Map和Reduce两个操作组成。为了标识任务的进度,GenericUDAFEvaluator类中还设计了一个枚举类型
public static enum Mode {
/**
* PARTIAL1: from original data to partial aggregation data: iterate() and
* terminatePartial() will be called.
*/
PARTIAL1,
/**
* PARTIAL2: from partial aggregation data to partial aggregation data:
* merge() and terminatePartial() will be called.
*/
PARTIAL2,
/**
* FINAL: from partial aggregation to full aggregation: merge() and
* terminate() will be called.
*/
FINAL,
/**
* COMPLETE: from original data directly to full aggregation: iterate() and
* terminate() will be called.
*/
COMPLETE
};
其中,
- PARTIAL1对应着map操作,在这个阶段,程序分别依次调用iterate方法和terminatePartial方法,以处理每一条原始的输入数据,然后做map端的初步融合
- PARTIAL2对应着combiner阶段,他依次调用merge方法和terminatePartial方法,对map端的数据做进一步的聚合
- FINAL是reduce阶段。他调用merge方法和terminate方法,整合最后结果
- COMPLETE指代没有reduce任务的map only操作。他直接调用iterate和terminate方法得到最后结果。
所以,原始数据只会在PARTIAL1和COMPLETE阶段中出现,并且terminatePartial方法只会在PARTIAL1和PARTIAL2阶段调用,这正好说明了map输出和combiner的输出类型一定是一致的,merge方法只在PARTIAL2和FINAL阶段调用,说明combiner和reduce的输入类型是一致的。最后,在以上的四个阶段里,一开始都会调用init方法来指明输入输出。所以在init方法中有个参数Mode,专门用来判断任务的执行阶段。
下面的代码同样是用来计算avg的。这里是用了ValuePairsAgg来暂时存储sum和count
package MyUDAF;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFParameterInfo;
import org.apache.hadoop.hive.serde2.objectinspector.*;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.LongObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo;
import org.apache.hadoop.io.DoubleWritable;
import org.apache.hadoop.io.LongWritable;
import java.util.ArrayList;
public class GenericUDAFTest extends AbstractGenericUDAFResolver {
/**
* receives information about how the UDAF is being invoked
* */
@Override
public GenericUDAFEvaluator getEvaluator(GenericUDAFParameterInfo info) throws SemanticException {
ObjectInspector[] objectInspectors = info.getParameterObjectInspectors();
if(objectInspectors.length != 1){
throw new UDFArgumentLengthException("It only takes one params");
}
if(objectInspectors[0].getCategory() == ObjectInspector.Category.PRIMITIVE){
PrimitiveTypeInfo primitiveTypeInfo = (PrimitiveTypeInfo) info.getParameters()[0];
switch (primitiveTypeInfo.getPrimitiveCategory()){
case INT:
case BYTE:
case LONG:
case FLOAT:
case SHORT:
case DOUBLE:
break;
case DATE:
case BOOLEAN:
case STRING:
default:
throw new UDFArgumentException("Only numeric type is allowed");
}
}
else
throw new UDFArgumentException("Only numeric type is allowed");
return new MyUDAFEvaluator();
}
public static class MyUDAFEvaluator extends GenericUDAFEvaluator{
private PrimitiveObjectInspector inputDataInspector;
private StructObjectInspector structObjectInspector;
private StructField valueStruct;
private StructField countStruct;
private DoubleObjectInspector doubleObjectInspector;
private LongObjectInspector longObjectInspector;
// PARTICAL1 and 2
private Object[] objects;
// FINAL and COMPLETE
private DoubleWritable means;
/**
* temp results
* */
public static class ValuePairsAgg implements AggregationBuffer{
private double value;
private long count;
}
/**
* check input
* output format
* */
@Override
public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException {
super.init(m, parameters);
if(m == Mode.COMPLETE || m == Mode.PARTIAL1){
//from data, run iterate()
inputDataInspector = (PrimitiveObjectInspector) parameters[0];
}
else{
//PARTICAL2 and FINAL, run merge()
structObjectInspector = (StructObjectInspector) parameters[0];
valueStruct = structObjectInspector.getStructFieldRef("value");
countStruct = structObjectInspector.getStructFieldRef("count");
doubleObjectInspector = (DoubleObjectInspector) valueStruct.getFieldObjectInspector();
longObjectInspector = (LongObjectInspector) countStruct.getFieldObjectInspector();
}
if(m == Mode.PARTIAL1 || m == Mode.PARTIAL2){
// output type of terminatePartial()
ArrayList<ObjectInspector> objectInspectors = new ArrayList<ObjectInspector>();
objectInspectors.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); // value
objectInspectors.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector); // count
ArrayList<String> objectNames = new ArrayList<String>();
objectNames.add("value");
objectNames.add("count");
objects = new Object[2];
objects[0] = new DoubleWritable(0.0);
objects[1] = new LongWritable(0);
return ObjectInspectorFactory.getStandardStructObjectInspector(objectNames, objectInspectors);
}
else{
means = new DoubleWritable(0.0);
return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
}
}
public AggregationBuffer getNewAggregationBuffer() throws HiveException {
ValuePairsAgg valuePairsAgg = new ValuePairsAgg();
reset(valuePairsAgg);
return valuePairsAgg;
}
public void reset(AggregationBuffer agg) throws HiveException {
ValuePairsAgg valuePairsAgg = (ValuePairsAgg) agg;
valuePairsAgg.count = 0;
valuePairsAgg.value = 0.0;
}
public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException {
if(parameters == null || parameters.length == 0){
return;
}
ValuePairsAgg valuePairsAgg = (ValuePairsAgg) agg;
valuePairsAgg.count++;
valuePairsAgg.value += PrimitiveObjectInspectorUtils.getDouble(parameters[0], inputDataInspector);
}
public Object terminatePartial(AggregationBuffer agg) throws HiveException {
ValuePairsAgg valuePairsAgg = (ValuePairsAgg) agg;
((DoubleWritable) objects[0]).set(valuePairsAgg.value);
((LongWritable) objects[1]).set(valuePairsAgg.count);
return objects;
}
public void merge(AggregationBuffer agg, Object partial) throws HiveException {
ValuePairsAgg valuePairsAgg = (ValuePairsAgg) agg;
double value = doubleObjectInspector.get(structObjectInspector.getStructFieldData(partial, valueStruct));
double count = longObjectInspector.get(structObjectInspector.getStructFieldData(partial, countStruct));
valuePairsAgg.count += count;
valuePairsAgg.value += value;
}
public Object terminate(AggregationBuffer agg) throws HiveException {
ValuePairsAgg valuePairsAgg = (ValuePairsAgg) agg;
if(valuePairsAgg == null || valuePairsAgg.count == 0){
return null;
}
means.set(valuePairsAgg.value / valuePairsAgg.count);
return means;
}
}
}
下面是运行结果图。可以看到输出错误的输入格式,会有相应的报错提示。
2. UDTF函数
与UDAF和UDF函数不同,UDTF可以用来生成多行和(或)多列,比如explode方法。UDTF函数的实现方法只有一种,相对比较简单。只需要继承GenericUDTF类,然后覆写initialize方法,process方法和close方法。其中:
- initialize方法和GenericUDF中的initialize方法的作用比较类似。用来做一些初始化操作,比如检测输入类型,指定输出类型。不过UDTF方法可以产生多个列,所以需要用StandardStructObjectInspector来指明每一个新产生的列的类型和列名。
- process方法:如果说initialize方法指定了要生成多少个列,那么process方法则指定了要生成多少个行。在process中,通过调用一次forward方法,便可以产生一行数据。当然,每次调用forward的时候需要传递一个数组,数组中记录了这一行的值。数据的长度需要等于initialize方法中所指定的新增列的个数。
- close方法,可以用作一些扫尾工作
下面的代码,实现了一个阉割版的explode方法。输入限定为数组类型,然后把数组元素展开为新列。
package MyUDTF;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.UDFType;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDTF;
import org.apache.hadoop.hive.serde2.objectinspector.*;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import java.util.ArrayList;
import java.util.List;
/**
* one in multiple rows/cols out
* array<T> to multiple rows and cols
* e.g. value and index
* */
@UDFType(deterministic = true)
@Description(name = "flat", value = "_FUN_(Array<T>) gives two cols and many rows")
public class UDTFTest extends GenericUDTF{
private ListObjectInspector objectInspector;
private Object[] forwardObjects = new Object[2];
@Override
public StructObjectInspector initialize(StructObjectInspector argOIs) throws UDFArgumentException {
List<? extends StructField> inputFields = argOIs.getAllStructFieldRefs();
if(inputFields.size() != 1){
throw new UDFArgumentLengthException("It only takes one param");
}
ObjectInspector objectInspector_ = inputFields.get(0).getFieldObjectInspector();
if(!(objectInspector_ instanceof ListObjectInspector)){
throw new UDFArgumentTypeException(0, "Array<T> / List<T> type is expected");
}
objectInspector = (ListObjectInspector) objectInspector_;
ArrayList<String> filedNames = new ArrayList<String>();
filedNames.add("element");
filedNames.add("index");
ArrayList<ObjectInspector> objectInspectorArrayList = new ArrayList<ObjectInspector>();
objectInspectorArrayList.add(objectInspector.getListElementObjectInspector());
objectInspectorArrayList.add(PrimitiveObjectInspectorFactory.javaIntObjectInspector);
return ObjectInspectorFactory.getStandardStructObjectInspector(filedNames, objectInspectorArrayList);
}
public void process(Object[] args) throws HiveException {
List list = objectInspector.getList(args[0]);
if(list == null){
return;
}
for (int i = 0; i < list.size(); i++) {
forwardObjects[0] = list.get(i);
forwardObjects[1] = i;
forward(forwardObjects);
}
}
public void close() throws HiveException {
}
}
运行结果如下: