Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid validation error when function parameter has internal function pointer #1038

Merged
merged 11 commits into from
Jun 18, 2024
4 changes: 4 additions & 0 deletions packages/core/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Changelog

## Unreleases

- Fix unexpected validation error when function parameter has internal function pointer.

## 1.34.0 (2024-06-12)

- Fix storage layout comparison for function types, disallow internal functions in storage. ([#1032](https://github.com/OpenZeppelin/openzeppelin-upgrades/pull/1032))
Expand Down
18 changes: 18 additions & 0 deletions packages/core/contracts/test/Validations.sol
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,13 @@ contract StructInternalFunctionPointer {
}
}

contract StructInternalFunctionPointerUsed {
struct S {
function(bool) internal foo;
}
S s;
}

contract StructImpliedInternalFunctionPointer {
struct S {
function(bool) foo;
Expand Down Expand Up @@ -276,3 +283,14 @@ contract InternalFunctionPointer {
contract ImpliedInternalFunctionPointer {
function(bool) foo;
}

contract FunctionWithInternalFunctionPointer {
uint208 x;

function doOp(
function(uint208, uint208) view returns (uint208) op,
uint208 y
) internal view returns (uint208) {
return op(x, y);
}
}
25 changes: 21 additions & 4 deletions packages/core/src/validate.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import {
assertUpgradeSafe,
ValidationOptions,
RunValidation,
ValidationErrors,
} from './validate';
import { solcInputOutputDecoder } from './src-decoder';

Expand Down Expand Up @@ -38,11 +39,21 @@ test.before(async t => {
}
});

function testValid(name: string, kind: ValidationOptions['kind'], valid: boolean) {
testOverride(name, kind, {}, valid);
function testValid(name: string, kind: ValidationOptions['kind'], valid: boolean, numExpectedErrors?: number) {
testOverride(name, kind, {}, valid, numExpectedErrors);
}

function testOverride(name: string, kind: ValidationOptions['kind'], opts: ValidationOptions, valid: boolean) {
function testOverride(
name: string,
kind: ValidationOptions['kind'],
opts: ValidationOptions,
valid: boolean,
numExpectedErrors?: number,
) {
if (numExpectedErrors !== undefined && numExpectedErrors > 0 && valid) {
throw new Error('Cannot expect errors for a valid contract');
}

const optKeys = Object.keys(opts);
const describeOpts = optKeys.length > 0 ? '(' + optKeys.join(', ') + ')' : '';
const testName = [valid ? 'accepts' : 'rejects', kind, name, describeOpts].join(' ');
Expand All @@ -52,7 +63,10 @@ function testOverride(name: string, kind: ValidationOptions['kind'], opts: Valid
if (valid) {
t.notThrows(assertUpgSafe);
} else {
t.throws(assertUpgSafe);
const error = t.throws(assertUpgSafe) as ValidationErrors;
if (numExpectedErrors !== undefined) {
t.is(error.errors.length, numExpectedErrors);
}
}
});
}
Expand Down Expand Up @@ -142,6 +156,7 @@ testValid('contracts/test/ValidationsSameNameUnsafe.sol:SameName', 'transparent'

testValid('StructExternalFunctionPointer', 'transparent', true);
testValid('StructInternalFunctionPointer', 'transparent', false);
testValid('StructInternalFunctionPointerUsed', 'transparent', false, 1);
testValid('StructImpliedInternalFunctionPointer', 'transparent', false);
testOverride(
'StructImpliedInternalFunctionPointer',
Expand All @@ -163,6 +178,8 @@ testValid('InternalFunctionPointer', 'transparent', false);
testValid('ImpliedInternalFunctionPointer', 'transparent', false);
testOverride('ImpliedInternalFunctionPointer', 'transparent', { unsafeAllow: ['internal-function-storage'] }, true);

testValid('FunctionWithInternalFunctionPointer', 'transparent', true);

test('ambiguous name', t => {
const error = t.throws(() => getContractVersion(t.context.validation, 'SameName'));
t.is(
Expand Down
30 changes: 29 additions & 1 deletion packages/core/src/validate/run.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import type {
StructDefinition,
TypeName,
UserDefinedTypeName,
VariableDeclaration,
} from 'solidity-ast';
import debug from '../utils/debug';

Expand Down Expand Up @@ -586,7 +587,7 @@ function* getInternalFunctionStorageErrors(
): Generator<ValidationError> {
// Note: Solidity does not allow annotations for non-public state variables, nor recursive types for public variables,
// so annotations cannot be used to skip these checks.
for (const variableDec of findAll('VariableDeclaration', contractOrStructDef)) {
for (const variableDec of getVariableDeclarations(contractOrStructDef, visitedNodeIds)) {
if (variableDec.typeName?.nodeType === 'FunctionTypeName' && variableDec.typeName.visibility === 'internal') {
// Find internal function types directly in this node's scope
yield {
Expand All @@ -608,6 +609,33 @@ function* getInternalFunctionStorageErrors(
}
}

/**
* Gets variables declared directly in a contract and its struct definitions, or in a struct definition.
*/
function getVariableDeclarations(
contractOrStructDef: ContractDefinition | StructDefinition,
visitedNodeIds: Set<number>,
): VariableDeclaration[] {
ericglau marked this conversation as resolved.
Show resolved Hide resolved
const results: VariableDeclaration[] = [];
if (contractOrStructDef.nodeType === 'ContractDefinition') {
for (const node of contractOrStructDef.nodes) {
if (node.nodeType === 'VariableDeclaration') {
results.push(node);
} else if (node.nodeType === 'StructDefinition' && !visitedNodeIds.has(node.id)) {
visitedNodeIds.add(node.id);
results.push(...getVariableDeclarations(node, visitedNodeIds));
}
ericglau marked this conversation as resolved.
Show resolved Hide resolved
}
} else if (contractOrStructDef.nodeType === 'StructDefinition') {
for (const member of contractOrStructDef.members) {
if (member.nodeType === 'VariableDeclaration') {
ericglau marked this conversation as resolved.
Show resolved Hide resolved
results.push(member);
}
}
}
return results;
}

/**
* Recursively traverse array and mapping types to find user-defined types (which may be struct references).
*/
Expand Down
Loading