Skip to content

Commit

Permalink
[native] Introduce native function namespace SPI
Browse files Browse the repository at this point in the history
  • Loading branch information
pdabre12 committed Feb 1, 2025
1 parent 092dd36 commit f54c498
Show file tree
Hide file tree
Showing 55 changed files with 1,617 additions and 143 deletions.
6 changes: 6 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -1018,6 +1018,12 @@
<version>${project.version}</version>
</dependency>

<dependency>
<groupId>com.facebook.presto</groupId>
<artifactId>presto-native-sidecar-plugin</artifactId>
<version>${project.version}</version>
</dependency>

<dependency>
<groupId>com.facebook.hive</groupId>
<artifactId>hive-dwrf</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import com.facebook.presto.common.QualifiedObjectName;
import com.facebook.presto.common.block.BlockEncodingSerde;
import com.facebook.presto.common.function.SqlFunctionResult;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.common.type.TypeManager;
import com.facebook.presto.common.type.UserDefinedType;
import com.facebook.presto.functionNamespace.execution.SqlFunctionExecutors;
Expand Down Expand Up @@ -192,7 +193,7 @@ public Optional<UserDefinedType> getUserDefinedType(QualifiedObjectName typeName
}

@Override
public final FunctionHandle getFunctionHandle(Optional<? extends FunctionNamespaceTransactionHandle> transactionHandle, Signature signature)
public FunctionHandle getFunctionHandle(Optional<? extends FunctionNamespaceTransactionHandle> transactionHandle, Signature signature)
{
checkCatalog(signature.getName());
// This is the only assumption in this class that we're dealing with sql-invoked regular function.
Expand Down Expand Up @@ -363,10 +364,13 @@ protected AggregationFunctionImplementation sqlInvokedFunctionToAggregationImple
"Need aggregationMetadata to get aggregation function implementation");

AggregationFunctionMetadata aggregationMetadata = function.getAggregationMetadata().get();
List<Type> parameters = function.getSignature().getArgumentTypes().stream().map(
(typeManager::getType)).collect(toImmutableList());
return new SqlInvokedAggregationFunctionImplementation(
typeManager.getType(aggregationMetadata.getIntermediateType()),
typeManager.getType(function.getSignature().getReturnType()),
aggregationMetadata.isOrderSensitive());
aggregationMetadata.isOrderSensitive(),
parameters);
default:
throw new IllegalStateException(format("Unknown function implementation type: %s", implementationType));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ private SqlInvokedFunction createSqlInvokedFunction(String functionName, JsonBas
jsonBasedUdfFunctionMetaData.getDocString(),
jsonBasedUdfFunctionMetaData.getRoutineCharacteristics(),
"",
jsonBasedUdfFunctionMetaData.getVariableArity(),
functionVersion,
jsonBasedUdfFunctionMetaData.getFunctionKind(),
functionId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
package com.facebook.presto.hive.functions;

import com.facebook.presto.common.CatalogSchemaName;
import com.facebook.presto.common.type.TypeSignature;
import com.facebook.presto.spi.function.FunctionHandle;
import com.facebook.presto.spi.function.FunctionKind;
import com.facebook.presto.spi.function.Signature;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;

import java.util.List;
import java.util.Objects;

import static java.util.Objects.requireNonNull;
Expand Down Expand Up @@ -54,6 +56,12 @@ public FunctionKind getKind()
return signature.getKind();
}

@Override
public List<TypeSignature> getArgumentTypes()
{
return signature.getArgumentTypes();
}

@JsonProperty
public Signature getSignature()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ public static TestingPrestoServer createTestingPrestoServer()
functionAndTypeManager.loadFunctionNamespaceManager(
"hive-functions",
"hive",
getNamespaceManagerCreationProperties());
getNamespaceManagerCreationProperties(),
server.getPluginNodeManager());
server.refreshNodes();
return server;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,8 @@ private static TestingPrestoServer createServer()
functionAndTypeManager.loadFunctionNamespaceManager(
"hive-functions",
"hive",
Collections.emptyMap());
Collections.emptyMap(),
server.getPluginNodeManager());
server.refreshNodes();
return server;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
*/
package com.facebook.presto.iceberg.optimizer;

