Skip to content

Commit

Permalink
feat: Add support for top level aggregates (#594)
Browse files Browse the repository at this point in the history
* Rework count input objects

Decouple them from host, cleaner for user, and allows reuse for top-level aggs

* Remove sourceInfo param from Sum

Minimal cost in re-getting, and makes it easier to call from other locations

* Use switch instead of if for type check

Switch will also gain a new case shortly

* Remove unused type from createExpandedFieldAggregate

* Use correct collection name

Whilst they should have the same value at the moment, the disinction between the two becomes more important when introducing top-level aggregates

* Extract out aggregate request logic to function

Will be called multiple times once top-level aggregates are introduced

* Remove legacy code

This has been incorrect for a while, and will cause problems for top-level aggregates

* Add support for top level aggregates
  • Loading branch information
AndrewSisley authored Jul 11, 2022
1 parent f9d5c0b commit 46ec563
Show file tree
Hide file tree
Showing 14 changed files with 809 additions and 64 deletions.
68 changes: 53 additions & 15 deletions query/graphql/mapper/mapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func toSelect(
Targetable: toTargetable(thisIndex, parsed, mapping),
DocumentMapping: *mapping,
Cid: parsed.CID,
CollectionName: desc.Name,
CollectionName: collectionName,
Fields: fields,
}, nil
}
Expand Down Expand Up @@ -365,6 +365,24 @@ func getRequestables(
desc *client.CollectionDescription,
descriptionsRepo *DescriptionsRepo,
) (fields []Requestable, aggregates []*aggregateRequest, err error) {
// If this parser.Select is itself an aggregate, we need to append the
// relevent info here as if it was a field of its own (due to a quirk of
// the parser package).
if _, isAggregate := parserTypes.Aggregates[parsed.Name]; isAggregate {
index := mapping.GetNextIndex()
aggregateReq, err := getAggregateRequests(index, parsed)
if err != nil {
return nil, nil, err
}

mapping.RenderKeys = append(mapping.RenderKeys, core.RenderKey{
Index: index,
Key: parsed.Alias,
})
mapping.Add(index, parsed.Name)
aggregates = append(aggregates, &aggregateReq)
}

for _, field := range parsed.Fields {
switch f := field.(type) {
case *parser.Field:
Expand All @@ -390,24 +408,12 @@ func getRequestables(
// aggregates have been requested and their targets here, before finalizing
// their evaluation later.
if _, isAggregate := parserTypes.Aggregates[f.Name]; isAggregate {
aggregateTargets, err := getAggregateSources(f)
aggregateRequest, err := getAggregateRequests(index, f)
if err != nil {
return nil, nil, err
}

if len(aggregateTargets) == 0 {
return nil, nil, fmt.Errorf(
"Aggregate must be provided with a property to aggregate.",
)
}

aggregates = append(aggregates, &aggregateRequest{
field: Field{
Index: index,
Name: f.Name,
},
targets: aggregateTargets,
})
aggregates = append(aggregates, &aggregateRequest)
} else {
innerSelect, err := toSelect(descriptionsRepo, index, f, desc.Name)
if err != nil {
Expand All @@ -433,13 +439,39 @@ func getRequestables(
return
}

func getAggregateRequests(index int, aggregate *parser.Select) (aggregateRequest, error) {
aggregateTargets, err := getAggregateSources(aggregate)
if err != nil {
return aggregateRequest{}, err
}

if len(aggregateTargets) == 0 {
return aggregateRequest{}, fmt.Errorf(
"Aggregate must be provided with a property to aggregate.",
)
}

return aggregateRequest{
field: Field{
Index: index,
Name: aggregate.Name,
},
targets: aggregateTargets,
}, nil
}

// getCollectionName returns the name of the parsed collection. This may be empty
// if this is a commit request.
func getCollectionName(
descriptionsRepo *DescriptionsRepo,
parsed *parser.Select,
parentCollectionName string,
) (string, error) {
if _, isAggregate := parserTypes.Aggregates[parsed.Name]; isAggregate {
// This string is not used or referenced, its value is only there to aid debugging
return "_topLevel", nil
}

if parsed.Name == parserTypes.GroupFieldName {
return parentCollectionName, nil
} else if parsed.Root == parserTypes.CommitSelection {
Expand Down Expand Up @@ -471,6 +503,12 @@ func getTopLevelInfo(
) (*core.DocumentMapping, *client.CollectionDescription, error) {
mapping := core.NewDocumentMapping()

if _, isAggregate := parserTypes.Aggregates[parsed.Name]; isAggregate {
// If this is a (top-level) aggregate, then it will have no collection
// description, and no top-level fields, so we return an empty mapping only
return mapping, &client.CollectionDescription{}, nil
}

if parsed.Root != parserTypes.CommitSelection {
mapping.Add(core.DocKeyFieldIndex, parserTypes.DocKeyFieldName)

Expand Down
19 changes: 2 additions & 17 deletions query/graphql/parser/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,8 @@ func parseSelect(rootType parserTypes.SelectionType, field *ast.Field, index int

// parse arguments
for _, argument := range field.Arguments {
prop, astValue := getArgumentKeyValue(field, argument)
prop := argument.Name.Value
astValue := argument.Value

// parse filter
if prop == parserTypes.FilterClause {
Expand Down Expand Up @@ -302,22 +303,6 @@ func parseSelect(rootType parserTypes.SelectionType, field *ast.Field, index int
return slct, err
}

// getArgumentKeyValue returns the relevant arguement name and value for the given field-argument
// Note: this function will likely need some rework when adding more aggregate options (e.g. limit)
func getArgumentKeyValue(field *ast.Field, argument *ast.Argument) (string, ast.Value) {
if _, isAggregate := parserTypes.Aggregates[field.Name.Value]; isAggregate {
switch innerProps := argument.Value.(type) {
case *ast.ObjectValue:
for _, innerV := range innerProps.Fields {
if innerV.Name.Value == parserTypes.FilterClause {
return parserTypes.FilterClause, innerV.Value
}
}
}
}
return argument.Name.Value, argument.Value
}

func getFieldAlias(field *ast.Field) string {
if field.Alias == nil {
return field.Name.Value
Expand Down
1 change: 1 addition & 0 deletions query/graphql/planner/operations.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ var (
_ planNode = (*selectTopNode)(nil)
_ planNode = (*orderNode)(nil)
_ planNode = (*sumNode)(nil)
_ planNode = (*topLevelNode)(nil)
_ planNode = (*typeIndexJoin)(nil)
_ planNode = (*typeJoinMany)(nil)
_ planNode = (*typeJoinOne)(nil)
Expand Down
25 changes: 25 additions & 0 deletions query/graphql/planner/planner.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,14 @@ func (p *Planner) newPlan(stmt interface{}) (planNode, error) {
if err != nil {
return nil, err
}

if _, isAgg := parserTypes.Aggregates[n.Name]; isAgg {
// If this Select is an aggregate, then it must be a top-level
// aggregate and we need to resolve it within the context of a
// top-level node.
return p.Top(m)
}

return p.Select(m)
case *mapper.Select:
return p.Select(n)
Expand Down Expand Up @@ -223,6 +231,23 @@ func (p *Planner) expandPlan(plan planNode, parentPlan *selectTopNode) error {
case *deleteNode:
return p.expandPlan(n.source, parentPlan)

case *topLevelNode:
for _, child := range n.children {
switch c := child.(type) {
case *selectTopNode:
// We only care about expanding the child source here, it is assumed that the parent source
// is expanded elsewhere/already
err := p.expandPlan(child, parentPlan)
if err != nil {
return err
}
case aggregateNode:
// top-level aggregates use the top-level node as a source
c.SetPlan(n)
}
}
return nil

default:
return nil
}
Expand Down
2 changes: 1 addition & 1 deletion query/graphql/planner/select.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ func (n *selectNode) initFields(parsed *mapper.Select) ([]aggregateNode, error)
case parserTypes.CountFieldName:
plan, aggregateError = n.p.Count(f, parsed)
case parserTypes.SumFieldName:
plan, aggregateError = n.p.Sum(&n.sourceInfo, f, parsed)
plan, aggregateError = n.p.Sum(f, parsed)
case parserTypes.AverageFieldName:
plan, aggregateError = n.p.Average(f)
}
Expand Down
21 changes: 11 additions & 10 deletions query/graphql/planner/sum.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,12 @@ type sumNode struct {
}

func (p *Planner) Sum(
sourceInfo *sourceInfo,
field *mapper.Aggregate,
parent *mapper.Select,
) (*sumNode, error) {
isFloat := false
for _, target := range field.AggregateTargets {
isTargetFloat, err := p.isValueFloat(&sourceInfo.collectionDescription, parent, &target)
isTargetFloat, err := p.isValueFloat(parent, &target)
if err != nil {
return nil, err
}
Expand All @@ -61,7 +60,6 @@ func (p *Planner) Sum(

// Returns true if the value to be summed is a float, otherwise false.
func (p *Planner) isValueFloat(
parentDescription *client.CollectionDescription,
parent *mapper.Select,
source *mapper.AggregateTarget,
) (bool, error) {
Expand All @@ -72,7 +70,11 @@ func (p *Planner) isValueFloat(
}

if !source.ChildTarget.HasValue {
// If path length is one - we are summing an inline array
parentDescription, err := p.getCollectionDesc(parent.CollectionName)
if err != nil {
return false, err
}

fieldDescription, fieldDescriptionFound := parentDescription.GetField(source.Name)
if !fieldDescriptionFound {
return false, fmt.Errorf(
Expand All @@ -95,11 +97,6 @@ func (p *Planner) isValueFloat(
return false, fmt.Errorf("Expected child select but none was found")
}

childCollectionDescription, err := p.getCollectionDesc(child.CollectionName)
if err != nil {
return false, err
}

if _, isAggregate := parserTypes.Aggregates[source.ChildTarget.Name]; isAggregate {
// If we are aggregating an aggregate, we need to traverse the aggregation chain down to
// the root field in order to determine the value type. This is recursive to allow handling
Expand All @@ -108,7 +105,6 @@ func (p *Planner) isValueFloat(

for _, aggregateTarget := range sourceField.AggregateTargets {
isFloat, err := p.isValueFloat(
&childCollectionDescription,
child,
&aggregateTarget,
)
Expand All @@ -124,6 +120,11 @@ func (p *Planner) isValueFloat(
return false, nil
}

childCollectionDescription, err := p.getCollectionDesc(child.CollectionName)
if err != nil {
return false, err
}

fieldDescription, fieldDescriptionFound := childCollectionDescription.GetField(source.ChildTarget.Name)
if !fieldDescriptionFound {
return false,
Expand Down
Loading

0 comments on commit 46ec563

Please sign in to comment.