diff --git a/client/collection.go b/client/collection.go index 20597e9ea3..c1a0e889a2 100644 --- a/client/collection.go +++ b/client/collection.go @@ -83,26 +83,26 @@ type Collection interface { // // Returns an ErrInvalidUpdateTarget error if the target type is not supported. // Returns an ErrInvalidUpdater error if the updater type is not supported. - UpdateWith(ctx context.Context, target interface{}, updater interface{}) (*UpdateResult, error) + UpdateWith(ctx context.Context, target interface{}, updater string) (*UpdateResult, error) // UpdateWithFilter updates using a filter to target documents for update. // // The provided updater must be a string Patch, string Merge Patch, a parsed Patch, or parsed Merge Patch // else an ErrInvalidUpdater will be returned. - UpdateWithFilter(ctx context.Context, filter interface{}, updater interface{}) (*UpdateResult, error) + UpdateWithFilter(ctx context.Context, filter interface{}, updater string) (*UpdateResult, error) // UpdateWithKey updates using a DocKey to target a single document for update. // // The provided updater must be a string Patch, string Merge Patch, a parsed Patch, or parsed Merge Patch // else an ErrInvalidUpdater will be returned. // // Returns an ErrDocumentNotFound if a document matching the given DocKey is not found. - UpdateWithKey(ctx context.Context, key DocKey, updater interface{}) (*UpdateResult, error) + UpdateWithKey(ctx context.Context, key DocKey, updater string) (*UpdateResult, error) // UpdateWithKeys updates documents matching the given DocKeys. // // The provided updater must be a string Patch, string Merge Patch, a parsed Patch, or parsed Merge Patch // else an ErrInvalidUpdater will be returned. // // Returns an ErrDocumentNotFound if a document is not found for any given DocKey. - UpdateWithKeys(context.Context, []DocKey, interface{}) (*UpdateResult, error) + UpdateWithKeys(context.Context, []DocKey, string) (*UpdateResult, error) // DeleteWith deletes a target document. // diff --git a/db/collection_update.go b/db/collection_update.go index 8916095f76..96253e1d14 100644 --- a/db/collection_update.go +++ b/db/collection_update.go @@ -12,7 +12,6 @@ package db import ( "context" - "encoding/json" "errors" "fmt" "strings" @@ -28,6 +27,7 @@ import ( parserTypes "github.com/sourcenetwork/defradb/query/graphql/parser/types" cbor "github.com/fxamacker/cbor/v2" + "github.com/valyala/fastjson" ) var ( @@ -46,7 +46,7 @@ var ( func (c *collection) UpdateWith( ctx context.Context, target interface{}, - updater interface{}, + updater string, ) (*client.UpdateResult, error) { switch t := target.(type) { case string, map[string]interface{}, *parser.Filter: @@ -66,7 +66,7 @@ func (c *collection) UpdateWith( func (c *collection) UpdateWithFilter( ctx context.Context, filter interface{}, - updater interface{}, + updater string, ) (*client.UpdateResult, error) { txn, err := c.getTxn(ctx, false) if err != nil { @@ -86,7 +86,7 @@ func (c *collection) UpdateWithFilter( func (c *collection) UpdateWithKey( ctx context.Context, key client.DocKey, - updater interface{}, + updater string, ) (*client.UpdateResult, error) { txn, err := c.getTxn(ctx, false) if err != nil { @@ -107,7 +107,7 @@ func (c *collection) UpdateWithKey( func (c *collection) UpdateWithKeys( ctx context.Context, keys []client.DocKey, - updater interface{}, + updater string, ) (*client.UpdateResult, error) { txn, err := c.getTxn(ctx, false) if err != nil { @@ -126,20 +126,17 @@ func (c *collection) updateWithKey( ctx context.Context, txn datastore.Txn, key client.DocKey, - updater interface{}, + updater string, ) (*client.UpdateResult, error) { - patch, err := parseUpdater(updater) + parsedUpdater, err := fastjson.Parse(updater) if err != nil { return nil, err } isPatch := false - switch patch.(type) { - case []map[string]interface{}: + if parsedUpdater.Type() == fastjson.TypeArray { isPatch = true - case map[string]interface{}: - isPatch = false - default: + } else if parsedUpdater.Type() != fastjson.TypeObject { return nil, client.ErrInvalidUpdater } @@ -155,7 +152,7 @@ func (c *collection) updateWithKey( if isPatch { // todo } else { - err = c.applyMerge(ctx, txn, v, patch.(map[string]interface{})) + err = c.applyMerge(ctx, txn, v, parsedUpdater.GetObject()) } if err != nil { return nil, err @@ -172,20 +169,17 @@ func (c *collection) updateWithKeys( ctx context.Context, txn datastore.Txn, keys []client.DocKey, - updater interface{}, + updater string, ) (*client.UpdateResult, error) { - patch, err := parseUpdater(updater) + parsedUpdater, err := fastjson.Parse(updater) if err != nil { return nil, err } isPatch := false - switch patch.(type) { - case []map[string]interface{}: + if parsedUpdater.Type() == fastjson.TypeArray { isPatch = true - case map[string]interface{}: - isPatch = false - default: + } else if parsedUpdater.Type() != fastjson.TypeObject { return nil, client.ErrInvalidUpdater } @@ -205,7 +199,7 @@ func (c *collection) updateWithKeys( if isPatch { // todo } else { - err = c.applyMerge(ctx, txn, v, patch.(map[string]interface{})) + err = c.applyMerge(ctx, txn, v, parsedUpdater.GetObject()) } if err != nil { return nil, err @@ -221,19 +215,19 @@ func (c *collection) updateWithFilter( ctx context.Context, txn datastore.Txn, filter interface{}, - updater interface{}, + updater string, ) (*client.UpdateResult, error) { - patch, err := parseUpdater(updater) + parsedUpdater, err := fastjson.Parse(updater) if err != nil { return nil, err } isPatch := false isMerge := false - switch patch.(type) { - case []map[string]interface{}: + switch parsedUpdater.Type() { + case fastjson.TypeArray: isPatch = true - case map[string]interface{}: + case fastjson.TypeObject: isMerge = true default: return nil, client.ErrInvalidUpdater @@ -275,9 +269,9 @@ func (c *collection) updateWithFilter( // Get the document, and apply the patch doc := docMap.ToMap(query.Value()) if isPatch { - err = c.applyPatch(txn, doc, patch.([]map[string]interface{})) + err = c.applyPatch(txn, doc, parsedUpdater.GetArray()) } else if isMerge { // else is fine here - err = c.applyMerge(ctx, txn, doc, patch.(map[string]interface{})) + err = c.applyMerge(ctx, txn, doc, parsedUpdater.GetObject()) } if err != nil { return nil, err @@ -294,25 +288,29 @@ func (c *collection) updateWithFilter( func (c *collection) applyPatch( txn datastore.Txn, doc map[string]interface{}, - patch []map[string]interface{}, + patch []*fastjson.Value, ) error { for _, op := range patch { - path, ok := op["path"].(string) - if !ok { - return errors.New("Missing document field to update") + opObject, err := op.Object() + if err != nil { + return err + } + path, err := opObject.Get("path").StringBytes() + if err != nil { + return fmt.Errorf("Missing document field to update: %w", err) } - targetCollection, _, err := c.getCollectionForPatchOpPath(txn, path) + targetCollection, _, err := c.getCollectionForPatchOpPath(txn, string(path)) if err != nil { return err } - key, err := c.getTargetKeyForPatchPath(txn, doc, path) + key, err := c.getTargetKeyForPatchPath(txn, doc, string(path)) if err != nil { return err } - field, val, _ := getValFromDocForPatchPath(doc, path) - if err := targetCollection.applyPatchOp(txn, key, field, val, op); err != nil { + field, val, _ := getValFromDocForPatchPath(doc, string(path)) + if err := targetCollection.applyPatchOp(txn, key, field, val, opObject); err != nil { return err } } @@ -326,7 +324,7 @@ func (c *collection) applyPatchOp( dockey string, field string, currentVal interface{}, - patchOp map[string]interface{}, + patchOp *fastjson.Object, ) error { return nil } @@ -335,7 +333,7 @@ func (c *collection) applyMerge( ctx context.Context, txn datastore.Txn, doc map[string]interface{}, - merge map[string]interface{}, + merge *fastjson.Object, ) error { keyStr, ok := doc["_key"].(string) if !ok { @@ -343,8 +341,16 @@ func (c *collection) applyMerge( } key := c.getPrimaryKey(keyStr) links := make([]core.DAGLink, 0) - for mfield, mval := range merge { - if _, ok := mval.(map[string]interface{}); ok { + + mergeMap := make(map[string]*fastjson.Value) + merge.Visit(func(k []byte, v *fastjson.Value) { + mergeMap[string(k)] = v + }) + + mergeCBOR := make(map[string]any) + + for mfield, mval := range mergeMap { + if mval.Type() == fastjson.TypeObject { return ErrInvalidMergeValueType } @@ -353,23 +359,13 @@ func (c *collection) applyMerge( return errors.New("Invalid field in Patch") } - cval, err := validateFieldSchema(mval, fd) + var err error + mergeCBOR[mfield], err = validateFieldSchema(mval, fd) if err != nil { return err } - // handle Int/Float case - // JSON is annoying in that it represents all numbers - // as Float64s. So our merge object contains float64s - // even for fields defined as Ints, which causes issues - // when we serialize that in CBOR. To generate the delta - // payload. - // So let's just make sure ints are ints ref: https://play.golang.org/p/djThEqGXtvR - if fd.Kind == client.FieldKind_INT { - merge[mfield] = int64(mval.(float64)) - } - - val := client.NewCBORValue(fd.Typ, cval) + val := client.NewCBORValue(fd.Typ, mergeCBOR[mfield]) fieldKey, fieldExists := c.tryGetFieldKey(key, mfield) if !fieldExists { return client.ErrFieldNotExist @@ -391,7 +387,7 @@ func (c *collection) applyMerge( if err != nil { return err } - buf, err := em.Marshal(merge) + buf, err := em.Marshal(mergeCBOR) if err != nil { return err } @@ -429,199 +425,133 @@ func (c *collection) applyMerge( // and ensures it matches the supplied field description. // It will do any minor parsing, like dates, and return // the typed value again as an interface. -func validateFieldSchema(val interface{}, field client.FieldDescription) (interface{}, error) { - var cval interface{} - var err error - var ok bool +func validateFieldSchema(val *fastjson.Value, field client.FieldDescription) (interface{}, error) { switch field.Kind { case client.FieldKind_DocKey, client.FieldKind_STRING: - cval, ok = val.(string) + return getString(val) + case client.FieldKind_STRING_ARRAY: - if val == nil { - ok = true - cval = nil - break - } - untypedCollection := val.([]interface{}) - stringArray := make([]string, len(untypedCollection)) - for i, value := range untypedCollection { - if value == nil { - stringArray[i] = "" - continue - } - stringArray[i], ok = value.(string) - if !ok { - return nil, fmt.Errorf( - "Failed to cast value: %v of type: %T to string", - value, - value, - ) - } - } - ok = true - cval = stringArray + return array(val, getString, "") case client.FieldKind_NILLABLE_STRING_ARRAY: - cval, err = convertNillableArray[string](val) - if err != nil { - return nil, err - } - ok = true + return convertNillableArray(val, getString) case client.FieldKind_BOOL: - cval, ok = val.(bool) + return val.Bool() + case client.FieldKind_BOOL_ARRAY: - if val == nil { - ok = true - cval = nil - break - } - untypedCollection := val.([]interface{}) - boolArray := make([]bool, len(untypedCollection)) - for i, value := range untypedCollection { - boolArray[i], ok = value.(bool) - if !ok { - return nil, fmt.Errorf("Failed to cast value: %v of type: %T to bool", value, value) - } - } - ok = true - cval = boolArray + return array(val, getBool, false) case client.FieldKind_NILLABLE_BOOL_ARRAY: - cval, err = convertNillableArray[bool](val) - if err != nil { - return nil, err - } - ok = true + return convertNillableArray(val, getBool) case client.FieldKind_FLOAT, client.FieldKind_DECIMAL: - cval, ok = val.(float64) + return val.Float64() + case client.FieldKind_FLOAT_ARRAY: - if val == nil { - ok = true - cval = nil - break - } - untypedCollection := val.([]interface{}) - floatArray := make([]float64, len(untypedCollection)) - for i, value := range untypedCollection { - floatArray[i], ok = value.(float64) - if !ok { - return nil, fmt.Errorf( - "Failed to cast value: %v of type: %T to float64", - value, - value, - ) - } - } - ok = true - cval = floatArray + return array(val, getFloat64, 0) case client.FieldKind_NILLABLE_FLOAT_ARRAY: - cval, err = convertNillableArray[float64](val) + return convertNillableArray(val, getFloat64) + + case client.FieldKind_DATE: + v, err := val.StringBytes() if err != nil { return nil, err } - ok = true + return time.Parse(time.RFC3339, string(v)) - case client.FieldKind_DATE: - var sval string - sval, ok = val.(string) - cval, err = time.Parse(time.RFC3339, sval) case client.FieldKind_INT: - var fval float64 - fval, ok = val.(float64) - if !ok { - return nil, ErrInvalidMergeValueType - } - cval = int64(fval) + return val.Int64() + case client.FieldKind_INT_ARRAY: - if val == nil { - ok = true - cval = nil - break - } - untypedCollection := val.([]interface{}) - intArray := make([]int64, len(untypedCollection)) - for i, value := range untypedCollection { - valueAsFloat, castOk := value.(float64) - if !castOk { - return nil, fmt.Errorf( - "Failed to cast value: %v of type: %T to float64", - value, - value, - ) - } - intArray[i] = int64(valueAsFloat) - } - ok = true - cval = intArray + return array(val, getInt64, 0) case client.FieldKind_NILLABLE_INT_ARRAY: - cval, err = convertNillableArrayWithConverter(val, func(in float64) int64 { return int64(in) }) - if err != nil { - return nil, err - } - ok = true + return convertNillableArray(val, getInt64) case client.FieldKind_OBJECT, client.FieldKind_OBJECT_ARRAY, client.FieldKind_FOREIGN_OBJECT, client.FieldKind_FOREIGN_OBJECT_ARRAY: - err = errors.New("Merge doesn't support sub types yet") + return nil, errors.New("Merge doesn't support sub types yet") } - if !ok { - return nil, ErrInvalidMergeValueType - } - if err != nil { - return nil, err - } + return nil, errors.New("Merge doesn't support sub types yet") +} + +func getString(v *fastjson.Value) (string, error) { + b, err := v.StringBytes() + return string(b), err +} + +func getBool(v *fastjson.Value) (bool, error) { + b, err := v.Bool() + return b, err +} + +func getFloat64(v *fastjson.Value) (float64, error) { + f, err := v.Float64() + return f, err +} - return cval, err +func getInt64(v *fastjson.Value) (int64, error) { + f, err := v.Int64() + return f, err } -func convertNillableArray[T any](val any) ([]*T, error) { - if val == nil { +func array[T any]( + val *fastjson.Value, + typeGetter func(*fastjson.Value) (T, error), + zeroValue T, +) ([]T, error) { + if val.Type() == fastjson.TypeNull { return nil, nil } - untypedCollection := val.([]interface{}) - // Cbor deals with pointers better than structs by default, however in the future - // we may want to write a custom encoder for the Option[T] type - resultArray := make([]*T, len(untypedCollection)) - for i, value := range untypedCollection { - if value == nil { - resultArray[i] = nil + + valArray, err := val.Array() + if err != nil { + return nil, err + } + + arr := make([]T, len(valArray)) + for i, arrItem := range valArray { + if arrItem.Type() == fastjson.TypeNull { + arr[i] = zeroValue continue } - tValue, ok := value.(T) - if !ok { - return nil, fmt.Errorf("Failed to cast value: %v of type: %T to %T", value, value, *new(T)) + arr[i], err = typeGetter(arrItem) + if err != nil { + return nil, err } - resultArray[i] = &tValue } - return resultArray, nil + return arr, nil } -func convertNillableArrayWithConverter[TIn any, TOut any](val any, converter func(TIn) TOut) ([]*TOut, error) { - if val == nil { +func convertNillableArray[T any]( + val *fastjson.Value, + typeGetter func(*fastjson.Value) (T, error), +) ([]*T, error) { + if val.Type() == fastjson.TypeNull { return nil, nil } - untypedCollection := val.([]interface{}) - // Cbor deals with pointers better than structs by default, however in the future - // we may want to write a custom encoder for the Option[T] type - resultArray := make([]*TOut, len(untypedCollection)) - for i, value := range untypedCollection { - if value == nil { - resultArray[i] = nil + + valArray, err := val.Array() + if err != nil { + return nil, err + } + + arr := make([]*T, len(valArray)) + for i, arrItem := range valArray { + if arrItem.Type() == fastjson.TypeNull { + arr[i] = nil continue } - tValue, ok := value.(TIn) - if !ok { - return nil, fmt.Errorf("Failed to cast value: %v of type: %T to %T", value, value, *new(TIn)) + v, err := typeGetter(arrItem) + if err != nil { + return nil, err } - outValue := converter(tValue) - resultArray[i] = &outValue + arr[i] = &v } - return resultArray, nil + return arr, nil } func (c *collection) applyMergePatchOp( //nolint:unused @@ -783,53 +713,6 @@ func getMapProp( return paths[0], val, true } -type patcher interface{} - -func parseUpdater(updater interface{}) (patcher, error) { - switch v := updater.(type) { - case string: - return parseUpdaterString(v) - case []interface{}: - return parseUpdaterSlice(v) - case []map[string]interface{}, map[string]interface{}: - return patcher(v), nil - case nil: - return nil, ErrUpdateEmpty - default: - return nil, client.ErrInvalidUpdater - } -} - -func parseUpdaterString(v string) (patcher, error) { - if v == "" { - return nil, ErrUpdateEmpty - } - var i interface{} - if err := json.Unmarshal([]byte(v), &i); err != nil { - return nil, err - } - return parseUpdater(i) -} - -// converts an []interface{} to []map[string]interface{} -// which is required to be an array of Patch Ops -func parseUpdaterSlice(v []interface{}) (patcher, error) { - if len(v) == 0 { - return nil, ErrUpdateEmpty - } - - patches := make([]map[string]interface{}, len(v)) - for i, patch := range v { - p, ok := patch.(map[string]interface{}) - if !ok { - return nil, client.ErrInvalidUpdater - } - patches[i] = p - } - - return parseUpdater(patches) -} - /* filter := NewFilterFromString("Name: {_eq: 'bob'}") diff --git a/go.mod b/go.mod index 087cacc4d5..41f2f28504 100644 --- a/go.mod +++ b/go.mod @@ -185,6 +185,7 @@ require ( github.com/stretchr/objx v0.1.1 // indirect github.com/subosito/gotenv v1.2.0 // indirect github.com/textileio/go-log/v2 v2.1.3-gke-2 // indirect + github.com/valyala/fastjson v1.6.3 // indirect github.com/whyrusleeping/cbor-gen v0.0.0-20220514204315-f29c37e9c44c // indirect github.com/whyrusleeping/chunker v0.0.0-20181014151217-fe64bd25879f // indirect github.com/whyrusleeping/go-keyspace v0.0.0-20160322163242-5b898ac5add1 // indirect diff --git a/go.sum b/go.sum index 633863f4eb..2713718a01 100644 --- a/go.sum +++ b/go.sum @@ -1565,6 +1565,8 @@ github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLY github.com/urfave/cli v1.20.0/go.mod h1:70zkFmudgCuE/ngEzBv17Jvp/497gISqfk5gWijbERA= github.com/urfave/cli v1.22.1/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0= github.com/urfave/cli v1.22.2/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0= +github.com/valyala/fastjson v1.6.3 h1:tAKFnnwmeMGPbwJ7IwxcTPCNr3uIzoIj3/Fh90ra4xc= +github.com/valyala/fastjson v1.6.3/go.mod h1:CLCAqky6SMuOcxStkYQvblddUtoRxhYMGLrsQns1aXY= github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49uaYMPRU= github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM= github.com/wangjia184/sortedset v0.0.0-20160527075905-f5d03557ba30/go.mod h1:YkocrP2K2tcw938x9gCOmT5G5eCD6jsTz0SZuyAqwIE=