Skip to content

Commit

Permalink
Merge pull request #241 from zama-ai/consistent-shift-operators
Browse files Browse the repository at this point in the history
made shl and shr more consistent with cleartext shift operators by restraining second operand to 8 bits
  • Loading branch information
jatZama authored Dec 21, 2023
2 parents d7e0e96 + bc6649d commit 1a6d3ee
Show file tree
Hide file tree
Showing 8 changed files with 759 additions and 1,211 deletions.
3 changes: 3 additions & 0 deletions codegen/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ export type Operator = {
fheLibName?: string;
binarySolidityOperator?: string;
unarySolidityOperator?: string;
shiftOperator?: boolean;
};

export type CodegenContext = {
Expand Down Expand Up @@ -125,6 +126,7 @@ export const ALL_OPERATORS: Operator[] = [
arguments: OperatorArguments.Binary,
returnType: ReturnType.Uint,
leftScalarEncrypt: true,
shiftOperator: true,
},
{
name: 'shr',
Expand All @@ -134,6 +136,7 @@ export const ALL_OPERATORS: Operator[] = [
arguments: OperatorArguments.Binary,
returnType: ReturnType.Uint,
leftScalarEncrypt: true,
shiftOperator: true,
},
{
name: 'eq',
Expand Down
50 changes: 4 additions & 46 deletions codegen/overloadTests.ts
Original file line number Diff line number Diff line change
Expand Up @@ -121,16 +121,6 @@ export const overloadTests: { [methodName: string]: OverloadTest[] } = {
{ inputs: [0xff, 0xffff], output: 0xff00 },
{ inputs: [0xff, 0xff00], output: 0xffff },
],
shl_euint8_euint16: [
// TODO: should shl output 8bit with 16bit be like that?
{ inputs: [0xff, 0x0100], output: 0xff },
{ inputs: [0x02, 0x0001], output: 0x04 },
],
shr_euint8_euint16: [
// TODO: should shr output 8bit with 16bit be like that?
{ inputs: [0xff, 0x0100], output: 0xff },
{ inputs: [0xff, 0x0001], output: 0x7f },
],
eq_euint8_euint16: [
{ inputs: [0xff, 0x00ff], output: true },
{ inputs: [0xff, 0x01ff], output: false },
Expand Down Expand Up @@ -187,18 +177,6 @@ export const overloadTests: { [methodName: string]: OverloadTest[] } = {
{ inputs: [0x10, 0x00010000], output: 0x00010010 },
{ inputs: [0x11, 0x00010010], output: 0x00010001 },
],
shl_euint8_euint32: [
// C compiler emits the same
{ inputs: [0x10, 0x00010000], output: 0x10 },
{ inputs: [0x1f, 0x00010000], output: 0x1f },
],
shr_euint8_euint32: [
// C compiler emits the same
{ inputs: [0x10, 0x00010000], output: 0x10 },
{ inputs: [0x1f, 0x00010000], output: 0x1f },
{ inputs: [0x10, 0x00000001], output: 0x8 },
{ inputs: [0x1f, 0x00000001], output: 0xf },
],
eq_euint8_euint32: [
{ inputs: [0x01, 0x00000001], output: true },
{ inputs: [0x01, 0x00010001], output: false },
Expand Down Expand Up @@ -265,18 +243,10 @@ export const overloadTests: { [methodName: string]: OverloadTest[] } = {
{ inputs: [0x10, 0x01], output: 0x20 },
{ inputs: [0x10, 0x02], output: 0x40 },
],
shl_uint8_euint8: [
{ inputs: [0x10, 0x01], output: 0x20 },
{ inputs: [0x10, 0x02], output: 0x40 },
],
shr_euint8_uint8: [
{ inputs: [0x10, 0x01], output: 0x08 },
{ inputs: [0x10, 0x02], output: 0x04 },
],
shr_uint8_euint8: [
{ inputs: [0x10, 0x01], output: 0x08 },
{ inputs: [0x10, 0x02], output: 0x04 },
],
eq_euint8_uint8: [
{ inputs: [0x10, 0x10], output: true },
{ inputs: [0x10, 0x02], output: false },
Expand Down Expand Up @@ -375,7 +345,9 @@ export const overloadTests: { [methodName: string]: OverloadTest[] } = {
{ inputs: [0x10f0, 0xf2], output: 0x1002 },
],
shl_euint16_euint8: [{ inputs: [0x1010, 0x02], output: 0x4040 }],
shl_euint16_uint8: [{ inputs: [0x1010, 0x02], output: 0x4040 }],
shr_euint16_euint8: [{ inputs: [0x1010, 0x02], output: 0x0404 }],
shr_euint16_uint8: [{ inputs: [0x1010, 0x02], output: 0x0404 }],
eq_euint16_euint8: [
{ inputs: [0x0010, 0x10], output: true },
{ inputs: [0x0110, 0x10], output: false },
Expand Down Expand Up @@ -429,8 +401,6 @@ export const overloadTests: { [methodName: string]: OverloadTest[] } = {
{ inputs: [0x0200, 0x0002], output: 0x0202 },
{ inputs: [0x0210, 0x0012], output: 0x0202 },
],
shl_euint16_euint16: [{ inputs: [0x0200, 0x0002], output: 0x0800 }],
shr_euint16_euint16: [{ inputs: [0x0200, 0x0002], output: 0x0080 }],
eq_euint16_euint16: [
{ inputs: [0x0200, 0x0002], output: false },
{ inputs: [0x0200, 0x0200], output: true },
Expand Down Expand Up @@ -487,8 +457,6 @@ export const overloadTests: { [methodName: string]: OverloadTest[] } = {
{ inputs: [0x0202, 0x00010000], output: 0x00010202 },
{ inputs: [0x0202, 0x00010002], output: 0x00010200 },
],
shl_euint16_euint32: [{ inputs: [0x0202, 0x00000002], output: 0x00000808 }],
shr_euint16_euint32: [{ inputs: [0x0202, 0x00000002], output: 0x00000080 }],
eq_euint16_euint32: [
{ inputs: [0x0202, 0x00010202], output: false },
{ inputs: [0x0202, 0x00000202], output: true },
Expand Down Expand Up @@ -535,10 +503,6 @@ export const overloadTests: { [methodName: string]: OverloadTest[] } = {
mul_uint16_euint16: [{ inputs: [0x0202, 0x0003], output: 0x0606 }],
div_euint16_uint16: [{ inputs: [0x0606, 0x0003], output: 0x0202 }],
rem_euint16_uint16: [{ inputs: [0x0608, 0x0003], output: 0x0002 }],
shl_euint16_uint16: [{ inputs: [0x0606, 0x0003], output: 0x3030 }],
shl_uint16_euint16: [{ inputs: [0x0606, 0x0003], output: 0x3030 }],
shr_euint16_uint16: [{ inputs: [0x0606, 0x0003], output: 0x00c0 }],
shr_uint16_euint16: [{ inputs: [0x0606, 0x0003], output: 0x00c0 }],
eq_euint16_uint16: [
{ inputs: [0x0606, 0x0606], output: true },
{ inputs: [0x0606, 0x0605], output: false },
Expand Down Expand Up @@ -631,7 +595,9 @@ export const overloadTests: { [methodName: string]: OverloadTest[] } = {
{ inputs: [0x03010003, 0x03], output: 0x03010000 },
],
shl_euint32_euint8: [{ inputs: [0x03010000, 0x03], output: 0x18080000 }],
shl_euint32_uint8: [{ inputs: [0x03010000, 0x03], output: 0x18080000 }],
shr_euint32_euint8: [{ inputs: [0x03010000, 0x03], output: 0x00602000 }],
shr_euint32_uint8: [{ inputs: [0x03010000, 0x03], output: 0x00602000 }],
eq_euint32_euint8: [
{ inputs: [0x00000003, 0x03], output: true },
{ inputs: [0x03000003, 0x03], output: false },
Expand Down Expand Up @@ -682,8 +648,6 @@ export const overloadTests: { [methodName: string]: OverloadTest[] } = {
{ inputs: [0x03000023, 0x1003], output: 0x03001023 },
],
xor_euint32_euint16: [{ inputs: [0x03000023, 0x1003], output: 0x03001020 }],
shl_euint32_euint16: [{ inputs: [0x03000000, 0x0002], output: 0x0c000000 }],
shr_euint32_euint16: [{ inputs: [0x03000000, 0x0002], output: 0x00c00000 }],
eq_euint32_euint16: [
{ inputs: [0x00001000, 0x1000], output: true },
{ inputs: [0x01001000, 0x1000], output: false },
Expand Down Expand Up @@ -737,8 +701,6 @@ export const overloadTests: { [methodName: string]: OverloadTest[] } = {
{ inputs: [0x00321000, 0x54000000], output: 0x54321000 },
{ inputs: [0x00321000, 0x54030000], output: 0x54311000 },
],
shl_euint32_euint32: [{ inputs: [0x00321000, 0x00000002], output: 0x00c84000 }],
shr_euint32_euint32: [{ inputs: [0x00321000, 0x00000002], output: 0x000c8400 }],
eq_euint32_euint32: [
{ inputs: [0x00321000, 0x00321000], output: true },
{ inputs: [0x00321000, 0x00321001], output: false },
Expand Down Expand Up @@ -785,10 +747,6 @@ export const overloadTests: { [methodName: string]: OverloadTest[] } = {
mul_uint32_euint32: [{ inputs: [0x00342000, 0x00000100], output: 0x34200000 }],
div_euint32_uint32: [{ inputs: [0x00342000, 0x00000100], output: 0x00003420 }],
rem_euint32_uint32: [{ inputs: [0x00342039, 0x00000100], output: 0x00000039 }],
shl_euint32_uint32: [{ inputs: [0x00342000, 0x00000001], output: 0x00684000 }],
shl_uint32_euint32: [{ inputs: [0x00342000, 0x00000001], output: 0x00684000 }],
shr_euint32_uint32: [{ inputs: [0x00342000, 0x00000001], output: 0x001a1000 }],
shr_uint32_euint32: [{ inputs: [0x00342000, 0x00000001], output: 0x001a1000 }],
eq_euint32_uint32: [
{ inputs: [0x00342000, 0x00342000], output: true },
{ inputs: [0x00342000, 0x00342001], output: false },
Expand Down
87 changes: 85 additions & 2 deletions codegen/templates.ts
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,19 @@ library TFHE {

supportedBits.forEach((lhsBits) => {
supportedBits.forEach((rhsBits) => {
operators.forEach((operator) => res.push(tfheEncryptedOperator(lhsBits, rhsBits, operator, signatures)));
operators.forEach((operator) => {
if (!operator.shiftOperator) res.push(tfheEncryptedOperator(lhsBits, rhsBits, operator, signatures));
});
});
operators.forEach((operator) => {
if (!operator.shiftOperator) res.push(tfheScalarOperator(lhsBits, lhsBits, operator, signatures));
});
});

supportedBits.forEach((bits) => {
operators.forEach((operator) => {
if (operator.shiftOperator) res.push(tfheShiftOperators(bits, operator, signatures));
});
operators.forEach((operator) => res.push(tfheScalarOperator(lhsBits, lhsBits, operator, signatures)));
});

// TODO: Decide whether we want to have mixed-inputs for CMUX
Expand Down Expand Up @@ -395,6 +405,79 @@ function tfheScalarOperator(
return res.join('');
}

function tfheShiftOperators(inputBits: number, operator: Operator, signatures: OverloadSignature[]): string {
const res: string[] = [];

// Code and test for shift(euint{inputBits},euint8}
const outputBits = inputBits;
const lhsBits = inputBits;
const rhsBits = 8;
const castRightToLeft = lhsBits > rhsBits;

const returnType = `euint${outputBits}`;

const returnTypeOverload: ArgumentType = ArgumentType.EUint;
let scalarFlag = ', false';

const leftExpr = 'a';
const rightExpr = castRightToLeft ? `asEuint${outputBits}(b)` : 'b';
let implExpression = `Impl.${operator.name}(euint${outputBits}.unwrap(${leftExpr}), euint${outputBits}.unwrap(${rightExpr})${scalarFlag})`;

signatures.push({
name: operator.name,
arguments: [
{ type: ArgumentType.EUint, bits: lhsBits },
{ type: ArgumentType.EUint, bits: rhsBits },
],
returnType: { type: returnTypeOverload, bits: outputBits },
});
res.push(`
// Evaluate ${operator.name}(a, b) and return the result.
function ${operator.name}(euint${lhsBits} a, euint${rhsBits} b) internal pure returns (${returnType}) {
if (!isInitialized(a)) {
a = asEuint${lhsBits}(0);
}
if (!isInitialized(b)) {
b = asEuint${rhsBits}(0);
}
return ${returnType}.wrap(${implExpression});
}
`);

// Code and test for shift(euint{inputBits},uint8}
scalarFlag = ', true';
const leftOpName = operator.name;
var implExpressionA = `Impl.${operator.name}(euint${outputBits}.unwrap(a), uint256(b)${scalarFlag})`;
var implExpressionB = `Impl.${leftOpName}(euint${outputBits}.unwrap(b), uint256(a)${scalarFlag})`;
var maybeEncryptLeft = '';
if (operator.leftScalarEncrypt) {
// workaround until tfhe-rs left scalar support:
// do the trivial encryption and preserve order of operations
scalarFlag = ', false';
maybeEncryptLeft = `euint${outputBits} aEnc = asEuint${outputBits}(a);`;
implExpressionB = `Impl.${leftOpName}(euint${outputBits}.unwrap(aEnc), euint${8}.unwrap(b)${scalarFlag})`;
}
signatures.push({
name: operator.name,
arguments: [
{ type: ArgumentType.EUint, bits: lhsBits },
{ type: ArgumentType.Uint, bits: rhsBits },
],
returnType: { type: returnTypeOverload, bits: outputBits },
});
res.push(`
// Evaluate ${operator.name}(a, b) and return the result.
function ${operator.name}(euint${lhsBits} a, uint${rhsBits} b) internal pure returns (${returnType}) {
if (!isInitialized(a)) {
a = asEuint${lhsBits}(0);
}
return ${returnType}.wrap(${implExpressionA});
}
`);

return res.join('');
}

function tfheCmux(inputBits: number): string {
if (inputBits == 8) {
return `
Expand Down
Loading

0 comments on commit 1a6d3ee

Please sign in to comment.