Skip to content

Commit

Permalink
Avoid validation error when function parameter has internal function …
Browse files Browse the repository at this point in the history
…pointer (#1038)
  • Loading branch information
ericglau authored Jun 18, 2024
1 parent b05a954 commit 32a2274
Show file tree
Hide file tree
Showing 6 changed files with 189 additions and 90 deletions.
4 changes: 4 additions & 0 deletions packages/core/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Changelog

## 1.34.1 (2024-06-18)

- Fix unexpected validation error when function parameter has internal function pointer. ([#1038](https://github.com/OpenZeppelin/openzeppelin-upgrades/pull/1038))

## 1.34.0 (2024-06-12)

- Fix storage layout comparison for function types, disallow internal functions in storage. ([#1032](https://github.com/OpenZeppelin/openzeppelin-upgrades/pull/1032))
Expand Down
70 changes: 0 additions & 70 deletions packages/core/contracts/test/Validations.sol
Original file line number Diff line number Diff line change
Expand Up @@ -206,73 +206,3 @@ contract TransitiveLibraryIsUnsafe {
DirectLibrary.f2();
}
}

contract StructExternalFunctionPointer {
struct S {
function(bool) external foo;
}
}

contract StructInternalFunctionPointer {
struct S {
function(bool) internal foo;
}
}

contract StructImpliedInternalFunctionPointer {
struct S {
function(bool) foo;
}
}

struct StandaloneStructInternalFn {
function(bool) internal foo;
}

contract UsesStandaloneStructInternalFn {
StandaloneStructInternalFn bad;
}

contract StructUsesStandaloneStructInternalFn {
struct Bad {
StandaloneStructInternalFn bad;
}
}

contract RecursiveStructInternalFn {
StructUsesStandaloneStructInternalFn.Bad bad;
}

contract MappingRecursiveStructInternalFn {
mapping(address => mapping(address => StructUsesStandaloneStructInternalFn.Bad)) bad;
}

contract ArrayRecursiveStructInternalFn {
StructUsesStandaloneStructInternalFn.Bad[][] bad;
}

contract SelfRecursiveMappingStructInternalFn {
struct SelfRecursive {
mapping(address => SelfRecursive) selfReference;
mapping(address => StructUsesStandaloneStructInternalFn.Bad) bad;
}
}

contract SelfRecursiveArrayStructInternalFn {
struct SelfRecursiveArray {
SelfRecursiveArray[] selfReference;
StructUsesStandaloneStructInternalFn.Bad[] bad;
}
}

contract ExternalFunctionPointer {
function(bool) external foo;
}

contract InternalFunctionPointer {
function(bool) internal foo;
}

contract ImpliedInternalFunctionPointer {
function(bool) foo;
}
116 changes: 116 additions & 0 deletions packages/core/contracts/test/ValidationsFunctionPointers.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
// SPDX-License-Identifier: MIT
pragma solidity ^0.8.20;

contract NamespacedExternalFunctionPointer {
/// @custom:storage-location erc7201:example.main
struct S {
function(bool) external foo;
}
}

contract NamespacedInternalFunctionPointer {
/// @custom:storage-location erc7201:example.main
struct S {
function(bool) internal foo;
}
}

contract NamespacedInternalFunctionPointerUsed {
/// @custom:storage-location erc7201:example.main
struct S {
function(bool) internal foo;
}
S s; // NOTE: This is unsafe usage of a namespace!
}

contract StructInternalFunctionPointerUsed {
// not a namespace, but it is referenced
struct S {
function(bool) internal foo;
}
S s;
}

contract NonNamespacedInternalFunctionPointer {
// not a namespace, and not referenced
struct S {
function(bool) internal foo;
}
}

contract NamespacedImpliedInternalFunctionPointer {
/// @custom:storage-location erc7201:example.main
struct S {
function(bool) foo;
}
}

struct StandaloneStructInternalFn {
function(bool) internal foo;
}

contract UsesStandaloneStructInternalFn {
StandaloneStructInternalFn bad;
}

struct StructUsesStandaloneStructInternalFn {
StandaloneStructInternalFn bad;
}

contract NamespacedUsesStandaloneStructInternalFn {
/// @custom:storage-location erc7201:example.main
struct Bad {
StandaloneStructInternalFn bad;
}
}

contract RecursiveStructInternalFn {
StructUsesStandaloneStructInternalFn bad;
}

contract MappingRecursiveStructInternalFn {
mapping(address => mapping(address => StructUsesStandaloneStructInternalFn)) bad;
}

contract ArrayRecursiveStructInternalFn {
StructUsesStandaloneStructInternalFn[][] bad;
}

contract SelfRecursiveMappingStructInternalFn {
/// @custom:storage-location erc7201:example.main
struct SelfRecursive {
mapping(address => SelfRecursive) selfReference;
mapping(address => StructUsesStandaloneStructInternalFn) bad;
}
}

contract SelfRecursiveArrayStructInternalFn {
/// @custom:storage-location erc7201:example.main
struct SelfRecursiveArray {
SelfRecursiveArray[] selfReference;
StructUsesStandaloneStructInternalFn[] bad;
}
}

contract ExternalFunctionPointer {
function(bool) external foo;
}

contract InternalFunctionPointer {
function(bool) internal foo;
}

contract ImpliedInternalFunctionPointer {
function(bool) foo;
}

contract FunctionWithInternalFunctionPointer {
uint208 x;

function doOp(
function(uint208, uint208) view returns (uint208) op,
uint208 y
) internal view returns (uint208) {
return op(x, y);
}
}
2 changes: 1 addition & 1 deletion packages/core/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@openzeppelin/upgrades-core",
"version": "1.34.0",
"version": "1.34.1",
"description": "",
"repository": "https://github.com/OpenZeppelin/openzeppelin-upgrades/tree/master/packages/core",
"license": "MIT",
Expand Down
56 changes: 38 additions & 18 deletions packages/core/src/validate.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import {
assertUpgradeSafe,
ValidationOptions,
RunValidation,
ValidationErrors,
} from './validate';
import { solcInputOutputDecoder } from './src-decoder';

Expand All @@ -23,6 +24,7 @@ test.before(async t => {
'contracts/test/ValidationsNatspec.sol:HasNonEmptyConstructorNatspec1',
'contracts/test/Proxiable.sol:ChildOfProxiable',
'contracts/test/ValidationsUDVT.sol:ValidationsUDVT',
'contracts/test/ValidationsFunctionPointers.sol:InternalFunctionPointer',
];

t.context.validation = {} as RunValidation;
Expand All @@ -38,11 +40,21 @@ test.before(async t => {
}
});

function testValid(name: string, kind: ValidationOptions['kind'], valid: boolean) {
testOverride(name, kind, {}, valid);
function testValid(name: string, kind: ValidationOptions['kind'], valid: boolean, numExpectedErrors?: number) {
testOverride(name, kind, {}, valid, numExpectedErrors);
}

function testOverride(name: string, kind: ValidationOptions['kind'], opts: ValidationOptions, valid: boolean) {
function testOverride(
name: string,
kind: ValidationOptions['kind'],
opts: ValidationOptions,
valid: boolean,
numExpectedErrors?: number,
) {
if (numExpectedErrors !== undefined && numExpectedErrors > 0 && valid) {
throw new Error('Cannot expect errors for a valid contract');
}

const optKeys = Object.keys(opts);
const describeOpts = optKeys.length > 0 ? '(' + optKeys.join(', ') + ')' : '';
const testName = [valid ? 'accepts' : 'rejects', kind, name, describeOpts].join(' ');
Expand All @@ -52,7 +64,10 @@ function testOverride(name: string, kind: ValidationOptions['kind'], opts: Valid
if (valid) {
t.notThrows(assertUpgSafe);
} else {
t.throws(assertUpgSafe);
const error = t.throws(assertUpgSafe) as ValidationErrors;
if (numExpectedErrors !== undefined) {
t.is(error.errors.length, numExpectedErrors);
}
}
});
}
Expand Down Expand Up @@ -140,18 +155,31 @@ testValid('TransitiveLibraryIsUnsafe', 'transparent', false);
testValid('contracts/test/ValidationsSameNameSafe.sol:SameName', 'transparent', true);
testValid('contracts/test/ValidationsSameNameUnsafe.sol:SameName', 'transparent', false);

testValid('StructExternalFunctionPointer', 'transparent', true);
testValid('StructInternalFunctionPointer', 'transparent', false);
testValid('StructImpliedInternalFunctionPointer', 'transparent', false);
test('ambiguous name', t => {
const error = t.throws(() => getContractVersion(t.context.validation, 'SameName'));
t.is(
error?.message,
'Contract SameName is ambiguous. Use one of the following:\n' +
'contracts/test/ValidationsSameNameSafe.sol:SameName\n' +
'contracts/test/ValidationsSameNameUnsafe.sol:SameName',
);
});

testValid('NamespacedExternalFunctionPointer', 'transparent', true);
testValid('NamespacedInternalFunctionPointer', 'transparent', false);
testValid('NamespacedInternalFunctionPointerUsed', 'transparent', false, 1);
testValid('StructInternalFunctionPointerUsed', 'transparent', false, 1);
testValid('NonNamespacedInternalFunctionPointer', 'transparent', true);
testValid('NamespacedImpliedInternalFunctionPointer', 'transparent', false);
testOverride(
'StructImpliedInternalFunctionPointer',
'NamespacedImpliedInternalFunctionPointer',
'transparent',
{ unsafeAllow: ['internal-function-storage'] },
true,
);

testValid('UsesStandaloneStructInternalFn', 'transparent', false);
testValid('StructUsesStandaloneStructInternalFn', 'transparent', false);
testValid('NamespacedUsesStandaloneStructInternalFn', 'transparent', false);
testValid('RecursiveStructInternalFn', 'transparent', false);
testValid('MappingRecursiveStructInternalFn', 'transparent', false);
testValid('ArrayRecursiveStructInternalFn', 'transparent', false);
Expand All @@ -163,12 +191,4 @@ testValid('InternalFunctionPointer', 'transparent', false);
testValid('ImpliedInternalFunctionPointer', 'transparent', false);
testOverride('ImpliedInternalFunctionPointer', 'transparent', { unsafeAllow: ['internal-function-storage'] }, true);

test('ambiguous name', t => {
const error = t.throws(() => getContractVersion(t.context.validation, 'SameName'));
t.is(
error?.message,
'Contract SameName is ambiguous. Use one of the following:\n' +
'contracts/test/ValidationsSameNameSafe.sol:SameName\n' +
'contracts/test/ValidationsSameNameUnsafe.sol:SameName',
);
});
testValid('FunctionWithInternalFunctionPointer', 'transparent', true);
31 changes: 30 additions & 1 deletion packages/core/src/validate/run.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import type {
StructDefinition,
TypeName,
UserDefinedTypeName,
VariableDeclaration,
} from 'solidity-ast';
import debug from '../utils/debug';

Expand Down Expand Up @@ -586,7 +587,7 @@ function* getInternalFunctionStorageErrors(
): Generator<ValidationError> {
// Note: Solidity does not allow annotations for non-public state variables, nor recursive types for public variables,
// so annotations cannot be used to skip these checks.
for (const variableDec of findAll('VariableDeclaration', contractOrStructDef)) {
for (const variableDec of getVariableDeclarations(contractOrStructDef, visitedNodeIds)) {
if (variableDec.typeName?.nodeType === 'FunctionTypeName' && variableDec.typeName.visibility === 'internal') {
// Find internal function types directly in this node's scope
yield {
Expand All @@ -608,6 +609,34 @@ function* getInternalFunctionStorageErrors(
}
}

/**
* Gets variables declared directly in a contract or in a struct definition.
*
* If this is a contract with struct definitions annotated with a storage location according to ERC-7201,
* then the struct members are also included.
*/
function* getVariableDeclarations(
contractOrStructDef: ContractDefinition | StructDefinition,
visitedNodeIds: Set<number>,
): Generator<VariableDeclaration, void, undefined> {
if (contractOrStructDef.nodeType === 'ContractDefinition') {
for (const node of contractOrStructDef.nodes) {
if (node.nodeType === 'VariableDeclaration') {
yield node;
} else if (
node.nodeType === 'StructDefinition' &&
getStorageLocationAnnotation(node) !== undefined &&
!visitedNodeIds.has(node.id)
) {
visitedNodeIds.add(node.id);
yield* getVariableDeclarations(node, visitedNodeIds);
}
}
} else if (contractOrStructDef.nodeType === 'StructDefinition') {
yield* contractOrStructDef.members;
}
}

/**
* Recursively traverse array and mapping types to find user-defined types (which may be struct references).
*/
Expand Down

0 comments on commit 32a2274

Please sign in to comment.