diff --git a/protocol/src/cross_psi/company.rs b/protocol/src/cross_psi/company.rs index dd13fab..738fa9b 100644 --- a/protocol/src/cross_psi/company.rs +++ b/protocol/src/cross_psi/company.rs @@ -51,7 +51,7 @@ pub struct CompanyCrossPsi { self_intersection_indices: Arc>>, //TODO: WARN: this is single column only (yet) - additive_mask: Arc>>, + additive_mask: Arc>>>, partner_shares: Arc>>, self_shares: Arc>>>, } @@ -207,7 +207,7 @@ impl CompanyCrossPsiProtocol for CompanyCrossPsi { } } - fn generate_additive_shares(&self, feature_id: usize, values: TPayload) { + fn generate_additive_shares(&self, feature_index: usize, values: TPayload) { let t = timer::Builder::new() .label("server") .silent(true) @@ -227,7 +227,7 @@ impl CompanyCrossPsiProtocol for CompanyCrossPsi { // Generate random mask { - *self.additive_mask.clone().write().unwrap() = rand_bigints(filtered_values.len()); + self.additive_mask.clone().write().unwrap().insert(feature_index, rand_bigints(filtered_values.len())); } if let (Ok(key), Ok(mask), Ok(mut partner_shares)) = ( @@ -237,9 +237,9 @@ impl CompanyCrossPsiProtocol for CompanyCrossPsi { ) { let res = self .he_cipher - .subtract_plaintext(&key, filtered_values, &mask); + .subtract_plaintext(&key, filtered_values, &mask[&feature_index]); t.qps("masking values in the intersection", res.len()); - partner_shares.insert(feature_id, res); + partner_shares.insert(feature_index, res); } else { panic!("Unable to add additive shares with the intersection") } @@ -336,29 +336,33 @@ impl Reveal for CompanyCrossPsi { self.additive_mask.clone().read(), self.self_shares.clone().write(), ) { + + let mut out: Vec> = Vec::with_capacity(self.get_self_num_features() + self.get_partner_num_features()); let max_val = BigInt::one() << 64; + for feature_index in 0.. self.get_partner_num_features() + self.get_self_num_features() { + if feature_index < self.get_partner_num_features() { + let partner_shares: Vec = additive_mask[&feature_index] + .iter() + .map(|z| (Option::::from(&z.mod_floor(&max_val))).unwrap()) + .collect::>(); + out.push(partner_shares); + } + else { + let mut filtered_shares: Vec = Vec::with_capacity(indices.len()); - let mut filtered_shares: Vec = Vec::with_capacity(indices.len()); + for index in indices.iter() { + filtered_shares.push(self_shares[&(feature_index - self.get_partner_num_features())][*index] + .clone()); + } - for index in indices.iter() { - filtered_shares.push((self_shares[&0][*index]).clone()); + let company_shares = filtered_shares + .iter() + .map(|z| (Option::::from(&z.mod_floor(&max_val))).unwrap()) + .collect::>(); + + out.push(company_shares); + } } - self_shares.remove(&0); - - let company_shares = filtered_shares - .iter() - .map(|z| (Option::::from(&z.mod_floor(&max_val))).unwrap()) - .collect::>(); - - let partner_shares: Vec = additive_mask - .iter() - .map(|z| (Option::::from(&z.mod_floor(&max_val))).unwrap()) - .collect::>(); - - let mut out: Vec> = - Vec::with_capacity(self.get_self_num_features() + self.get_partner_num_features()); - out.push(partner_shares); - out.push(company_shares); info!("revealing columns to output file"); common::files::write_u64cols_to_file(&mut out, path).unwrap(); } else { diff --git a/protocol/src/cross_psi/partner.rs b/protocol/src/cross_psi/partner.rs index 8e89c97..a9111b3 100644 --- a/protocol/src/cross_psi/partner.rs +++ b/protocol/src/cross_psi/partner.rs @@ -40,7 +40,7 @@ pub struct PartnerCrossPsi { plaintext_features: Arc>, company_permutation: Arc>>, self_permutation: Arc>>, - additive_mask: Arc>>, + additive_mask: Arc>>>, self_shares: Arc>>>, company_intersection_indices: Arc>>, } @@ -215,9 +215,9 @@ impl PartnerCrossPsiProtocol for PartnerCrossPsi { ) } - fn generate_additive_shares(&self, _: usize, values: TPayload) -> TPayload { + fn generate_additive_shares(&self, feature_index: usize, values: TPayload) -> TPayload { { - *self.additive_mask.clone().write().unwrap() = rand_bigints(values.len()); + self.additive_mask.clone().write().unwrap().insert(feature_index, rand_bigints(values.len())); } if let (Ok(key), Ok(mask)) = ( @@ -225,7 +225,7 @@ impl PartnerCrossPsiProtocol for PartnerCrossPsi { self.additive_mask.clone().read(), ) { self.he_cipher - .subtract_plaintext(key.deref(), values, &mask) + .subtract_plaintext(key.deref(), values, &mask[&feature_index]) } else { panic!("Cannot mask with additive shares") } @@ -243,37 +243,38 @@ impl PartnerCrossPsiProtocol for PartnerCrossPsi { impl Reveal for PartnerCrossPsi { fn reveal>(&self, path: T) { - if let (Ok(indices), Ok(mut self_shares), Ok(mut additive_mask)) = ( + if let (Ok(indices), Ok(self_shares), Ok(additive_mask)) = ( self.company_intersection_indices.clone().read(), self.self_shares.clone().write(), self.additive_mask.clone().write(), ) { let output_mod = BigInt::one() << 64; let n = BigInt::one() << 1024; - - let mut filtered_shares: Vec = Vec::with_capacity(indices.len()); - - for index in indices.iter() { - filtered_shares.push(additive_mask[*index].clone()); - } - additive_mask.clear(); - - let company_shares = filtered_shares - .iter() - .map(|e| (Option::::from(&mod_sub(e, &n, &output_mod))).unwrap()) - .collect::>(); - - let partner_shares = self_shares - .remove(&0) - .unwrap() - .iter() - .map(|e| (Option::::from(&mod_sub(e, &n, &output_mod))).unwrap()) - .collect::>(); - let mut out: Vec> = - Vec::with_capacity(self.get_self_num_features() + self.get_company_num_features()); - out.push(partner_shares); - out.push(company_shares); + Vec::with_capacity(self.get_self_num_features() + self.get_company_num_features()); + + for feature_index in 0..self.get_self_num_features() + self.get_company_num_features() { + + if feature_index < self.get_self_num_features() { + let partner_shares = self_shares[&feature_index] + .iter() + .map(|e| (Option::::from(&mod_sub(e, &n, &output_mod))).unwrap()) + .collect::>(); + out.push(partner_shares); + } + else { + let mut filtered_shares: Vec = Vec::with_capacity(indices.len()); + for index in indices.iter() { + filtered_shares.push(additive_mask[&(feature_index - self.get_self_num_features())][*index] + .clone()); + }; + let company_shares = filtered_shares + .iter() + .map(|e| (Option::::from(&mod_sub(e, &n, &output_mod))).unwrap()) + .collect::>(); + out.push(company_shares); + } + } info!("revealing columns to output file"); common::files::write_u64cols_to_file(&mut out, path).unwrap(); }