Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Fixed missing parameters in "dynamic" #56

Merged
merged 5 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
exclude: |
(?x)^(
tests/helpers/|
tests/utils.py
)

Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,6 @@ local_scheme = "no-local-version"
[tool.ruff]
# Same as Black.
line-length = 100

[tool.pytest.ini_options]
norecursedirs = "tests/helpers"
81 changes: 71 additions & 10 deletions src/caustics/sims/state_dict.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,38 @@
from datetime import datetime as dt
from collections import OrderedDict
from typing import Any, Dict
from typing import Any, Dict, Optional

from torch import Tensor
import torch
from .._version import __version__
from ..namespace_dict import NamespaceDict, NestedNamespaceDict

from safetensors.torch import save

IMMUTABLE_ERR = TypeError("'StateDict' cannot be modified after creation.")
STATIC_PARAMS = "static"
PARAM_KEYS = ["dynamic", "static"]


def _sanitize(tensors_dict: Dict[str, Optional[Tensor]]) -> Dict[str, Tensor]:
"""
Sanitize the input dictionary of tensors by
replacing Nones with tensors of size 0.

Parameters
----------
tensors_dict : dict
A dictionary of tensors, including None.

Returns
-------
dict
A dictionary of tensors, with empty tensors
replaced by tensors of size 0.
"""
return {
k: v if isinstance(v, Tensor) else torch.ones(0)
for k, v in tensors_dict.items()
}


class StateDict(OrderedDict):
Expand Down Expand Up @@ -41,6 +64,15 @@ def __setitem__(self, key: str, value: Any) -> None:
raise IMMUTABLE_ERR
super().__setitem__(key, value)

def __repr__(self) -> str:
state_dict_list = [
(k, v) if v.nelement() > 0 else (k, None) for k, v in self.items()
]
class_name = self.__class__.__name__
if not state_dict_list:
return "%s()" % (class_name,)
return "%s(%r)" % (class_name, state_dict_list)

@classmethod
def from_params(cls, params: "NestedNamespaceDict | NamespaceDict"):
"""Class method to create a StateDict
Expand All @@ -59,26 +91,55 @@ def from_params(cls, params: "NestedNamespaceDict | NamespaceDict"):
StateDict
A state dictionary object
"""
if isinstance(params, NestedNamespaceDict) and STATIC_PARAMS in params:
params: NamespaceDict = params[STATIC_PARAMS].flatten()
tensors_dict: Dict[str, Tensor] = {k: v.value for k, v in params.items()}
if not isinstance(params, (NamespaceDict, NestedNamespaceDict)):
raise TypeError("params must be a NamespaceDict or NestedNamespaceDict")

if isinstance(params, NestedNamespaceDict):
# In this case, params is the full parameters
# with both "static" and "dynamic" keys
if sorted(params.keys()) != PARAM_KEYS:
raise ValueError(f"params must have keys {PARAM_KEYS}")

# Extract the "static" and "dynamic" parameters
param_dicts = list(params.values())

# Extract the "static" and "dynamic" parameters
# to a single merged dictionary
final_dict = NestedNamespaceDict()
for pdict in param_dicts:
for k, v in pdict.items():
if k not in final_dict:
final_dict[k] = v
else:
final_dict[k] = {**final_dict[k], **v}

# Flatten the dictionary to a single level
params: NamespaceDict = final_dict.flatten()

tensors_dict: Dict[str, Tensor] = _sanitize(
{k: v.value for k, v in params.items()}
)
return cls(tensors_dict)

def to_params(self) -> NamespaceDict:
def to_params(self) -> NestedNamespaceDict:
"""
Convert the state dict to a dictionary of parameters.
Convert the state dict to
a nested dictionary of parameters.

Returns
-------
NamespaceDict
A dictionary of 'static' parameters.
NestedNamespaceDict
A nested dictionary of parameters.
"""
from ..parameter import Parameter

