Skip to content

Commit

Permalink
Support tensor storage for compat interface
Browse files Browse the repository at this point in the history
  • Loading branch information
hzhangxyz committed Nov 21, 2023
1 parent 16396f5 commit c071053
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
17 changes: 14 additions & 3 deletions tat/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,11 +275,14 @@ def _(self: T, position: dict[str, tuple[typing.Any, int]]) -> tuple[int, ...]:
# Function renames


def _compat_function(focus_type: type) -> typing.Callable[[typing.Callable], typing.Callable]:
def _compat_function(focus_type: type, name: str | None = None) -> typing.Callable[[typing.Callable], typing.Callable]:

def _result(function: typing.Callable) -> typing.Callable:
name = function.__name__
setattr(focus_type, name, function)
if name is None:
attr_name = function.__name__
else:
attr_name = name
setattr(focus_type, attr_name, function)
return function

return _result
Expand All @@ -303,6 +306,14 @@ def zero(self: T) -> T:
return self.zero_()


@_compat_function(T, name="storage") # type: ignore[misc]
@property
def storage(self: T) -> typing.Any:
"Get the storage of the tensor"
assert self.data.is_contiguous()
return self.data.numpy().reshape([-1])


# Exponential arguments

origin_exponential = T.exponential
Expand Down
2 changes: 1 addition & 1 deletion tat/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def copy(self: Tensor) -> Tensor:
edges=self.edges,
fermion=self.fermion,
dtypes=self.dtypes,
data=torch.clone(self.data),
data=torch.clone(self.data, memory_format=torch.contiguous_format),
mask=self.mask,
)

Expand Down

0 comments on commit c071053

Please sign in to comment.