Skip to content

Commit

Permalink
Implement SVD decomposition manually.
Browse files Browse the repository at this point in the history
  • Loading branch information
hzhangxyz committed Nov 20, 2023
1 parent 1e2a7ac commit 4506c75
Show file tree
Hide file tree
Showing 3 changed files with 295 additions and 8 deletions.
248 changes: 248 additions & 0 deletions tat/_svd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
"""
This module implements SVD decomposition without Householder reflection.
"""

import typing
import torch
from ._qr import _normalize_diagonal, _givens_parameter

# pylint: disable=invalid-name


@torch.jit.script
def _svd(A: torch.Tensor, error: float) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# pylint: disable=too-many-locals
# pylint: disable=too-many-branches
# pylint: disable=too-many-statements

# see https://web.stanford.edu/class/cme335/lecture6.pdf
m, n = A.shape
trans = False
if m < n:
trans = True
A = A.transpose(0, 1)
m, n = n, m
U = torch.eye(m, dtype=A.dtype, device=A.device)
V = torch.eye(n, dtype=A.dtype, device=A.device)

# Make bidiagonal matrix
B = A.clone(memory_format=torch.contiguous_format)
for i in range(n):
# (i:, i)
for j in range(m - 1, i, -1):
col = i
# Rotate inside col i
# Rotate from row j to j-1
c, s = _givens_parameter(B[j - 1, col], B[j, col])
U[j], U[j - 1] = -s * U[j - 1] + c * U[j], c.conj() * U[j - 1] + s.conj() * U[j]
B[j], B[j - 1] = -s * B[j - 1] + c * B[j], c.conj() * B[j - 1] + s.conj() * B[j]
# x = B[i:, i]
# v = torch.zeros_like(x)
# v[0] = _reflect_target(x)
# delta = _normalize_delta(v - x)
# B[i:, :] -= 2 * torch.outer(delta, delta.conj() @ B[i:, :])
# U[i:, :] -= 2 * torch.outer(delta, delta.conj() @ U[i:, :])

# (i, i+1:)/H
if i == n - 1:
break
for j in range(n - 1, i + 1, -1):
row = i
# Rotate inside row i
# Rotate from col j to j-1
c, s = _givens_parameter(B[row, j - 1], B[row, j])
V[j], V[j - 1] = -s * V[j - 1] + c * V[j], c.conj() * V[j - 1] + s.conj() * V[j]
B[:, j], B[:, j - 1] = -s * B[:, j - 1] + c * B[:, j], c.conj() * B[:, j - 1] + s.conj() * B[:, j]
# x = B[i, i + 1:]
# v = torch.zeros_like(x)
# v[0] = _reflect_target(x)
# delta = _normalize_delta(v - x)
# B[:, i + 1:] -= 2 * torch.outer(B[:, i + 1:] @ delta.conj(), delta)
# V[i + 1:, :] -= 2 * torch.outer(delta, delta.conj() @ V[i + 1:, :])
B = B[:n]
U = U[:n]
# print(B)
# error_decomp = torch.max(torch.abs(U.H @ B @ V.H.T - A)).item()
# assert error_decomp < 1e-4

# QR iteration with implicit Q
S = torch.diagonal(B).clone(memory_format=torch.contiguous_format)
F = torch.diagonal(B, offset=1).clone(memory_format=torch.contiguous_format)
F.resize_(S.size(0))
F[-1] = 0
X = F[-1]
stack: list[tuple[int, int]] = [(0, n - 1)]
while stack:
# B.zero_()
# B.diagonal()[:] = S
# B.diagonal(offset = 1)[:] = F[:-1]
# error_decomp = torch.max(torch.abs(U.H @ B @ V.H.T - A)).item()
# assert error_decomp < 1e-4

low = stack[-1][0]
high = stack[-1][1]

if low == high:
stack.pop()
continue

b = int(torch.argmin(torch.abs(F[low:high]))) + low
if torch.abs(F[b]) < torch.max(torch.abs(S)) * error:
F[b] = 0
stack.pop()
stack.append((b + 1, high))
stack.append((low, b))
continue

