aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMehant Baid <mehantr@gmail.com>2014-07-03 14:35:13 -0700
committerJacques Nadeau <jacques@apache.org>2014-07-07 14:50:31 -0700
commite987e06c7e9de907a689e658768d38b4f2dc5bd7 (patch)
tree7b9497e7fc2bb0d195cb9727c9ecd78aeaf0324d
parentcd4f7267de975f89aef775860532d49220ee8e88 (diff)
DRILL-861: Implement sum, avg for decimal data type.
-rw-r--r--common/src/main/java/org/apache/drill/common/util/DecimalUtility.java10
-rw-r--r--exec/java-exec/src/main/codegen/config.fmpp1
-rw-r--r--exec/java-exec/src/main/codegen/data/DecimalAggrTypes1.tdd13
-rw-r--r--exec/java-exec/src/main/codegen/data/DecimalAggrTypes2.tdd31
-rw-r--r--exec/java-exec/src/main/codegen/templates/Decimal/DecimalAggrTypeFunctions1.java (renamed from exec/java-exec/src/main/codegen/templates/DecimalAggrTypeFunctions1.java)42
-rw-r--r--exec/java-exec/src/main/codegen/templates/Decimal/DecimalAggrTypeFunctions2.java123
-rw-r--r--exec/java-exec/src/main/java/org/apache/drill/exec/expr/annotations/FunctionTemplate.java1
-rw-r--r--exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/DrillDecimalSumAggFuncHolder.java46
-rw-r--r--exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/FunctionConverter.java3
9 files changed, 266 insertions, 4 deletions
diff --git a/common/src/main/java/org/apache/drill/common/util/DecimalUtility.java b/common/src/main/java/org/apache/drill/common/util/DecimalUtility.java
index 4f9f09601..465cf8219 100644
--- a/common/src/main/java/org/apache/drill/common/util/DecimalUtility.java
+++ b/common/src/main/java/org/apache/drill/common/util/DecimalUtility.java
@@ -23,6 +23,7 @@ import org.apache.drill.common.types.TypeProtos;
import java.math.BigDecimal;
import java.math.BigInteger;
+import java.math.RoundingMode;
import java.util.Arrays;
public class DecimalUtility {
@@ -341,6 +342,15 @@ public class DecimalUtility {
return (input.unscaledValue().longValue());
}
+ public static BigDecimal getBigDecimalFromPrimitiveTypes(int input, int scale, int precision) {
+ return BigDecimal.valueOf(input, scale);
+ }
+
+ public static BigDecimal getBigDecimalFromPrimitiveTypes(long input, int scale, int precision) {
+ return BigDecimal.valueOf(input, scale);
+ }
+
+
public static int compareDenseBytes(ByteBuf left, int leftStart, boolean leftSign, ByteBuf right, int rightStart, boolean rightSign, int width) {
int invert = 1;
diff --git a/exec/java-exec/src/main/codegen/config.fmpp b/exec/java-exec/src/main/codegen/config.fmpp
index 83a0b70c5..72520f6a3 100644
--- a/exec/java-exec/src/main/codegen/config.fmpp
+++ b/exec/java-exec/src/main/codegen/config.fmpp
@@ -22,6 +22,7 @@ data: {
mathFunc:tdd(../data/MathFunc.tdd),
aggrtypes1: tdd(../data/AggrTypes1.tdd),
decimalaggrtypes1: tdd(../data/DecimalAggrTypes1.tdd),
+ decimalaggrtypes2: tdd(../data/DecimalAggrTypes2.tdd),
aggrtypes2: tdd(../data/AggrTypes2.tdd),
aggrtypes3: tdd(../data/AggrTypes3.tdd),
covarTypes: tdd(../data/CovarTypes.tdd),
diff --git a/exec/java-exec/src/main/codegen/data/DecimalAggrTypes1.tdd b/exec/java-exec/src/main/codegen/data/DecimalAggrTypes1.tdd
index 558f95bae..d8a2c73db 100644
--- a/exec/java-exec/src/main/codegen/data/DecimalAggrTypes1.tdd
+++ b/exec/java-exec/src/main/codegen/data/DecimalAggrTypes1.tdd
@@ -60,6 +60,17 @@
{inputType: "Decimal38Dense", outputType: "BigInt", runningType: "BigInt", major: "Numeric"},
{inputType: "NullableDecimal38Dense", outputType: "BigInt", runningType: "BigInt", major: "Numeric"}
]
+ },
+ {className: "Sum", funcName: "sum", types: [
+ {inputType: "Decimal9", outputType: "Decimal38Sparse", major: "Numeric"},
+ {inputType: "NullableDecimal9", outputType: "Decimal38Sparse", major: "Numeric"},
+ {inputType: "Decimal18", outputType: "Decimal38Sparse", major: "Numeric"},
+ {inputType: "NullableDecimal18", outputType: "Decimal38Sparse", major: "Numeric"},
+ {inputType: "Decimal28Sparse", outputType: "Decimal38Sparse", major: "Numeric"},
+ {inputType: "NullableDecimal28Sparse", outputType: "Decimal38Sparse", major: "Numeric"},
+ {inputType: "Decimal38Sparse", outputType: "Decimal38Sparse", major: "Numeric"},
+ {inputType: "NullableDecimal38Sparse", outputType: "Decimal38Sparse", major: "Numeric"}
+ ]
}
- ]
+ ]
}
diff --git a/exec/java-exec/src/main/codegen/data/DecimalAggrTypes2.tdd b/exec/java-exec/src/main/codegen/data/DecimalAggrTypes2.tdd
new file mode 100644
index 000000000..ed8d1334b
--- /dev/null
+++ b/exec/java-exec/src/main/codegen/data/DecimalAggrTypes2.tdd
@@ -0,0 +1,31 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http:# www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+{
+ aggrtypes: [
+ {className: "Avg", funcName: "avg", types: [
+ {inputType: "Decimal9", outputType: "Decimal38Sparse", countRunningType: "BigInt", major: "Numeric"},
+ {inputType: "NullableDecimal9", outputType: "Decimal38Sparse", countRunningType: "BigInt", major: "Numeric"},
+ {inputType: "Decimal18", outputType: "Decimal38Sparse", countRunningType: "BigInt", major: "Numeric"},
+ {inputType: "NullableDecimal18", outputType: "Decimal38Sparse", countRunningType: "BigInt", major: "Numeric"},
+ {inputType: "Decimal28Sparse", outputType: "Decimal38Sparse", countRunningType: "BigInt", major: "Numeric"},
+ {inputType: "NullableDecimal28Sparse", outputType: "Decimal38Sparse", countRunningType: "BigInt", major: "Numeric"},
+ {inputType: "Decimal38Sparse", outputType: "Decimal38Sparse", countRunningType: "BigInt", major: "Numeric"},
+ {inputType: "NullableDecimal38Sparse", outputType: "Decimal38Sparse", countRunningType: "BigInt", major: "Numeric"}
+ ]
+ }
+ ]
+} \ No newline at end of file
diff --git a/exec/java-exec/src/main/codegen/templates/DecimalAggrTypeFunctions1.java b/exec/java-exec/src/main/codegen/templates/Decimal/DecimalAggrTypeFunctions1.java
index c5a927c71..f284a191f 100644
--- a/exec/java-exec/src/main/codegen/templates/DecimalAggrTypeFunctions1.java
+++ b/exec/java-exec/src/main/codegen/templates/Decimal/DecimalAggrTypeFunctions1.java
@@ -15,6 +15,9 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+
+import org.apache.drill.exec.expr.annotations.Workspace;
+
<@pp.dropOutputFile />
@@ -50,19 +53,25 @@ public class Decimal${aggrtype.className}Functions {
<#list aggrtype.types as type>
-@FunctionTemplate(name = "${aggrtype.funcName}", scope = FunctionTemplate.FunctionScope.DECIMAL_AGGREGATE)
+@FunctionTemplate(name = "${aggrtype.funcName}", <#if aggrtype.funcName == "sum"> scope = FunctionTemplate.FunctionScope.DECIMAL_SUM_AGGREGATE <#else>scope = FunctionTemplate.FunctionScope.DECIMAL_AGGREGATE</#if>)
public static class ${type.inputType}${aggrtype.className} implements DrillAggFunc{
@Param ${type.inputType}Holder in;
+ <#if aggrtype.funcName == "sum">
+ @Workspace java.math.BigDecimal value;
+ @Workspace int outputScale;
+ <#else>
@Workspace ${type.runningType}Holder value;
+ </#if>
@Workspace ByteBuf buffer;
@Output ${type.outputType}Holder out;
public void setup(RecordBatch b) {
- value = new ${type.runningType}Holder();
<#if aggrtype.funcName == "count">
+ value = new ${type.runningType}Holder();
value.value = 0;
<#elseif aggrtype.funcName == "max" || aggrtype.funcName == "min">
+ value = new ${type.runningType}Holder();
<#if type.outputType.endsWith("Dense") || type.outputType.endsWith("Sparse")>
buffer = io.netty.buffer.Unpooled.wrappedBuffer(new byte[value.WIDTH]);
buffer = new io.netty.buffer.SwappedByteBuf(buffer);
@@ -84,6 +93,9 @@ public static class ${type.inputType}${aggrtype.className} implements DrillAggFu
<#elseif type.outputType == "Decimal9" || type.outputType == "Decimal18">
value.value = ${type.initValue};
</#if>
+ <#elseif aggrtype.funcName == "sum">
+ value = java.math.BigDecimal.ZERO;
+ outputScale = Integer.MIN_VALUE;
</#if>
}
@@ -145,6 +157,16 @@ public static class ${type.inputType}${aggrtype.className} implements DrillAggFu
<#elseif type.outputType == "Decimal9" || type.outputType == "Decimal18">
value.value = Math.min(value.value, in.value);
</#if>
+ <#elseif aggrtype.funcName == "sum">
+ <#if type.inputType.endsWith("Decimal9") || type.inputType.endsWith("Decimal18")>
+ java.math.BigDecimal currentValue = org.apache.drill.common.util.DecimalUtility.getBigDecimalFromPrimitiveTypes(in.value, in.scale, in.precision);
+ <#else>
+ java.math.BigDecimal currentValue = org.apache.drill.common.util.DecimalUtility.getBigDecimalFromSparse(in.buffer, in.start, in.nDecimalDigits, in.scale);
+ </#if>
+ value = value.add(currentValue);
+ if (outputScale == Integer.MIN_VALUE) {
+ outputScale = in.scale;
+ }
</#if>
<#if type.inputType?starts_with("Nullable")>
} // end of sout block
@@ -155,8 +177,20 @@ public static class ${type.inputType}${aggrtype.className} implements DrillAggFu
public void output() {
<#if aggrtype.funcName == "count">
out.value = value.value;
+ <#elseif aggrtype.funcName == "sum">
+ buffer = io.netty.buffer.Unpooled.wrappedBuffer(new byte[out.WIDTH]);
+ buffer = new io.netty.buffer.SwappedByteBuf(buffer);
+ out.buffer = buffer;
+ out.start = 0;
+ out.scale = outputScale;
+ out.precision = 38;
+ for (int i = 0; i < out.nDecimalDigits; i++) {
+ out.setInteger(i, 0);
+ }
+ value = value.setScale(out.scale, java.math.BigDecimal.ROUND_HALF_UP);
+ org.apache.drill.common.util.DecimalUtility.getSparseFromBigDecimal(value, out.buffer, out.start, out.scale, out.precision, out.nDecimalDigits);
<#else>
- <#if type.outputType.endsWith("Dense") || type.outputType.endsWith("Sparse")>
+ <#if type.outputType.endsWith("Dense") || type.outputType.endsWith("Sparse") || aggrtype.funcName == "sum">
out.buffer = value.buffer;
out.start = value.start;
out.setSign(value.getSign());
@@ -191,6 +225,8 @@ public static class ${type.inputType}${aggrtype.className} implements DrillAggFu
<#elseif type.outputType == "Decimal9" || type.outputType == "Decimal18">
value.value = ${type.initValue};
</#if>
+ <#elseif aggrtype.funcName == "sum">
+ value = java.math.BigDecimal.ZERO;
</#if>
}
diff --git a/exec/java-exec/src/main/codegen/templates/Decimal/DecimalAggrTypeFunctions2.java b/exec/java-exec/src/main/codegen/templates/Decimal/DecimalAggrTypeFunctions2.java
new file mode 100644
index 000000000..60d708a3c
--- /dev/null
+++ b/exec/java-exec/src/main/codegen/templates/Decimal/DecimalAggrTypeFunctions2.java
@@ -0,0 +1,123 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import org.apache.drill.exec.expr.annotations.Workspace;
+
+<@pp.dropOutputFile />
+
+
+
+<#list decimalaggrtypes2.aggrtypes as aggrtype>
+<@pp.changeOutputFile name="/org/apache/drill/exec/expr/fn/impl/gaggr/Decimal${aggrtype.className}Functions.java" />
+
+<#include "/@includes/license.ftl" />
+
+<#-- A utility class that is used to generate java code for aggr functions for decimal data type that maintain a single -->
+<#-- running counter to hold the result. This includes: MIN, MAX, COUNT. -->
+
+/*
+ * This class is automatically generated from AggrTypeFunctions1.tdd using FreeMarker.
+ */
+
+package org.apache.drill.exec.expr.fn.impl.gaggr;
+
+import org.apache.drill.exec.expr.DrillAggFunc;
+import org.apache.drill.exec.expr.annotations.FunctionTemplate;
+import org.apache.drill.exec.expr.annotations.FunctionTemplate.FunctionScope;
+import org.apache.drill.exec.expr.annotations.Output;
+import org.apache.drill.exec.expr.annotations.Param;
+import org.apache.drill.exec.expr.annotations.Workspace;
+import org.apache.drill.exec.expr.holders.*;
+import org.apache.drill.exec.record.RecordBatch;
+import io.netty.buffer.ByteBuf;
+
+@SuppressWarnings("unused")
+
+public class Decimal${aggrtype.className}Functions {
+<#list aggrtype.types as type>
+
+@FunctionTemplate(name = "${aggrtype.funcName}", scope = FunctionTemplate.FunctionScope.DECIMAL_SUM_AGGREGATE)
+public static class ${type.inputType}${aggrtype.className} implements DrillAggFunc{
+
+ @Param ${type.inputType}Holder in;
+ @Workspace java.math.BigDecimal value;
+ @Workspace ${type.countRunningType}Holder count;
+ @Workspace ByteBuf buffer;
+ @Workspace int outputScale;
+ @Output ${type.outputType}Holder out;
+
+ public void setup(RecordBatch b) {
+ value = java.math.BigDecimal.ZERO;
+ count = new ${type.countRunningType}Holder();
+ count.value = 0;
+ outputScale = Integer.MIN_VALUE;
+ }
+
+ @Override
+ public void add() {
+ <#if type.inputType?starts_with("Nullable")>
+ sout: {
+ if (in.isSet == 0) {
+ // processing nullable input and the value is null, so don't do anything...
+ break sout;
+ }
+ </#if>
+ count.value++;
+ <#if type.inputType.endsWith("Decimal9") || type.inputType.endsWith("Decimal18")>
+ java.math.BigDecimal currentValue = org.apache.drill.common.util.DecimalUtility.getBigDecimalFromPrimitiveTypes(in.value, in.scale, in.precision);
+ <#else>
+ java.math.BigDecimal currentValue = org.apache.drill.common.util.DecimalUtility.getBigDecimalFromSparse(in.buffer, in.start, in.nDecimalDigits, in.scale);
+ </#if>
+ value = value.add(currentValue);
+ if (outputScale == Integer.MIN_VALUE) {
+ outputScale = in.scale;
+ }
+ <#if type.inputType?starts_with("Nullable")>
+ } // end of sout block
+ </#if>
+ }
+
+ @Override
+ public void output() {
+ buffer = io.netty.buffer.Unpooled.wrappedBuffer(new byte[out.WIDTH]);
+ buffer = new io.netty.buffer.SwappedByteBuf(buffer);
+ out.buffer = buffer;
+ out.start = 0;
+ out.scale = outputScale;
+ out.precision = 38;
+ for (int i = 0; i < out.nDecimalDigits; i++) {
+ out.setInteger(i, 0);
+ }
+ java.math.BigDecimal average = value.divide(java.math.BigDecimal.valueOf(count.value, 0), out.scale, java.math.BigDecimal.ROUND_HALF_UP);
+ org.apache.drill.common.util.DecimalUtility.getSparseFromBigDecimal(average, out.buffer, out.start, out.scale, out.precision, out.nDecimalDigits);
+ }
+
+ @Override
+ public void reset() {
+ value = java.math.BigDecimal.ZERO;
+ count = new ${type.countRunningType}Holder();
+ count.value = 0;
+ outputScale = Integer.MIN_VALUE;
+ }
+}
+
+
+</#list>
+}
+</#list>
+
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/expr/annotations/FunctionTemplate.java b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/annotations/FunctionTemplate.java
index c2ad3d5b0..78a22d82e 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/expr/annotations/FunctionTemplate.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/annotations/FunctionTemplate.java
@@ -56,6 +56,7 @@ public @interface FunctionTemplate {
SIMPLE,
POINT_AGGREGATE,
DECIMAL_AGGREGATE,
+ DECIMAL_SUM_AGGREGATE,
HOLISTIC_AGGREGATE,
RANGE_AGGREGATE,
DECIMAL_MAX_SCALE,
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/DrillDecimalSumAggFuncHolder.java b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/DrillDecimalSumAggFuncHolder.java
new file mode 100644
index 000000000..69ab0677b
--- /dev/null
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/DrillDecimalSumAggFuncHolder.java
@@ -0,0 +1,46 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.drill.exec.expr.fn;
+
+import org.apache.drill.common.expression.LogicalExpression;
+import org.apache.drill.common.types.TypeProtos;
+import org.apache.drill.exec.expr.annotations.FunctionTemplate;
+
+import java.util.List;
+import java.util.Map;
+
+public class DrillDecimalSumAggFuncHolder extends DrillAggFuncHolder {
+ public DrillDecimalSumAggFuncHolder(FunctionTemplate.FunctionScope scope, FunctionTemplate.NullHandling nullHandling, boolean isBinaryCommutative, boolean isRandom, String[] registeredNames, ValueReference[] parameters, ValueReference returnValue, WorkspaceReference[] workspaceVars, Map<String, String> methods, List<String> imports) {
+ super(scope, nullHandling, isBinaryCommutative, isRandom, registeredNames, parameters, returnValue, workspaceVars, methods, imports);
+ }
+
+ @Override
+ public TypeProtos.MajorType getReturnType(List<LogicalExpression> args) {
+
+ int scale = 0;
+ int precision = 0;
+
+ // Get the max scale and precision from the inputs
+ for (LogicalExpression e : args) {
+ scale = Math.max(scale, e.getMajorType().getScale());
+ precision = Math.max(precision, e.getMajorType().getPrecision());
+ }
+
+ return (TypeProtos.MajorType.newBuilder().setMinorType(returnValue.type.getMinorType()).setScale(scale).setPrecision(38).setMode(TypeProtos.DataMode.REQUIRED).build());
+ }
+}
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/FunctionConverter.java b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/FunctionConverter.java
index 3c8536cd5..1d7dd0b36 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/FunctionConverter.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/FunctionConverter.java
@@ -224,6 +224,9 @@ public class FunctionConverter {
case DECIMAL_AGGREGATE:
return new DrillDecimalAggFuncHolder(template.scope(), template.nulls(), template.isBinaryCommutative(),
template.isRandom(), registeredNames, ps, outputField, works, methods, imports);
+ case DECIMAL_SUM_AGGREGATE:
+ return new DrillDecimalSumAggFuncHolder(template.scope(), template.nulls(), template.isBinaryCommutative(),
+ template.isRandom(), registeredNames, ps, outputField, works, methods, imports);
case SIMPLE:
if (outputField.isComplexWriter)
return new DrillComplexWriterFuncHolder(template.scope(), template.nulls(),