Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The model is likely memorizing samples #96

Open
karol-szustakowski opened this issue Dec 13, 2024 · 0 comments
Open

The model is likely memorizing samples #96

karol-szustakowski opened this issue Dec 13, 2024 · 0 comments

Comments

@karol-szustakowski
Copy link

I was trying to set up SAM-Med3D and used your own example, adapted in such a way that it does not use ground truth.
I took the toy example you supplied:

''' 1. read and pre-process your input data '''
img_path = "./test_data/kidney_right/AMOS/imagesVal/amos_0013.nii.gz"
category_index = 3  # the index of your target category in the gt annotation
output_dir = "./test_data/kidney_right/AMOS/pred/"
roi_image, _, _ = data_preprocess(img_path, gt_path, category_index=category_index)

''' 2. prepare the pre-trained model with local path or huggingface url '''
ckpt_path = "https://huggingface.co/blueyo0/SAM-Med3D/blob/main/sam_med3d_turbo.pth"
# or you can use the local path like: ckpt_path = "./ckpt/sam_med3d_turbo.pth"
model = medim.create_model("SAM-Med3D",
                            pretrained=True,
                            checkpoint_path=ckpt_path)

''' 3. infer with the pre-trained SAM-Med3D model '''
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("using device", device)
model = model.to(device)
prev_low_res_mask = None
with torch.no_grad():
    input_tensor = roi_image
    image_embeddings = model.image_encoder(input_tensor)

    points_coords, points_labels = torch.zeros(1, 0,
                                                3).to(device), torch.zeros(
                                                    1, 0).to(device)
    
    #new_points_co = torch.Tensor([[0, 0, 0], [59, 51, 38], [59, 51, 77]])

    new_points_co, new_points_la = torch.Tensor(
        [[[90, 5, 110]]]).to(device), torch.Tensor([[1]]).to(torch.int64)
    

    # new_points_co, new_points_la = torch.Tensor(
    #     [[np.argwhere(gt_img.detach().numpy().squeeze() > 0)[0]]]).to(device), torch.Tensor([[1]]).to(torch.int64)
        
    prev_low_res_mask = (torch.rand(1, 1, roi_image.shape[2]//4, roi_image.shape[3]//4, roi_image.shape[4]//4) > 0.5).float()

    points_coords = torch.cat([points_coords, new_points_co], dim=1)
    points_labels = torch.cat([points_labels, new_points_la], dim=1)

    sparse_embeddings, dense_embeddings = model.prompt_encoder(
        points=[points_coords, points_labels],
        boxes=None,  # we currently not support bbox prompt
        masks=prev_low_res_mask,
        # masks=None,
    )

    low_res_masks, _ = model.mask_decoder(
        image_embeddings=image_embeddings,  # (1, 384, 8, 8, 8)
        image_pe=model.prompt_encoder.get_dense_pe(),  # (1, 384, 8, 8, 8)
        sparse_prompt_embeddings=sparse_embeddings,  # (1, 2, 384)
        dense_prompt_embeddings=dense_embeddings,  # (1, 384, 8, 8, 8)
    )

    prev_mask = torch.nn.functional.interpolate(low_res_masks,
                                size=roi_image.shape[-3:],
                                mode='trilinear',
                                align_corners=False)

# convert prob to mask
medsam_seg_prob = torch.sigmoid(prev_mask)  # (1, 1, 64, 64, 64)
medsam_seg_prob = medsam_seg_prob.cpu().numpy().squeeze()
medsam_seg_mask2 = (medsam_seg_prob > 0.5).astype(np.uint8)

While playing with this line:
new_points_co, new_points_la = torch.Tensor( [[[1, 1, 1]]]).to(device), torch.Tensor([[1]]).to(torch.int64)
I noticed that no matter the starting point, the result is always close to the same.

If you supply any given point (eg. [1, 1, 1], [1e10, 0, 1e4] or [50, 50, 50]), even points that lie inside of the kidney, but don't specify mask nor boxes:

  • you get some artifacts around the position of the kidney
    kidney1

If you supply any given point (as above) or no points at all (points=Null), but supply a 0-mask:
prev_low_res_mask = torch.zeros(1, 1, roi_image.shape[2]//4, roi_image.shape[3]//4, roi_image.shape[4]//4)

  • you ALWAYS get a segmentation of the kidney, and almost a perfect one, no matter the starting point:
    kidney2

If you supply a random binary mask, and points:
prev_low_res_mask = (torch.rand(1, 1, roi_image.shape[2]//4, roi_image.shape[3]//4, roi_image.shape[4]//4) > 0.5).float()

  • you again get the same segmentation of the kidney + some artifacts (or a very degenerated kidney segmentation).

I also wasn't able to segment any out-of-training images using the code posted above, it seems like the model might be memorizing samples. Could you please assist with that? Maybe I am doing something wrong?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant