From c8bd87364a7de56fe03df36adf5c7d613f3df418 Mon Sep 17 00:00:00 2001 From: moebius <132487952+0xmoebius@users.noreply.github.com> Date: Fri, 24 Nov 2023 04:37:17 -0300 Subject: [PATCH] feat: introduce parameter validation for structs (#10) --- solidity/contracts/Module.sol | 37 ++++++++++ solidity/contracts/Oracle.sol | 124 ++++++++++++++------------------ solidity/interfaces/IModule.sol | 12 +++- solidity/test/unit/Oracle.t.sol | 19 ++++- 4 files changed, 117 insertions(+), 75 deletions(-) diff --git a/solidity/contracts/Module.sol b/solidity/contracts/Module.sol index afc2a8f..de5233c 100644 --- a/solidity/contracts/Module.sol +++ b/solidity/contracts/Module.sol @@ -56,4 +56,41 @@ abstract contract Module is IModule { function _getId(IOracle.Dispute calldata _dispute) internal pure returns (bytes32 _id) { _id = keccak256(abi.encode(_dispute)); } + + /** + * @notice Validates the correctness of a request-response pair + * + * @param _request The request to compute the id for + * @param _response The response to compute the id for + * @return _responseId The id the response + */ + function _validateResponse( + IOracle.Request calldata _request, + IOracle.Response calldata _response + ) internal pure returns (bytes32 _responseId) { + bytes32 _requestId = _getId(_request); + _responseId = _getId(_response); + if (_response.requestId != _requestId) revert Module_InvalidResponseBody(); + } + + /** + * @notice Validates the correctness of a request-response-dispute triplet + * + * @param _request The request to compute the id for + * @param _response The response to compute the id for + * @param _dispute The dispute to compute the id for + * @return _disputeId The id the dispute + */ + function _validateDispute( + IOracle.Request calldata _request, + IOracle.Response calldata _response, + IOracle.Dispute calldata _dispute + ) internal pure returns (bytes32 _disputeId) { + bytes32 _requestId = _getId(_request); + bytes32 _responseId = _getId(_response); + _disputeId = _getId(_dispute); + + if (_dispute.requestId != _requestId || _dispute.responseId != _responseId) revert Module_InvalidDisputeBody(); + if (_response.requestId != _requestId) revert Module_InvalidResponseBody(); + } } diff --git a/solidity/contracts/Oracle.sol b/solidity/contracts/Oracle.sol index e4fd462..345f94d 100644 --- a/solidity/contracts/Oracle.sol +++ b/solidity/contracts/Oracle.sol @@ -103,29 +103,23 @@ contract Oracle is IOracle { Request calldata _request, Response calldata _response ) external returns (bytes32 _responseId) { - bytes32 _requestId = _getId(_request); + _responseId = _validateResponse(_request, _response); // The caller must be the proposer, unless the response is coming from a dispute module if (msg.sender != _response.proposer && msg.sender != address(_request.disputeModule)) { revert Oracle_InvalidResponseBody(); } - // The request id must match the response's request id - if (_response.requestId != _requestId) { - revert Oracle_InvalidResponseBody(); - } - - if (finalizedAt[_requestId] != 0) { - revert Oracle_AlreadyFinalized(_requestId); + if (finalizedAt[_response.requestId] != 0) { + revert Oracle_AlreadyFinalized(_response.requestId); } - _responseId = _getId(_response); - _participants[_requestId] = abi.encodePacked(_participants[_requestId], _response.proposer); + _participants[_response.requestId] = abi.encodePacked(_participants[_response.requestId], _response.proposer); IResponseModule(_request.responseModule).propose(_request, _response, msg.sender); - _responseIds[_requestId] = abi.encodePacked(_responseIds[_requestId], _responseId); + _responseIds[_response.requestId] = abi.encodePacked(_responseIds[_response.requestId], _responseId); createdAt[_responseId] = uint128(block.number); - emit ResponseProposed(_requestId, _responseId, _response, block.number); + emit ResponseProposed(_response.requestId, _responseId, _response, block.number); } /// @inheritdoc IOracle @@ -134,49 +128,38 @@ contract Oracle is IOracle { Response calldata _response, Dispute calldata _dispute ) external returns (bytes32 _disputeId) { - bytes32 _requestId = _getId(_request); - bytes32 _responseId = _getId(_response); - - if (_requestId != _response.requestId) { - revert Oracle_InvalidResponseBody(); - } - - if (_dispute.responseId != _responseId || _dispute.disputer != msg.sender) { - revert Oracle_InvalidDisputeBody(); - } + _disputeId = _validateDispute(_request, _response, _dispute); // TODO: Check for createdAt instead? // if(_participants[_requestId].length == 0) { // revert(); // } - if (finalizedAt[_requestId] != 0) { - revert Oracle_AlreadyFinalized(_requestId); + if (finalizedAt[_response.requestId] != 0) { + revert Oracle_AlreadyFinalized(_response.requestId); } // TODO: Allow multiple disputes per response to prevent an attacker starting and losing a dispute, // making it impossible for non-malicious actors to dispute a response? - if (disputeOf[_responseId] != bytes32(0)) { - revert Oracle_ResponseAlreadyDisputed(_responseId); + if (disputeOf[_dispute.responseId] != bytes32(0)) { + revert Oracle_ResponseAlreadyDisputed(_dispute.responseId); } - _disputeId = _getId(_dispute); - _participants[_requestId] = abi.encodePacked(_participants[_requestId], msg.sender); + _participants[_response.requestId] = abi.encodePacked(_participants[_response.requestId], msg.sender); disputeStatus[_disputeId] = DisputeStatus.Active; - disputeOf[_responseId] = _disputeId; + disputeOf[_dispute.responseId] = _disputeId; createdAt[_disputeId] = uint128(block.number); IDisputeModule(_request.disputeModule).disputeResponse(_request, _response, _dispute); - emit ResponseDisputed(_responseId, _disputeId, _dispute, block.number); + emit ResponseDisputed(_dispute.responseId, _disputeId, _dispute, block.number); } /// @inheritdoc IOracle function escalateDispute(Request calldata _request, Response calldata _response, Dispute calldata _dispute) external { - bytes32 _requestId = _getId(_request); - bytes32 _disputeId = _getId(_dispute); + bytes32 _disputeId = _validateDispute(_request, _response, _dispute); - if (_dispute.requestId != _requestId || disputeOf[_dispute.responseId] != _disputeId) { + if (disputeOf[_dispute.responseId] != _disputeId) { revert Oracle_InvalidDisputeId(_disputeId); } @@ -200,10 +183,9 @@ contract Oracle is IOracle { /// @inheritdoc IOracle function resolveDispute(Request calldata _request, Response calldata _response, Dispute calldata _dispute) external { - bytes32 _requestId = _getId(_request); - bytes32 _disputeId = _getId(_dispute); + bytes32 _disputeId = _validateDispute(_request, _response, _dispute); - if (_dispute.requestId != _requestId || disputeOf[_dispute.responseId] != _disputeId) { + if (disputeOf[_dispute.responseId] != _disputeId) { revert Oracle_InvalidDisputeId(_disputeId); } @@ -229,12 +211,9 @@ contract Oracle is IOracle { Dispute calldata _dispute, DisputeStatus _status ) external { - bytes32 _disputeId = _getId(_dispute); - bytes32 _requestId = _getId(_request); + bytes32 _disputeId = _validateDispute(_request, _response, _dispute); - if (_response.requestId != _requestId || createdAt[_requestId] == 0) revert Oracle_InvalidRequestBody(); - - if (disputeOf[_getId(_response)] != _disputeId) { + if (disputeOf[_dispute.responseId] != _disputeId) { revert Oracle_InvalidDisputeId(_disputeId); } @@ -307,16 +286,15 @@ contract Oracle is IOracle { /// @inheritdoc IOracle function finalize(IOracle.Request calldata _request, IOracle.Response calldata _response) external { - bytes32 _requestId = _getId(_request); - bytes32 _responseId; + bytes32 _responseId = _validateResponse(_request, _response); - if (finalizedAt[_requestId] != 0) { - revert Oracle_AlreadyFinalized(_requestId); + if (finalizedAt[_response.requestId] != 0) { + revert Oracle_AlreadyFinalized(_response.requestId); } // Finalizing without a response (by passing a Response with `requestId` == 0x0) if (_response.requestId == bytes32(0)) { - bytes32[] memory _responses = getResponseIds(_requestId); + bytes32[] memory _responses = getResponseIds(_response.requestId); uint256 _responsesAmount = _responses.length; if (_responsesAmount != 0) { @@ -338,10 +316,7 @@ contract Oracle is IOracle { _responseId = bytes32(0); } } else { - // Finalizing with a response - _responseId = _getId(_response); - - if (_response.requestId != _requestId) { + if (_response.requestId != _response.requestId) { revert Oracle_InvalidFinalizedResponse(_responseId); } @@ -351,10 +326,10 @@ contract Oracle is IOracle { revert Oracle_InvalidFinalizedResponse(_responseId); } - _finalizedResponses[_requestId] = _responseId; + _finalizedResponses[_response.requestId] = _responseId; } - finalizedAt[_requestId] = uint128(block.number); + finalizedAt[_response.requestId] = uint128(block.number); if (address(_request.finalityModule) != address(0)) { IFinalityModule(_request.finalityModule).finalizeRequest(_request, _response, msg.sender); @@ -368,7 +343,7 @@ contract Oracle is IOracle { IResponseModule(_request.responseModule).finalizeRequest(_request, _response, msg.sender); IRequestModule(_request.requestModule).finalizeRequest(_request, _response, msg.sender); - emit OracleRequestFinalized(_requestId, _responseId, msg.sender, block.number); + emit OracleRequestFinalized(_response.requestId, _responseId, msg.sender, block.number); } /** @@ -384,7 +359,7 @@ contract Oracle is IOracle { // @audit what about removing nonces? or how we avoid nonce clashing? if (_requestNonce != _request.nonce || msg.sender != _request.requester) revert Oracle_InvalidRequestBody(); - _requestId = _getId(_request); + _requestId = keccak256(abi.encode(_request)); nonceToRequestId[_requestNonce] = _requestId; createdAt[_requestId] = uint128(block.number); @@ -404,32 +379,39 @@ contract Oracle is IOracle { } /** - * @notice Computes the id a given request + * @notice Validates the correctness of a request-response pair * * @param _request The request to compute the id for - * @return _id The id the request - */ - function _getId(IOracle.Request calldata _request) internal pure returns (bytes32 _id) { - _id = keccak256(abi.encode(_request)); - } - - /** - * @notice Computes the id a given response - * * @param _response The response to compute the id for - * @return _id The id the response + * @return _responseId The id the response */ - function _getId(IOracle.Response calldata _response) internal pure returns (bytes32 _id) { - _id = keccak256(abi.encode(_response)); + function _validateResponse( + Request calldata _request, + Response calldata _response + ) internal pure returns (bytes32 _responseId) { + bytes32 _requestId = keccak256(abi.encode(_request)); + _responseId = keccak256(abi.encode(_response)); + if (_response.requestId != _requestId) revert Oracle_InvalidResponseBody(); } /** - * @notice Computes the id a given dispute + * @notice Validates the correctness of a request-response-dispute triplet * + * @param _request The request to compute the id for + * @param _response The response to compute the id for * @param _dispute The dispute to compute the id for - * @return _id The id the dispute + * @return _disputeId The id the dispute */ - function _getId(IOracle.Dispute calldata _dispute) internal pure returns (bytes32 _id) { - _id = keccak256(abi.encode(_dispute)); + function _validateDispute( + Request calldata _request, + Response calldata _response, + Dispute calldata _dispute + ) internal pure returns (bytes32 _disputeId) { + bytes32 _requestId = keccak256(abi.encode(_request)); + bytes32 _responseId = keccak256(abi.encode(_response)); + _disputeId = keccak256(abi.encode(_dispute)); + + if (_dispute.requestId != _requestId || _dispute.responseId != _responseId) revert Oracle_InvalidDisputeBody(); + if (_response.requestId != _requestId) revert Oracle_InvalidResponseBody(); } } diff --git a/solidity/interfaces/IModule.sol b/solidity/interfaces/IModule.sol index a0be7fb..e285826 100644 --- a/solidity/interfaces/IModule.sol +++ b/solidity/interfaces/IModule.sol @@ -29,9 +29,19 @@ interface IModule { */ error Module_OnlyOracle(); + /** + * @notice Thrown when the response provided does not match the request + */ + error Module_InvalidResponseBody(); + + /** + * @notice Thrown when the dispute provided does not match the request or response + */ + error Module_InvalidDisputeBody(); + /*/////////////////////////////////////////////////////////////// VARIABLES - //////////////////////////////////////////////////////////////*/ + //////////////////////////////////////////////////////////////*/ /** * @notice Returns the address of the oracle diff --git a/solidity/test/unit/Oracle.t.sol b/solidity/test/unit/Oracle.t.sol index 4057117..8046062 100644 --- a/solidity/test/unit/Oracle.t.sol +++ b/solidity/test/unit/Oracle.t.sol @@ -568,7 +568,13 @@ contract Unit_ResolveDispute is BaseTest { function test_resolveDispute_revertsIfNoResolutionModule() public { // Clear the resolution module mockRequest.resolutionModule = address(0); - mockDispute.requestId = _getId(mockRequest); + bytes32 _requestId = _getId(mockRequest); + + mockResponse.requestId = _requestId; + bytes32 _responseId = _getId(mockResponse); + + mockDispute.requestId = _requestId; + mockDispute.responseId = _responseId; bytes32 _disputeId = _getId(mockDispute); // Mock the dispute @@ -701,7 +707,7 @@ contract Unit_Finalize is BaseTest { oracle.mock_addResponseId(_requestId, _responseId); // Test: finalize the request - vm.expectRevert(abi.encodeWithSelector(IOracle.Oracle_InvalidFinalizedResponse.selector, _responseId)); + vm.expectRevert(IOracle.Oracle_InvalidResponseBody.selector); vm.prank(_caller); oracle.finalize(mockRequest, mockResponse); } @@ -846,7 +852,14 @@ contract Unit_EscalateDispute is BaseTest { function test_escalateDispute_noResolutionModule() public { mockRequest.resolutionModule = address(0); - mockDispute.requestId = _getId(mockRequest); + + bytes32 _requestId = _getId(mockRequest); + + mockResponse.requestId = _requestId; + bytes32 _responseId = _getId(mockResponse); + + mockDispute.requestId = _requestId; + mockDispute.responseId = _responseId; bytes32 _disputeId = _getId(mockDispute); oracle.mock_setDisputeOf(_getId(mockResponse), _disputeId);