Skip to content

Commit

Permalink
perf: cache spks + txs to avoid n-plus-one in bdk (#472)
Browse files Browse the repository at this point in the history
* chore: cache spks for faster lookup

* chore: cache txs for faster lookup

* chore: remove accidentily committed file
  • Loading branch information
bodymindarts authored Feb 15, 2024
1 parent 9ce2a3b commit da77c07
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 41 deletions.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

132 changes: 94 additions & 38 deletions src/bdk/pg/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ use convert::BdkKeychainKind;
use descriptor_checksum::DescriptorChecksums;
use index::Indexes;
use script_pubkeys::ScriptPubkeys;
use std::{
collections::HashMap,
sync::{Arc, Mutex},
};
pub(super) use sync_times::SyncTimes;
pub use transactions::*;
pub use utxos::*;
Expand All @@ -27,9 +31,11 @@ pub struct SqlxWalletDb {
rt: Handle,
pool: PgPool,
keychain_id: KeychainId,
addresses: Option<Vec<(BdkKeychainKind, u32, ScriptBuf)>>,
utxos: Option<Vec<LocalUtxo>>,
txs: Option<Vec<TransactionDetails>>,
cached_spks: Arc<Mutex<HashMap<ScriptBuf, (KeychainKind, u32)>>>,
addresses: HashMap<ScriptBuf, (KeychainKind, u32)>,
cached_txs: Arc<Mutex<HashMap<Txid, TransactionDetails>>>,
txs: HashMap<Txid, TransactionDetails>,
}

impl SqlxWalletDb {
Expand All @@ -38,10 +44,37 @@ impl SqlxWalletDb {
rt: Handle::current(),
keychain_id,
pool,
addresses: None,
utxos: None,
txs: None,
addresses: HashMap::new(),
cached_spks: Arc::new(Mutex::new(HashMap::new())),
txs: HashMap::new(),
cached_txs: Arc::new(Mutex::new(HashMap::new())),
}
}

fn load_all_txs(&self) -> Result<(), bdk::Error> {
let mut txs = self.cached_txs.lock().expect("poisoned txs cache lock");
if txs.is_empty() {
let loaded = self.rt.block_on(async {
let txs = Transactions::new(self.keychain_id, self.pool.clone());
txs.load_all().await
})?;
*txs = loaded;
}
Ok(())
}

fn lookup_tx(&self, txid: &Txid) -> Result<Option<TransactionDetails>, bdk::Error> {
if let Some(tx) = self.txs.get(txid) {
return Ok(Some(tx.clone()));
}
self.load_all_txs()?;
Ok(self
.cached_txs
.lock()
.expect("poisoned txs cache lock")
.get(txid)
.cloned())
}
}

Expand All @@ -52,14 +85,7 @@ impl BatchOperations for SqlxWalletDb {
keychain: KeychainKind,
path: u32,
) -> Result<(), bdk::Error> {
if self.addresses.is_none() {
self.addresses = Some(Vec::new());
}
self.addresses.as_mut().unwrap().push((
BdkKeychainKind::from(keychain),
path,
script.into(),
));
self.addresses.insert(script.into(), (keychain, path));
Ok(())
}

Expand All @@ -76,10 +102,7 @@ impl BatchOperations for SqlxWalletDb {
}

fn set_tx(&mut self, tx: &TransactionDetails) -> Result<(), bdk::Error> {
if self.txs.is_none() {
self.txs = Some(Vec::new());
}
self.txs.as_mut().unwrap().push(tx.clone());
self.txs.insert(tx.txid, tx.clone());
Ok(())
}

Expand Down Expand Up @@ -179,10 +202,14 @@ impl Database for SqlxWalletDb {
}

fn iter_txs(&self, _: bool) -> Result<Vec<TransactionDetails>, bdk::Error> {
self.rt.block_on(async {
let txs = Transactions::new(self.keychain_id, self.pool.clone());
txs.list().await
})
self.load_all_txs()?;
Ok(self
.cached_txs
.lock()
.expect("poisoned txs cache lock")
.values()
.cloned()
.collect())
}

