diff --git a/tetragono/tetragono/sampling_neural_state/observer.py b/tetragono/tetragono/sampling_neural_state/observer.py index 51fab1de..ac7d772c 100644 --- a/tetragono/tetragono/sampling_neural_state/observer.py +++ b/tetragono/tetragono/sampling_neural_state/observer.py @@ -18,9 +18,39 @@ import numpy as np import torch -from ..utility import allreduce_buffer, allreduce_number, show, showln +from ..utility import allreduce_buffer, allreduce_number, show, showln, mpi_comm from .state import Configuration, index_tensor_element +opt = None + + +def torch_tensor_allgather(tensor): + from mpi4py import MPI + # Get the device of the input tensor + device = tensor.device + + # Convert torch tensor to numpy array + np_array = tensor.cpu().detach().numpy() + + # Initialize MPI + comm = mpi_comm + rank = comm.Get_rank() + size = comm.Get_size() + + counts = comm.allgather(np_array.size) + first = comm.allgather(np_array.shape[0]) + total_length = sum(first) + # Create a buffer to hold all gathered numpy arrays + gathered_np_arrays = np.empty((total_length, *np_array.shape[1:]), dtype=np_array.dtype) + + # Perform allgather + comm.Allgatherv(np_array, [gathered_np_arrays, counts]) + + # Convert gathered numpy arrays back to torch tensor + gathered_tensor = torch.from_numpy(gathered_np_arrays).to(device) + + return gathered_tensor + class Observer(): """ @@ -66,8 +96,7 @@ def __enter__(self): if self._enable_gradient: self._Delta = None self._EDelta = None - if self._enable_natural: - self._Deltas = [] + self._Deltas = [] # 临时使用这个list做别的用处 def __exit__(self, exc_type, exc_val, exc_tb): """ @@ -114,6 +143,25 @@ def __exit__(self, exc_type, exc_val, exc_tb): allreduce_buffer(self._Delta) allreduce_buffer(self._EDelta) + cs = torch.stack([c for c, e in self._Deltas]) + es = torch.tensor([e for c, e in self._Deltas], dtype=torch.complex128, device=cs.device) + cs = torch_tensor_allgather(cs) + es = torch.view_as_complex(torch_tensor_allgather(torch.view_as_real(es))) + es = es - es.mean() # 总之这个是用来采样的东西,以后可能会添加别的比如Delta也乘进去 + with torch.enable_grad(): + global opt + if opt is None: + opt = torch.optim.Adam(self.owner.network.es.parameters(), 1e-2) + for _ in range(100): + hes = self.owner.network.es(cs) + error = hes / hes.norm() - es / es.norm() + error = (error.abs()**2).mean() + show(error.item()) + opt.zero_grad() + error.backward() + opt.step() + showln("es error", error.item()) + def __init__( self, owner, @@ -395,6 +443,9 @@ def __call__(self, configurations, amplitudes, weights, multiplicities): name].imag * reweight if name == "energy" and self._enable_gradient: Es = whole_result[batch_index][name] + # train self.es + # collect and optimize self.es + self._Deltas.append((configurations[batch_index], Es)) if self.owner.Tensor.is_real: Es = Es.real diff --git a/tetragono/tetragono/sampling_neural_state/state.py b/tetragono/tetragono/sampling_neural_state/state.py index 342e8bee..9f17f6b2 100644 --- a/tetragono/tetragono/sampling_neural_state/state.py +++ b/tetragono/tetragono/sampling_neural_state/state.py @@ -348,16 +348,28 @@ def holes(self, value): if self.Tensor.is_complex: with torch_grad(True): value.real.backward(retain_graph=True) - real = torch.cat([param.grad.reshape([-1]) for param in self.network.parameters() if param.requires_grad]) + real = torch.cat([ + param.grad.reshape([-1]) + for param in self.network.parameters() + if param.requires_grad and param.grad is not None + ]) self.network.zero_grad() with torch_grad(True): value.imag.backward() - imag = torch.cat([param.grad.reshape([-1]) for param in self.network.parameters() if param.requires_grad]) + imag = torch.cat([ + param.grad.reshape([-1]) + for param in self.network.parameters() + if param.requires_grad and param.grad is not None + ]) self.network.zero_grad() result = (real + 1j * imag) else: value.backward() - result = torch.cat([param.grad.reshape([-1]) for param in self.network.parameters() if param.requires_grad]) + result = torch.cat([ + param.grad.reshape([-1]) + for param in self.network.parameters() + if param.requires_grad and param.grad is not None + ]) self.network.zero_grad() result = result / value return result.detach_() diff --git a/tetraku/tetraku/networks/naqs/reweight.py b/tetraku/tetraku/networks/naqs/reweight.py new file mode 100644 index 00000000..589fb2cf --- /dev/null +++ b/tetraku/tetraku/networks/naqs/reweight.py @@ -0,0 +1,246 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright (C) 2024 Hao Zhang +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +# + +import torch + + +class FakeLinear(torch.nn.Module): + + def __init__(self, dim_in, dim_out): + super().__init__() + self.bias = torch.nn.Parameter(torch.zeros([dim_out])) + + def forward(self, x): + shape = x.shape[:-1] + prod = torch.tensor(shape).prod() + return self.bias.view([1, -1]).expand([prod, -1]).view([*shape, -1]) + + +def Linear(dim_in, dim_out): + if dim_in == 0: + return FakeLinear(dim_in, dim_out) + else: + return torch.nn.Linear(dim_in, dim_out) + + +class MLP(torch.nn.Module): + + def __init__(self, dim_input, dim_output, hidden_size): + super().__init__() + self.dim_input = dim_input + self.dim_output = dim_output + self.hidden_size = hidden_size + self.depth = len(hidden_size) + + self.model = torch.nn.Sequential(*(Linear( + dim_input if i == 0 else hidden_size[i - 1], + dim_output if i == self.depth else hidden_size[i], + ) if j == 0 else torch.nn.SiLU() for i in range(self.depth + 1) for j in range(2) if i != self.depth or j != 1)) + + def forward(self, x): + return self.model(x) + + +class WaveFunction(torch.nn.Module): + + def __init__( + self, + *, + L1, + L2, + orbit_num, + physical_dim, + is_complex, + spin_up, + spin_down, + hidden_size, + ordering, + ): + super().__init__() + self.L1 = L1 + self.L2 = L2 + self.orbit_num = orbit_num + self.sites = L1 * L2 * orbit_num // 2 + assert physical_dim == 2 + assert is_complex == True + self.spin_up = spin_up + self.spin_down = spin_down + self.hidden_size = tuple(hidden_size) + + self.amplitude = torch.nn.ModuleList([MLP(i * 2, 4, self.hidden_size) for i in range(self.sites)]) + self.phase = torch.nn.ModuleList([MLP(i * 2, 4, self.hidden_size) for i in range(self.sites)]) + + if isinstance(ordering, int) and ordering == +1: + ordering = list(range(self.sites)) + if isinstance(ordering, int) and ordering == -1: + ordering = list(reversed(range(self.sites))) + self.register_buffer('ordering', torch.tensor(ordering, dtype=torch.int64), persistent=True) + ordering_bak = torch.zeros(self.sites, dtype=torch.int64) + ordering_bak.scatter_(0, self.ordering, torch.arange(self.sites)) + self.register_buffer('ordering_bak', ordering_bak, persistent=True) + + def mask(self, x): + # x : batch * i * 2 + i = x.size(1) + # number : batch * 2 + number = x.sum(dim=1) + + up_electron = number[:, 0] + down_electron = number[:, 1] + up_hole = i - up_electron + down_hole = i - down_electron + + add_up_electron = up_electron < self.spin_up + add_down_electron = down_electron < self.spin_down + add_up_hole = up_hole < self.sites - self.spin_up + add_down_hole = down_hole < self.sites - self.spin_down + + add_up = torch.stack([add_up_hole, add_up_electron], dim=-1).unsqueeze(-1) + add_down = torch.stack([add_down_hole, add_down_electron], dim=-1).unsqueeze(-2) + add = torch.logical_and(add_up, add_down) + return add + + def normalize_amplitude(self, x): + param = -(2 * x).exp().sum(dim=[1, 2]).log() / 2 + x = x + param.unsqueeze(-1).unsqueeze(-1) + return x + + def forward(self, x): + device = next(self.parameters()).device + dtype = next(self.parameters()).dtype + + batch_size = x.size(0) + x = x.reshape([batch_size, self.sites, 2]) + x = torch.index_select(x, 1, self.ordering_bak) + + xf = x.to(dtype=dtype) + arange = torch.arange(batch_size, device=device) + total_amplitude = 0 + total_phase = 0 + for i in range(self.sites): + amplitude = self.amplitude[i](xf[:, :i].reshape([batch_size, 2 * i])).reshape([batch_size, 2, 2]) + phase = self.phase[i](xf[:, :i].reshape([batch_size, 2 * i])).reshape([batch_size, 2, 2]) + amplitude = amplitude + torch.where(self.mask(x[:, :i]), 0, -torch.inf) + amplitude = self.normalize_amplitude(amplitude) + amplitude = amplitude[arange, x[:, i, 0], x[:, i, 1]] + phase = phase[arange, x[:, i, 0], x[:, i, 1]] + total_amplitude = total_amplitude + amplitude + total_phase = total_phase + phase + return (total_amplitude + 1j * total_phase).exp() + + def binomial(self, count, possibility): + possibility = torch.clamp(possibility, min=0, max=1) + possibility = torch.where(count == 0, 0, possibility) + dist = torch.distributions.binomial.Binomial(count, possibility) + result = dist.sample() + result = result.to(dtype=torch.int64) + # Numerical error since result was cast to float. + return torch.clamp(result, min=torch.zeros_like(count), max=count) + + def generate(self, batch_size, alpha=1): + # https://arxiv.org/pdf/2109.12606 + device = next(self.parameters()).device + dtype = next(self.parameters()).dtype + assert alpha == 1 + + x = torch.empty([1, 0, 2], device=device, dtype=torch.int64) + multiplicity = torch.tensor([batch_size], dtype=torch.int64, device=device) + amplitude_phase = torch.tensor([0], dtype=dtype.to_complex(), device=device) + for i in range(self.sites): + local_batch_size = x.size(0) + + xf = x.to(dtype=dtype) + amplitude = self.amplitude[i](xf.reshape([local_batch_size, 2 * i])).reshape([local_batch_size, 2, 2]) + phase = self.phase[i](xf.reshape([local_batch_size, 2 * i])).reshape([local_batch_size, 2, 2]) + amplitude = amplitude + torch.where(self.mask(x), 0, -torch.inf) + amplitude = self.normalize_amplitude(amplitude) + delta_amplitude_phase = (amplitude + 1j * phase).reshape([local_batch_size, 4]) + probability = (2 * amplitude).exp().reshape([local_batch_size, 4]) + probability = probability / probability.sum(dim=-1).unsqueeze(-1) + + sample0123 = multiplicity + prob23 = probability[:, 2] + probability[:, 3] + prob01 = probability[:, 0] + probability[:, 1] + sample23 = self.binomial(sample0123, prob23) + sample3 = self.binomial(sample23, probability[:, 3] / prob23) + sample2 = sample23 - sample3 + sample01 = sample0123 - sample23 + sample1 = self.binomial(sample01, probability[:, 1] / prob01) + sample0 = sample01 - sample1 + + x0 = torch.cat([x, torch.tensor([[0, 0]], device=device).expand(local_batch_size, -1, -1)], dim=1) + x1 = torch.cat([x, torch.tensor([[0, 1]], device=device).expand(local_batch_size, -1, -1)], dim=1) + x2 = torch.cat([x, torch.tensor([[1, 0]], device=device).expand(local_batch_size, -1, -1)], dim=1) + x3 = torch.cat([x, torch.tensor([[1, 1]], device=device).expand(local_batch_size, -1, -1)], dim=1) + + new_x = torch.cat([x0, x1, x2, x3]) + new_multiplicity = torch.cat([sample0, sample1, sample2, sample3]) + new_amplitude_phase = (amplitude_phase.unsqueeze(0) + delta_amplitude_phase.permute(1, 0)).reshape([-1]) + + selected = new_multiplicity != 0 + x = new_x[selected] + multiplicity = new_multiplicity[selected] + amplitude_phase = new_amplitude_phase[selected] + + real_amplitude = amplitude_phase.exp() + real_probability = (real_amplitude.conj() * real_amplitude).real + x = torch.index_select(x, 1, self.ordering) + return x.reshape([x.size(0), self.L1, self.L2, self.orbit_num]), real_amplitude, torch.ones_like(real_probability), torch.ones_like(multiplicity) + + +class ReweightWaveFunction(torch.nn.Module): + + def __init__( + self, + *args, + **kwargs, + ): + super().__init__() + self.psi = WaveFunction(*args, **kwargs) + self._es = WaveFunction(*args, **kwargs).cuda(), + self.es.load_state_dict(self.psi.state_dict()) + self.es.cuda() + + @property + def es(self): + return self._es[0] + + def forward(self, x): + return self.psi(x) + + def generate(self, batch_size, alpha=1): + configurations, _, weights, multiplicities = self.es.generate(batch_size, alpha) + amplitudes = self(configurations) + return configurations, amplitudes, weights, multiplicities + + +def network(state, spin_up, spin_down, hidden_size, ordering=+1): + max_orbit_index = max(orbit for [l1, l2, orbit], edge in state.physics_edges) + max_physical_dim = max(edge.dimension for [l1, l2, orbit], edge in state.physics_edges) + network = ReweightWaveFunction( + L1=state.L1, + L2=state.L2, + orbit_num=max_orbit_index + 1, + physical_dim=max_physical_dim, + is_complex=state.Tensor.is_complex, + spin_up=spin_up, + spin_down=spin_down, + hidden_size=hidden_size, + ordering=ordering, + ).double() + return network