From 61e9a4e16fa6aa6964676d95c4e4ed203bad7625 Mon Sep 17 00:00:00 2001 From: AndrewSisley Date: Mon, 11 Jul 2022 12:04:34 -0400 Subject: [PATCH] feat: Add support for top level aggregates (#594) * Rework count input objects Decouple them from host, cleaner for user, and allows reuse for top-level aggs * Remove sourceInfo param from Sum Minimal cost in re-getting, and makes it easier to call from other locations * Use switch instead of if for type check Switch will also gain a new case shortly * Remove unused type from createExpandedFieldAggregate * Use correct collection name Whilst they should have the same value at the moment, the disinction between the two becomes more important when introducing top-level aggregates * Extract out aggregate request logic to function Will be called multiple times once top-level aggregates are introduced * Remove legacy code This has been incorrect for a while, and will cause problems for top-level aggregates * Add support for top level aggregates --- query/graphql/mapper/mapper.go | 68 ++++-- query/graphql/parser/query.go | 19 +- query/graphql/planner/operations.go | 1 + query/graphql/planner/planner.go | 25 +++ query/graphql/planner/select.go | 2 +- query/graphql/planner/sum.go | 21 +- query/graphql/planner/top.go | 212 ++++++++++++++++++ query/graphql/schema/generate.go | 159 +++++++++++-- .../query/simple/with_average_filter_test.go | 49 ++++ .../query/simple/with_average_test.go | 73 ++++++ .../query/simple/with_count_filter_test.go | 49 ++++ .../query/simple/with_count_test.go | 73 ++++++ .../query/simple/with_sum_filter_test.go | 49 ++++ .../integration/query/simple/with_sum_test.go | 73 ++++++ 14 files changed, 809 insertions(+), 64 deletions(-) create mode 100644 query/graphql/planner/top.go create mode 100644 tests/integration/query/simple/with_average_filter_test.go create mode 100644 tests/integration/query/simple/with_average_test.go create mode 100644 tests/integration/query/simple/with_count_filter_test.go create mode 100644 tests/integration/query/simple/with_count_test.go create mode 100644 tests/integration/query/simple/with_sum_filter_test.go create mode 100644 tests/integration/query/simple/with_sum_test.go diff --git a/query/graphql/mapper/mapper.go b/query/graphql/mapper/mapper.go index 07d3e9f564..7413daee4d 100644 --- a/query/graphql/mapper/mapper.go +++ b/query/graphql/mapper/mapper.go @@ -93,7 +93,7 @@ func toSelect( Targetable: toTargetable(thisIndex, parsed, mapping), DocumentMapping: *mapping, Cid: parsed.CID, - CollectionName: desc.Name, + CollectionName: collectionName, Fields: fields, }, nil } @@ -365,6 +365,24 @@ func getRequestables( desc *client.CollectionDescription, descriptionsRepo *DescriptionsRepo, ) (fields []Requestable, aggregates []*aggregateRequest, err error) { + // If this parser.Select is itself an aggregate, we need to append the + // relevent info here as if it was a field of its own (due to a quirk of + // the parser package). + if _, isAggregate := parserTypes.Aggregates[parsed.Name]; isAggregate { + index := mapping.GetNextIndex() + aggregateReq, err := getAggregateRequests(index, parsed) + if err != nil { + return nil, nil, err + } + + mapping.RenderKeys = append(mapping.RenderKeys, core.RenderKey{ + Index: index, + Key: parsed.Alias, + }) + mapping.Add(index, parsed.Name) + aggregates = append(aggregates, &aggregateReq) + } + for _, field := range parsed.Fields { switch f := field.(type) { case *parser.Field: @@ -390,24 +408,12 @@ func getRequestables( // aggregates have been requested and their targets here, before finalizing // their evaluation later. if _, isAggregate := parserTypes.Aggregates[f.Name]; isAggregate { - aggregateTargets, err := getAggregateSources(f) + aggregateRequest, err := getAggregateRequests(index, f) if err != nil { return nil, nil, err } - if len(aggregateTargets) == 0 { - return nil, nil, fmt.Errorf( - "Aggregate must be provided with a property to aggregate.", - ) - } - - aggregates = append(aggregates, &aggregateRequest{ - field: Field{ - Index: index, - Name: f.Name, - }, - targets: aggregateTargets, - }) + aggregates = append(aggregates, &aggregateRequest) } else { innerSelect, err := toSelect(descriptionsRepo, index, f, desc.Name) if err != nil { @@ -433,6 +439,27 @@ func getRequestables( return } +func getAggregateRequests(index int, aggregate *parser.Select) (aggregateRequest, error) { + aggregateTargets, err := getAggregateSources(aggregate) + if err != nil { + return aggregateRequest{}, err + } + + if len(aggregateTargets) == 0 { + return aggregateRequest{}, fmt.Errorf( + "Aggregate must be provided with a property to aggregate.", + ) + } + + return aggregateRequest{ + field: Field{ + Index: index, + Name: aggregate.Name, + }, + targets: aggregateTargets, + }, nil +} + // getCollectionName returns the name of the parsed collection. This may be empty // if this is a commit request. func getCollectionName( @@ -440,6 +467,11 @@ func getCollectionName( parsed *parser.Select, parentCollectionName string, ) (string, error) { + if _, isAggregate := parserTypes.Aggregates[parsed.Name]; isAggregate { + // This string is not used or referenced, its value is only there to aid debugging + return "_topLevel", nil + } + if parsed.Name == parserTypes.GroupFieldName { return parentCollectionName, nil } else if parsed.Root == parserTypes.CommitSelection { @@ -471,6 +503,12 @@ func getTopLevelInfo( ) (*core.DocumentMapping, *client.CollectionDescription, error) { mapping := core.NewDocumentMapping() + if _, isAggregate := parserTypes.Aggregates[parsed.Name]; isAggregate { + // If this is a (top-level) aggregate, then it will have no collection + // description, and no top-level fields, so we return an empty mapping only + return mapping, &client.CollectionDescription{}, nil + } + if parsed.Root != parserTypes.CommitSelection { mapping.Add(core.DocKeyFieldIndex, parserTypes.DocKeyFieldName) diff --git a/query/graphql/parser/query.go b/query/graphql/parser/query.go index f73bb1a550..f8cd6bc926 100644 --- a/query/graphql/parser/query.go +++ b/query/graphql/parser/query.go @@ -208,7 +208,8 @@ func parseSelect(rootType parserTypes.SelectionType, field *ast.Field, index int // parse arguments for _, argument := range field.Arguments { - prop, astValue := getArgumentKeyValue(field, argument) + prop := argument.Name.Value + astValue := argument.Value // parse filter if prop == parserTypes.FilterClause { @@ -302,22 +303,6 @@ func parseSelect(rootType parserTypes.SelectionType, field *ast.Field, index int return slct, err } -// getArgumentKeyValue returns the relevant arguement name and value for the given field-argument -// Note: this function will likely need some rework when adding more aggregate options (e.g. limit) -func getArgumentKeyValue(field *ast.Field, argument *ast.Argument) (string, ast.Value) { - if _, isAggregate := parserTypes.Aggregates[field.Name.Value]; isAggregate { - switch innerProps := argument.Value.(type) { - case *ast.ObjectValue: - for _, innerV := range innerProps.Fields { - if innerV.Name.Value == parserTypes.FilterClause { - return parserTypes.FilterClause, innerV.Value - } - } - } - } - return argument.Name.Value, argument.Value -} - func getFieldAlias(field *ast.Field) string { if field.Alias == nil { return field.Name.Value diff --git a/query/graphql/planner/operations.go b/query/graphql/planner/operations.go index 60cccb34c7..94beab9b27 100644 --- a/query/graphql/planner/operations.go +++ b/query/graphql/planner/operations.go @@ -30,6 +30,7 @@ var ( _ planNode = (*selectTopNode)(nil) _ planNode = (*orderNode)(nil) _ planNode = (*sumNode)(nil) + _ planNode = (*topLevelNode)(nil) _ planNode = (*typeIndexJoin)(nil) _ planNode = (*typeJoinMany)(nil) _ planNode = (*typeJoinOne)(nil) diff --git a/query/graphql/planner/planner.go b/query/graphql/planner/planner.go index 0c183bc4cb..5be29f3786 100644 --- a/query/graphql/planner/planner.go +++ b/query/graphql/planner/planner.go @@ -125,6 +125,14 @@ func (p *Planner) newPlan(stmt interface{}) (planNode, error) { if err != nil { return nil, err } + + if _, isAgg := parserTypes.Aggregates[n.Name]; isAgg { + // If this Select is an aggregate, then it must be a top-level + // aggregate and we need to resolve it within the context of a + // top-level node. + return p.Top(m) + } + return p.Select(m) case *mapper.Select: return p.Select(n) @@ -223,6 +231,23 @@ func (p *Planner) expandPlan(plan planNode, parentPlan *selectTopNode) error { case *deleteNode: return p.expandPlan(n.source, parentPlan) + case *topLevelNode: + for _, child := range n.children { + switch c := child.(type) { + case *selectTopNode: + // We only care about expanding the child source here, it is assumed that the parent source + // is expanded elsewhere/already + err := p.expandPlan(child, parentPlan) + if err != nil { + return err + } + case aggregateNode: + // top-level aggregates use the top-level node as a source + c.SetPlan(n) + } + } + return nil + default: return nil } diff --git a/query/graphql/planner/select.go b/query/graphql/planner/select.go index 5a7ee9c9bd..90b0af0156 100644 --- a/query/graphql/planner/select.go +++ b/query/graphql/planner/select.go @@ -254,7 +254,7 @@ func (n *selectNode) initFields(parsed *mapper.Select) ([]aggregateNode, error) case parserTypes.CountFieldName: plan, aggregateError = n.p.Count(f, parsed) case parserTypes.SumFieldName: - plan, aggregateError = n.p.Sum(&n.sourceInfo, f, parsed) + plan, aggregateError = n.p.Sum(f, parsed) case parserTypes.AverageFieldName: plan, aggregateError = n.p.Average(f) } diff --git a/query/graphql/planner/sum.go b/query/graphql/planner/sum.go index ca6da55100..13d144ed1c 100644 --- a/query/graphql/planner/sum.go +++ b/query/graphql/planner/sum.go @@ -33,13 +33,12 @@ type sumNode struct { } func (p *Planner) Sum( - sourceInfo *sourceInfo, field *mapper.Aggregate, parent *mapper.Select, ) (*sumNode, error) { isFloat := false for _, target := range field.AggregateTargets { - isTargetFloat, err := p.isValueFloat(&sourceInfo.collectionDescription, parent, &target) + isTargetFloat, err := p.isValueFloat(parent, &target) if err != nil { return nil, err } @@ -61,7 +60,6 @@ func (p *Planner) Sum( // Returns true if the value to be summed is a float, otherwise false. func (p *Planner) isValueFloat( - parentDescription *client.CollectionDescription, parent *mapper.Select, source *mapper.AggregateTarget, ) (bool, error) { @@ -72,7 +70,11 @@ func (p *Planner) isValueFloat( } if !source.ChildTarget.HasValue { - // If path length is one - we are summing an inline array + parentDescription, err := p.getCollectionDesc(parent.CollectionName) + if err != nil { + return false, err + } + fieldDescription, fieldDescriptionFound := parentDescription.GetField(source.Name) if !fieldDescriptionFound { return false, fmt.Errorf( @@ -95,11 +97,6 @@ func (p *Planner) isValueFloat( return false, fmt.Errorf("Expected child select but none was found") } - childCollectionDescription, err := p.getCollectionDesc(child.CollectionName) - if err != nil { - return false, err - } - if _, isAggregate := parserTypes.Aggregates[source.ChildTarget.Name]; isAggregate { // If we are aggregating an aggregate, we need to traverse the aggregation chain down to // the root field in order to determine the value type. This is recursive to allow handling @@ -108,7 +105,6 @@ func (p *Planner) isValueFloat( for _, aggregateTarget := range sourceField.AggregateTargets { isFloat, err := p.isValueFloat( - &childCollectionDescription, child, &aggregateTarget, ) @@ -124,6 +120,11 @@ func (p *Planner) isValueFloat( return false, nil } + childCollectionDescription, err := p.getCollectionDesc(child.CollectionName) + if err != nil { + return false, err + } + fieldDescription, fieldDescriptionFound := childCollectionDescription.GetField(source.ChildTarget.Name) if !fieldDescriptionFound { return false, diff --git a/query/graphql/planner/top.go b/query/graphql/planner/top.go new file mode 100644 index 0000000000..976e370c3d --- /dev/null +++ b/query/graphql/planner/top.go @@ -0,0 +1,212 @@ +// Copyright 2022 Democratized Data Foundation +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package planner + +import ( + "errors" + + "github.com/sourcenetwork/defradb/core" + "github.com/sourcenetwork/defradb/query/graphql/mapper" + parserTypes "github.com/sourcenetwork/defradb/query/graphql/parser/types" +) + +// topLevelNode is a special node that represents the very top of the +// plan graph. It has no source, and will only yield a single item +// containing all of its children. +type topLevelNode struct { + documentIterator + docMapper + + children []planNode + childIndexes []int + isdone bool + + // This node's children may use this node as a source + // this property controls the recursive flow preventing + // infinate loops. + isInRecurse bool +} + +func (n *topLevelNode) Spans(spans core.Spans) { + if n.isInRecurse { + return + } + n.isInRecurse = true + defer func() { + n.isInRecurse = false + }() + + for _, child := range n.children { + child.Spans(spans) + } +} + +func (n *topLevelNode) Kind() string { + return "topLevelNode" +} + +func (n *topLevelNode) Init() error { + if n.isInRecurse { + return nil + } + n.isInRecurse = true + defer func() { + n.isInRecurse = false + }() + + n.isdone = false + for _, child := range n.children { + err := child.Init() + if err != nil { + return err + } + } + + return nil +} + +func (n *topLevelNode) Start() error { + if n.isInRecurse { + return nil + } + n.isInRecurse = true + defer func() { + n.isInRecurse = false + }() + + for _, child := range n.children { + err := child.Start() + if err != nil { + return err + } + } + + return nil +} + +func (n *topLevelNode) Close() error { + if n.isInRecurse { + return nil + } + n.isInRecurse = true + defer func() { + n.isInRecurse = false + }() + + for _, child := range n.children { + err := child.Close() + if err != nil { + return err + } + } + + return nil +} + +func (n *topLevelNode) Source() planNode { + return nil +} + +func (n *topLevelNode) Next() (bool, error) { + if n.isdone { + return false, nil + } + + if n.isInRecurse { + return true, nil + } + + n.currentValue = n.documentMapping.NewDoc() + n.isInRecurse = true + defer func() { + n.isInRecurse = false + }() + + for i, child := range n.children { + switch child.(type) { + case *selectTopNode: + docs := []core.Doc{} + for { + hasChild, err := child.Next() + if err != nil { + return false, err + } + if !hasChild { + break + } + docs = append(docs, child.Value()) + } + n.currentValue.Fields[n.childIndexes[i]] = docs + default: + // This Next will always return a value, as it's source is this node! + // Even if it adds nothing to the current currentValue, it should still + // yield it unchanged. + hasChild, err := child.Next() + if err != nil { + return false, err + } + if !hasChild { + return false, errors.New("Expected child value, however none was yielded") + } + + n.currentValue = child.Value() + } + } + + n.isdone = true + return true, nil +} + +// Top creates a new topLevelNode using the given Select. +func (p *Planner) Top(m *mapper.Select) (*topLevelNode, error) { + node := topLevelNode{ + docMapper: docMapper{&m.DocumentMapping}, + } + + aggregateChildren := []planNode{} + aggregateChildIndexes := []int{} + for _, field := range m.Fields { + switch f := field.(type) { + case *mapper.Aggregate: + var child planNode + var err error + switch field.GetName() { + case parserTypes.CountFieldName: + child, err = p.Count(f, m) + case parserTypes.SumFieldName: + child, err = p.Sum(f, m) + case parserTypes.AverageFieldName: + child, err = p.Average(f) + } + if err != nil { + return nil, err + } + aggregateChildren = append(aggregateChildren, child) + aggregateChildIndexes = append(aggregateChildIndexes, field.GetIndex()) + case *mapper.Select: + child, err := p.Select(f) + if err != nil { + return nil, err + } + node.children = append(node.children, child) + node.childIndexes = append(node.childIndexes, field.GetIndex()) + } + } + + // Iterate through the aggregates backwards to ensure dependencies + // execute *before* any aggregate dependent on them. + for i := len(aggregateChildren) - 1; i >= 0; i-- { + node.children = append(node.children, aggregateChildren[i]) + node.childIndexes = append(node.childIndexes, aggregateChildIndexes[i]) + } + + return &node, nil +} diff --git a/query/graphql/schema/generate.go b/query/graphql/schema/generate.go index 43cd940132..9ce3f4e031 100644 --- a/query/graphql/schema/generate.go +++ b/query/graphql/schema/generate.go @@ -177,11 +177,20 @@ func (g *Generator) fromAST(ctx context.Context, document *ast.Document) ([]*gql // queries := query.Fields() // only apply to generated query fields, and only once for _, def := range generatedQueryFields { - t := def.Type - if obj, ok := t.(*gql.List); ok { + switch obj := def.Type.(type) { + case *gql.List: if err := g.expandInputArgument(obj.OfType.(*gql.Object)); err != nil { return nil, err } + case *gql.Scalar: + if _, isAggregate := parserTypes.Aggregates[def.Name]; isAggregate { + for name, aggregateTarget := range def.Args { + expandedField := &gql.InputObjectFieldConfig{ + Type: g.manager.schema.TypeMap()[name+"FilterArg"], + } + aggregateTarget.Type.(*gql.InputObject).AddFieldConfig(parserTypes.FilterClause, expandedField) + } + } } } @@ -265,7 +274,7 @@ func (g *Generator) expandInputArgument(obj *gql.Object) error { } case *gql.Scalar: if _, isAggregate := parserTypes.Aggregates[f]; isAggregate { - g.createExpandedFieldAggregate(obj, def, t) + g.createExpandedFieldAggregate(obj, def) } // @todo: check if NonNull is possible here //case *gql.NonNull: @@ -279,7 +288,6 @@ func (g *Generator) expandInputArgument(obj *gql.Object) error { func (g *Generator) createExpandedFieldAggregate( obj *gql.Object, f *gql.FieldDefinition, - t gql.Type, ) { for _, aggregateTarget := range f.Args { target := aggregateTarget.Name() @@ -511,17 +519,38 @@ func getRelationshipName( func (g *Generator) genAggregateFields(ctx context.Context) error { numBaseArgs := make(map[string]*gql.InputObject) + topLevelCountInputs := map[string]*gql.InputObject{} + topLevelNumericAggInputs := map[string]*gql.InputObject{} + for _, t := range g.typeDefs { numArg := g.genNumericAggregateBaseArgInputs(t) numBaseArgs[numArg.Name()] = numArg + topLevelNumericAggInputs[t.Name()] = numArg // All base types need to be appended to the schema before calling genSumFieldConfig err := g.manager.schema.AppendType(numArg) if err != nil { return err } - objs := g.genNumericInlineArraySelectorObject(t) - for _, obj := range objs { + numericInlineArrayInputs := g.genNumericInlineArraySelectorObject(t) + for _, obj := range numericInlineArrayInputs { + numBaseArgs[obj.Name()] = obj + err := g.manager.schema.AppendType(obj) + if err != nil { + return err + } + } + + obj := g.genCountBaseArgInputs(t) + numBaseArgs[obj.Name()] = obj + topLevelCountInputs[t.Name()] = obj + err = g.manager.schema.AppendType(obj) + if err != nil { + return err + } + + countableInlineArrayInputs := g.genCountInlineArrayInputs(t) + for _, obj := range countableInlineArrayInputs { numBaseArgs[obj.Name()] = obj err := g.manager.schema.AppendType(obj) if err != nil { @@ -531,7 +560,7 @@ func (g *Generator) genAggregateFields(ctx context.Context) error { } for _, t := range g.typeDefs { - countField, err := g.genCountFieldConfig(t) + countField, err := g.genCountFieldConfig(t, numBaseArgs) if err != nil { return err } @@ -550,10 +579,54 @@ func (g *Generator) genAggregateFields(ctx context.Context) error { t.AddFieldConfig(averageField.Name, &averageField) } + queryType := g.manager.schema.QueryType() + + topLevelCountField := genTopLevelCount(topLevelCountInputs) + queryType.AddFieldConfig(topLevelCountField.Name, topLevelCountField) + + for _, topLevelAgg := range genTopLevelNumericAggregates(topLevelNumericAggInputs) { + queryType.AddFieldConfig(topLevelAgg.Name, topLevelAgg) + } + return nil } -func (g *Generator) genCountFieldConfig(obj *gql.Object) (gql.Field, error) { +func genTopLevelCount(topLevelCountInputs map[string]*gql.InputObject) *gql.Field { + topLevelCountField := gql.Field{ + Name: parserTypes.CountFieldName, + Type: gql.Int, + Args: gql.FieldConfigArgument{}, + } + + for name, inputObject := range topLevelCountInputs { + topLevelCountField.Args[name] = schemaTypes.NewArgConfig(inputObject) + } + + return &topLevelCountField +} + +func genTopLevelNumericAggregates(topLevelNumericAggInputs map[string]*gql.InputObject) []*gql.Field { + topLevelSumField := gql.Field{ + Name: parserTypes.SumFieldName, + Type: gql.Float, + Args: gql.FieldConfigArgument{}, + } + + topLevelAverageField := gql.Field{ + Name: parserTypes.AverageFieldName, + Type: gql.Float, + Args: gql.FieldConfigArgument{}, + } + + for name, inputObject := range topLevelNumericAggInputs { + topLevelSumField.Args[name] = schemaTypes.NewArgConfig(inputObject) + topLevelAverageField.Args[name] = schemaTypes.NewArgConfig(inputObject) + } + + return []*gql.Field{&topLevelSumField, &topLevelAverageField} +} + +func (g *Generator) genCountFieldConfig(obj *gql.Object, numBaseArgs map[string]*gql.InputObject) (gql.Field, error) { childTypesByFieldName := map[string]*gql.InputObject{} caser := cases.Title(language.Und) @@ -562,21 +635,19 @@ func (g *Generator) genCountFieldConfig(obj *gql.Object) (gql.Field, error) { if _, isList := field.Type.(*gql.List); !isList { continue } - countableObject := gql.NewInputObject(gql.InputObjectConfig{ - Name: fmt.Sprintf("%s%s%s", obj.Name(), caser.String(field.Name), "CountInputObj"), - Fields: gql.InputObjectConfigFieldMap{ - "_": &gql.InputObjectFieldConfig{ - Type: gql.Int, - Description: "Placeholder - empty object not permitted, but will have fields shortly", - }, - }, - }) - childTypesByFieldName[field.Name] = countableObject - err := g.manager.schema.AppendType(countableObject) - if err != nil { - return gql.Field{}, err + inputObjectName := genTypeName(field.Type, "CountInputObj") + countableObject, isSubTypeCountableCollection := numBaseArgs[inputObjectName] + if !isSubTypeCountableCollection { + inputObjectName = genNumericInlineArrayCountName(obj.Name(), caser.String(field.Name)) + var isSubTypeCountableInlineArray bool + countableObject, isSubTypeCountableInlineArray = numBaseArgs[inputObjectName] + if !isSubTypeCountableInlineArray { + continue + } } + + childTypesByFieldName[field.Name] = countableObject } field := gql.Field{ @@ -704,6 +775,52 @@ func genNumericInlineArraySelectorName(hostName string, fieldName string) string return fmt.Sprintf("%s%s%s", hostName, caser.String(fieldName), "NumericInlineArraySelector") } +func (g *Generator) genCountBaseArgInputs(obj *gql.Object) *gql.InputObject { + countableObject := gql.NewInputObject(gql.InputObjectConfig{ + Name: genTypeName(obj, "CountInputObj"), + Fields: gql.InputObjectConfigFieldMap{ + "_": &gql.InputObjectFieldConfig{ + Type: gql.Int, + Description: "Placeholder - empty object not permitted, but will have fields shortly", + }, + }, + }) + + return countableObject +} + +func (g *Generator) genCountInlineArrayInputs(obj *gql.Object) []*gql.InputObject { + objects := []*gql.InputObject{} + caser := cases.Title(language.Und) + for _, field := range obj.Fields() { + // we can only act on list items + _, isList := field.Type.(*gql.List) + if !isList { + continue + } + + // If it is an inline scalar array then we require an empty + // object as an argument due to the lack of union input types + selectorObject := gql.NewInputObject(gql.InputObjectConfig{ + Name: genNumericInlineArrayCountName(obj.Name(), caser.String(field.Name)), + Fields: gql.InputObjectConfigFieldMap{ + "_": &gql.InputObjectFieldConfig{ + Type: gql.Int, + Description: "Placeholder - empty object not permitted, but will have fields shortly", + }, + }, + }) + + objects = append(objects, selectorObject) + } + return objects +} + +func genNumericInlineArrayCountName(hostName string, fieldName string) string { + caser := cases.Title(language.Und) + return fmt.Sprintf("%s%s%s", hostName, caser.String(fieldName), "InlineArrayCountInput") +} + // Generates the base (numeric-only) aggregate input object-type for the give gql object, // declaring which fields are available for aggregation. func (g *Generator) genNumericAggregateBaseArgInputs(obj *gql.Object) *gql.InputObject { diff --git a/tests/integration/query/simple/with_average_filter_test.go b/tests/integration/query/simple/with_average_filter_test.go new file mode 100644 index 0000000000..a99e4e710c --- /dev/null +++ b/tests/integration/query/simple/with_average_filter_test.go @@ -0,0 +1,49 @@ +// Copyright 2022 Democratized Data Foundation +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package simple + +import ( + "testing" + + testUtils "github.com/sourcenetwork/defradb/tests/integration" +) + +func TestQuerySimpleWithAverageWithFilter(t *testing.T) { + test := testUtils.QueryTestCase{ + Description: "Simple query, average with filter", + Query: `query { + _avg(users: {field: Age, filter: {Age: {_gt: 26}}}) + }`, + Docs: map[int][]string{ + 0: { + `{ + "Name": "John", + "Age": 21 + }`, + `{ + "Name": "Bob", + "Age": 30 + }`, + `{ + "Name": "Alice", + "Age": 32 + }`, + }, + }, + Results: []map[string]interface{}{ + { + "_avg": float64(31), + }, + }, + } + + executeTestCase(t, test) +} diff --git a/tests/integration/query/simple/with_average_test.go b/tests/integration/query/simple/with_average_test.go new file mode 100644 index 0000000000..b8d12fe3e1 --- /dev/null +++ b/tests/integration/query/simple/with_average_test.go @@ -0,0 +1,73 @@ +// Copyright 2022 Democratized Data Foundation +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package simple + +import ( + "testing" + + testUtils "github.com/sourcenetwork/defradb/tests/integration" +) + +func TestQuerySimpleWithAverageOnUndefined(t *testing.T) { + test := testUtils.QueryTestCase{ + Description: "Simple query, average on undefined", + Query: `query { + _avg + }`, + ExpectedError: "Aggregate must be provided with a property to aggregate.", + } + + executeTestCase(t, test) +} + +func TestQuerySimpleWithAverageOnEmptyCollection(t *testing.T) { + test := testUtils.QueryTestCase{ + Description: "Simple query, average on empty", + Query: `query { + _avg(users: {field: Age}) + }`, + Results: []map[string]interface{}{ + { + "_avg": float64(0), + }, + }, + } + + executeTestCase(t, test) +} + +func TestQuerySimpleWithAverage(t *testing.T) { + test := testUtils.QueryTestCase{ + Description: "Simple query, average", + Query: `query { + _avg(users: {field: Age}) + }`, + Docs: map[int][]string{ + 0: { + `{ + "Name": "John", + "Age": 28 + }`, + `{ + "Name": "Bob", + "Age": 30 + }`, + }, + }, + Results: []map[string]interface{}{ + { + "_avg": float64(29), + }, + }, + } + + executeTestCase(t, test) +} diff --git a/tests/integration/query/simple/with_count_filter_test.go b/tests/integration/query/simple/with_count_filter_test.go new file mode 100644 index 0000000000..fee530d216 --- /dev/null +++ b/tests/integration/query/simple/with_count_filter_test.go @@ -0,0 +1,49 @@ +// Copyright 2022 Democratized Data Foundation +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package simple + +import ( + "testing" + + testUtils "github.com/sourcenetwork/defradb/tests/integration" +) + +func TestQuerySimpleWithCountWithFilter(t *testing.T) { + test := testUtils.QueryTestCase{ + Description: "Simple query, count with filter", + Query: `query { + _count(users: {filter: {Age: {_gt: 26}}}) + }`, + Docs: map[int][]string{ + 0: { + `{ + "Name": "John", + "Age": 21 + }`, + `{ + "Name": "Bob", + "Age": 30 + }`, + `{ + "Name": "Alice", + "Age": 32 + }`, + }, + }, + Results: []map[string]interface{}{ + { + "_count": 2, + }, + }, + } + + executeTestCase(t, test) +} diff --git a/tests/integration/query/simple/with_count_test.go b/tests/integration/query/simple/with_count_test.go new file mode 100644 index 0000000000..4b0ef066a0 --- /dev/null +++ b/tests/integration/query/simple/with_count_test.go @@ -0,0 +1,73 @@ +// Copyright 2022 Democratized Data Foundation +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package simple + +import ( + "testing" + + testUtils "github.com/sourcenetwork/defradb/tests/integration" +) + +func TestQuerySimpleWithCountOnUndefined(t *testing.T) { + test := testUtils.QueryTestCase{ + Description: "Simple query, count on undefined", + Query: `query { + _count + }`, + ExpectedError: "Aggregate must be provided with a property to aggregate.", + } + + executeTestCase(t, test) +} + +func TestQuerySimpleWithCountOnEmptyCollection(t *testing.T) { + test := testUtils.QueryTestCase{ + Description: "Simple query, count on empty", + Query: `query { + _count(users: {}) + }`, + Results: []map[string]interface{}{ + { + "_count": 0, + }, + }, + } + + executeTestCase(t, test) +} + +func TestQuerySimpleWithCount(t *testing.T) { + test := testUtils.QueryTestCase{ + Description: "Simple query, count", + Query: `query { + _count(users: {}) + }`, + Docs: map[int][]string{ + 0: { + `{ + "Name": "John", + "Age": 21 + }`, + `{ + "Name": "Bob", + "Age": 30 + }`, + }, + }, + Results: []map[string]interface{}{ + { + "_count": 2, + }, + }, + } + + executeTestCase(t, test) +} diff --git a/tests/integration/query/simple/with_sum_filter_test.go b/tests/integration/query/simple/with_sum_filter_test.go new file mode 100644 index 0000000000..6c595cfd70 --- /dev/null +++ b/tests/integration/query/simple/with_sum_filter_test.go @@ -0,0 +1,49 @@ +// Copyright 2022 Democratized Data Foundation +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package simple + +import ( + "testing" + + testUtils "github.com/sourcenetwork/defradb/tests/integration" +) + +func TestQuerySimpleWithSumWithFilter(t *testing.T) { + test := testUtils.QueryTestCase{ + Description: "Simple query, sum with filter", + Query: `query { + _sum(users: {field: Age, filter: {Age: {_gt: 26}}}) + }`, + Docs: map[int][]string{ + 0: { + `{ + "Name": "John", + "Age": 21 + }`, + `{ + "Name": "Bob", + "Age": 30 + }`, + `{ + "Name": "Alice", + "Age": 32 + }`, + }, + }, + Results: []map[string]interface{}{ + { + "_sum": int64(62), + }, + }, + } + + executeTestCase(t, test) +} diff --git a/tests/integration/query/simple/with_sum_test.go b/tests/integration/query/simple/with_sum_test.go new file mode 100644 index 0000000000..cd72a23d51 --- /dev/null +++ b/tests/integration/query/simple/with_sum_test.go @@ -0,0 +1,73 @@ +// Copyright 2022 Democratized Data Foundation +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package simple + +import ( + "testing" + + testUtils "github.com/sourcenetwork/defradb/tests/integration" +) + +func TestQuerySimpleWithSumOnUndefined(t *testing.T) { + test := testUtils.QueryTestCase{ + Description: "Simple query, sum on undefined", + Query: `query { + _sum + }`, + ExpectedError: "Aggregate must be provided with a property to aggregate.", + } + + executeTestCase(t, test) +} + +func TestQuerySimpleWithSumOnEmptyCollection(t *testing.T) { + test := testUtils.QueryTestCase{ + Description: "Simple query, sum on empty", + Query: `query { + _sum(users: {field: Age}) + }`, + Results: []map[string]interface{}{ + { + "_sum": int64(0), + }, + }, + } + + executeTestCase(t, test) +} + +func TestQuerySimpleWithSum(t *testing.T) { + test := testUtils.QueryTestCase{ + Description: "Simple query, sum", + Query: `query { + _sum(users: {field: Age}) + }`, + Docs: map[int][]string{ + 0: { + `{ + "Name": "John", + "Age": 21 + }`, + `{ + "Name": "Bob", + "Age": 30 + }`, + }, + }, + Results: []map[string]interface{}{ + { + "_sum": int64(51), + }, + }, + } + + executeTestCase(t, test) +}