Skip to content

Commit

Permalink
feat: atan, atan2
Browse files Browse the repository at this point in the history
  • Loading branch information
strasdat committed Nov 10, 2023
1 parent 4476b5e commit 610b3e1
Show file tree
Hide file tree
Showing 7 changed files with 171 additions and 1 deletion.
2 changes: 1 addition & 1 deletion dfdx-core/src/tensor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ pub use storage_traits::{AsArray, CopySlice, TensorFrom, TensorFromVec, TensorTo
pub use storage_traits::{Cache, RandomU64, Storage, Synchronize};
pub use storage_traits::{OnesTensor, SampleTensor, TriangleTensor, ZerosTensor};

pub use tensor_impls::{PutTape, SplitTape, Tensor, Trace, WithEmptyTape};
pub use tensor_impls::{PutTape, SplitTape, Tensor, Trace, WithEmptyTape, CloneNoTape};
pub use tensor_impls::{Tensor0D, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, Tensor6D};

pub(crate) use unique_id::unique_id;
Expand Down
21 changes: 21 additions & 0 deletions dfdx-core/src/tensor/tensor_impls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,27 @@ impl<S: Shape, E: Clone, D: Storage<E>, T> SplitTape for Tensor<S, E, D, T> {
}
}

pub trait CloneNoTape {
type Tape;
type NoTape: Clone + PutTape<Self::Tape, Output = Self>;
fn clone_no_tape(&self) -> Self::NoTape;
}

impl<S: Shape, E: Clone, D: Storage<E>, T> CloneNoTape for Tensor<S, E, D, T> {
type Tape = T;
type NoTape = Tensor<S, E, D>;
fn clone_no_tape(&self) -> Self::NoTape {
Tensor {
id: self.id,
data: self.data.clone(),
shape: self.shape,
strides: self.strides,
device: self.device.clone(),
tape: NoneTape,
}
}
}

