Skip to content

Commit

Permalink
tests update
Browse files Browse the repository at this point in the history
  • Loading branch information
plutonium-239 committed Apr 20, 2024
1 parent b016002 commit e32ac8f
Showing 1 changed file with 21 additions and 17 deletions.
38 changes: 21 additions & 17 deletions test/test_layers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Tests for the whole memsave_torch.nn module (all layers and the convert function)"""

from typing import Type
from typing import List, Type, Union

import pytest
import torch
Expand All @@ -13,36 +13,40 @@

x = torch.rand(7, 3, 12, 12)
cases = [
[torch.nn.Linear(3, 5), x[:, :, 0, 0]],
[torch.nn.Linear(3, 5), x[:, :, :, 0].permute(0, 2, 1)], # weight sharing
[torch.nn.Linear(3, 5), x.permute(0, 2, 3, 1)], # weight sharing
[torch.nn.Conv2d(3, 5, 3), x],
[torch.nn.Conv1d(3, 5, 3), x[:, :, 0]],
[torch.nn.BatchNorm2d(3), x],
[torch.nn.LayerNorm(normalized_shape=[3, 12, 12]), x],
[torch.nn.MaxPool2d(3), x],
[torch.nn.ReLU(), x],
[torch.nn.Linear, [3, 5], x[:, :, 0, 0]],
[torch.nn.Linear, [3, 5], x[:, :, :, 0].permute(0, 2, 1)], # weight sharing
[torch.nn.Linear, [3, 5], x.permute(0, 2, 3, 1)], # weight sharing
[torch.nn.Conv2d, [3, 5, 3], x],
[torch.nn.Conv1d, [3, 5, 3], x[:, :, 0]],
[torch.nn.BatchNorm2d, [3], x],
[torch.nn.LayerNorm, [[3, 12, 12]], x],
[torch.nn.MaxPool2d, [3], x],
[torch.nn.ReLU, [], x],
]


@pytest.mark.quick
@pytest.mark.parametrize("device", devices)
@pytest.mark.parametrize("layer,x", cases)
def test_single_layer(layer: torch.nn.Module, x: torch.Tensor, device: str) -> bool:
"""Runs tests for the layer defined by `layer`
@pytest.mark.parametrize("layer_cls,layer_init,x", cases)
def test_single_layer(
layer_cls: Type[torch.nn.Module],
layer_init: List[Union[int, List[int]]],
x: torch.Tensor,
device: str,
):
"""Runs tests for the layer_cls defined by `layer`
This tests for equality of outputs on forward pass and equality of the gradients
on backward pass, for all parameters and input as well.
Args:
layer (torch.nn.Module): torch.nn layer to test it's memsave counterpart
layer_cls (Type[torch.nn.Module]): torch.nn layer class to test it's memsave counterpart
layer_init (List[Union[int, List[int]]]): layer initialization parameters
x (torch.Tensor): Input tensor (B, C, H, W); will be reshaped properly based on layer
device (str): device
Returns:
bool: Description
"""
x = x.to(device)
layer = layer_cls(*layer_init)
layer.to(device)
memsave_layer = memsave_torch.nn.convert_to_memory_saving(layer, clone_params=True)
# clone_params is neede here because we want to backprop through both layer and memsave_layer
Expand Down

0 comments on commit e32ac8f

Please sign in to comment.