You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I was trying to set up SAM-Med3D and used your own example, adapted in such a way that it does not use ground truth.
I took the toy example you supplied:
''' 1. read and pre-process your input data '''img_path="./test_data/kidney_right/AMOS/imagesVal/amos_0013.nii.gz"category_index=3# the index of your target category in the gt annotationoutput_dir="./test_data/kidney_right/AMOS/pred/"roi_image, _, _=data_preprocess(img_path, gt_path, category_index=category_index)
''' 2. prepare the pre-trained model with local path or huggingface url '''ckpt_path="https://huggingface.co/blueyo0/SAM-Med3D/blob/main/sam_med3d_turbo.pth"# or you can use the local path like: ckpt_path = "./ckpt/sam_med3d_turbo.pth"model=medim.create_model("SAM-Med3D",
pretrained=True,
checkpoint_path=ckpt_path)
''' 3. infer with the pre-trained SAM-Med3D model '''device=torch.device("cuda"iftorch.cuda.is_available() else"cpu")
print("using device", device)
model=model.to(device)
prev_low_res_mask=Nonewithtorch.no_grad():
input_tensor=roi_imageimage_embeddings=model.image_encoder(input_tensor)
points_coords, points_labels=torch.zeros(1, 0,
3).to(device), torch.zeros(
1, 0).to(device)
#new_points_co = torch.Tensor([[0, 0, 0], [59, 51, 38], [59, 51, 77]])new_points_co, new_points_la=torch.Tensor(
[[[90, 5, 110]]]).to(device), torch.Tensor([[1]]).to(torch.int64)
# new_points_co, new_points_la = torch.Tensor(# [[np.argwhere(gt_img.detach().numpy().squeeze() > 0)[0]]]).to(device), torch.Tensor([[1]]).to(torch.int64)prev_low_res_mask= (torch.rand(1, 1, roi_image.shape[2]//4, roi_image.shape[3]//4, roi_image.shape[4]//4) >0.5).float()
points_coords=torch.cat([points_coords, new_points_co], dim=1)
points_labels=torch.cat([points_labels, new_points_la], dim=1)
sparse_embeddings, dense_embeddings=model.prompt_encoder(
points=[points_coords, points_labels],
boxes=None, # we currently not support bbox promptmasks=prev_low_res_mask,
# masks=None,
)
low_res_masks, _=model.mask_decoder(
image_embeddings=image_embeddings, # (1, 384, 8, 8, 8)image_pe=model.prompt_encoder.get_dense_pe(), # (1, 384, 8, 8, 8)sparse_prompt_embeddings=sparse_embeddings, # (1, 2, 384)dense_prompt_embeddings=dense_embeddings, # (1, 384, 8, 8, 8)
)
prev_mask=torch.nn.functional.interpolate(low_res_masks,
size=roi_image.shape[-3:],
mode='trilinear',
align_corners=False)
# convert prob to maskmedsam_seg_prob=torch.sigmoid(prev_mask) # (1, 1, 64, 64, 64)medsam_seg_prob=medsam_seg_prob.cpu().numpy().squeeze()
medsam_seg_mask2= (medsam_seg_prob>0.5).astype(np.uint8)
While playing with this line: new_points_co, new_points_la = torch.Tensor( [[[1, 1, 1]]]).to(device), torch.Tensor([[1]]).to(torch.int64)
I noticed that no matter the starting point, the result is always close to the same.
If you supply any given point (eg. [1, 1, 1], [1e10, 0, 1e4] or [50, 50, 50]), even points that lie inside of the kidney, but don't specify mask nor boxes:
you get some artifacts around the position of the kidney
If you supply any given point (as above) or no points at all (points=Null), but supply a 0-mask: prev_low_res_mask = torch.zeros(1, 1, roi_image.shape[2]//4, roi_image.shape[3]//4, roi_image.shape[4]//4)
you ALWAYS get a segmentation of the kidney, and almost a perfect one, no matter the starting point:
If you supply a random binary mask, and points: prev_low_res_mask = (torch.rand(1, 1, roi_image.shape[2]//4, roi_image.shape[3]//4, roi_image.shape[4]//4) > 0.5).float()
you again get the same segmentation of the kidney + some artifacts (or a very degenerated kidney segmentation).
I also wasn't able to segment any out-of-training images using the code posted above, it seems like the model might be memorizing samples. Could you please assist with that? Maybe I am doing something wrong?
The text was updated successfully, but these errors were encountered:
I was trying to set up SAM-Med3D and used your own example, adapted in such a way that it does not use ground truth.
I took the toy example you supplied:
While playing with this line:
new_points_co, new_points_la = torch.Tensor( [[[1, 1, 1]]]).to(device), torch.Tensor([[1]]).to(torch.int64)
I noticed that no matter the starting point, the result is always close to the same.
If you supply any given point (eg.
[1, 1, 1]
,[1e10, 0, 1e4]
or[50, 50, 50]
), even points that lie inside of the kidney, but don't specify mask nor boxes:If you supply any given point (as above) or no points at all (points=Null), but supply a 0-mask:
prev_low_res_mask = torch.zeros(1, 1, roi_image.shape[2]//4, roi_image.shape[3]//4, roi_image.shape[4]//4)
If you supply a random binary mask, and points:
prev_low_res_mask = (torch.rand(1, 1, roi_image.shape[2]//4, roi_image.shape[3]//4, roi_image.shape[4]//4) > 0.5).float()
I also wasn't able to segment any out-of-training images using the code posted above, it seems like the model might be memorizing samples. Could you please assist with that? Maybe I am doing something wrong?
The text was updated successfully, but these errors were encountered: