diff --git a/mapper/mapper.go b/mapper/mapper.go index 68220dab8f..94de196051 100644 --- a/mapper/mapper.go +++ b/mapper/mapper.go @@ -14,11 +14,8 @@ import ( "context" "fmt" "reflect" - "strconv" "strings" - "github.com/graphql-go/graphql/language/ast" - "github.com/sourcenetwork/defradb/client" "github.com/sourcenetwork/defradb/connor" "github.com/sourcenetwork/defradb/core" @@ -443,24 +440,6 @@ func getRequestables( desc *client.CollectionDescription, descriptionsRepo *DescriptionsRepo, ) (fields []Requestable, aggregates []*aggregateRequest, err error) { - // If this parser.Select is itself an aggregate, we need to append the - // relevent info here as if it was a field of its own (due to a quirk of - // the parser package). - if _, isAggregate := parserTypes.Aggregates[parsed.Name]; isAggregate { - index := mapping.GetNextIndex() - aggregateReq, err := getAggregateRequests(index, parsed) - if err != nil { - return nil, nil, err - } - - mapping.RenderKeys = append(mapping.RenderKeys, core.RenderKey{ - Index: index, - Key: parsed.Alias, - }) - mapping.Add(index, parsed.Name) - aggregates = append(aggregates, &aggregateReq) - } - for _, field := range parsed.Fields { switch f := field.(type) { case *parser.Field: @@ -481,26 +460,28 @@ func getRequestables( case *parser.Select: index := mapping.GetNextIndex() - // Aggregate targets are not known at this point, and must be evaluated - // after all requested fields have been evaluated - so we note which - // aggregates have been requested and their targets here, before finalizing - // their evaluation later. - if _, isAggregate := parserTypes.Aggregates[f.Name]; isAggregate { - aggregateRequest, err := getAggregateRequests(index, f) - if err != nil { - return nil, nil, err - } + innerSelect, err := toSelect(descriptionsRepo, index, f, desc.Name) + if err != nil { + return nil, nil, err + } + fields = append(fields, innerSelect) + mapping.SetChildAt(index, &innerSelect.DocumentMapping) - aggregates = append(aggregates, &aggregateRequest) - } else { - innerSelect, err := toSelect(descriptionsRepo, index, f, desc.Name) - if err != nil { - return nil, nil, err - } - fields = append(fields, innerSelect) - mapping.SetChildAt(index, &innerSelect.DocumentMapping) + mapping.RenderKeys = append(mapping.RenderKeys, core.RenderKey{ + Index: index, + Key: f.Alias, + }) + + mapping.Add(index, f.Name) + case *parser.Aggregate: + index := mapping.GetNextIndex() + aggregateRequest, err := getAggregateRequests(index, f) + if err != nil { + return nil, nil, err } + aggregates = append(aggregates, &aggregateRequest) + mapping.RenderKeys = append(mapping.RenderKeys, core.RenderKey{ Index: index, Key: f.Alias, @@ -517,7 +498,7 @@ func getRequestables( return } -func getAggregateRequests(index int, aggregate *parser.Select) (aggregateRequest, error) { +func getAggregateRequests(index int, aggregate *parser.Aggregate) (aggregateRequest, error) { aggregateTargets, err := getAggregateSources(aggregate) if err != nil { return aggregateRequest{}, err @@ -1163,119 +1144,22 @@ type aggregateRequestTarget struct { } // Returns the source of the aggregate as requested by the consumer -func getAggregateSources(field *parser.Select) ([]*aggregateRequestTarget, error) { - targets := make([]*aggregateRequestTarget, len(field.Statement.Arguments)) - - for i, argument := range field.Statement.Arguments { - switch argumentValue := argument.Value.GetValue().(type) { - case string: - targets[i] = &aggregateRequestTarget{ - hostExternalName: argumentValue, - } - case []*ast.ObjectField: - hostExternalName := argument.Name.Value - var childExternalName string - var filter *parser.Filter - var limit *Limit - var order *parserTypes.OrderBy - - fieldArg, hasFieldArg := tryGet(argumentValue, parserTypes.Field) - if hasFieldArg { - if innerPathStringValue, isString := fieldArg.Value.GetValue().(string); isString { - childExternalName = innerPathStringValue - } - } - - filterArg, hasFilterArg := tryGet(argumentValue, parserTypes.FilterClause) - if hasFilterArg { - var err error - filter, err = parser.NewFilter(filterArg.Value.(*ast.ObjectValue)) - if err != nil { - return nil, err - } - } - - limitArg, hasLimitArg := tryGet(argumentValue, parserTypes.LimitClause) - offsetArg, hasOffsetArg := tryGet(argumentValue, parserTypes.OffsetClause) - var limitValue int64 - var offsetValue int64 - if hasLimitArg { - var err error - limitValue, err = strconv.ParseInt(limitArg.Value.(*ast.IntValue).Value, 10, 64) - if err != nil { - return nil, err - } - } - - if hasOffsetArg { - var err error - offsetValue, err = strconv.ParseInt(offsetArg.Value.(*ast.IntValue).Value, 10, 64) - if err != nil { - return nil, err - } - } - - if hasLimitArg || hasOffsetArg { - limit = &Limit{ - Limit: limitValue, - Offset: offsetValue, - } - } - - orderArg, hasOrderArg := tryGet(argumentValue, parserTypes.OrderClause) - if hasOrderArg { - switch orderArgValue := orderArg.Value.(type) { - case *ast.EnumValue: - // For inline arrays the order arg will be a simple enum declaring the order direction - orderDirectionString := orderArgValue.Value - orderDirection := parserTypes.OrderDirection(orderDirectionString) - - order = &parserTypes.OrderBy{ - Conditions: []parserTypes.OrderCondition{ - { - Direction: orderDirection, - }, - }, - } - - case *ast.ObjectValue: - // For relations the order arg will be the complex order object as used by the host object - // for non-aggregate ordering - - // We use the parser package parsing for convienience here - orderConditions, err := parser.ParseConditionsInOrder(orderArgValue) - if err != nil { - return nil, err - } - - order = &parserTypes.OrderBy{ - Conditions: orderConditions, - } - } - } - - targets[i] = &aggregateRequestTarget{ - hostExternalName: hostExternalName, - childExternalName: childExternalName, - filter: filter, - limit: limit, - order: order, - } +func getAggregateSources(field *parser.Aggregate) ([]*aggregateRequestTarget, error) { + targets := make([]*aggregateRequestTarget, len(field.Targets)) + + for i, target := range field.Targets { + targets[i] = &aggregateRequestTarget{ + hostExternalName: target.HostName, + childExternalName: target.ChildName.Value(), + filter: target.Filter, + limit: (*Limit)(target.Limit), + order: target.OrderBy, } } return targets, nil } -func tryGet(fields []*ast.ObjectField, name string) (*ast.ObjectField, bool) { - for _, field := range fields { - if field.Name.Value == name { - return field, true - } - } - return nil, false -} - // tryGetMatchingAggregate scans the given collection for aggregates with the given name and targets. // // Will return the matching target and true if one is found, otherwise will return false. diff --git a/planner/planner.go b/planner/planner.go index 3e677698de..6e92e5e9e2 100644 --- a/planner/planner.go +++ b/planner/planner.go @@ -133,6 +133,7 @@ func (p *Planner) newPlan(stmt any) (planNode, error) { } return p.Select(m) + case *mapper.Select: return p.Select(n) diff --git a/query/graphql/parser/query.go b/query/graphql/parser/query.go index 4c232be727..863e7eeae5 100644 --- a/query/graphql/parser/query.go +++ b/query/graphql/parser/query.go @@ -216,6 +216,20 @@ func parseQueryOperationDefinition(def *ast.OperationDefinition) (*OperationDefi } parsedSelection = parsed + } else if _, isAggregate := parserTypes.Aggregates[node.Name.Value]; isAggregate { + parsed, err := parseAggregate(node, i) + if err != nil { + return nil, []error{err} + } + + // Top-level aggregates must be wrapped in a top-level Select for now + parsedSelection = &Select{ + Name: parsed.Name, + Alias: parsed.Alias, + Fields: []Selection{ + parsed, + }, + } } else { // the query doesn't match a reserve name // so its probably a generated query @@ -351,7 +365,7 @@ func parseSelectFields(root parserTypes.SelectionType, fields *ast.SelectionSet) switch node := selection.(type) { case *ast.Field: if _, isAggregate := parserTypes.Aggregates[node.Name.Value]; isAggregate { - s, err := parseSelect(root, node, i) + s, err := parseAggregate(node, i) if err != nil { return nil, err } @@ -385,6 +399,139 @@ func parseField(root parserTypes.SelectionType, field *ast.Field) *Field { } } +type Aggregate struct { + Name string + Alias string + + Targets []*AggregateTarget +} + +type AggregateTarget struct { + HostName string + ChildName client.Option[string] + + Limit *parserTypes.Limit + OrderBy *parserTypes.OrderBy + Filter *Filter +} + +func parseAggregate(field *ast.Field, index int) (*Aggregate, error) { + targets := make([]*AggregateTarget, len(field.Arguments)) + + for i, argument := range field.Arguments { + switch argumentValue := argument.Value.GetValue().(type) { + case string: + targets[i] = &AggregateTarget{ + HostName: argumentValue, + } + case []*ast.ObjectField: + hostName := argument.Name.Value + var childName string + var filter *Filter + var limit *parserTypes.Limit + var order *parserTypes.OrderBy + + fieldArg, hasFieldArg := tryGet(argumentValue, parserTypes.Field) + if hasFieldArg { + if innerPathStringValue, isString := fieldArg.Value.GetValue().(string); isString { + childName = innerPathStringValue + } + } + + filterArg, hasFilterArg := tryGet(argumentValue, parserTypes.FilterClause) + if hasFilterArg { + var err error + filter, err = NewFilter(filterArg.Value.(*ast.ObjectValue)) + if err != nil { + return nil, err + } + } + + limitArg, hasLimitArg := tryGet(argumentValue, parserTypes.LimitClause) + offsetArg, hasOffsetArg := tryGet(argumentValue, parserTypes.OffsetClause) + var limitValue int64 + var offsetValue int64 + if hasLimitArg { + var err error + limitValue, err = strconv.ParseInt(limitArg.Value.(*ast.IntValue).Value, 10, 64) + if err != nil { + return nil, err + } + } + + if hasOffsetArg { + var err error + offsetValue, err = strconv.ParseInt(offsetArg.Value.(*ast.IntValue).Value, 10, 64) + if err != nil { + return nil, err + } + } + + if hasLimitArg || hasOffsetArg { + limit = &parserTypes.Limit{ + Limit: limitValue, + Offset: offsetValue, + } + } + + orderArg, hasOrderArg := tryGet(argumentValue, parserTypes.OrderClause) + if hasOrderArg { + switch orderArgValue := orderArg.Value.(type) { + case *ast.EnumValue: + // For inline arrays the order arg will be a simple enum declaring the order direction + orderDirectionString := orderArgValue.Value + orderDirection := parserTypes.OrderDirection(orderDirectionString) + + order = &parserTypes.OrderBy{ + Conditions: []parserTypes.OrderCondition{ + { + Direction: orderDirection, + }, + }, + } + + case *ast.ObjectValue: + // For relations the order arg will be the complex order object as used by the host object + // for non-aggregate ordering + + // We use the parser package parsing for convienience here + orderConditions, err := ParseConditionsInOrder(orderArgValue) + if err != nil { + return nil, err + } + + order = &parserTypes.OrderBy{ + Conditions: orderConditions, + } + } + } + + targets[i] = &AggregateTarget{ + HostName: hostName, + ChildName: client.Some(childName), + Filter: filter, + Limit: limit, + OrderBy: order, + } + } + } + + return &Aggregate{ + Alias: getFieldAlias(field), + Name: field.Name.Value, + Targets: targets, + }, nil +} + +func tryGet(fields []*ast.ObjectField, name string) (*ast.ObjectField, bool) { + for _, field := range fields { + if field.Name.Value == name { + return field, true + } + } + return nil, false +} + func parseAPIQuery(field *ast.Field) (Selection, error) { switch field.Name.Value { case "latestCommits", "commits":