diff --git a/cmd/util/cmd/execution-state-extract/cmd.go b/cmd/util/cmd/execution-state-extract/cmd.go index 7c10e8dbdcb..b22ed8c5da3 100644 --- a/cmd/util/cmd/execution-state-extract/cmd.go +++ b/cmd/util/cmd/execution-state-extract/cmd.go @@ -15,15 +15,17 @@ import ( ) var ( - flagExecutionStateDir string - flagOutputDir string - flagBlockHash string - flagStateCommitment string - flagDatadir string - flagChain string - flagNoMigration bool - flagNoReport bool - flagNWorker int + flagExecutionStateDir string + flagOutputDir string + flagBlockHash string + flagStateCommitment string + flagDatadir string + flagChain string + flagNWorker int + flagNoMigration bool + flagNoReport bool + flagValidateMigration bool + flagLogVerboseValidationError bool ) var Cmd = &cobra.Command{ @@ -59,6 +61,13 @@ func init() { "don't report the state") Cmd.Flags().IntVar(&flagNWorker, "n-migrate-worker", 10, "number of workers to migrate payload concurrently") + + Cmd.Flags().BoolVar(&flagValidateMigration, "validate", false, + "validate migrated Cadence values (atree migration)") + + Cmd.Flags().BoolVar(&flagLogVerboseValidationError, "log-verbose-validation-error", false, + "log entire Cadence values on validation error (atree migration)") + } func run(*cobra.Command, []string) { @@ -135,6 +144,14 @@ func run(*cobra.Command, []string) { log.Warn().Msgf("--no-migration flag is deprecated") } + if flagValidateMigration { + log.Warn().Msgf("atree migration validation flag is enabled and will increase duration of migration") + } + + if flagLogVerboseValidationError { + log.Warn().Msgf("atree migration has verbose validation error logging enabled which may increase size of log") + } + err := extractExecutionState( flagExecutionStateDir, stateCommitment, diff --git a/cmd/util/cmd/execution-state-extract/execution_state_extract.go b/cmd/util/cmd/execution-state-extract/execution_state_extract.go index 18eba0f5281..f379bc2f9ec 100644 --- a/cmd/util/cmd/execution-state-extract/execution_state_extract.go +++ b/cmd/util/cmd/execution-state-extract/execution_state_extract.go @@ -85,19 +85,18 @@ func extractExecutionState( rwf := reporters.NewReportFileWriterFactory(dir, log) - cadenceDataValidation := migrators.NewCadenceDataValidationMigrations(rwf, nWorker) - var migrations = []ledger.Migration{ migrators.CreateAccountBasedMigration( log, nWorker, []migrators.AccountBasedMigration{ - cadenceDataValidation.PreMigration(), migrators.NewAtreeRegisterMigrator( rwf, + flagValidateMigration, + flagLogVerboseValidationError, ), + &migrators.DeduplicateContractNamesMigration{}, - cadenceDataValidation.PostMigration(), // This will fix storage used discrepancies caused by the // DeduplicateContractNamesMigration. diff --git a/cmd/util/ledger/migrations/atree_register_migration.go b/cmd/util/ledger/migrations/atree_register_migration.go index fe50671dcfb..afb04e63f4f 100644 --- a/cmd/util/ledger/migrations/atree_register_migration.go +++ b/cmd/util/ledger/migrations/atree_register_migration.go @@ -35,6 +35,9 @@ type AtreeRegisterMigrator struct { rwf reporters.ReportWriterFactory nWorkers int + + validateMigratedValues bool + logVerboseValidationError bool } var _ AccountBasedMigration = (*AtreeRegisterMigrator)(nil) @@ -42,15 +45,18 @@ var _ io.Closer = (*AtreeRegisterMigrator)(nil) func NewAtreeRegisterMigrator( rwf reporters.ReportWriterFactory, + validateMigratedValues bool, + logVerboseValidationError bool, ) *AtreeRegisterMigrator { sampler := util2.NewTimedSampler(30 * time.Second) migrator := &AtreeRegisterMigrator{ - sampler: sampler, - - rwf: rwf, - rw: rwf.ReportWriter("atree-register-migrator"), + sampler: sampler, + rwf: rwf, + rw: rwf.ReportWriter("atree-register-migrator"), + validateMigratedValues: validateMigratedValues, + logVerboseValidationError: logVerboseValidationError, } return migrator @@ -108,6 +114,13 @@ func (m *AtreeRegisterMigrator) MigrateAccount( return nil, err } + if m.validateMigratedValues { + err = validateCadenceValues(address, oldPayloads, newPayloads, m.log, m.logVerboseValidationError) + if err != nil { + return nil, err + } + } + newLen := len(newPayloads) if newLen > originalLen { diff --git a/cmd/util/ledger/migrations/atree_register_migration_test.go b/cmd/util/ledger/migrations/atree_register_migration_test.go index e905d12347b..da6d9f7fdfb 100644 --- a/cmd/util/ledger/migrations/atree_register_migration_test.go +++ b/cmd/util/ledger/migrations/atree_register_migration_test.go @@ -36,7 +36,7 @@ func TestAtreeRegisterMigration(t *testing.T) { migrations.CreateAccountBasedMigration(log, 2, []migrations.AccountBasedMigration{ validation.PreMigration(), - migrations.NewAtreeRegisterMigrator(reporters.NewReportFileWriterFactory(dir, log)), + migrations.NewAtreeRegisterMigrator(reporters.NewReportFileWriterFactory(dir, log), true, false), validation.PostMigration(), }, ), diff --git a/cmd/util/ledger/migrations/cadence_value_validation.go b/cmd/util/ledger/migrations/cadence_value_validation.go new file mode 100644 index 00000000000..ff45b2e2c97 --- /dev/null +++ b/cmd/util/ledger/migrations/cadence_value_validation.go @@ -0,0 +1,599 @@ +package migrations + +import ( + "fmt" + "strings" + "time" + + "github.com/onflow/atree" + "github.com/onflow/cadence" + "github.com/onflow/cadence/runtime" + "github.com/onflow/cadence/runtime/common" + "github.com/onflow/cadence/runtime/interpreter" + "github.com/rs/zerolog" + "go.opentelemetry.io/otel/attribute" + + "github.com/onflow/flow-go/cmd/util/ledger/util" + "github.com/onflow/flow-go/ledger" +) + +var nopMemoryGauge = util.NopMemoryGauge{} + +// TODO: optimize memory by reusing payloads snapshot created for migration +func validateCadenceValues( + address common.Address, + oldPayloads []*ledger.Payload, + newPayloads []*ledger.Payload, + log zerolog.Logger, + verboseLogging bool, +) error { + // Create all the runtime components we need for comparing Cadence values. + oldRuntime, err := newReadonlyStorageRuntime(oldPayloads) + if err != nil { + return fmt.Errorf("failed to create validator runtime with old payloads: %w", err) + } + + newRuntime, err := newReadonlyStorageRuntime(newPayloads) + if err != nil { + return fmt.Errorf("failed to create validator runtime with new payloads: %w", err) + } + + // Iterate through all domains and compare cadence values. + for _, domain := range domains { + err := validateStorageDomain(address, oldRuntime, newRuntime, domain, log, verboseLogging) + if err != nil { + return err + } + } + + return nil +} + +func validateStorageDomain( + address common.Address, + oldRuntime *readonlyStorageRuntime, + newRuntime *readonlyStorageRuntime, + domain string, + log zerolog.Logger, + verboseLogging bool, +) error { + + oldStorageMap := oldRuntime.Storage.GetStorageMap(address, domain, false) + + newStorageMap := newRuntime.Storage.GetStorageMap(address, domain, false) + + if oldStorageMap == nil && newStorageMap == nil { + // No storage for this domain. + return nil + } + + if oldStorageMap == nil && newStorageMap != nil { + return fmt.Errorf("old storage map is nil, new storage map isn't nil") + } + + if oldStorageMap != nil && newStorageMap == nil { + return fmt.Errorf("old storage map isn't nil, new storage map is nil") + } + + if oldStorageMap.Count() != newStorageMap.Count() { + return fmt.Errorf("old storage map count %d, new storage map count %d", oldStorageMap.Count(), newStorageMap.Count()) + } + + oldIterator := oldStorageMap.Iterator(nopMemoryGauge) + for { + key, oldValue := oldIterator.Next() + if key == nil { + break + } + + stringKey, ok := key.(interpreter.StringAtreeValue) + if !ok { + return fmt.Errorf("invalid key type %T, expected interpreter.StringAtreeValue", key) + } + + newValue := newStorageMap.ReadValue(nopMemoryGauge, interpreter.StringStorageMapKey(stringKey)) + + err := cadenceValueEqual(oldRuntime.Interpreter, oldValue, newRuntime.Interpreter, newValue) + if err != nil { + if verboseLogging { + log.Info(). + Str("address", address.Hex()). + Str("domain", domain). + Str("key", string(stringKey)). + Str("trace", err.Error()). + Str("old value", oldValue.String()). + Str("new value", newValue.String()). + Msgf("failed to validate value") + } + + return fmt.Errorf("failed to validate value for address %s, domain %s, key %s: %s", address.Hex(), domain, key, err.Error()) + } + } + + return nil +} + +type validationError struct { + trace []string + errorMsg string + traceReversed bool +} + +func newValidationErrorf(format string, a ...any) *validationError { + return &validationError{ + errorMsg: fmt.Sprintf(format, a...), + } +} + +func (e *validationError) addTrace(trace string) { + e.trace = append(e.trace, trace) +} + +func (e *validationError) Error() string { + if len(e.trace) == 0 { + return fmt.Sprintf("failed to validate: %s", e.errorMsg) + } + // Reverse trace + if !e.traceReversed { + for i, j := 0, len(e.trace)-1; i < j; i, j = i+1, j-1 { + e.trace[i], e.trace[j] = e.trace[j], e.trace[i] + } + e.traceReversed = true + } + trace := strings.Join(e.trace, ".") + return fmt.Sprintf("failed to validate %s: %s", trace, e.errorMsg) +} + +func cadenceValueEqual( + vInterpreter *interpreter.Interpreter, + v interpreter.Value, + otherInterpreter *interpreter.Interpreter, + other interpreter.Value, +) *validationError { + switch v := v.(type) { + case *interpreter.ArrayValue: + return cadenceArrayValueEqual(vInterpreter, v, otherInterpreter, other) + + case *interpreter.CompositeValue: + return cadenceCompositeValueEqual(vInterpreter, v, otherInterpreter, other) + + case *interpreter.DictionaryValue: + return cadenceDictionaryValueEqual(vInterpreter, v, otherInterpreter, other) + + case *interpreter.SomeValue: + return cadenceSomeValueEqual(vInterpreter, v, otherInterpreter, other) + + default: + oldValue, ok := v.(interpreter.EquatableValue) + if !ok { + return newValidationErrorf( + "value doesn't implement interpreter.EquatableValue: %T", + oldValue, + ) + } + if !oldValue.Equal(nil, interpreter.EmptyLocationRange, other) { + return newValidationErrorf( + "values differ: %v (%T) != %v (%T)", + oldValue, + oldValue, + other, + other, + ) + } + } + + return nil +} + +func cadenceSomeValueEqual( + vInterpreter *interpreter.Interpreter, + v *interpreter.SomeValue, + otherInterpreter *interpreter.Interpreter, + other interpreter.Value, +) *validationError { + otherSome, ok := other.(*interpreter.SomeValue) + if !ok { + return newValidationErrorf("types differ: %T != %T", v, other) + } + + innerValue := v.InnerValue(vInterpreter, interpreter.EmptyLocationRange) + + otherInnerValue := otherSome.InnerValue(otherInterpreter, interpreter.EmptyLocationRange) + + return cadenceValueEqual(vInterpreter, innerValue, otherInterpreter, otherInnerValue) +} + +func cadenceArrayValueEqual( + vInterpreter *interpreter.Interpreter, + v *interpreter.ArrayValue, + otherInterpreter *interpreter.Interpreter, + other interpreter.Value, +) *validationError { + otherArray, ok := other.(*interpreter.ArrayValue) + if !ok { + return newValidationErrorf("types differ: %T != %T", v, other) + } + + count := v.Count() + if count != otherArray.Count() { + return newValidationErrorf("array counts differ: %d != %d", count, otherArray.Count()) + } + + if v.Type == nil { + if otherArray.Type != nil { + return newValidationErrorf("array types differ: nil != %s", otherArray.Type) + } + } else { // v.Type != nil + if otherArray.Type == nil { + return newValidationErrorf("array types differ: %s != nil", v.Type) + } else if !v.Type.Equal(otherArray.Type) { + return newValidationErrorf("array types differ: %s != %s", v.Type, otherArray.Type) + } + } + + for i := 0; i < count; i++ { + element := v.Get(vInterpreter, interpreter.EmptyLocationRange, i) + otherElement := otherArray.Get(otherInterpreter, interpreter.EmptyLocationRange, i) + + err := cadenceValueEqual(vInterpreter, element, otherInterpreter, otherElement) + if err != nil { + err.addTrace(fmt.Sprintf("(%s[%d])", v.Type, i)) + return err + } + } + + return nil +} + +func cadenceCompositeValueEqual( + vInterpreter *interpreter.Interpreter, + v *interpreter.CompositeValue, + otherInterpreter *interpreter.Interpreter, + other interpreter.Value, +) *validationError { + otherComposite, ok := other.(*interpreter.CompositeValue) + if !ok { + return newValidationErrorf("types differ: %T != %T", v, other) + } + + if !v.StaticType(vInterpreter).Equal(otherComposite.StaticType(otherInterpreter)) { + return newValidationErrorf( + "composite types differ: %s != %s", + v.StaticType(vInterpreter), + otherComposite.StaticType(otherInterpreter), + ) + } + + if v.Kind != otherComposite.Kind { + return newValidationErrorf( + "composite kinds differ: %d != %d", + v.Kind, + otherComposite.Kind, + ) + } + + var err *validationError + vFieldNames := make([]string, 0, 10) // v's field names + v.ForEachField(nopMemoryGauge, func(fieldName string, fieldValue interpreter.Value) bool { + otherFieldValue := otherComposite.GetField(otherInterpreter, interpreter.EmptyLocationRange, fieldName) + + err = cadenceValueEqual(vInterpreter, fieldValue, otherInterpreter, otherFieldValue) + if err != nil { + err.addTrace(fmt.Sprintf("(%s.%s)", v.TypeID(), fieldName)) + return false + } + + vFieldNames = append(vFieldNames, fieldName) + return true + }) + + // TODO: Use CompositeValue.FieldCount() from Cadence after it is merged and available. + otherFieldNames := make([]string, 0, len(vFieldNames)) // otherComposite's field names + otherComposite.ForEachField(nopMemoryGauge, func(fieldName string, _ interpreter.Value) bool { + otherFieldNames = append(otherFieldNames, fieldName) + return true + }) + + if len(vFieldNames) != len(otherFieldNames) { + return newValidationErrorf( + "composite %s fields differ: %v != %v", + v.TypeID(), + vFieldNames, + otherFieldNames, + ) + } + + return err +} + +func cadenceDictionaryValueEqual( + vInterpreter *interpreter.Interpreter, + v *interpreter.DictionaryValue, + otherInterpreter *interpreter.Interpreter, + other interpreter.Value, +) *validationError { + otherDictionary, ok := other.(*interpreter.DictionaryValue) + if !ok { + return newValidationErrorf("types differ: %T != %T", v, other) + } + + if v.Count() != otherDictionary.Count() { + return newValidationErrorf("dict counts differ: %d != %d", v.Count(), otherDictionary.Count()) + } + + if !v.Type.Equal(otherDictionary.Type) { + return newValidationErrorf("dict types differ: %s != %s", v.Type, otherDictionary.Type) + } + + oldIterator := v.Iterator() + for { + key := oldIterator.NextKey(nopMemoryGauge) + if key == nil { + break + } + + oldValue, oldValueExist := v.Get(vInterpreter, interpreter.EmptyLocationRange, key) + if !oldValueExist { + err := newValidationErrorf("old value doesn't exist with key %v (%T)", key, key) + err.addTrace(fmt.Sprintf("(%s[%s])", v.Type, key)) + return err + } + newValue, newValueExist := otherDictionary.Get(otherInterpreter, interpreter.EmptyLocationRange, key) + if !newValueExist { + err := newValidationErrorf("new value doesn't exist with key %v (%T)", key, key) + err.addTrace(fmt.Sprintf("(%s[%s])", otherDictionary.Type, key)) + return err + } + err := cadenceValueEqual(vInterpreter, oldValue, otherInterpreter, newValue) + if err != nil { + err.addTrace(fmt.Sprintf("(%s[%s])", otherDictionary.Type, key)) + return err + } + } + + return nil +} + +type readonlyStorageRuntime struct { + Interpreter *interpreter.Interpreter + Storage *runtime.Storage +} + +func newReadonlyStorageRuntime(payloads []*ledger.Payload) ( + *readonlyStorageRuntime, + error, +) { + snapshot, err := util.NewPayloadSnapshot(payloads) + if err != nil { + return nil, fmt.Errorf("failed to create payload snapshot: %w", err) + } + + readonlyLedger := util.NewPayloadsReadonlyLedger(snapshot) + + storage := runtime.NewStorage(readonlyLedger, nopMemoryGauge) + + env := runtime.NewBaseInterpreterEnvironment(runtime.Config{ + AccountLinkingEnabled: true, + // Attachments are enabled everywhere except for Mainnet + AttachmentsEnabled: true, + // Capability Controllers are enabled everywhere except for Mainnet + CapabilityControllersEnabled: true, + }) + + env.Configure( + &NoopRuntimeInterface{}, + runtime.NewCodesAndPrograms(), + storage, + nil, + ) + + inter, err := interpreter.NewInterpreter(nil, nil, env.InterpreterConfig) + if err != nil { + return nil, err + } + + return &readonlyStorageRuntime{ + Interpreter: inter, + Storage: storage, + }, nil +} + +// NoopRuntimeInterface is a runtime interface that can be used in migrations. +type NoopRuntimeInterface struct { +} + +func (NoopRuntimeInterface) ResolveLocation(_ []runtime.Identifier, _ runtime.Location) ([]runtime.ResolvedLocation, error) { + panic("unexpected ResolveLocation call") +} + +func (NoopRuntimeInterface) GetCode(_ runtime.Location) ([]byte, error) { + panic("unexpected GetCode call") +} + +func (NoopRuntimeInterface) GetAccountContractCode(_ common.AddressLocation) ([]byte, error) { + panic("unexpected GetAccountContractCode call") +} + +func (NoopRuntimeInterface) GetOrLoadProgram(_ runtime.Location, _ func() (*interpreter.Program, error)) (*interpreter.Program, error) { + panic("unexpected GetOrLoadProgram call") +} + +func (NoopRuntimeInterface) MeterMemory(_ common.MemoryUsage) error { + return nil +} + +func (NoopRuntimeInterface) MeterComputation(_ common.ComputationKind, _ uint) error { + return nil +} + +func (NoopRuntimeInterface) GetValue(_, _ []byte) (value []byte, err error) { + panic("unexpected GetValue call") +} + +func (NoopRuntimeInterface) SetValue(_, _, _ []byte) (err error) { + panic("unexpected SetValue call") +} + +func (NoopRuntimeInterface) CreateAccount(_ runtime.Address) (address runtime.Address, err error) { + panic("unexpected CreateAccount call") +} + +func (NoopRuntimeInterface) AddEncodedAccountKey(_ runtime.Address, _ []byte) error { + panic("unexpected AddEncodedAccountKey call") +} + +func (NoopRuntimeInterface) RevokeEncodedAccountKey(_ runtime.Address, _ int) (publicKey []byte, err error) { + panic("unexpected RevokeEncodedAccountKey call") +} + +func (NoopRuntimeInterface) AddAccountKey(_ runtime.Address, _ *runtime.PublicKey, _ runtime.HashAlgorithm, _ int) (*runtime.AccountKey, error) { + panic("unexpected AddAccountKey call") +} + +func (NoopRuntimeInterface) GetAccountKey(_ runtime.Address, _ int) (*runtime.AccountKey, error) { + panic("unexpected GetAccountKey call") +} + +func (NoopRuntimeInterface) RevokeAccountKey(_ runtime.Address, _ int) (*runtime.AccountKey, error) { + panic("unexpected RevokeAccountKey call") +} + +func (NoopRuntimeInterface) UpdateAccountContractCode(_ common.AddressLocation, _ []byte) (err error) { + panic("unexpected UpdateAccountContractCode call") +} + +func (NoopRuntimeInterface) RemoveAccountContractCode(common.AddressLocation) (err error) { + panic("unexpected RemoveAccountContractCode call") +} + +func (NoopRuntimeInterface) GetSigningAccounts() ([]runtime.Address, error) { + panic("unexpected GetSigningAccounts call") +} + +func (NoopRuntimeInterface) ProgramLog(_ string) error { + panic("unexpected ProgramLog call") +} + +func (NoopRuntimeInterface) EmitEvent(_ cadence.Event) error { + panic("unexpected EmitEvent call") +} + +func (NoopRuntimeInterface) ValueExists(_, _ []byte) (exists bool, err error) { + panic("unexpected ValueExists call") +} + +func (NoopRuntimeInterface) GenerateUUID() (uint64, error) { + panic("unexpected GenerateUUID call") +} + +func (NoopRuntimeInterface) GetComputationLimit() uint64 { + panic("unexpected GetComputationLimit call") +} + +func (NoopRuntimeInterface) SetComputationUsed(_ uint64) error { + panic("unexpected SetComputationUsed call") +} + +func (NoopRuntimeInterface) DecodeArgument(_ []byte, _ cadence.Type) (cadence.Value, error) { + panic("unexpected DecodeArgument call") +} + +func (NoopRuntimeInterface) GetCurrentBlockHeight() (uint64, error) { + panic("unexpected GetCurrentBlockHeight call") +} + +func (NoopRuntimeInterface) GetBlockAtHeight(_ uint64) (block runtime.Block, exists bool, err error) { + panic("unexpected GetBlockAtHeight call") +} + +func (NoopRuntimeInterface) ReadRandom([]byte) error { + panic("unexpected ReadRandom call") +} + +func (NoopRuntimeInterface) VerifySignature(_ []byte, _ string, _ []byte, _ []byte, _ runtime.SignatureAlgorithm, _ runtime.HashAlgorithm) (bool, error) { + panic("unexpected VerifySignature call") +} + +func (NoopRuntimeInterface) Hash(_ []byte, _ string, _ runtime.HashAlgorithm) ([]byte, error) { + panic("unexpected Hash call") +} + +func (NoopRuntimeInterface) GetAccountBalance(_ common.Address) (value uint64, err error) { + panic("unexpected GetAccountBalance call") +} + +func (NoopRuntimeInterface) GetAccountAvailableBalance(_ common.Address) (value uint64, err error) { + panic("unexpected GetAccountAvailableBalance call") +} + +func (NoopRuntimeInterface) GetStorageUsed(_ runtime.Address) (value uint64, err error) { + panic("unexpected GetStorageUsed call") +} + +func (NoopRuntimeInterface) GetStorageCapacity(_ runtime.Address) (value uint64, err error) { + panic("unexpected GetStorageCapacity call") +} + +func (NoopRuntimeInterface) ImplementationDebugLog(_ string) error { + panic("unexpected ImplementationDebugLog call") +} + +func (NoopRuntimeInterface) ValidatePublicKey(_ *runtime.PublicKey) error { + panic("unexpected ValidatePublicKey call") +} + +func (NoopRuntimeInterface) GetAccountContractNames(_ runtime.Address) ([]string, error) { + panic("unexpected GetAccountContractNames call") +} + +func (NoopRuntimeInterface) AllocateStorageIndex(_ []byte) (atree.StorageIndex, error) { + panic("unexpected AllocateStorageIndex call") +} + +func (NoopRuntimeInterface) ComputationUsed() (uint64, error) { + panic("unexpected ComputationUsed call") +} + +func (NoopRuntimeInterface) MemoryUsed() (uint64, error) { + panic("unexpected MemoryUsed call") +} + +func (NoopRuntimeInterface) InteractionUsed() (uint64, error) { + panic("unexpected InteractionUsed call") +} + +func (NoopRuntimeInterface) SetInterpreterSharedState(_ *interpreter.SharedState) { + panic("unexpected SetInterpreterSharedState call") +} + +func (NoopRuntimeInterface) GetInterpreterSharedState() *interpreter.SharedState { + panic("unexpected GetInterpreterSharedState call") +} + +func (NoopRuntimeInterface) AccountKeysCount(_ runtime.Address) (uint64, error) { + panic("unexpected AccountKeysCount call") +} + +func (NoopRuntimeInterface) BLSVerifyPOP(_ *runtime.PublicKey, _ []byte) (bool, error) { + panic("unexpected BLSVerifyPOP call") +} + +func (NoopRuntimeInterface) BLSAggregateSignatures(_ [][]byte) ([]byte, error) { + panic("unexpected BLSAggregateSignatures call") +} + +func (NoopRuntimeInterface) BLSAggregatePublicKeys(_ []*runtime.PublicKey) (*runtime.PublicKey, error) { + panic("unexpected BLSAggregatePublicKeys call") +} + +func (NoopRuntimeInterface) ResourceOwnerChanged(_ *interpreter.Interpreter, _ *interpreter.CompositeValue, _ common.Address, _ common.Address) { + panic("unexpected ResourceOwnerChanged call") +} + +func (NoopRuntimeInterface) GenerateAccountID(_ common.Address) (uint64, error) { + panic("unexpected GenerateAccountID call") +} + +func (NoopRuntimeInterface) RecordTrace(_ string, _ runtime.Location, _ time.Duration, _ []attribute.KeyValue) { + panic("unexpected RecordTrace call") +} diff --git a/cmd/util/ledger/migrations/cadence_value_validation_test.go b/cmd/util/ledger/migrations/cadence_value_validation_test.go new file mode 100644 index 00000000000..ab52742a5fd --- /dev/null +++ b/cmd/util/ledger/migrations/cadence_value_validation_test.go @@ -0,0 +1,268 @@ +package migrations + +import ( + "bytes" + "fmt" + "strconv" + "testing" + + "github.com/onflow/cadence/runtime/common" + "github.com/onflow/cadence/runtime/interpreter" + + "github.com/onflow/flow-go/fvm/environment" + "github.com/onflow/flow-go/ledger" + "github.com/onflow/flow-go/ledger/common/convert" + "github.com/onflow/flow-go/model/flow" + + "github.com/rs/zerolog" + "github.com/stretchr/testify/require" +) + +func TestValidateCadenceValues(t *testing.T) { + address, err := common.HexToAddress("0x1") + require.NoError(t, err) + + domain := common.PathDomainStorage.Identifier() + + t.Run("no mismatch", func(t *testing.T) { + log := zerolog.New(zerolog.NewTestWriter(t)) + + err := validateCadenceValues( + address, + createTestPayloads(t, address, domain), + createTestPayloads(t, address, domain), + log, + false, + ) + require.NoError(t, err) + }) + + t.Run("has mismatch", func(t *testing.T) { + var w bytes.Buffer + log := zerolog.New(&w) + + createPayloads := func(nestedArrayValue interpreter.UInt64Value) []*ledger.Payload { + + // Create account status payload + accountStatus := environment.NewAccountStatus() + accountStatusPayload := ledger.NewPayload( + convert.RegisterIDToLedgerKey( + flow.AccountStatusRegisterID(flow.ConvertAddress(address)), + ), + accountStatus.ToBytes(), + ) + + mr, err := newMigratorRuntime(address, []*ledger.Payload{accountStatusPayload}) + require.NoError(t, err) + + // Create new storage map + storageMap := mr.Storage.GetStorageMap(mr.Address, domain, true) + + // Add Cadence ArrayValue with nested CadenceArray + nestedArray := interpreter.NewArrayValue( + mr.Interpreter, + interpreter.EmptyLocationRange, + interpreter.VariableSizedStaticType{ + Type: interpreter.PrimitiveStaticTypeUInt64, + }, + address, + interpreter.NewUnmeteredUInt64Value(0), + nestedArrayValue, + ) + + storageMap.WriteValue( + mr.Interpreter, + interpreter.StringStorageMapKey(strconv.FormatUint(storageMap.Count(), 10)), + interpreter.NewArrayValue( + mr.Interpreter, + interpreter.EmptyLocationRange, + interpreter.VariableSizedStaticType{ + Type: interpreter.PrimitiveStaticTypeAnyStruct, + }, + address, + nestedArray, + ), + ) + + err = mr.Storage.Commit(mr.Interpreter, false) + require.NoError(t, err) + + // finalize the transaction + result, err := mr.TransactionState.FinalizeMainTransaction() + require.NoError(t, err) + + payloads := make([]*ledger.Payload, 0, len(result.WriteSet)) + for id, value := range result.WriteSet { + key := convert.RegisterIDToLedgerKey(id) + payloads = append(payloads, ledger.NewPayload(key, value)) + } + + return payloads + } + + oldPayloads := createPayloads(interpreter.NewUnmeteredUInt64Value(1)) + newPayloads := createPayloads(interpreter.NewUnmeteredUInt64Value(2)) + wantErrorMsg := "failed to validate value for address 0000000000000001, domain storage, key 0: failed to validate ([AnyStruct][0]).([UInt64][1]): values differ: 1 (interpreter.UInt64Value) != 2 (interpreter.UInt64Value)" + wantVerboseMsg := "{\"level\":\"info\",\"address\":\"0000000000000001\",\"domain\":\"storage\",\"key\":\"0\",\"trace\":\"failed to validate ([AnyStruct][0]).([UInt64][1]): values differ: 1 (interpreter.UInt64Value) != 2 (interpreter.UInt64Value)\",\"old value\":\"[[0, 1]]\",\"new value\":\"[[0, 2]]\",\"message\":\"failed to validate value\"}\n" + + // Disable verbose logging + err := validateCadenceValues( + address, + oldPayloads, + newPayloads, + log, + false, + ) + require.ErrorContains(t, err, wantErrorMsg) + require.Equal(t, 0, w.Len()) + + // Enable verbose logging + err = validateCadenceValues( + address, + oldPayloads, + newPayloads, + log, + true, + ) + require.ErrorContains(t, err, wantErrorMsg) + require.Equal(t, wantVerboseMsg, w.String()) + }) +} + +func createTestPayloads(t *testing.T, address common.Address, domain string) []*ledger.Payload { + + // Create account status payload + accountStatus := environment.NewAccountStatus() + accountStatusPayload := ledger.NewPayload( + convert.RegisterIDToLedgerKey( + flow.AccountStatusRegisterID(flow.ConvertAddress(address)), + ), + accountStatus.ToBytes(), + ) + + mr, err := newMigratorRuntime(address, []*ledger.Payload{accountStatusPayload}) + require.NoError(t, err) + + // Create new storage map + storageMap := mr.Storage.GetStorageMap(mr.Address, domain, true) + + // Add Cadence UInt64Value + storageMap.WriteValue( + mr.Interpreter, + interpreter.StringStorageMapKey(strconv.FormatUint(storageMap.Count(), 10)), + interpreter.NewUnmeteredUInt64Value(1), + ) + + // Add Cadence SomeValue + storageMap.WriteValue( + mr.Interpreter, + interpreter.StringStorageMapKey(strconv.FormatUint(storageMap.Count(), 10)), + interpreter.NewUnmeteredSomeValueNonCopying(interpreter.NewUnmeteredStringValue("InnerValueString")), + ) + + // Add Cadence ArrayValue + const arrayCount = 10 + i := uint64(0) + storageMap.WriteValue( + mr.Interpreter, + interpreter.StringStorageMapKey(strconv.FormatUint(storageMap.Count(), 10)), + interpreter.NewArrayValueWithIterator( + mr.Interpreter, + interpreter.VariableSizedStaticType{ + Type: interpreter.PrimitiveStaticTypeAnyStruct, + }, + address, + 0, + func() interpreter.Value { + if i == arrayCount { + return nil + } + v := interpreter.NewUnmeteredUInt64Value(i) + i++ + return v + }, + ), + ) + + // Add Cadence DictionaryValue + const dictCount = 10 + dictValues := make([]interpreter.Value, 0, dictCount*2) + for i := 0; i < dictCount; i++ { + k := interpreter.NewUnmeteredUInt64Value(uint64(i)) + v := interpreter.NewUnmeteredStringValue(fmt.Sprintf("value %d", i)) + dictValues = append(dictValues, k, v) + } + + storageMap.WriteValue( + mr.Interpreter, + interpreter.StringStorageMapKey(strconv.FormatUint(storageMap.Count(), 10)), + interpreter.NewDictionaryValueWithAddress( + mr.Interpreter, + interpreter.EmptyLocationRange, + interpreter.DictionaryStaticType{ + KeyType: interpreter.PrimitiveStaticTypeUInt64, + ValueType: interpreter.PrimitiveStaticTypeString, + }, + address, + dictValues..., + ), + ) + + // Add Cadence CompositeValue + storageMap.WriteValue( + mr.Interpreter, + interpreter.StringStorageMapKey(strconv.FormatUint(storageMap.Count(), 10)), + interpreter.NewCompositeValue( + mr.Interpreter, + interpreter.EmptyLocationRange, + common.StringLocation("test"), + "Test", + common.CompositeKindStructure, + []interpreter.CompositeField{ + {Name: "field1", Value: interpreter.NewUnmeteredStringValue("value1")}, + {Name: "field2", Value: interpreter.NewUnmeteredStringValue("value2")}, + }, + address, + ), + ) + + // Add Cadence DictionaryValue with nested CadenceArray + nestedArrayValue := interpreter.NewArrayValue( + mr.Interpreter, + interpreter.EmptyLocationRange, + interpreter.VariableSizedStaticType{ + Type: interpreter.PrimitiveStaticTypeUInt64, + }, + address, + interpreter.NewUnmeteredUInt64Value(0), + ) + + storageMap.WriteValue( + mr.Interpreter, + interpreter.StringStorageMapKey(strconv.FormatUint(storageMap.Count(), 10)), + interpreter.NewArrayValue( + mr.Interpreter, + interpreter.EmptyLocationRange, + interpreter.VariableSizedStaticType{ + Type: interpreter.PrimitiveStaticTypeAnyStruct, + }, + address, + nestedArrayValue, + ), + ) + + err = mr.Storage.Commit(mr.Interpreter, false) + require.NoError(t, err) + + // finalize the transaction + result, err := mr.TransactionState.FinalizeMainTransaction() + require.NoError(t, err) + + payloads := make([]*ledger.Payload, 0, len(result.WriteSet)) + for id, value := range result.WriteSet { + key := convert.RegisterIDToLedgerKey(id) + payloads = append(payloads, ledger.NewPayload(key, value)) + } + + return payloads +}