diff options
Diffstat (limited to 'exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/TypeInferenceUtils.java')
-rw-r--r-- | exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/TypeInferenceUtils.java | 208 |
1 files changed, 167 insertions, 41 deletions
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/TypeInferenceUtils.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/TypeInferenceUtils.java index c6c93209a..529546f54 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/TypeInferenceUtils.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/TypeInferenceUtils.java @@ -18,6 +18,7 @@ package org.apache.drill.exec.planner.sql; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Lists; import org.apache.calcite.avatica.util.TimeUnit; @@ -45,12 +46,16 @@ import org.apache.drill.common.expression.MajorTypeInLogicalExpression; import org.apache.drill.common.exceptions.UserException; import org.apache.drill.common.types.TypeProtos; import org.apache.drill.common.types.Types; +import org.apache.drill.common.util.CoreDecimalUtility; +import org.apache.drill.exec.expr.annotations.FunctionTemplate; import org.apache.drill.exec.expr.fn.DrillFuncHolder; +import org.apache.drill.exec.planner.types.DrillRelDataTypeSystem; import org.apache.drill.exec.resolver.FunctionResolver; import org.apache.drill.exec.resolver.FunctionResolverFactory; import org.apache.drill.exec.resolver.TypeCastRules; import java.util.List; +import java.util.Set; public class TypeInferenceUtils { private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(TypeInferenceUtils.class); @@ -111,6 +116,7 @@ public class TypeInferenceUtils { .put(SqlTypeName.INTERVAL_MINUTE, TypeProtos.MinorType.INTERVALDAY) .put(SqlTypeName.INTERVAL_MINUTE_SECOND, TypeProtos.MinorType.INTERVALDAY) .put(SqlTypeName.INTERVAL_SECOND, TypeProtos.MinorType.INTERVALDAY) + .put(SqlTypeName.DECIMAL, TypeProtos.MinorType.VARDECIMAL) // SqlTypeName.CHAR is the type for Literals in Calcite, Drill treats Literals as VARCHAR also .put(SqlTypeName.CHAR, TypeProtos.MinorType.VARCHAR) @@ -154,10 +160,10 @@ public class TypeInferenceUtils { .put("CONVERT_FROM", DrillDeferToExecSqlReturnTypeInference.INSTANCE) // Functions that return the same type - .put("LOWER", DrillSameSqlReturnTypeInference.INSTANCE) - .put("UPPER", DrillSameSqlReturnTypeInference.INSTANCE) - .put("INITCAP", DrillSameSqlReturnTypeInference.INSTANCE) - .put("REVERSE", DrillSameSqlReturnTypeInference.INSTANCE) + .put("LOWER", DrillSameSqlReturnTypeInference.THE_SAME_RETURN_TYPE) + .put("UPPER", DrillSameSqlReturnTypeInference.THE_SAME_RETURN_TYPE) + .put("INITCAP", DrillSameSqlReturnTypeInference.THE_SAME_RETURN_TYPE) + .put("REVERSE", DrillSameSqlReturnTypeInference.THE_SAME_RETURN_TYPE) // Window Functions // RANKING @@ -175,8 +181,8 @@ public class TypeInferenceUtils { .put(SqlKind.LAG.name(), DrillLeadLagSqlReturnTypeInference.INSTANCE) // FIRST_VALUE, LAST_VALUE - .put(SqlKind.FIRST_VALUE.name(), DrillSameSqlReturnTypeInference.INSTANCE) - .put(SqlKind.LAST_VALUE.name(), DrillSameSqlReturnTypeInference.INSTANCE) + .put(SqlKind.FIRST_VALUE.name(), DrillSameSqlReturnTypeInference.THE_SAME_RETURN_TYPE) + .put(SqlKind.LAST_VALUE.name(), DrillSameSqlReturnTypeInference.THE_SAME_RETURN_TYPE) // Functions rely on DrillReduceAggregatesRule for expression simplification as opposed to getting evaluated directly .put(SqlKind.AVG.name(), DrillAvgAggSqlReturnTypeInference.INSTANCE) @@ -184,6 +190,18 @@ public class TypeInferenceUtils { .put(SqlKind.STDDEV_SAMP.name(), DrillAvgAggSqlReturnTypeInference.INSTANCE) .put(SqlKind.VAR_POP.name(), DrillAvgAggSqlReturnTypeInference.INSTANCE) .put(SqlKind.VAR_SAMP.name(), DrillAvgAggSqlReturnTypeInference.INSTANCE) + .put(SqlKind.MIN.name(), DrillSameSqlReturnTypeInference.ALL_NULLABLE) + .put(SqlKind.MAX.name(), DrillSameSqlReturnTypeInference.ALL_NULLABLE) + .build(); + + /** + * Set of the decimal functions which return type cannot be determined exactly for some reasons at the current stage. + * For example functions which takes as an parameter scale or precision of return type. + */ + private static final Set<String> SET_SCALE_DECIMAL_FUNCTIONS = ImmutableSet.<String> builder() + .add("ROUND") + .add("TRUNC") + .add("TRUNCATE") .build(); /** @@ -281,17 +299,27 @@ public class TypeInferenceUtils { factory.createSqlType(SqlTypeName.ANY), true); } + } else if (SET_SCALE_DECIMAL_FUNCTIONS.contains(opBinding.getOperator().getName()) + && getDrillTypeFromCalciteType(type) == TypeProtos.MinorType.VARDECIMAL) { + return factory.createTypeWithNullability( + factory.createSqlType(SqlTypeName.ANY), + true); } } - final DrillFuncHolder func = resolveDrillFuncHolder(opBinding, functions); - final RelDataType returnType = getReturnType(opBinding, func); + final FunctionCall functionCall = convertSqlOperatorBindingToFunctionCall(opBinding); + final DrillFuncHolder func = resolveDrillFuncHolder(opBinding, functions, functionCall); + + final RelDataType returnType = getReturnType(opBinding, + func.getReturnType(functionCall.args), func.getNullHandling()); + return returnType.getSqlTypeName() == SqlTypeName.VARBINARY ? createCalciteTypeWithNullability(factory, SqlTypeName.ANY, returnType.isNullable()) : returnType; } - private static RelDataType getReturnType(final SqlOperatorBinding opBinding, final DrillFuncHolder func) { + private static RelDataType getReturnType(final SqlOperatorBinding opBinding, + final TypeProtos.MajorType returnType, FunctionTemplate.NullHandling nullHandling) { final RelDataTypeFactory factory = opBinding.getTypeFactory(); // least restrictive type (nullable ANY type) @@ -299,7 +327,6 @@ public class TypeInferenceUtils { factory.createSqlType(SqlTypeName.ANY), true); - final TypeProtos.MajorType returnType = func.getReturnType(); if (UNKNOWN_TYPE.equals(returnType)) { return nullableAnyType; } @@ -318,7 +345,7 @@ public class TypeInferenceUtils { break; case REQUIRED: - switch (func.getNullHandling()) { + switch (nullHandling) { case INTERNAL: isNullable = false; break; @@ -343,10 +370,7 @@ public class TypeInferenceUtils { throw new UnsupportedOperationException(); } - return createCalciteTypeWithNullability( - factory, - sqlTypeName, - isNullable); + return convertToCalciteType(factory, returnType, isNullable); } } @@ -381,15 +405,36 @@ public class TypeInferenceUtils { isNullable); } + // Determines SqlTypeName of the result. + // For the case when input may be implicitly casted to BIGINT, the type of result is BIGINT. + // Else for the case when input may be implicitly casted to FLOAT4, the type of result is DOUBLE. + // Else for the case when input may be implicitly casted to VARDECIMAL, the type of result is DECIMAL + // with the same scale as input and max allowed numeric precision. + // Else for the case when input may be implicitly casted to FLOAT8, the type of result is DOUBLE. + // When none of these conditions is satisfied, error is thrown. + // This order of checks is caused by the order of types in ResolverTypePrecedence.precedenceMap final RelDataType operandType = opBinding.getOperandType(0); final TypeProtos.MinorType inputMinorType = getDrillTypeFromCalciteType(operandType); - if(TypeCastRules.getLeastRestrictiveType(Lists.newArrayList(inputMinorType, TypeProtos.MinorType.BIGINT)) + if (TypeCastRules.getLeastRestrictiveType(Lists.newArrayList(inputMinorType, TypeProtos.MinorType.BIGINT)) == TypeProtos.MinorType.BIGINT) { return createCalciteTypeWithNullability( factory, SqlTypeName.BIGINT, isNullable); - } else if(TypeCastRules.getLeastRestrictiveType(Lists.newArrayList(inputMinorType, TypeProtos.MinorType.FLOAT8)) + } else if (TypeCastRules.getLeastRestrictiveType(Lists.newArrayList(inputMinorType, TypeProtos.MinorType.FLOAT4)) + == TypeProtos.MinorType.FLOAT4) { + return createCalciteTypeWithNullability( + factory, + SqlTypeName.DOUBLE, + isNullable); + } else if (TypeCastRules.getLeastRestrictiveType(Lists.newArrayList(inputMinorType, TypeProtos.MinorType.VARDECIMAL)) + == TypeProtos.MinorType.VARDECIMAL) { + RelDataType sqlType = factory.createSqlType(SqlTypeName.DECIMAL, + DrillRelDataTypeSystem.DRILL_REL_DATATYPE_SYSTEM.getMaxNumericPrecision(), + Math.min(operandType.getScale(), + DrillRelDataTypeSystem.DRILL_REL_DATATYPE_SYSTEM.getMaxNumericScale())); + return factory.createTypeWithNullability(sqlType, isNullable); + } else if (TypeCastRules.getLeastRestrictiveType(Lists.newArrayList(inputMinorType, TypeProtos.MinorType.FLOAT8)) == TypeProtos.MinorType.FLOAT8) { return createCalciteTypeWithNullability( factory, @@ -669,10 +714,21 @@ public class TypeInferenceUtils { } private static class DrillSameSqlReturnTypeInference implements SqlReturnTypeInference { - private static final DrillSameSqlReturnTypeInference INSTANCE = new DrillSameSqlReturnTypeInference(); + private static final DrillSameSqlReturnTypeInference THE_SAME_RETURN_TYPE = new DrillSameSqlReturnTypeInference(true); + private static final DrillSameSqlReturnTypeInference ALL_NULLABLE = new DrillSameSqlReturnTypeInference(false); + + private final boolean preserveNullability; + + public DrillSameSqlReturnTypeInference(boolean preserveNullability) { + this.preserveNullability = preserveNullability; + } + @Override public RelDataType inferReturnType(SqlOperatorBinding opBinding) { - return opBinding.getOperandType(0); + if (preserveNullability) { + return opBinding.getOperandType(0); + } + return opBinding.getTypeFactory().createTypeWithNullability(opBinding.getOperandType(0), true); } } @@ -680,27 +736,72 @@ public class TypeInferenceUtils { private static final DrillAvgAggSqlReturnTypeInference INSTANCE = new DrillAvgAggSqlReturnTypeInference(); @Override public RelDataType inferReturnType(SqlOperatorBinding opBinding) { - final boolean isNullable = opBinding.getGroupCount() == 0 || opBinding.hasFilter() || opBinding.getOperandType(0).isNullable(); - return createCalciteTypeWithNullability( - opBinding.getTypeFactory(), - SqlTypeName.DOUBLE, - isNullable); + final RelDataTypeFactory factory = opBinding.getTypeFactory(); + // If there is group-by and the imput type is Non-nullable, + // the output is Non-nullable; + // Otherwise, the output is nullable. + final boolean isNullable = opBinding.getGroupCount() == 0 + || opBinding.getOperandType(0).isNullable(); + + if (getDrillTypeFromCalciteType(opBinding.getOperandType(0)) == TypeProtos.MinorType.LATE) { + return createCalciteTypeWithNullability( + factory, + SqlTypeName.ANY, + isNullable); + } + + // Determines SqlTypeName of the result. + // For the case when input may be implicitly casted to FLOAT4, the type of result is DOUBLE. + // Else for the case when input may be implicitly casted to VARDECIMAL, the type of result is DECIMAL + // with scale max(6, input) and max allowed numeric precision. + // Else for the case when input may be implicitly casted to FLOAT8, the type of result is DOUBLE. + // When none of these conditions is satisfied, error is thrown. + // This order of checks is caused by the order of types in ResolverTypePrecedence.precedenceMap + final RelDataType operandType = opBinding.getOperandType(0); + final TypeProtos.MinorType inputMinorType = getDrillTypeFromCalciteType(operandType); + if (TypeCastRules.getLeastRestrictiveType(Lists.newArrayList(inputMinorType, TypeProtos.MinorType.FLOAT4)) + == TypeProtos.MinorType.FLOAT4) { + return createCalciteTypeWithNullability( + factory, + SqlTypeName.DOUBLE, + isNullable); + } else if (TypeCastRules.getLeastRestrictiveType(Lists.newArrayList(inputMinorType, TypeProtos.MinorType.VARDECIMAL)) + == TypeProtos.MinorType.VARDECIMAL) { + RelDataType sqlType = factory.createSqlType(SqlTypeName.DECIMAL, + DrillRelDataTypeSystem.DRILL_REL_DATATYPE_SYSTEM.getMaxNumericPrecision(), + Math.min(Math.max(6, operandType.getScale()), + DrillRelDataTypeSystem.DRILL_REL_DATATYPE_SYSTEM.getMaxNumericScale())); + return factory.createTypeWithNullability(sqlType, isNullable); + } else if (TypeCastRules.getLeastRestrictiveType(Lists.newArrayList(inputMinorType, TypeProtos.MinorType.FLOAT8)) + == TypeProtos.MinorType.FLOAT8) { + return createCalciteTypeWithNullability( + factory, + SqlTypeName.DOUBLE, + isNullable); + } else { + throw UserException + .functionError() + .message(String.format("%s does not support operand types (%s)", + opBinding.getOperator().getName(), + opBinding.getOperandType(0).getSqlTypeName())) + .build(logger); + } } } - private static DrillFuncHolder resolveDrillFuncHolder(final SqlOperatorBinding opBinding, final List<DrillFuncHolder> functions) { - final FunctionCall functionCall = convertSqlOperatorBindingToFunctionCall(opBinding); + private static DrillFuncHolder resolveDrillFuncHolder(final SqlOperatorBinding opBinding, + final List<DrillFuncHolder> functions, FunctionCall functionCall) { final FunctionResolver functionResolver = FunctionResolverFactory.getResolver(functionCall); final DrillFuncHolder func = functionResolver.getBestMatch(functions, functionCall); // Throw an exception // if no DrillFuncHolder matched for the given list of operand types - if(func == null) { - String operandTypes = ""; - for(int i = 0; i < opBinding.getOperandCount(); ++i) { - operandTypes += opBinding.getOperandType(i).getSqlTypeName(); - if(i < opBinding.getOperandCount() - 1) { - operandTypes += ","; + if (func == null) { + StringBuilder operandTypes = new StringBuilder(); + for (int i = 0; i < opBinding.getOperandCount(); ++i) { + operandTypes.append(opBinding.getOperandType(i).getSqlTypeName()); + if (i < opBinding.getOperandCount() - 1) { + operandTypes.append(","); } } @@ -708,7 +809,7 @@ public class TypeInferenceUtils { .functionError() .message(String.format("%s does not support operand types (%s)", opBinding.getOperator().getName(), - operandTypes)) + operandTypes.toString())) .build(logger); } return func; @@ -768,6 +869,25 @@ public class TypeInferenceUtils { } /** + * Creates a RelDataType using specified RelDataTypeFactory which corresponds to specified TypeProtos.MajorType. + * + * @param typeFactory RelDataTypeFactory used to create the RelDataType + * @param drillType the given TypeProtos.MajorType + * @param isNullable nullability of the resulting type + * @return RelDataType which corresponds to specified TypeProtos.MajorType + */ + public static RelDataType convertToCalciteType(RelDataTypeFactory typeFactory, + TypeProtos.MajorType drillType, boolean isNullable) { + SqlTypeName sqlTypeName = getCalciteTypeFromDrillType(drillType.getMinorType()); + if (sqlTypeName == SqlTypeName.DECIMAL) { + return typeFactory.createTypeWithNullability( + typeFactory.createSqlType(sqlTypeName, drillType.getPrecision(), + drillType.getScale()), isNullable); + } + return createCalciteTypeWithNullability(typeFactory, sqlTypeName, isNullable); + } + + /** * Given a SqlOperatorBinding, convert it to FunctionCall * @param opBinding the given SqlOperatorBinding * @return FunctionCall the converted FunctionCall @@ -778,22 +898,28 @@ public class TypeInferenceUtils { for (int i = 0; i < opBinding.getOperandCount(); ++i) { final RelDataType type = opBinding.getOperandType(i); final TypeProtos.MinorType minorType = getDrillTypeFromCalciteType(type); - final TypeProtos.MajorType majorType; - if (type.isNullable()) { - majorType = Types.optional(minorType); - } else { - majorType = Types.required(minorType); + TypeProtos.DataMode dataMode = + type.isNullable() ? TypeProtos.DataMode.OPTIONAL : TypeProtos.DataMode.REQUIRED; + + TypeProtos.MajorType.Builder builder = + TypeProtos.MajorType.newBuilder() + .setMode(dataMode) + .setMinorType(minorType); + + if (CoreDecimalUtility.isDecimalType(minorType)) { + builder + .setScale(type.getScale()) + .setPrecision(type.getPrecision()); } - args.add(new MajorTypeInLogicalExpression(majorType)); + args.add(new MajorTypeInLogicalExpression(builder.build())); } final String drillFuncName = FunctionCallFactory.replaceOpWithFuncName(opBinding.getOperator().getName()); - final FunctionCall functionCall = new FunctionCall( + return new FunctionCall( drillFuncName, args, ExpressionPosition.UNKNOWN); - return functionCall; } /** |