From 1ed7043abcb3d6388237e54a2791fbaca9e1f961 Mon Sep 17 00:00:00 2001 From: Hao Zhang Date: Tue, 21 Nov 2023 11:19:28 +0800 Subject: [PATCH] Support tensor storage for compat interface --- tat/compat.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/tat/compat.py b/tat/compat.py index 92639e2d8..f0271f9c2 100644 --- a/tat/compat.py +++ b/tat/compat.py @@ -275,11 +275,14 @@ def _(self: T, position: dict[str, tuple[typing.Any, int]]) -> tuple[int, ...]: # Function renames -def _compat_function(focus_type: type) -> typing.Callable[[typing.Callable], typing.Callable]: +def _compat_function(focus_type: type, name: str | None = None) -> typing.Callable[[typing.Callable], typing.Callable]: def _result(function: typing.Callable) -> typing.Callable: - name = function.__name__ - setattr(focus_type, name, function) + if name is None: + attr_name = function.__name__ + else: + attr_name = name + setattr(focus_type, attr_name, function) return function return _result @@ -303,6 +306,14 @@ def zero(self: T) -> T: return self.zero_() +@_compat_function(T, name="storage") # type: ignore[misc] +@property +def storage(self: T) -> typing.Any: + "Get the storage of the tensor" + assert self.data.is_contiguous() + return self.data.numpy().reshape([-1]) + + # Exponential arguments origin_exponential = T.exponential