Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: introduce parameter validation for structs #10

Merged
merged 3 commits into from
Nov 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions solidity/contracts/Module.sol
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,41 @@ abstract contract Module is IModule {
function _getId(IOracle.Dispute calldata _dispute) internal pure returns (bytes32 _id) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we will still need the _getId function for disputes because some modules rely on the id for their functionality or to emit events. For instance: https://github.com/defi-wonderland/prophet-modules/blob/dev/solidity/contracts/modules/dispute/BondedDisputeModule.sol#L41

_id = keccak256(abi.encode(_dispute));
}

/**
* @notice Validates the correctness of a request-response pair
*
* @param _request The request to compute the id for
* @param _response The response to compute the id for
* @return _responseId The id the response
*/
function _validateResponse(
IOracle.Request calldata _request,
IOracle.Response calldata _response
) internal pure returns (bytes32 _responseId) {
bytes32 _requestId = _getId(_request);
_responseId = _getId(_response);
if (_response.requestId != _requestId) revert Module_InvalidResponseBody();
}

/**
* @notice Validates the correctness of a request-response-dispute triplet
*
* @param _request The request to compute the id for
* @param _response The response to compute the id for
* @param _dispute The dispute to compute the id for
* @return _disputeId The id the dispute
*/
function _validateDispute(
IOracle.Request calldata _request,
IOracle.Response calldata _response,
IOracle.Dispute calldata _dispute
) internal pure returns (bytes32 _disputeId) {
bytes32 _requestId = _getId(_request);
bytes32 _responseId = _getId(_response);
_disputeId = _getId(_dispute);

if (_dispute.requestId != _requestId || _dispute.responseId != _responseId) revert Module_InvalidDisputeBody();
if (_response.requestId != _requestId) revert Module_InvalidResponseBody();
}
}
124 changes: 53 additions & 71 deletions solidity/contracts/Oracle.sol
Original file line number Diff line number Diff line change
Expand Up @@ -103,29 +103,23 @@ contract Oracle is IOracle {
Request calldata _request,
Response calldata _response
) external returns (bytes32 _responseId) {
bytes32 _requestId = _getId(_request);
_responseId = _validateResponse(_request, _response);

// The caller must be the proposer, unless the response is coming from a dispute module
if (msg.sender != _response.proposer && msg.sender != address(_request.disputeModule)) {
revert Oracle_InvalidResponseBody();
}

// The request id must match the response's request id
if (_response.requestId != _requestId) {
revert Oracle_InvalidResponseBody();
}

if (finalizedAt[_requestId] != 0) {
revert Oracle_AlreadyFinalized(_requestId);
if (finalizedAt[_response.requestId] != 0) {
revert Oracle_AlreadyFinalized(_response.requestId);
}

_responseId = _getId(_response);
_participants[_requestId] = abi.encodePacked(_participants[_requestId], _response.proposer);
_participants[_response.requestId] = abi.encodePacked(_participants[_response.requestId], _response.proposer);
IResponseModule(_request.responseModule).propose(_request, _response, msg.sender);
_responseIds[_requestId] = abi.encodePacked(_responseIds[_requestId], _responseId);
_responseIds[_response.requestId] = abi.encodePacked(_responseIds[_response.requestId], _responseId);
createdAt[_responseId] = uint128(block.number);

emit ResponseProposed(_requestId, _responseId, _response, block.number);
emit ResponseProposed(_response.requestId, _responseId, _response, block.number);
}

/// @inheritdoc IOracle
Expand All @@ -134,49 +128,38 @@ contract Oracle is IOracle {
Response calldata _response,
Dispute calldata _dispute
) external returns (bytes32 _disputeId) {
bytes32 _requestId = _getId(_request);
bytes32 _responseId = _getId(_response);

if (_requestId != _response.requestId) {
revert Oracle_InvalidResponseBody();
}

if (_dispute.responseId != _responseId || _dispute.disputer != msg.sender) {
revert Oracle_InvalidDisputeBody();
}
_disputeId = _validateDispute(_request, _response, _dispute);

// TODO: Check for createdAt instead?
// if(_participants[_requestId].length == 0) {
// revert();
// }

if (finalizedAt[_requestId] != 0) {
revert Oracle_AlreadyFinalized(_requestId);
if (finalizedAt[_response.requestId] != 0) {
revert Oracle_AlreadyFinalized(_response.requestId);
}

// TODO: Allow multiple disputes per response to prevent an attacker starting and losing a dispute,
// making it impossible for non-malicious actors to dispute a response?
if (disputeOf[_responseId] != bytes32(0)) {
revert Oracle_ResponseAlreadyDisputed(_responseId);
if (disputeOf[_dispute.responseId] != bytes32(0)) {
revert Oracle_ResponseAlreadyDisputed(_dispute.responseId);
}

_disputeId = _getId(_dispute);
_participants[_requestId] = abi.encodePacked(_participants[_requestId], msg.sender);
_participants[_response.requestId] = abi.encodePacked(_participants[_response.requestId], msg.sender);
disputeStatus[_disputeId] = DisputeStatus.Active;
disputeOf[_responseId] = _disputeId;
disputeOf[_dispute.responseId] = _disputeId;
createdAt[_disputeId] = uint128(block.number);

IDisputeModule(_request.disputeModule).disputeResponse(_request, _response, _dispute);

emit ResponseDisputed(_responseId, _disputeId, _dispute, block.number);
emit ResponseDisputed(_dispute.responseId, _disputeId, _dispute, block.number);
}

/// @inheritdoc IOracle
function escalateDispute(Request calldata _request, Response calldata _response, Dispute calldata _dispute) external {
bytes32 _requestId = _getId(_request);
bytes32 _disputeId = _getId(_dispute);
bytes32 _disputeId = _validateDispute(_request, _response, _dispute);

if (_dispute.requestId != _requestId || disputeOf[_dispute.responseId] != _disputeId) {
if (disputeOf[_dispute.responseId] != _disputeId) {
revert Oracle_InvalidDisputeId(_disputeId);
}

Expand All @@ -200,10 +183,9 @@ contract Oracle is IOracle {

/// @inheritdoc IOracle
function resolveDispute(Request calldata _request, Response calldata _response, Dispute calldata _dispute) external {
bytes32 _requestId = _getId(_request);
bytes32 _disputeId = _getId(_dispute);
bytes32 _disputeId = _validateDispute(_request, _response, _dispute);

if (_dispute.requestId != _requestId || disputeOf[_dispute.responseId] != _disputeId) {
if (disputeOf[_dispute.responseId] != _disputeId) {
revert Oracle_InvalidDisputeId(_disputeId);
}

Expand All @@ -229,12 +211,9 @@ contract Oracle is IOracle {
Dispute calldata _dispute,
DisputeStatus _status
) external {
bytes32 _disputeId = _getId(_dispute);
bytes32 _requestId = _getId(_request);
bytes32 _disputeId = _validateDispute(_request, _response, _dispute);

if (_response.requestId != _requestId || createdAt[_requestId] == 0) revert Oracle_InvalidRequestBody();

if (disputeOf[_getId(_response)] != _disputeId) {
if (disputeOf[_dispute.responseId] != _disputeId) {
revert Oracle_InvalidDisputeId(_disputeId);
}

Expand Down Expand Up @@ -307,16 +286,15 @@ contract Oracle is IOracle {

/// @inheritdoc IOracle
function finalize(IOracle.Request calldata _request, IOracle.Response calldata _response) external {
bytes32 _requestId = _getId(_request);
bytes32 _responseId;
bytes32 _responseId = _validateResponse(_request, _response);

if (finalizedAt[_requestId] != 0) {
revert Oracle_AlreadyFinalized(_requestId);
if (finalizedAt[_response.requestId] != 0) {
revert Oracle_AlreadyFinalized(_response.requestId);
}

// Finalizing without a response (by passing a Response with `requestId` == 0x0)
if (_response.requestId == bytes32(0)) {
bytes32[] memory _responses = getResponseIds(_requestId);
bytes32[] memory _responses = getResponseIds(_response.requestId);
uint256 _responsesAmount = _responses.length;

if (_responsesAmount != 0) {
Expand All @@ -338,10 +316,7 @@ contract Oracle is IOracle {
_responseId = bytes32(0);
}
} else {
// Finalizing with a response
_responseId = _getId(_response);

if (_response.requestId != _requestId) {
if (_response.requestId != _response.requestId) {
revert Oracle_InvalidFinalizedResponse(_responseId);
}

Expand All @@ -351,10 +326,10 @@ contract Oracle is IOracle {
revert Oracle_InvalidFinalizedResponse(_responseId);
}

_finalizedResponses[_requestId] = _responseId;
_finalizedResponses[_response.requestId] = _responseId;
}

finalizedAt[_requestId] = uint128(block.number);
finalizedAt[_response.requestId] = uint128(block.number);

if (address(_request.finalityModule) != address(0)) {
IFinalityModule(_request.finalityModule).finalizeRequest(_request, _response, msg.sender);
Expand All @@ -368,7 +343,7 @@ contract Oracle is IOracle {
IResponseModule(_request.responseModule).finalizeRequest(_request, _response, msg.sender);
IRequestModule(_request.requestModule).finalizeRequest(_request, _response, msg.sender);

emit OracleRequestFinalized(_requestId, _responseId, msg.sender, block.number);
emit OracleRequestFinalized(_response.requestId, _responseId, msg.sender, block.number);
}

/**
Expand All @@ -384,7 +359,7 @@ contract Oracle is IOracle {
// @audit what about removing nonces? or how we avoid nonce clashing?
if (_requestNonce != _request.nonce || msg.sender != _request.requester) revert Oracle_InvalidRequestBody();

_requestId = _getId(_request);
_requestId = keccak256(abi.encode(_request));
nonceToRequestId[_requestNonce] = _requestId;
createdAt[_requestId] = uint128(block.number);

Expand All @@ -404,32 +379,39 @@ contract Oracle is IOracle {
}

/**
* @notice Computes the id a given request
* @notice Validates the correctness of a request-response pair
*
* @param _request The request to compute the id for
* @return _id The id the request
*/
function _getId(IOracle.Request calldata _request) internal pure returns (bytes32 _id) {
_id = keccak256(abi.encode(_request));
}

/**
* @notice Computes the id a given response
*
* @param _response The response to compute the id for
* @return _id The id the response
* @return _responseId The id the response
*/
function _getId(IOracle.Response calldata _response) internal pure returns (bytes32 _id) {
_id = keccak256(abi.encode(_response));
function _validateResponse(
Request calldata _request,
Response calldata _response
) internal pure returns (bytes32 _responseId) {
bytes32 _requestId = keccak256(abi.encode(_request));
_responseId = keccak256(abi.encode(_response));
if (_response.requestId != _requestId) revert Oracle_InvalidResponseBody();
}

/**
* @notice Computes the id a given dispute
* @notice Validates the correctness of a request-response-dispute triplet
*
* @param _request The request to compute the id for
* @param _response The response to compute the id for
* @param _dispute The dispute to compute the id for
* @return _id The id the dispute
* @return _disputeId The id the dispute
*/
function _getId(IOracle.Dispute calldata _dispute) internal pure returns (bytes32 _id) {
_id = keccak256(abi.encode(_dispute));
function _validateDispute(
Request calldata _request,
Response calldata _response,
Dispute calldata _dispute
) internal pure returns (bytes32 _disputeId) {
bytes32 _requestId = keccak256(abi.encode(_request));
bytes32 _responseId = keccak256(abi.encode(_response));
_disputeId = keccak256(abi.encode(_dispute));

if (_dispute.requestId != _requestId || _dispute.responseId != _responseId) revert Oracle_InvalidDisputeBody();
if (_response.requestId != _requestId) revert Oracle_InvalidResponseBody();
}
}
12 changes: 11 additions & 1 deletion solidity/interfaces/IModule.sol
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,19 @@ interface IModule {
*/
error Module_OnlyOracle();

/**
* @notice Thrown when the response provided does not match the request
*/
error Module_InvalidResponseBody();

/**
* @notice Thrown when the dispute provided does not match the request or response
*/
error Module_InvalidDisputeBody();

/*///////////////////////////////////////////////////////////////
VARIABLES
//////////////////////////////////////////////////////////////*/
//////////////////////////////////////////////////////////////*/

/**
* @notice Returns the address of the oracle
Expand Down
19 changes: 16 additions & 3 deletions solidity/test/unit/Oracle.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,13 @@ contract Unit_ResolveDispute is BaseTest {
function test_resolveDispute_revertsIfNoResolutionModule() public {
// Clear the resolution module
mockRequest.resolutionModule = address(0);
mockDispute.requestId = _getId(mockRequest);
bytes32 _requestId = _getId(mockRequest);

mockResponse.requestId = _requestId;
bytes32 _responseId = _getId(mockResponse);

mockDispute.requestId = _requestId;
mockDispute.responseId = _responseId;
bytes32 _disputeId = _getId(mockDispute);

// Mock the dispute
Expand Down Expand Up @@ -701,7 +707,7 @@ contract Unit_Finalize is BaseTest {
oracle.mock_addResponseId(_requestId, _responseId);

// Test: finalize the request
vm.expectRevert(abi.encodeWithSelector(IOracle.Oracle_InvalidFinalizedResponse.selector, _responseId));
vm.expectRevert(IOracle.Oracle_InvalidResponseBody.selector);
vm.prank(_caller);
oracle.finalize(mockRequest, mockResponse);
}
Expand Down Expand Up @@ -846,7 +852,14 @@ contract Unit_EscalateDispute is BaseTest {

function test_escalateDispute_noResolutionModule() public {
mockRequest.resolutionModule = address(0);
mockDispute.requestId = _getId(mockRequest);

bytes32 _requestId = _getId(mockRequest);

mockResponse.requestId = _requestId;
bytes32 _responseId = _getId(mockResponse);

mockDispute.requestId = _requestId;
mockDispute.responseId = _responseId;
bytes32 _disputeId = _getId(mockDispute);

oracle.mock_setDisputeOf(_getId(mockResponse), _disputeId);
Expand Down