Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid validation error when function parameter has internal function pointer #1038

Merged
merged 11 commits into from
Jun 18, 2024
Merged
4 changes: 4 additions & 0 deletions packages/core/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Changelog

## 1.34.1 (2024-06-18)

- Fix unexpected validation error when function parameter has internal function pointer. ([#1038](https://github.com/OpenZeppelin/openzeppelin-upgrades/pull/1038))

## 1.34.0 (2024-06-12)

- Fix storage layout comparison for function types, disallow internal functions in storage. ([#1032](https://github.com/OpenZeppelin/openzeppelin-upgrades/pull/1032))
Expand Down
70 changes: 0 additions & 70 deletions packages/core/contracts/test/Validations.sol
Original file line number Diff line number Diff line change
Expand Up @@ -206,73 +206,3 @@ contract TransitiveLibraryIsUnsafe {
DirectLibrary.f2();
}
}

contract StructExternalFunctionPointer {
struct S {
function(bool) external foo;
}
}

contract StructInternalFunctionPointer {
struct S {
function(bool) internal foo;
}
}

contract StructImpliedInternalFunctionPointer {
struct S {
function(bool) foo;
}
}

struct StandaloneStructInternalFn {
function(bool) internal foo;
}

contract UsesStandaloneStructInternalFn {
StandaloneStructInternalFn bad;
}

contract StructUsesStandaloneStructInternalFn {
struct Bad {
StandaloneStructInternalFn bad;
}
}

contract RecursiveStructInternalFn {
StructUsesStandaloneStructInternalFn.Bad bad;
}

contract MappingRecursiveStructInternalFn {
mapping(address => mapping(address => StructUsesStandaloneStructInternalFn.Bad)) bad;
}

contract ArrayRecursiveStructInternalFn {
StructUsesStandaloneStructInternalFn.Bad[][] bad;
}

contract SelfRecursiveMappingStructInternalFn {
struct SelfRecursive {
mapping(address => SelfRecursive) selfReference;
mapping(address => StructUsesStandaloneStructInternalFn.Bad) bad;
}
}

contract SelfRecursiveArrayStructInternalFn {
struct SelfRecursiveArray {
SelfRecursiveArray[] selfReference;
StructUsesStandaloneStructInternalFn.Bad[] bad;
}
}

contract ExternalFunctionPointer {
function(bool) external foo;
}

contract InternalFunctionPointer {
function(bool) internal foo;
}

contract ImpliedInternalFunctionPointer {
function(bool) foo;
}
116 changes: 116 additions & 0 deletions packages/core/contracts/test/ValidationsFunctionPointers.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
// SPDX-License-Identifier: MIT
pragma solidity ^0.8.20;

contract NamespacedExternalFunctionPointer {
/// @custom:storage-location erc7201:example.main
struct S {
function(bool) external foo;
}
}

contract NamespacedInternalFunctionPointer {
/// @custom:storage-location erc7201:example.main
struct S {
function(bool) internal foo;
}
}

contract NamespacedInternalFunctionPointerUsed {
/// @custom:storage-location erc7201:example.main
struct S {
function(bool) internal foo;
}
S s; // NOTE: This is unsafe usage of a namespace!
}

contract StructInternalFunctionPointerUsed {
// not a namespace, but it is referenced
struct S {
function(bool) internal foo;
}
S s;
}

contract NonNamespacedInternalFunctionPointer {
// not a namespace, and not referenced
struct S {
function(bool) internal foo;
}
}

contract NamespacedImpliedInternalFunctionPointer {
/// @custom:storage-location erc7201:example.main
struct S {
function(bool) foo;
}
}

struct StandaloneStructInternalFn {
function(bool) internal foo;
}

contract UsesStandaloneStructInternalFn {
StandaloneStructInternalFn bad;
}

struct StructUsesStandaloneStructInternalFn {
StandaloneStructInternalFn bad;
}

contract NamespacedUsesStandaloneStructInternalFn {
/// @custom:storage-location erc7201:example.main
struct Bad {
StandaloneStructInternalFn bad;
}
}

contract RecursiveStructInternalFn {
StructUsesStandaloneStructInternalFn bad;
}

contract MappingRecursiveStructInternalFn {
mapping(address => mapping(address => StructUsesStandaloneStructInternalFn)) bad;
}

contract ArrayRecursiveStructInternalFn {
StructUsesStandaloneStructInternalFn[][] bad;
}

contract SelfRecursiveMappingStructInternalFn {
/// @custom:storage-location erc7201:example.main
struct SelfRecursive {
mapping(address => SelfRecursive) selfReference;
mapping(address => StructUsesStandaloneStructInternalFn) bad;
}
}

contract SelfRecursiveArrayStructInternalFn {
/// @custom:storage-location erc7201:example.main
struct SelfRecursiveArray {
SelfRecursiveArray[] selfReference;
StructUsesStandaloneStructInternalFn[] bad;
}
}

contract ExternalFunctionPointer {
function(bool) external foo;
}

contract InternalFunctionPointer {
function(bool) internal foo;
}

contract ImpliedInternalFunctionPointer {
function(bool) foo;
}

contract FunctionWithInternalFunctionPointer {
uint208 x;

function doOp(
function(uint208, uint208) view returns (uint208) op,
uint208 y
) internal view returns (uint208) {
return op(x, y);
}
}
2 changes: 1 addition & 1 deletion packages/core/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@openzeppelin/upgrades-core",
"version": "1.34.0",
"version": "1.34.1",
"description": "",
"repository": "https://github.com/OpenZeppelin/openzeppelin-upgrades/tree/master/packages/core",
"license": "MIT",
Expand Down
56 changes: 38 additions & 18 deletions packages/core/src/validate.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import {
assertUpgradeSafe,
ValidationOptions,
RunValidation,
ValidationErrors,
} from './validate';
import { solcInputOutputDecoder } from './src-decoder';

Expand All @@ -23,6 +24,7 @@ test.before(async t => {
'contracts/test/ValidationsNatspec.sol:HasNonEmptyConstructorNatspec1',
'contracts/test/Proxiable.sol:ChildOfProxiable',
'contracts/test/ValidationsUDVT.sol:ValidationsUDVT',
'contracts/test/ValidationsFunctionPointers.sol:InternalFunctionPointer',
];

t.context.validation = {} as RunValidation;
Expand All @@ -38,11 +40,21 @@ test.before(async t => {
}
});

function testValid(name: string, kind: ValidationOptions['kind'], valid: boolean) {
testOverride(name, kind, {}, valid);
function testValid(name: string, kind: ValidationOptions['kind'], valid: boolean, numExpectedErrors?: number) {
testOverride(name, kind, {}, valid, numExpectedErrors);
}

function testOverride(name: string, kind: ValidationOptions['kind'], opts: ValidationOptions, valid: boolean) {
function testOverride(
name: string,
kind: ValidationOptions['kind'],
opts: ValidationOptions,
valid: boolean,
numExpectedErrors?: number,
) {
if (numExpectedErrors !== undefined && numExpectedErrors > 0 && valid) {
throw new Error('Cannot expect errors for a valid contract');
}

const optKeys = Object.keys(opts);
const describeOpts = optKeys.length > 0 ? '(' + optKeys.join(', ') + ')' : '';
const testName = [valid ? 'accepts' : 'rejects', kind, name, describeOpts].join(' ');
Expand All @@ -52,7 +64,10 @@ function testOverride(name: string, kind: ValidationOptions['kind'], opts: Valid
if (valid) {
t.notThrows(assertUpgSafe);
} else {
t.throws(assertUpgSafe);
const error = t.throws(assertUpgSafe) as ValidationErrors;
if (numExpectedErrors !== undefined) {
t.is(error.errors.length, numExpectedErrors);
}
}
});
}
Expand Down Expand Up @@ -140,18 +155,31 @@ testValid('TransitiveLibraryIsUnsafe', 'transparent', false);
testValid('contracts/test/ValidationsSameNameSafe.sol:SameName', 'transparent', true);
testValid('contracts/test/ValidationsSameNameUnsafe.sol:SameName', 'transparent', false);

testValid('StructExternalFunctionPointer', 'transparent', true);
testValid('StructInternalFunctionPointer', 'transparent', false);
testValid('StructImpliedInternalFunctionPointer', 'transparent', false);
test('ambiguous name', t => {
const error = t.throws(() => getContractVersion(t.context.validation, 'SameName'));
t.is(
error?.message,
'Contract SameName is ambiguous. Use one of the following:\n' +
'contracts/test/ValidationsSameNameSafe.sol:SameName\n' +
'contracts/test/ValidationsSameNameUnsafe.sol:SameName',
);
});

testValid('NamespacedExternalFunctionPointer', 'transparent', true);
testValid('NamespacedInternalFunctionPointer', 'transparent', false);
testValid('NamespacedInternalFunctionPointerUsed', 'transparent', false, 1);
testValid('StructInternalFunctionPointerUsed', 'transparent', false, 1);
testValid('NonNamespacedInternalFunctionPointer', 'transparent', true);
testValid('NamespacedImpliedInternalFunctionPointer', 'transparent', false);
testOverride(
'StructImpliedInternalFunctionPointer',
'NamespacedImpliedInternalFunctionPointer',
'transparent',
{ unsafeAllow: ['internal-function-storage'] },
true,
);

testValid('UsesStandaloneStructInternalFn', 'transparent', false);
testValid('StructUsesStandaloneStructInternalFn', 'transparent', false);
testValid('NamespacedUsesStandaloneStructInternalFn', 'transparent', false);
testValid('RecursiveStructInternalFn', 'transparent', false);
testValid('MappingRecursiveStructInternalFn', 'transparent', false);
testValid('ArrayRecursiveStructInternalFn', 'transparent', false);
Expand All @@ -163,12 +191,4 @@ testValid('InternalFunctionPointer', 'transparent', false);
testValid('ImpliedInternalFunctionPointer', 'transparent', false);
testOverride('ImpliedInternalFunctionPointer', 'transparent', { unsafeAllow: ['internal-function-storage'] }, true);

test('ambiguous name', t => {
const error = t.throws(() => getContractVersion(t.context.validation, 'SameName'));
t.is(
error?.message,
'Contract SameName is ambiguous. Use one of the following:\n' +
'contracts/test/ValidationsSameNameSafe.sol:SameName\n' +
'contracts/test/ValidationsSameNameUnsafe.sol:SameName',
);
});
testValid('FunctionWithInternalFunctionPointer', 'transparent', true);
31 changes: 30 additions & 1 deletion packages/core/src/validate/run.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import type {
StructDefinition,
TypeName,
UserDefinedTypeName,
VariableDeclaration,
} from 'solidity-ast';
import debug from '../utils/debug';

Expand Down Expand Up @@ -586,7 +587,7 @@ function* getInternalFunctionStorageErrors(
): Generator<ValidationError> {
// Note: Solidity does not allow annotations for non-public state variables, nor recursive types for public variables,
// so annotations cannot be used to skip these checks.
for (const variableDec of findAll('VariableDeclaration', contractOrStructDef)) {
for (const variableDec of getVariableDeclarations(contractOrStructDef, visitedNodeIds)) {
if (variableDec.typeName?.nodeType === 'FunctionTypeName' && variableDec.typeName.visibility === 'internal') {
// Find internal function types directly in this node's scope
yield {
Expand All @@ -608,6 +609,34 @@ function* getInternalFunctionStorageErrors(
}
}

/**
* Gets variables declared directly in a contract or in a struct definition.
*
* If this is a contract with struct definitions annotated with a storage location according to ERC-7201,
* then the struct members are also included.
*/
function* getVariableDeclarations(
contractOrStructDef: ContractDefinition | StructDefinition,
visitedNodeIds: Set<number>,
): Generator<VariableDeclaration, void, undefined> {
if (contractOrStructDef.nodeType === 'ContractDefinition') {
for (const node of contractOrStructDef.nodes) {
if (node.nodeType === 'VariableDeclaration') {
yield node;
} else if (
node.nodeType === 'StructDefinition' &&
getStorageLocationAnnotation(node) !== undefined &&
!visitedNodeIds.has(node.id)
) {
visitedNodeIds.add(node.id);
yield* getVariableDeclarations(node, visitedNodeIds);
}
}
} else if (contractOrStructDef.nodeType === 'StructDefinition') {
yield* contractOrStructDef.members;
}
}

/**
* Recursively traverse array and mapping types to find user-defined types (which may be struct references).
*/
Expand Down
Loading