tdn = (S[b + 1].conj() * S[b + 1] + F[b].conj() * F[b]).real
tdn_1 = (S[b].conj() * S[b] + F[b - 1].conj() * F[b - 1]).real
tsn_1 = F[b].conj() * S[b]
d = (tdn_1 - tdn) / 2
mu = tdn + d - torch.sign(d) * torch.sqrt(d**2 + tsn_1.conj() * tsn_1)
for i in range(low, high):
if i == low:
c, s = _givens_parameter(S[low].conj() * S[low] - mu, S[low].conj() * F[low])
else:
c, s = _givens_parameter(F[i - 1], X)
V[i + 1], V[i] = -s * V[i] + c * V[i + 1], c.conj() * V[i] + s.conj() * V[i + 1]
if i != low:
F[i - 1] = c.conj() * F[i - 1] + s.conj() * X
F[i], S[i] = -s * S[i] + c * F[i], c.conj() * S[i] + s.conj() * F[i]
S[i + 1], X = c * S[i + 1], s.conj() * S[i + 1]

c, s = _givens_parameter(S[i], X)
U[i + 1], U[i] = -s * U[i] + c * U[i + 1], c.conj() * U[i] + s.conj() * U[i + 1]

S[i] = c.conj() * S[i] + s.conj() * X
S[i + 1], F[i] = -s * F[i] + c * S[i + 1], c.conj() * F[i] + s.conj() * S[i + 1]
if i != high - 1:
F[i + 1], X = c * F[i + 1], s.conj() * F[i + 1]

# Make diagonal positive
c = _normalize_diagonal(S).conj()
V *= c.unsqueeze(1) # U is larger than V
S *= c
S = S.real

# Sort
S, order = torch.sort(S, descending=True)
U = U[order]
V = V[order]

# pylint: disable=no-else-return
if trans:
return V.H, S, U.H.T
else:
return U.H, S, V.H.T


@torch.jit.script
def _skew(A: torch.Tensor) -> torch.Tensor:
return A - A.H


@torch.jit.script
def _svd_backward(
U: torch.Tensor,
S: torch.Tensor,
Vh: torch.Tensor,
gU: typing.Optional[torch.Tensor],
gS: typing.Optional[torch.Tensor],
gVh: typing.Optional[torch.Tensor],
) -> typing.Optional[torch.Tensor]:
# pylint: disable=too-many-locals
# pylint: disable=too-many-branches
# pylint: disable=too-many-arguments

# See pytorch torch/csrc/autograd/FunctionsManual.cpp:svd_backward
if gS is None and gU is None and gVh is None:
return None

m = U.size(0)
n = Vh.size(1)

if gU is None and gVh is None:
assert gS is not None
# pylint: disable=no-else-return
if m >= n:
return U @ (gS.unsqueeze(1) * Vh)
else:
return (U * gS.unsqueeze(0)) @ Vh

is_complex = torch.is_complex(U)

UhgU = _skew(U.H @ gU) if gU is not None else None
VhgV = _skew(Vh @ gVh.H) if gVh is not None else None

S2 = S * S
E = S2.unsqueeze(0) - S2.unsqueeze(1)
E.diagonal()[:] = 1

if gU is not None:
if gVh is not None:
assert UhgU is not None
assert VhgV is not None
gA = (UhgU * S.unsqueeze(0) + S.unsqueeze(1) * VhgV) / E
else:
assert UhgU is not None
gA = (UhgU / E) * S.unsqueeze(0)
else:
assert VhgV is not None
gA = S.unsqueeze(1) * (VhgV / E)

if gS is not None:
gA = gA + torch.diag(gS)

if is_complex and gU is not None and gVh is not None:
assert UhgU is not None
gA = gA + torch.diag(UhgU.diagonal() / (2 * S))

if m > n and gU is not None:
gA = U @ gA
gUSinv = gU / S.unsqueeze(0)
gA = gA + gUSinv - U @ (U.H @ gUSinv)
gA = gA @ Vh
elif m < n and gVh is not None:
gA = gA @ Vh
SinvgVh = gVh / S.unsqueeze(1)
gA = gA + SinvgVh - (SinvgVh @ Vh.H) @ Vh
gA = U @ gA
elif m >= n:
gA = U @ (gA @ Vh)
else:
gA = (U @ gA) @ Vh

return gA


class SVD(torch.autograd.Function):
"""
Compute SVD decomposition without Householder reflection.
"""

# pylint: disable=abstract-method

