aboutsummaryrefslogtreecommitdiff
path: root/exec/java-exec/src/main/codegen/templates/Decimal/DecimalAggrTypeFunctions2.java
diff options
context:
space:
mode:
Diffstat (limited to 'exec/java-exec/src/main/codegen/templates/Decimal/DecimalAggrTypeFunctions2.java')
-rw-r--r--exec/java-exec/src/main/codegen/templates/Decimal/DecimalAggrTypeFunctions2.java42
1 files changed, 22 insertions, 20 deletions
diff --git a/exec/java-exec/src/main/codegen/templates/Decimal/DecimalAggrTypeFunctions2.java b/exec/java-exec/src/main/codegen/templates/Decimal/DecimalAggrTypeFunctions2.java
index df1eb7e81..c633b678c 100644
--- a/exec/java-exec/src/main/codegen/templates/Decimal/DecimalAggrTypeFunctions2.java
+++ b/exec/java-exec/src/main/codegen/templates/Decimal/DecimalAggrTypeFunctions2.java
@@ -27,7 +27,7 @@ import org.apache.drill.exec.expr.annotations.Workspace;
<#include "/@includes/license.ftl" />
<#-- A utility class that is used to generate java code for aggr functions for decimal data type that maintain a single -->
-<#-- running counter to hold the result. This includes: MIN, MAX, COUNT. -->
+<#-- running counter to hold the result. This includes: AVG. -->
/*
* This class is automatically generated from AggrTypeFunctions1.tdd using FreeMarker.
@@ -59,9 +59,9 @@ public class Decimal${aggrtype.className}Functions {
<#list aggrtype.types as type>
@FunctionTemplate(name = "${aggrtype.funcName}",
- scope = FunctionTemplate.FunctionScope.POINT_AGGREGATE,
- returnType = FunctionTemplate.ReturnType.DECIMAL_SUM_AGGREGATE)
-public static class ${type.inputType}${aggrtype.className} implements DrillAggFunc{
+ scope = FunctionTemplate.FunctionScope.POINT_AGGREGATE,
+ returnType = FunctionTemplate.ReturnType.DECIMAL_AVG_AGGREGATE)
+public static class ${type.inputType}${aggrtype.className} implements DrillAggFunc {
@Param ${type.inputType}Holder in;
@Inject DrillBuf buffer;
@@ -71,7 +71,6 @@ public static class ${type.inputType}${aggrtype.className} implements DrillAggFu
@Output ${type.outputType}Holder out;
public void setup() {
- buffer.reallocIfNeeded(${type.outputType}Holder.WIDTH);
value = new ObjectHolder();
value.obj = java.math.BigDecimal.ZERO;
count = new ${type.countRunningType}Holder();
@@ -83,18 +82,15 @@ public static class ${type.inputType}${aggrtype.className} implements DrillAggFu
@Override
public void add() {
<#if type.inputType?starts_with("Nullable")>
- sout: {
+ sout: {
if (in.isSet == 0) {
// processing nullable input and the value is null, so don't do anything...
break sout;
}
</#if>
count.value++;
- <#if type.inputType.endsWith("Decimal9") || type.inputType.endsWith("Decimal18")>
- java.math.BigDecimal currentValue = org.apache.drill.exec.util.DecimalUtility.getBigDecimalFromPrimitiveTypes(in.value, in.scale, in.precision);
- <#else>
- java.math.BigDecimal currentValue = in.getBigDecimal();
- </#if>
+ java.math.BigDecimal currentValue = org.apache.drill.exec.util.DecimalUtility
+ .getBigDecimalFromDrillBuf(in.buffer, in.start, in.end - in.start, in.scale);
value.obj = ((java.math.BigDecimal)(value.obj)).add(currentValue);
if (outputScale.value == Integer.MIN_VALUE) {
outputScale.value = in.scale;
@@ -106,14 +102,21 @@ public static class ${type.inputType}${aggrtype.className} implements DrillAggFu
@Override
public void output() {
- out.buffer = buffer;
- out.start = 0;
- out.scale = outputScale.value;
- java.math.BigDecimal average = ((java.math.BigDecimal)(value.obj)).divide(java.math.BigDecimal.valueOf(count.value, 0), out.scale, java.math.BigDecimal.ROUND_HALF_UP);
-<#if !type.inputType.contains("VarDecimal")>
- out.precision = 38;
- org.apache.drill.exec.util.DecimalUtility.getSparseFromBigDecimal(average, out.buffer, out.start, out.scale, out.precision, out.nDecimalDigits);
-</#if>
+ if (count.value > 0) {
+ out.isSet = 1;
+ out.start = 0;
+ out.scale = Math.max(outputScale.value, 6);
+ java.math.BigDecimal average = ((java.math.BigDecimal) value.obj)
+ .divide(java.math.BigDecimal.valueOf(count.value), out.scale, java.math.BigDecimal.ROUND_HALF_UP);
+ out.precision = org.apache.drill.exec.planner.types.DrillRelDataTypeSystem.DRILL_REL_DATATYPE_SYSTEM.getMaxNumericPrecision();
+ byte[] bytes = average.unscaledValue().toByteArray();
+ int len = bytes.length;
+ out.buffer = buffer.reallocIfNeeded(len);
+ out.buffer.setBytes(0, bytes);
+ out.end = len;
+ } else {
+ out.isSet = 0;
+ }
}
@Override
@@ -127,7 +130,6 @@ public static class ${type.inputType}${aggrtype.className} implements DrillAggFu
}
}
-
</#list>
}
</#list>