diff options
3 files changed, 63 insertions, 10 deletions
diff --git a/common/src/main/java/org/apache/drill/common/util/DecimalUtility.java b/common/src/main/java/org/apache/drill/common/util/DecimalUtility.java index 7f1a4a0e5..0ac870a1a 100644 --- a/common/src/main/java/org/apache/drill/common/util/DecimalUtility.java +++ b/common/src/main/java/org/apache/drill/common/util/DecimalUtility.java @@ -276,7 +276,7 @@ public class DecimalUtility { } // Truncate the input as per the scale provided - input = input.setScale(scale, BigDecimal.ROUND_DOWN); + input = input.setScale(scale, BigDecimal.ROUND_HALF_UP); // Separate out the integer part BigDecimal integerPart = input.setScale(0, BigDecimal.ROUND_DOWN); @@ -329,14 +329,14 @@ public class DecimalUtility { } public static int getDecimal9FromBigDecimal(BigDecimal input, int scale, int precision) { // Truncate/ or pad to set the input to the correct scale - input = input.setScale(scale, BigDecimal.ROUND_DOWN); + input = input.setScale(scale, BigDecimal.ROUND_HALF_UP); return (input.unscaledValue().intValue()); } public static long getDecimal18FromBigDecimal(BigDecimal input, int scale, int precision) { // Truncate or pad to set the input to the correct scale - input = input.setScale(scale, BigDecimal.ROUND_DOWN); + input = input.setScale(scale, BigDecimal.ROUND_HALF_UP); return (input.unscaledValue().longValue()); } diff --git a/exec/java-exec/src/main/codegen/templates/Decimal/CastVarCharDecimal.java b/exec/java-exec/src/main/codegen/templates/Decimal/CastVarCharDecimal.java index ceebc0a5a..8a50bb65e 100644 --- a/exec/java-exec/src/main/codegen/templates/Decimal/CastVarCharDecimal.java +++ b/exec/java-exec/src/main/codegen/templates/Decimal/CastVarCharDecimal.java @@ -84,6 +84,7 @@ public class Cast${type.from}${type.to} implements DrillSimpleFunc { int integerStartIndex = readIndex; int integerEndIndex = endIndex; boolean leadingDigitFound = false; + boolean round = false; int radix = 10; @@ -96,8 +97,10 @@ public class Cast${type.from}${type.to} implements DrillSimpleFunc { // Integer end index is just before the scale part begins integerEndIndex = scaleIndex - 1; // If the number of fractional digits is > scale specified we might have to truncate - endIndex = (scaleIndex + out.scale) < endIndex ? (scaleIndex + out.scale) : endIndex; - + if ((scaleIndex + out.scale) < endIndex ) { + endIndex = scaleIndex + out.scale; + round = true; + } continue; } else { // If its not a '.' we expect only numbers @@ -129,6 +132,21 @@ public class Cast${type.from}${type.to} implements DrillSimpleFunc { " Total Digits: " + (out.scale + (integerEndIndex - integerStartIndex))); } + // Check if we need to round up + if (round == true) { + next = in.buffer.getByte(endIndex); + next = (byte) Character.digit(next, radix); + if (next == -1) { + // not a valid digit + byte[] buf = new byte[in.end - in.start]; + in.buffer.getBytes(in.start, buf, 0, in.end - in.start); + throw new org.apache.drill.common.exceptions.DrillRuntimeException(new String(buf, com.google.common.base.Charsets.UTF_8)); + } + if (next > 4) { + out.value++; + } + } + // Number of fractional digits in the input int fractionalDigits = (scaleIndex == -1) ? 0 : ((endIndex - scaleIndex)); @@ -180,7 +198,6 @@ public class Cast${type.from}${type.to} implements DrillSimpleFunc { } public void eval() { - out.buffer = buffer; out.start = 0; @@ -225,6 +242,7 @@ public class Cast${type.from}${type.to} implements DrillSimpleFunc { int radix = 10; boolean leadingDigitFound = false; + boolean round = false; /* This is the first pass, we get the number of integer digits and based on the provided scale * we compute which index into the ByteBuf we start storing the integer part of the Decimal @@ -239,7 +257,10 @@ public class Cast${type.from}${type.to} implements DrillSimpleFunc { // We have found the decimal point. we can compute the starting index into the Decimal's bytebuf scaleIndex = readIndex; // We may have to truncate fractional part if > scale - scaleEndIndex = ((in.end - scaleIndex) <= out.scale) ? in.end : (scaleIndex + out.scale); + if ((in.end - scaleIndex) > out.scale) { + scaleEndIndex = scaleIndex + out.scale; + round = true; + } break; } @@ -337,9 +358,41 @@ public class Cast${type.from}${type.to} implements DrillSimpleFunc { // added another digit to the current index ndigits++; } + + // round up the decimal if we had to chop off a part of it + if (round == true) { + next = in.buffer.getByte(scaleEndIndex); + + // We expect only numbers beyond this + next = (byte) Character.digit(next, radix); + + if (next == -1) { + // not a valid digit + byte[] buf = new byte[in.end - in.start]; + in.buffer.getBytes(in.start, buf, 0, in.end - in.start); + throw new NumberFormatException(new String(buf, com.google.common.base.Charsets.UTF_8)); + } + if (next > 4) { + // Need to round up + out.setInteger(decimalBufferIndex, out.getInteger(decimalBufferIndex)+1); + } + } // Pad zeroes in the fractional part so that number of digits = MAX_DIGITS int padding = (int) org.apache.drill.common.util.DecimalUtility.getPowerOfTen((int) (org.apache.drill.common.util.DecimalUtility.MAX_DIGITS - ndigits)); out.setInteger(decimalBufferIndex, out.getInteger(decimalBufferIndex) * padding); + + int carry = 0; + do { + // propogate the carry + int tempValue = out.getInteger(decimalBufferIndex) + carry; + if (tempValue >= org.apache.drill.common.util.DecimalUtility.DIGITS_BASE) { + carry = tempValue / org.apache.drill.common.util.DecimalUtility.DIGITS_BASE; + tempValue = (tempValue % org.apache.drill.common.util.DecimalUtility.DIGITS_BASE); + } else { + carry = 0; + } + out.setInteger(decimalBufferIndex--, tempValue); + } while (carry > 0 && decimalBufferIndex >= 0); } out.setSign(sign); } diff --git a/exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/TestDecimal.java b/exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/TestDecimal.java index 093366f29..489336a0b 100644 --- a/exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/TestDecimal.java +++ b/exec/java-exec/src/test/java/org/apache/drill/exec/physical/impl/TestDecimal.java @@ -66,7 +66,7 @@ public class TestDecimal extends PopUnitTestBase{ QueryResultBatch batch = results.get(0); assertTrue(batchLoader.load(batch.getHeader().getDef(), batch.getData())); - String decimal9Output[] = {"99.0000", "11.1234", "0.1000", "-0.1200", "-123.1234", "-1.0001"}; + String decimal9Output[] = {"99.0000", "11.1235", "0.1000", "-0.1200", "-123.1234", "-1.0001"}; String decimal18Output[] = {"123456789.000000000", "11.123456789", "0.100000000", "-0.100400000", "-987654321.123456789", "-2.030100000"}; Iterator<VectorWrapper<?>> itr = batchLoader.iterator(); @@ -111,8 +111,8 @@ public class TestDecimal extends PopUnitTestBase{ QueryResultBatch batch = results.get(0); assertTrue(batchLoader.load(batch.getHeader().getDef(), batch.getData())); - String decimal9Output[] = {"99.0000", "11.1234", "0.1000", "-0.1200", "-123.1234", "-1.0001"}; - String decimal38Output[] = {"123456789.0000", "11.1234", "0.1000", "-0.1004", "-987654321.1234", "-2.0301"}; + String decimal9Output[] = {"99.0000", "11.1235", "0.1000", "-0.1200", "-123.1234", "-1.0001"}; + String decimal38Output[] = {"123456789.0000", "11.1235", "0.1000", "-0.1004", "-987654321.1235", "-2.0301"}; Iterator<VectorWrapper<?>> itr = batchLoader.iterator(); |