Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: Added tests for completeness of StateDict and Simulator #59

Merged
merged 7 commits into from
Jan 23, 2024
4 changes: 3 additions & 1 deletion src/caustics/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ def _normalize_path(path: "str | Path") -> Path:
# Convert string path to Path object
if isinstance(path, str):
path = Path(path)
return path

# Get absolute path
return path.absolute()


def to_file(
Expand Down
3 changes: 2 additions & 1 deletion src/caustics/sims/state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections import OrderedDict
from typing import Any, Dict, Optional
from pathlib import Path
import os

from torch import Tensor
import torch
Expand Down Expand Up @@ -184,7 +185,7 @@ def save(self, file_path: Optional[str] = None) -> str:
The final path of the saved file
"""
if not file_path:
file_path = Path(".") / self.__st_file
file_path = Path(os.path.curdir) / self.__st_file
elif isinstance(file_path, str):
file_path = Path(file_path)

Expand Down
39 changes: 39 additions & 0 deletions tests/sims/test_simulator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import pytest
from pathlib import Path
import sys

import torch

from caustics.sims.state_dict import StateDict
from helpers.sims import extract_tensors
Expand Down Expand Up @@ -29,3 +33,38 @@ def test_state_dict(self, state_dict, expected_tensors):

# Check params
assert dict(state_dict) == expected_tensors

def test_set_module_params(self, simple_common_sim):
params = {"param1": torch.as_tensor(1), "param2": torch.as_tensor(2)}
# Call the __set_module_params method
simple_common_sim._Simulator__set_module_params(simple_common_sim, params)

# Check if the module attributes have been set correctly
assert simple_common_sim.param1 == params["param1"]
assert simple_common_sim.param2 == params["param2"]

def test_load_state_dict(self, simple_common_sim):
fpath = simple_common_sim.state_dict().save()
loaded_state_dict = StateDict.load(fpath)

# Change a value in the simulator
simple_common_sim.z_s = 3.0

# Ensure that the simulator has been changed
assert (
loaded_state_dict[f"{simple_common_sim.name}.z_s"]
!= simple_common_sim.z_s.value
)

# Load the state dict form file
simple_common_sim.load_state_dict(fpath)

# Once loaded now the values should be the same
assert (
loaded_state_dict[f"{simple_common_sim.name}.z_s"]
== simple_common_sim.z_s.value
)

# Cleanup after only for non-windows
if not sys.platform.startswith("win"):
Path(fpath).unlink(missing_ok=True)
105 changes: 99 additions & 6 deletions tests/sims/test_state_dict.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,45 @@
from pathlib import Path
from tempfile import TemporaryDirectory
import os
import sys

import pytest
import torch
from collections import OrderedDict
from safetensors.torch import save, load
from datetime import datetime as dt
from caustics.parameter import Parameter
from caustics.namespace_dict import NamespaceDict, NestedNamespaceDict
from caustics.sims.state_dict import StateDict, IMMUTABLE_ERR, _sanitize
from caustics.sims.state_dict import ImmutableODict, StateDict, IMMUTABLE_ERR, _sanitize
from caustics import __version__

from helpers.sims import extract_tensors


class TestImmutableODict:
def test_constructor(self):
odict = ImmutableODict(a=1, b=2, c=3)
assert isinstance(odict, OrderedDict)
assert odict == {"a": 1, "b": 2, "c": 3}
assert hasattr(odict, "_created")
assert odict._created is True

def test_setitem(self):
odict = ImmutableODict()
with pytest.raises(type(IMMUTABLE_ERR), match=str(IMMUTABLE_ERR)):
odict["key"] = "value"

def test_delitem(self):
odict = ImmutableODict(key="value")
with pytest.raises(type(IMMUTABLE_ERR), match=str(IMMUTABLE_ERR)):
del odict["key"]

def test_setattr(self):
odict = ImmutableODict()
with pytest.raises(type(IMMUTABLE_ERR), match=str(IMMUTABLE_ERR)):
odict.meta = {"key": "value"}


class TestStateDict:
simple_tensors = {"var1": torch.as_tensor(1.0), "var2": torch.as_tensor(2.0)}

Expand All @@ -33,6 +63,15 @@ def test_constructor(self):
assert sd_ct_str == time_str_now
assert dict(state_dict) == self.simple_tensors

def test_constructor_with_metadata(self):
time_format = "%Y-%m-%dT%H:%M:%S"
time_str_now = dt.utcnow().strftime(time_format)
metadata = {"created_time": time_str_now, "software_version": "0.0.1"}
state_dict = StateDict(metadata=metadata, **self.simple_tensors)

assert isinstance(state_dict._metadata, ImmutableODict)
assert dict(state_dict._metadata) == dict(metadata)

def test_setitem(self, simple_state_dict):
with pytest.raises(type(IMMUTABLE_ERR), match=str(IMMUTABLE_ERR)):
simple_state_dict["var1"] = torch.as_tensor(3.0)
Expand All @@ -56,14 +95,26 @@ def test_from_params(self, simple_common_sim):
state_dict = StateDict.from_params(all_params)
assert state_dict == expected_state_dict

def test_to_params(self, simple_state_dict):
params = simple_state_dict.to_params()
# Check for TypeError when passing a NamespaceDict or NestedNamespaceDict
with pytest.raises(TypeError):
StateDict.from_params({"a": 1, "b": 2})

# Check for TypeError when passing a NestedNamespaceDict
# without the "static" and "dynamic" keys
with pytest.raises(ValueError):
StateDict.from_params(NestedNamespaceDict({"a": 1, "b": 2}))

def test_to_params(self):
params_with_none = {"var3": torch.ones(0), **self.simple_tensors}
state_dict = StateDict(**params_with_none)
params = StateDict(**params_with_none).to_params()
assert isinstance(params, NamespaceDict)

for k, v in params.items():
tensor_value = simple_state_dict[k]
assert isinstance(v, Parameter)
assert v.value == tensor_value
tensor_value = state_dict[k]
if tensor_value.nelement() > 0:
assert isinstance(v, Parameter)
assert v.value == tensor_value

def test__to_safetensors(self):
state_dict = StateDict(**self.simple_tensors)
Expand All @@ -78,3 +129,45 @@ def test__to_safetensors(self):
loaded_tensors = load(tensors_bytes)
loaded_expected_tensors = load(expected_bytes)
assert loaded_tensors == loaded_expected_tensors

def test_st_file_string(self, simple_state_dict):
file_format = "%Y%m%dT%H%M%S_caustics.st"
expected_file = simple_state_dict._created_time.strftime(file_format)

assert simple_state_dict._StateDict__st_file == expected_file

def test_save(self, simple_state_dict):
# Check for default save path
expected_fpath = Path(os.path.curdir) / simple_state_dict._StateDict__st_file
default_fpath = simple_state_dict.save()

assert Path(default_fpath).exists()
assert default_fpath == str(expected_fpath.absolute())

# Cleanup after only for non-windows
if not sys.platform.startswith("win"):
Path(default_fpath).unlink(missing_ok=True)

# Check for specified save path
with TemporaryDirectory() as tempdir:
tempdir = Path(tempdir)
# Correct extension and path in a tempdir
fpath = tempdir / "test.st"
saved_path = simple_state_dict.save(str(fpath.absolute()))

assert Path(saved_path).exists()
assert saved_path == str(fpath.absolute())

# Wrong extension
wrong_fpath = tempdir / "test.txt"
with pytest.raises(ValueError):
saved_path = simple_state_dict.save(str(wrong_fpath.absolute()))

def test_load(self, simple_state_dict):
fpath = simple_state_dict.save()
loaded_state_dict = StateDict.load(fpath)
assert loaded_state_dict == simple_state_dict

# Cleanup after only for non-windows
if not sys.platform.startswith("win"):
Path(fpath).unlink(missing_ok=True)
63 changes: 63 additions & 0 deletions tests/test_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from pathlib import Path
import tempfile
import struct
import json
import torch
from safetensors.torch import save
from caustics.io import (
_get_safetensors_header,
_normalize_path,
to_file,
from_file,
get_safetensors_metadata,
)


def test_normalize_path():
# Test with a string path
path_str = "/path/to/file.txt"
normalized_path = _normalize_path(path_str)
assert normalized_path == Path(path_str)
assert str(normalized_path), path_str

# Test with a Path object
path_obj = Path("/path/to/file.txt")
normalized_path = _normalize_path(path_obj)
assert normalized_path == path_obj


def test_to_and_from_file():
with tempfile.TemporaryDirectory() as tmpdir:
fpath = Path(tmpdir) / "test.txt"
data = "test data"

# Test to file
ffile = to_file(fpath, data)

assert Path(ffile).exists()
assert ffile == str(fpath.absolute())
assert Path(ffile).read_text() == data

# Test from file
assert from_file(fpath) == data.encode("utf-8")


def test_get_safetensors_metadata():
with tempfile.TemporaryDirectory() as tmpdir:
fpath = Path(tmpdir) / "test.st"
meta_dict = {"meta": "data"}
tensors_bytes = save({"test1": torch.as_tensor(1.0)}, metadata=meta_dict)
fpath.write_bytes(tensors_bytes)

# Manually get header
first_bytes_length = 8
(length_of_header,) = struct.unpack("<Q", tensors_bytes[:first_bytes_length])
expected_header = json.loads(
tensors_bytes[first_bytes_length : first_bytes_length + length_of_header]
)

# Test for get header only
assert _get_safetensors_header(fpath) == expected_header

# Test for get metadata only
assert get_safetensors_metadata(fpath) == meta_dict
Loading