diff --git a/tat/_svd.py b/tat/_svd.py new file mode 100644 index 000000000..db9965e2a --- /dev/null +++ b/tat/_svd.py @@ -0,0 +1,248 @@ +""" +This module implements SVD decomposition without Householder reflection. +""" + +import typing +import torch +from ._qr import _normalize_diagonal, _givens_parameter + +# pylint: disable=invalid-name + + +@torch.jit.script +def _svd(A: torch.Tensor, error: float) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # pylint: disable=too-many-locals + # pylint: disable=too-many-branches + # pylint: disable=too-many-statements + + # see https://web.stanford.edu/class/cme335/lecture6.pdf + m, n = A.shape + trans = False + if m < n: + trans = True + A = A.transpose(0, 1) + m, n = n, m + U = torch.eye(m, dtype=A.dtype, device=A.device) + V = torch.eye(n, dtype=A.dtype, device=A.device) + + # Make bidiagonal matrix + B = A.clone(memory_format=torch.contiguous_format) + for i in range(n): + # (i:, i) + for j in range(m - 1, i, -1): + col = i + # Rotate inside col i + # Rotate from row j to j-1 + c, s = _givens_parameter(B[j - 1, col], B[j, col]) + U[j], U[j - 1] = -s * U[j - 1] + c * U[j], c.conj() * U[j - 1] + s.conj() * U[j] + B[j], B[j - 1] = -s * B[j - 1] + c * B[j], c.conj() * B[j - 1] + s.conj() * B[j] + # x = B[i:, i] + # v = torch.zeros_like(x) + # v[0] = _reflect_target(x) + # delta = _normalize_delta(v - x) + # B[i:, :] -= 2 * torch.outer(delta, delta.conj() @ B[i:, :]) + # U[i:, :] -= 2 * torch.outer(delta, delta.conj() @ U[i:, :]) + + # (i, i+1:)/H + if i == n - 1: + break + for j in range(n - 1, i + 1, -1): + row = i + # Rotate inside row i + # Rotate from col j to j-1 + c, s = _givens_parameter(B[row, j - 1], B[row, j]) + V[j], V[j - 1] = -s * V[j - 1] + c * V[j], c.conj() * V[j - 1] + s.conj() * V[j] + B[:, j], B[:, j - 1] = -s * B[:, j - 1] + c * B[:, j], c.conj() * B[:, j - 1] + s.conj() * B[:, j] + # x = B[i, i + 1:] + # v = torch.zeros_like(x) + # v[0] = _reflect_target(x) + # delta = _normalize_delta(v - x) + # B[:, i + 1:] -= 2 * torch.outer(B[:, i + 1:] @ delta.conj(), delta) + # V[i + 1:, :] -= 2 * torch.outer(delta, delta.conj() @ V[i + 1:, :]) + B = B[:n] + U = U[:n] + # print(B) + # error_decomp = torch.max(torch.abs(U.H @ B @ V.H.T - A)).item() + # assert error_decomp < 1e-4 + + # QR iteration with implicit Q + S = torch.diagonal(B).clone(memory_format=torch.contiguous_format) + F = torch.diagonal(B, offset=1).clone(memory_format=torch.contiguous_format) + F.resize_(S.size(0)) + F[-1] = 0 + X = F[-1] + stack: list[tuple[int, int]] = [(0, n - 1)] + while stack: + # B.zero_() + # B.diagonal()[:] = S + # B.diagonal(offset = 1)[:] = F[:-1] + # error_decomp = torch.max(torch.abs(U.H @ B @ V.H.T - A)).item() + # assert error_decomp < 1e-4 + + low = stack[-1][0] + high = stack[-1][1] + + if low == high: + stack.pop() + continue + + b = int(torch.argmin(torch.abs(F[low:high]))) + low + if torch.abs(F[b]) < torch.max(torch.abs(S)) * error: + F[b] = 0 + stack.pop() + stack.append((b + 1, high)) + stack.append((low, b)) + continue + + tdn = (S[b + 1].conj() * S[b + 1] + F[b].conj() * F[b]).real + tdn_1 = (S[b].conj() * S[b] + F[b - 1].conj() * F[b - 1]).real + tsn_1 = F[b].conj() * S[b] + d = (tdn_1 - tdn) / 2 + mu = tdn + d - torch.sign(d) * torch.sqrt(d**2 + tsn_1.conj() * tsn_1) + for i in range(low, high): + if i == low: + c, s = _givens_parameter(S[low].conj() * S[low] - mu, S[low].conj() * F[low]) + else: + c, s = _givens_parameter(F[i - 1], X) + V[i + 1], V[i] = -s * V[i] + c * V[i + 1], c.conj() * V[i] + s.conj() * V[i + 1] + if i != low: + F[i - 1] = c.conj() * F[i - 1] + s.conj() * X + F[i], S[i] = -s * S[i] + c * F[i], c.conj() * S[i] + s.conj() * F[i] + S[i + 1], X = c * S[i + 1], s.conj() * S[i + 1] + + c, s = _givens_parameter(S[i], X) + U[i + 1], U[i] = -s * U[i] + c * U[i + 1], c.conj() * U[i] + s.conj() * U[i + 1] + + S[i] = c.conj() * S[i] + s.conj() * X + S[i + 1], F[i] = -s * F[i] + c * S[i + 1], c.conj() * F[i] + s.conj() * S[i + 1] + if i != high - 1: + F[i + 1], X = c * F[i + 1], s.conj() * F[i + 1] + + # Make diagonal positive + c = _normalize_diagonal(S).conj() + V *= c.unsqueeze(1) # U is larger than V + S *= c + S = S.real + + # Sort + S, order = torch.sort(S, descending=True) + U = U[order] + V = V[order] + + # pylint: disable=no-else-return + if trans: + return V.H, S, U.H.T + else: + return U.H, S, V.H.T + + +@torch.jit.script +def _skew(A: torch.Tensor) -> torch.Tensor: + return A - A.H + + +@torch.jit.script +def _svd_backward( + U: torch.Tensor, + S: torch.Tensor, + Vh: torch.Tensor, + gU: typing.Optional[torch.Tensor], + gS: typing.Optional[torch.Tensor], + gVh: typing.Optional[torch.Tensor], +) -> typing.Optional[torch.Tensor]: + # pylint: disable=too-many-locals + # pylint: disable=too-many-branches + # pylint: disable=too-many-arguments + + # See pytorch torch/csrc/autograd/FunctionsManual.cpp:svd_backward + if gS is None and gU is None and gVh is None: + return None + + m = U.size(0) + n = Vh.size(1) + + if gU is None and gVh is None: + assert gS is not None + # pylint: disable=no-else-return + if m >= n: + return U @ (gS.unsqueeze(1) * Vh) + else: + return (U * gS.unsqueeze(0)) @ Vh + + is_complex = torch.is_complex(U) + + UhgU = _skew(U.H @ gU) if gU is not None else None + VhgV = _skew(Vh @ gVh.H) if gVh is not None else None + + S2 = S * S + E = S2.unsqueeze(0) - S2.unsqueeze(1) + E.diagonal()[:] = 1 + + if gU is not None: + if gVh is not None: + assert UhgU is not None + assert VhgV is not None + gA = (UhgU * S.unsqueeze(0) + S.unsqueeze(1) * VhgV) / E + else: + assert UhgU is not None + gA = (UhgU / E) * S.unsqueeze(0) + else: + assert VhgV is not None + gA = S.unsqueeze(1) * (VhgV / E) + + if gS is not None: + gA = gA + torch.diag(gS) + + if is_complex and gU is not None and gVh is not None: + assert UhgU is not None + gA = gA + torch.diag(UhgU.diagonal() / (2 * S)) + + if m > n and gU is not None: + gA = U @ gA + gUSinv = gU / S.unsqueeze(0) + gA = gA + gUSinv - U @ (U.H @ gUSinv) + gA = gA @ Vh + elif m < n and gVh is not None: + gA = gA @ Vh + SinvgVh = gVh / S.unsqueeze(1) + gA = gA + SinvgVh - (SinvgVh @ Vh.H) @ Vh + gA = U @ gA + elif m >= n: + gA = U @ (gA @ Vh) + else: + gA = (U @ gA) @ Vh + + return gA + + +class SVD(torch.autograd.Function): + """ + Compute SVD decomposition without Householder reflection. + """ + + # pylint: disable=abstract-method + + @staticmethod + def forward( # type: ignore[override] + ctx: torch.autograd.function.FunctionCtx, + A: torch.Tensor, + error: float, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # pylint: disable=arguments-differ + U, S, V = _svd(A, error) + ctx.save_for_backward(U, S, V) + return U, S, V + + @staticmethod + def backward( # type: ignore[override] + ctx: typing.Any, + U_grad: torch.Tensor | None, + S_grad: torch.Tensor | None, + V_grad: torch.Tensor, + ) -> tuple[torch.Tensor | None, None]: + # pylint: disable=arguments-differ + U, S, V = ctx.saved_tensors + return _svd_backward(U, S, V, U_grad, S_grad, V_grad), None + + +svd = SVD.apply diff --git a/tat/tensor.py b/tat/tensor.py index 05562e665..c2b9128dc 100644 --- a/tat/tensor.py +++ b/tat/tensor.py @@ -10,6 +10,7 @@ import torch from . import _utility from ._qr import givens_qr +from ._svd import svd as manual_svd from .edge import Edge # pylint: disable=too-many-public-methods @@ -1561,11 +1562,6 @@ def _guess_edge(matrix: torch.Tensor, edge: Edge, arrow: bool) -> Edge: parity=edge.parity[argmax], ) - def _ensure_mask(self: Tensor) -> None: - assert torch.allclose(torch.where(self.mask, torch.zeros([], dtype=self.dtype), self.data), - torch.zeros([], dtype=self.dtype)) - self._data = torch.where(self.mask, self.data, torch.zeros([], dtype=self.dtype)) - def svd( self: Tensor, free_names_u: set[str], @@ -1628,7 +1624,11 @@ def svd( names=("SVD_U", "SVD_V") if put_v_right else ("SVD_V", "SVD_U"), ) - data_1, data_s, data_2 = torch.linalg.svd(tensor.data, full_matrices=False) + if self.fermion: + data_1, data_s, data_2 = manual_svd(tensor.data, 1e-6) + else: + data_1, data_s, data_2 = torch.linalg.svd(tensor.data, full_matrices=False) + if cut != -1: data_1 = data_1[:, :cut] data_s = data_s[:cut] @@ -1644,7 +1644,7 @@ def svd( dtypes=self.dtypes, data=data_1, ) - tensor_1._ensure_mask() # pylint: disable=protected-access + # tensor_1._ensure_mask() free_edge_2 = tensor.edges[1] common_edge_2 = Tensor._guess_edge(torch.abs(data_2).transpose(0, 1), free_edge_2, False) tensor_2 = Tensor( @@ -1654,7 +1654,7 @@ def svd( dtypes=self.dtypes, data=data_2, ) - tensor_2._ensure_mask() # pylint: disable=protected-access + # tensor_2._ensure_mask() assert common_edge_1.conjugate() == common_edge_2 tensor_s = Tensor( names=(singular_name_u, singular_name_v) if put_v_right else (singular_name_v, singular_name_u), diff --git a/tests/test_svd.py b/tests/test_svd.py new file mode 100644 index 000000000..77d99f4f4 --- /dev/null +++ b/tests/test_svd.py @@ -0,0 +1,39 @@ +"Test SVD" + +import torch +from tat._svd import svd + +# pylint: disable=missing-function-docstring +# pylint: disable=invalid-name + + +def svd_func(A: torch.Tensor) -> torch.Tensor: + U, S, V = svd(A, 1e-10) + return U @ torch.diag(S).to(dtype=A.dtype) @ V + + +def check_svd(A: torch.Tensor) -> None: + m, n = A.size() + U, S, V = svd(A, 1e-10) + assert torch.allclose(U @ torch.diag(S.to(dtype=A.dtype)) @ V, A) + assert torch.allclose(U.H @ U, torch.eye(min(m, n), dtype=A.dtype, device=A.device)) + assert torch.allclose(V @ V.H, torch.eye(min(m, n), dtype=A.dtype, device=A.device)) + grad_check = torch.autograd.gradcheck( + svd_func if torch.is_complex(A) else lambda x: svd(x, 1e-10), + A, + eps=1e-8, + atol=1e-4, + ) + assert grad_check + + +def test_svd_real() -> None: + check_svd(torch.randn(7, 5, dtype=torch.float64, requires_grad=True)) + check_svd(torch.randn(5, 5, dtype=torch.float64, requires_grad=True)) + check_svd(torch.randn(5, 7, dtype=torch.float64, requires_grad=True)) + + +def test_svd_complex() -> None: + check_svd(torch.randn(7, 5, dtype=torch.complex128, requires_grad=True)) + check_svd(torch.randn(5, 5, dtype=torch.complex128, requires_grad=True)) + check_svd(torch.randn(5, 7, dtype=torch.complex128, requires_grad=True))