Skip to content

Commit

Permalink
perf: add rayon parallel computation
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobkaufmann committed Apr 11, 2024
1 parent 9c32605 commit 2d4fbe7
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 56 deletions.
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ edition = "2021"
blst = "0.3.11"
hex = { version = "0.4.3", optional = true }
rand = { version = "0.8.5", optional = true }
rayon = "1.10.0"
serde = { version = "1.0.189", features = ["derive"], optional = true }
serde_json = { version = "1.0.107", optional = true }
serde_yaml = { version = "0.9.25", optional = true }
Expand All @@ -23,4 +24,4 @@ serde = ["dep:hex", "dep:serde", "dep:serde_json", "dep:serde_yaml"]
[[bench]]
name = "kzg"
harness = false
required-features = ["rand", "serde"]
required-features = ["rand", "serde"]
10 changes: 10 additions & 0 deletions src/bls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,16 @@ impl Add for Fr {
}
}

impl core::iter::Sum for Fr {
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
let mut out = Self::ZERO;
for summand in iter {
out = out + summand;
}
out
}
}

impl Sub for Fr {
type Output = Self;

Expand Down
77 changes: 42 additions & 35 deletions src/kzg/poly.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator};

use crate::bls::{Fr, P1};

use super::{setup::Setup, Proof};
Expand All @@ -11,21 +13,25 @@ impl<'a, const N: usize> Polynomial<'a, N> {
let roots = &setup.roots_of_unity_brp;

// if `point` is a root of a unity, then we have the evaluation available
for i in 0..N {
if point == roots[i] {
return self.0[i];
}
let eval = roots
.par_iter()
.enumerate()
.find_any(|(_, &root)| point == root);
if let Some((index, _)) = eval {
return self.0[index];
}

let mut eval = Fr::ZERO;

// barycentric evaluation summation
for i in 0..N {
let numer = self.0[i] * roots[i];
let denom = point - roots[i];
let term = numer / denom;
eval = eval + term;
}
let eval: Fr = self
.0
.par_iter()
.enumerate()
.map(|(i, &x)| {
let numer = x * roots[i];
let denom = point - roots[i];
numer / denom
})
.sum();

// barycentric evaluation scalar multiplication
let term = (point.pow(&Fr::from(N as u64)) - Fr::ONE) / Fr::from(N as u64);
Expand All @@ -39,31 +45,32 @@ impl<'a, const N: usize> Polynomial<'a, N> {
let eval = self.evaluate(point, setup);

// compute the quotient polynomial
//
// TODO: parallelize (e.g. rayon)
let mut quotient_poly = Vec::with_capacity(N);
for i in 0..N {
let numer = self.0[i] - eval;
let denom = roots[i] - point;
let quotient = if denom != Fr::ZERO {
numer / denom
} else {
let mut quotient = Fr::ZERO;
for j in 0..N {
if j == i {
continue;
}
let quotient_poly: Vec<Fr> = self
.0
.par_iter()
.enumerate()
.map(|(i, &x)| {
let numer = x - eval;
let denom = roots[i] - point;
if denom != Fr::ZERO {
numer / denom
} else {
let mut quotient = Fr::ZERO;
for j in 0..N {
if j == i {
continue;
}

let coefficient = self.0[j] - eval;
let numer = coefficient * roots[j];
let denom = (roots[i] * roots[i]) - (roots[i] * roots[j]);
let term = numer / denom;
quotient = quotient + term;
let coefficient = self.0[j] - eval;
let numer = coefficient * roots[j];
let denom = (roots[i] * roots[i]) - (roots[i] * roots[j]);
let term = numer / denom;
quotient = quotient + term;
}
quotient
}
quotient
};
quotient_poly.push(quotient);
}
})
.collect();

let lincomb = P1::lincomb_pippenger(setup.g1_lagrange_brp.as_slice(), quotient_poly);

Expand Down
46 changes: 26 additions & 20 deletions src/kzg/setup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use crate::{
#[cfg(feature = "serde")]
use crate::{bytes::Bytes, math};

use rayon::iter::{IntoParallelIterator, ParallelIterator};
#[cfg(feature = "serde")]
use serde::Deserialize;

Expand Down Expand Up @@ -114,10 +115,10 @@ impl<const G1: usize, const G2: usize> Setup<G1, G2> {

fn verify_proof_batch(
&self,
proofs: impl AsRef<[Proof]>,
commitments: impl AsRef<[Commitment]>,
points: impl AsRef<[Fr]>,
evals: impl AsRef<[Fr]>,
proofs: &[Proof],
commitments: &[Commitment],
points: &[Fr],
evals: &[Fr],
) -> bool {
assert_eq!(proofs.as_ref().len(), commitments.as_ref().len());
assert_eq!(commitments.as_ref().len(), points.as_ref().len());
Expand All @@ -134,22 +135,22 @@ impl<const G1: usize, const G2: usize> Setup<G1, G2> {
data[32..].copy_from_slice(&len);

let r = Fr::hash_to(data);
let mut rpowers = Vec::with_capacity(n);
let mut points_mul_rpowers = Vec::with_capacity(n);
let mut comms_minus_evals = Vec::with_capacity(n);
for i in 0..proofs.as_ref().len() {
let rpower = r.pow(&Fr::from(i as u64));
rpowers.push(rpower);

let point = points.as_ref()[i];
points_mul_rpowers.push(point * rpower);

let commitment = commitments.as_ref()[i];
let eval = evals.as_ref()[i];
comms_minus_evals.push(commitment + (P1::neg_generator() * eval));
}
let (rpowers, (points_mul_rpowers, comms_minus_evals)): (Vec<_>, (Vec<_>, Vec<_>)) = (0..n)
.into_par_iter()
.map(|i| {
let rpower = r.pow(&Fr::from(i as u64));

let point = points[i] * rpower;

let proof_lincomb = P1::lincomb(&proofs, &rpowers);
let commitment = commitments[i];
let eval = evals[i];
let comms_minus_eval = commitment + (P1::neg_generator() * eval);

(rpower, (point, comms_minus_eval))
})
.collect();

let proof_lincomb = P1::lincomb(proofs, &rpowers);
let proof_z_lincomb = P1::lincomb(proofs, points_mul_rpowers);

let comm_minus_eval_lincomb = P1::lincomb(comms_minus_evals, rpowers);
Expand Down Expand Up @@ -241,7 +242,12 @@ impl<const G1: usize, const G2: usize> Setup<G1, G2> {
evaluations.push(eval);
}

self.verify_proof_batch(proofs, commitments, challenges, evaluations)
self.verify_proof_batch(
proofs.as_ref(),
commitments.as_ref(),
challenges.as_slice(),
evaluations.as_slice(),
)
}

pub fn verify_blob_proof_batch<B>(
Expand Down

0 comments on commit 2d4fbe7

Please sign in to comment.