From 1f5da9927a117429f938ed697e15bc968ae0121c Mon Sep 17 00:00:00 2001 From: moebius <0xmoebius@tutanota.com> Date: Wed, 22 Nov 2023 13:29:57 -0300 Subject: [PATCH] feat: add parameter validation to module and update unit tests --- solidity/contracts/Module.sol | 41 +++++++++++++++++------------ solidity/contracts/Oracle.sol | 46 +++++++-------------------------- solidity/interfaces/IModule.sol | 12 ++++++++- solidity/test/unit/Oracle.t.sol | 19 +++++++++++--- 4 files changed, 61 insertions(+), 57 deletions(-) diff --git a/solidity/contracts/Module.sol b/solidity/contracts/Module.sol index afc2a8f..d11298a 100644 --- a/solidity/contracts/Module.sol +++ b/solidity/contracts/Module.sol @@ -28,32 +28,39 @@ abstract contract Module is IModule { ) external virtual onlyOracle {} /** - * @notice Computes the id a given request + * @notice Validates the correctness of a request-response pair * * @param _request The request to compute the id for - * @return _id The id the request - */ - function _getId(IOracle.Request calldata _request) internal pure returns (bytes32 _id) { - _id = keccak256(abi.encode(_request)); - } - - /** - * @notice Computes the id a given response - * * @param _response The response to compute the id for - * @return _id The id the response + * @return _responseId The id the response */ - function _getId(IOracle.Response calldata _response) internal pure returns (bytes32 _id) { - _id = keccak256(abi.encode(_response)); + function _validateResponse( + IOracle.Request calldata _request, + IOracle.Response calldata _response + ) internal pure returns (bytes32 _responseId) { + bytes32 _requestId = keccak256(abi.encode(_request)); + _responseId = keccak256(abi.encode(_response)); + if (_response.requestId != _requestId) revert Module_InvalidResponseBody(); } /** - * @notice Computes the id a given dispute + * @notice Validates the correctness of a request-response-dispute triplet * + * @param _request The request to compute the id for + * @param _response The response to compute the id for * @param _dispute The dispute to compute the id for - * @return _id The id the dispute + * @return _disputeId The id the dispute */ - function _getId(IOracle.Dispute calldata _dispute) internal pure returns (bytes32 _id) { - _id = keccak256(abi.encode(_dispute)); + function _validateDispute( + IOracle.Request calldata _request, + IOracle.Response calldata _response, + IOracle.Dispute calldata _dispute + ) internal pure returns (bytes32 _disputeId) { + bytes32 _requestId = keccak256(abi.encode(_request)); + bytes32 _responseId = keccak256(abi.encode(_response)); + _disputeId = keccak256(abi.encode(_dispute)); + + if (_dispute.requestId != _requestId || _dispute.responseId != _responseId) revert Module_InvalidDisputeBody(); + if (_response.requestId != _requestId) revert Module_InvalidResponseBody(); } } diff --git a/solidity/contracts/Oracle.sol b/solidity/contracts/Oracle.sol index e8f0a5b..345f94d 100644 --- a/solidity/contracts/Oracle.sol +++ b/solidity/contracts/Oracle.sol @@ -159,6 +159,10 @@ contract Oracle is IOracle { function escalateDispute(Request calldata _request, Response calldata _response, Dispute calldata _dispute) external { bytes32 _disputeId = _validateDispute(_request, _response, _dispute); + if (disputeOf[_dispute.responseId] != _disputeId) { + revert Oracle_InvalidDisputeId(_disputeId); + } + if (disputeStatus[_disputeId] != DisputeStatus.Active) { revert Oracle_CannotEscalate(_disputeId); } @@ -355,7 +359,7 @@ contract Oracle is IOracle { // @audit what about removing nonces? or how we avoid nonce clashing? if (_requestNonce != _request.nonce || msg.sender != _request.requester) revert Oracle_InvalidRequestBody(); - _requestId = _getId(_request); + _requestId = keccak256(abi.encode(_request)); nonceToRequestId[_requestNonce] = _requestId; createdAt[_requestId] = uint128(block.number); @@ -374,36 +378,6 @@ contract Oracle is IOracle { emit RequestCreated(_requestId, _request, _ipfsHash, block.number); } - /** - * @notice Computes the id a given request - * - * @param _request The request to compute the id for - * @return _id The id the request - */ - function _getId(IOracle.Request calldata _request) internal pure returns (bytes32 _id) { - _id = keccak256(abi.encode(_request)); - } - - /** - * @notice Computes the id a given response - * - * @param _response The response to compute the id for - * @return _id The id the response - */ - function _getId(IOracle.Response calldata _response) internal pure returns (bytes32 _id) { - _id = keccak256(abi.encode(_response)); - } - - /** - * @notice Computes the id a given dispute - * - * @param _dispute The dispute to compute the id for - * @return _id The id the dispute - */ - function _getId(IOracle.Dispute calldata _dispute) internal pure returns (bytes32 _id) { - _id = keccak256(abi.encode(_dispute)); - } - /** * @notice Validates the correctness of a request-response pair * @@ -415,8 +389,8 @@ contract Oracle is IOracle { Request calldata _request, Response calldata _response ) internal pure returns (bytes32 _responseId) { - bytes32 _requestId = _getId(_request); - _responseId = _getId(_response); + bytes32 _requestId = keccak256(abi.encode(_request)); + _responseId = keccak256(abi.encode(_response)); if (_response.requestId != _requestId) revert Oracle_InvalidResponseBody(); } @@ -433,9 +407,9 @@ contract Oracle is IOracle { Response calldata _response, Dispute calldata _dispute ) internal pure returns (bytes32 _disputeId) { - bytes32 _requestId = _getId(_request); - bytes32 _responseId = _getId(_response); - _disputeId = _getId(_dispute); + bytes32 _requestId = keccak256(abi.encode(_request)); + bytes32 _responseId = keccak256(abi.encode(_response)); + _disputeId = keccak256(abi.encode(_dispute)); if (_dispute.requestId != _requestId || _dispute.responseId != _responseId) revert Oracle_InvalidDisputeBody(); if (_response.requestId != _requestId) revert Oracle_InvalidResponseBody(); diff --git a/solidity/interfaces/IModule.sol b/solidity/interfaces/IModule.sol index a0be7fb..e285826 100644 --- a/solidity/interfaces/IModule.sol +++ b/solidity/interfaces/IModule.sol @@ -29,9 +29,19 @@ interface IModule { */ error Module_OnlyOracle(); + /** + * @notice Thrown when the response provided does not match the request + */ + error Module_InvalidResponseBody(); + + /** + * @notice Thrown when the dispute provided does not match the request or response + */ + error Module_InvalidDisputeBody(); + /*/////////////////////////////////////////////////////////////// VARIABLES - //////////////////////////////////////////////////////////////*/ + //////////////////////////////////////////////////////////////*/ /** * @notice Returns the address of the oracle diff --git a/solidity/test/unit/Oracle.t.sol b/solidity/test/unit/Oracle.t.sol index 4057117..8046062 100644 --- a/solidity/test/unit/Oracle.t.sol +++ b/solidity/test/unit/Oracle.t.sol @@ -568,7 +568,13 @@ contract Unit_ResolveDispute is BaseTest { function test_resolveDispute_revertsIfNoResolutionModule() public { // Clear the resolution module mockRequest.resolutionModule = address(0); - mockDispute.requestId = _getId(mockRequest); + bytes32 _requestId = _getId(mockRequest); + + mockResponse.requestId = _requestId; + bytes32 _responseId = _getId(mockResponse); + + mockDispute.requestId = _requestId; + mockDispute.responseId = _responseId; bytes32 _disputeId = _getId(mockDispute); // Mock the dispute @@ -701,7 +707,7 @@ contract Unit_Finalize is BaseTest { oracle.mock_addResponseId(_requestId, _responseId); // Test: finalize the request - vm.expectRevert(abi.encodeWithSelector(IOracle.Oracle_InvalidFinalizedResponse.selector, _responseId)); + vm.expectRevert(IOracle.Oracle_InvalidResponseBody.selector); vm.prank(_caller); oracle.finalize(mockRequest, mockResponse); } @@ -846,7 +852,14 @@ contract Unit_EscalateDispute is BaseTest { function test_escalateDispute_noResolutionModule() public { mockRequest.resolutionModule = address(0); - mockDispute.requestId = _getId(mockRequest); + + bytes32 _requestId = _getId(mockRequest); + + mockResponse.requestId = _requestId; + bytes32 _responseId = _getId(mockResponse); + + mockDispute.requestId = _requestId; + mockDispute.responseId = _responseId; bytes32 _disputeId = _getId(mockDispute); oracle.mock_setDisputeOf(_getId(mockResponse), _disputeId);