Skip to content

Commit

Permalink
Support gpu
Browse files Browse the repository at this point in the history
  • Loading branch information
hzhangxyz committed Nov 27, 2023
1 parent 27a1c4d commit 808a938
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 10 deletions.
9 changes: 9 additions & 0 deletions tetragono/tetragono/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,15 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
#

from mpi4py import MPI
import torch

if torch.cuda.device_count() != 0:
torch.set_default_tensor_type(torch.cuda.FloatTensor)
torch.cuda.set_device(
MPI.COMM_WORLD.Split_type(MPI.COMM_TYPE_SHARED, MPI.COMM_WORLD.Get_rank()).Get_rank() %
torch.cuda.device_count())

# States
from .abstract_state import AbstractState
from .exact_state import ExactState
Expand Down
5 changes: 3 additions & 2 deletions tetragono/tetragono/sampling_lattice/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ def __exit__(self, exc_type, exc_val, exc_tb):
buffer.append(self._whole_result_square_reweight_square[name])
buffer.append(self._total_imaginary_energy_reweight)

buffer = np.array(buffer)
import torch
buffer = torch.tensor(buffer).cpu()
allreduce_buffer(buffer)
buffer = buffer.tolist()

Expand Down Expand Up @@ -663,7 +664,7 @@ def _delta_to_array(self, delta):
# Both delta and result array is in bra space
result = []
for l1, l2 in self.owner.sites():
result.append(delta[l1][l2].transpose(self._Delta[l1][l2].names).copy().storage)
result.append(delta[l1][l2].transpose(self._Delta[l1][l2].names).copy().storage.cpu())
result = np.concatenate(result)
return result

Expand Down
2 changes: 1 addition & 1 deletion tetragono/tetragono/sampling_lattice/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def __call__(self):
.transpose(["I", "O"]))
hole_edge = hole.edge_by_name("O")
# Calculate rho for all the segments of the physics edge of this orbit
rho = hole.data.diagonal()
rho = hole.data.diagonal().cpu()
rho = np.array(rho).real
rho = np.maximum(rho, 0) # Sometimes there is some negative value because of numeric error.
if np.sum(rho) == 0:
Expand Down
10 changes: 3 additions & 7 deletions tetragono/tetragono/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,8 @@ def allreduce_buffer(buffer):


def allreduce_iterator_buffer(iterator):
requests = []
for tensor in iterator:
requests.append(mpi_comm.Iallreduce(MPI.IN_PLACE, tensor))
MPI.Request.Waitall(requests)
mpi_comm.Allreduce(MPI.IN_PLACE, tensor)


def allreduce_lattice_buffer(lattice):
Expand All @@ -92,10 +90,8 @@ def bcast_buffer(buffer, root=0):


def bcast_iterator_buffer(iterator, root=0):
requests = []
for tensor in iterator:
requests.append(mpi_comm.Ibcast(tensor, root=root))
MPI.Request.Waitall(requests)
mpi_comm.Bcast(tensor, root=root)


def bcast_lattice_buffer(lattice, root=0):
Expand Down Expand Up @@ -190,7 +186,7 @@ def lattice_conjugate(tensor):

@np.vectorize
def lattice_dot(tensor_1, tensor_2):
return tensor_1.contract(tensor_2, {(name, name) for name in tensor_1.names}).storage[0]
return tensor_1.contract(tensor_2, {(name, name) for name in tensor_1.names}).storage.cpu().item()


def lattice_prod_sum(tensors_1, tensors_2):
Expand Down

0 comments on commit 808a938

Please sign in to comment.