Skip to content
This repository has been archived by the owner on Jan 11, 2024. It is now read-only.

Commit

Permalink
precompile diamond function selectors at compile time (#407)
Browse files Browse the repository at this point in the history
  • Loading branch information
snissn authored Dec 20, 2023
1 parent 3d2820c commit cd1a85e
Show file tree
Hide file tree
Showing 14 changed files with 303 additions and 89 deletions.
7 changes: 6 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,12 @@ make coverage-for-mac: | forge
genhtml -o coverage_report lcov.info --branch-coverage --ignore-errors category
./tools/check_coverage.sh

prepare: fmt lint test slither
prepare: build-selector-library fmt lint test slither


build-selector-library: | forge
python scripts/python/build_selector_library.py
npx prettier -w test/helpers/SelectorLibrary.sol

# Forge is used by the ipc-solidity-actors compilation steps.
.PHONY: forge
Expand Down
116 changes: 116 additions & 0 deletions scripts/python/build_selector_library.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import argparse
import glob
import json
import os
import subprocess
import sys
from eth_abi import encode
from json.decoder import JSONDecodeError

def writeToFile(selector_storage_content):
# Define the file path
file_path = 'test/helpers/SelectorLibrary.sol'

# Write the content to the file
with open(file_path, 'w') as file:
file.write(selector_storage_content)


def generate_solidity_function(contract_selectors):
solidity_code = "// SPDX-License-Identifier: MIT OR Apache-2.0\npragma solidity ^0.8.19;\n"
solidity_code += "library SelectorLibrary {\n"
solidity_code += " function resolveSelectors(string memory facetName) public pure returns (bytes4[] memory facetSelectors) {\n"

for contract_name, selectors in contract_selectors.items():
solidity_code += f' if (keccak256(abi.encodePacked(facetName)) == keccak256(abi.encodePacked("{contract_name}"))) {{\n'
solidity_code += f' return abi.decode(hex"{selectors}", (bytes4[]));\n'
solidity_code += " }\n"

solidity_code += " revert(\"Selector not found\");\n"
solidity_code += " }\n"
solidity_code += "}\n"
return solidity_code

def format_selector(selector_bytes):
hex_str = selector_bytes.hex()
if len(hex_str) % 2 != 0:
hex_str = '0' + hex_str # Add a leading zero if the length is odd
return hex_str

def parse_selectors(encoded_selectors):
# Assuming the encoded selectors are in the format provided in your example
decoded = bytes.fromhex(encoded_selectors[2:]) # Skip the "0x" prefix
return [format_selector(decoded[i:i+4]) for i in range(0, len(decoded), 4)] # Return in chunks of 4 bytes

def get_selectors(contract):
"""This function gets the selectors of the functions of the target contract."""

res = subprocess.run(
["forge", "inspect", contract, "methodIdentifiers"], capture_output=True)
res = res.stdout.decode()
try:
res = json.loads(res)
except JSONDecodeError as e:
print("failed to load JSON:", e)
print("forge output:", res)
print("contract:", contract)
sys.exit(1)

selectors = []
for signature in res:
selector = res[signature]
selectors.append(bytes.fromhex(selector))

enc = encode(["bytes4[]"], [selectors])
return "" + enc.hex()

def main():
contract_selectors = {}
filepaths_to_target = [
'src/GatewayDiamond.sol',
'src/SubnetActorDiamond.sol',
'src/SubnetRegistryDiamond.sol',
'src/diamond/DiamondCutFacet.sol',
'src/diamond/DiamondLoupeFacet.sol',
'src/gateway/GatewayGetterFacet.sol',
'src/gateway/GatewayManagerFacet.sol',
'src/gateway/GatewayMessengerFacet.sol',
'src/gateway/GatewayRouterFacet.sol',
'src/subnet/SubnetActorGetterFacet.sol',
'src/subnet/SubnetActorManagerFacet.sol',
'src/subnetregistry/RegisterSubnetFacet.sol',
'src/subnetregistry/SubnetGetterFacet.sol',
'test/helpers/ERC20PresetFixedSupply.sol',
'test/helpers/NumberContractFacetEight.sol',
'test/helpers/NumberContractFacetSeven.sol',
'test/helpers/SelectorLibrary.sol',
'test/helpers/TestUtils.sol',
'test/mocks/SubnetActorManagerFacetMock.sol',
]

for filepath in filepaths_to_target:

# Extract just the contract name (without path and .sol extension)
contract_name = os.path.splitext(os.path.basename(filepath))[0]

#skip lib or interfaces
if contract_name.startswith("Lib") or contract_name.startswith("I") or contract_name.endswith("Helper"):
continue

# Format full path
# Call get_selectors for each contract
try:
selectors = get_selectors(filepath + ':' + contract_name)
if selectors:
contract_selectors[contract_name] = selectors
except Exception as oops:
print(f"Error processing {filepath}: {oops}")


# Print the final JSON
solidity_library_code = generate_solidity_function(contract_selectors)
writeToFile(solidity_library_code)

if __name__ == "__main__":
main()

47 changes: 0 additions & 47 deletions scripts/python/get_selectors.py

This file was deleted.

33 changes: 17 additions & 16 deletions test/IntegrationTestBase.sol
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import {GatewayGetterFacet} from "../src/gateway/GatewayGetterFacet.sol";
import {GatewayMessengerFacet} from "../src/gateway/GatewayMessengerFacet.sol";
import {GatewayManagerFacet} from "../src/gateway/GatewayManagerFacet.sol";
import {GatewayRouterFacet} from "../src/gateway/GatewayRouterFacet.sol";
import {SubnetActorManagerFacetMock} from "./mocks/SubnetActor.sol";
import {SubnetActorManagerFacetMock} from "./mocks/SubnetActorManagerFacetMock.sol";
import {SubnetActorManagerFacet} from "../src/subnet/SubnetActorManagerFacet.sol";
import {SubnetActorGetterFacet} from "../src/subnet/SubnetActorGetterFacet.sol";
import {SubnetRegistryDiamond} from "../src/SubnetRegistryDiamond.sol";
Expand All @@ -33,6 +33,7 @@ import {DiamondLoupeFacet} from "../src/diamond/DiamondLoupeFacet.sol";
import {DiamondCutFacet} from "../src/diamond/DiamondCutFacet.sol";
import {SupplySourceHelper} from "../src/lib/SupplySourceHelper.sol";
import {TestUtils} from "./helpers/TestUtils.sol";
import {SelectorLibrary} from "./helpers/SelectorLibrary.sol";

contract TestParams {
uint64 constant MAX_NONCE = type(uint64).max;
Expand Down Expand Up @@ -71,10 +72,10 @@ contract TestRegistry is Test, TestParams {
SubnetGetterFacet registrySubnetGetterFacet;

constructor() {
registerSubnetFacetSelectors = TestUtils.generateSelectors(vm, "RegisterSubnetFacet");
registerSubnetGetterFacetSelectors = TestUtils.generateSelectors(vm, "SubnetGetterFacet");
registerCutterSelectors = TestUtils.generateSelectors(vm, "DiamondCutFacet");
registerLouperSelectors = TestUtils.generateSelectors(vm, "DiamondLoupeFacet");
registerSubnetFacetSelectors = SelectorLibrary.resolveSelectors("RegisterSubnetFacet");
registerSubnetGetterFacetSelectors = SelectorLibrary.resolveSelectors("SubnetGetterFacet");
registerCutterSelectors = SelectorLibrary.resolveSelectors("DiamondCutFacet");
registerLouperSelectors = SelectorLibrary.resolveSelectors("DiamondLoupeFacet");
}
}

Expand All @@ -95,12 +96,12 @@ contract TestGatewayActor is Test, TestParams {
DiamondLoupeFacet gwLouper;

constructor() {
gwRouterSelectors = TestUtils.generateSelectors(vm, "GatewayRouterFacet");
gwGetterSelectors = TestUtils.generateSelectors(vm, "GatewayGetterFacet");
gwManagerSelectors = TestUtils.generateSelectors(vm, "GatewayManagerFacet");
gwMessengerSelectors = TestUtils.generateSelectors(vm, "GatewayMessengerFacet");
gwCutterSelectors = TestUtils.generateSelectors(vm, "DiamondCutFacet");
gwLoupeSelectors = TestUtils.generateSelectors(vm, "DiamondLoupeFacet");
gwRouterSelectors = SelectorLibrary.resolveSelectors("GatewayRouterFacet");
gwGetterSelectors = SelectorLibrary.resolveSelectors("GatewayGetterFacet");
gwManagerSelectors = SelectorLibrary.resolveSelectors("GatewayManagerFacet");
gwMessengerSelectors = SelectorLibrary.resolveSelectors("GatewayMessengerFacet");
gwCutterSelectors = SelectorLibrary.resolveSelectors("DiamondCutFacet");
gwLoupeSelectors = SelectorLibrary.resolveSelectors("DiamondLoupeFacet");
}

function defaultGatewayParams() internal pure virtual returns (GatewayDiamond.ConstructorParams memory) {
Expand Down Expand Up @@ -133,11 +134,11 @@ contract TestSubnetActor is Test, TestParams {
DiamondLoupeFacet saLouper;

constructor() {
saGetterSelectors = TestUtils.generateSelectors(vm, "SubnetActorGetterFacet");
saManagerSelectors = TestUtils.generateSelectors(vm, "SubnetActorManagerFacet");
saManagerMockedSelectors = TestUtils.generateSelectors(vm, "SubnetActorManagerFacetMock");
saCutterSelectors = TestUtils.generateSelectors(vm, "DiamondCutFacet");
saLouperSelectors = TestUtils.generateSelectors(vm, "DiamondLoupeFacet");
saGetterSelectors = SelectorLibrary.resolveSelectors("SubnetActorGetterFacet");
saManagerSelectors = SelectorLibrary.resolveSelectors("SubnetActorManagerFacet");
saManagerMockedSelectors = SelectorLibrary.resolveSelectors("SubnetActorManagerFacetMock");
saCutterSelectors = SelectorLibrary.resolveSelectors("DiamondCutFacet");
saLouperSelectors = SelectorLibrary.resolveSelectors("DiamondLoupeFacet");
}

function defaultSubnetActorParamsWithGateway(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
// SPDX-License-Identifier: MIT OR Apache-2.0
pragma solidity 0.8.19;

contract NumberContractFacetSeven {
function getNum() external pure returns (uint8) {
return 7;
}
}

contract NumberContractFacetEight {
function getNum() external pure returns (uint8) {
return 8;
Expand Down
8 changes: 8 additions & 0 deletions test/helpers/NumberContractFacetSeven.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
// SPDX-License-Identifier: MIT OR Apache-2.0
pragma solidity 0.8.19;

contract NumberContractFacetSeven {
function getNum() external pure returns (uint8) {
return 7;
}
}
Loading

0 comments on commit cd1a85e

Please sign in to comment.