Skip to content

Commit

Permalink
Fix issue milesial#496: Predicting batch of image arrays
Browse files Browse the repository at this point in the history
 add `predict_imgs()` function in predict.py to support predicting batch of image
 arrays.
  • Loading branch information
chuck.py.liu committed May 29, 2024
1 parent 21d7850 commit 1462b46
Showing 1 changed file with 68 additions and 27 deletions.
95 changes: 68 additions & 27 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from unet import UNet
from utils.utils import plot_img_and_mask


# for predicting single PIL image
def predict_img(net,
full_img,
device,
Expand All @@ -33,6 +35,27 @@ def predict_img(net,
return mask[0].long().squeeze().numpy()


# for predicting batch of nd.array images
def predict_imgs(net,
image_arrays,
device,
scale_factor=1,
out_threshold=0.5):
net.eval()
img = torch.from_numpy(image_arrays)
img = img.to(device=device, dtype=torch.float32)

with torch.no_grad():
output = net(img).cpu()
if net.n_classes > 1:
mask = output.argmax(dim=1)
else:
mask = torch.sigmoid(output) > out_threshold

return mask.long().squeeze().numpy()



def get_args():
parser = argparse.ArgumentParser(description='Predict masks from input images')
parser.add_argument('--model', '-m', default='MODEL.pth', metavar='FILE',
Expand Down Expand Up @@ -80,38 +103,56 @@ def mask_to_image(mask: np.ndarray, mask_values):
args = get_args()
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')

in_files = args.input
out_files = get_output_filenames(args)
image_in_file = False
if image_in_file:
in_files = args.input
out_files = get_output_filenames(args)

net = UNet(n_channels=3, n_classes=args.classes, bilinear=args.bilinear)
net = UNet(n_channels=3, n_classes=args.classes, bilinear=args.bilinear)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Loading model {args.model}')
logging.info(f'Using device {device}')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Loading model {args.model}')
logging.info(f'Using device {device}')

net.to(device=device)
state_dict = torch.load(args.model, map_location=device)
mask_values = state_dict.pop('mask_values', [0, 1])
net.load_state_dict(state_dict)
net.to(device=device)
state_dict = torch.load(args.model, map_location=device)
mask_values = state_dict.pop('mask_values', [0, 1])
net.load_state_dict(state_dict)

logging.info('Model loaded!')
logging.info('Model loaded!')

for i, filename in enumerate(in_files):
logging.info(f'Predicting image {filename} ...')
img = Image.open(filename)
for i, filename in enumerate(in_files):
logging.info(f'Predicting image {filename} ...')
img = Image.open(filename)

mask = predict_img(net=net,
full_img=img,
scale_factor=args.scale,
out_threshold=args.mask_threshold,
device=device)
mask = predict_img(net=net,
full_img=img,
scale_factor=args.scale,
out_threshold=args.mask_threshold,
device=device)

if not args.no_save:
out_filename = out_files[i]
result = mask_to_image(mask, mask_values)
result.save(out_filename)
logging.info(f'Mask saved to {out_filename}')
if not args.no_save:
out_filename = out_files[i]
result = mask_to_image(mask, mask_values)
result.save(out_filename)
logging.info(f'Mask saved to {out_filename}')

if args.viz:
logging.info(f'Visualizing results for image {filename}, close to continue...')
plot_img_and_mask(img, mask)
if args.viz:
logging.info(f'Visualizing results for image {filename}, close to continue...')
plot_img_and_mask(img, mask)
else:
batch_size = 128

net = torch.hub.load('milesial/Pytorch-UNet', 'unet_carvana', pretrained=True, scale=1.0)
device = torch.device('cpu' if torch.cuda.is_available() else 'cpu')
logging.info(f'Using device {device}')
net.to(device=device)
mask_values = [0, 1]

image_arrays = np.random.rand(batch_size, 3, 112, 112)
mask_arrays = predict_imgs(net=net,
image_arrays=image_arrays,
scale_factor=args.scale,
out_threshold=args.mask_threshold,
device=device)
print(mask_arrays.shape)

0 comments on commit 1462b46

Please sign in to comment.