import com.facebook.presto.common.CatalogSchemaName;
import com.facebook.presto.common.QualifiedObjectName;
import com.facebook.presto.common.predicate.NullableValue;
import com.facebook.presto.common.predicate.TupleDomain;
import com.facebook.presto.common.predicate.TupleDomain.ColumnDomain;
Expand Down Expand Up @@ -58,7 +56,6 @@
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;

Expand All @@ -67,7 +64,7 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Predicate;

import static com.facebook.presto.expressions.LogicalRowExpressions.TRUE_CONSTANT;
import static com.facebook.presto.iceberg.IcebergSessionProperties.getRowsForMetadataOptimizationThreshold;
Expand All @@ -81,16 +78,10 @@
public class IcebergMetadataOptimizer
implements ConnectorPlanOptimizer
{
public static final CatalogSchemaName DEFAULT_NAMESPACE = new CatalogSchemaName("presto", "default");
private static final Set<QualifiedObjectName> ALLOWED_FUNCTIONS = ImmutableSet.of(
QualifiedObjectName.valueOf(DEFAULT_NAMESPACE, "max"),
QualifiedObjectName.valueOf(DEFAULT_NAMESPACE, "min"),
QualifiedObjectName.valueOf(DEFAULT_NAMESPACE, "approx_distinct"));

// Min/Max could be folded into LEAST/GREATEST
private static final Map<QualifiedObjectName, QualifiedObjectName> AGGREGATION_SCALAR_MAPPING = ImmutableMap.of(
QualifiedObjectName.valueOf(DEFAULT_NAMESPACE, "max"), QualifiedObjectName.valueOf(DEFAULT_NAMESPACE, "greatest"),
QualifiedObjectName.valueOf(DEFAULT_NAMESPACE, "min"), QualifiedObjectName.valueOf(DEFAULT_NAMESPACE, "least"));
private static final Map<String, String> AGGREGATION_SCALAR_MAPPING = ImmutableMap.of(
"max", "greatest",
"min", "least");

private final FunctionMetadataManager functionMetadataManager;
private final TypeManager typeManager;
Expand Down Expand Up @@ -137,6 +128,7 @@ private static class Optimizer
private final RowExpressionService rowExpressionService;
private final StandardFunctionResolution functionResolution;
private final int rowsForMetadataOptimizationThreshold;
private final List<Predicate<FunctionHandle>> allowedFunctionsPredicates;

private Optimizer(ConnectorSession connectorSession,
PlanNodeIdAllocator idAllocator,
Expand All @@ -156,15 +148,19 @@ private Optimizer(ConnectorSession connectorSession,
this.functionResolution = functionResolution;
this.typeManager = typeManager;
this.rowsForMetadataOptimizationThreshold = rowsForMetadataOptimizationThreshold;
this.allowedFunctionsPredicates = ImmutableList.of(
functionResolution::isMaxFunction,
functionResolution::isMinFunction,
functionResolution::isApproximateCountDistinctFunction);
}

@Override
public PlanNode visitAggregation(AggregationNode node, RewriteContext<Void> context)
{
// supported functions are only MIN/MAX/APPROX_DISTINCT or distinct aggregates
for (Aggregation aggregation : node.getAggregations().values()) {
QualifiedObjectName functionName = functionMetadataManager.getFunctionMetadata(aggregation.getFunctionHandle()).getName();
if (!ALLOWED_FUNCTIONS.contains(functionName) && !aggregation.isDistinct()) {
if (allowedFunctionsPredicates.stream().noneMatch(
pred -> pred.test(aggregation.getFunctionHandle())) && !aggregation.isDistinct()) {
return context.defaultRewrite(node);
}
}
Expand Down Expand Up @@ -270,7 +266,7 @@ private boolean isReducible(AggregationNode node, List<VariableReferenceExpressi
}
for (Aggregation aggregation : node.getAggregations().values()) {
FunctionMetadata functionMetadata = functionMetadataManager.getFunctionMetadata(aggregation.getFunctionHandle());
if (!AGGREGATION_SCALAR_MAPPING.containsKey(functionMetadata.getName()) ||
if (!AGGREGATION_SCALAR_MAPPING.containsKey(functionMetadata.getName().getObjectName()) ||
functionMetadata.getArgumentTypes().size() > 1 ||
!inputs.containsAll(aggregation.getCall().getArguments())) {
return false;
Expand Down Expand Up @@ -340,7 +336,7 @@ private RowExpression evaluateMinMax(FunctionMetadata aggregationFunctionMetadat
return new ConstantExpression(Optional.empty(), null, returnType);
}

String scalarFunctionName = AGGREGATION_SCALAR_MAPPING.get(aggregationFunctionMetadata.getName()).getObjectName();
String scalarFunctionName = AGGREGATION_SCALAR_MAPPING.get(aggregationFunctionMetadata.getName().getObjectName());
while (arguments.size() > 1) {
List<RowExpression> reducedArguments = new ArrayList<>();
// We fold for every 100 values because GREATEST/LEAST has argument count limit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@
package com.facebook.presto.metadata;

import com.facebook.presto.common.CatalogSchemaName;
import com.facebook.presto.common.type.TypeSignature;
import com.facebook.presto.spi.function.FunctionHandle;
import com.facebook.presto.spi.function.FunctionKind;
import com.facebook.presto.spi.function.Signature;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;

import java.util.List;
import java.util.Objects;

import static java.util.Objects.requireNonNull;
Expand Down Expand Up @@ -54,6 +56,12 @@ public FunctionKind getKind()
return signature.getKind();
}

@Override
public List<TypeSignature> getArgumentTypes()
{
return signature.getArgumentTypes();
}

@Override
public CatalogSchemaName getCatalogSchemaName()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,6 @@
import com.facebook.presto.spi.function.SqlInvokedFunction;
import com.facebook.presto.spi.function.SqlInvokedScalarFunctionImplementation;
import com.facebook.presto.sql.analyzer.FunctionsConfig;
import com.facebook.presto.sql.analyzer.TypeSignatureProvider;
import com.facebook.presto.type.BigintOperators;
import com.facebook.presto.type.BooleanOperators;
import com.facebook.presto.type.CharOperators;
Expand Down Expand Up @@ -345,7 +344,6 @@
import static com.facebook.presto.geospatial.SphericalGeographyType.SPHERICAL_GEOGRAPHY;
import static com.facebook.presto.geospatial.type.BingTileType.BING_TILE;
import static com.facebook.presto.geospatial.type.GeometryType.GEOMETRY;
import static com.facebook.presto.metadata.SignatureBinder.applyBoundVariables;
import static com.facebook.presto.operator.aggregation.AlternativeArbitraryAggregationFunction.ALTERNATIVE_ANY_VALUE_AGGREGATION;
import static com.facebook.presto.operator.aggregation.AlternativeArbitraryAggregationFunction.ALTERNATIVE_ARBITRARY_AGGREGATION;
import static com.facebook.presto.operator.aggregation.AlternativeMaxAggregationFunction.ALTERNATIVE_MAX;
Expand Down Expand Up @@ -461,7 +459,6 @@
import static com.facebook.presto.spi.function.FunctionKind.SCALAR;
import static com.facebook.presto.spi.function.FunctionKind.WINDOW;
import static com.facebook.presto.spi.function.SqlFunctionVisibility.HIDDEN;
import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypeSignatures;
import static com.facebook.presto.sql.planner.LiteralEncoder.MAGIC_LITERAL_FUNCTION_PREFIX;
import static com.facebook.presto.type.ArrayParametricType.ARRAY;
import static com.facebook.presto.type.CodePointsType.CODE_POINTS;
Expand Down Expand Up @@ -521,7 +518,6 @@
import static com.facebook.presto.type.Re2JRegexpType.RE2J_REGEXP;
import static com.facebook.presto.type.RowParametricType.ROW;
import static com.facebook.presto.type.SfmSketchType.SFM_SKETCH;
import static com.facebook.presto.type.TypeUtils.resolveTypes;
import static com.facebook.presto.type.khyperloglog.KHyperLogLogType.K_HYPER_LOG_LOG;
import static com.facebook.presto.type.setdigest.SetDigestType.SET_DIGEST;
import static com.google.common.base.Preconditions.checkArgument;
Expand Down Expand Up @@ -1343,47 +1339,11 @@ private SpecializedFunctionKey getSpecializedFunctionKey(Signature signature)

private SpecializedFunctionKey doGetSpecializedFunctionKey(Signature signature)
{
Iterable<SqlFunction> candidates = getFunctions(null, signature.getName());
// search for exact match
Type returnType = functionAndTypeManager.getType(signature.getReturnType());
List<TypeSignatureProvider> argumentTypeSignatureProviders = fromTypeSignatures(signature.getArgumentTypes());
for (SqlFunction candidate : candidates) {
Optional<BoundVariables> boundVariables = new SignatureBinder(functionAndTypeManager, candidate.getSignature(), false)
.bindVariables(argumentTypeSignatureProviders, returnType);
if (boundVariables.isPresent()) {
return new SpecializedFunctionKey(candidate, boundVariables.get(), argumentTypeSignatureProviders.size());
}
}

// TODO: hack because there could be "type only" coercions (which aren't necessarily included as implicit casts),
// so do a second pass allowing "type only" coercions
List<Type> argumentTypes = resolveTypes(signature.getArgumentTypes(), functionAndTypeManager);
for (SqlFunction candidate : candidates) {
SignatureBinder binder = new SignatureBinder(functionAndTypeManager, candidate.getSignature(), true);
Optional<BoundVariables> boundVariables = binder.bindVariables(argumentTypeSignatureProviders, returnType);
if (!boundVariables.isPresent()) {
continue;
}
Signature boundSignature = applyBoundVariables(candidate.getSignature(), boundVariables.get(), argumentTypes.size());

if (!functionAndTypeManager.isTypeOnlyCoercion(functionAndTypeManager.getType(boundSignature.getReturnType()), returnType)) {
continue;
}
boolean nonTypeOnlyCoercion = false;
for (int i = 0; i < argumentTypes.size(); i++) {
Type expectedType = functionAndTypeManager.getType(boundSignature.getArgumentTypes().get(i));
if (!functionAndTypeManager.isTypeOnlyCoercion(argumentTypes.get(i), expectedType)) {
nonTypeOnlyCoercion = true;
break;
}
}
if (nonTypeOnlyCoercion) {
continue;
}

return new SpecializedFunctionKey(candidate, boundVariables.get(), argumentTypes.size());
}
return functionAndTypeManager.getSpecializedFunctionKey(signature);
}

public SpecializedFunctionKey doGetSpecializedFunctionKeyForMagicLiteralFunctions(Signature signature, FunctionAndTypeManager functionAndTypeManager)
{
// TODO: this is a hack and should be removed
if (signature.getNameSuffix().startsWith(MAGIC_LITERAL_FUNCTION_PREFIX)) {
List<TypeSignature> parameterTypes = signature.getArgumentTypes();
Expand All @@ -1406,7 +1366,6 @@ private SpecializedFunctionKey doGetSpecializedFunctionKey(Signature signature)
.build(),
1);
}

throw new PrestoException(FUNCTION_IMPLEMENTATION_MISSING, format("%s not found", signature));
}

Expand Down
Loading

0 comments on commit f54c498

Please sign in to comment.