forked from Ciela-Institute/caustics
-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix: Fixed missing parameters in "dynamic" (#56)
* fix: Fixed missing parameters in "dynamic" Fixed the missing parameters that are intentionally None in the "dynamic" section. Now StateDict contains dynamic parameters. Also introduced '_sanitize' function to make None to empty tensor of size 0 since safetensors format doesn't except None. * test: Fix state dict comparison with empty torch Fixed the comparison for StateDicts. Since comparing empty tensors doesn't work, I created isEquals internal function to the 'test_from_params' test to compare values one by one. * test: Create utility and update tests * test: Refactored how helpers are setup in tests * refactor: Extract class name to variable
- Loading branch information
Showing
8 changed files
with
132 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
exclude: | | ||
(?x)^( | ||
tests/helpers/| | ||
tests/utils.py | ||
) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
import sys | ||
import os | ||
|
||
# Add the helpers directory to the path so we can import the helpers | ||
sys.path.append(os.path.join(os.path.dirname(__file__), "helpers")) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
from caustics.namespace_dict import NestedNamespaceDict | ||
from caustics.sims.state_dict import _sanitize | ||
|
||
|
||
def extract_tensors(params, include_params=False): | ||
# Extract the "static" and "dynamic" parameters | ||
param_dicts = list(params.values()) | ||
|
||
# Extract the "static" and "dynamic" parameters | ||
# to a single merged dictionary | ||
final_dict = NestedNamespaceDict() | ||
for pdict in param_dicts: | ||
for k, v in pdict.items(): | ||
if k not in final_dict: | ||
final_dict[k] = v | ||
else: | ||
final_dict[k] = {**final_dict[k], **v} | ||
|
||
# flatten function only exists for NestedNamespaceDict | ||
all_params = final_dict.flatten() | ||
|
||
tensors_dict = _sanitize({k: v.value for k, v in all_params.items()}) | ||
if include_params: | ||
return tensors_dict, all_params | ||
return tensors_dict | ||
|
||
|
||
def isEquals(a, b): | ||
# Go through each key and values | ||
# change empty torch to be None | ||
# since we can't directly compare | ||
# empty torch | ||
truthy = [] | ||
for k, v in a.items(): | ||
if k not in b: | ||
return False | ||
kv = b[k] | ||
if (v.nelement() == 0) or (kv.nelement() == 0): | ||
v = None | ||
kv = None | ||
truthy.append(v == kv) | ||
|
||
return all(truthy) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters