From 5efc296214e817169ab2b1617c9b231795762a57 Mon Sep 17 00:00:00 2001 From: zfscgy Date: Thu, 4 Feb 2021 13:15:48 +0800 Subject: [PATCH 1/3] support for multiple features for company/partner --- protocol/src/cross_psi/company.rs | 583 ++++++++++++++---------------- protocol/src/cross_psi/partner.rs | 410 +++++++++------------ 2 files changed, 448 insertions(+), 545 deletions(-) diff --git a/protocol/src/cross_psi/company.rs b/protocol/src/cross_psi/company.rs index dd13fab..43301e3 100644 --- a/protocol/src/cross_psi/company.rs +++ b/protocol/src/cross_psi/company.rs @@ -1,368 +1,337 @@ // Copyright (c) Facebook, Inc. and its affiliates. // SPDX-License-Identifier: Apache-2.0 -extern crate common; -extern crate crypto; +extern crate csv; -use log::info; -use std::{ - collections::HashMap, - path::Path, - sync::{Arc, RwLock}, +use std::sync::{Arc, RwLock}; + +use crypto::{ + eccipher::{gen_scalar, ECCipher}, + prelude::*, +}; +#[cfg(target_arch = "wasm32")] +use crypto::eccipher::ECRistrettoSequential as ECRistretto; +#[cfg(not(target_arch = "wasm32"))] +use crypto::eccipher::ECRistrettoParallel as ECRistretto; + +use common::{ + files, + permutations::{permute, undo_permute}, + timer, }; use crate::{ - cross_psi::traits::*, - fileio::load_data_with_features, - shared::{LoadData, Reveal, ShareableEncKey, TFeatures}, + fileio::{load_data, load_json, KeyedCSV}, + private_id::traits::CompanyPrivateIdProtocol, }; -use common::timer; -use crypto::{ - eccipher, - eccipher::{gen_scalar, ECCipher, ECRistrettoParallel}, - he, - he::PaillierParallel, - prelude::{rand_bigints, BigInt, ByteBuffer, EncryptionKey, Scalar, TPayload}, -}; +use super::{fill_permute, ProtocolError}; #[derive(Debug)] -pub struct CompanyCrossPsi { - ec_cipher: eccipher::ECRistrettoParallel, - he_cipher: he::PaillierParallel, +pub struct CompanyPrivateId { + private_keys: (Scalar, Scalar), + ec_cipher: ECRistretto, + // TODO: consider using dyn pid::crypto::ECCipher trait? + plain_data: Arc>, + permutation: Arc>>, - ec_key: Scalar, - partner_he_public_key: Arc>, + v_company: Arc>>, + e_company: Arc>>, + e_partner: Arc>>, - self_num_records: Arc>, - self_num_features: Arc>, - partner_num_records: Arc>, - partner_num_features: Arc>, + s_prime_company: Arc>>, + s_prime_partner: Arc>>, - plaintext_keys: Arc>>, - plaintext_features: Arc>, - - self_permutation: Arc>>, - // These are double encrypted - once by partner - // and once by company - encrypted_company_keys: Arc>, - - partner_intersection_mask: Arc>>, - self_intersection_indices: Arc>>, - - //TODO: WARN: this is single column only (yet) - additive_mask: Arc>>, - partner_shares: Arc>>, - self_shares: Arc>>>, + id_map: Arc>>>, } -impl CompanyCrossPsi { - pub fn new() -> CompanyCrossPsi { - CompanyCrossPsi { - ec_cipher: ECRistrettoParallel::new(), - he_cipher: PaillierParallel::new(), - - ec_key: gen_scalar(), - partner_he_public_key: Arc::new(RwLock::new(EncryptionKey { - n: BigInt::zero(), - nn: BigInt::zero(), - })), - - self_num_records: Arc::new(RwLock::default()), - self_num_features: Arc::new(RwLock::default()), - partner_num_records: Arc::new(RwLock::default()), - partner_num_features: Arc::new(RwLock::default()), - - plaintext_features: Arc::new(RwLock::default()), - plaintext_keys: Arc::new(RwLock::default()), - - self_permutation: Arc::new(RwLock::default()), - - encrypted_company_keys: Arc::new(RwLock::default()), - - partner_intersection_mask: Arc::new(RwLock::default()), - self_intersection_indices: Arc::new(RwLock::default()), - - additive_mask: Arc::new(RwLock::default()), - partner_shares: Arc::new(RwLock::default()), - self_shares: Arc::new(RwLock::default()), +impl CompanyPrivateId { + pub fn new() -> CompanyPrivateId { + CompanyPrivateId { + private_keys: (gen_scalar(), gen_scalar()), + ec_cipher: ECRistretto::default(), + plain_data: Arc::new(RwLock::default()), + permutation: Arc::new(RwLock::default()), + v_company: Arc::new(RwLock::default()), + e_company: Arc::new(RwLock::default()), + e_partner: Arc::new(RwLock::default()), + s_prime_company: Arc::new(RwLock::default()), + s_prime_partner: Arc::new(RwLock::default()), + id_map: Arc::new(RwLock::default()), } } - pub fn get_self_num_features(&self) -> usize { - *self.self_num_features.clone().read().unwrap() - } - - pub fn get_self_num_records(&self) -> usize { - *self.self_num_records.clone().read().unwrap() - } - - pub fn get_partner_num_features(&self) -> usize { - *self.partner_num_features.clone().read().unwrap() - } - - pub fn get_partner_num_records(&self) -> usize { - *self.partner_num_records.clone().read().unwrap() - } - - pub fn set_partner_num_features(&self, partner_num_features: usize) { - *self.partner_num_features.clone().write().unwrap() = partner_num_features; - } - - pub fn set_partner_num_records(&self, partner_num_records: usize) { - *self.partner_num_records.clone().write().unwrap() = partner_num_records; - } - - pub fn set_partner_he_public_key(&self, partner_he_pub_key: EncryptionKey) { - *self.partner_he_public_key.clone().write().unwrap() = partner_he_pub_key; + pub fn load_data(&self, path: &str, input_with_headers: bool) { + load_data(self.plain_data.clone(), path, input_with_headers); + fill_permute( + self.permutation.clone(), + (*self.plain_data.clone().read().unwrap()).records.len(), + ); } - pub fn fill_permute_self(&self) { - if let Ok(mut permute) = self.self_permutation.clone().write() { - permute.clear(); - permute.append(&mut common::permutations::gen_permute_pattern( - self.get_self_num_records(), - )); + pub fn load_json(&self, json: &str, input_with_headers: bool) -> bool { + let success = load_json(self.plain_data.clone(), json, input_with_headers); + if success { + fill_permute( + self.permutation.clone(), + (*self.plain_data.clone().read().unwrap()).records.len(), + ); } + success } } -impl Default for CompanyCrossPsi { +impl Default for CompanyPrivateId { fn default() -> Self { Self::new() } } -impl LoadData for CompanyCrossPsi { - fn load_data(&self, input_path: T) - where - T: AsRef, - { - load_data_with_features( - input_path, - self.plaintext_keys.clone(), - self.plaintext_features.clone(), - self.self_num_features.clone(), - self.self_num_records.clone(), - ) +impl CompanyPrivateIdProtocol for CompanyPrivateId { + fn set_encrypted_company(&self, name: String, data: TPayload) -> Result<(), ProtocolError> { + match name.as_str() { + "e_company" => self + .e_company + .clone() + .write() + .map(|mut d| { + let t = timer::Timer::new_silent("Load e_company"); + d.append(&mut self.ec_cipher.to_points(&data)); + t.qps("deserialize", data.len()); + }) + .map_err(|_| { + ProtocolError::ErrorDeserialization("Cannot load e_company".to_string()) + }), + "v_company" => self + .v_company + .clone() + .write() + .map(|mut d| { + let t = timer::Timer::new_silent("Load v_company"); + d.append(&mut self.ec_cipher.to_points(&data)); + t.qps("deserialize", data.len()); + }) + .map_err(|_| { + ProtocolError::ErrorDeserialization("Cannot load v_company".to_string()) + }), + _ => panic!("wrong name"), + } } -} -impl ShareableEncKey for CompanyCrossPsi { - fn get_he_public_key(&self) -> EncryptionKey { - (*self.he_cipher.enc_key.clone()).clone() + fn set_encrypted_partner_keys(&self, u_partner_payload: TPayload) -> Result<(), ProtocolError> { + self.e_partner + .clone() + .write() + .map(|mut data| { + let t = timer::Timer::new_silent("load_u_partner"); + if data.is_empty() { + data.extend( + &self + .ec_cipher + .to_points_encrypt(&u_partner_payload, &self.private_keys.0), + ); + t.qps("deserialize_exp", u_partner_payload.len()); + } + }) + .map_err(|err| { + error!("Cannot load e_company {}", err); + ProtocolError::ErrorDeserialization("cannot load u_partner".to_string()) + }) } -} -impl CompanyCrossPsiProtocol for CompanyCrossPsi { - fn get_permuted_keys(&self) -> TPayload { - let t = timer::Builder::new() - .label("u_company") - .size(self.get_self_num_records()) - .build(); - - if let (Ok(perm), Ok(mut text)) = ( - self.self_permutation.clone().read(), - self.plaintext_keys.clone().write(), - ) { - common::permutations::permute(perm.as_slice(), &mut text); - let res = self - .ec_cipher - .hash_encrypt_to_bytes(text.as_slice(), &self.ec_key); - t.qps("keys EC enc", res.len()); - res - } else { - panic!("Unable to make u_company keys happen") - } + fn write_partner_to_id_map( + &self, + s_prime_partner_payload: TPayload, + na_val: Option<&String>, + ) -> Result<(), ProtocolError> { + self.id_map + .clone() + .write() + .map(|mut data| { + let t = timer::Timer::new_silent("load_s_prime_partner"); + if data.is_empty() { + for k in &s_prime_partner_payload { + let record = (*self.plain_data.clone().read().unwrap()) + .get_empty_record_with_key(k.to_string(), na_val); + data.push(record); + } + t.qps("deserialize_exp", s_prime_partner_payload.len()); + } + }) + .map_err(|err| { + error!("Cannot load s_double_prime_partner {}", err); + ProtocolError::ErrorDeserialization( + "cannot load s_double_prime_partner".to_string(), + ) + }) } - fn get_permuted_features(&self, feature_id: usize) -> TPayload { - let t = timer::Builder::new() - .silent(true) - .label("u_company") - .size(self.get_self_num_records()) - .build(); - - if let (Ok(perm), Ok(mut features)) = ( - self.self_permutation.clone().read(), - self.plaintext_features.clone().write(), - ) { - let feature = &mut features[feature_id]; - common::permutations::permute(perm.as_slice(), feature); - - let res = self.he_cipher.enc_serialise_u64(&feature); - t.qps(format!("feature {} HE enc", feature_id).as_str(), res.len()); - res - } else { - panic!("Cannot HE encrypt column {} ", feature_id); + fn get_permuted_keys(&self) -> Result { + match self.plain_data.clone().read() { + Ok(pdata) => { + let t = timer::Timer::new_silent("u_company"); + let plain_keys = pdata.get_plain_keys(); + let mut u = self + .ec_cipher + .hash_encrypt_to_bytes(&plain_keys.as_slice(), &self.private_keys.0); + t.qps("encryption", u.len()); + + self.permutation + .clone() + .read() + .map(|pm| { + permute(&pm, &mut u); + t.qps("permutation", pm.len()); + u + }) + .map_err(|err| { + error!("Cannot permute {}", err); + ProtocolError::ErrorEncryption("cannot permute u_company".to_string()) + }) + } + Err(e) => { + error!("Unable to encrypt UCompany: {}", e); + Err(ProtocolError::ErrorEncryption( + "cannot encrypt UCompany".to_string(), + )) + } } } - fn set_encrypted_company_keys(&self, mut data: TPayload) { - if let Ok(mut keys) = self.encrypted_company_keys.clone().write() { - keys.clear(); - keys.extend(data.drain(..)) - } else { - panic!("Cannot upload e_company keys"); - } + fn get_encrypted_partner_keys(&self) -> Result { + self.e_partner + .clone() + .read() + .map(|data| { + let t = timer::Timer::new_silent("v_partner"); + let u = self.ec_cipher.encrypt_to_bytes(&data, &self.private_keys.1); + t.qps("exp_serialize", u.len()); + u + }) + .map_err(|err| { + error!("Unable to encrypt VPartner: {}", err); + ProtocolError::ErrorDeserialization("cannot encrypt VPartner".to_string()) + }) } - fn generate_additive_shares(&self, feature_id: usize, values: TPayload) { - let t = timer::Builder::new() - .label("server") - .silent(true) - .extra_label("additive shares mask") - .build(); - let filtered_values: TPayload = - if let Ok(mask) = self.partner_intersection_mask.clone().read() { - values - .iter() - .zip(mask.iter()) - .filter(|(_, &b)| b) - .map(|(a, _)| a.clone()) - .collect::() - } else { - panic!("unable to get masked vals") - }; - - // Generate random mask - { - *self.additive_mask.clone().write().unwrap() = rand_bigints(filtered_values.len()); - } - - if let (Ok(key), Ok(mask), Ok(mut partner_shares)) = ( - self.partner_he_public_key.clone().read(), - self.additive_mask.clone().read(), - self.partner_shares.clone().write(), + fn calculate_set_diff(&self) -> Result<(), ProtocolError> { + match ( + self.e_partner.clone().read(), + self.e_company.clone().read(), + self.s_prime_company.clone().write(), + self.s_prime_partner.clone().write(), ) { - let res = self - .he_cipher - .subtract_plaintext(&key, filtered_values, &mask); - t.qps("masking values in the intersection", res.len()); - partner_shares.insert(feature_id, res); - } else { - panic!("Unable to add additive shares with the intersection") - } - } - - fn get_shares(&self, feature_index: usize) -> TPayload { - if let Ok(mut shares) = self.partner_shares.clone().write() { - if !shares.contains_key(&feature_index) { - panic!("No feature_index {} for shares", feature_index); + (Ok(e_partner), Ok(e_company), Ok(mut s_prime_company), Ok(mut s_prime_partner)) => { + let e_company_bytes = self + .ec_cipher + .encrypt_to_bytes(&e_company, &self.private_keys.1); + let e_partner_bytes = self + .ec_cipher + .encrypt_to_bytes(&e_partner, &self.private_keys.1); + + s_prime_partner.clear(); + s_prime_partner.extend(common::vectors::subtract_set( + &e_partner_bytes, + &e_company_bytes, + )); + + s_prime_company.clear(); + s_prime_company.extend(common::vectors::subtract_set( + &e_company_bytes, + &e_partner_bytes, + )); + Ok(()) + } + _ => { + error!("Unable to obtain locks to buffers for set diff operation"); + Err(ProtocolError::ErrorCalcSetDiff( + "cannot calculate set difference".to_string(), + )) } - shares.remove(&feature_index).unwrap() - } else { - panic!("Unable to read shares"); } } - fn set_self_shares(&self, feature_index: usize, data: TPayload) { - if let Ok(mut shares) = self.self_shares.clone().write() { - info!( - "Saving self-shares for feature index {} len {}", - feature_index, - data.len() - ); - shares.insert(feature_index, self.he_cipher.decrypt(data)); - } else { - panic!("Unable to write shares"); + fn get_set_diff_output(&self, name: String) -> Result { + match name.as_str() { + "s_prime_partner" => self + .s_prime_partner + .clone() + .read() + .map(|data| data.to_vec()) + .map_err(|err| { + error!("Unable to get s_prime_partner: {}", err); + ProtocolError::ErrorDeserialization("cannot obtain s_prime_partner".to_string()) + }), + "s_prime_company" => self + .s_prime_company + .clone() + .read() + .map(|data| data.to_vec()) + .map_err(|err| { + error!("Unable to get s_prime_company: {}", err); + ProtocolError::ErrorDeserialization("cannot obtain s_prime_company".to_string()) + }), + _ => panic!("wrong name"), } } - fn calculate_intersection(&self, keys: TPayload) { - let partner_keys = self.ec_cipher.to_bytes( - &self - .ec_cipher - .to_points_encrypt(keys.as_slice(), &self.ec_key), - ); - - // find the index of the intersection - - if let (Ok(company_keys), Ok(mut partner_mask), Ok(mut company_indices)) = ( - self.encrypted_company_keys.clone().read(), - self.partner_intersection_mask.clone().write(), - self.self_intersection_indices.clone().write(), + fn write_company_to_id_map(&self) { + match ( + self.permutation.clone().read(), + self.plain_data.clone().read(), + self.v_company.clone().read(), + self.id_map.clone().write(), ) { - if company_keys.is_empty() { - panic!("e_partner keys should be uploaded after e_company keys are uploaded"); - } - - partner_mask.clear(); - - partner_mask.extend(common::vectors::vec_intersection_mask( - partner_keys.as_slice(), - company_keys.as_slice(), - )); + (Ok(pm), Ok(plain_data), Ok(v_company), Ok(mut id_map)) => { + let mut company_encrypt = self.ec_cipher.encrypt(&v_company, &self.private_keys.1); + undo_permute(&pm, &mut company_encrypt); + for (k, v) in self + .ec_cipher + .to_bytes(&company_encrypt) + .iter() + .zip(plain_data.get_plain_keys().iter()) + { + let record = plain_data.get_record_with_keys(k.to_string(), &v); + id_map.push(record); + } - // TODO: can this be a parallel forall - for (flag, partner_key) in partner_mask.iter().zip(&partner_keys) { - if *flag { - let index = company_keys - .iter() - .position(|x| *x == *partner_key) - .unwrap(); - company_indices.push(index); + if !plain_data.headers.is_empty() { + id_map.insert(0, plain_data.headers.clone()); } } - - info!( - "Company-Partner Intersection size: {}", - company_indices.len() - ); - } else { - panic!("Unable to find interesection"); + _ => panic!("Cannot make v"), } } - fn get_company_indices(&self) -> TPayload { - if let Ok(indices) = self.self_intersection_indices.clone().read() { - let mut index_buffer: TPayload = Vec::with_capacity(indices.len()); - for index in indices.iter() { - index_buffer.push(ByteBuffer { - buffer: (*index as u64).to_le_bytes().to_vec(), - }); - } - index_buffer - } else { - panic!("Unable to fetch company indices"); - } + fn print_id_map(&self, limit: usize, input_with_headers: bool, use_row_numbers: bool) { + let _ = self + .id_map + .clone() + .read() + .map(|data| { + files::write_vec_to_stdout(&data, limit, input_with_headers, use_row_numbers) + .unwrap() + }) + .map_err(|_| {}); } -} - -impl Reveal for CompanyCrossPsi { - fn reveal>(&self, path: T) { - if let (Ok(indices), Ok(additive_mask), Ok(mut self_shares)) = ( - self.self_intersection_indices.clone().read(), - self.additive_mask.clone().read(), - self.self_shares.clone().write(), - ) { - let max_val = BigInt::one() << 64; - let mut filtered_shares: Vec = Vec::with_capacity(indices.len()); - - for index in indices.iter() { - filtered_shares.push((self_shares[&0][*index]).clone()); - } - self_shares.remove(&0); - - let company_shares = filtered_shares - .iter() - .map(|z| (Option::::from(&z.mod_floor(&max_val))).unwrap()) - .collect::>(); - - let partner_shares: Vec = additive_mask - .iter() - .map(|z| (Option::::from(&z.mod_floor(&max_val))).unwrap()) - .collect::>(); + fn save_id_map( + &self, + path: &str, + input_with_headers: bool, + use_row_numbers: bool, + ) -> Result<(), ProtocolError> { + self.id_map + .clone() + .write() + .map(|mut data| { + files::write_vec_to_csv(&mut data, path, input_with_headers, use_row_numbers) + .unwrap(); + }) + .map_err(|_| ProtocolError::ErrorIO("Unable to write company view to file".to_string())) + } - let mut out: Vec> = - Vec::with_capacity(self.get_self_num_features() + self.get_partner_num_features()); - out.push(partner_shares); - out.push(company_shares); - info!("revealing columns to output file"); - common::files::write_u64cols_to_file(&mut out, path).unwrap(); - } else { - panic!("Unable to reveal"); - } + fn stringify_id_map(&self, use_row_numbers: bool) -> String { + files::stringify_id_map(self.id_map.clone(), use_row_numbers) } } diff --git a/protocol/src/cross_psi/partner.rs b/protocol/src/cross_psi/partner.rs index 8e89c97..9b536c8 100644 --- a/protocol/src/cross_psi/partner.rs +++ b/protocol/src/cross_psi/partner.rs @@ -1,281 +1,215 @@ // Copyright (c) Facebook, Inc. and its affiliates. // SPDX-License-Identifier: Apache-2.0 -extern crate common; -extern crate crypto; +extern crate csv; -use log::info; -use std::{ - collections::HashMap, - ops::Deref, - path::Path, - sync::{Arc, RwLock}, +use crypto::{ + eccipher::{gen_scalar, ECCipher}, + prelude::*, }; +#[cfg(target_arch = "wasm32")] +use crypto::eccipher::ECRistrettoSequential as ECRistretto; +#[cfg(not(target_arch = "wasm32"))] +use crypto::eccipher::ECRistrettoParallel as ECRistretto; use crate::{ - cross_psi::traits::*, - fileio::load_data_with_features, - shared::{LoadData, Reveal, ShareableEncKey, TFeatures}, + fileio::{load_data, load_json, KeyedCSV}, + private_id::traits::PartnerPrivateIdProtocol, }; -use common::timer; -use crypto::{ - eccipher, - eccipher::{gen_scalar, ECCipher}, - he, - prelude::{mod_sub, rand_bigints, BigInt, EncryptionKey, Scalar, TPayload}, +use common::{ + files, + permutations::{gen_permute_pattern, permute, undo_permute}, + timer, }; -#[derive(Debug)] -pub struct PartnerCrossPsi { - ec_cipher: eccipher::ECRistrettoParallel, - he_cipher: he::PaillierParallel, - ec_key: Scalar, - company_he_public_key: Arc>, - self_num_records: Arc>, - self_num_features: Arc>, - company_num_records: Arc>, - company_num_features: Arc>, - plaintext_keys: Arc>>, - plaintext_features: Arc>, - company_permutation: Arc>>, - self_permutation: Arc>>, - additive_mask: Arc>>, - self_shares: Arc>>>, - company_intersection_indices: Arc>>, -} - -impl PartnerCrossPsi { - pub fn new() -> PartnerCrossPsi { - PartnerCrossPsi { - ec_cipher: eccipher::ECRistrettoParallel::new(), - he_cipher: he::PaillierParallel::new(), - ec_key: gen_scalar(), - company_he_public_key: Arc::new(RwLock::new(EncryptionKey { - n: BigInt::zero(), - nn: BigInt::zero(), - })), - self_num_records: Arc::new(RwLock::default()), - self_num_features: Arc::new(RwLock::default()), - company_num_records: Arc::new(RwLock::default()), - company_num_features: Arc::new(RwLock::default()), - plaintext_keys: Arc::new(RwLock::default()), - plaintext_features: Arc::new(RwLock::default()), - company_permutation: Arc::new(RwLock::default()), - self_permutation: Arc::new(RwLock::default()), - additive_mask: Arc::new(RwLock::default()), - self_shares: Arc::new(RwLock::default()), - company_intersection_indices: Arc::new(RwLock::default()), - } - } - - pub fn set_company_intersection_indices(&self, mut indices: Vec) { - if let Ok(mut company_indices) = self.company_intersection_indices.clone().write() { - company_indices.clear(); - company_indices.extend(indices.drain(..)); - } else { - panic!("Cannot set indices"); - } - } - - pub fn get_self_num_features(&self) -> usize { - *self.self_num_features.clone().read().unwrap() - } - - pub fn get_self_num_records(&self) -> usize { - *self.self_num_records.clone().read().unwrap() - } - - pub fn get_company_num_features(&self) -> usize { - *self.company_num_features.clone().read().unwrap() - } +use std::sync::{Arc, RwLock}; - pub fn get_company_num_records(&self) -> usize { - *self.company_num_records.clone().read().unwrap() - } +use super::{fill_permute, ProtocolError}; - pub fn set_company_num_records(&self, company_num_records: usize) { - *self.company_num_records.clone().write().unwrap() = company_num_records; - } - - pub fn set_company_num_features(&self, company_num_features: usize) { - *self.company_num_features.clone().write().unwrap() = company_num_features; - } +pub struct PartnerPrivateId { + private_keys: (Scalar, Scalar), + ec_cipher: ECRistretto, + plain_data: Arc>, + permutation: Arc>>, + id_map: Arc>>>, +} - pub fn set_company_he_public_key(&self, company_he_public_key: EncryptionKey) { - *self.company_he_public_key.clone().write().unwrap() = company_he_public_key; +impl PartnerPrivateId { + pub fn new() -> PartnerPrivateId { + PartnerPrivateId { + private_keys: (gen_scalar(), gen_scalar()), + ec_cipher: ECRistretto::default(), + plain_data: Arc::new(RwLock::default()), + permutation: Arc::new(RwLock::default()), + id_map: Arc::new(RwLock::default()), + } } - pub fn fill_permute_company(&self, length: usize) { - if let Ok(mut permute) = self.company_permutation.clone().write() { - permute.clear(); - permute.append(&mut common::permutations::gen_permute_pattern(length)); - } + pub fn load_data(&self, path: &str, input_with_headers: bool) -> Result<(), ProtocolError> { + load_data(self.plain_data.clone(), path, input_with_headers); + Ok(()) } - pub fn fill_permute_self(&self) { - if let Ok(mut permute) = self.self_permutation.clone().write() { - permute.clear(); - permute.append(&mut common::permutations::gen_permute_pattern( - self.get_self_num_records(), - )); - } + pub fn load_json(&self, path: &str, input_with_headers: bool) -> Result { + Ok(load_json(self.plain_data.clone(), path, input_with_headers)) } - pub fn permute(&self, values: &mut Vec) { - common::permutations::permute( - self.company_permutation.clone().read().unwrap().as_slice(), - values, - ); + pub fn get_size(&self) -> usize { + self.plain_data.clone().read().unwrap().records.len() } } -impl Default for PartnerCrossPsi { +impl Default for PartnerPrivateId { fn default() -> Self { Self::new() } } -impl LoadData for PartnerCrossPsi { - fn load_data(&self, input_path: T) - where - T: AsRef, - { - load_data_with_features( - input_path, - self.plaintext_keys.clone(), - self.plaintext_features.clone(), - self.self_num_features.clone(), - self.self_num_records.clone(), - ) - } -} - -impl ShareableEncKey for PartnerCrossPsi { - fn get_he_public_key(&self) -> EncryptionKey { - (*self.he_cipher.enc_key.clone()).clone() - } -} - -impl PartnerCrossPsiProtocol for PartnerCrossPsi { - fn get_permuted_keys(&self) -> TPayload { - timer::Builder::new() - .label("u_partner") - .extra_label("keys EC enc") - .size(self.get_self_num_records()) - .build(); - - if let (Ok(perm), Ok(mut text)) = ( - self.self_permutation.clone().read(), - self.plaintext_keys.clone().write(), - ) { - common::permutations::permute(perm.as_slice(), &mut text); - self.ec_cipher - .hash_encrypt_to_bytes(text.as_slice(), &self.ec_key) - } else { - panic!("Could not permute and encrypt keys"); - } - } - - fn get_permuted_features(&self, feature_index: usize) -> TPayload { - let t = timer::Builder::new() - .silent(true) - .label("u_partner") - .size(self.get_self_num_records()) - .build(); +impl PartnerPrivateIdProtocol for PartnerPrivateId { + fn gen_permute_pattern(&self) -> Result<(), ProtocolError> { + fill_permute( + self.permutation.clone(), + (*self.plain_data.clone().read().unwrap()).records.len(), + ); + Ok(()) + } + + fn permute_hash_to_bytes(&self) -> Result { + match self.plain_data.clone().read() { + Ok(pdata) => { + #[cfg(not(target_arch="wasm32"))] let t = timer::Timer::new_silent("u_partner"); + let plain_keys = pdata.get_plain_keys(); + let mut u = self + .ec_cipher + .hash_encrypt_to_bytes(&plain_keys.as_slice(), &self.private_keys.0); + #[cfg(not(target_arch="wasm32"))] t.qps("encryption", u.len()); + + self.permutation + .clone() + .read() + .map(|pm| { + permute(&pm, &mut u); + #[cfg(not(target_arch="wasm32"))] t.qps("permutation", pm.len()); + u + }) + .map_err(|err| { + error!("error in permute {}", err); + ProtocolError::ErrorEncryption("unable to encrypt data".to_string()) + }) + } - if let (Ok(perm), Ok(mut features)) = ( - self.self_permutation.clone().read(), - self.plaintext_features.clone().write(), - ) { - let feature_column = &mut features[feature_index]; - common::permutations::permute(perm.as_slice(), feature_column); - let res = self.he_cipher.enc_serialise_u64(&feature_column); - t.qps( - format!("column {} HE enc", feature_index).as_str(), - res.len(), - ); - res - } else { - panic!("Cannot HE encrypt column {} ", feature_index); + Err(e) => { + error!("Unable to encrypt plain_data: {}", e); + Err(ProtocolError::ErrorEncryption( + "unable to encrypt data".to_string(), + )) + } } } - fn encrypt(&self, keys: TPayload) -> TPayload { - timer::Builder::new() - .label("e_company") - .extra_label("keys EC enc + srlz") - .size(keys.len()) - .build(); - - self.ec_cipher.to_bytes( - self.ec_cipher - .to_points_encrypt(keys.as_slice(), &self.ec_key) - .as_slice(), - ) - } - - fn generate_additive_shares(&self, _: usize, values: TPayload) -> TPayload { + //TODO: return result + fn encrypt_permute(&self, company: TPayload) -> (TPayload, TPayload) { + #[cfg(not(target_arch="wasm32"))] let t = timer::Timer::new_silent("encrypt_permute_company"); + let mut encrypt_company = self + .ec_cipher + .to_points_encrypt(&company, &self.private_keys.0); + #[cfg(not(target_arch="wasm32"))] t.qps("encrypt_company", encrypt_company.len()); + let v_company = self + .ec_cipher + .encrypt_to_bytes(&encrypt_company, &self.private_keys.1); + #[cfg(not(target_arch="wasm32"))] t.qps("v_company", v_company.len()); { - *self.additive_mask.clone().write().unwrap() = rand_bigints(values.len()); - } - - if let (Ok(key), Ok(mask)) = ( - self.company_he_public_key.clone().read(), - self.additive_mask.clone().read(), - ) { - self.he_cipher - .subtract_plaintext(key.deref(), values, &mask) - } else { - panic!("Cannot mask with additive shares") + let rand_permutation = gen_permute_pattern(encrypt_company.len()); + // TODO: BUG why is this undo_permute + // undo_permute(&rand_permutation, &mut e_company_dsrlz); + permute(&rand_permutation, &mut encrypt_company); } + (self.ec_cipher.to_bytes(&encrypt_company), v_company) } - fn set_self_shares(&self, feature_index: usize, data: TPayload) { - if let Ok(mut shares) = self.self_shares.clone().write() { - info!("Saving self-shares for feature {}", feature_index); - shares.insert(feature_index, self.he_cipher.decrypt(data)); - } else { - panic!("Unable to write shares"); - } + fn encrypt(&self, partner: TPayload) -> Result { + let ep = self + .ec_cipher + .to_points_encrypt(&partner, &self.private_keys.1); + Ok(self.ec_cipher.to_bytes(&ep)) } -} -impl Reveal for PartnerCrossPsi { - fn reveal>(&self, path: T) { - if let (Ok(indices), Ok(mut self_shares), Ok(mut additive_mask)) = ( - self.company_intersection_indices.clone().read(), - self.self_shares.clone().write(), - self.additive_mask.clone().write(), + fn create_id_map(&self, partner: TPayload, company: TPayload, na_val: Option<&str>) { + match ( + self.permutation.clone().read(), + self.plain_data.clone().read(), + self.id_map.clone().write(), ) { - let output_mod = BigInt::one() << 64; - let n = BigInt::one() << 1024; - - let mut filtered_shares: Vec = Vec::with_capacity(indices.len()); - - for index in indices.iter() { - filtered_shares.push(additive_mask[*index].clone()); + (Ok(pm), Ok(plain_data), Ok(mut id_map)) => { + let mut partner_encrypt = self + .ec_cipher + .to_points_encrypt(&partner, &self.private_keys.1); + undo_permute(&pm, &mut partner_encrypt); + + for (k, v) in self + .ec_cipher + .to_bytes(&partner_encrypt) + .iter() + .zip(plain_data.get_plain_keys().iter()) + { + let record = plain_data.get_record_with_keys(k.to_string(), &v); + id_map.push(record); + } + + for k in self + .ec_cipher + .to_bytes( + &self + .ec_cipher + .to_points_encrypt(&company, &self.private_keys.1), + ) + .iter() + { + let record = plain_data.get_empty_record_with_key( + k.to_string(), + na_val.map(String::from).as_ref(), + ); + id_map.push(record); + } + + if !plain_data.headers.is_empty() { + id_map.insert(0, plain_data.headers.clone()); + } } - additive_mask.clear(); - - let company_shares = filtered_shares - .iter() - .map(|e| (Option::::from(&mod_sub(e, &n, &output_mod))).unwrap()) - .collect::>(); - - let partner_shares = self_shares - .remove(&0) - .unwrap() - .iter() - .map(|e| (Option::::from(&mod_sub(e, &n, &output_mod))).unwrap()) - .collect::>(); - - let mut out: Vec> = - Vec::with_capacity(self.get_self_num_features() + self.get_company_num_features()); - out.push(partner_shares); - out.push(company_shares); - info!("revealing columns to output file"); - common::files::write_u64cols_to_file(&mut out, path).unwrap(); + _ => panic!("Cannot make v"), } } + + fn print_id_map(&self, limit: usize, input_with_headers: bool, use_row_numbers: bool) { + let _ = self + .id_map + .clone() + .read() + .map(|data| { + files::write_vec_to_stdout(&data, limit, input_with_headers, use_row_numbers) + .unwrap() + }) + .map_err(|_| {}); + } + + fn save_id_map( + &self, + path: &str, + input_with_headers: bool, + use_row_numbers: bool, + ) -> Result<(), ProtocolError> { + self.id_map + .clone() + .write() + .map(|mut data| { + files::write_vec_to_csv(&mut data, path, input_with_headers, use_row_numbers) + .unwrap(); + }) + .map_err(|_| ProtocolError::ErrorIO("Unable to write partner view to file".to_string())) + } + + fn stringify_id_map(&self, use_row_numbers: bool) -> String { + files::stringify_id_map(self.id_map.clone(), use_row_numbers) + } } From be8d426327a855fe3e51775a819521a624005c85 Mon Sep 17 00:00:00 2001 From: zfscgy Date: Thu, 4 Feb 2021 13:22:36 +0800 Subject: [PATCH 2/3] Revert "support for multiple features for company/partner" This reverts commit 5efc296214e817169ab2b1617c9b231795762a57. --- protocol/src/cross_psi/company.rs | 583 ++++++++++++++++-------------- protocol/src/cross_psi/partner.rs | 410 ++++++++++++--------- 2 files changed, 545 insertions(+), 448 deletions(-) diff --git a/protocol/src/cross_psi/company.rs b/protocol/src/cross_psi/company.rs index 43301e3..dd13fab 100644 --- a/protocol/src/cross_psi/company.rs +++ b/protocol/src/cross_psi/company.rs @@ -1,337 +1,368 @@ // Copyright (c) Facebook, Inc. and its affiliates. // SPDX-License-Identifier: Apache-2.0 -extern crate csv; +extern crate common; +extern crate crypto; -use std::sync::{Arc, RwLock}; - -use crypto::{ - eccipher::{gen_scalar, ECCipher}, - prelude::*, -}; -#[cfg(target_arch = "wasm32")] -use crypto::eccipher::ECRistrettoSequential as ECRistretto; -#[cfg(not(target_arch = "wasm32"))] -use crypto::eccipher::ECRistrettoParallel as ECRistretto; - -use common::{ - files, - permutations::{permute, undo_permute}, - timer, +use log::info; +use std::{ + collections::HashMap, + path::Path, + sync::{Arc, RwLock}, }; use crate::{ - fileio::{load_data, load_json, KeyedCSV}, - private_id::traits::CompanyPrivateIdProtocol, + cross_psi::traits::*, + fileio::load_data_with_features, + shared::{LoadData, Reveal, ShareableEncKey, TFeatures}, }; +use common::timer; -use super::{fill_permute, ProtocolError}; +use crypto::{ + eccipher, + eccipher::{gen_scalar, ECCipher, ECRistrettoParallel}, + he, + he::PaillierParallel, + prelude::{rand_bigints, BigInt, ByteBuffer, EncryptionKey, Scalar, TPayload}, +}; #[derive(Debug)] -pub struct CompanyPrivateId { - private_keys: (Scalar, Scalar), - ec_cipher: ECRistretto, - // TODO: consider using dyn pid::crypto::ECCipher trait? - plain_data: Arc>, - permutation: Arc>>, +pub struct CompanyCrossPsi { + ec_cipher: eccipher::ECRistrettoParallel, + he_cipher: he::PaillierParallel, - v_company: Arc>>, - e_company: Arc>>, - e_partner: Arc>>, + ec_key: Scalar, + partner_he_public_key: Arc>, - s_prime_company: Arc>>, - s_prime_partner: Arc>>, + self_num_records: Arc>, + self_num_features: Arc>, + partner_num_records: Arc>, + partner_num_features: Arc>, - id_map: Arc>>>, + plaintext_keys: Arc>>, + plaintext_features: Arc>, + + self_permutation: Arc>>, + // These are double encrypted - once by partner + // and once by company + encrypted_company_keys: Arc>, + + partner_intersection_mask: Arc>>, + self_intersection_indices: Arc>>, + + //TODO: WARN: this is single column only (yet) + additive_mask: Arc>>, + partner_shares: Arc>>, + self_shares: Arc>>>, } -impl CompanyPrivateId { - pub fn new() -> CompanyPrivateId { - CompanyPrivateId { - private_keys: (gen_scalar(), gen_scalar()), - ec_cipher: ECRistretto::default(), - plain_data: Arc::new(RwLock::default()), - permutation: Arc::new(RwLock::default()), - v_company: Arc::new(RwLock::default()), - e_company: Arc::new(RwLock::default()), - e_partner: Arc::new(RwLock::default()), - s_prime_company: Arc::new(RwLock::default()), - s_prime_partner: Arc::new(RwLock::default()), - id_map: Arc::new(RwLock::default()), +impl CompanyCrossPsi { + pub fn new() -> CompanyCrossPsi { + CompanyCrossPsi { + ec_cipher: ECRistrettoParallel::new(), + he_cipher: PaillierParallel::new(), + + ec_key: gen_scalar(), + partner_he_public_key: Arc::new(RwLock::new(EncryptionKey { + n: BigInt::zero(), + nn: BigInt::zero(), + })), + + self_num_records: Arc::new(RwLock::default()), + self_num_features: Arc::new(RwLock::default()), + partner_num_records: Arc::new(RwLock::default()), + partner_num_features: Arc::new(RwLock::default()), + + plaintext_features: Arc::new(RwLock::default()), + plaintext_keys: Arc::new(RwLock::default()), + + self_permutation: Arc::new(RwLock::default()), + + encrypted_company_keys: Arc::new(RwLock::default()), + + partner_intersection_mask: Arc::new(RwLock::default()), + self_intersection_indices: Arc::new(RwLock::default()), + + additive_mask: Arc::new(RwLock::default()), + partner_shares: Arc::new(RwLock::default()), + self_shares: Arc::new(RwLock::default()), } } - pub fn load_data(&self, path: &str, input_with_headers: bool) { - load_data(self.plain_data.clone(), path, input_with_headers); - fill_permute( - self.permutation.clone(), - (*self.plain_data.clone().read().unwrap()).records.len(), - ); + pub fn get_self_num_features(&self) -> usize { + *self.self_num_features.clone().read().unwrap() } - pub fn load_json(&self, json: &str, input_with_headers: bool) -> bool { - let success = load_json(self.plain_data.clone(), json, input_with_headers); - if success { - fill_permute( - self.permutation.clone(), - (*self.plain_data.clone().read().unwrap()).records.len(), - ); + pub fn get_self_num_records(&self) -> usize { + *self.self_num_records.clone().read().unwrap() + } + + pub fn get_partner_num_features(&self) -> usize { + *self.partner_num_features.clone().read().unwrap() + } + + pub fn get_partner_num_records(&self) -> usize { + *self.partner_num_records.clone().read().unwrap() + } + + pub fn set_partner_num_features(&self, partner_num_features: usize) { + *self.partner_num_features.clone().write().unwrap() = partner_num_features; + } + + pub fn set_partner_num_records(&self, partner_num_records: usize) { + *self.partner_num_records.clone().write().unwrap() = partner_num_records; + } + + pub fn set_partner_he_public_key(&self, partner_he_pub_key: EncryptionKey) { + *self.partner_he_public_key.clone().write().unwrap() = partner_he_pub_key; + } + + pub fn fill_permute_self(&self) { + if let Ok(mut permute) = self.self_permutation.clone().write() { + permute.clear(); + permute.append(&mut common::permutations::gen_permute_pattern( + self.get_self_num_records(), + )); } - success } } -impl Default for CompanyPrivateId { +impl Default for CompanyCrossPsi { fn default() -> Self { Self::new() } } -impl CompanyPrivateIdProtocol for CompanyPrivateId { - fn set_encrypted_company(&self, name: String, data: TPayload) -> Result<(), ProtocolError> { - match name.as_str() { - "e_company" => self - .e_company - .clone() - .write() - .map(|mut d| { - let t = timer::Timer::new_silent("Load e_company"); - d.append(&mut self.ec_cipher.to_points(&data)); - t.qps("deserialize", data.len()); - }) - .map_err(|_| { - ProtocolError::ErrorDeserialization("Cannot load e_company".to_string()) - }), - "v_company" => self - .v_company - .clone() - .write() - .map(|mut d| { - let t = timer::Timer::new_silent("Load v_company"); - d.append(&mut self.ec_cipher.to_points(&data)); - t.qps("deserialize", data.len()); - }) - .map_err(|_| { - ProtocolError::ErrorDeserialization("Cannot load v_company".to_string()) - }), - _ => panic!("wrong name"), - } +impl LoadData for CompanyCrossPsi { + fn load_data(&self, input_path: T) + where + T: AsRef, + { + load_data_with_features( + input_path, + self.plaintext_keys.clone(), + self.plaintext_features.clone(), + self.self_num_features.clone(), + self.self_num_records.clone(), + ) } +} - fn set_encrypted_partner_keys(&self, u_partner_payload: TPayload) -> Result<(), ProtocolError> { - self.e_partner - .clone() - .write() - .map(|mut data| { - let t = timer::Timer::new_silent("load_u_partner"); - if data.is_empty() { - data.extend( - &self - .ec_cipher - .to_points_encrypt(&u_partner_payload, &self.private_keys.0), - ); - t.qps("deserialize_exp", u_partner_payload.len()); - } - }) - .map_err(|err| { - error!("Cannot load e_company {}", err); - ProtocolError::ErrorDeserialization("cannot load u_partner".to_string()) - }) +impl ShareableEncKey for CompanyCrossPsi { + fn get_he_public_key(&self) -> EncryptionKey { + (*self.he_cipher.enc_key.clone()).clone() } +} - fn write_partner_to_id_map( - &self, - s_prime_partner_payload: TPayload, - na_val: Option<&String>, - ) -> Result<(), ProtocolError> { - self.id_map - .clone() - .write() - .map(|mut data| { - let t = timer::Timer::new_silent("load_s_prime_partner"); - if data.is_empty() { - for k in &s_prime_partner_payload { - let record = (*self.plain_data.clone().read().unwrap()) - .get_empty_record_with_key(k.to_string(), na_val); - data.push(record); - } - t.qps("deserialize_exp", s_prime_partner_payload.len()); - } - }) - .map_err(|err| { - error!("Cannot load s_double_prime_partner {}", err); - ProtocolError::ErrorDeserialization( - "cannot load s_double_prime_partner".to_string(), - ) - }) +impl CompanyCrossPsiProtocol for CompanyCrossPsi { + fn get_permuted_keys(&self) -> TPayload { + let t = timer::Builder::new() + .label("u_company") + .size(self.get_self_num_records()) + .build(); + + if let (Ok(perm), Ok(mut text)) = ( + self.self_permutation.clone().read(), + self.plaintext_keys.clone().write(), + ) { + common::permutations::permute(perm.as_slice(), &mut text); + let res = self + .ec_cipher + .hash_encrypt_to_bytes(text.as_slice(), &self.ec_key); + t.qps("keys EC enc", res.len()); + res + } else { + panic!("Unable to make u_company keys happen") + } } - fn get_permuted_keys(&self) -> Result { - match self.plain_data.clone().read() { - Ok(pdata) => { - let t = timer::Timer::new_silent("u_company"); - let plain_keys = pdata.get_plain_keys(); - let mut u = self - .ec_cipher - .hash_encrypt_to_bytes(&plain_keys.as_slice(), &self.private_keys.0); - t.qps("encryption", u.len()); - - self.permutation - .clone() - .read() - .map(|pm| { - permute(&pm, &mut u); - t.qps("permutation", pm.len()); - u - }) - .map_err(|err| { - error!("Cannot permute {}", err); - ProtocolError::ErrorEncryption("cannot permute u_company".to_string()) - }) - } - Err(e) => { - error!("Unable to encrypt UCompany: {}", e); - Err(ProtocolError::ErrorEncryption( - "cannot encrypt UCompany".to_string(), - )) - } + fn get_permuted_features(&self, feature_id: usize) -> TPayload { + let t = timer::Builder::new() + .silent(true) + .label("u_company") + .size(self.get_self_num_records()) + .build(); + + if let (Ok(perm), Ok(mut features)) = ( + self.self_permutation.clone().read(), + self.plaintext_features.clone().write(), + ) { + let feature = &mut features[feature_id]; + common::permutations::permute(perm.as_slice(), feature); + + let res = self.he_cipher.enc_serialise_u64(&feature); + t.qps(format!("feature {} HE enc", feature_id).as_str(), res.len()); + res + } else { + panic!("Cannot HE encrypt column {} ", feature_id); } } - fn get_encrypted_partner_keys(&self) -> Result { - self.e_partner - .clone() - .read() - .map(|data| { - let t = timer::Timer::new_silent("v_partner"); - let u = self.ec_cipher.encrypt_to_bytes(&data, &self.private_keys.1); - t.qps("exp_serialize", u.len()); - u - }) - .map_err(|err| { - error!("Unable to encrypt VPartner: {}", err); - ProtocolError::ErrorDeserialization("cannot encrypt VPartner".to_string()) - }) + fn set_encrypted_company_keys(&self, mut data: TPayload) { + if let Ok(mut keys) = self.encrypted_company_keys.clone().write() { + keys.clear(); + keys.extend(data.drain(..)) + } else { + panic!("Cannot upload e_company keys"); + } } - fn calculate_set_diff(&self) -> Result<(), ProtocolError> { - match ( - self.e_partner.clone().read(), - self.e_company.clone().read(), - self.s_prime_company.clone().write(), - self.s_prime_partner.clone().write(), + fn generate_additive_shares(&self, feature_id: usize, values: TPayload) { + let t = timer::Builder::new() + .label("server") + .silent(true) + .extra_label("additive shares mask") + .build(); + let filtered_values: TPayload = + if let Ok(mask) = self.partner_intersection_mask.clone().read() { + values + .iter() + .zip(mask.iter()) + .filter(|(_, &b)| b) + .map(|(a, _)| a.clone()) + .collect::() + } else { + panic!("unable to get masked vals") + }; + + // Generate random mask + { + *self.additive_mask.clone().write().unwrap() = rand_bigints(filtered_values.len()); + } + + if let (Ok(key), Ok(mask), Ok(mut partner_shares)) = ( + self.partner_he_public_key.clone().read(), + self.additive_mask.clone().read(), + self.partner_shares.clone().write(), ) { - (Ok(e_partner), Ok(e_company), Ok(mut s_prime_company), Ok(mut s_prime_partner)) => { - let e_company_bytes = self - .ec_cipher - .encrypt_to_bytes(&e_company, &self.private_keys.1); - let e_partner_bytes = self - .ec_cipher - .encrypt_to_bytes(&e_partner, &self.private_keys.1); - - s_prime_partner.clear(); - s_prime_partner.extend(common::vectors::subtract_set( - &e_partner_bytes, - &e_company_bytes, - )); - - s_prime_company.clear(); - s_prime_company.extend(common::vectors::subtract_set( - &e_company_bytes, - &e_partner_bytes, - )); - Ok(()) - } - _ => { - error!("Unable to obtain locks to buffers for set diff operation"); - Err(ProtocolError::ErrorCalcSetDiff( - "cannot calculate set difference".to_string(), - )) + let res = self + .he_cipher + .subtract_plaintext(&key, filtered_values, &mask); + t.qps("masking values in the intersection", res.len()); + partner_shares.insert(feature_id, res); + } else { + panic!("Unable to add additive shares with the intersection") + } + } + + fn get_shares(&self, feature_index: usize) -> TPayload { + if let Ok(mut shares) = self.partner_shares.clone().write() { + if !shares.contains_key(&feature_index) { + panic!("No feature_index {} for shares", feature_index); } + shares.remove(&feature_index).unwrap() + } else { + panic!("Unable to read shares"); } } - fn get_set_diff_output(&self, name: String) -> Result { - match name.as_str() { - "s_prime_partner" => self - .s_prime_partner - .clone() - .read() - .map(|data| data.to_vec()) - .map_err(|err| { - error!("Unable to get s_prime_partner: {}", err); - ProtocolError::ErrorDeserialization("cannot obtain s_prime_partner".to_string()) - }), - "s_prime_company" => self - .s_prime_company - .clone() - .read() - .map(|data| data.to_vec()) - .map_err(|err| { - error!("Unable to get s_prime_company: {}", err); - ProtocolError::ErrorDeserialization("cannot obtain s_prime_company".to_string()) - }), - _ => panic!("wrong name"), + fn set_self_shares(&self, feature_index: usize, data: TPayload) { + if let Ok(mut shares) = self.self_shares.clone().write() { + info!( + "Saving self-shares for feature index {} len {}", + feature_index, + data.len() + ); + shares.insert(feature_index, self.he_cipher.decrypt(data)); + } else { + panic!("Unable to write shares"); } } - fn write_company_to_id_map(&self) { - match ( - self.permutation.clone().read(), - self.plain_data.clone().read(), - self.v_company.clone().read(), - self.id_map.clone().write(), + fn calculate_intersection(&self, keys: TPayload) { + let partner_keys = self.ec_cipher.to_bytes( + &self + .ec_cipher + .to_points_encrypt(keys.as_slice(), &self.ec_key), + ); + + // find the index of the intersection + + if let (Ok(company_keys), Ok(mut partner_mask), Ok(mut company_indices)) = ( + self.encrypted_company_keys.clone().read(), + self.partner_intersection_mask.clone().write(), + self.self_intersection_indices.clone().write(), ) { - (Ok(pm), Ok(plain_data), Ok(v_company), Ok(mut id_map)) => { - let mut company_encrypt = self.ec_cipher.encrypt(&v_company, &self.private_keys.1); - undo_permute(&pm, &mut company_encrypt); - for (k, v) in self - .ec_cipher - .to_bytes(&company_encrypt) - .iter() - .zip(plain_data.get_plain_keys().iter()) - { - let record = plain_data.get_record_with_keys(k.to_string(), &v); - id_map.push(record); - } + if company_keys.is_empty() { + panic!("e_partner keys should be uploaded after e_company keys are uploaded"); + } + + partner_mask.clear(); + + partner_mask.extend(common::vectors::vec_intersection_mask( + partner_keys.as_slice(), + company_keys.as_slice(), + )); - if !plain_data.headers.is_empty() { - id_map.insert(0, plain_data.headers.clone()); + // TODO: can this be a parallel forall + for (flag, partner_key) in partner_mask.iter().zip(&partner_keys) { + if *flag { + let index = company_keys + .iter() + .position(|x| *x == *partner_key) + .unwrap(); + company_indices.push(index); } } - _ => panic!("Cannot make v"), + + info!( + "Company-Partner Intersection size: {}", + company_indices.len() + ); + } else { + panic!("Unable to find interesection"); } } - fn print_id_map(&self, limit: usize, input_with_headers: bool, use_row_numbers: bool) { - let _ = self - .id_map - .clone() - .read() - .map(|data| { - files::write_vec_to_stdout(&data, limit, input_with_headers, use_row_numbers) - .unwrap() - }) - .map_err(|_| {}); + fn get_company_indices(&self) -> TPayload { + if let Ok(indices) = self.self_intersection_indices.clone().read() { + let mut index_buffer: TPayload = Vec::with_capacity(indices.len()); + for index in indices.iter() { + index_buffer.push(ByteBuffer { + buffer: (*index as u64).to_le_bytes().to_vec(), + }); + } + index_buffer + } else { + panic!("Unable to fetch company indices"); + } } +} - fn save_id_map( - &self, - path: &str, - input_with_headers: bool, - use_row_numbers: bool, - ) -> Result<(), ProtocolError> { - self.id_map - .clone() - .write() - .map(|mut data| { - files::write_vec_to_csv(&mut data, path, input_with_headers, use_row_numbers) - .unwrap(); - }) - .map_err(|_| ProtocolError::ErrorIO("Unable to write company view to file".to_string())) - } +impl Reveal for CompanyCrossPsi { + fn reveal>(&self, path: T) { + if let (Ok(indices), Ok(additive_mask), Ok(mut self_shares)) = ( + self.self_intersection_indices.clone().read(), + self.additive_mask.clone().read(), + self.self_shares.clone().write(), + ) { + let max_val = BigInt::one() << 64; - fn stringify_id_map(&self, use_row_numbers: bool) -> String { - files::stringify_id_map(self.id_map.clone(), use_row_numbers) + let mut filtered_shares: Vec = Vec::with_capacity(indices.len()); + + for index in indices.iter() { + filtered_shares.push((self_shares[&0][*index]).clone()); + } + self_shares.remove(&0); + + let company_shares = filtered_shares + .iter() + .map(|z| (Option::::from(&z.mod_floor(&max_val))).unwrap()) + .collect::>(); + + let partner_shares: Vec = additive_mask + .iter() + .map(|z| (Option::::from(&z.mod_floor(&max_val))).unwrap()) + .collect::>(); + + let mut out: Vec> = + Vec::with_capacity(self.get_self_num_features() + self.get_partner_num_features()); + out.push(partner_shares); + out.push(company_shares); + info!("revealing columns to output file"); + common::files::write_u64cols_to_file(&mut out, path).unwrap(); + } else { + panic!("Unable to reveal"); + } } } diff --git a/protocol/src/cross_psi/partner.rs b/protocol/src/cross_psi/partner.rs index 9b536c8..8e89c97 100644 --- a/protocol/src/cross_psi/partner.rs +++ b/protocol/src/cross_psi/partner.rs @@ -1,215 +1,281 @@ // Copyright (c) Facebook, Inc. and its affiliates. // SPDX-License-Identifier: Apache-2.0 -extern crate csv; +extern crate common; +extern crate crypto; -use crypto::{ - eccipher::{gen_scalar, ECCipher}, - prelude::*, +use log::info; +use std::{ + collections::HashMap, + ops::Deref, + path::Path, + sync::{Arc, RwLock}, }; -#[cfg(target_arch = "wasm32")] -use crypto::eccipher::ECRistrettoSequential as ECRistretto; -#[cfg(not(target_arch = "wasm32"))] -use crypto::eccipher::ECRistrettoParallel as ECRistretto; use crate::{ - fileio::{load_data, load_json, KeyedCSV}, - private_id::traits::PartnerPrivateIdProtocol, + cross_psi::traits::*, + fileio::load_data_with_features, + shared::{LoadData, Reveal, ShareableEncKey, TFeatures}, }; +use common::timer; -use common::{ - files, - permutations::{gen_permute_pattern, permute, undo_permute}, - timer, +use crypto::{ + eccipher, + eccipher::{gen_scalar, ECCipher}, + he, + prelude::{mod_sub, rand_bigints, BigInt, EncryptionKey, Scalar, TPayload}, }; -use std::sync::{Arc, RwLock}; - -use super::{fill_permute, ProtocolError}; - -pub struct PartnerPrivateId { - private_keys: (Scalar, Scalar), - ec_cipher: ECRistretto, - plain_data: Arc>, - permutation: Arc>>, - id_map: Arc>>>, +#[derive(Debug)] +pub struct PartnerCrossPsi { + ec_cipher: eccipher::ECRistrettoParallel, + he_cipher: he::PaillierParallel, + ec_key: Scalar, + company_he_public_key: Arc>, + self_num_records: Arc>, + self_num_features: Arc>, + company_num_records: Arc>, + company_num_features: Arc>, + plaintext_keys: Arc>>, + plaintext_features: Arc>, + company_permutation: Arc>>, + self_permutation: Arc>>, + additive_mask: Arc>>, + self_shares: Arc>>>, + company_intersection_indices: Arc>>, } -impl PartnerPrivateId { - pub fn new() -> PartnerPrivateId { - PartnerPrivateId { - private_keys: (gen_scalar(), gen_scalar()), - ec_cipher: ECRistretto::default(), - plain_data: Arc::new(RwLock::default()), - permutation: Arc::new(RwLock::default()), - id_map: Arc::new(RwLock::default()), +impl PartnerCrossPsi { + pub fn new() -> PartnerCrossPsi { + PartnerCrossPsi { + ec_cipher: eccipher::ECRistrettoParallel::new(), + he_cipher: he::PaillierParallel::new(), + ec_key: gen_scalar(), + company_he_public_key: Arc::new(RwLock::new(EncryptionKey { + n: BigInt::zero(), + nn: BigInt::zero(), + })), + self_num_records: Arc::new(RwLock::default()), + self_num_features: Arc::new(RwLock::default()), + company_num_records: Arc::new(RwLock::default()), + company_num_features: Arc::new(RwLock::default()), + plaintext_keys: Arc::new(RwLock::default()), + plaintext_features: Arc::new(RwLock::default()), + company_permutation: Arc::new(RwLock::default()), + self_permutation: Arc::new(RwLock::default()), + additive_mask: Arc::new(RwLock::default()), + self_shares: Arc::new(RwLock::default()), + company_intersection_indices: Arc::new(RwLock::default()), + } + } + + pub fn set_company_intersection_indices(&self, mut indices: Vec) { + if let Ok(mut company_indices) = self.company_intersection_indices.clone().write() { + company_indices.clear(); + company_indices.extend(indices.drain(..)); + } else { + panic!("Cannot set indices"); } } - pub fn load_data(&self, path: &str, input_with_headers: bool) -> Result<(), ProtocolError> { - load_data(self.plain_data.clone(), path, input_with_headers); - Ok(()) + pub fn get_self_num_features(&self) -> usize { + *self.self_num_features.clone().read().unwrap() + } + + pub fn get_self_num_records(&self) -> usize { + *self.self_num_records.clone().read().unwrap() + } + + pub fn get_company_num_features(&self) -> usize { + *self.company_num_features.clone().read().unwrap() + } + + pub fn get_company_num_records(&self) -> usize { + *self.company_num_records.clone().read().unwrap() + } + + pub fn set_company_num_records(&self, company_num_records: usize) { + *self.company_num_records.clone().write().unwrap() = company_num_records; + } + + pub fn set_company_num_features(&self, company_num_features: usize) { + *self.company_num_features.clone().write().unwrap() = company_num_features; + } + + pub fn set_company_he_public_key(&self, company_he_public_key: EncryptionKey) { + *self.company_he_public_key.clone().write().unwrap() = company_he_public_key; } - pub fn load_json(&self, path: &str, input_with_headers: bool) -> Result { - Ok(load_json(self.plain_data.clone(), path, input_with_headers)) + pub fn fill_permute_company(&self, length: usize) { + if let Ok(mut permute) = self.company_permutation.clone().write() { + permute.clear(); + permute.append(&mut common::permutations::gen_permute_pattern(length)); + } } - pub fn get_size(&self) -> usize { - self.plain_data.clone().read().unwrap().records.len() + pub fn fill_permute_self(&self) { + if let Ok(mut permute) = self.self_permutation.clone().write() { + permute.clear(); + permute.append(&mut common::permutations::gen_permute_pattern( + self.get_self_num_records(), + )); + } + } + + pub fn permute(&self, values: &mut Vec) { + common::permutations::permute( + self.company_permutation.clone().read().unwrap().as_slice(), + values, + ); } } -impl Default for PartnerPrivateId { +impl Default for PartnerCrossPsi { fn default() -> Self { Self::new() } } -impl PartnerPrivateIdProtocol for PartnerPrivateId { - fn gen_permute_pattern(&self) -> Result<(), ProtocolError> { - fill_permute( - self.permutation.clone(), - (*self.plain_data.clone().read().unwrap()).records.len(), - ); - Ok(()) - } - - fn permute_hash_to_bytes(&self) -> Result { - match self.plain_data.clone().read() { - Ok(pdata) => { - #[cfg(not(target_arch="wasm32"))] let t = timer::Timer::new_silent("u_partner"); - let plain_keys = pdata.get_plain_keys(); - let mut u = self - .ec_cipher - .hash_encrypt_to_bytes(&plain_keys.as_slice(), &self.private_keys.0); - #[cfg(not(target_arch="wasm32"))] t.qps("encryption", u.len()); - - self.permutation - .clone() - .read() - .map(|pm| { - permute(&pm, &mut u); - #[cfg(not(target_arch="wasm32"))] t.qps("permutation", pm.len()); - u - }) - .map_err(|err| { - error!("error in permute {}", err); - ProtocolError::ErrorEncryption("unable to encrypt data".to_string()) - }) - } +impl LoadData for PartnerCrossPsi { + fn load_data(&self, input_path: T) + where + T: AsRef, + { + load_data_with_features( + input_path, + self.plaintext_keys.clone(), + self.plaintext_features.clone(), + self.self_num_features.clone(), + self.self_num_records.clone(), + ) + } +} - Err(e) => { - error!("Unable to encrypt plain_data: {}", e); - Err(ProtocolError::ErrorEncryption( - "unable to encrypt data".to_string(), - )) - } +impl ShareableEncKey for PartnerCrossPsi { + fn get_he_public_key(&self) -> EncryptionKey { + (*self.he_cipher.enc_key.clone()).clone() + } +} + +impl PartnerCrossPsiProtocol for PartnerCrossPsi { + fn get_permuted_keys(&self) -> TPayload { + timer::Builder::new() + .label("u_partner") + .extra_label("keys EC enc") + .size(self.get_self_num_records()) + .build(); + + if let (Ok(perm), Ok(mut text)) = ( + self.self_permutation.clone().read(), + self.plaintext_keys.clone().write(), + ) { + common::permutations::permute(perm.as_slice(), &mut text); + self.ec_cipher + .hash_encrypt_to_bytes(text.as_slice(), &self.ec_key) + } else { + panic!("Could not permute and encrypt keys"); } } - //TODO: return result - fn encrypt_permute(&self, company: TPayload) -> (TPayload, TPayload) { - #[cfg(not(target_arch="wasm32"))] let t = timer::Timer::new_silent("encrypt_permute_company"); - let mut encrypt_company = self - .ec_cipher - .to_points_encrypt(&company, &self.private_keys.0); - #[cfg(not(target_arch="wasm32"))] t.qps("encrypt_company", encrypt_company.len()); - let v_company = self - .ec_cipher - .encrypt_to_bytes(&encrypt_company, &self.private_keys.1); - #[cfg(not(target_arch="wasm32"))] t.qps("v_company", v_company.len()); - { - let rand_permutation = gen_permute_pattern(encrypt_company.len()); - // TODO: BUG why is this undo_permute - // undo_permute(&rand_permutation, &mut e_company_dsrlz); - permute(&rand_permutation, &mut encrypt_company); + fn get_permuted_features(&self, feature_index: usize) -> TPayload { + let t = timer::Builder::new() + .silent(true) + .label("u_partner") + .size(self.get_self_num_records()) + .build(); + + if let (Ok(perm), Ok(mut features)) = ( + self.self_permutation.clone().read(), + self.plaintext_features.clone().write(), + ) { + let feature_column = &mut features[feature_index]; + common::permutations::permute(perm.as_slice(), feature_column); + let res = self.he_cipher.enc_serialise_u64(&feature_column); + t.qps( + format!("column {} HE enc", feature_index).as_str(), + res.len(), + ); + res + } else { + panic!("Cannot HE encrypt column {} ", feature_index); } - (self.ec_cipher.to_bytes(&encrypt_company), v_company) } - fn encrypt(&self, partner: TPayload) -> Result { - let ep = self - .ec_cipher - .to_points_encrypt(&partner, &self.private_keys.1); - Ok(self.ec_cipher.to_bytes(&ep)) + fn encrypt(&self, keys: TPayload) -> TPayload { + timer::Builder::new() + .label("e_company") + .extra_label("keys EC enc + srlz") + .size(keys.len()) + .build(); + + self.ec_cipher.to_bytes( + self.ec_cipher + .to_points_encrypt(keys.as_slice(), &self.ec_key) + .as_slice(), + ) } - fn create_id_map(&self, partner: TPayload, company: TPayload, na_val: Option<&str>) { - match ( - self.permutation.clone().read(), - self.plain_data.clone().read(), - self.id_map.clone().write(), + fn generate_additive_shares(&self, _: usize, values: TPayload) -> TPayload { + { + *self.additive_mask.clone().write().unwrap() = rand_bigints(values.len()); + } + + if let (Ok(key), Ok(mask)) = ( + self.company_he_public_key.clone().read(), + self.additive_mask.clone().read(), ) { - (Ok(pm), Ok(plain_data), Ok(mut id_map)) => { - let mut partner_encrypt = self - .ec_cipher - .to_points_encrypt(&partner, &self.private_keys.1); - undo_permute(&pm, &mut partner_encrypt); - - for (k, v) in self - .ec_cipher - .to_bytes(&partner_encrypt) - .iter() - .zip(plain_data.get_plain_keys().iter()) - { - let record = plain_data.get_record_with_keys(k.to_string(), &v); - id_map.push(record); - } - - for k in self - .ec_cipher - .to_bytes( - &self - .ec_cipher - .to_points_encrypt(&company, &self.private_keys.1), - ) - .iter() - { - let record = plain_data.get_empty_record_with_key( - k.to_string(), - na_val.map(String::from).as_ref(), - ); - id_map.push(record); - } - - if !plain_data.headers.is_empty() { - id_map.insert(0, plain_data.headers.clone()); - } - } - _ => panic!("Cannot make v"), + self.he_cipher + .subtract_plaintext(key.deref(), values, &mask) + } else { + panic!("Cannot mask with additive shares") } } - fn print_id_map(&self, limit: usize, input_with_headers: bool, use_row_numbers: bool) { - let _ = self - .id_map - .clone() - .read() - .map(|data| { - files::write_vec_to_stdout(&data, limit, input_with_headers, use_row_numbers) - .unwrap() - }) - .map_err(|_| {}); - } - - fn save_id_map( - &self, - path: &str, - input_with_headers: bool, - use_row_numbers: bool, - ) -> Result<(), ProtocolError> { - self.id_map - .clone() - .write() - .map(|mut data| { - files::write_vec_to_csv(&mut data, path, input_with_headers, use_row_numbers) - .unwrap(); - }) - .map_err(|_| ProtocolError::ErrorIO("Unable to write partner view to file".to_string())) - } - - fn stringify_id_map(&self, use_row_numbers: bool) -> String { - files::stringify_id_map(self.id_map.clone(), use_row_numbers) + fn set_self_shares(&self, feature_index: usize, data: TPayload) { + if let Ok(mut shares) = self.self_shares.clone().write() { + info!("Saving self-shares for feature {}", feature_index); + shares.insert(feature_index, self.he_cipher.decrypt(data)); + } else { + panic!("Unable to write shares"); + } + } +} + +impl Reveal for PartnerCrossPsi { + fn reveal>(&self, path: T) { + if let (Ok(indices), Ok(mut self_shares), Ok(mut additive_mask)) = ( + self.company_intersection_indices.clone().read(), + self.self_shares.clone().write(), + self.additive_mask.clone().write(), + ) { + let output_mod = BigInt::one() << 64; + let n = BigInt::one() << 1024; + + let mut filtered_shares: Vec = Vec::with_capacity(indices.len()); + + for index in indices.iter() { + filtered_shares.push(additive_mask[*index].clone()); + } + additive_mask.clear(); + + let company_shares = filtered_shares + .iter() + .map(|e| (Option::::from(&mod_sub(e, &n, &output_mod))).unwrap()) + .collect::>(); + + let partner_shares = self_shares + .remove(&0) + .unwrap() + .iter() + .map(|e| (Option::::from(&mod_sub(e, &n, &output_mod))).unwrap()) + .collect::>(); + + let mut out: Vec> = + Vec::with_capacity(self.get_self_num_features() + self.get_company_num_features()); + out.push(partner_shares); + out.push(company_shares); + info!("revealing columns to output file"); + common::files::write_u64cols_to_file(&mut out, path).unwrap(); + } } } From fbfcd00f82c43e3fdc4133a69fe6898fa0df9079 Mon Sep 17 00:00:00 2001 From: zfscgy Date: Thu, 4 Feb 2021 13:23:56 +0800 Subject: [PATCH 3/3] support for multiple features for company/partner --- protocol/src/cross_psi/company.rs | 52 +++++++++++++++------------- protocol/src/cross_psi/partner.rs | 57 ++++++++++++++++--------------- 2 files changed, 57 insertions(+), 52 deletions(-) diff --git a/protocol/src/cross_psi/company.rs b/protocol/src/cross_psi/company.rs index dd13fab..738fa9b 100644 --- a/protocol/src/cross_psi/company.rs +++ b/protocol/src/cross_psi/company.rs @@ -51,7 +51,7 @@ pub struct CompanyCrossPsi { self_intersection_indices: Arc>>, //TODO: WARN: this is single column only (yet) - additive_mask: Arc>>, + additive_mask: Arc>>>, partner_shares: Arc>>, self_shares: Arc>>>, } @@ -207,7 +207,7 @@ impl CompanyCrossPsiProtocol for CompanyCrossPsi { } } - fn generate_additive_shares(&self, feature_id: usize, values: TPayload) { + fn generate_additive_shares(&self, feature_index: usize, values: TPayload) { let t = timer::Builder::new() .label("server") .silent(true) @@ -227,7 +227,7 @@ impl CompanyCrossPsiProtocol for CompanyCrossPsi { // Generate random mask { - *self.additive_mask.clone().write().unwrap() = rand_bigints(filtered_values.len()); + self.additive_mask.clone().write().unwrap().insert(feature_index, rand_bigints(filtered_values.len())); } if let (Ok(key), Ok(mask), Ok(mut partner_shares)) = ( @@ -237,9 +237,9 @@ impl CompanyCrossPsiProtocol for CompanyCrossPsi { ) { let res = self .he_cipher - .subtract_plaintext(&key, filtered_values, &mask); + .subtract_plaintext(&key, filtered_values, &mask[&feature_index]); t.qps("masking values in the intersection", res.len()); - partner_shares.insert(feature_id, res); + partner_shares.insert(feature_index, res); } else { panic!("Unable to add additive shares with the intersection") } @@ -336,29 +336,33 @@ impl Reveal for CompanyCrossPsi { self.additive_mask.clone().read(), self.self_shares.clone().write(), ) { + + let mut out: Vec> = Vec::with_capacity(self.get_self_num_features() + self.get_partner_num_features()); let max_val = BigInt::one() << 64; + for feature_index in 0.. self.get_partner_num_features() + self.get_self_num_features() { + if feature_index < self.get_partner_num_features() { + let partner_shares: Vec = additive_mask[&feature_index] + .iter() + .map(|z| (Option::::from(&z.mod_floor(&max_val))).unwrap()) + .collect::>(); + out.push(partner_shares); + } + else { + let mut filtered_shares: Vec = Vec::with_capacity(indices.len()); - let mut filtered_shares: Vec = Vec::with_capacity(indices.len()); + for index in indices.iter() { + filtered_shares.push(self_shares[&(feature_index - self.get_partner_num_features())][*index] + .clone()); + } - for index in indices.iter() { - filtered_shares.push((self_shares[&0][*index]).clone()); + let company_shares = filtered_shares + .iter() + .map(|z| (Option::::from(&z.mod_floor(&max_val))).unwrap()) + .collect::>(); + + out.push(company_shares); + } } - self_shares.remove(&0); - - let company_shares = filtered_shares - .iter() - .map(|z| (Option::::from(&z.mod_floor(&max_val))).unwrap()) - .collect::>(); - - let partner_shares: Vec = additive_mask - .iter() - .map(|z| (Option::::from(&z.mod_floor(&max_val))).unwrap()) - .collect::>(); - - let mut out: Vec> = - Vec::with_capacity(self.get_self_num_features() + self.get_partner_num_features()); - out.push(partner_shares); - out.push(company_shares); info!("revealing columns to output file"); common::files::write_u64cols_to_file(&mut out, path).unwrap(); } else { diff --git a/protocol/src/cross_psi/partner.rs b/protocol/src/cross_psi/partner.rs index 8e89c97..a9111b3 100644 --- a/protocol/src/cross_psi/partner.rs +++ b/protocol/src/cross_psi/partner.rs @@ -40,7 +40,7 @@ pub struct PartnerCrossPsi { plaintext_features: Arc>, company_permutation: Arc>>, self_permutation: Arc>>, - additive_mask: Arc>>, + additive_mask: Arc>>>, self_shares: Arc>>>, company_intersection_indices: Arc>>, } @@ -215,9 +215,9 @@ impl PartnerCrossPsiProtocol for PartnerCrossPsi { ) } - fn generate_additive_shares(&self, _: usize, values: TPayload) -> TPayload { + fn generate_additive_shares(&self, feature_index: usize, values: TPayload) -> TPayload { { - *self.additive_mask.clone().write().unwrap() = rand_bigints(values.len()); + self.additive_mask.clone().write().unwrap().insert(feature_index, rand_bigints(values.len())); } if let (Ok(key), Ok(mask)) = ( @@ -225,7 +225,7 @@ impl PartnerCrossPsiProtocol for PartnerCrossPsi { self.additive_mask.clone().read(), ) { self.he_cipher - .subtract_plaintext(key.deref(), values, &mask) + .subtract_plaintext(key.deref(), values, &mask[&feature_index]) } else { panic!("Cannot mask with additive shares") } @@ -243,37 +243,38 @@ impl PartnerCrossPsiProtocol for PartnerCrossPsi { impl Reveal for PartnerCrossPsi { fn reveal>(&self, path: T) { - if let (Ok(indices), Ok(mut self_shares), Ok(mut additive_mask)) = ( + if let (Ok(indices), Ok(self_shares), Ok(additive_mask)) = ( self.company_intersection_indices.clone().read(), self.self_shares.clone().write(), self.additive_mask.clone().write(), ) { let output_mod = BigInt::one() << 64; let n = BigInt::one() << 1024; - - let mut filtered_shares: Vec = Vec::with_capacity(indices.len()); - - for index in indices.iter() { - filtered_shares.push(additive_mask[*index].clone()); - } - additive_mask.clear(); - - let company_shares = filtered_shares - .iter() - .map(|e| (Option::::from(&mod_sub(e, &n, &output_mod))).unwrap()) - .collect::>(); - - let partner_shares = self_shares - .remove(&0) - .unwrap() - .iter() - .map(|e| (Option::::from(&mod_sub(e, &n, &output_mod))).unwrap()) - .collect::>(); - let mut out: Vec> = - Vec::with_capacity(self.get_self_num_features() + self.get_company_num_features()); - out.push(partner_shares); - out.push(company_shares); + Vec::with_capacity(self.get_self_num_features() + self.get_company_num_features()); + + for feature_index in 0..self.get_self_num_features() + self.get_company_num_features() { + + if feature_index < self.get_self_num_features() { + let partner_shares = self_shares[&feature_index] + .iter() + .map(|e| (Option::::from(&mod_sub(e, &n, &output_mod))).unwrap()) + .collect::>(); + out.push(partner_shares); + } + else { + let mut filtered_shares: Vec = Vec::with_capacity(indices.len()); + for index in indices.iter() { + filtered_shares.push(additive_mask[&(feature_index - self.get_self_num_features())][*index] + .clone()); + }; + let company_shares = filtered_shares + .iter() + .map(|e| (Option::::from(&mod_sub(e, &n, &output_mod))).unwrap()) + .collect::>(); + out.push(company_shares); + } + } info!("revealing columns to output file"); common::files::write_u64cols_to_file(&mut out, path).unwrap(); }