fn get_script_pubkey_from_path(
Expand All @@ -199,13 +226,22 @@ impl Database for SqlxWalletDb {
&self,
script: &Script,
) -> Result<Option<(KeychainKind, u32)>, bdk::Error> {
self.rt.block_on(async {
let script_pubkeys = ScriptPubkeys::new(self.keychain_id, self.pool.clone());
Ok(script_pubkeys
.find_path(&ScriptBuf::from(script))
.await?
.map(|(kind, path)| (kind.into(), path)))
})
let mut cache = self.cached_spks.lock().expect("poisoned spk cache lock");
if cache.is_empty() {
let loaded = self.rt.block_on(async {
let script_pubkeys = ScriptPubkeys::new(self.keychain_id, self.pool.clone());
script_pubkeys.load_all().await
})?;
*cache = loaded;
}

if let Some(res) = cache.get(script) {
Ok(Some(*res))
} else if let Some(res) = self.addresses.get(script) {
Ok(Some(*res))
} else {
Ok(None)
}
}
fn get_utxo(&self, outpoint: &OutPoint) -> Result<Option<LocalUtxo>, bdk::Error> {
self.rt.block_on(async {
Expand All @@ -215,20 +251,15 @@ impl Database for SqlxWalletDb {
})
}
fn get_raw_tx(&self, tx_id: &Txid) -> Result<Option<Transaction>, bdk::Error> {
self.rt.block_on(async {
let txs = Transactions::new(self.keychain_id, self.pool.clone());
Ok(txs.find_by_id(tx_id).await?.and_then(|tx| tx.transaction))
})
self.lookup_tx(tx_id)
.map(|tx| tx.and_then(|tx| tx.transaction))
}
fn get_tx(
&self,
tx_id: &Txid,
_include_raw: bool,
) -> Result<Option<TransactionDetails>, bdk::Error> {
self.rt.block_on(async {
let txs = Transactions::new(self.keychain_id, self.pool.clone());
txs.find_by_id(tx_id).await
})
self.lookup_tx(tx_id)
}
fn get_last_index(&self, kind: KeychainKind) -> Result<std::option::Option<u32>, bdk::Error> {
self.rt.block_on(async {
Expand All @@ -254,23 +285,48 @@ impl BatchDatabase for SqlxWalletDb {
type Batch = Self;

fn begin_batch(&self) -> <Self as BatchDatabase>::Batch {
SqlxWalletDb::new(self.pool.clone(), self.keychain_id)
let mut res = SqlxWalletDb::new(self.pool.clone(), self.keychain_id);
res.cached_spks = Arc::clone(&self.cached_spks);
res.cached_txs = Arc::clone(&self.cached_txs);
res
}

fn commit_batch(
&mut self,
mut batch: <Self as BatchDatabase>::Batch,
) -> Result<(), bdk::Error> {
self.cached_spks
.lock()
.expect("poisoned spk cache lock")
.extend(
batch
.addresses
.iter()
.map(|(s, (k, p))| (s.clone(), (*k, *p))),
);

self.cached_txs
.lock()
.expect("poisoned txs cache lock")
.extend(batch.txs.iter().map(|(id, tx)| (*id, tx.clone())));

self.rt.block_on(async move {
if let Some(addresses) = batch.addresses.take() {
if !batch.addresses.is_empty() {
let addresses: Vec<_> = batch
.addresses
.drain()
.map(|(s, (k, p))| (BdkKeychainKind::from(k), p, s))
.collect();
let repo = ScriptPubkeys::new(batch.keychain_id, batch.pool.clone());
repo.persist_all(addresses).await?;
}

if let Some(utxos) = batch.utxos.take() {
let repo = Utxos::new(batch.keychain_id, batch.pool.clone());
repo.persist_all(utxos).await?;
}
if let Some(txs) = batch.txs.take() {
if !batch.txs.is_empty() {
let txs = batch.txs.drain().map(|(_, tx)| tx).collect();
let repo = Transactions::new(batch.keychain_id, batch.pool.clone());
repo.persist_all(txs).await?;
}
Expand Down
23 changes: 23 additions & 0 deletions src/bdk/pg/script_pubkeys.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use sqlx::{PgPool, Postgres, QueryBuilder};
use std::collections::HashMap;
use tracing::instrument;
use uuid::Uuid;

Expand Down Expand Up @@ -70,6 +71,28 @@ impl ScriptPubkeys {
.map(|row| ScriptBuf::from(row.script)))
}

#[instrument(name = "bdk.script_pubkeys.load_all", skip_all)]
pub async fn load_all(
&self,
) -> Result<HashMap<ScriptBuf, (bdk::KeychainKind, u32)>, bdk::Error> {
let rows = sqlx::query!(
r#"SELECT script, keychain_kind as "keychain_kind: BdkKeychainKind", path FROM bdk_script_pubkeys
WHERE keychain_id = $1"#,
Uuid::from(self.keychain_id),
)
.fetch_all(&self.pool)
.await
.map_err(|e| bdk::Error::Generic(e.to_string()))?;
let mut ret = HashMap::new();
for row in rows {
ret.insert(
ScriptBuf::from(row.script),
(bdk::KeychainKind::from(row.keychain_kind), row.path as u32),
);
}
Ok(ret)
}

#[instrument(name = "bdk.script_pubkeys.find_path", skip_all)]
pub async fn find_path(
&self,
Expand Down
11 changes: 8 additions & 3 deletions src/bdk/pg/transactions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ use bdk::{bitcoin::Txid, LocalUtxo, TransactionDetails};
use sqlx::{PgPool, Postgres, QueryBuilder, Transaction};
use tracing::instrument;

use std::collections::HashMap;

use crate::{bdk::error::BdkError, primitives::*};

#[derive(Debug)]
Expand Down Expand Up @@ -97,8 +99,8 @@ impl Transactions {
Ok(tx.map(|tx| serde_json::from_value(tx.details_json).unwrap()))
}

#[instrument(name = "bdk.transactions.list", skip(self), fields(n_rows))]
pub async fn list(&self) -> Result<Vec<TransactionDetails>, bdk::Error> {
#[instrument(name = "bdk.transactions.load_all", skip(self), fields(n_rows))]
pub async fn load_all(&self) -> Result<HashMap<Txid, TransactionDetails>, bdk::Error> {
let txs = sqlx::query!(
r#"
SELECT details_json FROM bdk_transactions WHERE keychain_id = $1 AND deleted_at IS NULL"#,
Expand All @@ -110,7 +112,10 @@ impl Transactions {
tracing::Span::current().record("n_rows", txs.len());
Ok(txs
.into_iter()
.map(|tx| serde_json::from_value(tx.details_json).unwrap())
.map(|tx| {
let tx = serde_json::from_value::<TransactionDetails>(tx.details_json).unwrap();
(tx.txid, tx)
})
.collect())
}

Expand Down

0 comments on commit da77c07

Please sign in to comment.