From 0c1dd4f1219d0bf5372e5b1b95980d108a067229 Mon Sep 17 00:00:00 2001 From: Malte Jensen Date: Mon, 27 May 2024 13:12:58 -0700 Subject: [PATCH] fixed error for running on .nii.gz files andn updated aortic_calcification with better artifact detection --- bin/C2C-slurm | 6 +- comp2comp/aortic_calcium/aortic_calcium.py | 359 +++++++++++++++------ comp2comp/io/io_utils.py | 15 +- comp2comp/utils/process.py | 2 - 4 files changed, 270 insertions(+), 112 deletions(-) diff --git a/bin/C2C-slurm b/bin/C2C-slurm index 7e5b392..d5979af 100755 --- a/bin/C2C-slurm +++ b/bin/C2C-slurm @@ -8,7 +8,6 @@ from pathlib import Path exec_file = sys.argv[0].split("-")[0] command = exec_file + " " + " ".join([pipes.quote(s) for s in sys.argv[1:]]) - def submit_command(command): subprocess.run(command.split(" "), check=True, capture_output=False) @@ -43,4 +42,7 @@ def python_submit(command, node=None): os.remove("./slurm.sh") -python_submit(command, node='siena') +python_submit(command, node='roma') + + + diff --git a/comp2comp/aortic_calcium/aortic_calcium.py b/comp2comp/aortic_calcium/aortic_calcium.py index 732bdca..71f20e1 100644 --- a/comp2comp/aortic_calcium/aortic_calcium.py +++ b/comp2comp/aortic_calcium/aortic_calcium.py @@ -9,6 +9,7 @@ import time from pathlib import Path from typing import Union +import matplotlib.pyplot as plt import numpy as np from scipy import ndimage @@ -121,18 +122,6 @@ def __init__(self): super().__init__() def __call__(self, inference_pipeline): - ct = inference_pipeline.medical_volume.get_fdata() - aorta_mask = inference_pipeline.segmentation.get_fdata().astype(np.int8) == 52 - spine_mask = inference_pipeline.spine_segmentation.get_fdata() > 0 - - calcification_results = self.detectCalcifications( - ct, aorta_mask, exclude_mask=spine_mask, remove_size=3, return_dilated_mask=True, - threshold=inference_pipeline.args.threshold - ) - - inference_pipeline.calc_mask = calcification_results['calc_mask'] - inference_pipeline.calcium_threshold = calcification_results['threshold'] - # Set output dirs self.output_dir = inference_pipeline.output_dir self.output_dir_images_organs = os.path.join(self.output_dir, "images/") @@ -144,7 +133,38 @@ def __call__(self, inference_pipeline): os.makedirs(self.output_dir_images_organs) if not os.path.exists(self.output_dir_segmentation_masks): os.makedirs(self.output_dir_segmentation_masks) - + if not os.path.exists(os.path.join(self.output_dir, "metrics/")): + os.makedirs(os.path.join(self.output_dir, "metrics/")) + + ct = inference_pipeline.medical_volume.get_fdata() + aorta_mask = inference_pipeline.segmentation.get_fdata().astype(np.int8) == 52 + spine_mask = inference_pipeline.spine_segmentation.get_fdata() > 0 + + # Determine the target number of pixels + pix_size = np.array(inference_pipeline.medical_volume.header.get_zooms()) + # target: 1 mm + target_aorta_dil = round(1/pix_size[0]) + # target: 3 mm + target_exclude_dil = round(3/pix_size[0]) + # target: 7 mm + target_aorta_erode = round(7/pix_size[0]) + + # Run calcification detection pipeline + calcification_results = self.detectCalcifications( + ct, + aorta_mask, + exclude_mask=spine_mask, + remove_size=3, + return_dilated_mask=True, + threshold=inference_pipeline.args.threshold, + dilation_iteration = target_aorta_dil, + dilation_iteration_exclude = target_exclude_dil, + aorta_erode_iteration = target_aorta_erode, + ) + + inference_pipeline.calc_mask = calcification_results['calc_mask'] + inference_pipeline.calcium_threshold = calcification_results['threshold'] + # save masks inference_pipeline.saveArrToNifti( inference_pipeline.calc_mask, @@ -154,7 +174,12 @@ def __call__(self, inference_pipeline): inference_pipeline.saveArrToNifti( calcification_results['dilated_mask'], os.path.join(inference_pipeline.output_dir_segmentation_masks, - "dilated_aorta.nii.gz") + "dilated_aorta_mask.nii.gz") + ) + inference_pipeline.saveArrToNifti( + aorta_mask, + os.path.join(inference_pipeline.output_dir_segmentation_masks, + "aorta_mask.nii.gz") ) inference_pipeline.saveArrToNifti( ct, @@ -164,7 +189,6 @@ def __call__(self, inference_pipeline): return {} - def detectCalcifications( self, ct, @@ -185,6 +209,7 @@ def detectCalcifications( aorta_erode_iteration=6, threshold = 'adaptive', agatson_failsafe = 100, + generate_plots = True, ): """ Function that takes in a CT image and aorta segmentation (and optionally volumes to use @@ -230,40 +255,16 @@ def detectCalcifications( Will mean a threshold of 130 HU. agatson_failsafe: (int): A fail-safe raising an error if the mean HU of the aorta is too high - to reliably be using the agatson threshold of 130 + to reliably be using the agatson threshold of 130. Defaults to 100 HU. Returns: results: array of only the mask is returned, or dict if other volumes are also returned. """ - - def slicedDilationOrErosion(input_mask, struct, num_iteration, operation): - """ - Perform the dilation on the smallest slice that will fit the - segmentation - """ - margin = 2 if num_iteration is None else num_iteration + 1 - - x_idx = np.where(input_mask.sum(axis=(1, 2)))[0] - x_start, x_end = x_idx[0] - margin, x_idx[-1] + margin - y_idx = np.where(input_mask.sum(axis=(0, 2)))[0] - y_start, y_end = y_idx[0] - margin, y_idx[-1] + margin - - if operation == "dilate": - mask_slice = ndimage.binary_dilation( - input_mask[x_start:x_end, y_start:y_end, :], structure=struct - ).astype(np.int8) - elif operation == "erode": - mask_slice = ndimage.binary_erosion( - input_mask[x_start:x_end, y_start:y_end, :], structure=struct - ).astype(np.int8) - - output_mask = input_mask.copy() - - output_mask[x_start:x_end, y_start:y_end, :] = mask_slice - - return output_mask + ''' + Remove the ascending aorta if present + ''' # remove parts that are not the abdominal aorta labelled_aorta, num_classes = ndimage.label(aorta_mask) if num_classes > 1: @@ -277,23 +278,86 @@ def slicedDilationOrErosion(input_mask, struct, num_iteration, operation): biggest_idx = np.argmax(aorta_vols) + 1 aorta_mask[labelled_aorta != biggest_idx] = 0 + + + ''' + Erode the center aorta to get statistics from the blood pool + ''' + t0 = time.time() + + struct = ndimage.generate_binary_structure(3, 1) + struct = ndimage.iterate_structure(struct, aorta_erode_iteration) + + aorta_eroded = self.slicedDilationOrErosion( + aorta_mask, + struct=struct, + num_iteration=aorta_erode_iteration, + operation="erode", + ) + + eroded_ct_points = ct[aorta_eroded==1] + eroded_ct_points_mean = eroded_ct_points.mean() + eroded_ct_points_std = eroded_ct_points.std() + + if generate_plots: + # save the statistics of the eroded aorta for reference + with open(os.path.join(self.output_dir, 'metrics/eroded_aorta_statistics.csv'), 'w') as f: + f.write('metric,value\n') + f.write('mean,{:.1f}\n'.format(eroded_ct_points_mean)) + f.write('std,{:.1f}\n'.format(eroded_ct_points_std)) + + # save a histogram: + fig, axx = plt.subplots(1) + axx.hist(eroded_ct_points, bins=100) + axx.set_ylabel('Counts') + axx.set_xlabel('HU') + axx.set_title('Histogram of eroded aorta') + axx.grid() + plt.tight_layout() + fig.savefig(os.path.join(self.output_dir, 'images/histogram_eroded_aorta.png')) + + # Perform the fail-safe check if the method is agatson + if threshold == 'agatson' and eroded_ct_points_mean > agatson_failsafe: + raise ValueError('The mean HU in the center aorta is {:.0f}, and the Agatson method will provide unreliable results (fail-safe threshold is {})'.format( + eroded_ct_points_mean, agatson_failsafe + )) + + # calc_mask = calc_mask * (aorta_eroded == 0) + if show_time: + print("exclude center aorta time: {:.2f} sec".format(time.time() - t0)) + + ''' + Choose threshold + ''' - ### Choose the threshold ### if threshold == 'adaptive': + # calc_thres = eroded_ct_points.max() + # Get aortic CT point to set adaptive threshold aorta_ct_points = ct[aorta_mask == 1] - + # equal to one standard deviation to the left of the curve quant = 0.158 quantile_median_dist = np.median(aorta_ct_points) - np.quantile( aorta_ct_points, q=quant ) calc_thres = np.median(aorta_ct_points) + quantile_median_dist * num_std + elif threshold == 'agatson': calc_thres = 130 - # needed for surpressing noise and detecting if theres - # Contrast in the aorta - exclude_center_aorta = True + + counter = self.slicedSizeCount(aorta_eroded, ct, remove_size, calc_thres) + + # if num_features >= 10: + # raise ValueError('Too many pixels above 130 in blood pool, found: {}'.format(num_features)) + + if verbose: + print('{} calc over threshold of {}'.format(counter, remove_size)) + + if generate_plots: + # save the statistics of the eroded aorta for reference + with open(os.path.join(self.output_dir, 'metrics/eroded_aorta_statistics.csv'), 'a') as f: + f.write('num calcification blood pool,{}\n'.format(counter)) else: try: calc_thres = int(threshold) @@ -301,13 +365,15 @@ def slicedDilationOrErosion(input_mask, struct, num_iteration, operation): raise ValueError('Error in threshold value for aortic calcium segmentaiton. \ Should be \'adaptive\', \'agatson\' or int, but got: ' + str(threshold)) + ''' + Dilate aorta before using threshold to segment calcifications + ''' t0 = time.time() - if dilation is not None: struct = ndimage.generate_binary_structure(*dilation) if dilation_iteration is not None: struct = ndimage.iterate_structure(struct, dilation_iteration) - aorta_dilated = slicedDilationOrErosion( + aorta_dilated = self.slicedDilationOrErosion( aorta_mask, struct=struct, num_iteration=dilation_iteration, @@ -316,35 +382,14 @@ def slicedDilationOrErosion(input_mask, struct, num_iteration, operation): if show_time: print("dilation mask time: {:.2f}".format(time.time() - t0)) - + t0 = time.time() + # make threshold calc_mask = np.logical_and(aorta_dilated == 1, ct >= calc_thres) + if show_time: print("find calc time: {:.2f}".format(time.time() - t0)) - if exclude_center_aorta: - t0 = time.time() - - struct = ndimage.generate_binary_structure(3, 1) - struct = ndimage.iterate_structure(struct, aorta_erode_iteration) - - aorta_eroded = slicedDilationOrErosion( - aorta_mask, - struct=struct, - num_iteration=aorta_erode_iteration, - operation="erode", - ) - - # Perform the fail-safe check if the method is agatson - if threshold == 'agatson' and ct[aorta_eroded].mean() > agatson_failsafe: - raise ValueError('The mean HU in the center aorta is {:.0f}, and the Agatson method will provide unreliable results (fail-safe threshold is {})'.format( - ct[aorta_eroded].mean(), agatson_failsafe - )) - - calc_mask = calc_mask * (aorta_eroded == 0) - if show_time: - print("exclude center aorta time: {:.2f} sec".format(time.time() - t0)) - t0 = time.time() if exclude_mask is not None: if dilation_exclude_mask is not None: @@ -356,7 +401,7 @@ def slicedDilationOrErosion(input_mask, struct, num_iteration, operation): struct_exclude, dilation_iteration_exclude ) - exclude_mask = slicedDilationOrErosion( + exclude_mask = self.slicedDilationOrErosion( exclude_mask, struct=struct_exclude, num_iteration=dilation_iteration_exclude, @@ -367,26 +412,33 @@ def slicedDilationOrErosion(input_mask, struct, num_iteration, operation): print("exclude dilation time: {:.2f}".format(time.time() - t0)) t0 = time.time() - calc_mask = calc_mask * (exclude_mask == 0) + calc_mask = calc_mask * (exclude_mask == 0) if show_time: print("exclude time: {:.2f}".format(time.time() - t0)) if remove_size is not None: - t0 = time.time() - - labels, num_features = ndimage.label(calc_mask) - - counter = 0 - for n in range(1, num_features + 1): - idx_tmp = labels == n - if idx_tmp.sum() <= remove_size: - calc_mask[idx_tmp] = 0 - counter += 1 + if verbose: + print("Excluding calcifications under {} pixels".format(remove_size)) + t0 = time.time() + + if calc_mask.sum() != 0: + # perform the exclusion on a slice for speed + arr_slices = self.getSmallestArraySlice(calc_mask, margin = 1) + labels, num_features = ndimage.label(calc_mask[arr_slices]) + + counter = 0 + for n in range(1, num_features + 1): + idx_tmp = labels == n + if idx_tmp.sum() <= remove_size: + labels[idx_tmp] = 0 + counter += 1 + + calc_mask[arr_slices] = labels > 0 + + if show_time: print("Size exclusion time: {:.1f} sec".format(time.time() - t0)) - if verbose: - print("Excluded {} foci under {}".format(counter, remove_size)) if not any([return_dilated_mask, return_dilated_exclude]): return calc_mask.astype(np.int8) @@ -404,6 +456,92 @@ def slicedDilationOrErosion(input_mask, struct, num_iteration, operation): return results + + def slicedDilationOrErosion(self, input_mask, struct, num_iteration, operation): + """ + Perform the dilation on the smallest slice that will fit the + segmentation + """ + + if num_iteration < 1: + return input_mask + + margin = 2 if num_iteration is None else num_iteration + 1 + + x_idx = np.where(input_mask.sum(axis=(1, 2)))[0] + if len(x_idx) > 0: + x_start, x_end = max(x_idx[0] - margin, 0), min(x_idx[-1] + margin, input_mask.shape[0]) + + y_idx = np.where(input_mask.sum(axis=(0, 2)))[0] + if len(y_idx) > 0: + y_start, y_end = max(y_idx[0] - margin, 0), min(y_idx[-1] + margin, input_mask.shape[1]) + + # Don't dilate the aorta at the bifurcation + z_idx = np.where(input_mask.sum(axis=(0, 1)))[0] + z_start, z_end = z_idx[0], z_idx[-1] + + if operation == "dilate": + mask_slice = ndimage.binary_dilation( + input_mask[x_start:x_end, y_start:y_end, :], structure=struct + ).astype(np.int8) + elif operation == "erode": + mask_slice = ndimage.binary_erosion( + input_mask[x_start:x_end, y_start:y_end, :], structure=struct + ).astype(np.int8) + + # copy to not change the originial mask + output_mask = input_mask.copy() + + # insert dilated mask, but restrain to undilated z_start + output_mask[x_start:x_end, y_start:y_end, z_start:] = mask_slice[:,:,z_start:] + + return output_mask + + def slicedSizeCount(self, aorta_eroded, ct, remove_size, calc_thres): + ''' + Counts the number of calcifications over the size threshold in the eroded + aorta on the smallest slice that fits the aorta. + ''' + eroded_calc_mask = np.logical_and(aorta_eroded == 1, ct >= calc_thres) + + if eroded_calc_mask.sum() != 0: + # Perfom the counts on a slice of the aorta for speed + arr_slices = self.getSmallestArraySlice(eroded_calc_mask, margin = 1) + labels, num_features = ndimage.label(eroded_calc_mask[arr_slices]) + counter = 0 + for n in range(1, num_features + 1): + idx_tmp = labels == n + if idx_tmp.sum() > remove_size: + counter += 1 + + return counter + else: + return 0 + + + def getSmallestArraySlice(self, input_mask, margin = 0): + ''' + Generated the smallest slice that will fit the mask plus the given margin + and return a touple of slice objects + ''' + + x_idx = np.where(input_mask.sum(axis=(1, 2)))[0] + if len(x_idx) > 0: + x_start, x_end = max(x_idx[0] - margin, 0), min(x_idx[-1] + margin, input_mask.shape[0]) + + y_idx = np.where(input_mask.sum(axis=(0, 2)))[0] + if len(y_idx) > 0: + y_start, y_end = max(y_idx[0] - margin, 0), min(y_idx[-1] + margin, input_mask.shape[1]) + + z_idx = np.where(input_mask.sum(axis=(0, 1)))[0] + if len(z_idx) > 0: + z_start, z_end = max(z_idx[0] - margin, 0), min(z_idx[-1] + margin, input_mask.shape[2]) + + + return (slice(x_start, x_end), slice(y_start, y_end), slice(z_start, z_end)) + + + class AorticCalciumMetrics(InferenceClass): """Calculate metrics for the aortic calcifications""" @@ -453,7 +591,7 @@ def __call__(self, inference_pipeline): print('WARNNG: could not locate L1, using T12 only..') sep_plane = t12_level[0] else: - raise ValueError('Could not locate spine either T12 or L1, aborting..') + raise ValueError('Could not locate either T12 or L1, aborting..') planes = np.zeros_like(spine_mask, dtype=np.int8) planes[:,:,sep_plane] = 1 @@ -494,27 +632,36 @@ def __call__(self, inference_pipeline): "median_hu": [], "max_hu": [], } - - for j in range(1, num_lesions + 1): - tmp_mask = labelled_calc == j - - tmp_ct_vals = ct[tmp_mask] - - metrics["volume"].append( - len(tmp_ct_vals) * inference_pipeline.vol_per_pixel - ) - metrics["mean_hu"].append(np.mean(tmp_ct_vals)) - metrics["median_hu"].append(np.median(tmp_ct_vals)) - metrics["max_hu"].append(np.max(tmp_ct_vals)) + + if num_lesions == 0: + metrics["volume"].append(0) + metrics["mean_hu"].append(0) + metrics["median_hu"].append(0) + metrics["max_hu"].append(0) + else: + for j in range(1, num_lesions + 1): + tmp_mask = labelled_calc == j + + tmp_ct_vals = ct[tmp_mask] + + metrics["volume"].append( + len(tmp_ct_vals) * inference_pipeline.vol_per_pixel + ) + metrics["mean_hu"].append(np.mean(tmp_ct_vals)) + metrics["median_hu"].append(np.median(tmp_ct_vals)) + metrics["max_hu"].append(np.max(tmp_ct_vals)) # Volume of calcificaitons calc_vol = np.sum(metrics["volume"]) metrics["volume_total"] = calc_vol - metrics["num_calc"] = len(metrics["volume"]) + metrics["num_calc"] = num_lesions if inference_pipeline.args.threshold == 'agatson': - metrics["agatson_score"] = self.CalculateAgatsonScore(calc_mask_region, ct, inference_pipeline.pix_dims) + if num_lesions == 0: + metrics["agatson_score"] = 0 + else: + metrics["agatson_score"] = self.CalculateAgatsonScore(calc_mask_region, ct, inference_pipeline.pix_dims) all_regions[region_names[i]] = metrics @@ -535,8 +682,10 @@ def get_hu_factor(max_hu): factor = 2 elif 300 <= max_hu < 400: factor = 3 - elif max_hu > 400: + elif max_hu >= 400: factor = 4 + else: + raise ValueError('Could determine factor, got: ' + str(max_hu)) return factor @@ -558,6 +707,6 @@ def get_hu_factor(max_hu): if tmp_area <= 1: continue else: - agatson += tmp_area * get_hu_factor(tmp_ct_slice[tmp_mask].max()) + agatson += tmp_area * get_hu_factor(int(tmp_ct_slice[tmp_mask].max())) return agatson \ No newline at end of file diff --git a/comp2comp/io/io_utils.py b/comp2comp/io/io_utils.py index 299b547..e1f9190 100644 --- a/comp2comp/io/io_utils.py +++ b/comp2comp/io/io_utils.py @@ -3,7 +3,7 @@ """ import csv import os - +import nibabel as nib import pydicom @@ -50,12 +50,15 @@ def get_dicom_or_nifti_paths_and_num(path): dicom_nifti_paths = [] if path.endswith(".nii") or path.endswith(".nii.gz"): - dicom_nifti_paths.append([(path, 1)]) + dicom_nifti_paths.append( (path, getNumSlicesNifti(path)) ) elif path.endswith('.txt'): dicom_nifti_paths = [] with open(path, 'r') as f: for dicom_folder_path in f: - dicom_nifti_paths.append( (dicom_folder_path.strip(), len(os.listdir(dicom_folder_path.strip()))) ) + if dicom_folder_path.endswith(".nii") or path.dicom_folder_path(".nii.gz"): + dicom_nifti_paths.append( (dicom_folder_path.strip(), getNumSlicesNifti(dicom_folder_path.strip()))) + else: + dicom_nifti_paths.append( (dicom_folder_path.strip(), len(os.listdir(dicom_folder_path.strip())))) else: for root, dirs, files in os.walk(path): if len(files) > 0: @@ -77,3 +80,9 @@ def write_dicom_metadata_to_csv(ds, csv_filename): continue value = str(element.value) csvwriter.writerow([tag, keyword, value]) + +def getNumSlicesNifti(path): + img = nib.load(path) + img = nib.as_closest_canonical(img) + return img.shape[2] + diff --git a/comp2comp/utils/process.py b/comp2comp/utils/process.py index 52f1564..740a177 100644 --- a/comp2comp/utils/process.py +++ b/comp2comp/utils/process.py @@ -150,5 +150,3 @@ def process_3d(args, pipeline_builder): if len(os.listdir(os.path.dirname(output_dir))) == 0: shutil.rmtree(os.path.dirname(output_dir)) continue - -