-
Notifications
You must be signed in to change notification settings - Fork 80
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #680 from zama-ai/DecryptionOracle
feat: new DecryptionOracle contract
- Loading branch information
Showing
21 changed files
with
1,956 additions
and
2,496 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
* | ||
!lib/** | ||
!gateway/** | ||
!decryption/** | ||
!config/** | ||
!package.json | ||
!README.md | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,259 @@ | ||
// SPDX-License-Identifier: BSD-3-Clause-Clear | ||
|
||
pragma solidity ^0.8.24; | ||
|
||
import "../lib/TFHE.sol"; | ||
import "../lib/Impl.sol"; | ||
|
||
interface IKMSVerifier { | ||
function verifyDecryptionEIP712KMSSignatures( | ||
address aclAddress, | ||
uint256[] memory handlesList, | ||
bytes memory decryptedResult, | ||
bytes[] memory signatures | ||
) external returns (bool); | ||
} | ||
|
||
interface IDecryptionOracle { | ||
function requestDecryption(uint256[] calldata ctsHandles, bytes4 callbackSelector) external returns (uint256); | ||
} | ||
|
||
struct DecryptionOracleConfigStruct { | ||
address DecryptionOracleAddress; | ||
} | ||
|
||
abstract contract DecryptionOracleCaller { | ||
error HandlesAlreadySavedForRequestID(); | ||
error NoHandleFoundForRequestID(); | ||
error InvalidKMSSignatures(); | ||
error UnsupportedHandleType(); | ||
|
||
mapping(uint256 => ebool[]) private paramsEBool; | ||
mapping(uint256 => euint4[]) private paramsEUint4; | ||
mapping(uint256 => euint8[]) private paramsEUint8; | ||
mapping(uint256 => euint16[]) private paramsEUint16; | ||
mapping(uint256 => euint32[]) private paramsEUint32; | ||
mapping(uint256 => euint64[]) private paramsEUint64; | ||
mapping(uint256 => eaddress[]) private paramsEAddress; | ||
mapping(uint256 => address[]) private paramsAddress; | ||
mapping(uint256 => uint256[]) private paramsUint256; | ||
mapping(uint256 => uint256[]) private requestedHandles; | ||
|
||
constructor() {} | ||
|
||
function addParamsEBool(uint256 requestID, ebool _ebool) internal { | ||
paramsEBool[requestID].push(_ebool); | ||
} | ||
|
||
function addParamsEUint4(uint256 requestID, euint4 _euint4) internal { | ||
paramsEUint4[requestID].push(_euint4); | ||
} | ||
|
||
function addParamsEUint8(uint256 requestID, euint8 _euint8) internal { | ||
paramsEUint8[requestID].push(_euint8); | ||
} | ||
|
||
function addParamsEUint16(uint256 requestID, euint16 _euint16) internal { | ||
paramsEUint16[requestID].push(_euint16); | ||
} | ||
|
||
function addParamsEUint32(uint256 requestID, euint32 _euint32) internal { | ||
paramsEUint32[requestID].push(_euint32); | ||
} | ||
|
||
function addParamsEUint64(uint256 requestID, euint64 _euint64) internal { | ||
paramsEUint64[requestID].push(_euint64); | ||
} | ||
|
||
function addParamsEAddress(uint256 requestID, eaddress _eaddress) internal { | ||
paramsEAddress[requestID].push(_eaddress); | ||
} | ||
|
||
function addParamsAddress(uint256 requestID, address _address) internal { | ||
paramsAddress[requestID].push(_address); | ||
} | ||
|
||
function addParamsUint256(uint256 requestID, uint256 _uint) internal { | ||
paramsUint256[requestID].push(_uint); | ||
} | ||
|
||
function saveRequestedHandles(uint256 requestID, uint256[] memory handlesList) internal { | ||
if (requestedHandles[requestID].length != 0) { | ||
revert HandlesAlreadySavedForRequestID(); | ||
} | ||
requestedHandles[requestID] = handlesList; | ||
} | ||
|
||
function loadRequestedHandles(uint256 requestID) internal view returns (uint256[] memory) { | ||
if (requestedHandles[requestID].length == 0) { | ||
revert NoHandleFoundForRequestID(); | ||
} | ||
return requestedHandles[requestID]; | ||
} | ||
|
||
function getParamsEBool(uint256 requestID) internal view returns (ebool[] memory) { | ||
return paramsEBool[requestID]; | ||
} | ||
|
||
function getParamsEUint4(uint256 requestID) internal view returns (euint4[] memory) { | ||
return paramsEUint4[requestID]; | ||
} | ||
|
||
function getParamsEUint8(uint256 requestID) internal view returns (euint8[] memory) { | ||
return paramsEUint8[requestID]; | ||
} | ||
|
||
function getParamsEUint16(uint256 requestID) internal view returns (euint16[] memory) { | ||
return paramsEUint16[requestID]; | ||
} | ||
|
||
function getParamsEUint32(uint256 requestID) internal view returns (euint32[] memory) { | ||
return paramsEUint32[requestID]; | ||
} | ||
|
||
function getParamsEUint64(uint256 requestID) internal view returns (euint64[] memory) { | ||
return paramsEUint64[requestID]; | ||
} | ||
|
||
function getParamsEAddress(uint256 requestID) internal view returns (eaddress[] memory) { | ||
return paramsEAddress[requestID]; | ||
} | ||
|
||
function getParamsAddress(uint256 requestID) internal view returns (address[] memory) { | ||
return paramsAddress[requestID]; | ||
} | ||
|
||
function getParamsUint256(uint256 requestID) internal view returns (uint256[] memory) { | ||
return paramsUint256[requestID]; | ||
} | ||
|
||
// keccak256(abi.encode(uint256(keccak256("fhevm.storage.DecryptionOracleConfig")) - 1)) & ~bytes32(uint256(0xff)) | ||
bytes32 private constant DecryptionOracleLocation = | ||
0xde725b831312c6c6cbab76ff954da6836e26d78fc467781757c36be9c3d67d00; | ||
|
||
function getDecryptionOracleConfig() internal pure returns (DecryptionOracleConfigStruct storage $) { | ||
assembly { | ||
$.slot := DecryptionOracleLocation | ||
} | ||
} | ||
|
||
function setDecryptionOracle(address decryptionOracleAddress) internal { | ||
DecryptionOracleConfigStruct storage $ = getDecryptionOracleConfig(); | ||
$.DecryptionOracleAddress = decryptionOracleAddress; | ||
} | ||
|
||
function gatewayContractAddress() internal view returns (address) { | ||
DecryptionOracleConfigStruct storage $ = getDecryptionOracleConfig(); | ||
return $.DecryptionOracleAddress; | ||
} | ||
|
||
function toUint256(ebool newCT) internal pure returns (uint256 ct) { | ||
ct = ebool.unwrap(newCT); | ||
} | ||
|
||
function toUint256(euint4 newCT) internal pure returns (uint256 ct) { | ||
ct = euint4.unwrap(newCT); | ||
} | ||
|
||
function toUint256(euint8 newCT) internal pure returns (uint256 ct) { | ||
ct = euint8.unwrap(newCT); | ||
} | ||
|
||
function toUint256(euint16 newCT) internal pure returns (uint256 ct) { | ||
ct = euint16.unwrap(newCT); | ||
} | ||
|
||
function toUint256(euint32 newCT) internal pure returns (uint256 ct) { | ||
ct = euint32.unwrap(newCT); | ||
} | ||
|
||
function toUint256(euint64 newCT) internal pure returns (uint256 ct) { | ||
ct = euint64.unwrap(newCT); | ||
} | ||
|
||
function toUint256(euint128 newCT) internal pure returns (uint256 ct) { | ||
ct = euint128.unwrap(newCT); | ||
} | ||
|
||
function toUint256(eaddress newCT) internal pure returns (uint256 ct) { | ||
ct = eaddress.unwrap(newCT); | ||
} | ||
|
||
function toUint256(euint256 newCT) internal pure returns (uint256 ct) { | ||
ct = euint256.unwrap(newCT); | ||
} | ||
|
||
function toUint256(ebytes64 newCT) internal pure returns (uint256 ct) { | ||
ct = ebytes64.unwrap(newCT); | ||
} | ||
|
||
function toUint256(ebytes128 newCT) internal pure returns (uint256 ct) { | ||
ct = ebytes128.unwrap(newCT); | ||
} | ||
|
||
function toUint256(ebytes256 newCT) internal pure returns (uint256 ct) { | ||
ct = ebytes256.unwrap(newCT); | ||
} | ||
|
||
function requestDecryption( | ||
uint256[] memory ctsHandles, | ||
bytes4 callbackSelector | ||
) internal returns (uint256 requestID) { | ||
FHEVMConfigStruct storage $ = Impl.getFHEVMConfig(); | ||
IACL($.ACLAddress).allowForDecryption(ctsHandles); | ||
DecryptionOracleConfigStruct storage $$ = getDecryptionOracleConfig(); | ||
requestID = IDecryptionOracle($$.DecryptionOracleAddress).requestDecryption(ctsHandles, callbackSelector); | ||
saveRequestedHandles(requestID, ctsHandles); | ||
} | ||
|
||
/// @dev this function should be called inside the callback function the dApp contract to verify the signatures | ||
function verifySignatures(uint256[] memory handlesList, bytes[] memory signatures) internal returns (bool) { | ||
uint256 start = 4 + 32; // start position after skipping the selector (4 bytes) and the first argument (index, 32 bytes) | ||
uint256 length = getSignedDataLength(handlesList); | ||
bytes memory decryptedResult = new bytes(length); | ||
assembly { | ||
calldatacopy(add(decryptedResult, 0x20), start, length) // Copy the relevant part of calldata to decryptedResult memory | ||
} | ||
FHEVMConfigStruct storage $ = Impl.getFHEVMConfig(); | ||
return | ||
IKMSVerifier($.KMSVerifierAddress).verifyDecryptionEIP712KMSSignatures( | ||
$.ACLAddress, | ||
handlesList, | ||
decryptedResult, | ||
signatures | ||
); | ||
} | ||
|
||
function getSignedDataLength(uint256[] memory handlesList) private pure returns (uint256) { | ||
uint256 handlesListlen = handlesList.length; | ||
uint256 signedDataLength; | ||
for (uint256 i = 0; i < handlesListlen; i++) { | ||
uint8 typeCt = uint8(handlesList[i] >> 8); | ||
if (typeCt < 9) { | ||
signedDataLength += 32; | ||
} else if (typeCt == 9) { | ||
//ebytes64 | ||
signedDataLength += 128; | ||
} else if (typeCt == 10) { | ||
//ebytes128 | ||
signedDataLength += 192; | ||
} else if (typeCt == 11) { | ||
//ebytes256 | ||
signedDataLength += 320; | ||
} else { | ||
revert UnsupportedHandleType(); | ||
} | ||
} | ||
signedDataLength += 32; // add offset of signatures | ||
return signedDataLength; | ||
} | ||
|
||
modifier checkSignatures(uint256 requestID, bytes[] memory signatures) { | ||
uint256[] memory handlesList = loadRequestedHandles(requestID); | ||
bool isVerified = verifySignatures(handlesList, signatures); | ||
if (!isVerified) { | ||
revert InvalidKMSSignatures(); | ||
} | ||
_; | ||
} | ||
} |
Oops, something went wrong.