diff options
author | Mehant Baid <mehantr@gmail.com> | 2014-07-03 14:35:13 -0700 |
---|---|---|
committer | Jacques Nadeau <jacques@apache.org> | 2014-07-07 14:50:31 -0700 |
commit | e987e06c7e9de907a689e658768d38b4f2dc5bd7 (patch) | |
tree | 7b9497e7fc2bb0d195cb9727c9ecd78aeaf0324d | |
parent | cd4f7267de975f89aef775860532d49220ee8e88 (diff) |
DRILL-861: Implement sum, avg for decimal data type.
-rw-r--r-- | common/src/main/java/org/apache/drill/common/util/DecimalUtility.java | 10 | ||||
-rw-r--r-- | exec/java-exec/src/main/codegen/config.fmpp | 1 | ||||
-rw-r--r-- | exec/java-exec/src/main/codegen/data/DecimalAggrTypes1.tdd | 13 | ||||
-rw-r--r-- | exec/java-exec/src/main/codegen/data/DecimalAggrTypes2.tdd | 31 | ||||
-rw-r--r-- | exec/java-exec/src/main/codegen/templates/Decimal/DecimalAggrTypeFunctions1.java (renamed from exec/java-exec/src/main/codegen/templates/DecimalAggrTypeFunctions1.java) | 42 | ||||
-rw-r--r-- | exec/java-exec/src/main/codegen/templates/Decimal/DecimalAggrTypeFunctions2.java | 123 | ||||
-rw-r--r-- | exec/java-exec/src/main/java/org/apache/drill/exec/expr/annotations/FunctionTemplate.java | 1 | ||||
-rw-r--r-- | exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/DrillDecimalSumAggFuncHolder.java | 46 | ||||
-rw-r--r-- | exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/FunctionConverter.java | 3 |
9 files changed, 266 insertions, 4 deletions
diff --git a/common/src/main/java/org/apache/drill/common/util/DecimalUtility.java b/common/src/main/java/org/apache/drill/common/util/DecimalUtility.java index 4f9f09601..465cf8219 100644 --- a/common/src/main/java/org/apache/drill/common/util/DecimalUtility.java +++ b/common/src/main/java/org/apache/drill/common/util/DecimalUtility.java @@ -23,6 +23,7 @@ import org.apache.drill.common.types.TypeProtos; import java.math.BigDecimal; import java.math.BigInteger; +import java.math.RoundingMode; import java.util.Arrays; public class DecimalUtility { @@ -341,6 +342,15 @@ public class DecimalUtility { return (input.unscaledValue().longValue()); } + public static BigDecimal getBigDecimalFromPrimitiveTypes(int input, int scale, int precision) { + return BigDecimal.valueOf(input, scale); + } + + public static BigDecimal getBigDecimalFromPrimitiveTypes(long input, int scale, int precision) { + return BigDecimal.valueOf(input, scale); + } + + public static int compareDenseBytes(ByteBuf left, int leftStart, boolean leftSign, ByteBuf right, int rightStart, boolean rightSign, int width) { int invert = 1; diff --git a/exec/java-exec/src/main/codegen/config.fmpp b/exec/java-exec/src/main/codegen/config.fmpp index 83a0b70c5..72520f6a3 100644 --- a/exec/java-exec/src/main/codegen/config.fmpp +++ b/exec/java-exec/src/main/codegen/config.fmpp @@ -22,6 +22,7 @@ data: { mathFunc:tdd(../data/MathFunc.tdd), aggrtypes1: tdd(../data/AggrTypes1.tdd), decimalaggrtypes1: tdd(../data/DecimalAggrTypes1.tdd), + decimalaggrtypes2: tdd(../data/DecimalAggrTypes2.tdd), aggrtypes2: tdd(../data/AggrTypes2.tdd), aggrtypes3: tdd(../data/AggrTypes3.tdd), covarTypes: tdd(../data/CovarTypes.tdd), diff --git a/exec/java-exec/src/main/codegen/data/DecimalAggrTypes1.tdd b/exec/java-exec/src/main/codegen/data/DecimalAggrTypes1.tdd index 558f95bae..d8a2c73db 100644 --- a/exec/java-exec/src/main/codegen/data/DecimalAggrTypes1.tdd +++ b/exec/java-exec/src/main/codegen/data/DecimalAggrTypes1.tdd @@ -60,6 +60,17 @@ {inputType: "Decimal38Dense", outputType: "BigInt", runningType: "BigInt", major: "Numeric"}, {inputType: "NullableDecimal38Dense", outputType: "BigInt", runningType: "BigInt", major: "Numeric"} ] + }, + {className: "Sum", funcName: "sum", types: [ + {inputType: "Decimal9", outputType: "Decimal38Sparse", major: "Numeric"}, + {inputType: "NullableDecimal9", outputType: "Decimal38Sparse", major: "Numeric"}, + {inputType: "Decimal18", outputType: "Decimal38Sparse", major: "Numeric"}, + {inputType: "NullableDecimal18", outputType: "Decimal38Sparse", major: "Numeric"}, + {inputType: "Decimal28Sparse", outputType: "Decimal38Sparse", major: "Numeric"}, + {inputType: "NullableDecimal28Sparse", outputType: "Decimal38Sparse", major: "Numeric"}, + {inputType: "Decimal38Sparse", outputType: "Decimal38Sparse", major: "Numeric"}, + {inputType: "NullableDecimal38Sparse", outputType: "Decimal38Sparse", major: "Numeric"} + ] } - ] + ] } diff --git a/exec/java-exec/src/main/codegen/data/DecimalAggrTypes2.tdd b/exec/java-exec/src/main/codegen/data/DecimalAggrTypes2.tdd new file mode 100644 index 000000000..ed8d1334b --- /dev/null +++ b/exec/java-exec/src/main/codegen/data/DecimalAggrTypes2.tdd @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http:# www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +{ + aggrtypes: [ + {className: "Avg", funcName: "avg", types: [ + {inputType: "Decimal9", outputType: "Decimal38Sparse", countRunningType: "BigInt", major: "Numeric"}, + {inputType: "NullableDecimal9", outputType: "Decimal38Sparse", countRunningType: "BigInt", major: "Numeric"}, + {inputType: "Decimal18", outputType: "Decimal38Sparse", countRunningType: "BigInt", major: "Numeric"}, + {inputType: "NullableDecimal18", outputType: "Decimal38Sparse", countRunningType: "BigInt", major: "Numeric"}, + {inputType: "Decimal28Sparse", outputType: "Decimal38Sparse", countRunningType: "BigInt", major: "Numeric"}, + {inputType: "NullableDecimal28Sparse", outputType: "Decimal38Sparse", countRunningType: "BigInt", major: "Numeric"}, + {inputType: "Decimal38Sparse", outputType: "Decimal38Sparse", countRunningType: "BigInt", major: "Numeric"}, + {inputType: "NullableDecimal38Sparse", outputType: "Decimal38Sparse", countRunningType: "BigInt", major: "Numeric"} + ] + } + ] +}
\ No newline at end of file diff --git a/exec/java-exec/src/main/codegen/templates/DecimalAggrTypeFunctions1.java b/exec/java-exec/src/main/codegen/templates/Decimal/DecimalAggrTypeFunctions1.java index c5a927c71..f284a191f 100644 --- a/exec/java-exec/src/main/codegen/templates/DecimalAggrTypeFunctions1.java +++ b/exec/java-exec/src/main/codegen/templates/Decimal/DecimalAggrTypeFunctions1.java @@ -15,6 +15,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + +import org.apache.drill.exec.expr.annotations.Workspace; + <@pp.dropOutputFile /> @@ -50,19 +53,25 @@ public class Decimal${aggrtype.className}Functions { <#list aggrtype.types as type> -@FunctionTemplate(name = "${aggrtype.funcName}", scope = FunctionTemplate.FunctionScope.DECIMAL_AGGREGATE) +@FunctionTemplate(name = "${aggrtype.funcName}", <#if aggrtype.funcName == "sum"> scope = FunctionTemplate.FunctionScope.DECIMAL_SUM_AGGREGATE <#else>scope = FunctionTemplate.FunctionScope.DECIMAL_AGGREGATE</#if>) public static class ${type.inputType}${aggrtype.className} implements DrillAggFunc{ @Param ${type.inputType}Holder in; + <#if aggrtype.funcName == "sum"> + @Workspace java.math.BigDecimal value; + @Workspace int outputScale; + <#else> @Workspace ${type.runningType}Holder value; + </#if> @Workspace ByteBuf buffer; @Output ${type.outputType}Holder out; public void setup(RecordBatch b) { - value = new ${type.runningType}Holder(); <#if aggrtype.funcName == "count"> + value = new ${type.runningType}Holder(); value.value = 0; <#elseif aggrtype.funcName == "max" || aggrtype.funcName == "min"> + value = new ${type.runningType}Holder(); <#if type.outputType.endsWith("Dense") || type.outputType.endsWith("Sparse")> buffer = io.netty.buffer.Unpooled.wrappedBuffer(new byte[value.WIDTH]); buffer = new io.netty.buffer.SwappedByteBuf(buffer); @@ -84,6 +93,9 @@ public static class ${type.inputType}${aggrtype.className} implements DrillAggFu <#elseif type.outputType == "Decimal9" || type.outputType == "Decimal18"> value.value = ${type.initValue}; </#if> + <#elseif aggrtype.funcName == "sum"> + value = java.math.BigDecimal.ZERO; + outputScale = Integer.MIN_VALUE; </#if> } @@ -145,6 +157,16 @@ public static class ${type.inputType}${aggrtype.className} implements DrillAggFu <#elseif type.outputType == "Decimal9" || type.outputType == "Decimal18"> value.value = Math.min(value.value, in.value); </#if> + <#elseif aggrtype.funcName == "sum"> + <#if type.inputType.endsWith("Decimal9") || type.inputType.endsWith("Decimal18")> + java.math.BigDecimal currentValue = org.apache.drill.common.util.DecimalUtility.getBigDecimalFromPrimitiveTypes(in.value, in.scale, in.precision); + <#else> + java.math.BigDecimal currentValue = org.apache.drill.common.util.DecimalUtility.getBigDecimalFromSparse(in.buffer, in.start, in.nDecimalDigits, in.scale); + </#if> + value = value.add(currentValue); + if (outputScale == Integer.MIN_VALUE) { + outputScale = in.scale; + } </#if> <#if type.inputType?starts_with("Nullable")> } // end of sout block @@ -155,8 +177,20 @@ public static class ${type.inputType}${aggrtype.className} implements DrillAggFu public void output() { <#if aggrtype.funcName == "count"> out.value = value.value; + <#elseif aggrtype.funcName == "sum"> + buffer = io.netty.buffer.Unpooled.wrappedBuffer(new byte[out.WIDTH]); + buffer = new io.netty.buffer.SwappedByteBuf(buffer); + out.buffer = buffer; + out.start = 0; + out.scale = outputScale; + out.precision = 38; + for (int i = 0; i < out.nDecimalDigits; i++) { + out.setInteger(i, 0); + } + value = value.setScale(out.scale, java.math.BigDecimal.ROUND_HALF_UP); + org.apache.drill.common.util.DecimalUtility.getSparseFromBigDecimal(value, out.buffer, out.start, out.scale, out.precision, out.nDecimalDigits); <#else> - <#if type.outputType.endsWith("Dense") || type.outputType.endsWith("Sparse")> + <#if type.outputType.endsWith("Dense") || type.outputType.endsWith("Sparse") || aggrtype.funcName == "sum"> out.buffer = value.buffer; out.start = value.start; out.setSign(value.getSign()); @@ -191,6 +225,8 @@ public static class ${type.inputType}${aggrtype.className} implements DrillAggFu <#elseif type.outputType == "Decimal9" || type.outputType == "Decimal18"> value.value = ${type.initValue}; </#if> + <#elseif aggrtype.funcName == "sum"> + value = java.math.BigDecimal.ZERO; </#if> } diff --git a/exec/java-exec/src/main/codegen/templates/Decimal/DecimalAggrTypeFunctions2.java b/exec/java-exec/src/main/codegen/templates/Decimal/DecimalAggrTypeFunctions2.java new file mode 100644 index 000000000..60d708a3c --- /dev/null +++ b/exec/java-exec/src/main/codegen/templates/Decimal/DecimalAggrTypeFunctions2.java @@ -0,0 +1,123 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.drill.exec.expr.annotations.Workspace; + +<@pp.dropOutputFile /> + + + +<#list decimalaggrtypes2.aggrtypes as aggrtype> +<@pp.changeOutputFile name="/org/apache/drill/exec/expr/fn/impl/gaggr/Decimal${aggrtype.className}Functions.java" /> + +<#include "/@includes/license.ftl" /> + +<#-- A utility class that is used to generate java code for aggr functions for decimal data type that maintain a single --> +<#-- running counter to hold the result. This includes: MIN, MAX, COUNT. --> + +/* + * This class is automatically generated from AggrTypeFunctions1.tdd using FreeMarker. + */ + +package org.apache.drill.exec.expr.fn.impl.gaggr; + +import org.apache.drill.exec.expr.DrillAggFunc; +import org.apache.drill.exec.expr.annotations.FunctionTemplate; +import org.apache.drill.exec.expr.annotations.FunctionTemplate.FunctionScope; +import org.apache.drill.exec.expr.annotations.Output; +import org.apache.drill.exec.expr.annotations.Param; +import org.apache.drill.exec.expr.annotations.Workspace; +import org.apache.drill.exec.expr.holders.*; +import org.apache.drill.exec.record.RecordBatch; +import io.netty.buffer.ByteBuf; + +@SuppressWarnings("unused") + +public class Decimal${aggrtype.className}Functions { +<#list aggrtype.types as type> + +@FunctionTemplate(name = "${aggrtype.funcName}", scope = FunctionTemplate.FunctionScope.DECIMAL_SUM_AGGREGATE) +public static class ${type.inputType}${aggrtype.className} implements DrillAggFunc{ + + @Param ${type.inputType}Holder in; + @Workspace java.math.BigDecimal value; + @Workspace ${type.countRunningType}Holder count; + @Workspace ByteBuf buffer; + @Workspace int outputScale; + @Output ${type.outputType}Holder out; + + public void setup(RecordBatch b) { + value = java.math.BigDecimal.ZERO; + count = new ${type.countRunningType}Holder(); + count.value = 0; + outputScale = Integer.MIN_VALUE; + } + + @Override + public void add() { + <#if type.inputType?starts_with("Nullable")> + sout: { + if (in.isSet == 0) { + // processing nullable input and the value is null, so don't do anything... + break sout; + } + </#if> + count.value++; + <#if type.inputType.endsWith("Decimal9") || type.inputType.endsWith("Decimal18")> + java.math.BigDecimal currentValue = org.apache.drill.common.util.DecimalUtility.getBigDecimalFromPrimitiveTypes(in.value, in.scale, in.precision); + <#else> + java.math.BigDecimal currentValue = org.apache.drill.common.util.DecimalUtility.getBigDecimalFromSparse(in.buffer, in.start, in.nDecimalDigits, in.scale); + </#if> + value = value.add(currentValue); + if (outputScale == Integer.MIN_VALUE) { + outputScale = in.scale; + } + <#if type.inputType?starts_with("Nullable")> + } // end of sout block + </#if> + } + + @Override + public void output() { + buffer = io.netty.buffer.Unpooled.wrappedBuffer(new byte[out.WIDTH]); + buffer = new io.netty.buffer.SwappedByteBuf(buffer); + out.buffer = buffer; + out.start = 0; + out.scale = outputScale; + out.precision = 38; + for (int i = 0; i < out.nDecimalDigits; i++) { + out.setInteger(i, 0); + } + java.math.BigDecimal average = value.divide(java.math.BigDecimal.valueOf(count.value, 0), out.scale, java.math.BigDecimal.ROUND_HALF_UP); + org.apache.drill.common.util.DecimalUtility.getSparseFromBigDecimal(average, out.buffer, out.start, out.scale, out.precision, out.nDecimalDigits); + } + + @Override + public void reset() { + value = java.math.BigDecimal.ZERO; + count = new ${type.countRunningType}Holder(); + count.value = 0; + outputScale = Integer.MIN_VALUE; + } +} + + +</#list> +} +</#list> + diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/expr/annotations/FunctionTemplate.java b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/annotations/FunctionTemplate.java index c2ad3d5b0..78a22d82e 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/expr/annotations/FunctionTemplate.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/annotations/FunctionTemplate.java @@ -56,6 +56,7 @@ public @interface FunctionTemplate { SIMPLE, POINT_AGGREGATE, DECIMAL_AGGREGATE, + DECIMAL_SUM_AGGREGATE, HOLISTIC_AGGREGATE, RANGE_AGGREGATE, DECIMAL_MAX_SCALE, diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/DrillDecimalSumAggFuncHolder.java b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/DrillDecimalSumAggFuncHolder.java new file mode 100644 index 000000000..69ab0677b --- /dev/null +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/DrillDecimalSumAggFuncHolder.java @@ -0,0 +1,46 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.drill.exec.expr.fn; + +import org.apache.drill.common.expression.LogicalExpression; +import org.apache.drill.common.types.TypeProtos; +import org.apache.drill.exec.expr.annotations.FunctionTemplate; + +import java.util.List; +import java.util.Map; + +public class DrillDecimalSumAggFuncHolder extends DrillAggFuncHolder { + public DrillDecimalSumAggFuncHolder(FunctionTemplate.FunctionScope scope, FunctionTemplate.NullHandling nullHandling, boolean isBinaryCommutative, boolean isRandom, String[] registeredNames, ValueReference[] parameters, ValueReference returnValue, WorkspaceReference[] workspaceVars, Map<String, String> methods, List<String> imports) { + super(scope, nullHandling, isBinaryCommutative, isRandom, registeredNames, parameters, returnValue, workspaceVars, methods, imports); + } + + @Override + public TypeProtos.MajorType getReturnType(List<LogicalExpression> args) { + + int scale = 0; + int precision = 0; + + // Get the max scale and precision from the inputs + for (LogicalExpression e : args) { + scale = Math.max(scale, e.getMajorType().getScale()); + precision = Math.max(precision, e.getMajorType().getPrecision()); + } + + return (TypeProtos.MajorType.newBuilder().setMinorType(returnValue.type.getMinorType()).setScale(scale).setPrecision(38).setMode(TypeProtos.DataMode.REQUIRED).build()); + } +} diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/FunctionConverter.java b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/FunctionConverter.java index 3c8536cd5..1d7dd0b36 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/FunctionConverter.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/FunctionConverter.java @@ -224,6 +224,9 @@ public class FunctionConverter { case DECIMAL_AGGREGATE: return new DrillDecimalAggFuncHolder(template.scope(), template.nulls(), template.isBinaryCommutative(), template.isRandom(), registeredNames, ps, outputField, works, methods, imports); + case DECIMAL_SUM_AGGREGATE: + return new DrillDecimalSumAggFuncHolder(template.scope(), template.nulls(), template.isBinaryCommutative(), + template.isRandom(), registeredNames, ps, outputField, works, methods, imports); case SIMPLE: if (outputField.isComplexWriter) return new DrillComplexWriterFuncHolder(template.scope(), template.nulls(), |