Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use torchcompat to work on other devices #384

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 16 additions & 15 deletions dlrm_s_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
# pytorch
import torch
import torch.nn as nn
import torchcompat.core as accelerator

# dataloader
try:
Expand Down Expand Up @@ -122,13 +123,13 @@

def time_wrap(use_gpu):
if use_gpu:
torch.cuda.synchronize()
accelerator.synchronize()
return time.time()


def dlrm_wrap(X, lS_o, lS_i, use_gpu, device, ndevices=1):
with record_function("DLRM forward"):
if use_gpu: # .cuda()
if use_gpu:
# lS_i can be either a list of tensors or a stacked tensor.
# Handle each case below:
if ndevices == 1:
Expand Down Expand Up @@ -636,7 +637,7 @@ def parallel_forward(self, dense_x, lS_o, lS_i):
t_list = []
w_list = []
for k, emb in enumerate(self.emb_l):
d = torch.device("cuda:" + str(k % ndevices))
d = accelerator.fetch_device(k % ndevices)
t_list.append(emb.to(d))
if self.weighted_pooling == "learned":
w_list.append(Parameter(self.v_W_l[k].to(d)))
Expand All @@ -662,7 +663,7 @@ def parallel_forward(self, dense_x, lS_o, lS_i):
t_list = []
i_list = []
for k, _ in enumerate(self.emb_l):
d = torch.device("cuda:" + str(k % ndevices))
d = accelerator.fetch_device(k % ndevices)
t_list.append(lS_o[k].to(d))
i_list.append(lS_i[k].to(d))
lS_o = t_list
Expand Down Expand Up @@ -695,7 +696,7 @@ def parallel_forward(self, dense_x, lS_o, lS_i):

t_list = []
for k, _ in enumerate(self.emb_l):
d = torch.device("cuda:" + str(k % ndevices))
d = accelerator.fetch_device(k % ndevices)
y = scatter(ly[k], device_ids, dim=0)
t_list.append(y)
# adjust the list to be ordered per device
Expand Down Expand Up @@ -802,8 +803,8 @@ def inference(
### gather the distributed results on each rank ###
# For some reason it requires explicit sync before all_gather call if
# tensor is on GPU memory
if Z_test.is_cuda:
torch.cuda.synchronize()
accelerator.synchronize()

(_, batch_split_lengths) = ext_dist.get_split_lengths(X_test.size(0))
if ext_dist.my_size > 1:
Z_test = ext_dist.all_gather(Z_test, batch_split_lengths)
Expand Down Expand Up @@ -1072,22 +1073,22 @@ def run():
# if the parameter is not set, use the same parameter for training
args.test_num_workers = args.num_workers

use_gpu = args.use_gpu and torch.cuda.is_available()
use_gpu = args.use_gpu and accelerator.is_available()

if not args.debug_mode:
ext_dist.init_distributed(
local_rank=args.local_rank, use_gpu=use_gpu, backend=args.dist_backend
)

if use_gpu:
torch.cuda.manual_seed_all(args.numpy_rand_seed)
accelerator.manual_seed_all(args.numpy_rand_seed)
torch.backends.cudnn.deterministic = True
if ext_dist.my_size > 1:
ngpus = 1
device = torch.device("cuda", ext_dist.my_local_rank)
device = accelerator.fetch_device(ext_dist.my_local_rank)
else:
ngpus = torch.cuda.device_count()
device = torch.device("cuda", 0)
ngpus = accelerator.device_count()
device = accelerator.fetch_device(0)
print("Using {} GPU(s)...".format(ngpus))
else:
device = torch.device("cpu")
Expand Down Expand Up @@ -1318,15 +1319,15 @@ def run():
# Custom Model-Data Parallel
# the mlps are replicated and use data parallelism, while
# the embeddings are distributed and use model parallelism
dlrm = dlrm.to(device) # .cuda()
dlrm = dlrm.to(device)
if dlrm.ndevices > 1:
dlrm.emb_l, dlrm.v_W_l = dlrm.create_emb(
m_spa, ln_emb, args.weighted_pooling
)
else:
if dlrm.weighted_pooling == "fixed":
for k, w in enumerate(dlrm.v_W_l):
dlrm.v_W_l[k] = w.cuda()
dlrm.v_W_l[k] = w.to(device)

# distribute data parallel mlps
if ext_dist.my_size > 1:
Expand Down Expand Up @@ -1412,7 +1413,7 @@ def run():
# note that the call to .to(device) has already happened
ld_model = torch.load(
args.load_model,
map_location=torch.device("cuda"),
map_location=torch.device(accelerator.fetch_device(0)),
# map_location=lambda storage, loc: storage.cuda(0)
)
else:
Expand Down