From 32a227484e0b5b28a06c8af6debad56ecc22cfdf Mon Sep 17 00:00:00 2001 From: Eric Lau Date: Tue, 18 Jun 2024 18:07:42 -0400 Subject: [PATCH] Avoid validation error when function parameter has internal function pointer (#1038) --- packages/core/CHANGELOG.md | 4 + packages/core/contracts/test/Validations.sol | 70 ----------- .../test/ValidationsFunctionPointers.sol | 116 ++++++++++++++++++ packages/core/package.json | 2 +- packages/core/src/validate.test.ts | 56 ++++++--- packages/core/src/validate/run.ts | 31 ++++- 6 files changed, 189 insertions(+), 90 deletions(-) create mode 100644 packages/core/contracts/test/ValidationsFunctionPointers.sol diff --git a/packages/core/CHANGELOG.md b/packages/core/CHANGELOG.md index 50bd237a3..866beb00f 100644 --- a/packages/core/CHANGELOG.md +++ b/packages/core/CHANGELOG.md @@ -1,5 +1,9 @@ # Changelog +## 1.34.1 (2024-06-18) + +- Fix unexpected validation error when function parameter has internal function pointer. ([#1038](https://github.com/OpenZeppelin/openzeppelin-upgrades/pull/1038)) + ## 1.34.0 (2024-06-12) - Fix storage layout comparison for function types, disallow internal functions in storage. ([#1032](https://github.com/OpenZeppelin/openzeppelin-upgrades/pull/1032)) diff --git a/packages/core/contracts/test/Validations.sol b/packages/core/contracts/test/Validations.sol index 5e9856497..cbec23c36 100644 --- a/packages/core/contracts/test/Validations.sol +++ b/packages/core/contracts/test/Validations.sol @@ -206,73 +206,3 @@ contract TransitiveLibraryIsUnsafe { DirectLibrary.f2(); } } - -contract StructExternalFunctionPointer { - struct S { - function(bool) external foo; - } -} - -contract StructInternalFunctionPointer { - struct S { - function(bool) internal foo; - } -} - -contract StructImpliedInternalFunctionPointer { - struct S { - function(bool) foo; - } -} - -struct StandaloneStructInternalFn { - function(bool) internal foo; -} - -contract UsesStandaloneStructInternalFn { - StandaloneStructInternalFn bad; -} - -contract StructUsesStandaloneStructInternalFn { - struct Bad { - StandaloneStructInternalFn bad; - } -} - -contract RecursiveStructInternalFn { - StructUsesStandaloneStructInternalFn.Bad bad; -} - -contract MappingRecursiveStructInternalFn { - mapping(address => mapping(address => StructUsesStandaloneStructInternalFn.Bad)) bad; -} - -contract ArrayRecursiveStructInternalFn { - StructUsesStandaloneStructInternalFn.Bad[][] bad; -} - -contract SelfRecursiveMappingStructInternalFn { - struct SelfRecursive { - mapping(address => SelfRecursive) selfReference; - mapping(address => StructUsesStandaloneStructInternalFn.Bad) bad; - } -} - -contract SelfRecursiveArrayStructInternalFn { - struct SelfRecursiveArray { - SelfRecursiveArray[] selfReference; - StructUsesStandaloneStructInternalFn.Bad[] bad; - } -} - -contract ExternalFunctionPointer { - function(bool) external foo; -} - -contract InternalFunctionPointer { - function(bool) internal foo; -} - -contract ImpliedInternalFunctionPointer { - function(bool) foo; -} diff --git a/packages/core/contracts/test/ValidationsFunctionPointers.sol b/packages/core/contracts/test/ValidationsFunctionPointers.sol new file mode 100644 index 000000000..96d4c056c --- /dev/null +++ b/packages/core/contracts/test/ValidationsFunctionPointers.sol @@ -0,0 +1,116 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.20; + +contract NamespacedExternalFunctionPointer { + /// @custom:storage-location erc7201:example.main + struct S { + function(bool) external foo; + } +} + +contract NamespacedInternalFunctionPointer { + /// @custom:storage-location erc7201:example.main + struct S { + function(bool) internal foo; + } +} + +contract NamespacedInternalFunctionPointerUsed { + /// @custom:storage-location erc7201:example.main + struct S { + function(bool) internal foo; + } + S s; // NOTE: This is unsafe usage of a namespace! +} + +contract StructInternalFunctionPointerUsed { + // not a namespace, but it is referenced + struct S { + function(bool) internal foo; + } + S s; +} + +contract NonNamespacedInternalFunctionPointer { + // not a namespace, and not referenced + struct S { + function(bool) internal foo; + } +} + +contract NamespacedImpliedInternalFunctionPointer { + /// @custom:storage-location erc7201:example.main + struct S { + function(bool) foo; + } +} + +struct StandaloneStructInternalFn { + function(bool) internal foo; +} + +contract UsesStandaloneStructInternalFn { + StandaloneStructInternalFn bad; +} + +struct StructUsesStandaloneStructInternalFn { + StandaloneStructInternalFn bad; +} + +contract NamespacedUsesStandaloneStructInternalFn { + /// @custom:storage-location erc7201:example.main + struct Bad { + StandaloneStructInternalFn bad; + } +} + +contract RecursiveStructInternalFn { + StructUsesStandaloneStructInternalFn bad; +} + +contract MappingRecursiveStructInternalFn { + mapping(address => mapping(address => StructUsesStandaloneStructInternalFn)) bad; +} + +contract ArrayRecursiveStructInternalFn { + StructUsesStandaloneStructInternalFn[][] bad; +} + +contract SelfRecursiveMappingStructInternalFn { + /// @custom:storage-location erc7201:example.main + struct SelfRecursive { + mapping(address => SelfRecursive) selfReference; + mapping(address => StructUsesStandaloneStructInternalFn) bad; + } +} + +contract SelfRecursiveArrayStructInternalFn { + /// @custom:storage-location erc7201:example.main + struct SelfRecursiveArray { + SelfRecursiveArray[] selfReference; + StructUsesStandaloneStructInternalFn[] bad; + } +} + +contract ExternalFunctionPointer { + function(bool) external foo; +} + +contract InternalFunctionPointer { + function(bool) internal foo; +} + +contract ImpliedInternalFunctionPointer { + function(bool) foo; +} + +contract FunctionWithInternalFunctionPointer { + uint208 x; + + function doOp( + function(uint208, uint208) view returns (uint208) op, + uint208 y + ) internal view returns (uint208) { + return op(x, y); + } +} diff --git a/packages/core/package.json b/packages/core/package.json index c12e5b53c..e5dce6b1a 100644 --- a/packages/core/package.json +++ b/packages/core/package.json @@ -1,6 +1,6 @@ { "name": "@openzeppelin/upgrades-core", - "version": "1.34.0", + "version": "1.34.1", "description": "", "repository": "https://github.com/OpenZeppelin/openzeppelin-upgrades/tree/master/packages/core", "license": "MIT", diff --git a/packages/core/src/validate.test.ts b/packages/core/src/validate.test.ts index 8be4e7bf4..982ed0bc0 100644 --- a/packages/core/src/validate.test.ts +++ b/packages/core/src/validate.test.ts @@ -8,6 +8,7 @@ import { assertUpgradeSafe, ValidationOptions, RunValidation, + ValidationErrors, } from './validate'; import { solcInputOutputDecoder } from './src-decoder'; @@ -23,6 +24,7 @@ test.before(async t => { 'contracts/test/ValidationsNatspec.sol:HasNonEmptyConstructorNatspec1', 'contracts/test/Proxiable.sol:ChildOfProxiable', 'contracts/test/ValidationsUDVT.sol:ValidationsUDVT', + 'contracts/test/ValidationsFunctionPointers.sol:InternalFunctionPointer', ]; t.context.validation = {} as RunValidation; @@ -38,11 +40,21 @@ test.before(async t => { } }); -function testValid(name: string, kind: ValidationOptions['kind'], valid: boolean) { - testOverride(name, kind, {}, valid); +function testValid(name: string, kind: ValidationOptions['kind'], valid: boolean, numExpectedErrors?: number) { + testOverride(name, kind, {}, valid, numExpectedErrors); } -function testOverride(name: string, kind: ValidationOptions['kind'], opts: ValidationOptions, valid: boolean) { +function testOverride( + name: string, + kind: ValidationOptions['kind'], + opts: ValidationOptions, + valid: boolean, + numExpectedErrors?: number, +) { + if (numExpectedErrors !== undefined && numExpectedErrors > 0 && valid) { + throw new Error('Cannot expect errors for a valid contract'); + } + const optKeys = Object.keys(opts); const describeOpts = optKeys.length > 0 ? '(' + optKeys.join(', ') + ')' : ''; const testName = [valid ? 'accepts' : 'rejects', kind, name, describeOpts].join(' '); @@ -52,7 +64,10 @@ function testOverride(name: string, kind: ValidationOptions['kind'], opts: Valid if (valid) { t.notThrows(assertUpgSafe); } else { - t.throws(assertUpgSafe); + const error = t.throws(assertUpgSafe) as ValidationErrors; + if (numExpectedErrors !== undefined) { + t.is(error.errors.length, numExpectedErrors); + } } }); } @@ -140,18 +155,31 @@ testValid('TransitiveLibraryIsUnsafe', 'transparent', false); testValid('contracts/test/ValidationsSameNameSafe.sol:SameName', 'transparent', true); testValid('contracts/test/ValidationsSameNameUnsafe.sol:SameName', 'transparent', false); -testValid('StructExternalFunctionPointer', 'transparent', true); -testValid('StructInternalFunctionPointer', 'transparent', false); -testValid('StructImpliedInternalFunctionPointer', 'transparent', false); +test('ambiguous name', t => { + const error = t.throws(() => getContractVersion(t.context.validation, 'SameName')); + t.is( + error?.message, + 'Contract SameName is ambiguous. Use one of the following:\n' + + 'contracts/test/ValidationsSameNameSafe.sol:SameName\n' + + 'contracts/test/ValidationsSameNameUnsafe.sol:SameName', + ); +}); + +testValid('NamespacedExternalFunctionPointer', 'transparent', true); +testValid('NamespacedInternalFunctionPointer', 'transparent', false); +testValid('NamespacedInternalFunctionPointerUsed', 'transparent', false, 1); +testValid('StructInternalFunctionPointerUsed', 'transparent', false, 1); +testValid('NonNamespacedInternalFunctionPointer', 'transparent', true); +testValid('NamespacedImpliedInternalFunctionPointer', 'transparent', false); testOverride( - 'StructImpliedInternalFunctionPointer', + 'NamespacedImpliedInternalFunctionPointer', 'transparent', { unsafeAllow: ['internal-function-storage'] }, true, ); testValid('UsesStandaloneStructInternalFn', 'transparent', false); -testValid('StructUsesStandaloneStructInternalFn', 'transparent', false); +testValid('NamespacedUsesStandaloneStructInternalFn', 'transparent', false); testValid('RecursiveStructInternalFn', 'transparent', false); testValid('MappingRecursiveStructInternalFn', 'transparent', false); testValid('ArrayRecursiveStructInternalFn', 'transparent', false); @@ -163,12 +191,4 @@ testValid('InternalFunctionPointer', 'transparent', false); testValid('ImpliedInternalFunctionPointer', 'transparent', false); testOverride('ImpliedInternalFunctionPointer', 'transparent', { unsafeAllow: ['internal-function-storage'] }, true); -test('ambiguous name', t => { - const error = t.throws(() => getContractVersion(t.context.validation, 'SameName')); - t.is( - error?.message, - 'Contract SameName is ambiguous. Use one of the following:\n' + - 'contracts/test/ValidationsSameNameSafe.sol:SameName\n' + - 'contracts/test/ValidationsSameNameUnsafe.sol:SameName', - ); -}); +testValid('FunctionWithInternalFunctionPointer', 'transparent', true); diff --git a/packages/core/src/validate/run.ts b/packages/core/src/validate/run.ts index 2565314da..1435d83a2 100644 --- a/packages/core/src/validate/run.ts +++ b/packages/core/src/validate/run.ts @@ -6,6 +6,7 @@ import type { StructDefinition, TypeName, UserDefinedTypeName, + VariableDeclaration, } from 'solidity-ast'; import debug from '../utils/debug'; @@ -586,7 +587,7 @@ function* getInternalFunctionStorageErrors( ): Generator { // Note: Solidity does not allow annotations for non-public state variables, nor recursive types for public variables, // so annotations cannot be used to skip these checks. - for (const variableDec of findAll('VariableDeclaration', contractOrStructDef)) { + for (const variableDec of getVariableDeclarations(contractOrStructDef, visitedNodeIds)) { if (variableDec.typeName?.nodeType === 'FunctionTypeName' && variableDec.typeName.visibility === 'internal') { // Find internal function types directly in this node's scope yield { @@ -608,6 +609,34 @@ function* getInternalFunctionStorageErrors( } } +/** + * Gets variables declared directly in a contract or in a struct definition. + * + * If this is a contract with struct definitions annotated with a storage location according to ERC-7201, + * then the struct members are also included. + */ +function* getVariableDeclarations( + contractOrStructDef: ContractDefinition | StructDefinition, + visitedNodeIds: Set, +): Generator { + if (contractOrStructDef.nodeType === 'ContractDefinition') { + for (const node of contractOrStructDef.nodes) { + if (node.nodeType === 'VariableDeclaration') { + yield node; + } else if ( + node.nodeType === 'StructDefinition' && + getStorageLocationAnnotation(node) !== undefined && + !visitedNodeIds.has(node.id) + ) { + visitedNodeIds.add(node.id); + yield* getVariableDeclarations(node, visitedNodeIds); + } + } + } else if (contractOrStructDef.nodeType === 'StructDefinition') { + yield* contractOrStructDef.members; + } +} + /** * Recursively traverse array and mapping types to find user-defined types (which may be struct references). */