From f5c8cb396f0606ab0d978e147a77fdc7ede753d2 Mon Sep 17 00:00:00 2001 From: Gas One Cent <86567384+gas1cent@users.noreply.github.com> Date: Wed, 29 Nov 2023 22:18:01 +0400 Subject: [PATCH] fix: finalization and unit tests --- solidity/contracts/Oracle.sol | 44 +++++++++++++++++++++++++-------- solidity/test/unit/Oracle.t.sol | 12 +++++++-- 2 files changed, 44 insertions(+), 12 deletions(-) diff --git a/solidity/contracts/Oracle.sol b/solidity/contracts/Oracle.sol index 345f94d..ba3e5a0 100644 --- a/solidity/contracts/Oracle.sol +++ b/solidity/contracts/Oracle.sol @@ -253,6 +253,26 @@ contract Oracle is IOracle { } } + function _matchBytes32(bytes32 _sought, bytes memory _bytes) internal pure returns (bool _found) { + assembly { + let length := mload(_bytes) + let i := 0 + + // Iterate 32-bytes chunks of the list + for {} lt(i, length) { i := add(i, 32) } { + // Load the address at index i + let _chunk := mload(add(add(_bytes, 0x20), i)) + + // Shift the address to the right by 96 bits and compare with _sought + if eq(_chunk, _sought) { + // Set _found to true and return + _found := 1 + break + } + } + } + } + /// @inheritdoc IOracle function allowedModule(bytes32 _requestId, address _module) external view returns (bool _isAllowed) { _isAllowed = _matchBytes(_module, _allowedModules[_requestId]); @@ -286,15 +306,13 @@ contract Oracle is IOracle { /// @inheritdoc IOracle function finalize(IOracle.Request calldata _request, IOracle.Response calldata _response) external { - bytes32 _responseId = _validateResponse(_request, _response); - - if (finalizedAt[_response.requestId] != 0) { - revert Oracle_AlreadyFinalized(_response.requestId); - } + bytes32 _requestId; + bytes32 _responseId; // Finalizing without a response (by passing a Response with `requestId` == 0x0) if (_response.requestId == bytes32(0)) { - bytes32[] memory _responses = getResponseIds(_response.requestId); + _requestId = keccak256(abi.encode(_request)); + bytes32[] memory _responses = getResponseIds(_requestId); uint256 _responsesAmount = _responses.length; if (_responsesAmount != 0) { @@ -316,7 +334,9 @@ contract Oracle is IOracle { _responseId = bytes32(0); } } else { - if (_response.requestId != _response.requestId) { + _responseId = _validateResponse(_request, _response); + _requestId = _response.requestId; + if (!_matchBytes32(_responseId, _responseIds[_requestId])) { revert Oracle_InvalidFinalizedResponse(_responseId); } @@ -326,10 +346,14 @@ contract Oracle is IOracle { revert Oracle_InvalidFinalizedResponse(_responseId); } - _finalizedResponses[_response.requestId] = _responseId; + _finalizedResponses[_requestId] = _responseId; + } + + if (finalizedAt[_requestId] != 0) { + revert Oracle_AlreadyFinalized(_requestId); } - finalizedAt[_response.requestId] = uint128(block.number); + finalizedAt[_requestId] = uint128(block.number); if (address(_request.finalityModule) != address(0)) { IFinalityModule(_request.finalityModule).finalizeRequest(_request, _response, msg.sender); @@ -343,7 +367,7 @@ contract Oracle is IOracle { IResponseModule(_request.responseModule).finalizeRequest(_request, _response, msg.sender); IRequestModule(_request.requestModule).finalizeRequest(_request, _response, msg.sender); - emit OracleRequestFinalized(_response.requestId, _responseId, msg.sender, block.number); + emit OracleRequestFinalized(_requestId, _responseId, msg.sender, block.number); } /** diff --git a/solidity/test/unit/Oracle.t.sol b/solidity/test/unit/Oracle.t.sol index e593330..bd1d5a0 100644 --- a/solidity/test/unit/Oracle.t.sol +++ b/solidity/test/unit/Oracle.t.sol @@ -664,6 +664,8 @@ contract Unit_Finalize is BaseTest { ) public setResolutionAndFinality(_useResolutionAndFinality) { bytes32 _requestId = _getId(mockRequest); mockResponse.requestId = _requestId; + bytes32 _responseId = _getId(mockResponse); + oracle.mock_addResponseId(_requestId, _responseId); // Mock the finalize call on all modules bytes memory _calldata = abi.encodeCall(IModule.finalizeRequest, (mockRequest, mockResponse, _caller)); @@ -679,7 +681,7 @@ contract Unit_Finalize is BaseTest { // Check: emits OracleRequestFinalized event? _expectEmit(address(oracle)); - emit OracleRequestFinalized(_requestId, _getId(mockResponse), _caller, block.number); + emit OracleRequestFinalized(_requestId, _responseId, _caller, block.number); // Test: finalize the request vm.prank(_caller); @@ -691,6 +693,8 @@ contract Unit_Finalize is BaseTest { // Test: finalize a finalized request oracle.mock_setFinalizedAt(_requestId, uint128(block.number)); + oracle.mock_addResponseId(_requestId, _getId(mockResponse)); + vm.expectRevert(abi.encodeWithSelector(IOracle.Oracle_AlreadyFinalized.selector, _requestId)); vm.prank(_caller); oracle.finalize(mockRequest, mockResponse); @@ -700,6 +704,8 @@ contract Unit_Finalize is BaseTest { * @notice Test the response validation, its requestId should match the id of the provided request */ function test_finalize_revertsInvalidRequestId(address _caller, bytes32 _requestId) public assumeFuzzable(_caller) { + vm.assume(_requestId != bytes32(0) && _requestId != _getId(mockRequest)); + mockResponse.requestId = _requestId; bytes32 _responseId = _getId(mockResponse); @@ -723,6 +729,7 @@ contract Unit_Finalize is BaseTest { vm.assume(_caller != address(0)); bytes32 _requestId = _getId(mockRequest); + mockResponse.requestId = bytes32(0); // Create mock request and store it bytes memory _calldata = abi.encodeCall(IModule.finalizeRequest, (mockRequest, mockResponse, _caller)); @@ -738,7 +745,7 @@ contract Unit_Finalize is BaseTest { // Check: emits OracleRequestFinalized event? _expectEmit(address(oracle)); - emit OracleRequestFinalized(_requestId, _getId(mockResponse), _caller, block.number); + emit OracleRequestFinalized(_requestId, bytes32(0), _caller, block.number); // Test: finalize the request vm.prank(_caller); @@ -799,6 +806,7 @@ contract Unit_Finalize is BaseTest { bytes32 _disputeId = _getId(mockDispute); // The response must be disputed + oracle.mock_addResponseId(_requestId, _responseId); oracle.mock_setDisputeOf(_responseId, _disputeId); oracle.mock_setDisputeStatus(_disputeId, _disputeStatus);