package com.cloudera.science.pig;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import org.apache.pig.EvalFunc;
import org.apache.pig.backend.executionengine.ExecException;
import org.apache.pig.data.BagFactory;
import org.apache.pig.data.DataBag;
import org.apache.pig.data.DataType;
import org.apache.pig.data.Tuple;
import org.apache.pig.data.TupleFactory;
import org.apache.pig.impl.logicalLayer.FrontendException;
import org.apache.pig.impl.logicalLayer.schema.Schema;
import org.apache.pig.impl.logicalLayer.schema.Schema.FieldSchema;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
public class Bin extends EvalFunc<DataBag> {
private final TupleFactory tupleFactory = TupleFactory.getInstance();
private final BagFactory bagFactory = BagFactory.getInstance();
@Override
public DataBag
exec(Tuple input)
throws IOException {
DataBag output = bagFactory.newDefaultBag();
Object o1 = input.get(0);
if (!(o1 instanceof DataBag)) {
throw new IOException("Expected input to be a bag, but got: " + o1.getClass());
}
DataBag inputBag = (DataBag) o1;
Object o2 = input.get(1);
if (!(o2 instanceof DataBag)) {
throw new IOException("Expected second input to be a bag, but got: " + o2.getClass());
}
List<Double> quantiles = getQuantiles((DataBag) o2);
for (Tuple t : inputBag) {
if (t != null && t.get(0) != null) {
double val = ((Number)t.get(0)).doubleValue();
int index = Collections.binarySearch(quantiles, val);
if (index > -1) {
t = tupleFactory.newTuple(ImmutableList.of(index, t.get(0)));
} else {
t = tupleFactory.newTuple(ImmutableList.of(-index - 1, t.get(0)));
}
output.add(t);
}
}
return output;
}
return pigType == DataType.DOUBLE || pigType == DataType.FLOAT ||
pigType == DataType.INTEGER || pigType == DataType.LONG;
}
private byte checkField(FieldSchema field)
throws FrontendException {
if (field.type != DataType.BAG) {
throw new IllegalArgumentException("Expected a bag; found: " +
DataType.findTypeName(field.type));
}
if (field.schema.size() != 1) {
throw new IllegalArgumentException("The bag must contain a single field");
}
byte bagType = field.schema.getField(0).type;
if (bagType == DataType.TUPLE) {
bagType = field.schema.getField(0).schema.getField(0).type;
}
if (!isNumeric(bagType)) {
throw new IllegalArgumentException("The bag's field must be a numeric type");
}
return bagType;
}
@Override
if (input.size() != 2) {
throw new IllegalArgumentException("Expected two bags; input has != 2 fields");
}
try {
byte binType = checkField(input.getField(0));
byte quantileType = checkField(input.getField(1));
if (quantileType != DataType.DOUBLE) {
throw new IllegalArgumentException("Expected doubles for quantile bag");
}
List<FieldSchema> fields = Lists.newArrayList(new FieldSchema("bin", DataType.INTEGER),
new FieldSchema("value", binType));
Schema tupleSchema = new Schema(fields);
FieldSchema tupleFieldSchema = new FieldSchema("t", tupleSchema,
DataType.TUPLE);
Schema bagSchema = new Schema(tupleFieldSchema);
bagSchema.setTwoLevelAccessRequired(true);
FieldSchema bagFieldSchema = new FieldSchema("b", bagSchema, DataType.BAG);
return new Schema(bagFieldSchema);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
private List<Double>
getQuantiles(DataBag bag)
throws ExecException {
List<Double> quantiles = Lists.newArrayList();
for (Tuple t : bag) {
if (t != null && t.get(0) != null) {
quantiles.add(((Number)t.get(0)).doubleValue());
}
}
Collections.sort(quantiles);
return quantiles;
}
}