Skip to content

Commit

Permalink
tests improve readability and reduce tolerances
Browse files Browse the repository at this point in the history
  • Loading branch information
plutonium-239 committed Apr 27, 2024
1 parent cfc72d8 commit b2b4661
Showing 1 changed file with 40 additions and 23 deletions.
63 changes: 40 additions & 23 deletions test/test_layers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Tests for the whole memsave_torch.nn module (all layers and the convert function)"""

from typing import List, Type, Union
from typing import Callable, Dict, Union

import pytest
import torch
Expand All @@ -11,27 +11,47 @@
if torch.cuda.is_available():
devices.append("cuda")

x = torch.rand(7, 3, 12, 12)
torch.manual_seed(239)

cases = [
[torch.nn.Linear, [3, 5], x[:, :, 0, 0]],
[torch.nn.Linear, [3, 5], x[:, :, :, 0].permute(0, 2, 1)], # weight sharing
[torch.nn.Linear, [3, 5], x.permute(0, 2, 3, 1)], # weight sharing
[torch.nn.Conv2d, [3, 5, 3], x],
[torch.nn.Conv1d, [3, 5, 3], x[:, :, 0]],
[torch.nn.BatchNorm2d, [3], x],
[torch.nn.LayerNorm, [[3, 12, 12]], x],
[torch.nn.MaxPool2d, [3], x],
[torch.nn.ReLU, [], x],
{"layer_fn": lambda: torch.nn.Linear(3, 5), "data_fn": lambda: torch.rand(7, 3)},
{
"layer_fn": lambda: torch.nn.Linear(3, 5),
"data_fn": lambda: torch.rand(7, 12, 3), # weight sharing
},
{
"layer_fn": lambda: torch.nn.Linear(3, 5),
"data_fn": lambda: torch.rand(7, 12, 12, 3), # weight sharing
},
{
"layer_fn": lambda: torch.nn.Conv2d(3, 5, 3),
"data_fn": lambda: torch.rand(7, 3, 12, 12),
},
{
"layer_fn": lambda: torch.nn.Conv1d(3, 5, 3),
"data_fn": lambda: torch.rand(7, 3, 12),
},
{
"layer_fn": lambda: torch.nn.BatchNorm2d(3),
"data_fn": lambda: torch.rand(7, 3, 12, 12),
},
{
"layer_fn": lambda: torch.nn.LayerNorm([3, 12, 12]),
"data_fn": lambda: torch.rand(7, 3, 12, 12),
},
{
"layer_fn": lambda: torch.nn.MaxPool2d(3),
"data_fn": lambda: torch.rand(7, 3, 12, 12),
},
{"layer_fn": lambda: torch.nn.ReLU(), "data_fn": lambda: torch.rand(7, 3, 12, 12)},
]


@pytest.mark.quick
@pytest.mark.parametrize("case", cases)
@pytest.mark.parametrize("device", devices)
@pytest.mark.parametrize("layer_cls,layer_init,x", cases)
def test_single_layer(
layer_cls: Type[torch.nn.Module],
layer_init: List[Union[int, List[int]]],
x: torch.Tensor,
case: Dict[str, Callable[[], Union[torch.Tensor, torch.nn.Module]]],
device: str,
):
"""Runs tests for the layer_cls defined by `layer`
Expand All @@ -40,16 +60,13 @@ def test_single_layer(
on backward pass, for all parameters and input as well.
Args:
layer_cls (Type[torch.nn.Module]): torch.nn layer class to test it's memsave counterpart
layer_init (List[Union[int, List[int]]]): layer initialization parameters
x (torch.Tensor): Input tensor (B, C, H, W); will be reshaped properly based on layer
case (Dict[str, Callable[[Union[torch.Tensor, torch.nn.Module]]]]): Case dictionary specifying layer_fn and data_fn
device (str): device
"""
x = x.to(device)
layer = layer_cls(*layer_init)
x = case["data_fn"]().to(device)
layer = case["layer_fn"]()
layer.to(device)
memsave_layer = memsave_torch.nn.convert_to_memory_saving(layer, clone_params=True)
# clone_params is neede here because we want to backprop through both layer and memsave_layer

x1 = x.clone().detach()
x1.requires_grad = True
Expand All @@ -65,8 +82,8 @@ def test_single_layer(
atol = 1e-8 # defaults
rtol = 1e-5 # defaults
elif device == "cuda":
atol = 1e-4
rtol = 1e-2
atol = 1e-5
rtol = 1e-4
assert torch.allclose(y1, y2, rtol=rtol, atol=atol)
assert torch.allclose(x1.grad, x2.grad, rtol=rtol, atol=atol)
for p1, p2 in zip(layer.parameters(), memsave_layer.parameters()):
Expand Down

0 comments on commit b2b4661

Please sign in to comment.