Skip to content

Commit

Permalink
feat: introduce parameter validation for structs (#10)
Browse files Browse the repository at this point in the history
  • Loading branch information
0xmoebius authored Nov 24, 2023
1 parent 7d4c31a commit c8bd873
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 75 deletions.
37 changes: 37 additions & 0 deletions solidity/contracts/Module.sol
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,41 @@ abstract contract Module is IModule {
function _getId(IOracle.Dispute calldata _dispute) internal pure returns (bytes32 _id) {
_id = keccak256(abi.encode(_dispute));
}

/**
* @notice Validates the correctness of a request-response pair
*
* @param _request The request to compute the id for
* @param _response The response to compute the id for
* @return _responseId The id the response
*/
function _validateResponse(
IOracle.Request calldata _request,
IOracle.Response calldata _response
) internal pure returns (bytes32 _responseId) {
bytes32 _requestId = _getId(_request);
_responseId = _getId(_response);
if (_response.requestId != _requestId) revert Module_InvalidResponseBody();
}

/**
* @notice Validates the correctness of a request-response-dispute triplet
*
* @param _request The request to compute the id for
* @param _response The response to compute the id for
* @param _dispute The dispute to compute the id for
* @return _disputeId The id the dispute
*/
function _validateDispute(
IOracle.Request calldata _request,
IOracle.Response calldata _response,
IOracle.Dispute calldata _dispute
) internal pure returns (bytes32 _disputeId) {
bytes32 _requestId = _getId(_request);
bytes32 _responseId = _getId(_response);
_disputeId = _getId(_dispute);

if (_dispute.requestId != _requestId || _dispute.responseId != _responseId) revert Module_InvalidDisputeBody();
if (_response.requestId != _requestId) revert Module_InvalidResponseBody();
}
}
124 changes: 53 additions & 71 deletions solidity/contracts/Oracle.sol
Original file line number Diff line number Diff line change
Expand Up @@ -103,29 +103,23 @@ contract Oracle is IOracle {
Request calldata _request,
Response calldata _response
) external returns (bytes32 _responseId) {
bytes32 _requestId = _getId(_request);
_responseId = _validateResponse(_request, _response);

// The caller must be the proposer, unless the response is coming from a dispute module
if (msg.sender != _response.proposer && msg.sender != address(_request.disputeModule)) {
revert Oracle_InvalidResponseBody();
}

// The request id must match the response's request id
if (_response.requestId != _requestId) {
revert Oracle_InvalidResponseBody();
}

if (finalizedAt[_requestId] != 0) {
revert Oracle_AlreadyFinalized(_requestId);
if (finalizedAt[_response.requestId] != 0) {
revert Oracle_AlreadyFinalized(_response.requestId);
}

_responseId = _getId(_response);
_participants[_requestId] = abi.encodePacked(_participants[_requestId], _response.proposer);
_participants[_response.requestId] = abi.encodePacked(_participants[_response.requestId], _response.proposer);
IResponseModule(_request.responseModule).propose(_request, _response, msg.sender);
_responseIds[_requestId] = abi.encodePacked(_responseIds[_requestId], _responseId);
_responseIds[_response.requestId] = abi.encodePacked(_responseIds[_response.requestId], _responseId);
createdAt[_responseId] = uint128(block.number);

emit ResponseProposed(_requestId, _responseId, _response, block.number);
emit ResponseProposed(_response.requestId, _responseId, _response, block.number);
}

/// @inheritdoc IOracle
Expand All @@ -134,49 +128,38 @@ contract Oracle is IOracle {
Response calldata _response,
Dispute calldata _dispute
) external returns (bytes32 _disputeId) {
bytes32 _requestId = _getId(_request);
bytes32 _responseId = _getId(_response);

if (_requestId != _response.requestId) {
revert Oracle_InvalidResponseBody();
}

if (_dispute.responseId != _responseId || _dispute.disputer != msg.sender) {
revert Oracle_InvalidDisputeBody();
}
_disputeId = _validateDispute(_request, _response, _dispute);

// TODO: Check for createdAt instead?
// if(_participants[_requestId].length == 0) {
// revert();
// }

if (finalizedAt[_requestId] != 0) {
revert Oracle_AlreadyFinalized(_requestId);
if (finalizedAt[_response.requestId] != 0) {
revert Oracle_AlreadyFinalized(_response.requestId);
}

// TODO: Allow multiple disputes per response to prevent an attacker starting and losing a dispute,
// making it impossible for non-malicious actors to dispute a response?
if (disputeOf[_responseId] != bytes32(0)) {
revert Oracle_ResponseAlreadyDisputed(_responseId);
if (disputeOf[_dispute.responseId] != bytes32(0)) {
revert Oracle_ResponseAlreadyDisputed(_dispute.responseId);
}

_disputeId = _getId(_dispute);
_participants[_requestId] = abi.encodePacked(_participants[_requestId], msg.sender);
_participants[_response.requestId] = abi.encodePacked(_participants[_response.requestId], msg.sender);
disputeStatus[_disputeId] = DisputeStatus.Active;
disputeOf[_responseId] = _disputeId;
disputeOf[_dispute.responseId] = _disputeId;
createdAt[_disputeId] = uint128(block.number);

IDisputeModule(_request.disputeModule).disputeResponse(_request, _response, _dispute);

emit ResponseDisputed(_responseId, _disputeId, _dispute, block.number);
emit ResponseDisputed(_dispute.responseId, _disputeId, _dispute, block.number);
}

/// @inheritdoc IOracle
function escalateDispute(Request calldata _request, Response calldata _response, Dispute calldata _dispute) external {
bytes32 _requestId = _getId(_request);
bytes32 _disputeId = _getId(_dispute);
bytes32 _disputeId = _validateDispute(_request, _response, _dispute);

if (_dispute.requestId != _requestId || disputeOf[_dispute.responseId] != _disputeId) {
if (disputeOf[_dispute.responseId] != _disputeId) {
revert Oracle_InvalidDisputeId(_disputeId);
}

Expand All @@ -200,10 +183,9 @@ contract Oracle is IOracle {

/// @inheritdoc IOracle
function resolveDispute(Request calldata _request, Response calldata _response, Dispute calldata _dispute) external {
bytes32 _requestId = _getId(_request);
bytes32 _disputeId = _getId(_dispute);
bytes32 _disputeId = _validateDispute(_request, _response, _dispute);

if (_dispute.requestId != _requestId || disputeOf[_dispute.responseId] != _disputeId) {
if (disputeOf[_dispute.responseId] != _disputeId) {
revert Oracle_InvalidDisputeId(_disputeId);
}

Expand All @@ -229,12 +211,9 @@ contract Oracle is IOracle {
Dispute calldata _dispute,
DisputeStatus _status
) external {
bytes32 _disputeId = _getId(_dispute);
bytes32 _requestId = _getId(_request);
bytes32 _disputeId = _validateDispute(_request, _response, _dispute);

if (_response.requestId != _requestId || createdAt[_requestId] == 0) revert Oracle_InvalidRequestBody();

if (disputeOf[_getId(_response)] != _disputeId) {
if (disputeOf[_dispute.responseId] != _disputeId) {
revert Oracle_InvalidDisputeId(_disputeId);
}

Expand Down Expand Up @@ -307,16 +286,15 @@ contract Oracle is IOracle {

/// @inheritdoc IOracle
function finalize(IOracle.Request calldata _request, IOracle.Response calldata _response) external {
bytes32 _requestId = _getId(_request);
bytes32 _responseId;
bytes32 _responseId = _validateResponse(_request, _response);

if (finalizedAt[_requestId] != 0) {
revert Oracle_AlreadyFinalized(_requestId);
if (finalizedAt[_response.requestId] != 0) {
revert Oracle_AlreadyFinalized(_response.requestId);
}

// Finalizing without a response (by passing a Response with `requestId` == 0x0)
if (_response.requestId == bytes32(0)) {
bytes32[] memory _responses = getResponseIds(_requestId);
bytes32[] memory _responses = getResponseIds(_response.requestId);
uint256 _responsesAmount = _responses.length;

if (_responsesAmount != 0) {
Expand All @@ -338,10 +316,7 @@ contract Oracle is IOracle {
_responseId = bytes32(0);
}
} else {
// Finalizing with a response
_responseId = _getId(_response);

if (_response.requestId != _requestId) {
if (_response.requestId != _response.requestId) {
revert Oracle_InvalidFinalizedResponse(_responseId);
}

Expand All @@ -351,10 +326,10 @@ contract Oracle is IOracle {
revert Oracle_InvalidFinalizedResponse(_responseId);
}

_finalizedResponses[_requestId] = _responseId;
_finalizedResponses[_response.requestId] = _responseId;
}

finalizedAt[_requestId] = uint128(block.number);
finalizedAt[_response.requestId] = uint128(block.number);

if (address(_request.finalityModule) != address(0)) {
IFinalityModule(_request.finalityModule).finalizeRequest(_request, _response, msg.sender);
Expand All @@ -368,7 +343,7 @@ contract Oracle is IOracle {
IResponseModule(_request.responseModule).finalizeRequest(_request, _response, msg.sender);
IRequestModule(_request.requestModule).finalizeRequest(_request, _response, msg.sender);

emit OracleRequestFinalized(_requestId, _responseId, msg.sender, block.number);
emit OracleRequestFinalized(_response.requestId, _responseId, msg.sender, block.number);
}

/**
Expand All @@ -384,7 +359,7 @@ contract Oracle is IOracle {
// @audit what about removing nonces? or how we avoid nonce clashing?
if (_requestNonce != _request.nonce || msg.sender != _request.requester) revert Oracle_InvalidRequestBody();

_requestId = _getId(_request);
_requestId = keccak256(abi.encode(_request));
nonceToRequestId[_requestNonce] = _requestId;
createdAt[_requestId] = uint128(block.number);

Expand All @@ -404,32 +379,39 @@ contract Oracle is IOracle {
}

/**
* @notice Computes the id a given request
* @notice Validates the correctness of a request-response pair
*
* @param _request The request to compute the id for
* @return _id The id the request
*/
function _getId(IOracle.Request calldata _request) internal pure returns (bytes32 _id) {
_id = keccak256(abi.encode(_request));
}

/**
* @notice Computes the id a given response
*
* @param _response The response to compute the id for
* @return _id The id the response
* @return _responseId The id the response
*/
function _getId(IOracle.Response calldata _response) internal pure returns (bytes32 _id) {
_id = keccak256(abi.encode(_response));
function _validateResponse(
Request calldata _request,
Response calldata _response
) internal pure returns (bytes32 _responseId) {
bytes32 _requestId = keccak256(abi.encode(_request));
_responseId = keccak256(abi.encode(_response));
if (_response.requestId != _requestId) revert Oracle_InvalidResponseBody();
}

/**
* @notice Computes the id a given dispute
* @notice Validates the correctness of a request-response-dispute triplet
*
* @param _request The request to compute the id for
* @param _response The response to compute the id for
* @param _dispute The dispute to compute the id for
* @return _id The id the dispute
* @return _disputeId The id the dispute
*/
function _getId(IOracle.Dispute calldata _dispute) internal pure returns (bytes32 _id) {
_id = keccak256(abi.encode(_dispute));
function _validateDispute(
Request calldata _request,
Response calldata _response,
Dispute calldata _dispute
) internal pure returns (bytes32 _disputeId) {
bytes32 _requestId = keccak256(abi.encode(_request));
bytes32 _responseId = keccak256(abi.encode(_response));
_disputeId = keccak256(abi.encode(_dispute));

if (_dispute.requestId != _requestId || _dispute.responseId != _responseId) revert Oracle_InvalidDisputeBody();
if (_response.requestId != _requestId) revert Oracle_InvalidResponseBody();
}
}
12 changes: 11 additions & 1 deletion solidity/interfaces/IModule.sol
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,19 @@ interface IModule {
*/
error Module_OnlyOracle();

/**
* @notice Thrown when the response provided does not match the request
*/
error Module_InvalidResponseBody();

/**
* @notice Thrown when the dispute provided does not match the request or response
*/
error Module_InvalidDisputeBody();

/*///////////////////////////////////////////////////////////////
VARIABLES
//////////////////////////////////////////////////////////////*/
//////////////////////////////////////////////////////////////*/

/**
* @notice Returns the address of the oracle
Expand Down
19 changes: 16 additions & 3 deletions solidity/test/unit/Oracle.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,13 @@ contract Unit_ResolveDispute is BaseTest {
function test_resolveDispute_revertsIfNoResolutionModule() public {
// Clear the resolution module
mockRequest.resolutionModule = address(0);
mockDispute.requestId = _getId(mockRequest);
bytes32 _requestId = _getId(mockRequest);

mockResponse.requestId = _requestId;
bytes32 _responseId = _getId(mockResponse);

mockDispute.requestId = _requestId;
mockDispute.responseId = _responseId;
bytes32 _disputeId = _getId(mockDispute);

// Mock the dispute
Expand Down Expand Up @@ -701,7 +707,7 @@ contract Unit_Finalize is BaseTest {
oracle.mock_addResponseId(_requestId, _responseId);

// Test: finalize the request
vm.expectRevert(abi.encodeWithSelector(IOracle.Oracle_InvalidFinalizedResponse.selector, _responseId));
vm.expectRevert(IOracle.Oracle_InvalidResponseBody.selector);
vm.prank(_caller);
oracle.finalize(mockRequest, mockResponse);
}
Expand Down Expand Up @@ -846,7 +852,14 @@ contract Unit_EscalateDispute is BaseTest {

function test_escalateDispute_noResolutionModule() public {
mockRequest.resolutionModule = address(0);
mockDispute.requestId = _getId(mockRequest);

bytes32 _requestId = _getId(mockRequest);

mockResponse.requestId = _requestId;
bytes32 _responseId = _getId(mockResponse);

mockDispute.requestId = _requestId;
mockDispute.responseId = _responseId;
bytes32 _disputeId = _getId(mockDispute);

oracle.mock_setDisputeOf(_getId(mockResponse), _disputeId);
Expand Down

0 comments on commit c8bd873

Please sign in to comment.