/// Clones self and inserts a new empty tape into the clone
pub trait WithEmptyTape {
/// Clones self and inserts a new empty tape into the clone
Expand Down
15 changes: 15 additions & 0 deletions dfdx-core/src/tensor_ops/atan/cpu_kernel.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
use crate::tensor_ops::cpu_kernels::UnaryDerivative;

impl<F: num_traits::Float + std::ops::Mul<Output=F>> UnaryDerivative<F> for super::AtanKernelOp {
const DF_USES_FX: bool = false;
const HAS_CONST_DF: bool = false;
#[inline(always)]
fn f(&self, x: &F) -> F {
x.atan()
}
#[inline(always)]
fn df(&self, x: &F) -> F {
let one = F::from(1.0).unwrap();
one / (one + (*x)*(*x))
}
}
57 changes: 57 additions & 0 deletions dfdx-core/src/tensor_ops/atan/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
mod cpu_kernel;

#[cfg(feature = "cuda")]
mod cuda_kernel;

use super::ops::{try_unary_op, UnaryKernel};
use crate::{shapes::*, tensor::*};

#[repr(C)]
#[derive(Debug, Default, Copy, Clone)]
pub struct AtanKernelOp;

///
/// It's derivative is `1 / (1 + x^2)`
///
/// Examples:
/// ```rust
/// ```
pub fn atan<S: Shape, E: Dtype, D: UnaryKernel<AtanKernelOp, E>, T: Tape<E, D>>(
t: Tensor<S, E, D, T>,
) -> Tensor<S, E, D, T> {
t.atan()
}

impl<S: Shape, E: Dtype, D: UnaryKernel<AtanKernelOp, E>, T: Tape<E, D>> Tensor<S, E, D, T> {
/// See [atan]
pub fn atan(self) -> Self {
self.try_atan().unwrap()
}
/// See [atan]
pub fn try_atan(self) -> Result<Self, Error> {
try_unary_op(AtanKernelOp, self)
}
}

#[cfg(test)]
mod tests {
use crate::prelude::storage_traits::AsArray;
use crate::tests::*;
use crate::{tensor::*, tensor_ops::*};

#[test]
fn test_atan() {
let dev: TestDevice = Default::default();
let x = dev.tensor([-2.0, -1.0]).to_dtype::<TestDtype>();

let expected_dx_atanx = [0.2, 0.5];
let mut a = [0.0, 0.0];
for i in 0..2 {
let r = x.clone().leaky_trace().atan();
let g = r.select(dev.tensor(i)).backward();
let rr = g.get(&x);
a[i] = rr[[i]];
}
assert_close_to_literal!(dev.tensor(a), expected_dx_atanx);
}
}
19 changes: 19 additions & 0 deletions dfdx-core/src/tensor_ops/atan2/cpu_kernel.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
use crate::tensor_ops::cpu_kernels::{BinaryDerivative};
use num_traits::Float;

impl<F: Float> BinaryDerivative<F> for super::BinaryAtan2KernelOp {
const HAS_CONST_DF: bool = false;
#[inline(always)]
fn f(&self, &x: &F, &y: &F) -> F {
x.atan2(y)
}
#[inline(always)]
fn dfdx(&self, x: &F, y: &F) -> F {
-(*x)/((*x)*(*x) + (*y)*(*y))
}

#[inline(always)]
fn dfdy(&self, x: &F, y: &F) -> F {
*y/((*x)*(*x) + (*y)*(*y))
}
}
54 changes: 54 additions & 0 deletions dfdx-core/src/tensor_ops/atan2/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
mod cpu_kernel;

#[cfg(feature = "cuda")]
mod cuda_kernel;

use super::ops::*;
use crate::{shapes::*, tensor::*};

#[repr(C)]
#[derive(Debug, Default, Clone, Copy)]
pub struct BinaryAtan2KernelOp;


pub fn atan2<S: Shape, E: Dtype, D, T: Tape<E, D> + Merge<R>, R: Default>(
lhs: Tensor<S, E, D, T>,
rhs: Tensor<S, E, D, R>,
) -> Tensor<S, E, D, T>
where
D: BinaryKernel<BinaryAtan2KernelOp, E>,
{
lhs.try_atan2(rhs).unwrap()
}


// pub fn add<S: Shape, E: Dtype, D, T: Tape<E, D> + Merge<R>, R: Default>(
// lhs: Tensor<S, E, D, T>,
// rhs: Tensor<S, E, D, R>,
// ) -> Tensor<S, E, D, T>
// where
// D: BinaryKernel<BinaryAddKernelOp, E>,
// {
// lhs + rhs
// }

/// Fallible version of [std::ops::Add]. See [add]
pub trait TryAtan2<Rhs = Self> {
type Output;

fn try_atan2(self, rhs: Rhs) -> Result<Self::Output, Error>;
}

impl<S: Shape, E: Dtype, D, LhsTape: Tape<E, D>, R> TryAtan2<Tensor<S, E, D, R>>
for Tensor<S, E, D, LhsTape>
where
D: BinaryKernel<BinaryAtan2KernelOp, E>,
LhsTape: Merge<R>,
{
type Output = Self;

/// See [add]
fn try_atan2(self, rhs: Tensor<S, E, D, R>) -> Result<Self, Error> {
try_binary_op(BinaryAtan2KernelOp, self, rhs)
}
}
4 changes: 4 additions & 0 deletions dfdx-core/src/tensor_ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@ mod accurate_gelu;
mod adam;
mod add;
mod attention_reshape;
mod atan2;
mod atan;
pub(crate) mod axpy;
mod bce;
mod boolean;
Expand Down Expand Up @@ -211,6 +213,8 @@ mod upscale2d;
mod var_to;

pub use abs::abs;
pub use atan::{atan};
pub use atan2::{atan2,TryAtan2};
pub use accurate_gelu::accurate_gelu;
pub use adam::AdamConfig;
pub use add::{add, TryAdd};
Expand Down

0 comments on commit 610b3e1

Please sign in to comment.