@staticmethod
def forward( # type: ignore[override]
ctx: torch.autograd.function.FunctionCtx,
A: torch.Tensor,
error: float,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# pylint: disable=arguments-differ
U, S, V = _svd(A, error)
ctx.save_for_backward(U, S, V)
return U, S, V

@staticmethod
def backward( # type: ignore[override]
ctx: typing.Any,
U_grad: torch.Tensor | None,
S_grad: torch.Tensor | None,
V_grad: torch.Tensor,
) -> tuple[torch.Tensor | None, None]:
# pylint: disable=arguments-differ
U, S, V = ctx.saved_tensors
return _svd_backward(U, S, V, U_grad, S_grad, V_grad), None


svd = SVD.apply
16 changes: 8 additions & 8 deletions tat/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch
from . import _utility
from ._qr import givens_qr
from ._svd import svd as manual_svd
from .edge import Edge

# pylint: disable=too-many-public-methods
Expand Down Expand Up @@ -1561,11 +1562,6 @@ def _guess_edge(matrix: torch.Tensor, edge: Edge, arrow: bool) -> Edge:
parity=edge.parity[argmax],
)

def _ensure_mask(self: Tensor) -> None:
assert torch.allclose(torch.where(self.mask, torch.zeros([], dtype=self.dtype), self.data),
torch.zeros([], dtype=self.dtype))
self._data = torch.where(self.mask, self.data, torch.zeros([], dtype=self.dtype))

def svd(
self: Tensor,
free_names_u: set[str],
Expand Down Expand Up @@ -1628,7 +1624,11 @@ def svd(
names=("SVD_U", "SVD_V") if put_v_right else ("SVD_V", "SVD_U"),
)

data_1, data_s, data_2 = torch.linalg.svd(tensor.data, full_matrices=False)
if self.fermion:
data_1, data_s, data_2 = manual_svd(tensor.data, 1e-6)
else:
data_1, data_s, data_2 = torch.linalg.svd(tensor.data, full_matrices=False)

if cut != -1:
data_1 = data_1[:, :cut]
data_s = data_s[:cut]
Expand All @@ -1644,7 +1644,7 @@ def svd(
dtypes=self.dtypes,
data=data_1,
)
tensor_1._ensure_mask() # pylint: disable=protected-access
# tensor_1._ensure_mask()
free_edge_2 = tensor.edges[1]
common_edge_2 = Tensor._guess_edge(torch.abs(data_2).transpose(0, 1), free_edge_2, False)
tensor_2 = Tensor(
Expand All @@ -1654,7 +1654,7 @@ def svd(
dtypes=self.dtypes,
data=data_2,
)
tensor_2._ensure_mask() # pylint: disable=protected-access
# tensor_2._ensure_mask()
assert common_edge_1.conjugate() == common_edge_2
tensor_s = Tensor(
names=(singular_name_u, singular_name_v) if put_v_right else (singular_name_v, singular_name_u),
Expand Down
39 changes: 39 additions & 0 deletions tests/test_svd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"Test SVD"

import torch
from tat._svd import svd

# pylint: disable=missing-function-docstring
# pylint: disable=invalid-name


def svd_func(A: torch.Tensor) -> torch.Tensor:
U, S, V = svd(A, 1e-10)
return U @ torch.diag(S).to(dtype=A.dtype) @ V


def check_svd(A: torch.Tensor) -> None:
m, n = A.size()
U, S, V = svd(A, 1e-10)
assert torch.allclose(U @ torch.diag(S.to(dtype=A.dtype)) @ V, A)
assert torch.allclose(U.H @ U, torch.eye(min(m, n), dtype=A.dtype, device=A.device))
assert torch.allclose(V @ V.H, torch.eye(min(m, n), dtype=A.dtype, device=A.device))
grad_check = torch.autograd.gradcheck(
svd_func if torch.is_complex(A) else lambda x: svd(x, 1e-10),
A,
eps=1e-8,
atol=1e-4,
)
assert grad_check


def test_svd_real() -> None:
check_svd(torch.randn(7, 5, dtype=torch.float64, requires_grad=True))
check_svd(torch.randn(5, 5, dtype=torch.float64, requires_grad=True))
check_svd(torch.randn(5, 7, dtype=torch.float64, requires_grad=True))


def test_svd_complex() -> None:
check_svd(torch.randn(7, 5, dtype=torch.complex128, requires_grad=True))
check_svd(torch.randn(5, 5, dtype=torch.complex128, requires_grad=True))
check_svd(torch.randn(5, 7, dtype=torch.complex128, requires_grad=True))

0 comments on commit 4506c75

Please sign in to comment.