Skip to content

Commit

Permalink
change structure - add functional sub-module
Browse files Browse the repository at this point in the history
  • Loading branch information
plutonium-239 committed Apr 25, 2024
1 parent e32ac8f commit cfc72d8
Show file tree
Hide file tree
Showing 16 changed files with 574 additions and 465 deletions.
108 changes: 2 additions & 106 deletions memsave_torch/nn/BatchNorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
This is done by not saving the inputs/weights if weight/inputs dont require grad.
"""

import torch
import torch.nn as nn

from memsave_torch.nn.functional import batch_normMemSave


class MemSaveBatchNorm2d(nn.BatchNorm2d):
"""MemSaveBatchNorm2d."""
Expand Down Expand Up @@ -79,108 +80,3 @@ def from_nn_BatchNorm2d(cls, bn2d: nn.BatchNorm2d):
obj.running_mean = bn2d.running_mean
obj.running_var = bn2d.running_var
return obj


class _MemSaveBatchNorm(torch.autograd.Function):
@staticmethod
def forward(
ctx, x, running_mean, running_var, weight, bias, training, momentum, eps
):
"""torch.native_batch_norm is the same as torch.ops.aten.native_batch_norm
Not using functional.batch_norm here because we need the `save_mean` and `save_invstd` values
returned by torch ops in the backward pass (it is the forwarded batch's "stable" mean and invstd)
Also, we need to fuse forward and setup_context here because
we dont want to make save_mean and save_invstd as outputs but need to save them in ctx
"""
outputs = torch.native_batch_norm(
x, weight, bias, running_mean, running_var, training, momentum, eps
)

# print('setting up context', ctx.needs_input_grad)
ctx.save_mean = outputs[1]
ctx.save_invstd = outputs[2]
ctx.running_mean = running_mean
ctx.running_var = running_var
ctx.training = training
ctx.momentum = momentum
ctx.eps = eps
ctx.x_shape = x.shape
ctx.weight_shape = weight.shape
ctx.device = x.device

need_grad = [] # save_mean and save_invstd
if ctx.needs_input_grad[0]:
need_grad.append(weight)
if ctx.needs_input_grad[3]:
need_grad.append(x)
# bias doesnt need anything for calc

ctx.save_for_backward(*need_grad)

return outputs[0]

@staticmethod
def backward(ctx, grad_output):
# print('backward', ctx.needs_input_grad)
x = weight = None
current_idx = 0
if ctx.needs_input_grad[0]:
weight = ctx.saved_tensors[current_idx]
current_idx += 1
if ctx.needs_input_grad[3]:
x = ctx.saved_tensors[current_idx]
current_idx += 1

if x is None:
x = torch.zeros(ctx.x_shape, device=ctx.device)
if weight is None:
weight = torch.zeros(ctx.weight_shape, device=ctx.device)

# print(current_idx)

grad_x, grad_weight, grad_bias = torch.ops.aten.native_batch_norm_backward(
grad_output,
x,
weight,
ctx.running_mean,
ctx.running_var,
ctx.save_mean,
ctx.save_invstd,
ctx.training,
ctx.eps,
[ctx.needs_input_grad[0], ctx.needs_input_grad[3], ctx.needs_input_grad[4]],
)

return grad_x, None, None, grad_weight, grad_bias, None, None, None


def batch_normMemSave(
input,
running_mean,
running_var,
weight=None,
bias=None,
training=False,
momentum=0.1,
eps=1e-05,
):
"""Functional form of the memory saving batch_norm.
Args:
input (TYPE): Input to the network [B, C, H, W]
running_mean: running_mean
running_var: running_var
weight: weight
bias: bias
training: training
momentum: momentum
eps: eps
Returns:
torch.Tensor: Output of the network [B, C, H, W]
"""
return _MemSaveBatchNorm.apply(
input, running_mean, running_var, weight, bias, training, momentum, eps
)
22 changes: 1 addition & 21 deletions memsave_torch/nn/Conv1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
import torch.nn as nn

from memsave_torch.nn.Conv2d import _MemSaveConv
from memsave_torch.nn.functional import conv1dMemSave


class MemSaveConv1d(nn.Conv1d):
Expand Down Expand Up @@ -100,23 +100,3 @@ def from_nn_Conv1d(cls, conv1d: nn.Conv1d):
obj.weight = conv1d.weight
obj.bias = conv1d.bias
return obj


def conv1dMemSave(
input, weight, bias, stride, padding, dilation, groups
) -> torch.Tensor:
"""Functional form of the memory saving convolution.
Args:
input: input [B, C_in, H, W]
weight: weight
bias: bias
stride: stride
padding: padding
dilation: dilation
groups: groups
Returns:
torch.Tensor: Output of the conv operation [B, C_out, H_out, W_out]
"""
return _MemSaveConv.apply(input, weight, bias, stride, padding, dilation, groups)
91 changes: 2 additions & 89 deletions memsave_torch/nn/Conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import torch
import torch.nn as nn

from memsave_torch.nn.functional import conv2dMemSave


class MemSaveConv2d(nn.Conv2d):
"""MemSaveConv2d."""
Expand Down Expand Up @@ -98,92 +100,3 @@ def from_nn_Conv2d(cls, conv2d: nn.Conv2d):
obj.weight = conv2d.weight
obj.bias = conv2d.bias
return obj


class _MemSaveConv(torch.autograd.Function):
@staticmethod
def forward(x, weight, bias, stride, padding, dilation, groups):
return torch.ops.aten.convolution(
x,
weight,
bias,
stride,
padding,
dilation,
False,
tuple([0] * len(padding)),
groups,
)

@staticmethod
def setup_context(ctx, inputs, output):
x, weight, bias, stride, padding, dilation, groups = inputs
need_grad = []
if ctx.needs_input_grad[0]:
need_grad.append(weight)
if ctx.needs_input_grad[1]:
need_grad.append(x)
# bias doesnt need anything for calc
ctx.bias_exists = bias is not None
ctx.stride = stride
ctx.padding = padding
ctx.dilation = dilation
ctx.groups = groups
ctx.x_shape = x.shape
ctx.weight_shape = weight.shape
ctx.device = x.device

ctx.save_for_backward(*need_grad)

@staticmethod
def backward(ctx, grad_output):
x = weight = None

current_idx = 0
if ctx.needs_input_grad[0]:
weight = ctx.saved_tensors[current_idx]
current_idx += 1
elif ctx.needs_input_grad[1]:
x = ctx.saved_tensors[current_idx]
current_idx += 1

if x is None:
x = torch.zeros(ctx.x_shape, device=ctx.device)
if weight is None:
weight = torch.zeros(ctx.weight_shape, device=ctx.device)

grad_x, grad_weight, grad_bias = torch.ops.aten.convolution_backward(
grad_output,
x,
weight,
[weight.shape[0]] if ctx.bias_exists else None,
ctx.stride,
ctx.padding,
ctx.dilation,
False,
[0],
ctx.groups,
ctx.needs_input_grad[:3],
)

return grad_x, grad_weight, grad_bias, None, None, None, None, None


def conv2dMemSave(
input, weight, bias, stride, padding, dilation, groups
) -> torch.Tensor:
"""Functional form of the memory saving convolution.
Args:
input: input [B, C_in, H, W]
weight: weight
bias: bias
stride: stride
padding: padding
dilation: dilation
groups: groups
Returns:
torch.Tensor: Output of the conv operation [B, C_out, H_out, W_out]
"""
return _MemSaveConv.apply(input, weight, bias, stride, padding, dilation, groups)
76 changes: 2 additions & 74 deletions memsave_torch/nn/LayerNorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import torch
import torch.nn as nn

from memsave_torch.nn.functional import layer_normMemSave


class MemSaveLayerNorm(nn.LayerNorm):
"""MemSaveLayerNorm."""
Expand Down Expand Up @@ -86,77 +88,3 @@ def from_nn_LayerNorm(cls, ln: nn.LayerNorm):
else:
obj.bias = ln.bias
return obj


class _MemSaveLayerNorm(torch.autograd.Function):
@staticmethod
def forward(ctx, x, normalized_shape, weight, bias, eps):
"""torch.native_layer_norm is the same as torch.ops.aten.native_layer_norm
Also, we need to fuse forward and setup_context here because
we dont want to make save_mean and save_invstd as outputs but need to save them in ctx
"""
outputs = torch.native_layer_norm(x, normalized_shape, weight, bias, eps)

ctx.mean = outputs[1]
ctx.rstd = outputs[2]
ctx.eps = eps
ctx.x_shape = x.shape
ctx.normalized_shape = normalized_shape
ctx.device = x.device

need_grad = [] # save_mean and save_invstd
if ctx.needs_input_grad[0]:
need_grad.append(weight)
if ctx.needs_input_grad[2]:
need_grad.append(x)
# bias doesnt need anything for calc

ctx.save_for_backward(*need_grad)

return outputs[0]

@staticmethod
def backward(ctx, grad_output):
x = weight = None
current_idx = 0
if ctx.needs_input_grad[0]:
weight = ctx.saved_tensors[current_idx]
current_idx += 1
if ctx.needs_input_grad[3]:
x = ctx.saved_tensors[current_idx]
current_idx += 1

if x is None:
x = torch.zeros(ctx.x_shape, device=ctx.device)
if weight is None:
weight = torch.zeros(ctx.normalized_shape, device=ctx.device)
bias = torch.zeros(ctx.normalized_shape, device=ctx.device)

grad_x, grad_weight, grad_bias = torch.ops.aten.native_layer_norm_backward(
grad_output,
x,
ctx.normalized_shape,
ctx.mean,
ctx.rstd,
weight,
bias,
[ctx.needs_input_grad[0], ctx.needs_input_grad[2], ctx.needs_input_grad[3]],
)

return grad_x, None, grad_weight, grad_bias, None


def layer_normMemSave(input, normalized_shape, weight=None, bias=None, eps=1e-05):
"""Functional form of the memory saving layer_norm.
Args:
input: Input to the network [B, C, H, W]
normalized_shape: normalized_shape
weight: weight
bias: bias
eps: eps
Returns:
torch.Tensor: Output of the network [B, C, H, W]
"""
return _MemSaveLayerNorm.apply(input, normalized_shape, weight, bias, eps)
Loading

0 comments on commit cfc72d8

Please sign in to comment.