diff --git a/README.md b/README.md index ee7f11c..c699ca0 100644 --- a/README.md +++ b/README.md @@ -42,10 +42,10 @@ optimizer = torch.optim.Adam(model.parameters()) + optimizer = DistributedOptimizer(optimizer) dataset = load_dataset('goose') -data = torch.utils.data.DataLoader(dataset, shuffle=True) +dataloader = torch.utils.data.DataLoader(dataset, shuffle=True) for epoch in range(10): - for source, targets in data: + for source, targets in dataloader: - source = source.to(device) - targets = targets.to(device)