From 478e3e39f14a5e1fc271f3531dc73d71168188f7 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Thu, 14 Sep 2023 23:03:11 +0800 Subject: [PATCH] add lookup gadget add rwlookup test: heaplify add pre-computed challenge mechanism fix type mutabble in bench make clippy happy optimise constraints wip add lookupsnark reformat constraints scope to immutable stepcircuit --- src/gadgets/lookup.rs | 915 ++++++++++++++++++++++++++++++++++++ src/gadgets/mod.rs | 1 + src/lib.rs | 511 ++++++++++++++++++++ src/spartan/lookupsnark.rs | 927 +++++++++++++++++++++++++++++++++++++ src/spartan/mod.rs | 3 +- src/spartan/ppsnark.rs | 38 +- src/spartan/sumcheck.rs | 9 +- 7 files changed, 2386 insertions(+), 18 deletions(-) create mode 100644 src/gadgets/lookup.rs create mode 100644 src/spartan/lookupsnark.rs diff --git a/src/gadgets/lookup.rs b/src/gadgets/lookup.rs new file mode 100644 index 000000000..0ba739a56 --- /dev/null +++ b/src/gadgets/lookup.rs @@ -0,0 +1,915 @@ +//! This module implements lookup gadget for applications built with Nova. +use std::cmp::max; +use std::collections::BTreeMap; + +use bellpepper::gadgets::Assignment; +use bellpepper_core::{num::AllocatedNum, ConstraintSystem, LinearCombination, SynthesisError}; +use std::cmp::Ord; + +use crate::constants::NUM_CHALLENGE_BITS; +use crate::gadgets::nonnative::util::Num; +use crate::gadgets::utils::alloc_const; +use crate::spartan::math::Math; +use crate::traits::ROCircuitTrait; +use crate::traits::ROConstants; +use crate::traits::ROTrait; +use crate::traits::{Group, ROConstantsCircuit}; +use ff::{Field, PrimeField}; + +use super::utils::scalar_as_base; +use super::utils::{alloc_one, conditionally_select2, le_bits_to_num}; + +/// rw trace +#[derive(Clone, Debug)] +pub enum RWTrace { + /// read + Read(T, T, T), // addr, read_value, read_counter + /// write + Write(T, T, T, T), // addr, read_value, write_value, read_counter +} + +/// Lookup in R1CS +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum TableType { + /// read only + ReadOnly, + /// write + ReadWrite, +} + +/// for build up a lookup trace +#[derive(Clone)] +pub struct LookupTrace { + expected_rw_trace: Vec>, + rw_trace_allocated_num: Vec>>, + max_cap_rwcounter_log2: usize, + table_type: TableType, + cursor: usize, +} + +impl LookupTrace { + /// read value from table + pub fn read::Scalar>>( + &mut self, + mut cs: CS, + addr: &AllocatedNum, + ) -> Result, SynthesisError> + where + ::Scalar: Ord + PartialEq + Eq, + { + assert!( + self.cursor < self.expected_rw_trace.len(), + "cursor {} out of range with expected length {}", + self.cursor, + self.expected_rw_trace.len() + ); + if let RWTrace::Read(expected_addr, expected_read_value, expected_read_counter) = + self.expected_rw_trace[self.cursor] + { + if let Some(key) = addr.get_value() { + assert!( + key == expected_addr, + "read address {:?} mismatch with expected {:?}", + key, + expected_addr + ); + } + let read_value = + AllocatedNum::alloc(cs.namespace(|| "read_value"), || Ok(expected_read_value))?; + let read_counter = AllocatedNum::alloc(cs.namespace(|| "read_counter"), || { + Ok(expected_read_counter) + })?; + self + .rw_trace_allocated_num + .push(RWTrace::Read::>( + addr.clone(), + read_value.clone(), + read_counter, + )); // append read trace + + self.cursor += 1; + Ok(read_value) + } else { + Err(SynthesisError::AssignmentMissing) + } + } + + /// write value to lookup table + pub fn write::Scalar>>( + &mut self, + mut cs: CS, + addr: &AllocatedNum, + value: &AllocatedNum, + ) -> Result<(), SynthesisError> + where + ::Scalar: Ord, + { + assert!( + self.cursor < self.expected_rw_trace.len(), + "cursor {} out of range with expected length {}", + self.cursor, + self.expected_rw_trace.len() + ); + if let RWTrace::Write( + expected_addr, + expected_read_value, + expected_write_value, + expected_read_counter, + ) = self.expected_rw_trace[self.cursor] + { + if let Some((addr, value)) = addr.get_value().zip(value.get_value()) { + assert!( + addr == expected_addr, + "write address {:?} mismatch with expected {:?}", + addr, + expected_addr + ); + assert!( + value == expected_write_value, + "write value {:?} mismatch with expected {:?}", + value, + expected_write_value + ); + } + let expected_read_value = + AllocatedNum::alloc(cs.namespace(|| "read_value"), || Ok(expected_read_value))?; + let expected_read_counter = AllocatedNum::alloc(cs.namespace(|| "read_counter"), || { + Ok(expected_read_counter) + })?; + self.rw_trace_allocated_num.push(RWTrace::Write( + addr.clone(), + expected_read_value, + value.clone(), + expected_read_counter, + )); // append write trace + self.cursor += 1; + Ok(()) + } else { + Err(SynthesisError::AssignmentMissing) + } + } + + /// commit rw_trace to lookup + #[allow(clippy::too_many_arguments)] + pub fn commit::Scalar>>( + &mut self, + mut cs: CS, + ro_const: ROConstantsCircuit, + prev_intermediate_gamma: &AllocatedNum, + gamma: &AllocatedNum, + prev_R: &AllocatedNum, + prev_W: &AllocatedNum, + prev_rw_counter: &AllocatedNum, + ) -> Result< + ( + AllocatedNum, + AllocatedNum, + AllocatedNum, + AllocatedNum, + ), + SynthesisError, + > + where + ::Scalar: Ord, + G: Group::Scalar>, + G2: Group::Scalar>, + { + let mut ro = G2::ROCircuit::new( + ro_const, + 1 + 3 * self.expected_rw_trace.len(), // prev_challenge + [(address, value, counter)] + ); + ro.absorb(prev_intermediate_gamma); + let rw_trace_allocated_num = self.rw_trace_allocated_num.clone(); + let (next_R, next_W, next_rw_counter) = rw_trace_allocated_num.iter().enumerate().try_fold( + (prev_R.clone(), prev_W.clone(), prev_rw_counter.clone()), + |(prev_R, prev_W, prev_rw_counter), (i, rwtrace)| match rwtrace { + RWTrace::Read(addr, read_value, expected_read_counter) => { + let (next_R, next_W, next_rw_counter) = self.rw_operation_circuit( + cs.namespace(|| format!("{}th read ", i)), + addr, + gamma, + read_value, + read_value, + &prev_R, + &prev_W, + expected_read_counter, + &prev_rw_counter, + )?; + ro.absorb(addr); + ro.absorb(read_value); + ro.absorb(expected_read_counter); + Ok::< + ( + AllocatedNum, + AllocatedNum, + AllocatedNum, + ), + SynthesisError, + >((next_R, next_W, next_rw_counter)) + } + RWTrace::Write(addr, read_value, write_value, read_counter) => { + let (next_R, next_W, next_rw_counter) = self.rw_operation_circuit( + cs.namespace(|| format!("{}th write ", i)), + addr, + gamma, + read_value, + write_value, + &prev_R, + &prev_W, + read_counter, + &prev_rw_counter, + )?; + ro.absorb(addr); + ro.absorb(read_value); + ro.absorb(read_counter); + Ok::< + ( + AllocatedNum, + AllocatedNum, + AllocatedNum, + ), + SynthesisError, + >((next_R, next_W, next_rw_counter)) + } + }, + )?; + let hash_bits = ro.squeeze(cs.namespace(|| "challenge"), NUM_CHALLENGE_BITS)?; + let hash = le_bits_to_num(cs.namespace(|| "bits to hash"), &hash_bits)?; + Ok((next_R, next_W, next_rw_counter, hash)) + } + + #[allow(clippy::too_many_arguments)] + fn rw_operation_circuit>( + &mut self, + mut cs: CS, + addr: &AllocatedNum, + // challenges: &(AllocatedNum, AllocatedNum), + gamma: &AllocatedNum, + read_value: &AllocatedNum, + write_value: &AllocatedNum, + prev_R: &AllocatedNum, + prev_W: &AllocatedNum, + read_counter: &AllocatedNum, + prev_rw_counter: &AllocatedNum, + ) -> Result<(AllocatedNum, AllocatedNum, AllocatedNum), SynthesisError> + where + F: Ord, + { + // update R + let gamma_square = gamma.mul(cs.namespace(|| "gamme^2"), gamma)?; + // read_value_term = gamma * value + let read_value_term = gamma.mul(cs.namespace(|| "read_value_term"), read_value)?; + // counter_term = gamma^2 * counter + let read_counter_term = gamma_square.mul(cs.namespace(|| "read_counter_term"), read_counter)?; + // new_R = R * (gamma - (addr + gamma * value + gamma^2 * counter)) + let new_R = AllocatedNum::alloc(cs.namespace(|| "new_R"), || { + prev_R + .get_value() + .zip(gamma.get_value()) + .zip(addr.get_value()) + .zip(read_value_term.get_value()) + .zip(read_counter_term.get_value()) + .map(|((((R, gamma), addr), value_term), counter_term)| { + R * (gamma - (addr + value_term + counter_term)) + }) + .ok_or(SynthesisError::AssignmentMissing) + })?; + let mut r_blc = LinearCombination::::zero(); + r_blc = r_blc + gamma.get_variable() + - addr.get_variable() + - read_value_term.get_variable() + - read_counter_term.get_variable(); + cs.enforce( + || "R update", + |lc| lc + prev_R.get_variable(), + |_| r_blc, + |lc| lc + new_R.get_variable(), + ); + + let alloc_num_one = alloc_one(cs.namespace(|| "one"))?; + // max{read_counter, rw_counter} logic on read-write lookup + // read_counter on read-only + // - max{read_counter, rw_counter} if read-write table + // - read_counter if read-only table + // +1 will be hadle later + let (write_counter, write_counter_term) = if self.table_type == TableType::ReadWrite { + // write_counter = read_counter < prev_rw_counter ? prev_rw_counter: read_counter + // TODO optimise with `max` table lookup to save more constraints + let lt = less_than( + cs.namespace(|| "read_counter < a"), + read_counter, + prev_rw_counter, + self.max_cap_rwcounter_log2, + )?; + let write_counter = conditionally_select2( + cs.namespace(|| { + "write_counter = read_counter < prev_rw_counter ? prev_rw_counter: read_counter" + }), + prev_rw_counter, + read_counter, + <, + )?; + let write_counter_term = + gamma_square.mul(cs.namespace(|| "write_counter_term"), &write_counter)?; + (write_counter, write_counter_term) + } else { + (read_counter.clone(), read_counter_term) + }; + + // update W + // write_value_term = gamma * value + let write_value_term = gamma.mul(cs.namespace(|| "write_value_term"), write_value)?; + let new_W = AllocatedNum::alloc(cs.namespace(|| "new_W"), || { + prev_W + .get_value() + .zip(gamma.get_value()) + .zip(addr.get_value()) + .zip(write_value_term.get_value()) + .zip(write_counter_term.get_value()) + .zip(gamma_square.get_value()) + .map( + |(((((W, gamma), addr), value_term), write_counter_term), gamma_square)| { + W * (gamma - (addr + value_term + write_counter_term + gamma_square)) + }, + ) + .ok_or(SynthesisError::AssignmentMissing) + })?; + // new_W = W * (gamma - (addr + gamma * value + gamma^2 * counter + gamma^2))) + let mut w_blc = LinearCombination::::zero(); + w_blc = w_blc + gamma.get_variable() + - addr.get_variable() + - write_value_term.get_variable() + - write_counter_term.get_variable() + - gamma_square.get_variable(); + cs.enforce( + || "W update", + |lc| lc + prev_W.get_variable(), + |_| w_blc, + |lc| lc + new_W.get_variable(), + ); + let new_rw_counter = add_allocated_num( + cs.namespace(|| "new_rw_counter"), + &write_counter, + &alloc_num_one, + )?; + Ok((new_R, new_W, new_rw_counter)) + } +} + +/// for build up a lookup trace +pub struct LookupTraceBuilder<'a, G: Group> { + lookup: &'a mut Lookup, + rw_trace: Vec>, + map_aux: BTreeMap, +} + +impl<'a, G: Group> LookupTraceBuilder<'a, G> { + /// start a new transaction simulated + pub fn new(lookup: &'a mut Lookup) -> LookupTraceBuilder<'a, G> { + LookupTraceBuilder { + lookup, + rw_trace: vec![], + map_aux: BTreeMap::new(), + } + } + + /// read value from table + pub fn read(&mut self, addr: G::Scalar) -> G::Scalar + where + ::Scalar: Ord, + { + let key = &addr; + let (value, _) = self.map_aux.entry(*key).or_insert_with(|| { + self + .lookup + .map_aux + .get(key) + .cloned() + .unwrap_or((G::Scalar::ZERO, G::Scalar::ZERO)) + }); + self + .rw_trace + .push(RWTrace::Read(addr, *value, G::Scalar::ZERO)); + *value + } + /// write value to lookup table + pub fn write(&mut self, addr: G::Scalar, value: G::Scalar) + where + ::Scalar: Ord, + { + let _ = self.map_aux.insert( + addr, + ( + value, + G::Scalar::ZERO, // zero counter doens't matter, real counter will provided in snapshot stage + ), + ); + self.rw_trace.push(RWTrace::Write( + addr, + G::Scalar::ZERO, + value, + G::Scalar::ZERO, + )); // append read trace + } + + /// commit rw_trace to lookup + pub fn snapshot( + &mut self, + ro_consts: ROConstants, + prev_intermediate_gamma: G::Scalar, + ) -> (G::Scalar, LookupTrace) + where + ::Scalar: Ord, + G: Group::Scalar>, + G2: Group::Scalar>, + { + let mut hasher: ::RO = + ::RO::new(ro_consts, 1 + self.rw_trace.len() * 3); + hasher.absorb(prev_intermediate_gamma); + + self.rw_trace = self + .rw_trace + .iter() + .map(|rwtrace| { + let (addr, (read_value, read_counter)) = match rwtrace { + RWTrace::Read(addr, _, _) => (addr, self.lookup.rw_operation(*addr, None)), + RWTrace::Write(addr, _, write_value, _) => { + (addr, self.lookup.rw_operation(*addr, Some(*write_value))) + } + }; + hasher.absorb(*addr); + hasher.absorb(read_value); + hasher.absorb(read_counter); + match rwtrace { + RWTrace::Read(..) => RWTrace::Read(*addr, read_value, read_counter), + RWTrace::Write(_, _, write_value, _) => { + RWTrace::Write(*addr, read_value, *write_value, read_counter) + } + } + }) + .collect(); + let hash_bits = hasher.squeeze(NUM_CHALLENGE_BITS); + let rw_trace = self.rw_trace.to_vec(); + self.rw_trace.clear(); + let next_intermediate_gamma = scalar_as_base::(hash_bits); + ( + next_intermediate_gamma, + LookupTrace { + expected_rw_trace: rw_trace, + rw_trace_allocated_num: vec![], + cursor: 0, + max_cap_rwcounter_log2: self.lookup.max_cap_rwcounter_log2, + table_type: self.lookup.table_type.clone(), + }, + ) + } +} + +/// Lookup in R1CS +#[derive(Clone, Debug)] +pub struct Lookup { + pub(crate) map_aux: BTreeMap, // (value, counter) + /// map_aux_dirty only include the modified fields of `map_aux`, thats why called dirty + map_aux_dirty: BTreeMap, // (value, counter) + rw_counter: F, + pub(crate) table_type: TableType, // read only or read-write + pub(crate) max_cap_rwcounter_log2: usize, // max cap for rw_counter operation in bits +} + +impl Lookup { + /// new lookup table + pub fn new( + max_cap_rwcounter: usize, + table_type: TableType, + initial_table: Vec<(F, F)>, + ) -> Lookup + where + F: Ord, + { + let max_cap_rwcounter_log2 = max_cap_rwcounter.log_2(); + Self { + map_aux: initial_table + .into_iter() + .map(|(addr, value)| (addr, (value, F::ZERO))) + .collect(), + map_aux_dirty: BTreeMap::new(), + rw_counter: F::ZERO, + table_type, + max_cap_rwcounter_log2, + } + } + + /// get table vector + /// very costly operation + pub fn get_table(&self) -> Vec<(F, F, F)> { + self + .map_aux + .iter() + .map(|(addr, (value, counter))| (*addr, *value, *counter)) + .collect() + } + + /// table size + pub fn table_size(&self) -> usize { + self.map_aux.len() + } + + fn rw_operation(&mut self, addr: F, external_value: Option) -> (F, F) + where + F: Ord, + { + // write operations + if external_value.is_some() { + debug_assert!(self.table_type == TableType::ReadWrite) // table need to set as rw + } + let (read_value, read_counter) = self + .map_aux + .get(&addr) + .cloned() + .unwrap_or((F::from(0), F::from(0))); + + let (write_value, write_counter) = ( + external_value.unwrap_or(read_value), + if self.table_type == TableType::ReadOnly { + read_counter + } else { + max(self.rw_counter, read_counter) + } + F::ONE, + ); + self.map_aux.insert(addr, (write_value, write_counter)); + self + .map_aux_dirty + .insert(addr, (write_value, write_counter)); + self.rw_counter = write_counter; + (read_value, read_counter) + } + + // fn write(&mut self, addr: AllocatedNum, value: F) {} +} + +/// c = a + b where a, b is AllocatedNum +pub fn add_allocated_num>( + mut cs: CS, + a: &AllocatedNum, + b: &AllocatedNum, +) -> Result, SynthesisError> { + let c = AllocatedNum::alloc(cs.namespace(|| "c"), || { + Ok(*a.get_value().get()? + b.get_value().get()?) + })?; + cs.enforce( + || "c = a+b", + |lc| lc + a.get_variable() + b.get_variable(), + |lc| lc + CS::one(), + |lc| lc + c.get_variable(), + ); + Ok(c) +} + +/// a < b ? 1 : 0 +pub fn less_than>( + mut cs: CS, + a: &AllocatedNum, + b: &AllocatedNum, + n_bits: usize, +) -> Result, SynthesisError> { + assert!(n_bits < 64, "not support n_bits {n_bits} >= 64"); + let range = alloc_const( + cs.namespace(|| "range"), + F::from(2_usize.pow(n_bits as u32) as u64), + )?; + // diff = (lhs - rhs) + (if lt { range } else { 0 }); + let diff = Num::alloc(cs.namespace(|| "diff"), || { + a.get_value() + .zip(b.get_value()) + .zip(range.get_value()) + .map(|((a, b), range)| { + let lt = a < b; + (a - b) + (if lt { range } else { F::ZERO }) + }) + .ok_or(SynthesisError::AssignmentMissing) + })?; + diff.fits_in_bits(cs.namespace(|| "diff fit in bits"), n_bits)?; + let diff = diff.as_allocated_num(cs.namespace(|| "diff_alloc_num"))?; + let lt = AllocatedNum::alloc(cs.namespace(|| "lt"), || { + a.get_value() + .zip(b.get_value()) + .map(|(a, b)| F::from(u64::from(a < b))) + .ok_or(SynthesisError::AssignmentMissing) + })?; + cs.enforce( + || "lt is bit", + |lc| lc + lt.get_variable(), + |lc| lc + CS::one() - lt.get_variable(), + |lc| lc, + ); + cs.enforce( + || "lt ⋅ range == diff - lhs + rhs", + |lc| lc + lt.get_variable(), + |lc| lc + range.get_variable(), + |lc| lc + diff.get_variable() - a.get_variable() + b.get_variable(), + ); + Ok(lt) +} + +#[cfg(test)] +mod test { + use crate::{ + // bellpepper::test_shape_cs::TestShapeCS, + constants::NUM_CHALLENGE_BITS, + gadgets::{ + lookup::{LookupTraceBuilder, TableType}, + utils::{alloc_one, alloc_zero, scalar_as_base}, + }, + provider::poseidon::PoseidonConstantsCircuit, + traits::{Group, ROConstantsCircuit}, + }; + use ff::Field; + + use super::Lookup; + use crate::traits::ROTrait; + use bellpepper_core::{num::AllocatedNum, test_cs::TestConstraintSystem, ConstraintSystem}; + + #[test] + fn test_lookup_simulation() { + type G1 = pasta_curves::pallas::Point; + type G2 = pasta_curves::vesta::Point; + + let ro_consts: ROConstantsCircuit = PoseidonConstantsCircuit::default(); + + // let mut cs: TestShapeCS = TestShapeCS::new(); + let initial_table = vec![ + (::Scalar::ZERO, ::Scalar::ZERO), + (::Scalar::ONE, ::Scalar::ONE), + ]; + let mut lookup = + Lookup::<::Scalar>::new(1024, TableType::ReadWrite, initial_table); + let mut lookup_trace_builder = LookupTraceBuilder::::new(&mut lookup); + let prev_intermediate_gamma = ::Scalar::ONE; + let read_value = lookup_trace_builder.read(::Scalar::ZERO); + assert_eq!(read_value, ::Scalar::ZERO); + let read_value = lookup_trace_builder.read(::Scalar::ONE); + assert_eq!(read_value, ::Scalar::ONE); + lookup_trace_builder.write( + ::Scalar::ZERO, + ::Scalar::from(111), + ); + let read_value = lookup_trace_builder.read(::Scalar::ZERO); + assert_eq!(read_value, ::Scalar::from(111),); + + let (next_intermediate_gamma, _) = + lookup_trace_builder.snapshot::(ro_consts.clone(), prev_intermediate_gamma); + + let mut hasher = ::RO::new(ro_consts, 1 + 3 * 4); + hasher.absorb(prev_intermediate_gamma); + hasher.absorb(::Scalar::ZERO); // addr + hasher.absorb(::Scalar::ZERO); // value + hasher.absorb(::Scalar::ZERO); // counter + hasher.absorb(::Scalar::ONE); // addr + hasher.absorb(::Scalar::ONE); // value + hasher.absorb(::Scalar::ZERO); // counter + hasher.absorb(::Scalar::ZERO); // addr + hasher.absorb(::Scalar::ZERO); // value + hasher.absorb(::Scalar::ONE); // counter + hasher.absorb(::Scalar::ZERO); // addr + hasher.absorb(::Scalar::from(111)); // value + hasher.absorb(::Scalar::from(3)); // counter + let res = hasher.squeeze(NUM_CHALLENGE_BITS); + assert_eq!(scalar_as_base::(res), next_intermediate_gamma); + } + + #[test] + fn test_read_twice_on_readonly() { + type G1 = pasta_curves::pallas::Point; + type G2 = pasta_curves::vesta::Point; + + let ro_consts: ROConstantsCircuit = PoseidonConstantsCircuit::default(); + + let mut cs = TestConstraintSystem::<::Scalar>::new(); + // let mut cs: TestShapeCS = TestShapeCS::new(); + let initial_table = vec![ + ( + ::Scalar::ZERO, + ::Scalar::from(101), + ), + (::Scalar::ONE, ::Scalar::ZERO), + ]; + let mut lookup = Lookup::<::Scalar>::new(1024, TableType::ReadOnly, initial_table); + let mut lookup_trace_builder = LookupTraceBuilder::::new(&mut lookup); + let gamma = AllocatedNum::alloc(cs.namespace(|| "gamma"), || { + Ok(::Scalar::from(2)) + }) + .unwrap(); + let zero = alloc_zero(cs.namespace(|| "zero")).unwrap(); + let one = alloc_one(cs.namespace(|| "one")).unwrap(); + let prev_intermediate_gamma = &one; + let prev_rw_counter = &zero; + let addr = zero.clone(); + let read_value = lookup_trace_builder.read(addr.get_value().unwrap()); + assert_eq!(read_value, ::Scalar::from(101)); + let read_value = lookup_trace_builder.read(addr.get_value().unwrap()); + assert_eq!(read_value, ::Scalar::from(101)); + let (_, mut lookup_trace) = lookup_trace_builder.snapshot::( + ro_consts.clone(), + prev_intermediate_gamma.get_value().unwrap(), + ); + + let read_value = lookup_trace + .read(cs.namespace(|| "read_value1"), &addr) + .unwrap(); + assert_eq!( + read_value.get_value(), + Some(::Scalar::from(101)) + ); + + let read_value = lookup_trace + .read(cs.namespace(|| "read_value2"), &addr) + .unwrap(); + assert_eq!( + read_value.get_value(), + Some(::Scalar::from(101)) + ); + + let (prev_W, prev_R) = (&one, &one); + let (next_R, next_W, next_rw_counter, next_intermediate_gamma) = lookup_trace + .commit::( + cs.namespace(|| "commit"), + ro_consts.clone(), + prev_intermediate_gamma, + &gamma, + prev_W, + prev_R, + prev_rw_counter, + ) + .unwrap(); + assert_eq!( + next_rw_counter.get_value(), + Some(::Scalar::from(2)) + ); + // next_R check + assert_eq!( + next_R.get_value(), + prev_R + .get_value() + .zip(gamma.get_value()) + .zip(addr.get_value()) + .zip(read_value.get_value()) + .map(|(((prev_R, gamma), addr), read_value)| prev_R + * (gamma - (addr + gamma * read_value + gamma * gamma * ::Scalar::ZERO)) + * (gamma - (addr + gamma * read_value + gamma * gamma * ::Scalar::ONE))) + ); + // next_W check + assert_eq!( + next_W.get_value(), + prev_W + .get_value() + .zip(gamma.get_value()) + .zip(addr.get_value()) + .zip(read_value.get_value()) + .map(|(((prev_W, gamma), addr), read_value)| { + prev_W + * (gamma - (addr + gamma * read_value + gamma * gamma * (::Scalar::ONE))) + * (gamma + - (addr + gamma * read_value + gamma * gamma * (::Scalar::from(2)))) + }), + ); + + let mut hasher = ::RO::new(ro_consts, 7); + hasher.absorb(prev_intermediate_gamma.get_value().unwrap()); + hasher.absorb(addr.get_value().unwrap()); + hasher.absorb(read_value.get_value().unwrap()); + hasher.absorb(::Scalar::ZERO); + hasher.absorb(addr.get_value().unwrap()); + hasher.absorb(read_value.get_value().unwrap()); + hasher.absorb(::Scalar::ONE); + let res = hasher.squeeze(NUM_CHALLENGE_BITS); + assert_eq!( + scalar_as_base::(res), + next_intermediate_gamma.get_value().unwrap() + ); + // TODO check rics is_sat + // let (_, _) = cs.r1cs_shape_with_commitmentkey(); + // let (U1, W1) = cs.r1cs_instance_and_witness(&shape, &ck).unwrap(); + + // // Make sure that the first instance is satisfiable + // assert!(shape.is_sat(&ck, &U1, &W1).is_ok()); + } + + #[test] + fn test_write_read_on_rwlookup() { + type G1 = pasta_curves::pallas::Point; + type G2 = pasta_curves::vesta::Point; + + let ro_consts: ROConstantsCircuit = PoseidonConstantsCircuit::default(); + + let mut cs = TestConstraintSystem::<::Scalar>::new(); + // let mut cs: TestShapeCS = TestShapeCS::new(); + let initial_table = vec![ + (::Scalar::ZERO, ::Scalar::ZERO), + (::Scalar::ONE, ::Scalar::ZERO), + ]; + let mut lookup = + Lookup::<::Scalar>::new(1024, TableType::ReadWrite, initial_table); + let mut lookup_trace_builder = LookupTraceBuilder::::new(&mut lookup); + let gamma = AllocatedNum::alloc(cs.namespace(|| "gamma"), || { + Ok(::Scalar::from(2)) + }) + .unwrap(); + let zero = alloc_zero(cs.namespace(|| "zero")).unwrap(); + let one = alloc_one(cs.namespace(|| "one")).unwrap(); + let prev_intermediate_gamma = &one; + let prev_rw_counter = &zero; + let addr = zero.clone(); + let write_value_1 = AllocatedNum::alloc(cs.namespace(|| "write value 1"), || { + Ok(::Scalar::from(101)) + }) + .unwrap(); + lookup_trace_builder.write( + addr.get_value().unwrap(), + write_value_1.get_value().unwrap(), + ); + let read_value = lookup_trace_builder.read(addr.get_value().unwrap()); + // cs.namespace(|| "read_value 1"), + assert_eq!(read_value, ::Scalar::from(101)); + let (_, mut lookup_trace) = lookup_trace_builder.snapshot::( + ro_consts.clone(), + prev_intermediate_gamma.get_value().unwrap(), + ); + lookup_trace + .write(cs.namespace(|| "write_value 1"), &addr, &write_value_1) + .unwrap(); + let read_value = lookup_trace + .read(cs.namespace(|| "read_value 1"), &addr) + .unwrap(); + assert_eq!( + read_value.get_value(), + Some(::Scalar::from(101)) + ); + + let (prev_W, prev_R) = (&one, &one); + let (next_R, next_W, next_rw_counter, next_intermediate_gamma) = lookup_trace + .commit::( + cs.namespace(|| "commit"), + ro_consts.clone(), + prev_intermediate_gamma, + &gamma, + prev_W, + prev_R, + prev_rw_counter, + ) + .unwrap(); + assert_eq!( + next_rw_counter.get_value(), + Some(::Scalar::from(2)) + ); + // next_R check + assert_eq!( + next_R.get_value(), + prev_R + .get_value() + .zip(gamma.get_value()) + .zip(addr.get_value()) + .zip(read_value.get_value()) + .map(|(((prev_R, gamma), addr), read_value)| prev_R + * (gamma + - (addr + + gamma * ::Scalar::ZERO + + gamma * gamma * ::Scalar::ZERO)) + * (gamma - (addr + gamma * read_value + gamma * gamma * ::Scalar::ONE))) + ); + // next_W check + assert_eq!( + next_W.get_value(), + prev_W + .get_value() + .zip(gamma.get_value()) + .zip(addr.get_value()) + .zip(read_value.get_value()) + .map(|(((prev_W, gamma), addr), read_value)| { + prev_W + * (gamma - (addr + gamma * read_value + gamma * gamma * (::Scalar::ONE))) + * (gamma + - (addr + gamma * read_value + gamma * gamma * (::Scalar::from(2)))) + }), + ); + + let mut hasher = ::RO::new(ro_consts, 7); + hasher.absorb(prev_intermediate_gamma.get_value().unwrap()); + hasher.absorb(addr.get_value().unwrap()); + hasher.absorb(::Scalar::ZERO); + hasher.absorb(::Scalar::ZERO); + hasher.absorb(addr.get_value().unwrap()); + hasher.absorb(read_value.get_value().unwrap()); + hasher.absorb(::Scalar::ONE); + let res = hasher.squeeze(NUM_CHALLENGE_BITS); + assert_eq!( + scalar_as_base::(res), + next_intermediate_gamma.get_value().unwrap() + ); + // TODO check rics is_sat + // let (_, _) = cs.r1cs_shape_with_commitmentkey(); + // let (U1, W1) = cs.r1cs_instance_and_witness(&shape, &ck).unwrap(); + + // // Make sure that the first instance is satisfiable + // assert!(shape.is_sat(&ck, &U1, &W1).is_ok()); + } +} diff --git a/src/gadgets/mod.rs b/src/gadgets/mod.rs index d42474626..90ec38d89 100644 --- a/src/gadgets/mod.rs +++ b/src/gadgets/mod.rs @@ -1,5 +1,6 @@ //! This module implements various gadgets necessary for Nova and applications built with Nova. pub mod ecc; +pub mod lookup; pub(crate) mod nonnative; pub(crate) mod r1cs; pub(crate) mod utils; diff --git a/src/lib.rs b/src/lib.rs index 31afd9c10..939f79266 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -939,9 +939,16 @@ type CE = ::CE; #[cfg(test)] mod tests { + use crate::bellpepper::test_shape_cs::TestShapeCS; + use crate::constants::NUM_CHALLENGE_BITS; + use crate::gadgets::lookup::{less_than, Lookup, LookupTrace, LookupTraceBuilder, TableType}; + use crate::gadgets::utils::conditionally_select2; use crate::provider::bn256_grumpkin::{bn256, grumpkin}; use crate::provider::pedersen::CommitmentKeyExtTrait; + use crate::provider::poseidon::PoseidonConstantsCircuit; use crate::provider::secp_secq::{secp256k1, secq256k1}; + use crate::spartan::lookupsnark::LookupSNARK; + use crate::spartan::math::Math; use core::fmt::Write; use super::*; @@ -950,8 +957,10 @@ mod tests { type SPrime = spartan::ppsnark::RelaxedR1CSSNARK>; use ::bellpepper_core::{num::AllocatedNum, ConstraintSystem, SynthesisError}; + use bellpepper_core::Namespace; use core::marker::PhantomData; use ff::PrimeField; + use tap::TapOptional; use traits::circuit::TrivialCircuit; #[derive(Clone, Debug, Default)] @@ -1644,4 +1653,506 @@ mod tests { test_ivc_base_with::(); test_ivc_base_with::(); } + + fn print_constraints_name_on_error_index(err: &NovaError, c_primary: &C1) + where + G1: Group::Scalar>, + G2: Group::Scalar>, + C1: StepCircuit, + { + match err { + NovaError::UnSatIndex(index) => { + let augmented_circuit_params_primary = + NovaAugmentedCircuitParams::new(BN_LIMB_WIDTH, BN_N_LIMBS, true); + + // let (mut circuit_primary, z0_primary) = HeapifyCircuit::new(ro_consts); + let ro_consts_circuit_primary: ROConstantsCircuit = ROConstantsCircuit::::default(); + let circuit_primary: NovaAugmentedCircuit<'_, G2, C1> = NovaAugmentedCircuit::new( + &augmented_circuit_params_primary, + None, + c_primary, + ro_consts_circuit_primary, + ); + // let mut cs: ShapeCS = ShapeCS::new(); + // let _ = circuit_primary.synthesize(&mut cs); + let mut cs: TestShapeCS = TestShapeCS::new(); + let _ = circuit_primary.synthesize(&mut cs); + cs.constraints + .get(*index) + .tap_some(|constraint| println!("failed at constraint {}", constraint.3)); + } + error => unimplemented!("{:?}", error), + } + } + + #[test] + fn test_ivc_rwlookup() { + type G1 = pasta_curves::pallas::Point; + type G2 = pasta_curves::vesta::Point; + + // rw lookup to serve as a non-deterministic advices. + #[derive(Clone)] + struct HeapifyCircuit + where + ::Scalar: Ord, + G1: Group::Scalar>, + G2: Group::Scalar>, + { + lookup_trace: LookupTrace, + ro_consts: ROConstantsCircuit, + max_value_bits: usize, + _phantom: PhantomData, + } + + impl HeapifyCircuit + where + ::Scalar: Ord, + G1: Group::Scalar>, + G2: Group::Scalar>, + { + fn new( + initial_table: &Lookup, + ro_consts_circuit: ROConstantsCircuit, + ) -> (Vec, Lookup, G1::Scalar) { + let n = initial_table.table_size(); + let initial_table = initial_table.clone(); + + let initial_index = (n - 4) / 2; + let max_value_bits = (n - 1).log_2() + 1; // + 1 as a buffer + let initial_intermediate_gamma = ::Scalar::from(1); + + let mut lookup = initial_table.clone(); + let num_steps = initial_index; + let mut intermediate_gamma = initial_intermediate_gamma; + // simulate folding step lookup io + let mut primary_circuits = vec![]; + let ro_consts = <::RO as ROTrait< + ::Base, + ::Scalar, + >>::Constants::default(); + for i in 0..num_steps + 1 { + let mut lookup_trace_builder = LookupTraceBuilder::::new(&mut lookup); + let addr = G1::Scalar::from((num_steps - i) as u64); + let parent = lookup_trace_builder.read(addr); + let left_child = lookup_trace_builder.read(G1::Scalar::from(2) * addr + G1::Scalar::ONE); + let right_child = + lookup_trace_builder.read(G1::Scalar::from(2) * addr + G1::Scalar::from(2)); + // swap left pair + let (new_parent_left, new_left_child) = if left_child < parent { + (left_child, parent) + } else { + (parent, left_child) + }; + lookup_trace_builder.write(addr, new_parent_left); + lookup_trace_builder.write( + G1::Scalar::from(2) * addr + G1::Scalar::from(1), + new_left_child, + ); + let (new_parent_right, new_right_child) = if right_child < new_parent_left { + (right_child, new_parent_left) + } else { + (new_parent_left, right_child) + }; + lookup_trace_builder.write(addr, new_parent_right); + lookup_trace_builder.write( + G1::Scalar::from(2) * addr + G1::Scalar::from(2), + new_right_child, + ); + let res = lookup_trace_builder.snapshot::(ro_consts.clone(), intermediate_gamma); + intermediate_gamma = res.0; + let (_, lookup_trace) = res; + primary_circuits.push(Self { + lookup_trace, + ro_consts: ro_consts_circuit.clone(), + max_value_bits, + _phantom: PhantomData:: {}, + }); + } + + (primary_circuits, lookup, intermediate_gamma) + } + + fn get_z0( + ck: &CommitmentKey, + final_table: &Lookup, + intermediate_gamma: G1::Scalar, + ) -> Vec + where + G1: Group::Scalar>, + G2: Group::Scalar>, + { + let n = final_table.table_size(); + let initial_index = (n - 4) / 2; + let (initial_intermediate_gamma, init_prev_R, init_prev_W, init_rw_counter) = ( + ::Scalar::from(1), + ::Scalar::from(1), + ::Scalar::from(1), + ::Scalar::from(0), + ); + + let ro_consts = <::RO as ROTrait< + ::Base, + ::Scalar, + >>::Constants::default(); + + let final_values: Vec<::Scalar> = final_table + .get_table() + .iter() + .map(|(_, value, _)| *value) + .collect(); + let final_counters: Vec<::Scalar> = final_table + .get_table() + .iter() + .map(|(_, _, counter)| *counter) + .collect(); + + // final_value and final_commitment + let ( + (comm_final_value_cordx, comm_final_value_cordy, comm_final_value_infinity), + (comm_final_counter_cordx, comm_final_counter_cordy, comm_final_counter_infinity), + ) = rayon::join( + || G1::CE::commit(ck, &final_values).to_coordinates(), + || G1::CE::commit(ck, &final_counters).to_coordinates(), + ); + + let mut hasher = ::RO::new(ro_consts, 7); + hasher.absorb(intermediate_gamma); + hasher.absorb(scalar_as_base::(comm_final_value_cordx)); + hasher.absorb(scalar_as_base::(comm_final_value_cordy)); + hasher.absorb(scalar_as_base::(G2::Scalar::from( + comm_final_value_infinity as u64, + ))); + hasher.absorb(scalar_as_base::(comm_final_counter_cordx)); + hasher.absorb(scalar_as_base::(comm_final_counter_cordy)); + hasher.absorb(scalar_as_base::(G2::Scalar::from( + comm_final_counter_infinity as u64, + ))); + + let hash_bits = hasher.squeeze(NUM_CHALLENGE_BITS); + let gamma = scalar_as_base::(hash_bits); + vec![ + initial_intermediate_gamma, + gamma, + init_prev_R, + init_prev_W, + init_rw_counter, + G1::Scalar::from(initial_index as u64), + ] + } + } + + impl, G2: Group> StepCircuit + for HeapifyCircuit + where + G1::Scalar: Ord, + G1: Group::Scalar>, + G2: Group::Scalar>, + { + fn arity(&self) -> usize { + 6 + } + + fn synthesize>( + &self, + cs: &mut CS, + z: &[AllocatedNum], + ) -> Result>, SynthesisError> { + let mut lookup_trace = self.lookup_trace.clone(); + let prev_intermediate_gamma = &z[0]; + let gamma = &z[1]; + let prev_R = &z[2]; + let prev_W = &z[3]; + let prev_rw_counter = &z[4]; + let index = &z[5]; + + let left_child_index = AllocatedNum::alloc(cs.namespace(|| "left_child_index"), || { + index + .get_value() + .map(|i| i.mul(F::from(2)) + F::ONE) + .ok_or(SynthesisError::AssignmentMissing) + })?; + cs.enforce( + || "(2*index + 1) * 1 = left_child_index", + |lc| lc + (F::from(2), index.get_variable()) + CS::one(), + |lc| lc + CS::one(), + |lc| lc + left_child_index.get_variable(), + ); + let right_child_index = AllocatedNum::alloc(cs.namespace(|| "right_child_index"), || { + left_child_index + .get_value() + .map(|i| i + F::ONE) + .ok_or(SynthesisError::AssignmentMissing) + })?; + cs.enforce( + || "(left_child_index + 1) * 1 = right_child_index", + |lc| lc + left_child_index.get_variable() + CS::one(), + |lc| lc + CS::one(), + |lc| lc + right_child_index.get_variable(), + ); + let parent = lookup_trace.read(cs.namespace(|| "parent"), index)?; + let left_child = lookup_trace.read(cs.namespace(|| "left_child"), &left_child_index)?; + let right_child = lookup_trace.read(cs.namespace(|| "right_child"), &right_child_index)?; + + let is_left_child_smaller = less_than( + cs.namespace(|| "left_child < parent"), + &left_child, + &parent, + self.max_value_bits, + )?; + + let new_parent_left = conditionally_select2( + cs.namespace(|| "new_left_pair_parent"), + &left_child, + &parent, + &is_left_child_smaller, + )?; + + let new_left_child = conditionally_select2( + cs.namespace(|| "new_left_pair_child"), + &parent, + &left_child, + &is_left_child_smaller, + )?; + + lookup_trace.write( + cs.namespace(|| "write_left_pair_parent"), + index, + &new_parent_left, + )?; + lookup_trace.write( + cs.namespace(|| "write_left_pair_child"), + &left_child_index, + &new_left_child, + )?; + + let is_right_child_smaller = less_than( + cs.namespace(|| "right_child < parent"), + &right_child, + &new_parent_left, + self.max_value_bits, + )?; + + let new_parent_right = conditionally_select2( + cs.namespace(|| "new_right_pair_parent"), + &right_child, + &new_parent_left, + &is_right_child_smaller, + )?; + let new_right_child = conditionally_select2( + cs.namespace(|| "new_right_pair_child"), + &new_parent_left, + &right_child, + &is_right_child_smaller, + )?; + lookup_trace.write( + cs.namespace(|| "write_right_pair_parent"), + index, + &new_parent_right, + )?; + lookup_trace.write( + cs.namespace(|| "write_left_pair_child"), + &right_child_index, + &new_right_child, + )?; + + // commit the rw change + let (next_R, next_W, next_rw_counter, next_intermediate_gamma) = lookup_trace + .commit::>::Root>>( + cs.namespace(|| "commit"), + self.ro_consts.clone(), + prev_intermediate_gamma, + gamma, + prev_W, + prev_R, + prev_rw_counter, + )?; + + let next_index = AllocatedNum::alloc(cs.namespace(|| "next_index"), || { + index + .get_value() + .map(|index| index - G1::Scalar::from(1)) + .ok_or(SynthesisError::AssignmentMissing) + })?; + cs.enforce( + || "(next_index + 1) * 1 = index", + |lc| lc + next_index.get_variable() + CS::one(), + |lc| lc + CS::one(), + |lc| lc + index.get_variable(), + ); + Ok(vec![ + next_intermediate_gamma, + gamma.clone(), + next_R, + next_W, + next_rw_counter, + next_index, + ]) + } + } + + /// A trivial step circuit that simply returns the input + #[derive(Clone, Debug, Default, PartialEq, Eq)] + pub struct TrivialTestCircuit { + _p: PhantomData, + } + + impl StepCircuit for TrivialTestCircuit + where + F: PrimeField, + { + fn arity(&self) -> usize { + 1 + } + + fn synthesize>( + &self, + _cs: &mut CS, + z: &[AllocatedNum], + ) -> Result>, SynthesisError> { + Ok(z.to_vec()) + } + } + + let heap_size: usize = 4; + + let ro_consts: ROConstantsCircuit = PoseidonConstantsCircuit::default(); + + let initial_table = { + let mut initial_table = (0..heap_size - 1) + .map(|i| { + ( + ::Base::from(i as u64), + ::Base::from((heap_size - 2 - i) as u64), + ) + }) + .collect::::Base, ::Base)>>(); + initial_table.push(( + ::Base::from(heap_size as u64 - 1), + ::Base::ZERO, + )); // attach 1 dummy element to assure table size is power of 2 + Lookup::new(heap_size * 4, TableType::ReadWrite, initial_table) + }; + + let (circuit_primaries, final_table, intermediate_gamma) = + HeapifyCircuit::new(&initial_table, ro_consts); + // let mut circuit_primary = TrivialTestCircuit::default(); + // let z0_primary = vec![::Scalar::ZERO; 6]; + + let circuit_secondary = TrivialTestCircuit::default(); + // let mut circuit_primary = TrivialTestCircuit::default(); + + // produce public parameters + let pp_hint1 = Some(SPrime::::commitment_key_floor()); + let pp_hint2 = Some(SPrime::::commitment_key_floor()); + let pp = PublicParams::< + G1, + G2, + HeapifyCircuit, + TrivialTestCircuit<::Scalar>, + >::new( + &circuit_primaries[0], + &circuit_secondary, + pp_hint1, + pp_hint2, + ); + + let z0_primary = + HeapifyCircuit::::get_z0(&pp.ck_primary, &initial_table, intermediate_gamma); + // println!("num constraints {:?}", pp.num_constraints()); + + // 5th is initial index. + // +1 for index end with 0 + let num_steps = u32::from_le_bytes(z0_primary[5].to_repr()[0..4].try_into().unwrap()) + 1; + + let z0_secondary = vec![::Scalar::ZERO; 1]; + + // produce a recursive SNARK + let mut recursive_snark: RecursiveSNARK< + G1, + G2, + HeapifyCircuit, + TrivialTestCircuit<::Scalar>, + > = RecursiveSNARK::new( + &pp, + &circuit_primaries[0], + &circuit_secondary, + z0_primary.clone(), + z0_secondary.clone(), + ); + + for i in 0..num_steps { + println!("step i {}", i); + let res = recursive_snark.prove_step( + &pp, + &circuit_primaries[i as usize], + &circuit_secondary.clone(), + z0_primary.clone(), + z0_secondary.clone(), + ); + res + .clone() + .map_err(|err| println!("err {:?}", err)) + .unwrap(); + assert!(res.is_ok()); + } + // verify the recursive SNARK + let res = recursive_snark.verify(&pp, num_steps as usize, &z0_primary, &z0_secondary); + res + .clone() + .map_err(|err| { + print_constraints_name_on_error_index::(&err, &circuit_primaries[0]) + }) + .unwrap(); + assert!(res.is_ok()); + /* + let next_gamma = &z[0]; + let gamma = &z[1]; + let next_R = &z[2]; + let next_W = &z[3]; + let next_rw_counter = &z[4]; + let next_index = &z[5]; + */ + let (zn_primary, _) = res.unwrap(); + + // TODO move below check to LookupSNARK + // assert_eq!(zn_primary[0], zn_primary[1]); // challenge == pre_compute_challenge + + assert_eq!(::Scalar::from(1).neg(), zn_primary[5]); // last index == -1 + + let number_of_iterated_nodes = (heap_size - 4) / 2 + 1; + assert_eq!( + ::Scalar::from((number_of_iterated_nodes * 7) as u64), + zn_primary[4] + ); // rw counter = number_of_iterated_nodes * (3r + 4w) operations + + assert_eq!(pp.circuit_shape_primary.r1cs_shape.num_cons, 12598); + assert_eq!(pp.circuit_shape_primary.r1cs_shape.num_vars, 12604); + assert_eq!(pp.circuit_shape_secondary.r1cs_shape.num_cons, 10347); + assert_eq!(pp.circuit_shape_secondary.r1cs_shape.num_vars, 10329); + + println!("zn_primary {:?}", zn_primary); + + let gamma = zn_primary[1]; + let read_row = zn_primary[2]; + let write_row = zn_primary[3]; + + // lookup snark prove/verify + type EE = crate::provider::ipa_pc::EvaluationEngine; + let (pk, vk) = + LookupSNARK::::setup(&pp.ck_primary, &initial_table.get_table()).unwrap(); + let snark_proof = LookupSNARK::::prove( + &pp.ck_primary, + &pk, + gamma, + read_row, + write_row, + initial_table.get_table(), + final_table.get_table(), + ) + .unwrap(); + + let res = snark_proof.verify(&vk); + let _ = res.clone().map_err(|err| println!("{:?}", err)); + res.unwrap() + } } diff --git a/src/spartan/lookupsnark.rs b/src/spartan/lookupsnark.rs new file mode 100644 index 000000000..30f9c2899 --- /dev/null +++ b/src/spartan/lookupsnark.rs @@ -0,0 +1,927 @@ +//! This module implements LookupSNARK which leverage memory-offline-check skills +use crate::{ + digest::{DigestComputer, SimpleDigestible}, + errors::NovaError, + spartan::{ + math::Math, + polys::{ + eq::EqPolynomial, + multilinear::MultilinearPolynomial, + univariate::{CompressedUniPoly, UniPoly}, + }, + powers, + sumcheck::SumcheckProof, + PolyEvalInstance, PolyEvalWitness, + }, + traits::{ + commitment::{CommitmentEngineTrait, CommitmentTrait}, + evaluation::EvaluationEngineTrait, + Group, TranscriptEngineTrait, + }, + Commitment, CommitmentKey, CompressedCommitment, +}; +use abomonation::Abomonation; +use abomonation_derive::Abomonation; +use core::marker::PhantomData; +use ff::{Field, PrimeField}; + +use crate::spartan::ppsnark::vec_to_arr; +use once_cell::sync::OnceCell; +use rayon::prelude::*; +use serde::{Deserialize, Serialize}; +use std::ops::Deref; + +use super::ppsnark::{IdentityPolynomial, ProductSumcheckInstance, SumcheckEngine}; + +/// A type that represents the prover's key +#[derive(Clone, Serialize, Deserialize, Abomonation)] +#[serde(bound = "")] +#[abomonation_bounds(where ::Repr: Abomonation)] +pub struct ProverKey> { + pk_ee: EE::ProverKey, + comm_init_value: Commitment, + #[abomonate_with(::Repr)] + vk_digest: G::Scalar, // digest of verifier's key +} + +/// A type that represents the verifier's key +#[derive(Clone, Serialize, Deserialize, Abomonation)] +#[serde(bound = "")] +#[abomonation_bounds(where ::Repr: Abomonation)] +pub struct VerifierKey> { + N: usize, // table size + vk_ee: EE::VerifierKey, + comm_init_value: Commitment, + #[abomonation_skip] + #[serde(skip, default = "OnceCell::new")] + digest: OnceCell, +} + +impl> VerifierKey { + fn new(vk_ee: EE::VerifierKey, table_size: usize, comm_init_value: Commitment) -> Self { + VerifierKey { + vk_ee, + digest: Default::default(), + comm_init_value, + N: table_size, + } + } + + /// Returns the digest of the verifier's key + pub fn digest(&self) -> G::Scalar { + self + .digest + .get_or_try_init(|| { + let dc = DigestComputer::new(self); + dc.digest() + }) + .cloned() + .expect("Failure to retrieve digest!") + } +} + +impl> SimpleDigestible for VerifierKey {} + +/// MemoryOfflineSumcheckInstance +pub struct MemoryOfflineSumcheckInstance(ProductSumcheckInstance); + +impl Deref for MemoryOfflineSumcheckInstance { + type Target = ProductSumcheckInstance; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl MemoryOfflineSumcheckInstance { + /// new a productsumcheck instance + pub fn new( + ck: &CommitmentKey, + input_vec: Vec>, // list of input vectors + transcript: &mut G::TE, + ) -> Result { + let inner = ProductSumcheckInstance::new(ck, input_vec, transcript)?; + Ok(MemoryOfflineSumcheckInstance(inner)) + } +} + +impl SumcheckEngine for MemoryOfflineSumcheckInstance { + fn initial_claims(&self) -> Vec { + self.0.claims.to_vec() + } + + fn degree(&self) -> usize { + self.0.degree() + } + + fn size(&self) -> usize { + self.0.size() + } + + fn evaluation_points(&self) -> Vec> { + self.0.evaluation_points() + } + + fn bound(&mut self, r: &G::Scalar) { + self.0.bound(r) + } + + fn final_claims(&self) -> Vec> { + self.0.final_claims() + } +} + +#[allow(unused)] +/// LookupSNARK +pub struct LookupSNARK> { + a: PhantomData<(G, EE)>, + + // commitment to oracles for the inner sum-check + comm_final_counter: CompressedCommitment, + comm_final_value: CompressedCommitment, + + read_row: G::Scalar, + write_row: G::Scalar, + gamma: G::Scalar, + + comm_output_arr: [CompressedCommitment; 2], + claims_product_arr: [G::Scalar; 2], + + eval_left_arr: [G::Scalar; 2], + eval_right_arr: [G::Scalar; 2], + eval_output_arr: [G::Scalar; 2], + eval_input_arr: [G::Scalar; 2], + eval_output2_arr: [G::Scalar; 2], + + // satisfiability sum-check + sc_sat: SumcheckProof, + + eval_init_value_at_r_prod: G::Scalar, + eval_final_value_at_r_prod: G::Scalar, + eval_final_counter_at_r_prod: G::Scalar, + + // batch openings of all multilinear polynomials + sc_proof_batch: SumcheckProof, + evals_batch_arr: [G::Scalar; 4], + eval_arg: EE::EvaluationArgument, +} + +impl> LookupSNARK +where + ::Repr: Abomonation, +{ + /// setup + pub fn setup( + ck: &CommitmentKey, + initial_table: &Vec<(G::Scalar, G::Scalar, G::Scalar)>, + ) -> Result<(ProverKey, VerifierKey), NovaError> { + // check the provided commitment key meets minimal requirements + // assert!(ck.length() >= Self::commitment_key_floor()(S)); + let init_values: Vec<::Scalar> = + initial_table.iter().map(|(_, value, _)| *value).collect(); + + let comm_init_value = G::CE::commit(ck, &init_values); + + let (pk_ee, vk_ee) = EE::setup(ck); + let table_size = initial_table.len(); + + let vk = VerifierKey::new(vk_ee, table_size, comm_init_value); + + let pk = ProverKey { + pk_ee, + comm_init_value, + vk_digest: vk.digest(), + }; + + Ok((pk, vk)) + } + /// produces a succinct proof of satisfiability of a `LookupSNARK` instance + #[tracing::instrument(skip_all, name = "LookupSNARK::prove")] + pub fn prove( + ck: &CommitmentKey, + pk: &ProverKey, + gamma: G::Scalar, + read_row: G::Scalar, + write_row: G::Scalar, + initial_table: Vec<(G::Scalar, G::Scalar, G::Scalar)>, + final_table: Vec<(G::Scalar, G::Scalar, G::Scalar)>, + ) -> Result { + // a list of polynomial evaluation claims that will be batched + let mut w_u_vec = Vec::new(); + + let gamma_square = gamma * gamma; + let hash_func = |addr: &G::Scalar, val: &G::Scalar, ts: &G::Scalar| -> G::Scalar { + gamma - (*ts * gamma_square + *val * gamma + *addr) + }; + // init_row + // TODO: initial_table need to be put in setup phase + let initial_row: Vec = initial_table + .iter() + .map(|(addr, value, counter)| hash_func(addr, value, counter)) + .collect(); + // audit_row + let audit_row: Vec = final_table + .iter() + .map(|(addr, value, counter)| hash_func(addr, value, counter)) + .collect(); + let mut transcript = G::TE::new(b"LookupSNARK"); + // append the verifier key (which includes commitment to R1CS matrices) and the read_row/write_row to the transcript + transcript.absorb(b"vk", &pk.vk_digest); + transcript.absorb(b"read_row", &read_row); + transcript.absorb(b"write_row", &write_row); + transcript.absorb(b"gamma", &gamma); + + let init_values: Vec<::Scalar> = + initial_table.iter().map(|(_, value, _)| *value).collect(); + let final_values: Vec<::Scalar> = + final_table.iter().map(|(_, value, _)| *value).collect(); + let final_counters: Vec<::Scalar> = + final_table.iter().map(|(_, _, counter)| *counter).collect(); + // TODO add comm_final_value, comm_final_counter to gamma challange + // which means we need to move final_values, final_counters commitment at earlier + let comm_init_value = pk.comm_init_value; + let (comm_final_value, comm_final_counter) = rayon::join( + || G::CE::commit(ck, &final_values), + || G::CE::commit(ck, &final_counters), + ); + // add commitment into the challenge + transcript.absorb(b"e", &[comm_final_value, comm_final_counter].as_slice()); + + let mut memory_offline_sc_inst = + MemoryOfflineSumcheckInstance::::new(ck, vec![initial_row, audit_row], &mut transcript) + .unwrap(); + + // sanity check: claimed_prod_init_row * write_row - claimed_prod_audit_row * read_row = 0 + let initial_claims = memory_offline_sc_inst.initial_claims(); + let (claimed_prod_init_row, claimed_prod_audit_row) = (initial_claims[0], initial_claims[1]); + assert_eq!(claimed_prod_init_row * write_row - read_row * claimed_prod_audit_row, ::Scalar::ZERO, "claimed_prod_init_row {:?} * write_row {:?} - claimed_prod_audit_row {:?} * read_row {:?} = {:?}", + claimed_prod_init_row, + write_row, + claimed_prod_audit_row, + read_row, + claimed_prod_init_row * write_row - read_row * claimed_prod_audit_row + ); + + // generate sumcheck proof + let num_claims = initial_claims.len(); + let coeffs = { + let s = transcript.squeeze(b"r").unwrap(); + let mut s_vec = vec![s]; + for i in 1..num_claims { + s_vec.push(s_vec[i - 1] * s); + } + s_vec + }; + // compute the joint claim + let claim = initial_claims + .iter() + .zip(coeffs.iter()) + .map(|(c_1, c_2)| *c_1 * c_2) + .sum(); + let mut e = claim; + let mut r_sat: Vec = Vec::new(); + let mut cubic_polys: Vec> = Vec::new(); + let num_rounds = memory_offline_sc_inst.size().log_2(); + + for _i in 0..num_rounds { + let mut evals: Vec> = Vec::new(); + evals.extend(memory_offline_sc_inst.evaluation_points()); + + let evals_combined_0 = (0..evals.len()).map(|i| evals[i][0] * coeffs[i]).sum(); + let evals_combined_2 = (0..evals.len()).map(|i| evals[i][1] * coeffs[i]).sum(); + let evals_combined_3 = (0..evals.len()).map(|i| evals[i][2] * coeffs[i]).sum(); + + let evals = vec![ + evals_combined_0, + e - evals_combined_0, + evals_combined_2, + evals_combined_3, + ]; + let poly = UniPoly::from_evals(&evals); + + // append the prover's message to the transcript + transcript.absorb(b"p", &poly); + + // derive the verifier's challenge for the next round + let r_i = transcript.squeeze(b"c").unwrap(); + r_sat.push(r_i); + + memory_offline_sc_inst.bound(&r_i); + + e = poly.evaluate(&r_i); + cubic_polys.push(poly.compress()); + } + + let final_claims = memory_offline_sc_inst.final_claims(); + + let sc_sat = SumcheckProof::::new(cubic_polys); + + // claims[0] is about the Eq polynomial, which the verifier computes directly + // claims[1] =? weighed sum of left(rand) + // claims[2] =? weighted sum of right(rand) + // claims[3] =? weighted sum of output(rand), which is easy to verify by querying output + // we also need to prove that output(output.len()-2) = claimed_product + let eval_left_vec = final_claims[1].clone(); + let eval_right_vec = final_claims[2].clone(); + let eval_output_vec = final_claims[3].clone(); + + let eval_vec = vec![ + eval_left_vec.clone(), + eval_right_vec.clone(), + eval_output_vec.clone(), + ] + .concat(); + // absorb all the claimed evaluations + transcript.absorb(b"e", &eval_vec.as_slice()); + + // we now combine eval_left = left(rand) and eval_right = right(rand) + // into claims about input and output + let c = transcript.squeeze(b"c").unwrap(); + + // eval = (G::Scalar::ONE - c) * eval_left + c * eval_right + // eval is claimed evaluation of input||output(r, c), which can be proven by proving input(r[1..], c) and output(r[1..], c) + let rand_ext = { + let mut r = r_sat.clone(); + r.extend(&[c]); + r + }; + let r_prod = rand_ext[1..].to_vec(); + + let eval_input_vec = memory_offline_sc_inst + .input_vec + .iter() + .map(|i| MultilinearPolynomial::evaluate_with(i, &r_prod)) + .collect::>(); + + let eval_output2_vec = memory_offline_sc_inst + .output_vec + .iter() + .map(|o| MultilinearPolynomial::evaluate_with(o, &r_prod)) + .collect::>(); + + // add claimed evaluations to the transcript + let evals = eval_input_vec + .clone() + .into_iter() + .chain(eval_output2_vec.clone()) + .collect::>(); + transcript.absorb(b"e", &evals.as_slice()); + + // squeeze a challenge to combine multiple claims into one + let powers_of_rho = { + let s = transcript.squeeze(b"r")?; + let mut s_vec = vec![s]; + for i in 1..memory_offline_sc_inst.initial_claims().len() { + s_vec.push(s_vec[i - 1] * s); + } + s_vec + }; + + // take weighted sum (random linear combination) of input, output, and their commitments + // product is `initial claim` + let product: ::Scalar = memory_offline_sc_inst + .claims + .iter() + .zip(powers_of_rho.iter()) + .map(|(e, p)| *e * p) + .sum(); + + let eval_output: ::Scalar = eval_output_vec + .iter() + .zip(powers_of_rho.iter()) + .map(|(e, p)| *e * p) + .sum(); + + let comm_output = memory_offline_sc_inst + .comm_output_vec + .iter() + .zip(powers_of_rho.iter()) + .map(|(c, r_i)| *c * *r_i) + .fold(Commitment::::default(), |acc, item| acc + item); + + let weighted_sum = |W: &[Vec], s: &[G::Scalar]| -> Vec { + assert_eq!(W.len(), s.len()); + let mut p = vec![::Scalar::ZERO; W[0].len()]; + for i in 0..W.len() { + for (j, item) in W[i].iter().enumerate().take(W[i].len()) { + p[j] += *item * s[i] + } + } + p + }; + + let poly_output = weighted_sum(&memory_offline_sc_inst.output_vec, &powers_of_rho); + + let eval_output2: ::Scalar = eval_output2_vec + .iter() + .zip(powers_of_rho.iter()) + .map(|(e, p)| *e * p) + .sum(); + + // eval_output = output(r_sat) + w_u_vec.push(( + PolyEvalWitness:: { + p: poly_output.clone(), + }, + PolyEvalInstance:: { + c: comm_output, + x: r_sat.clone(), + e: eval_output, + }, + )); + + // claimed_product = output(1, ..., 1, 0) + let x = { + let mut x = vec![G::Scalar::ONE; r_sat.len()]; + x[r_sat.len() - 1] = G::Scalar::ZERO; + x + }; + w_u_vec.push(( + PolyEvalWitness { + p: poly_output.clone(), + }, + PolyEvalInstance { + c: comm_output, + x, + e: product, + }, + )); + + // eval_output2 = output(r_prod) + w_u_vec.push(( + PolyEvalWitness { p: poly_output }, + PolyEvalInstance { + c: comm_output, + x: r_prod.clone(), + e: eval_output2, + }, + )); + + let evals = [ + &init_values, // init value (all init ts are 0) + &final_values, + &final_counters, + ] + .into_par_iter() + .map(|p| MultilinearPolynomial::evaluate_with(p, &r_prod.clone())) + .collect::>(); + + let eval_init_value_at_r_prod = evals[0]; + let eval_final_value_at_r_prod = evals[1]; + let eval_final_counter_at_r_prod = evals[2]; + + // we can batch all the claims + transcript.absorb( + b"e", + &[ + eval_init_value_at_r_prod, + eval_final_value_at_r_prod, + eval_final_counter_at_r_prod, + ] + .as_slice(), + ); + + // generate challenge for rlc + let c = transcript.squeeze(b"c")?; + let eval_vec = [ + eval_init_value_at_r_prod, + eval_final_value_at_r_prod, + eval_final_counter_at_r_prod, + ]; + let comm_vec = [comm_init_value, comm_final_value, comm_final_counter]; + let poly_vec = [ + &init_values.to_vec(), + &final_values.to_vec(), + &final_counters.to_vec(), + ]; + let w = PolyEvalWitness::batch(&poly_vec, &c); + let u = PolyEvalInstance::batch(&comm_vec, &r_prod, &eval_vec, &c); + + // add the claim to prove for later + w_u_vec.push((w, u)); + + // We will now reduce a vector of claims of evaluations at different points into claims about them at the same point. + // For example, eval_W =? W(r_y[1..]) and eval_W =? E(r_x) into + // two claims: eval_W_prime =? W(rz) and eval_E_prime =? E(rz) + // We can them combine the two into one: eval_W_prime + gamma * eval_E_prime =? (W + gamma*E)(rz), + // where gamma is a public challenge + // Since commitments to W and E are homomorphic, the verifier can compute a commitment + // to the batched polynomial. + assert!(w_u_vec.len() >= 2); + + let (w_vec, u_vec): (Vec>, Vec>) = + w_u_vec.into_iter().unzip(); + let w_vec_padded = PolyEvalWitness::pad(&w_vec); // pad the polynomials to be of the same size + let u_vec_padded = PolyEvalInstance::pad(&u_vec); // pad the evaluation points + + // generate a challenge + let rho = transcript.squeeze(b"r")?; + let num_claims = w_vec_padded.len(); + let powers_of_rho = powers::(&rho, num_claims); + let claim_batch_joint: ::Scalar = u_vec_padded + .iter() + .zip(powers_of_rho.iter()) + .map(|(u, p)| u.e * p) + .sum(); + + let mut polys_left: Vec> = w_vec_padded + .iter() + .map(|w| MultilinearPolynomial::new(w.p.clone())) + .collect(); + let mut polys_right: Vec> = u_vec_padded + .iter() + .map(|u| MultilinearPolynomial::new(EqPolynomial::new(u.x.clone()).evals())) + .collect(); + + let num_rounds_z = u_vec_padded[0].x.len(); + let comb_func = |poly_A_comp: &G::Scalar, poly_B_comp: &G::Scalar| -> G::Scalar { + *poly_A_comp * *poly_B_comp + }; + let (sc_proof_batch, r_z, claims_batch) = SumcheckProof::::prove_quad_batch( + &claim_batch_joint, + num_rounds_z, + &mut polys_left, + &mut polys_right, + &powers_of_rho, + comb_func, + &mut transcript, + )?; + + let (claims_batch_left, _): (Vec, Vec) = claims_batch; + + transcript.absorb(b"l", &claims_batch_left.as_slice()); + + // we now combine evaluation claims at the same point rz into one + let gamma = transcript.squeeze(b"g")?; + let powers_of_gamma: Vec = powers::(&gamma, num_claims); + let comm_joint = u_vec_padded + .iter() + .zip(powers_of_gamma.iter()) + .map(|(u, g_i)| u.c * *g_i) + .fold(Commitment::::default(), |acc, item| acc + item); + let poly_joint = PolyEvalWitness::weighted_sum(&w_vec_padded, &powers_of_gamma); + let eval_joint: ::Scalar = claims_batch_left + .iter() + .zip(powers_of_gamma.iter()) + .map(|(e, g_i)| *e * *g_i) + .sum(); + + let eval_arg = EE::prove( + ck, + &pk.pk_ee, + &mut transcript, + &comm_joint, + &poly_joint.p, + &r_z, + &eval_joint, + )?; + + println!( + "debug: prove: before going to compress {:?}", + memory_offline_sc_inst.comm_output_vec + ); + + println!("debug: prove: comm_final_value {:?}", comm_final_value,); + + println!( + "debug: prove 2: comm_final_value {:?}", + Commitment::::decompress(&comm_final_value.compress())? + ); + Ok(LookupSNARK { + comm_final_counter: comm_final_counter.compress(), + comm_final_value: comm_final_value.compress(), + + read_row, + write_row, + gamma, + + comm_output_arr: vec_to_arr( + memory_offline_sc_inst + .comm_output_vec + .iter() + .map(|c| c.compress()) + .collect::>>(), + ), + claims_product_arr: vec_to_arr(memory_offline_sc_inst.claims.clone()), + + sc_sat, + + eval_left_arr: vec_to_arr(eval_left_vec), + eval_right_arr: vec_to_arr(eval_right_vec), + eval_output_arr: vec_to_arr(eval_output_vec), + eval_input_arr: vec_to_arr(eval_input_vec), + eval_output2_arr: vec_to_arr(eval_output2_vec), + + eval_init_value_at_r_prod, + eval_final_value_at_r_prod, + eval_final_counter_at_r_prod, + + sc_proof_batch, + evals_batch_arr: vec_to_arr(claims_batch_left), + eval_arg, + a: PhantomData {}, + }) + } + + /// verifies a proof of satisfiability of a `RelaxedR1CS` instance + pub fn verify(&self, vk: &VerifierKey) -> Result<(), NovaError> { + let mut transcript = G::TE::new(b"LookupSNARK"); + let mut u_vec: Vec> = Vec::new(); + let comm_final_value = Commitment::::decompress(&self.comm_final_value)?; + println!( + "debug: verify Commitment::::decompress(comm_final_value) {:?}", + Commitment::::decompress(&self.comm_final_value)? + ); + let comm_final_counter = Commitment::::decompress(&self.comm_final_counter)?; + + // append the verifier key (including commitment to R1CS matrices) and the RelaxedR1CSInstance to the transcript + transcript.absorb(b"vk", &vk.digest()); + transcript.absorb(b"read_row", &self.read_row); + transcript.absorb(b"write_row", &self.write_row); + transcript.absorb(b"gamma", &self.gamma); + + // add commitment into the challenge + transcript.absorb(b"e", &[comm_final_value, comm_final_counter].as_slice()); + + let num_rounds_sat = vk.N.log_2(); + + // hash function + let gamma_square = self.gamma * self.gamma; + let hash_func = |addr: &G::Scalar, val: &G::Scalar, ts: &G::Scalar| -> G::Scalar { + self.gamma - (*ts * gamma_square + *val * self.gamma + *addr) + }; + + // check claimed_prod_init_row * write_row - claimed_prod_audit_row * read_row = 0 + // sanity check: any of them might not be 0 + assert!( + self.claims_product_arr[0] * self.write_row * self.claims_product_arr[1] * self.read_row + != G::Scalar::ZERO, + "any of claims_product_arr {:?}, write_row {:?}, read_row {:?} = 0", + self.claims_product_arr, + self.write_row, + self.read_row + ); + if self.claims_product_arr[0] * self.write_row - self.claims_product_arr[1] * self.read_row + != G::Scalar::ZERO + { + return Err(NovaError::InvalidMultisetProof); + } + + let comm_output_vec = self + .comm_output_arr + .iter() + .map(|c| Commitment::::decompress(c)) + .collect::>, NovaError>>()?; + + transcript.absorb(b"o", &comm_output_vec.as_slice()); + println!( + "debug: verify comm_output_vec {:?}", + comm_output_vec.as_slice() + ); + transcript.absorb(b"c", &self.claims_product_arr.as_slice()); + + let num_rounds = vk.N.log_2(); + let rand_eq = (0..num_rounds) + .map(|_i| transcript.squeeze(b"e")) + .collect::, NovaError>>()?; + + let num_claims = 2; + let coeffs = { + let s = transcript.squeeze(b"r")?; + let mut s_vec = vec![s]; + for i in 1..num_claims { + s_vec.push(s_vec[i - 1] * s); + } + s_vec + }; + + let initial_claims = self + .claims_product_arr + .iter() + .zip(coeffs.iter()) + .map(|(e, p)| *e * p) + .sum(); + let (claim_mem_sat_final, r_sat) = + self + .sc_sat + .verify(initial_claims, num_rounds_sat, 3, &mut transcript)?; + let rand_eq_bound_r_sat = EqPolynomial::new(rand_eq).evaluate(&r_sat); + let claim_mem_final_expected: G::Scalar = (0..2) + .map(|i| { + coeffs[i] + * rand_eq_bound_r_sat + * (self.eval_left_arr[i] * self.eval_right_arr[i] - self.eval_output_arr[i]) + }) + .sum(); + + if claim_mem_final_expected != claim_mem_sat_final { + println!( + "claim_mem_final_expected {:?} != claim_mem_sat_final {:?}", + claim_mem_final_expected, claim_mem_sat_final + ); + return Err(NovaError::InvalidSumcheckProof); + } + + // claims from the end of the sum-check + let eval_vec = [] + .into_iter() + .chain(self.eval_left_arr) + .chain(self.eval_right_arr) + .chain(self.eval_output_arr) + .collect::>(); + + transcript.absorb(b"e", &eval_vec.as_slice()); + // we now combine eval_left = left(rand) and eval_right = right(rand) + // into claims about input and output + let c = transcript.squeeze(b"c")?; + + // eval = (G::Scalar::ONE - c) * eval_left + c * eval_right + // eval is claimed evaluation of input||output(r, c), which can be proven by proving input(r[1..], c) and output(r[1..], c) + let rand_ext = { + let mut r = r_sat.clone(); + r.extend(&[c]); + r + }; + let r_prod = rand_ext[1..].to_vec(); + + // add claimed evaluations to the transcript + let evals = self + .eval_input_arr + .into_iter() + .chain(self.eval_output2_arr) + .collect::>(); + transcript.absorb(b"e", &evals.as_slice()); + + // squeeze a challenge to combine multiple claims into one + let powers_of_rho = { + let s = transcript.squeeze(b"r")?; + let mut s_vec = vec![s]; + for i in 1..num_claims { + s_vec.push(s_vec[i - 1] * s); + } + s_vec + }; + + // take weighted sum of input, output, and their commitments + let product = self + .claims_product_arr + .iter() + .zip(powers_of_rho.iter()) + .map(|(e, p)| *e * p) + .sum(); + + let eval_output = self + .eval_output_arr + .iter() + .zip(powers_of_rho.iter()) + .map(|(e, p)| *e * p) + .sum(); + + let comm_output = comm_output_vec + .iter() + .zip(powers_of_rho.iter()) + .map(|(c, r_i)| *c * *r_i) + .fold(Commitment::::default(), |acc, item| acc + item); + + let eval_output2 = self + .eval_output2_arr + .iter() + .zip(powers_of_rho.iter()) + .map(|(e, p)| *e * p) + .sum(); + + // eval_output = output(r_sat) + u_vec.push(PolyEvalInstance { + c: comm_output, + x: r_sat.clone(), + e: eval_output, + }); + + // claimed_product = output(1, ..., 1, 0) + let x = { + let mut x = vec![G::Scalar::ONE; r_sat.len()]; + x[r_sat.len() - 1] = G::Scalar::ZERO; + x + }; + u_vec.push(PolyEvalInstance { + c: comm_output, + x, + e: product, + }); + + // eval_output2 = output(r_prod) + u_vec.push(PolyEvalInstance { + c: comm_output, + x: r_prod.clone(), + e: eval_output2, + }); + + // we can batch all the claims + transcript.absorb( + b"e", + &[ + self.eval_init_value_at_r_prod, + self.eval_final_value_at_r_prod, + self.eval_final_counter_at_r_prod, + ] + .as_slice(), + ); + let c = transcript.squeeze(b"c")?; + let eval_vec = [ + self.eval_init_value_at_r_prod, + self.eval_final_value_at_r_prod, + self.eval_final_counter_at_r_prod, + ]; + let comm_vec = vec![vk.comm_init_value, comm_final_value, comm_final_counter]; + let u = PolyEvalInstance::batch(&comm_vec, &r_prod, &eval_vec, &c); + + // add the claim to prove for later + u_vec.push(u); + + // finish the final step of the sum-check + let (claim_init_expected_row, claim_audit_expected_row) = { + let addr = IdentityPolynomial::new(r_prod.len()).evaluate(&r_prod); + ( + hash_func(&addr, &self.eval_init_value_at_r_prod, &G::Scalar::ZERO), + hash_func( + &addr, + &self.eval_final_value_at_r_prod, + &self.eval_final_counter_at_r_prod, + ), + ) + }; + + // multiset check for the row + if claim_init_expected_row != self.eval_input_arr[0] + || claim_audit_expected_row != self.eval_input_arr[1] + { + return Err(NovaError::InvalidSumcheckProof); + } + + let u_vec_padded = PolyEvalInstance::pad(&u_vec); // pad the evaluation points + + // generate a challenge + let rho = transcript.squeeze(b"r")?; + let num_claims = u_vec.len(); + let powers_of_rho = powers::(&rho, num_claims); + let claim_batch_joint = u_vec_padded + .iter() + .zip(powers_of_rho.iter()) + .map(|(u, p)| u.e * p) + .sum(); + + let num_rounds_z = u_vec_padded[0].x.len(); + let (claim_batch_final, r_z) = + self + .sc_proof_batch + .verify(claim_batch_joint, num_rounds_z, 2, &mut transcript)?; + + let claim_batch_final_expected = { + let poly_rz = EqPolynomial::new(r_z.clone()); + let evals = u_vec_padded + .iter() + .map(|u| poly_rz.evaluate(&u.x)) + .collect::>(); + + evals + .iter() + .zip(self.evals_batch_arr.iter()) + .zip(powers_of_rho.iter()) + .map(|((e_i, p_i), rho_i)| *e_i * *p_i * rho_i) + .sum() + }; + + if claim_batch_final != claim_batch_final_expected { + return Err(NovaError::InvalidSumcheckProof); + } + + transcript.absorb(b"l", &self.evals_batch_arr.as_slice()); + + // we now combine evaluation claims at the same point rz into one + let gamma = transcript.squeeze(b"g")?; + let powers_of_gamma: Vec = powers::(&gamma, num_claims); + let comm_joint = u_vec_padded + .iter() + .zip(powers_of_gamma.iter()) + .map(|(u, g_i)| u.c * *g_i) + .fold(Commitment::::default(), |acc, item| acc + item); + let eval_joint = self + .evals_batch_arr + .iter() + .zip(powers_of_gamma.iter()) + .map(|(e, g_i)| *e * *g_i) + .sum(); + + // verify + EE::verify( + &vk.vk_ee, + &mut transcript, + &comm_joint, + &r_z, + &eval_joint, + &self.eval_arg, + )?; + + Ok(()) + } +} diff --git a/src/spartan/mod.rs b/src/spartan/mod.rs index e3cd204b5..2ac889fa6 100644 --- a/src/spartan/mod.rs +++ b/src/spartan/mod.rs @@ -6,11 +6,12 @@ //! //! In polynomial.rs we also provide foundational types and functions for manipulating multilinear polynomials. pub mod direct; +pub mod lookupsnark; pub(crate) mod math; pub mod polys; pub mod ppsnark; pub mod snark; -mod sumcheck; +pub mod sumcheck; use crate::{traits::Group, Commitment}; use ff::Field; diff --git a/src/spartan/ppsnark.rs b/src/spartan/ppsnark.rs index a83b45bdc..82755253a 100644 --- a/src/spartan/ppsnark.rs +++ b/src/spartan/ppsnark.rs @@ -34,12 +34,12 @@ use once_cell::sync::OnceCell; use rayon::prelude::*; use serde::{Deserialize, Serialize}; -fn vec_to_arr(v: Vec) -> [T; N] { +pub(crate) fn vec_to_arr(v: Vec) -> [T; N] { v.try_into() .unwrap_or_else(|v: Vec| panic!("Expected a Vec of length {} but it was {}", N, v.len())) } -struct IdentityPolynomial { +pub(crate) struct IdentityPolynomial { ell: usize, _p: PhantomData, } @@ -321,12 +321,13 @@ pub trait SumcheckEngine { fn final_claims(&self) -> Vec>; } -struct ProductSumcheckInstance { +/// ProductSumcheckInstance +pub struct ProductSumcheckInstance { pub(crate) claims: Vec, // claimed products pub(crate) comm_output_vec: Vec>, - input_vec: Vec>, - output_vec: Vec>, + pub(crate) input_vec: Vec>, + pub(crate) output_vec: Vec>, poly_A: MultilinearPolynomial, poly_B_vec: Vec>, @@ -335,6 +336,7 @@ struct ProductSumcheckInstance { } impl ProductSumcheckInstance { + /// new a productsumcheck instance pub fn new( ck: &CommitmentKey, input_vec: Vec>, // list of input vectors @@ -408,6 +410,7 @@ impl ProductSumcheckInstance { // absorb the output commitment and the claimed product transcript.absorb(b"o", &comm_output_vec.as_slice()); + println!("prove comm_output_vec {:?}", comm_output_vec.as_slice()); transcript.absorb(b"c", &claims.as_slice()); // generate randomness for the eq polynomial @@ -1019,6 +1022,7 @@ where let comm_vec = vec![comm_Az, comm_Bz, comm_Cz]; let poly_vec = vec![&Az, &Bz, &Cz]; transcript.absorb(b"e", &eval_vec.as_slice()); // c_vec is already in the transcript + // note: c is used for RLC let c = transcript.squeeze(b"c")?; let w = PolyEvalWitness::batch(&poly_vec, &c); let u = PolyEvalInstance::batch(&comm_vec, &tau, &eval_vec, &c); @@ -1129,6 +1133,7 @@ where &mut transcript, )?; + // r_sat is the sumcheck challenge let (sc_sat, r_sat, claims_mem, claims_outer, claims_inner) = Self::prove_inner( &mut mem_sc_inst, &mut outer_sc_inst, @@ -1145,7 +1150,7 @@ where let eval_right_vec = claims_mem[2].clone(); let eval_output_vec = claims_mem[3].clone(); - // claims from the end of sum-check + // claims from the end of sum-check, i.e. final claims let (eval_Az, eval_Bz): (G::Scalar, G::Scalar) = (claims_outer[0][1], claims_outer[0][2]); let eval_Cz = MultilinearPolynomial::evaluate_with(&Cz, &r_sat); let eval_E = MultilinearPolynomial::evaluate_with(&E, &r_sat); @@ -1177,16 +1182,17 @@ where r.extend(&[c]); r }; + let r_prod = rand_ext[1..].to_vec(); let eval_input_vec = mem_sc_inst .input_vec .iter() - .map(|i| MultilinearPolynomial::evaluate_with(i, &rand_ext[1..])) + .map(|i| MultilinearPolynomial::evaluate_with(i, &r_prod)) .collect::>(); let eval_output2_vec = mem_sc_inst .output_vec .iter() - .map(|o| MultilinearPolynomial::evaluate_with(o, &rand_ext[1..])) + .map(|o| MultilinearPolynomial::evaluate_with(o, &r_prod)) .collect::>(); // add claimed evaluations to the transcript @@ -1207,7 +1213,8 @@ where s_vec }; - // take weighted sum of input, output, and their commitments + // take weighted sum (random linear combination) of input, output, and their commitments + // product is `initial claim` let product = mem_sc_inst .claims .iter() @@ -1276,17 +1283,16 @@ where }, )); - // eval_output2 = output(rand_ext[1..]) + // eval_output2 = output(r_prod) w_u_vec.push(( PolyEvalWitness { p: poly_output }, PolyEvalInstance { c: comm_output, - x: rand_ext[1..].to_vec(), + x: r_prod.clone(), e: eval_output2, }, )); - let r_prod = rand_ext[1..].to_vec(); // row-related and col-related claims of polynomial evaluations to aid the final check of the sum-check let evals = [ &pk.S_repr.row, @@ -1299,7 +1305,7 @@ where &pk.S_repr.col_audit_ts, ] .into_par_iter() - .map(|p| MultilinearPolynomial::evaluate_with(p, &r_prod)) + .map(|p| MultilinearPolynomial::evaluate_with(p, &r_prod.clone())) .collect::>(); let eval_row = evals[0]; @@ -1699,6 +1705,7 @@ where r.extend(&[c]); r }; + let r_prod = rand_ext[1..].to_vec(); // add claimed evaluations to the transcript let evals = self @@ -1765,14 +1772,13 @@ where e: product, }); - // eval_output2 = output(rand_ext[1..]) + // eval_output2 = output(r_prod) u_vec.push(PolyEvalInstance { c: comm_output, - x: rand_ext[1..].to_vec(), + x: r_prod.clone(), e: eval_output2, }); - let r_prod = rand_ext[1..].to_vec(); // row-related and col-related claims of polynomial evaluations to aid the final check of the sum-check // we can batch all the claims transcript.absorb( diff --git a/src/spartan/sumcheck.rs b/src/spartan/sumcheck.rs index e3a733e99..28efd9518 100644 --- a/src/spartan/sumcheck.rs +++ b/src/spartan/sumcheck.rs @@ -1,3 +1,6 @@ +//! define sumcheck module +#![allow(clippy::too_many_arguments)] +#![allow(clippy::type_complexity)] use crate::errors::NovaError; use crate::spartan::polys::{ multilinear::MultilinearPolynomial, @@ -8,17 +11,20 @@ use ff::Field; use rayon::prelude::*; use serde::{Deserialize, Serialize}; +/// SumcheckProof #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(bound = "")] -pub(crate) struct SumcheckProof { +pub struct SumcheckProof { compressed_polys: Vec>, } impl SumcheckProof { + /// new a sumcheck proof pub fn new(compressed_polys: Vec>) -> Self { Self { compressed_polys } } + /// verify sumcheck proof pub fn verify( &self, claim: G::Scalar, @@ -137,6 +143,7 @@ impl SumcheckProof { )) } + /// prove_quad_batch pub fn prove_quad_batch( claim: &G::Scalar, num_rounds: usize,