aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--common/src/main/java/org/apache/drill/common/util/DecimalUtility.java6
-rw-r--r--exec/java-exec/src/main/codegen/templates/Decimal/CastVarCharDecimal.java61
-rw-r--r--exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/TestDecimal.java6
3 files changed, 63 insertions, 10 deletions
diff --git a/common/src/main/java/org/apache/drill/common/util/DecimalUtility.java b/common/src/main/java/org/apache/drill/common/util/DecimalUtility.java
index 7f1a4a0e5..0ac870a1a 100644
--- a/common/src/main/java/org/apache/drill/common/util/DecimalUtility.java
+++ b/common/src/main/java/org/apache/drill/common/util/DecimalUtility.java
@@ -276,7 +276,7 @@ public class DecimalUtility {
}
// Truncate the input as per the scale provided
- input = input.setScale(scale, BigDecimal.ROUND_DOWN);
+ input = input.setScale(scale, BigDecimal.ROUND_HALF_UP);
// Separate out the integer part
BigDecimal integerPart = input.setScale(0, BigDecimal.ROUND_DOWN);
@@ -329,14 +329,14 @@ public class DecimalUtility {
}
public static int getDecimal9FromBigDecimal(BigDecimal input, int scale, int precision) {
// Truncate/ or pad to set the input to the correct scale
- input = input.setScale(scale, BigDecimal.ROUND_DOWN);
+ input = input.setScale(scale, BigDecimal.ROUND_HALF_UP);
return (input.unscaledValue().intValue());
}
public static long getDecimal18FromBigDecimal(BigDecimal input, int scale, int precision) {
// Truncate or pad to set the input to the correct scale
- input = input.setScale(scale, BigDecimal.ROUND_DOWN);
+ input = input.setScale(scale, BigDecimal.ROUND_HALF_UP);
return (input.unscaledValue().longValue());
}
diff --git a/exec/java-exec/src/main/codegen/templates/Decimal/CastVarCharDecimal.java b/exec/java-exec/src/main/codegen/templates/Decimal/CastVarCharDecimal.java
index ceebc0a5a..8a50bb65e 100644
--- a/exec/java-exec/src/main/codegen/templates/Decimal/CastVarCharDecimal.java
+++ b/exec/java-exec/src/main/codegen/templates/Decimal/CastVarCharDecimal.java
@@ -84,6 +84,7 @@ public class Cast${type.from}${type.to} implements DrillSimpleFunc {
int integerStartIndex = readIndex;
int integerEndIndex = endIndex;
boolean leadingDigitFound = false;
+ boolean round = false;
int radix = 10;
@@ -96,8 +97,10 @@ public class Cast${type.from}${type.to} implements DrillSimpleFunc {
// Integer end index is just before the scale part begins
integerEndIndex = scaleIndex - 1;
// If the number of fractional digits is > scale specified we might have to truncate
- endIndex = (scaleIndex + out.scale) < endIndex ? (scaleIndex + out.scale) : endIndex;
-
+ if ((scaleIndex + out.scale) < endIndex ) {
+ endIndex = scaleIndex + out.scale;
+ round = true;
+ }
continue;
} else {
// If its not a '.' we expect only numbers
@@ -129,6 +132,21 @@ public class Cast${type.from}${type.to} implements DrillSimpleFunc {
" Total Digits: " + (out.scale + (integerEndIndex - integerStartIndex)));
}
+ // Check if we need to round up
+ if (round == true) {
+ next = in.buffer.getByte(endIndex);
+ next = (byte) Character.digit(next, radix);
+ if (next == -1) {
+ // not a valid digit
+ byte[] buf = new byte[in.end - in.start];
+ in.buffer.getBytes(in.start, buf, 0, in.end - in.start);
+ throw new org.apache.drill.common.exceptions.DrillRuntimeException(new String(buf, com.google.common.base.Charsets.UTF_8));
+ }
+ if (next > 4) {
+ out.value++;
+ }
+ }
+
// Number of fractional digits in the input
int fractionalDigits = (scaleIndex == -1) ? 0 : ((endIndex - scaleIndex));
@@ -180,7 +198,6 @@ public class Cast${type.from}${type.to} implements DrillSimpleFunc {
}
public void eval() {
-
out.buffer = buffer;
out.start = 0;
@@ -225,6 +242,7 @@ public class Cast${type.from}${type.to} implements DrillSimpleFunc {
int radix = 10;
boolean leadingDigitFound = false;
+ boolean round = false;
/* This is the first pass, we get the number of integer digits and based on the provided scale
* we compute which index into the ByteBuf we start storing the integer part of the Decimal
@@ -239,7 +257,10 @@ public class Cast${type.from}${type.to} implements DrillSimpleFunc {
// We have found the decimal point. we can compute the starting index into the Decimal's bytebuf
scaleIndex = readIndex;
// We may have to truncate fractional part if > scale
- scaleEndIndex = ((in.end - scaleIndex) <= out.scale) ? in.end : (scaleIndex + out.scale);
+ if ((in.end - scaleIndex) > out.scale) {
+ scaleEndIndex = scaleIndex + out.scale;
+ round = true;
+ }
break;
}
@@ -337,9 +358,41 @@ public class Cast${type.from}${type.to} implements DrillSimpleFunc {
// added another digit to the current index
ndigits++;
}
+
+ // round up the decimal if we had to chop off a part of it
+ if (round == true) {
+ next = in.buffer.getByte(scaleEndIndex);
+
+ // We expect only numbers beyond this
+ next = (byte) Character.digit(next, radix);
+
+ if (next == -1) {
+ // not a valid digit
+ byte[] buf = new byte[in.end - in.start];
+ in.buffer.getBytes(in.start, buf, 0, in.end - in.start);
+ throw new NumberFormatException(new String(buf, com.google.common.base.Charsets.UTF_8));
+ }
+ if (next > 4) {
+ // Need to round up
+ out.setInteger(decimalBufferIndex, out.getInteger(decimalBufferIndex)+1);
+ }
+ }
// Pad zeroes in the fractional part so that number of digits = MAX_DIGITS
int padding = (int) org.apache.drill.common.util.DecimalUtility.getPowerOfTen((int) (org.apache.drill.common.util.DecimalUtility.MAX_DIGITS - ndigits));
out.setInteger(decimalBufferIndex, out.getInteger(decimalBufferIndex) * padding);
+
+ int carry = 0;
+ do {
+ // propogate the carry
+ int tempValue = out.getInteger(decimalBufferIndex) + carry;
+ if (tempValue >= org.apache.drill.common.util.DecimalUtility.DIGITS_BASE) {
+ carry = tempValue / org.apache.drill.common.util.DecimalUtility.DIGITS_BASE;
+ tempValue = (tempValue % org.apache.drill.common.util.DecimalUtility.DIGITS_BASE);
+ } else {
+ carry = 0;
+ }
+ out.setInteger(decimalBufferIndex--, tempValue);
+ } while (carry > 0 && decimalBufferIndex >= 0);
}
out.setSign(sign);
}
diff --git a/exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/TestDecimal.java b/exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/TestDecimal.java
index 093366f29..489336a0b 100644
--- a/exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/TestDecimal.java
+++ b/exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/TestDecimal.java
@@ -66,7 +66,7 @@ public class TestDecimal extends PopUnitTestBase{
QueryResultBatch batch = results.get(0);
assertTrue(batchLoader.load(batch.getHeader().getDef(), batch.getData()));
- String decimal9Output[] = {"99.0000", "11.1234", "0.1000", "-0.1200", "-123.1234", "-1.0001"};
+ String decimal9Output[] = {"99.0000", "11.1235", "0.1000", "-0.1200", "-123.1234", "-1.0001"};
String decimal18Output[] = {"123456789.000000000", "11.123456789", "0.100000000", "-0.100400000", "-987654321.123456789", "-2.030100000"};
Iterator<VectorWrapper<?>> itr = batchLoader.iterator();
@@ -111,8 +111,8 @@ public class TestDecimal extends PopUnitTestBase{
QueryResultBatch batch = results.get(0);
assertTrue(batchLoader.load(batch.getHeader().getDef(), batch.getData()));
- String decimal9Output[] = {"99.0000", "11.1234", "0.1000", "-0.1200", "-123.1234", "-1.0001"};
- String decimal38Output[] = {"123456789.0000", "11.1234", "0.1000", "-0.1004", "-987654321.1234", "-2.0301"};
+ String decimal9Output[] = {"99.0000", "11.1235", "0.1000", "-0.1200", "-123.1234", "-1.0001"};
+ String decimal38Output[] = {"123456789.0000", "11.1235", "0.1000", "-0.1004", "-987654321.1235", "-2.0301"};
Iterator<VectorWrapper<?>> itr = batchLoader.iterator();