Skip to content

Commit

Permalink
Provide a method to allow PTH files with state maps to be loaded. (#2639
Browse files Browse the repository at this point in the history
)

* Provide a method to allow PTH files iwth state maps to be loaded.

* add a line to the doc

* String-. &str
  • Loading branch information
zachcp authored Nov 26, 2024
1 parent c12db59 commit b4deb5c
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion candle-nn/src/var_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,17 @@ impl<'a> VarBuilder<'a> {
let pth = candle::pickle::PthTensors::new(p, None)?;
Ok(Self::from_backend(Box::new(pth), dtype, dev.clone()))
}

/// Initializes a `VarBuilder` that retrieves tensors stored in a pytorch pth file.
/// similar to [`from_pth`] but requires a `state_key`.
pub fn from_pth_with_state<P: AsRef<std::path::Path>>(
p: P,
dtype: DType,
state_key: &str,
dev: &Device,
) -> Result<Self> {
let pth = candle::pickle::PthTensors::new(p, Some(state_key))?;
Ok(Self::from_backend(Box::new(pth), dtype, dev.clone()))
}
/// Gets a VarBuilder that applies some renaming function on tensor it gets queried for before
/// passing the new names to the inner VarBuilder.
///
Expand Down

0 comments on commit b4deb5c

Please sign in to comment.