From 4a2358e5596babbf4d99380913373bbd13ed74ba Mon Sep 17 00:00:00 2001 From: Setepenre Date: Tue, 21 May 2024 11:31:07 -0400 Subject: [PATCH] Use torchcompat to work on other devices --- dlrm_s_pytorch.py | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/dlrm_s_pytorch.py b/dlrm_s_pytorch.py index c4289409..1aeb6a28 100644 --- a/dlrm_s_pytorch.py +++ b/dlrm_s_pytorch.py @@ -82,6 +82,7 @@ # pytorch import torch import torch.nn as nn +import torchcompat.core as accelerator # dataloader try: @@ -122,13 +123,13 @@ def time_wrap(use_gpu): if use_gpu: - torch.cuda.synchronize() + accelerator.synchronize() return time.time() def dlrm_wrap(X, lS_o, lS_i, use_gpu, device, ndevices=1): with record_function("DLRM forward"): - if use_gpu: # .cuda() + if use_gpu: # lS_i can be either a list of tensors or a stacked tensor. # Handle each case below: if ndevices == 1: @@ -636,7 +637,7 @@ def parallel_forward(self, dense_x, lS_o, lS_i): t_list = [] w_list = [] for k, emb in enumerate(self.emb_l): - d = torch.device("cuda:" + str(k % ndevices)) + d = accelerator.fetch_device(k % ndevices) t_list.append(emb.to(d)) if self.weighted_pooling == "learned": w_list.append(Parameter(self.v_W_l[k].to(d))) @@ -662,7 +663,7 @@ def parallel_forward(self, dense_x, lS_o, lS_i): t_list = [] i_list = [] for k, _ in enumerate(self.emb_l): - d = torch.device("cuda:" + str(k % ndevices)) + d = accelerator.fetch_device(k % ndevices) t_list.append(lS_o[k].to(d)) i_list.append(lS_i[k].to(d)) lS_o = t_list @@ -695,7 +696,7 @@ def parallel_forward(self, dense_x, lS_o, lS_i): t_list = [] for k, _ in enumerate(self.emb_l): - d = torch.device("cuda:" + str(k % ndevices)) + d = accelerator.fetch_device(k % ndevices) y = scatter(ly[k], device_ids, dim=0) t_list.append(y) # adjust the list to be ordered per device @@ -802,8 +803,8 @@ def inference( ### gather the distributed results on each rank ### # For some reason it requires explicit sync before all_gather call if # tensor is on GPU memory - if Z_test.is_cuda: - torch.cuda.synchronize() + accelerator.synchronize() + (_, batch_split_lengths) = ext_dist.get_split_lengths(X_test.size(0)) if ext_dist.my_size > 1: Z_test = ext_dist.all_gather(Z_test, batch_split_lengths) @@ -1072,7 +1073,7 @@ def run(): # if the parameter is not set, use the same parameter for training args.test_num_workers = args.num_workers - use_gpu = args.use_gpu and torch.cuda.is_available() + use_gpu = args.use_gpu and accelerator.is_available() if not args.debug_mode: ext_dist.init_distributed( @@ -1080,14 +1081,14 @@ def run(): ) if use_gpu: - torch.cuda.manual_seed_all(args.numpy_rand_seed) + accelerator.manual_seed_all(args.numpy_rand_seed) torch.backends.cudnn.deterministic = True if ext_dist.my_size > 1: ngpus = 1 - device = torch.device("cuda", ext_dist.my_local_rank) + device = accelerator.fetch_device(ext_dist.my_local_rank) else: - ngpus = torch.cuda.device_count() - device = torch.device("cuda", 0) + ngpus = accelerator.device_count() + device = accelerator.fetch_device(0) print("Using {} GPU(s)...".format(ngpus)) else: device = torch.device("cpu") @@ -1318,7 +1319,7 @@ def run(): # Custom Model-Data Parallel # the mlps are replicated and use data parallelism, while # the embeddings are distributed and use model parallelism - dlrm = dlrm.to(device) # .cuda() + dlrm = dlrm.to(device) if dlrm.ndevices > 1: dlrm.emb_l, dlrm.v_W_l = dlrm.create_emb( m_spa, ln_emb, args.weighted_pooling @@ -1326,7 +1327,7 @@ def run(): else: if dlrm.weighted_pooling == "fixed": for k, w in enumerate(dlrm.v_W_l): - dlrm.v_W_l[k] = w.cuda() + dlrm.v_W_l[k] = w.to(device) # distribute data parallel mlps if ext_dist.my_size > 1: @@ -1412,7 +1413,7 @@ def run(): # note that the call to .to(device) has already happened ld_model = torch.load( args.load_model, - map_location=torch.device("cuda"), + map_location=torch.device(accelerator.fetch_device(0)), # map_location=lambda storage, loc: storage.cuda(0) ) else: