Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] (Part3) Support creating materialized views with common aggregate state functions #51510

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,8 @@ private Expr analyzeExpr(SelectAnalyzer.RewriteAliasVisitor visitor,
Expr newExpr = defineExpr.clone(smap);
newExpr = newExpr.accept(visitor, null);
newExpr = Expr.analyzeAndCastFold(newExpr);
if (!newExpr.getType().equals(type)) {
Type newType = newExpr.getType();
if (!type.isFullyCompatible(newType)) {
newExpr = new CastExpr(type, newExpr);
}
return newExpr;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,19 @@ public AggStateCombinator(AggStateCombinator other) {

public static Optional<AggStateCombinator> of(AggregateFunction aggFunc) {
try {
Type intermediateType = aggFunc.getIntermediateTypeOrReturnType();
Type intermediateType = aggFunc.getIntermediateTypeOrReturnType().clone();
FunctionName funcName = new FunctionName(aggFunc.functionName() + FunctionSet.AGG_STATE_SUFFIX);
AggStateCombinator aggStateFunc = new AggStateCombinator(funcName, Arrays.asList(aggFunc.getArgs()),
intermediateType);
aggStateFunc.setBinaryType(TFunctionBinaryType.BUILTIN);
aggStateFunc.setPolymorphic(aggFunc.isPolymorphic());
aggStateFunc.setAggStateDesc(new AggStateDesc(aggFunc));

AggStateDesc aggStateDesc = new AggStateDesc(aggFunc);
aggStateFunc.setAggStateDesc(aggStateDesc);
// `agg_state` function's type will contain agg state desc.
intermediateType.setAggStateDesc(aggStateDesc);
// use agg state desc's nullable as `agg_state` function's nullable
aggStateFunc.setIsNullable(aggStateDesc.getResultNullable());
LOG.info("Register agg state function: {}", aggStateFunc.functionName());
return Optional.of(aggStateFunc);
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,15 @@ public static Optional<AggStateMergeCombinator> of(AggregateFunction aggFunc) {
new AggStateMergeCombinator(functionName, imtermediateType, aggFunc.getReturnType());
aggStateMergeFunc.setBinaryType(TFunctionBinaryType.BUILTIN);
aggStateMergeFunc.setPolymorphic(aggFunc.isPolymorphic());
aggStateMergeFunc.setAggStateDesc(new AggStateDesc(aggFunc));
AggStateDesc aggStateDesc;
if (aggFunc.getAggStateDesc() != null) {
aggStateDesc = aggFunc.getAggStateDesc().clone();
} else {
aggStateDesc = new AggStateDesc(aggFunc);
}
aggStateMergeFunc.setAggStateDesc(aggStateDesc);
// use agg state desc's nullable as `agg_state` function's nullable
aggStateMergeFunc.setIsNullable(aggStateDesc.getResultNullable());
LOG.info("Register agg state function: {}", aggStateMergeFunc.functionName());
return Optional.of(aggStateMergeFunc);
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,15 @@ public static Optional<AggStateUnionCombinator> of(AggregateFunction aggFunc) {
new AggStateUnionCombinator(functionName, intermediateType);
aggStateUnionFunc.setBinaryType(TFunctionBinaryType.BUILTIN);
aggStateUnionFunc.setPolymorphic(aggFunc.isPolymorphic());
aggStateUnionFunc.setAggStateDesc(new AggStateDesc(aggFunc));
AggStateDesc aggStateDesc;
if (aggFunc.getAggStateDesc() != null) {
aggStateDesc = aggFunc.getAggStateDesc().clone();
} else {
aggStateDesc = new AggStateDesc(aggFunc);
}
aggStateUnionFunc.setAggStateDesc(aggStateDesc);
// use agg state desc's nullable as `agg_state` function's nullable
aggStateUnionFunc.setIsNullable(aggStateDesc.getResultNullable());
LOG.info("Register agg state function: {}", aggStateUnionFunc.functionName());
return Optional.of(aggStateUnionFunc);
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ public static Function getAnalyzedCombinatorFunction(ConnectContext session,
return null;
}
AggregateFunction aggFunc = (AggregateFunction) argFn;
if (aggFunc.getNumArgs() == 1 && argumentTypes[0].isDecimalOfAnyVersion()) {
if (aggFunc.getNumArgs() == 1) {
// only copy argument if it's a decimal type
AggregateFunction argFnCopy = (AggregateFunction) aggFunc.copy();
argFnCopy.setArgsType(argumentTypes);
Expand Down Expand Up @@ -209,7 +209,9 @@ private static AggregateFunction getAggStateFunction(ConnectContext session,
if (!(fn instanceof AggregateFunction)) {
return null;
}
return (AggregateFunction) fn;
AggregateFunction result = (AggregateFunction) fn.copy();
result.setAggStateDesc(aggStateDesc);
return result;
}

private static Type[] getNewArgumentTypes(Type[] origArgTypes, String argFnName, Type arg0Type) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import com.starrocks.analysis.CaseWhenClause;
import com.starrocks.analysis.Expr;
import com.starrocks.analysis.FunctionCallExpr;
import com.starrocks.analysis.FunctionParams;
import com.starrocks.analysis.IntLiteral;
import com.starrocks.analysis.IsNullPredicate;
import com.starrocks.analysis.OrderByElement;
Expand All @@ -58,6 +59,8 @@
import com.starrocks.catalog.Table;
import com.starrocks.catalog.Type;
import com.starrocks.catalog.View;
import com.starrocks.catalog.combinator.AggStateDesc;
import com.starrocks.catalog.combinator.AggStateUnionCombinator;
import com.starrocks.common.ErrorCode;
import com.starrocks.common.ErrorReport;
import com.starrocks.common.FeConstants;
Expand All @@ -68,6 +71,7 @@
import com.starrocks.sql.analyzer.AnalyzerUtils;
import com.starrocks.sql.analyzer.ExpressionAnalyzer;
import com.starrocks.sql.analyzer.Field;
import com.starrocks.sql.analyzer.FunctionAnalyzer;
import com.starrocks.sql.analyzer.RelationFields;
import com.starrocks.sql.analyzer.RelationId;
import com.starrocks.sql.analyzer.Scope;
Expand Down Expand Up @@ -250,7 +254,7 @@ public Map<String, Expr> parseDefineExprWithoutAnalyze(String originalSql) {
case FunctionSet.HLL_UNION:
case FunctionSet.PERCENTILE_UNION:
case FunctionSet.COUNT: {
MVColumnItem item = buildAggColumnItem(selectListItem, slots);
MVColumnItem item = buildAggColumnItem(new ConnectContext(), selectListItem, slots);
expr = item.getDefineExpr();
name = item.getName();
break;
Expand Down Expand Up @@ -337,7 +341,7 @@ public void analyze(ConnectContext context) {
if (!(selectRelation.getRelation() instanceof TableRelation)) {
throw new UnsupportedMVException("Materialized view query statement only support direct query from table.");
}
int beginIndexOfAggregation = genColumnAndSetIntoStmt(table, selectRelation);
int beginIndexOfAggregation = genColumnAndSetIntoStmt(context, table, selectRelation);
if (selectRelation.isDistinct() || selectRelation.hasAggregation()) {
setMvKeysType(KeysType.AGG_KEYS);
}
Expand Down Expand Up @@ -409,7 +413,7 @@ private void analyzeExprWithTableAlias(ConnectContext context,
.collect(Collectors.toList()))), context);
}

private int genColumnAndSetIntoStmt(Table table, SelectRelation selectRelation) {
private int genColumnAndSetIntoStmt(ConnectContext context, Table table, SelectRelation selectRelation) {
List<MVColumnItem> mvColumnItemList = Lists.newArrayList();

boolean meetAggregate = false;
Expand Down Expand Up @@ -442,30 +446,33 @@ private int genColumnAndSetIntoStmt(Table table, SelectRelation selectRelation)
&& ((FunctionCallExpr) selectListItemExpr).isAggregateFunction()) {
// Aggregate Function must match pattern.
FunctionCallExpr functionCallExpr = (FunctionCallExpr) selectListItemExpr;
String functionName = functionCallExpr.getFnName().getFunction();

MVColumnPattern mvColumnPattern =
CreateMaterializedViewStmt.FN_NAME_TO_PATTERN.get(functionName.toLowerCase());
if (mvColumnPattern == null) {
throw new UnsupportedMVException(
"Materialized view does not support function:%s, supported functions are: %s",
functionCallExpr.toSqlImpl(), FN_NAME_TO_PATTERN.keySet());
}
String functionName = functionCallExpr.getFnName().getFunction().toLowerCase();
// current version not support count(distinct) function in creating materialized view
if (!isReplay && functionCallExpr.isDistinct()) {
throw new UnsupportedMVException(
"Materialized view does not support distinct function " + functionCallExpr.toSqlImpl());
}
if (!mvColumnPattern.match(functionCallExpr)) {
throw new UnsupportedMVException(
"The function " + functionName + " must match pattern:" + mvColumnPattern);
if (!FN_NAME_TO_PATTERN.containsKey(functionName)) {
// eg: avg_union(avg_state(xxx))
} else {
MVColumnPattern mvColumnPattern = FN_NAME_TO_PATTERN.get(functionName);
if (mvColumnPattern == null) {

throw new UnsupportedMVException(
"Materialized view does not support function:%s, supported functions are: %s",
functionCallExpr.toSqlImpl(), FN_NAME_TO_PATTERN.keySet());
}
if (!mvColumnPattern.match(functionCallExpr)) {
throw new UnsupportedMVException(
"The function " + functionName + " must match pattern:" + mvColumnPattern);
}
}
if (beginIndexOfAggregation == -1) {
beginIndexOfAggregation = i;
}
meetAggregate = true;

mvColumnItem = buildAggColumnItem(selectListItem, slots);
mvColumnItem = buildAggColumnItem(context, selectListItem, slots);
if (!mvColumnNameSet.add(mvColumnItem.getName())) {
ErrorReport.reportSemanticException(ErrorCode.ERR_DUP_FIELDNAME, mvColumnItem.getName());
}
Expand Down Expand Up @@ -527,17 +534,68 @@ private MVColumnItem buildNonAggColumnItem(SelectListItem selectListItem,
type = AnalyzerUtils.transformTableColumnType(type, false);
}
Set<String> baseColumnNames = baseSlotRefs.stream().map(slot -> slot.getColumnName())
.collect(Collectors.toSet());
return new MVColumnItem(columnName, type, null, false, defineExpr,
.collect(Collectors.toSet());
return new MVColumnItem(columnName, type, null, null, false, defineExpr,
defineExpr.isNullable(), baseColumnNames);
}

// Convert the aggregate function to MVColumn.
private MVColumnItem buildAggColumnItem(SelectListItem selectListItem,
private MVColumnItem buildAggColumnItem(ConnectContext context,
SelectListItem selectListItem,
List<SlotRef> baseSlotRefs) {
FunctionCallExpr node = (FunctionCallExpr) selectListItem.getExpr();
String functionName = node.getFnName().getFunction();
Preconditions.checkState(node.getChildren().size() == 1, "Aggregate function only support one child");

if (!FN_NAME_TO_PATTERN.containsKey(functionName)) {
if (Strings.isNullOrEmpty(selectListItem.getAlias())) {
throw new SemanticException("Create materialized view non-slot ref expression should have an alias:" +
selectListItem.getExpr());
}

Expr defineExpr = node.getChild(0);
List<Type> argTypes = node.getChildren().stream().map(Expr::getType).collect(Collectors.toList());
Type arg0Type = argTypes.get(0);
if (arg0Type.getAggStateDesc() == null) {
throw new UnsupportedMVException("Unsupported function:" + functionName + ", cannot find agg state desc from " +
"arg0");
}
FunctionParams params = new FunctionParams(false, Lists.newArrayList());
Type[] argumentTypes = argTypes.toArray(Type[]::new);
Boolean[] isArgumentConstants = argTypes.stream().map(x -> false).toArray(Boolean[]::new);
Function function = FunctionAnalyzer.getAnalyzedAggregateFunction(context, functionName,
params, argumentTypes, isArgumentConstants, NodePosition.ZERO);
if (function == null || !(function instanceof AggStateUnionCombinator)) {
throw new UnsupportedMVException("Unsupported function:" + functionName);
}
AggStateUnionCombinator aggFunction = (AggStateUnionCombinator) function;
String mvColumnName = MVUtils.getMVColumnName(selectListItem.getAlias());
AggStateDesc aggStateDesc = aggFunction.getAggStateDesc();
Type type = aggFunction.getIntermediateTypeOrReturnType();
if (type.isWildcardDecimal()) {
throw new UnsupportedMVException("Unsupported wildcard decimal type in materialized view:" + type + ", " +
"function:" + node);
}
if (aggStateDesc.getArgTypes().stream().anyMatch(t -> t.isWildcardDecimal())) {
throw new UnsupportedMVException("Unsupported wildcard decimal type in materialized view:" + type + ", " +
"function:" + node);
}
Set<String> baseColumnNames = baseSlotRefs.stream().map(slot -> slot.getColumnName())
.collect(Collectors.toSet());
AggregateType mvAggregateType = AggregateType.AGG_STATE_UNION;
Type finalType = AnalyzerUtils.transformTableColumnType(type, false);
return new MVColumnItem(mvColumnName, finalType, mvAggregateType, aggStateDesc, false,
defineExpr, aggStateDesc.getResultNullable(), baseColumnNames);
} else {
return buildAggColumnItemWithPattern(selectListItem, baseSlotRefs);
}
}

// Convert the aggregate function to MVColumn.
private MVColumnItem buildAggColumnItemWithPattern(SelectListItem selectListItem,
List<SlotRef> baseSlotRefs) {
FunctionCallExpr functionCallExpr = (FunctionCallExpr) selectListItem.getExpr();
String functionName = functionCallExpr.getFnName().getFunction();
Preconditions.checkState(functionCallExpr.getChildren().size() == 1, "Aggregate function only support one child");
Expr defineExpr = functionCallExpr.getChild(0);
AggregateType mvAggregateType = null;
Type baseType = defineExpr.getType();
Expand Down Expand Up @@ -640,8 +698,8 @@ private MVColumnItem buildAggColumnItem(SelectListItem selectListItem,
String.format("Invalid aggregate function '%s' for '%s'", mvAggregateType, type));
}
Set<String> baseColumnNames = baseSlotRefs.stream().map(slot -> slot.getColumnName())
.collect(Collectors.toSet());
return new MVColumnItem(mvColumnName, type, mvAggregateType, false,
.collect(Collectors.toSet());
return new MVColumnItem(mvColumnName, type, mvAggregateType, null, false,
defineExpr, functionCallExpr.isNullable(), baseColumnNames);
}

LiShuMing marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,12 @@
import com.starrocks.catalog.Column;
import com.starrocks.catalog.OlapTable;
import com.starrocks.catalog.Type;
import com.starrocks.catalog.combinator.AggStateDesc;

import java.util.Set;

import static com.starrocks.catalog.Column.COLUMN_UNIQUE_ID_INIT_VALUE;

/**
* This is a result of semantic analysis for AddMaterializedViewClause.
* It is used to construct real mv column in MaterializedViewHandler.
Expand All @@ -54,16 +57,19 @@ public class MVColumnItem {
private Type type;
private boolean isKey;
private AggregateType aggregationType;
private AggStateDesc aggStateDesc;
private boolean isAllowNull;
private boolean isAggregationTypeImplicit;
private Expr defineExpr;
private Set<String> baseColumnNames;

public MVColumnItem(String name, Type type, AggregateType aggregateType, boolean isAggregationTypeImplicit,
public MVColumnItem(String name, Type type, AggregateType aggregateType, AggStateDesc aggStateDesc,
boolean isAggregationTypeImplicit,
Expr defineExpr, boolean isAllowNull, Set<String> baseColumnNames) {
this.name = name;
this.type = type;
this.aggregationType = aggregateType;
this.aggStateDesc = aggStateDesc;
this.isAggregationTypeImplicit = isAggregationTypeImplicit;
this.defineExpr = defineExpr;
this.isAllowNull = isAllowNull;
Expand Down Expand Up @@ -124,8 +130,8 @@ public Column toMVColumn(OlapTable olapTable) {
Column result;
boolean hasUniqueId = olapTable.getMaxColUniqueId() >= 0;
if (baseColumn == null) {
result = new Column(name, type, isKey, aggregationType, isAllowNull,
null, "");
result = new Column(name, type, isKey, aggregationType, aggStateDesc, isAllowNull,
null, "", COLUMN_UNIQUE_ID_INIT_VALUE);
if (defineExpr != null) {
result.setDefineExpr(defineExpr);
}
Expand Down
Loading
Loading