Skip to content

Commit

Permalink
feat: AggregateRel should use grouping references (#94)
Browse files Browse the repository at this point in the history
  • Loading branch information
srikrishnak authored Jan 8, 2025
1 parent ec31db0 commit e403121
Show file tree
Hide file tree
Showing 5 changed files with 506 additions and 130 deletions.
230 changes: 161 additions & 69 deletions plan/builders.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,9 @@ type Builder interface {
Project(input Rel, exprs ...expr.Expression) (*ProjectRel, error)
// Deprecated: Use Project(...).Remap() instead.
ProjectRemap(input Rel, remap []int32, exprs ...expr.Expression) (*ProjectRel, error)
// Deprecated: Use AggregateColumns(...).Remap() instead.
AggregateColumnsRemap(input Rel, remap []int32, measures []AggRelMeasure, groupByCols ...int32) (*AggregateRel, error)
// Deprecated: Use GetRelBuilder().AggregateRel(...) instead.
AggregateColumns(input Rel, measures []AggRelMeasure, groupByCols ...int32) (*AggregateRel, error)
// Deprecated: Use AggregateExprs(...).Remap() instead.
AggregateExprsRemap(input Rel, remap []int32, measures []AggRelMeasure, groups ...[]expr.Expression) (*AggregateRel, error)
// Deprecated: Use GetRelBuilder().AggregateRel(...) instead.
AggregateExprs(input Rel, measures []AggRelMeasure, groups ...[]expr.Expression) (*AggregateRel, error)
// Deprecated: Use CreateTableAsSelect(...).Remap() instead.
CreateTableAsSelectRemap(input Rel, remap []int32, tableName []string, schema types.NamedStruct) (*NamedTableWriteRel, error)
Expand Down Expand Up @@ -146,6 +144,10 @@ type Builder interface {
// GetExprBuilder returns an expr.ExprBuilder that shares the extension
// registry that this Builder uses.
GetExprBuilder() *expr.ExprBuilder

// GetRelBuilder returns an expr.RelBuilder that can be used to construct
// relations which need multiple stages to build them.
GetRelBuilder() *RelBuilder
}

const FETCH_COUNT_ALL_RECORDS = -1
Expand All @@ -166,6 +168,10 @@ func NewBuilder(c *extensions.Collection) Builder {
var (
errOutputMappingOutOfRange = fmt.Errorf("%w: output mapping index out of range", substraitgo.ErrInvalidRel)
errNilInputRel = fmt.Errorf("%w: input Relation must not be nil", substraitgo.ErrInvalidRel)
errNoGroupingOrMeasure = fmt.Errorf("%w: must have at least one grouping expression or measure for AggregateRel", substraitgo.ErrInvalidRel)
errNoGroupingExpression = fmt.Errorf("%w: groupings cannot contain empty expression list or nil expression", substraitgo.ErrInvalidRel)
errInvalidGroupingIndex = fmt.Errorf("%w: groupingReferences contains invalid indices", substraitgo.ErrInvalidRel)
errCubeGroupingSizeLimit = fmt.Errorf("cannot exceed %d grouping references for AddCube", maxGroupingSize)
)

type builder struct {
Expand Down Expand Up @@ -260,77 +266,34 @@ func (b *builder) Measure(measure *expr.AggregateFunction, filter expr.Expressio
}
}

func (b *builder) AggregateColumnsRemap(input Rel, remap []int32, measures []AggRelMeasure, groupByCols ...int32) (*AggregateRel, error) {
if input == nil {
return nil, errNilInputRel
}

if (len(measures) + len(groupByCols)) == 0 {
return nil, fmt.Errorf("%w: must have at least one grouping expression or measure for AggregateRel",
substraitgo.ErrInvalidRel)
}

exprs := make([][]expr.Expression, len(groupByCols))
for i, c := range groupByCols {
ref, err := b.RootFieldRef(input, c)
if err != nil {
return nil, err
}
exprs[i] = []expr.Expression{ref}
}

noutput := int32(len(measures) + len(groupByCols))
for _, idx := range remap {
if idx < 0 || idx >= noutput {
return nil, errOutputMappingOutOfRange
}
}

return &AggregateRel{
RelCommon: RelCommon{mapping: remap},
input: input,
groups: exprs,
measures: measures,
}, nil
}

func (b *builder) AggregateColumns(input Rel, measures []AggRelMeasure, groupByCols ...int32) (*AggregateRel, error) {
return b.AggregateColumnsRemap(input, nil, measures, groupByCols...)
}

func (b *builder) AggregateExprsRemap(input Rel, remap []int32, measures []AggRelMeasure, groups ...[]expr.Expression) (*AggregateRel, error) {
if input == nil {
return nil, errNilInputRel
}

if (len(measures) + len(groups)) == 0 {
return nil, fmt.Errorf("%w: must have at least one grouping expression or measure for AggregateRel",
substraitgo.ErrInvalidRel)
}

if slices.ContainsFunc(groups, func(exlist []expr.Expression) bool {
return len(exlist) == 0 || slices.ContainsFunc(exlist, func(e expr.Expression) bool { return e == nil })
}) {
return nil, fmt.Errorf("%w: groupings cannot contain empty expression list or nil expression", substraitgo.ErrInvalidRel)
}

noutput := int32(len(measures) + len(groups))
for _, idx := range remap {
if idx < 0 || idx >= noutput {
return nil, errOutputMappingOutOfRange
arb := b.GetRelBuilder().AggregateRel(input, measures)
if len(groupByCols) > 0 {
groupingReferences := []uint32{}
for _, c := range groupByCols {
ref, err := b.RootFieldRef(input, c)
if err != nil {
return nil, err
}
i := arb.AddExpression(ref)
groupingReferences = append(groupingReferences, i)
}
arb.AddGroupingSet(groupingReferences)
}

return &AggregateRel{
RelCommon: RelCommon{mapping: remap},
input: input,
groups: groups,
measures: measures,
}, nil
return arb.Build()
}

func (b *builder) AggregateExprs(input Rel, measures []AggRelMeasure, groups ...[]expr.Expression) (*AggregateRel, error) {
return b.AggregateExprsRemap(input, nil, measures, groups...)
arb := b.GetRelBuilder().AggregateRel(input, measures)
for _, group := range groups {
groupingSet := []uint32{}
for _, expr := range group {
i := arb.AddExpression(expr)
groupingSet = append(groupingSet, i)
}
arb.AddGroupingSet(groupingSet)
}
return arb.Build()
}

func (b *builder) CreateTableAsSelectRemap(input Rel, remap []int32, tableName []string, schema types.NamedStruct) (*NamedTableWriteRel, error) {
Expand Down Expand Up @@ -744,3 +707,132 @@ func (b *builder) Plan(root Rel, rootNames []string, others ...Rel) (*Plan, erro
var (
_ Builder = (*builder)(nil)
)

// RelBuilder is a builder for constructing a plan.Rel expression.
type RelBuilder struct {
}

func (b *builder) GetRelBuilder() *RelBuilder {
return &RelBuilder{}
}

// AggregateRel returns a builder for constructing an AggregateRelation
// expression. The input plan.Rel is the input relation to the aggregation.
// The measures are the aggregation measures to be computed.
func (r *RelBuilder) AggregateRel(input Rel, measures []AggRelMeasure) *AggregateRelBuilder {
return &AggregateRelBuilder{input: input, measures: measures}
}

type AggregateRelBuilder struct {
input Rel
measures []AggRelMeasure
groupingExpressions []expr.Expression
groupingReferences [][]uint32
}

// AddExpression adds an expression to the expression map and returns an expression reference.
func (arb *AggregateRelBuilder) AddExpression(e expr.Expression) uint32 {
for idx, expr := range arb.groupingExpressions {
if expr == e {
return uint32(idx)
}
}

arb.groupingExpressions = append(arb.groupingExpressions, e)
return uint32(len(arb.groupingExpressions) - 1)
}

// maxGroupingSize is the maximum allowed size for the grouping references in the AddCube API.
const maxGroupingSize = 20

// AddCube generates all combinations (subsets) of the group represented by the set of expressionReferences and appends them to groupingReferences.
// It uses the power set to generate all possible subsets and adds them to the groupingReferences.
// If the length of expressionReferences exceeds maxGroupingSize, an error is returned to avoid excessive computation.
func (arb *AggregateRelBuilder) AddCube(expressionReferences []uint32) error {
// Ensure the input size is within allowed limits
if len(expressionReferences) > maxGroupingSize {
return errCubeGroupingSizeLimit
}

// Total combinations in the power set (2^n)
totalCombinations := 1 << len(expressionReferences)

// Generate each subset based on the binary representation of the combination index
for combinationIndex := 1; combinationIndex < totalCombinations; combinationIndex++ {
group := extractGroup(expressionReferences, combinationIndex)
arb.groupingReferences = append(arb.groupingReferences, group)
}

return nil
}

// extractGroup generates a subset of expressionReferences based on the binary representation of combinationIndex.
// For each bit set to 1 in combinationIndex, the corresponding element from expressionReferences is included in the subset.
func extractGroup(expressionReferences []uint32, combinationIndex int) []uint32 {
var group []uint32
for bit := 0; bit < len(expressionReferences); bit++ {
if (combinationIndex & (1 << bit)) != 0 {
group = append(group, expressionReferences[bit])
}
}
return group
}

// addRollup constructs the rollup grouping strategy from the provided grouping references.
func (arb *AggregateRelBuilder) AddRollup(groupingReferences []uint32) {
for i := len(groupingReferences); i > 0; i-- {
rollupSet := groupingReferences[:i]
arb.groupingReferences = append(arb.groupingReferences, rollupSet)
}
}

// addGroupingSet adds a new grouping set based on the provided grouping references.
func (arb *AggregateRelBuilder) AddGroupingSet(groupingReferences []uint32) {
arb.groupingReferences = append(arb.groupingReferences, groupingReferences)
}

func (arb *AggregateRelBuilder) Build() (*AggregateRel, error) {
if err := arb.validate(); err != nil {
return nil, err
}

aggregateRel := &AggregateRel{
RelCommon: RelCommon{},
}
aggregateRel.SetInput(arb.input)
aggregateRel.SetMeasures(arb.measures)
aggregateRel.SetGroupingExpressions(arb.groupingExpressions)
aggregateRel.SetGroupingReferences(arb.groupingReferences)

return aggregateRel, nil
}

func (arb *AggregateRelBuilder) validate() error {
if arb.input == nil {
return errNilInputRel
}

if len(arb.measures) == 0 && len(arb.groupingReferences) == 0 {
return errNoGroupingOrMeasure
}

if len(arb.measures) == 0 && len(arb.groupingExpressions) == 0 {
return errNoGroupingExpression
}

if slices.ContainsFunc(arb.groupingExpressions, func(e expr.Expression) bool {
return e == nil
}) {
return errNoGroupingExpression
}

for _, refList := range arb.groupingReferences {
for _, ref := range refList {
if ref >= uint32(len(arb.groupingExpressions)) {
return errInvalidGroupingIndex
}
}
}

return nil
}
66 changes: 55 additions & 11 deletions plan/plan.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,33 @@ func init() {
}
}

// groupingExprs takes 2-dimensional slice of expressions and returns
// a single slice of unique expressions and a slice of references to
// the unique expressions for each group.
func groupingExprs(groups [][]expr.Expression) ([]expr.Expression, [][]uint32) {
groupingExpressions := make([]expr.Expression, 0)
groupingReferences := make([][]uint32, 0)
for _, group := range groups {
refs := make([]uint32, 0)
for _, expr := range group {
existingExpr := false
for eIndex, existing := range groupingExpressions {
if existing.Equals(expr) {
existingExpr = true
refs = append(refs, uint32(eIndex))
break
}
}
if !existingExpr {
groupingExpressions = append(groupingExpressions, expr)
refs = append(refs, uint32(len(groupingExpressions)-1))
}
}
groupingReferences = append(groupingReferences, refs)
}
return groupingExpressions, groupingReferences
}

// Relation is either a Root relation (a relation + list of column names)
// or another relation (such as a CTE or other reference).
type Relation struct {
Expand Down Expand Up @@ -401,16 +428,32 @@ func RelFromProto(rel *proto.Rel, reg expr.ExtensionRegistry) (Rel, error) {
}

base := input.RecordType()
groups := make([][]expr.Expression, len(rel.Aggregate.Groupings))
for i, g := range rel.Aggregate.Groupings {
groups[i] = make([]expr.Expression, len(g.GroupingExpressions))
for j, e := range g.GroupingExpressions {
groups[i][j], err = expr.ExprFromProto(e, &base, reg)
var groupingExpressions []expr.Expression
var groupingReferences [][]uint32
if len(rel.Aggregate.GroupingExpressions) > 0 {
for _, e := range rel.Aggregate.GroupingExpressions {
expr, err := expr.ExprFromProto(e, &base, reg)
if err != nil {
return nil, fmt.Errorf("error getting grouping expr [%d][%d] for AggregateRel: %w",
i, j, err)
return nil, fmt.Errorf("error getting grouping expr for AggregateRel: %w", err)
}
groupingExpressions = append(groupingExpressions, expr)
}
for _, g := range rel.Aggregate.Groupings {
groupingReferences = append(groupingReferences, g.ExpressionReferences)
}
} else { // support old style grouping for backward compatibility
groups := make([][]expr.Expression, len(rel.Aggregate.Groupings))
for i, g := range rel.Aggregate.Groupings {
groups[i] = make([]expr.Expression, len(g.GroupingExpressions))
for j, e := range g.GroupingExpressions {
groups[i][j], err = expr.ExprFromProto(e, &base, reg)
if err != nil {
return nil, fmt.Errorf("error getting grouping expr [%d][%d] for AggregateRel: %w",
i, j, err)
}
}
}
groupingExpressions, groupingReferences = groupingExprs(groups)
}

measures := make([]AggRelMeasure, len(rel.Aggregate.Measures))
Expand All @@ -429,10 +472,11 @@ func RelFromProto(rel *proto.Rel, reg expr.ExtensionRegistry) (Rel, error) {
}

out := &AggregateRel{
input: input,
groups: groups,
measures: measures,
advExtension: rel.Aggregate.AdvancedExtension,
input: input,
measures: measures,
groupingReferences: groupingReferences,
groupingExpressions: groupingExpressions,
advExtension: rel.Aggregate.AdvancedExtension,
}
out.fromProtoCommon(rel.Aggregate.Common)
return out, nil
Expand Down
Loading

0 comments on commit e403121

Please sign in to comment.