params = NamespaceDict()
for k, v in self.items():
if v.nelement() == 0:
# Set to None if the tensor is empty
v = None
params[k] = Parameter(v)
return params
return NestedNamespaceDict(params)

def _to_safetensors(self) -> bytes:
return save(self, metadata=self._metadata)
5 changes: 5 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import sys
import os

# Add the helpers directory to the path so we can import the helpers
sys.path.append(os.path.join(os.path.dirname(__file__), "helpers"))
Empty file added tests/helpers/__init__.py
Empty file.
43 changes: 43 additions & 0 deletions tests/helpers/sims.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from caustics.namespace_dict import NestedNamespaceDict
from caustics.sims.state_dict import _sanitize


def extract_tensors(params, include_params=False):
# Extract the "static" and "dynamic" parameters
param_dicts = list(params.values())

# Extract the "static" and "dynamic" parameters
# to a single merged dictionary
final_dict = NestedNamespaceDict()
for pdict in param_dicts:
for k, v in pdict.items():
if k not in final_dict:
final_dict[k] = v
else:
final_dict[k] = {**final_dict[k], **v}

# flatten function only exists for NestedNamespaceDict
all_params = final_dict.flatten()

tensors_dict = _sanitize({k: v.value for k, v in all_params.items()})
if include_params:
return tensors_dict, all_params
return tensors_dict


def isEquals(a, b):
# Go through each key and values
# change empty torch to be None
# since we can't directly compare
# empty torch
truthy = []
for k, v in a.items():
if k not in b:
return False
kv = b[k]
if (v.nelement() == 0) or (kv.nelement() == 0):
v = None
kv = None
truthy.append(v == kv)

return all(truthy)
6 changes: 3 additions & 3 deletions tests/sims/test_simulator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest

from caustics.sims.state_dict import StateDict
from helpers.sims import extract_tensors, isEquals


@pytest.fixture
Expand All @@ -10,8 +11,7 @@ def state_dict(simple_common_sim):

@pytest.fixture
def expected_tensors(simple_common_sim):
static_params = simple_common_sim.params["static"].flatten()
return {k: v.value for k, v in static_params.items()}
return extract_tensors(simple_common_sim.params)


class TestSimulator:
Expand All @@ -28,4 +28,4 @@ def test_state_dict(self, state_dict, expected_tensors):
assert "created_time" in state_dict._metadata

# Check params
assert dict(state_dict) == expected_tensors
assert isEquals(dict(state_dict), expected_tensors)
15 changes: 6 additions & 9 deletions tests/sims/test_state_dict.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from typing import Dict
import pytest
import torch
from safetensors.torch import save, load
Expand All @@ -8,6 +7,8 @@
from caustics.sims.state_dict import StateDict, IMMUTABLE_ERR
from caustics import __version__

from helpers.sims import extract_tensors, isEquals


class TestStateDict:
simple_tensors = {"var1": torch.as_tensor(1.0), "var2": torch.as_tensor(2.0)}
Expand Down Expand Up @@ -43,21 +44,17 @@ def test_delitem(self, simple_state_dict):
def test_from_params(self, simple_common_sim):
params: NestedNamespaceDict = simple_common_sim.params

# flatten function only exists for NestedNamespaceDict
static_params: NamespaceDict = params["static"].flatten()
tensors_dict: Dict[str, torch.Tensor] = {
k: v.value for k, v in static_params.items()
}
tensors_dict, all_params = extract_tensors(params, True)

expected_state_dict = StateDict(tensors_dict)

# Full parameters
state_dict = StateDict.from_params(params)
assert state_dict == expected_state_dict
assert isEquals(state_dict, expected_state_dict)

# Static only
state_dict = StateDict.from_params(static_params)
assert state_dict == expected_state_dict
state_dict = StateDict.from_params(all_params)
assert isEquals(state_dict, expected_state_dict)

def test_to_params(self, simple_state_dict):
params = simple_state_dict.to_params()
Expand Down