Skip to content

Commit

Permalink
Add max-all/min-all. (#2616)
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare authored Nov 14, 2024
1 parent 06350c3 commit 0ed24b9
Showing 1 changed file with 36 additions and 0 deletions.
36 changes: 36 additions & 0 deletions candle-core/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1760,6 +1760,42 @@ impl Tensor {
&self.op
}

/// Computes the max of all the elements in this tensor and returns a tensor holding this
/// scalar with zero dimensions.
///
/// ```rust
/// use candle_core::{Tensor, Device};
/// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
/// let tensor = tensor.max_all()?;
/// assert_eq!(tensor.to_scalar::<f32>()?, 5.);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn max_all(&self) -> Result<Tensor> {
if self.rank() == 0 {
Ok(self.clone())
} else {
self.flatten_all()?.max(0)
}
}

/// Computes the min of all the elements in this tensor and returns a tensor holding this
/// scalar with zero dimensions.
///
/// ```rust
/// use candle_core::{Tensor, Device};
/// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
/// let tensor = tensor.min_all()?;
/// assert_eq!(tensor.to_scalar::<f32>()?, 0.);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn min_all(&self) -> Result<Tensor> {
if self.rank() == 0 {
Ok(self.clone())
} else {
self.flatten_all()?.min(0)
}
}

/// Computes the sum of all the elements in this tensor and returns a tensor holding this
/// scalar with zero dimensions.
///
Expand Down

0 comments on commit 0ed24b9

Please sign in to comment.