diff --git a/migrations/entitlements/migration_test.go b/migrations/entitlements/migration_test.go index 7717beca44..3faad22223 100644 --- a/migrations/entitlements/migration_test.go +++ b/migrations/entitlements/migration_test.go @@ -3518,3 +3518,218 @@ func TestIntersectionTypeWithIntersectionLegacyType(t *testing.T) { require.Equal(t, expectedType, typeValue.Type) })() } + +func TestUseAfterMigrationFailure(t *testing.T) { + + t.Parallel() + + locationRange := interpreter.EmptyLocationRange + + ledger := NewTestLedger(nil, nil) + + storageMapKey := interpreter.StringStorageMapKey("dict") + newTestValue := func() interpreter.Value { + return interpreter.NewUnmeteredStringValue("test") + } + + const fooBarQualifiedIdentifier = "Foo.Bar" + testAddress := common.Address{0x42} + fooAddressLocation := common.NewAddressLocation(nil, testAddress, "Foo") + + newStorageAndInterpreter := func(t *testing.T) (*runtime.Storage, *interpreter.Interpreter) { + storage := runtime.NewStorage(ledger, nil) + inter, err := interpreter.NewInterpreter( + nil, + utils.TestLocation, + &interpreter.Config{ + Storage: storage, + // NOTE: disabled, because encoded and decoded values are expected to not match + AtreeValueValidationEnabled: false, + AtreeStorageValidationEnabled: true, + }, + ) + require.NoError(t, err) + + return storage, inter + } + + newCompositeType := func() *interpreter.CompositeStaticType { + return interpreter.NewCompositeStaticType( + nil, + fooAddressLocation, + fooBarQualifiedIdentifier, + common.NewTypeIDFromQualifiedName( + nil, + fooAddressLocation, + fooBarQualifiedIdentifier, + ), + ) + } + + // Prepare + (func() { + + storage, inter := newStorageAndInterpreter(t) + + dictionaryStaticType := interpreter.NewDictionaryStaticType( + nil, + interpreter.PrimitiveStaticTypeMetaType, + interpreter.PrimitiveStaticTypeString, + ) + dictValue := interpreter.NewDictionaryValue(inter, locationRange, dictionaryStaticType) + + refType := interpreter.NewReferenceStaticType( + nil, + interpreter.UnauthorizedAccess, + newCompositeType(), + ) + refType.HasLegacyIsAuthorized = true + refType.LegacyIsAuthorized = true + + legacyRefType := &migrations.LegacyReferenceType{ + ReferenceStaticType: refType, + } + + optType := interpreter.NewOptionalStaticType( + nil, + legacyRefType, + ) + + legacyOptType := &migrations.LegacyOptionalType{ + OptionalStaticType: optType, + } + + typeValue := interpreter.NewUnmeteredTypeValue(legacyOptType) + + dictValue.Insert( + inter, + locationRange, + typeValue, + newTestValue(), + ) + + // Note: ID is in the old format + assert.Equal(t, + common.TypeID("auth&A.4200000000000000.Foo.Bar?"), + legacyOptType.ID(), + ) + + storageMap := storage.GetStorageMap( + testAddress, + common.PathDomainStorage.Identifier(), + true, + ) + + storageMap.SetValue(inter, + storageMapKey, + dictValue.Transfer( + inter, + locationRange, + atree.Address(testAddress), + false, + nil, + nil, + true, // dictValue is standalone + ), + ) + + err := storage.Commit(inter, false) + require.NoError(t, err) + + err = storage.CheckHealth() + require.NoError(t, err) + })() + + // Migrate + (func() { + + storage, inter := newStorageAndInterpreter(t) + + const importErrorMessage = "cannot import" + + inter.SharedState.Config.ImportLocationHandler = + func(inter *interpreter.Interpreter, location common.Location) interpreter.Import { + panic(importErrorMessage) + } + + migration, err := migrations.NewStorageMigration(inter, storage, "test", testAddress) + require.NoError(t, err) + + reporter := newTestReporter() + + migration.Migrate( + migration.NewValueMigrationsPathMigrator( + reporter, + NewEntitlementsMigration(inter), + ), + ) + + err = migration.Commit() + require.NoError(t, err) + + // Assert + + err = storage.CheckHealth() + require.NoError(t, err) + + require.Len(t, reporter.errors, 1) + + assert.ErrorContains(t, reporter.errors[0], importErrorMessage) + + require.Empty(t, reporter.migrated) + })() + + // Load + (func() { + + storage, inter := newStorageAndInterpreter(t) + + err := storage.CheckHealth() + require.NoError(t, err) + + storageMap := storage.GetStorageMap( + testAddress, + common.PathDomainStorage.Identifier(), + false, + ) + storedValue := storageMap.ReadValue(inter, storageMapKey) + + require.IsType(t, &interpreter.DictionaryValue{}, storedValue) + + dictValue := storedValue.(*interpreter.DictionaryValue) + + refType := interpreter.NewReferenceStaticType( + nil, + interpreter.UnauthorizedAccess, + newCompositeType(), + ) + refType.HasLegacyIsAuthorized = true + refType.LegacyIsAuthorized = true + + optType := interpreter.NewOptionalStaticType( + nil, + refType, + ) + + typeValue := interpreter.NewUnmeteredTypeValue(optType) + + // Note: ID is in the new format + assert.Equal(t, + common.TypeID("(&A.4200000000000000.Foo.Bar)?"), + optType.ID(), + ) + + assert.Equal(t, 1, dictValue.Count()) + + // Key did not get migrated, so is inaccessible using the "new" type value + _, ok := dictValue.Get(inter, locationRange, typeValue) + require.False(t, ok) + + // But the key is still accessible using the "old" type value + legacyKey := migrations.LegacyKey(typeValue) + + value, ok := dictValue.Get(inter, locationRange, legacyKey) + require.True(t, ok) + require.Equal(t, newTestValue(), value) + })() +} diff --git a/migrations/migration.go b/migrations/migration.go index c395c2c63d..1cde187b6f 100644 --- a/migrations/migration.go +++ b/migrations/migration.go @@ -363,7 +363,7 @@ func (m *StorageMigration) MigrateNestedValue( // The mutating iterator is only able to read new keys, // as it recalculates the stored values' hashes. - m.migrateDictionaryKeys( + keys := m.migrateDictionaryKeys( storageKey, storageMapKey, dictionary, @@ -379,6 +379,7 @@ func (m *StorageMigration) MigrateNestedValue( valueMigrations, reporter, allowMutation, + keys, ) case *interpreter.PublishedValue: @@ -468,6 +469,11 @@ func (m *StorageMigration) MigrateNestedValue( } +type migratedDictionaryKey struct { + key interpreter.Value + migrated bool +} + func (m *StorageMigration) migrateDictionaryKeys( storageKey interpreter.StorageKey, storageMapKey interpreter.StorageMapKey, @@ -475,7 +481,7 @@ func (m *StorageMigration) migrateDictionaryKeys( valueMigrations []ValueMigration, reporter Reporter, allowMutation bool, -) { +) (migratedKeys []migratedDictionaryKey) { inter := m.interpreter var existingKeys []interpreter.Value @@ -505,6 +511,11 @@ func (m *StorageMigration) migrateDictionaryKeys( ) if newKey == nil { + migratedKeys = append(migratedKeys, migratedDictionaryKey{ + key: existingKey, + migrated: false, + }) + continue } @@ -523,7 +534,7 @@ func (m *StorageMigration) migrateDictionaryKeys( // Remove the old key-value pair - existingKey = legacyKey(existingKey) + existingKey = LegacyKey(existingKey) existingKeyStorable, existingValueStorable := dictionary.RemoveWithoutTransfer( inter, emptyLocationRange, @@ -645,8 +656,15 @@ func (m *StorageMigration) migrateDictionaryKeys( newKey, existingValue, ) + + migratedKeys = append(migratedKeys, migratedDictionaryKey{ + key: newKey, + migrated: true, + }) } } + + return } func (m *StorageMigration) migrateDictionaryValues( @@ -656,37 +674,21 @@ func (m *StorageMigration) migrateDictionaryValues( valueMigrations []ValueMigration, reporter Reporter, allowMutation bool, + migratedDictionaryKeys []migratedDictionaryKey, ) { - inter := m.interpreter - type keyValuePair struct { - key, value interpreter.Value - } - - var existingKeysAndValues []keyValuePair - - dictionary.Iterate( - inter, - emptyLocationRange, - func(key, value interpreter.Value) (resume bool) { - - existingKeysAndValues = append( - existingKeysAndValues, - keyValuePair{ - key: key, - value: value, - }, - ) + for _, migratedDictionaryKey := range migratedDictionaryKeys { - // Continue iteration - return true - }, - ) + existingKey := migratedDictionaryKey.key + if !migratedDictionaryKey.migrated { + existingKey = LegacyKey(existingKey) + } - for _, existingKeyAndValue := range existingKeysAndValues { - existingKey := existingKeyAndValue.key - existingValue := existingKeyAndValue.value + existingValue, ok := dictionary.Get(inter, emptyLocationRange, existingKey) + if !ok { + panic(errors.NewUnexpectedError("failed to get existing value for key: %s", existingKey)) + } newValue := m.MigrateNestedValue( storageKey, @@ -811,8 +813,8 @@ func (m *StorageMigration) migrate( ) } -// legacyKey return the same type with the "old" hash/ID generation function. -func legacyKey(key interpreter.Value) interpreter.Value { +// LegacyKey return the same type with the "old" hash/ID generation function. +func LegacyKey(key interpreter.Value) interpreter.Value { switch key := key.(type) { case interpreter.TypeValue: legacyType := legacyType(key.Type) diff --git a/migrations/migration_test.go b/migrations/migration_test.go index 93739a23a0..c708215fc2 100644 --- a/migrations/migration_test.go +++ b/migrations/migration_test.go @@ -2485,12 +2485,12 @@ func TestDictionaryKeyConflict(t *testing.T) { dictionaryValue, ) - // NOTE: use legacyKey to ensure the key is encoded in old format + // NOTE: use LegacyKey to ensure the key is encoded in old format dictionaryValue.InsertWithoutTransfer( inter, emptyLocationRange, - legacyKey(dictionaryKey1), + LegacyKey(dictionaryKey1), interpreter.NewArrayValue( inter, emptyLocationRange, @@ -2503,7 +2503,7 @@ func TestDictionaryKeyConflict(t *testing.T) { dictionaryValue.InsertWithoutTransfer( inter, emptyLocationRange, - legacyKey(dictionaryKey2), + LegacyKey(dictionaryKey2), interpreter.NewArrayValue( inter, emptyLocationRange, @@ -2516,7 +2516,7 @@ func TestDictionaryKeyConflict(t *testing.T) { oldValue1, ok := dictionaryValue.Get( inter, emptyLocationRange, - legacyKey(dictionaryKey1), + LegacyKey(dictionaryKey1), ) require.True(t, ok) @@ -2535,7 +2535,7 @@ func TestDictionaryKeyConflict(t *testing.T) { oldValue2, ok := dictionaryValue.Get( inter, emptyLocationRange, - legacyKey(dictionaryKey2), + LegacyKey(dictionaryKey2), ) require.True(t, ok) diff --git a/npm-packages/cadence-parser/package.json b/npm-packages/cadence-parser/package.json index b752684c6f..0692d98249 100644 --- a/npm-packages/cadence-parser/package.json +++ b/npm-packages/cadence-parser/package.json @@ -1,6 +1,6 @@ { "name": "@onflow/cadence-parser", - "version": "1.0.0-preview.29", + "version": "1.0.0-preview.31", "description": "The Cadence parser", "homepage": "https://github.com/onflow/cadence", "repository": { diff --git a/runtime/contract_update_validation_test.go b/runtime/contract_update_validation_test.go index e747a87682..f490cfe12f 100644 --- a/runtime/contract_update_validation_test.go +++ b/runtime/contract_update_validation_test.go @@ -3246,14 +3246,14 @@ func TestPragmaUpdates(t *testing.T) { const oldCode = ` access(all) contract Test { access(all) resource R {} - access(all) struct interface I {} + access(all) struct S {} } ` const newCode = ` access(all) contract Test { #removedType(R) - #removedType(I) + #removedType(S) } ` @@ -3261,6 +3261,44 @@ func TestPragmaUpdates(t *testing.T) { require.NoError(t, err) }) + testWithValidators(t, "#removedType does not allow resource interface type removal", func(t *testing.T, withC1Upgrade bool) { + + const oldCode = ` + access(all) contract Test { + access(all) resource interface R {} + } + ` + + const newCode = ` + access(all) contract Test { + #removedType(R) + } + ` + + err := testDeployAndUpdate(t, "Test", oldCode, newCode, withC1Upgrade) + var expectedErr *stdlib.MissingDeclarationError + require.ErrorAs(t, err, &expectedErr) + }) + + testWithValidators(t, "#removedType does not allow struct interface type removal", func(t *testing.T, withC1Upgrade bool) { + + const oldCode = ` + access(all) contract Test { + access(all) struct interface S {} + } + ` + + const newCode = ` + access(all) contract Test { + #removedType(S) + } + ` + + err := testDeployAndUpdate(t, "Test", oldCode, newCode, withC1Upgrade) + var expectedErr *stdlib.MissingDeclarationError + require.ErrorAs(t, err, &expectedErr) + }) + testWithValidators(t, "#removedType can be added", func(t *testing.T, withC1Upgrade bool) { const oldCode = ` diff --git a/runtime/stdlib/contract_update_validation.go b/runtime/stdlib/contract_update_validation.go index d43840d9dd..a11df15f79 100644 --- a/runtime/stdlib/contract_update_validation.go +++ b/runtime/stdlib/contract_update_validation.go @@ -347,19 +347,22 @@ func (validator *ContractUpdateValidator) checkNestedDeclarationRemoval( newContainingDeclaration ast.Declaration, removedTypes *orderedmap.OrderedMap[string, struct{}], ) { + declarationKind := nestedDeclaration.DeclarationKind() + // OK to remove events - they are not stored - if nestedDeclaration.DeclarationKind() == common.DeclarationKindEvent { + if declarationKind == common.DeclarationKindEvent { return } - // OK to remove a type if it is included in a #removedType pragma - if removedTypes.Contains(nestedDeclaration.DeclarationIdentifier().Identifier) { + // OK to remove a type if it is included in a #removedType pragma, and it is not an interface + if removedTypes.Contains(nestedDeclaration.DeclarationIdentifier().Identifier) && + !declarationKind.IsInterfaceDeclaration() { return } validator.report(&MissingDeclarationError{ Name: nestedDeclaration.DeclarationIdentifier().Identifier, - Kind: nestedDeclaration.DeclarationKind(), + Kind: declarationKind, Range: ast.NewUnmeteredRangeFromPositioned( newContainingDeclaration.DeclarationIdentifier(), ), @@ -841,7 +844,7 @@ func (*MissingDeclarationError) IsUserError() {} func (e *MissingDeclarationError) Error() string { return fmt.Sprintf( "missing %s declaration `%s`", - e.Kind, + e.Kind.Name(), e.Name, ) }