From 75af5a468242cadf1e97c3bbc9233baa8682938a Mon Sep 17 00:00:00 2001 From: Gas <86567384+gas1cent@users.noreply.github.com> Date: Tue, 12 Dec 2023 15:28:39 +0400 Subject: [PATCH] refactor: reuse `matchBytes` for address and bytes32 (#14) --- solidity/contracts/Oracle.sol | 54 ++++++++++++----------------------- 1 file changed, 18 insertions(+), 36 deletions(-) diff --git a/solidity/contracts/Oracle.sol b/solidity/contracts/Oracle.sol index 6291041..74e1a93 100644 --- a/solidity/contracts/Oracle.sol +++ b/solidity/contracts/Oracle.sol @@ -222,46 +222,28 @@ contract Oracle is IOracle { } /** - * @notice Confirms wether the address is in the list or not + * @notice Matches a bytes32 value against an encoded list of values * - * @param _sought The address to look for - * @param _bytes The list of addresses packed together - * @return _found Whether the address was found or not + * @param _sought The value to be matched + * @param _list The encoded list of values + * @param _chunkSize The size of each chunk in bytes + * @return _found Whether the value was found or not */ - function _matchBytes(address _sought, bytes memory _bytes) internal pure returns (bool _found) { + function _matchBytes(bytes32 _sought, bytes memory _list, uint256 _chunkSize) internal pure returns (bool _found) { assembly { - let length := mload(_bytes) + let length := mload(_list) let i := 0 + let shiftBy := sub(256, mul(_chunkSize, 8)) - // Iterate 20-bytes chunks of the list - for {} lt(i, length) { i := add(i, 20) } { - // Load the address at index i - let _chunk := mload(add(add(_bytes, 0x20), i)) + // Iterate N-bytes chunks of the list + for {} lt(i, length) { i := add(i, _chunkSize) } { + // Load the value at index i + let _chunk := mload(add(add(_list, 0x20), i)) - // Shift the address to the right by 96 bits and compare with _sought - if eq(shr(96, _chunk), _sought) { + // Shift the value to the right and compare with _sought + if eq(shr(shiftBy, _chunk), _sought) { // Set _found to true and return - _found := 1 - break - } - } - } - } - - function _matchBytes32(bytes32 _sought, bytes memory _bytes) internal pure returns (bool _found) { - assembly { - let length := mload(_bytes) - let i := 0 - - // Iterate 32-bytes chunks of the list - for {} lt(i, length) { i := add(i, 32) } { - // Load the address at index i - let _chunk := mload(add(add(_bytes, 0x20), i)) - - // Shift the address to the right by 96 bits and compare with _sought - if eq(_chunk, _sought) { - // Set _found to true and return - _found := 1 + _found := true break } } @@ -270,12 +252,12 @@ contract Oracle is IOracle { /// @inheritdoc IOracle function allowedModule(bytes32 _requestId, address _module) external view returns (bool _isAllowed) { - _isAllowed = _matchBytes(_module, _allowedModules[_requestId]); + _isAllowed = _matchBytes(bytes32(uint256(uint160(_module))), _allowedModules[_requestId], 20); } /// @inheritdoc IOracle function isParticipant(bytes32 _requestId, address _user) external view returns (bool _isParticipant) { - _isParticipant = _matchBytes(_user, _participants[_requestId]); + _isParticipant = _matchBytes(bytes32(uint256(uint160(_user))), _participants[_requestId], 20); } /// @inheritdoc IOracle @@ -375,7 +357,7 @@ contract Oracle is IOracle { ) internal returns (bytes32 _requestId, bytes32 _responseId) { _responseId = _validateResponse(_request, _response); _requestId = _response.requestId; - if (!_matchBytes32(_responseId, _responseIds[_requestId])) { + if (!_matchBytes(_responseId, _responseIds[_requestId], 32)) { revert Oracle_InvalidFinalizedResponse(); }