diff --git a/metrics/eval.py b/metrics/eval.py index e5c55fe3..dd478cfd 100644 --- a/metrics/eval.py +++ b/metrics/eval.py @@ -43,6 +43,7 @@ def calculate_metrics(nets, args, step, mode): img_size=args.img_size, batch_size=args.val_batch_size, imagenet_normalize=False, + num_workers=args.num_workers, drop_last=True) for src_idx, src_domain in enumerate(src_domains): @@ -50,7 +51,8 @@ def calculate_metrics(nets, args, step, mode): loader_src = get_eval_loader(root=path_src, img_size=args.img_size, batch_size=args.val_batch_size, - imagenet_normalize=False) + imagenet_normalize=False, + num_workers=args.num_workers) task = '%s2%s' % (src_domain, trg_domain) path_fake = os.path.join(args.eval_dir, task) @@ -131,6 +133,7 @@ def calculate_fid_for_all_tasks(args, domains, step, mode): path_fake = os.path.join(args.eval_dir, task) print('Calculating FID for %s...' % task) fid_value = calculate_fid_given_paths( + args=args, paths=[path_real, path_fake], img_size=args.img_size, batch_size=args.val_batch_size) diff --git a/metrics/fid.py b/metrics/fid.py index 31525105..66afe7e0 100644 --- a/metrics/fid.py +++ b/metrics/fid.py @@ -60,11 +60,11 @@ def frechet_distance(mu, cov, mu2, cov2): @torch.no_grad() -def calculate_fid_given_paths(paths, img_size=256, batch_size=50): +def calculate_fid_given_paths(args, paths, img_size=256, batch_size=50): print('Calculating FID given paths %s and %s...' % (paths[0], paths[1])) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') inception = InceptionV3().eval().to(device) - loaders = [get_eval_loader(path, img_size, batch_size) for path in paths] + loaders = [get_eval_loader(path, img_size, batch_size, num_workers=args.num_workers) for path in paths] mu, cov = [], [] for loader in loaders: @@ -84,8 +84,10 @@ def calculate_fid_given_paths(paths, img_size=256, batch_size=50): parser.add_argument('--paths', type=str, nargs=2, help='paths to real and fake images') parser.add_argument('--img_size', type=int, default=256, help='image resolution') parser.add_argument('--batch_size', type=int, default=64, help='batch size to use') + parser.add_argument('--num_workers', type=int, default=4, help='Number of workers used in DataLoader') + args = parser.parse_args() - fid_value = calculate_fid_given_paths(args.paths, args.img_size, args.batch_size) + fid_value = calculate_fid_given_paths(args, args.paths, args.img_size, args.batch_size) print('FID: ', fid_value) # python -m metrics.fid --paths PATH_REAL PATH_FAKE \ No newline at end of file