diff --git a/rt_utils/image_helper.py b/rt_utils/image_helper.py index b030987..87b9ef6 100644 --- a/rt_utils/image_helper.py +++ b/rt_utils/image_helper.py @@ -1,7 +1,8 @@ import os -from typing import List +from typing import List, Union from enum import IntEnum +import SimpleITK import cv2 as cv import numpy as np from pydicom import dcmread @@ -60,7 +61,11 @@ def get_contours_coords(roi_data: ROIData, series_data): mask_slice = create_pin_hole_mask(mask_slice, roi_data.approximate_contours) # Get contours from mask - contours, _ = find_mask_contours(mask_slice, roi_data.approximate_contours) + contours, _ = find_mask_contours(mask_slice, + roi_data.approximate_contours, + scaling_factor=roi_data.scaling_factor) + if not contours: + continue validate_contours(contours) # Format for DICOM @@ -82,7 +87,8 @@ def get_contours_coords(roi_data: ROIData, series_data): return series_contours -def find_mask_contours(mask: np.ndarray, approximate_contours: bool): + +def find_mask_contours(mask: np.ndarray, approximate_contours: bool, scaling_factor: int): approximation_method = ( cv.CHAIN_APPROX_SIMPLE if approximate_contours else cv.CHAIN_APPROX_NONE ) @@ -93,8 +99,12 @@ def find_mask_contours(mask: np.ndarray, approximate_contours: bool): contours = list( contours ) # Open-CV updated contours to be a tuple so we convert it back into a list here + + # Coordinates are rescaled to image grid by dividing with scaling factor for i, contour in enumerate(contours): - contours[i] = [[pos[0][0], pos[0][1]] for pos in contour] + contours[i] = [[(contour[i][0][0] / scaling_factor), (contour[i][0][1] / scaling_factor)] for i in + range(0, len(contour))] + hierarchy = hierarchy[0] # Format extra array out of data return contours, hierarchy @@ -126,7 +136,7 @@ def create_pin_hole_mask(mask: np.ndarray, approximate_contours: bool): def draw_line_upwards_from_point( - mask: np.ndarray, start, fill_value: int + mask: np.ndarray, start, fill_value: int ) -> np.ndarray: line_width = 2 end = (start[0], start[1] - 1) @@ -196,7 +206,7 @@ def get_patient_to_pixel_transformation_matrix(series_data): def apply_transformation_to_3d_points( - points: np.ndarray, transformation_matrix: np.ndarray + points: np.ndarray, transformation_matrix: np.ndarray ): """ * Augment each point with a '1' as the fourth coordinate to allow translation @@ -219,7 +229,7 @@ def get_slice_directions(series_slice: Dataset): slice_direction = np.cross(row_direction, column_direction) if not np.allclose( - np.dot(row_direction, column_direction), 0.0, atol=1e-3 + np.dot(row_direction, column_direction), 0.0, atol=1e-3 ) or not np.allclose(np.linalg.norm(slice_direction), 1.0, atol=1e-3): raise Exception("Invalid Image Orientation (Patient) attribute") @@ -263,7 +273,7 @@ def get_slice_contour_data(series_slice: Dataset, contour_sequence: Sequence): def get_slice_mask_from_slice_contour_data( - series_slice: Dataset, slice_contour_data, transformation_matrix: np.ndarray + series_slice: Dataset, slice_contour_data, transformation_matrix: np.ndarray ): # Go through all contours in a slice, create polygons in correct space and with a correct format # and append to polygons array (appropriate for fillPoly) @@ -275,9 +285,10 @@ def get_slice_mask_from_slice_contour_data( polygon = np.array(polygon).squeeze() polygons.append(polygon) slice_mask = create_empty_slice_mask(series_slice).astype(np.uint8) - cv.fillPoly(img=slice_mask, pts = polygons, color = 1) + cv.fillPoly(img=slice_mask, pts=polygons, color=1) return slice_mask + def create_empty_series_mask(series_data): ref_dicom_image = series_data[0] mask_dims = ( diff --git a/rt_utils/rtstruct.py b/rt_utils/rtstruct.py index dfe82be..55c204d 100644 --- a/rt_utils/rtstruct.py +++ b/rt_utils/rtstruct.py @@ -1,11 +1,11 @@ -from typing import List, Union +from typing import List, Union, Dict import numpy as np from pydicom.dataset import FileDataset from rt_utils.utils import ROIData -from . import ds_helper, image_helper - +from . import ds_helper, image_helper, smoothing +from typing import Tuple class RTStruct: """ @@ -35,6 +35,14 @@ def add_roi( use_pin_hole: bool = False, approximate_contours: bool = True, roi_generation_algorithm: Union[str, int] = 0, + apply_smoothing: Union[str, None] = None, # strings can be "2d" or "3d" or something else if a different smoothing function is used + smoothing_function = smoothing.pipeline, # Can be any function/set of functions that takes the following parameters + # # smoothing_function(mask=mask, apply_smoothing=apply_smoothing, + # smoothing_parameters=smoothing_parameters) -> np.ndarray + # The returned np.ndarray can be of any integer scalar shape in x and y of the used dicom image. + # Note that Z direction should not be scaled. For instance CT_image.shape == (512, 512, 150). + # Smoothed returned array can be (1024, 1024, 150) or (5120, 5120, 150), though you RAM will suffer with the latter. + smoothing_parameters: Union[Dict, None] = None, ): """ Add a ROI to the rtstruct given a 3D binary mask for the ROI's at each slice @@ -42,6 +50,13 @@ def add_roi( If use_pin_hole is set to true, will cut a pinhole through ROI's with holes in them so that they are represented with one contour If approximate_contours is set to False, no approximation will be done when generating contour data, leading to much larger amount of contour data """ + if apply_smoothing: + mask = smoothing_function(mask=mask, apply_smoothing=apply_smoothing, + smoothing_parameters=smoothing_parameters) + + ## If upscaled coords are given, they should be adjusted accordingly + rows = self.series_data[0][0x00280010].value + scaling_factor = int(mask.shape[0] / rows) # TODO test if name already exists self.validate_mask(mask) @@ -56,6 +71,7 @@ def add_roi( use_pin_hole, approximate_contours, roi_generation_algorithm, + scaling_factor ) self.ds.ROIContourSequence.append( diff --git a/rt_utils/smoothing.py b/rt_utils/smoothing.py new file mode 100644 index 0000000..2af1c9c --- /dev/null +++ b/rt_utils/smoothing.py @@ -0,0 +1,172 @@ +import SimpleITK as sitk +import numpy as np +from scipy import ndimage, signal +from typing import List, Union, Tuple, Dict +import logging + +# A set of parameters that is know to work well +default_smoothing_parameters_2d = { + "scaling_iterations": 2, + "filter_iterations": 3, + "crop_margins": [20, 20, 1], + "np_kron": {"scaling_factor": 3}, + "ndimage_gaussian_filter": {"sigma": 2, + "radius": 3}, + "threshold": {"threshold": 0.5}, +} + +def kron_upscale(mask: np.ndarray, params): + """ + This function upscales masks like so + + 1|2 1|1|2|2 + 3|4 --> 1|1|2|2 + 3|3|4|4 + 3|3|4|4 + + Scaling only in x and y direction + """ + + scaling_array = (params["scaling_factor"], params["scaling_factor"], 1) + + return np.kron(mask, np.ones(scaling_array)) + +def gaussian_blur(mask: np.ndarray, params): + return ndimage.gaussian_filter(mask, **params) + +def binary_threshold(mask: np.ndarray, params): + return mask > params["threshold"] + +def get_new_margin(column, margin, column_length): + """ + This functions takes a column (of x, y, or z) coordinates and adds a margin. + If margin exceeds mask size, the margin is returned to most extreme possible values + """ + new_min = column.min() - margin + if new_min < 0: + new_min = 0 + + new_max = column.max() + margin + if new_max > column_length: + new_max = column_length + + return new_min, new_max + +def crop_mask(mask: np.ndarray, crop_margins: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """ + This function crops masks to non-zero pixels padded by crop_margins. + Returns (cropped mask, bounding box) + """ + x, y, z = np.nonzero(mask) + + x_min, x_max = get_new_margin(x, crop_margins[0], mask.shape[0]) + y_min, y_max = get_new_margin(y, crop_margins[1], mask.shape[1]) + z_min, z_max = get_new_margin(z, crop_margins[2], mask.shape[2]) + + bbox = np.array([x_min, x_max, y_min, y_max, z_min, z_max]) + + return mask[bbox[0]: bbox[1], + bbox[2]: bbox[3], + bbox[4]: bbox[5]], bbox + +def restore_mask_dimensions(cropped_mask: np.ndarray, new_shape, bbox): + """ + This funtion restores mask dimentions to the given shape. + """ + new_mask = np.zeros(new_shape) + + new_mask[bbox[0]: bbox[1], bbox[2]: bbox[3], bbox[4]: bbox[5]] = cropped_mask + return new_mask.astype(bool) + +def iteration_2d(mask: np.ndarray, np_kron, ndimage_gaussian_filter, threshold, filter_iterations): + """ + This is the actual set of filters. Applied iterative over z direction + """ + cropped_mask = kron_upscale(mask=mask, params=np_kron) + + for filter_iteration in range(filter_iterations): + for z_idx in range(cropped_mask.shape[2]): + slice = cropped_mask[:, :, z_idx] + slice = gaussian_blur(mask=slice, params=ndimage_gaussian_filter) + slice = binary_threshold(mask=slice, params=threshold) + + cropped_mask[:, :, z_idx] = slice + + return cropped_mask + +def iteration_3d(mask: np.ndarray, np_kron, ndimage_gaussian_filter, threshold, filter_iterations): + """ + This is the actual filters applied iteratively in 3d. + """ + for filter_iteration in range(filter_iterations): + cropped_mask = kron_upscale(mask=mask, params=np_kron) + cropped_mask = gaussian_blur(mask=cropped_mask, params=ndimage_gaussian_filter) + cropped_mask = binary_threshold(mask=cropped_mask, params=threshold) + + return cropped_mask + +def pipeline(mask: np.ndarray, + apply_smoothing: str, + smoothing_parameters: Union[Dict, None]): + """ + This is the entrypoint for smoothing a mask. + """ + if not smoothing_parameters: + smoothing_parameters = default_smoothing_parameters_2d + + scaling_iterations = smoothing_parameters["scaling_iterations"] + filter_iterations = smoothing_parameters["filter_iterations"] + + crop_margins = np.array(smoothing_parameters["crop_margins"]) + np_kron = smoothing_parameters["np_kron"] + ndimage_gaussian_filter = smoothing_parameters["ndimage_gaussian_filter"] + threshold = smoothing_parameters["threshold"] + + logging.info(f"Original mask shape {mask.shape}") + logging.info(f"Cropping mask to non-zero") + cropped_mask, bbox = crop_mask(mask, crop_margins=crop_margins) + final_shape, final_bbox = get_final_mask_shape_and_bbox(mask=mask, + scaling_factor=np_kron["scaling_factor"], + scaling_iterations=scaling_iterations, + bbox=bbox) + logging.info(f"Final scaling with factor of {np_kron['scaling_factor']} for {scaling_iterations} scaling_iterations") + for i in range(scaling_iterations): + logging.info(f"Iteration {i+1} out of {scaling_iterations}") + logging.info(f"Applying filters") + if apply_smoothing == "2d": + cropped_mask = iteration_2d(cropped_mask, + np_kron=np_kron, + ndimage_gaussian_filter=ndimage_gaussian_filter, + threshold=threshold, + filter_iterations=filter_iterations) + elif apply_smoothing == "3d": + cropped_mask = iteration_3d(cropped_mask, + np_kron=np_kron, + ndimage_gaussian_filter=ndimage_gaussian_filter, + threshold=threshold, + filter_iterations=filter_iterations) + else: + raise Exception("Wrong dimension parameter. Use '2d' or '3d'.") + + # Restore dimensions + logging.info("Restoring original mask shape") + mask = restore_mask_dimensions(cropped_mask, final_shape, final_bbox) + return mask + +def get_final_mask_shape_and_bbox(mask, bbox, scaling_factor, scaling_iterations): + """ + This function scales image shape and the bounding box which should be used for the final mask + """ + + final_scaling_factor = pow(scaling_factor, scaling_iterations) + + final_shape = np.array(mask.shape) + final_shape[:2] *= final_scaling_factor + + bbox[:4] *= final_scaling_factor # Scale bounding box to final shape + bbox[:4] -= round(final_scaling_factor * 0.5) # Shift volumes to account for the shift that occurs as a result of the scaling + logging.info("Final shape: ", final_shape) + logging.info("Final bbox: ", bbox) + return final_shape, bbox + + diff --git a/rt_utils/utils.py b/rt_utils/utils.py index f04089e..3afb5c6 100644 --- a/rt_utils/utils.py +++ b/rt_utils/utils.py @@ -1,4 +1,4 @@ -from typing import List, Union +from typing import List, Union, Tuple from random import randrange from pydicom.uid import PYDICOM_IMPLEMENTATION_UID from dataclasses import dataclass @@ -41,7 +41,6 @@ class SOPClassUID: @dataclass class ROIData: """Data class to easily pass ROI data to helper methods.""" - mask: str color: Union[str, List[int]] number: int @@ -51,6 +50,10 @@ class ROIData: use_pin_hole: bool = False approximate_contours: bool = True roi_generation_algorithm: Union[str, int] = 0 + scaling_factor: int = 1 + smooth_radius: Union[int, None] = None + smooth_scale: Union[int, None] = None + def __post_init__(self): self.validate_color()