Skip to content

Commit

Permalink
fix: Fixed missing parameters in "dynamic" (#56)
Browse files Browse the repository at this point in the history
* fix: Fixed missing parameters in "dynamic"

Fixed the missing parameters that are intentionally
None in the "dynamic" section. Now StateDict
contains dynamic parameters. Also introduced
'_sanitize' function to make None to empty
tensor of size 0 since safetensors format
doesn't except None.

* test: Fix state dict comparison with empty torch

Fixed the comparison for StateDicts. Since comparing
empty tensors doesn't work, I created isEquals internal
function to the 'test_from_params' test to compare
values one by one.

* test: Create utility and update tests

* test: Refactored how helpers are setup in tests

* refactor: Extract class name to variable
  • Loading branch information
lsetiawan authored Jan 19, 2024
1 parent 40e7612 commit cc94a43
Show file tree
Hide file tree
Showing 8 changed files with 132 additions and 22 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
exclude: |
(?x)^(
tests/helpers/|
tests/utils.py
)
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,6 @@ local_scheme = "no-local-version"
[tool.ruff]
# Same as Black.
line-length = 100

[tool.pytest.ini_options]
norecursedirs = "tests/helpers"
81 changes: 71 additions & 10 deletions src/caustics/sims/state_dict.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,38 @@
from datetime import datetime as dt
from collections import OrderedDict
from typing import Any, Dict
from typing import Any, Dict, Optional

from torch import Tensor
import torch
from .._version import __version__
from ..namespace_dict import NamespaceDict, NestedNamespaceDict

from safetensors.torch import save

IMMUTABLE_ERR = TypeError("'StateDict' cannot be modified after creation.")
STATIC_PARAMS = "static"
PARAM_KEYS = ["dynamic", "static"]


def _sanitize(tensors_dict: Dict[str, Optional[Tensor]]) -> Dict[str, Tensor]:
"""
Sanitize the input dictionary of tensors by
replacing Nones with tensors of size 0.
Parameters
----------
tensors_dict : dict
A dictionary of tensors, including None.
Returns
-------
dict
A dictionary of tensors, with empty tensors
replaced by tensors of size 0.
"""
return {
k: v if isinstance(v, Tensor) else torch.ones(0)
for k, v in tensors_dict.items()
}


class StateDict(OrderedDict):
Expand Down Expand Up @@ -41,6 +64,15 @@ def __setitem__(self, key: str, value: Any) -> None:
raise IMMUTABLE_ERR
super().__setitem__(key, value)

def __repr__(self) -> str:
state_dict_list = [
(k, v) if v.nelement() > 0 else (k, None) for k, v in self.items()
]
class_name = self.__class__.__name__
if not state_dict_list:
return "%s()" % (class_name,)
return "%s(%r)" % (class_name, state_dict_list)

@classmethod
def from_params(cls, params: "NestedNamespaceDict | NamespaceDict"):
"""Class method to create a StateDict
Expand All @@ -59,26 +91,55 @@ def from_params(cls, params: "NestedNamespaceDict | NamespaceDict"):
StateDict
A state dictionary object
"""
if isinstance(params, NestedNamespaceDict) and STATIC_PARAMS in params:
params: NamespaceDict = params[STATIC_PARAMS].flatten()
tensors_dict: Dict[str, Tensor] = {k: v.value for k, v in params.items()}
if not isinstance(params, (NamespaceDict, NestedNamespaceDict)):
raise TypeError("params must be a NamespaceDict or NestedNamespaceDict")

if isinstance(params, NestedNamespaceDict):
# In this case, params is the full parameters
# with both "static" and "dynamic" keys
if sorted(params.keys()) != PARAM_KEYS:
raise ValueError(f"params must have keys {PARAM_KEYS}")

# Extract the "static" and "dynamic" parameters
param_dicts = list(params.values())

# Extract the "static" and "dynamic" parameters
# to a single merged dictionary
final_dict = NestedNamespaceDict()
for pdict in param_dicts:
for k, v in pdict.items():
if k not in final_dict:
final_dict[k] = v
else:
final_dict[k] = {**final_dict[k], **v}

# Flatten the dictionary to a single level
params: NamespaceDict = final_dict.flatten()

tensors_dict: Dict[str, Tensor] = _sanitize(
{k: v.value for k, v in params.items()}
)
return cls(tensors_dict)

def to_params(self) -> NamespaceDict:
def to_params(self) -> NestedNamespaceDict:
"""
Convert the state dict to a dictionary of parameters.
Convert the state dict to
a nested dictionary of parameters.
Returns
-------
NamespaceDict
A dictionary of 'static' parameters.
NestedNamespaceDict
A nested dictionary of parameters.
"""
from ..parameter import Parameter

params = NamespaceDict()
for k, v in self.items():
if v.nelement() == 0:
# Set to None if the tensor is empty
v = None
params[k] = Parameter(v)
return params
return NestedNamespaceDict(params)

def _to_safetensors(self) -> bytes:
return save(self, metadata=self._metadata)
5 changes: 5 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import sys
import os

# Add the helpers directory to the path so we can import the helpers
sys.path.append(os.path.join(os.path.dirname(__file__), "helpers"))
Empty file added tests/helpers/__init__.py
Empty file.
43 changes: 43 additions & 0 deletions tests/helpers/sims.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from caustics.namespace_dict import NestedNamespaceDict
from caustics.sims.state_dict import _sanitize


def extract_tensors(params, include_params=False):
# Extract the "static" and "dynamic" parameters
param_dicts = list(params.values())

# Extract the "static" and "dynamic" parameters
# to a single merged dictionary
final_dict = NestedNamespaceDict()
for pdict in param_dicts:
for k, v in pdict.items():
if k not in final_dict:
final_dict[k] = v
else:
final_dict[k] = {**final_dict[k], **v}

# flatten function only exists for NestedNamespaceDict
all_params = final_dict.flatten()

tensors_dict = _sanitize({k: v.value for k, v in all_params.items()})
if include_params:
return tensors_dict, all_params
return tensors_dict


def isEquals(a, b):
# Go through each key and values
# change empty torch to be None
# since we can't directly compare
# empty torch
truthy = []
for k, v in a.items():
if k not in b:
return False
kv = b[k]
if (v.nelement() == 0) or (kv.nelement() == 0):
v = None
kv = None
truthy.append(v == kv)

return all(truthy)
6 changes: 3 additions & 3 deletions tests/sims/test_simulator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest

from caustics.sims.state_dict import StateDict
from helpers.sims import extract_tensors, isEquals


@pytest.fixture
Expand All @@ -10,8 +11,7 @@ def state_dict(simple_common_sim):

@pytest.fixture
def expected_tensors(simple_common_sim):
static_params = simple_common_sim.params["static"].flatten()
return {k: v.value for k, v in static_params.items()}
return extract_tensors(simple_common_sim.params)


class TestSimulator:
Expand All @@ -28,4 +28,4 @@ def test_state_dict(self, state_dict, expected_tensors):
assert "created_time" in state_dict._metadata

# Check params
assert dict(state_dict) == expected_tensors
assert isEquals(dict(state_dict), expected_tensors)
15 changes: 6 additions & 9 deletions tests/sims/test_state_dict.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from typing import Dict
import pytest
import torch
from safetensors.torch import save, load
Expand All @@ -8,6 +7,8 @@
from caustics.sims.state_dict import StateDict, IMMUTABLE_ERR
from caustics import __version__

from helpers.sims import extract_tensors, isEquals


class TestStateDict:
simple_tensors = {"var1": torch.as_tensor(1.0), "var2": torch.as_tensor(2.0)}
Expand Down Expand Up @@ -43,21 +44,17 @@ def test_delitem(self, simple_state_dict):
def test_from_params(self, simple_common_sim):
params: NestedNamespaceDict = simple_common_sim.params

# flatten function only exists for NestedNamespaceDict
static_params: NamespaceDict = params["static"].flatten()
tensors_dict: Dict[str, torch.Tensor] = {
k: v.value for k, v in static_params.items()
}
tensors_dict, all_params = extract_tensors(params, True)

expected_state_dict = StateDict(tensors_dict)

# Full parameters
state_dict = StateDict.from_params(params)
assert state_dict == expected_state_dict
assert isEquals(state_dict, expected_state_dict)

# Static only
state_dict = StateDict.from_params(static_params)
assert state_dict == expected_state_dict
state_dict = StateDict.from_params(all_params)
assert isEquals(state_dict, expected_state_dict)

def test_to_params(self, simple_state_dict):
params = simple_state_dict.to_params()
Expand Down

0 comments on commit cc94a43

Please sign in to comment.