diff --git a/tat/_qr.py b/tat/_qr.py new file mode 100644 index 000000000..f330a4a04 --- /dev/null +++ b/tat/_qr.py @@ -0,0 +1,233 @@ +""" +This module implements QR decomposition based on Givens rotation and Householder reflection. +""" + +import typing +import torch + +# pylint: disable=invalid-name + + +@torch.jit.script +def _syminvadj(X: torch.Tensor) -> torch.Tensor: + ret = X + X.H + ret.diagonal().real[:] *= 1 / 2 + return ret + + +@torch.jit.script +def _triliminvadjskew(X: torch.Tensor) -> torch.Tensor: + ret = torch.tril(X - X.H) + if torch.is_complex(X): + ret.diagonal().imag[:] *= 1 / 2 + return ret + + +@torch.jit.script +def _qr_backward( + Q: torch.Tensor, + R: torch.Tensor, + Q_grad: typing.Optional[torch.Tensor], + R_grad: typing.Optional[torch.Tensor], +) -> typing.Optional[torch.Tensor]: + # see https://arxiv.org/pdf/2009.10071.pdf section 4.3 and 4.5 + # see pytorch torch/csrc/autograd/FunctionsManual.cpp:linalg_qr_backward + m = Q.size(0) + n = R.size(1) + + if Q_grad is not None: + if R_grad is not None: + MH = R_grad @ R.H - Q.H @ Q_grad + else: + MH = -Q.H @ Q_grad + else: + if R_grad is not None: + MH = R_grad @ R.H + else: + return None + + # pylint: disable=no-else-return + if m >= n: + # Deep and square matrix + b = Q @ _syminvadj(torch.triu(MH)) + if Q_grad is not None: + b = b + Q_grad + return torch.linalg.solve_triangular(R.H, b, upper=False, left=False) + else: + # Wide matrix + b = Q @ (_triliminvadjskew(-MH)) + result = torch.linalg.solve_triangular(R[:, :m].H, b, upper=False, left=False) + result = torch.cat((result, torch.zeros([m, n - m], dtype=result.dtype, device=result.device)), dim=1) + if R_grad is not None: + result = result + Q @ R_grad + return result + + +class CommonQR(torch.autograd.Function): + """ + Implement the autograd function for QR. + """ + + # pylint: disable=abstract-method + + @staticmethod + def backward( # type: ignore[override] + ctx: typing.Any, + Q_grad: torch.Tensor | None, + R_grad: torch.Tensor | None, + ) -> torch.Tensor | None: + # pylint: disable=arguments-differ + Q, R = ctx.saved_tensors + return _qr_backward(Q, R, Q_grad, R_grad) + + +@torch.jit.script +def _normalize_diagonal(a: torch.Tensor) -> torch.Tensor: + r = torch.sqrt(a.conj() * a) + return torch.where( + r == torch.zeros([], dtype=a.dtype, device=a.device), + torch.ones([], dtype=a.dtype, device=a.device), + a / r, + ) + + +@torch.jit.script +def _givens_parameter(a: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + r = torch.sqrt(a.conj() * a + b.conj() * b) + return torch.where( + b == torch.zeros([], dtype=a.dtype, device=a.device), + torch.ones([], dtype=a.dtype, device=a.device), + a / r, + ), torch.where( + b == torch.zeros([], dtype=a.dtype, device=a.device), + torch.zeros([], dtype=a.dtype, device=a.device), + b / r, + ) + + +@torch.jit.script +def _givens_qr(A: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + m, n = A.shape + k = min(m, n) + Q = torch.eye(m, dtype=A.dtype, device=A.device) + R = A.clone() + + # Parallel strategy + # Every row rotated to the nearest row above + for g in range(m - 1, 0, -1): + # rotate R[g, 0], R[g+2, 1], R[g+4, 2], ... + for i, col in zip(range(g, m, 2), range(n)): + j = i - 1 + # Rotate inside column col + # Rotate from row i to row j + c, s = _givens_parameter(R[j, col], R[i, col]) + Q[i], Q[j] = -s * Q[j] + c * Q[i], c.conj() * Q[j] + s.conj() * Q[i] + R[i], R[j] = -s * R[j] + c * R[i], c.conj() * R[j] + s.conj() * R[i] + for g in range(1, k): + # rotate R[g+1, g], R[g+1+2, g+1], R[g+1+4, g+2], ... + for i, col in zip(range(g + 1, m, 2), range(g, n)): + j = i - 1 + # Rotate inside column col + # Rotate from row i to row j + c, s = _givens_parameter(R[j, col], R[i, col]) + Q[i], Q[j] = -s * Q[j] + c * Q[i], c.conj() * Q[j] + s.conj() * Q[i] + R[i], R[j] = -s * R[j] + c * R[i], c.conj() * R[j] + s.conj() * R[i] + + # for j in range(n): + # for i in range(j + 1, m): + # col = j + # # Rotate inside column col + # # Rotate from row i to row j + # c, s = _givens_parameter(R[j, col], R[i, col]) + # Q[i], Q[j] = -s * Q[j] + c * Q[i], c.conj() * Q[j] + s.conj() * Q[i] + # R[i], R[j] = -s * R[j] + c * R[i], c.conj() * R[j] + s.conj() * R[i] + + # Make diagonal positive + c = _normalize_diagonal(R.diagonal()).conj() + Q[:k] *= torch.unsqueeze(c, 1) + R[:k] *= torch.unsqueeze(c, 1) + + Q, R = Q[:k].H, R[:k] + return Q, R + + +class GivensQR(CommonQR): + """ + Compute the reduced QR decomposition using Givens rotation. + """ + + # pylint: disable=abstract-method + + @staticmethod + def forward( # type: ignore[override] + ctx: torch.autograd.function.FunctionCtx, + A: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + # pylint: disable=arguments-differ + Q, R = _givens_qr(A) + ctx.save_for_backward(Q, R) + return Q, R + + +@torch.jit.script +def _normalize_delta(a: torch.Tensor) -> torch.Tensor: + norm = a.norm() + return torch.where( + norm == torch.zeros([], dtype=a.dtype, device=a.device), + torch.zeros([], dtype=a.dtype, device=a.device), + a / norm, + ) + + +@torch.jit.script +def _reflect_target(x: torch.Tensor) -> torch.Tensor: + return torch.norm(x) * _normalize_diagonal(x[0]) + + +@torch.jit.script +def _householder_qr(A: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + m, n = A.shape + k = min(m, n) + Q = torch.eye(m, dtype=A.dtype, device=A.device) + R = A.clone() + + for i in range(k): + x = R[i:, i] + v = torch.zeros_like(x) + # For complex matrix, it require = , i.e. v[0] and x[0] have opposite argument. + v[0] = _reflect_target(x) + # Reflect x to v + delta = _normalize_delta(v - x) + # H = 1 - 2 |Delta> tuple[torch.Tensor, torch.Tensor]: + # pylint: disable=arguments-differ + Q, R = _householder_qr(A) + ctx.save_for_backward(Q, R) + return Q, R + + +givens_qr = GivensQR.apply +householder_qr = HouseholderQR.apply diff --git a/tat/tensor.py b/tat/tensor.py index c81d832c8..05562e665 100644 --- a/tat/tensor.py +++ b/tat/tensor.py @@ -9,6 +9,7 @@ from multimethod import multimethod import torch from . import _utility +from ._qr import givens_qr from .edge import Edge # pylint: disable=too-many-public-methods @@ -1739,7 +1740,12 @@ def qr( names=("QR_Q", "QR_R"), ) - data_q, data_r = torch.linalg.qr(tensor.data, mode="reduced") + if self.fermion: + # Blocked tensor, use Givens rotation + data_q, data_r = givens_qr(tensor.data) + else: + # Non-blocked tensor, use Householder reflection + data_q, data_r = torch.linalg.qr(tensor.data, mode="reduced") free_edge_q = tensor.edges[0] common_edge_q = Tensor._guess_edge(torch.abs(data_q), free_edge_q, True) @@ -1750,9 +1756,12 @@ def qr( dtypes=self.dtypes, data=data_q, ) - tensor_q._ensure_mask() # pylint: disable=protected-access + # tensor_q._ensure_mask() free_edge_r = tensor.edges[1] - common_edge_r = Tensor._guess_edge(torch.abs(data_r).transpose(0, 1), free_edge_r, False) + # common_edge_r = Tensor._guess_edge(torch.abs(data_r).transpose(0, 1), free_edge_r, False) + # Sometimes R matrix maybe singular, guess edge will return arbitary symmetry, so do not use guessed edge. + # In the other hand, QR based on Givens rotation always gives blocked result, which can be believed. + common_edge_r = common_edge_q.conjugate() tensor_r = Tensor( names=(common_name_r, "QR_R"), edges=(common_edge_r, free_edge_r), @@ -1760,7 +1769,7 @@ def qr( dtypes=self.dtypes, data=data_r, ) - tensor_r._ensure_mask() # pylint: disable=protected-access + # tensor_r._ensure_mask() assert common_edge_q.conjugate() == common_edge_r tensor_q = tensor_q.split_edge({"QR_Q": ordered_free_edges_q}, False, set()) diff --git a/tests/test_qr.py b/tests/test_qr.py new file mode 100644 index 000000000..5085cc261 --- /dev/null +++ b/tests/test_qr.py @@ -0,0 +1,59 @@ +"Test QR" + +import torch +from tat._qr import givens_qr, householder_qr + +# pylint: disable=missing-function-docstring +# pylint: disable=invalid-name + + +def check_givens(A: torch.Tensor) -> None: + m, n = A.size() + Q, R = givens_qr(A) + assert torch.allclose(A, Q @ R) + assert torch.allclose(Q.H @ Q, torch.eye(min(m, n), dtype=A.dtype, device=A.device)) + grad_check = torch.autograd.gradcheck( + givens_qr, + A, + eps=1e-8, + atol=1e-4, + ) + assert grad_check + + +def test_qr_real_givens() -> None: + check_givens(torch.randn(7, 5, dtype=torch.float64, requires_grad=True)) + check_givens(torch.randn(5, 5, dtype=torch.float64, requires_grad=True)) + check_givens(torch.randn(5, 7, dtype=torch.float64, requires_grad=True)) + + +def test_qr_complex_givens() -> None: + check_givens(torch.randn(7, 5, dtype=torch.complex128, requires_grad=True)) + check_givens(torch.randn(5, 5, dtype=torch.complex128, requires_grad=True)) + check_givens(torch.randn(5, 7, dtype=torch.complex128, requires_grad=True)) + + +def check_householder(A: torch.Tensor) -> None: + m, n = A.size() + Q, R = householder_qr(A) + assert torch.allclose(A, Q @ R) + assert torch.allclose(Q.H @ Q, torch.eye(min(m, n), dtype=A.dtype, device=A.device)) + grad_check = torch.autograd.gradcheck( + householder_qr, + A, + eps=1e-8, + atol=1e-4, + ) + assert grad_check + + +def test_qr_real_householder() -> None: + check_householder(torch.randn(7, 5, dtype=torch.float64, requires_grad=True)) + check_householder(torch.randn(5, 5, dtype=torch.float64, requires_grad=True)) + check_householder(torch.randn(5, 7, dtype=torch.float64, requires_grad=True)) + + +def test_qr_complex_householder() -> None: + check_householder(torch.randn(7, 5, dtype=torch.complex128, requires_grad=True)) + check_householder(torch.randn(5, 5, dtype=torch.complex128, requires_grad=True)) + check_householder(torch.randn(5, 7, dtype=torch.complex128, requires_grad=True))