Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
raise ValueError("invalid slice_dir passed as argument")
data_aug = data_result[sample_id]
if data_aug.shape[0] == 9:
data_result[sample_id] = rotate_multiple_peaks(data_aug, a_x, a_y, a_z)
elif data_aug.shape[0] == 18:
data_result[sample_id] = rotate_multiple_tensors(data_aug, a_x, a_y, a_z)
else:
raise ValueError("Incorrect number of channels (expected 9 or 18)")
return data_result, seg_result
# This is identical to batchgenerators.transforms.spatial_transforms.SpatialTransform except for another
# augment_spatial function, which also rotates the peaks when doing rotation.
class SpatialTransformPeaks(AbstractTransform):
"""The ultimate spatial transform generator. Rotation, deformation, scaling, cropping: It has all you ever dreamed
of. Computational time scales only with patch_size, not with input patch size or type of augmentations used.
Internally, this transform will use a coordinate grid of shape patch_size to which the transformations are
applied (very fast). Interpolation on the image data will only be done at the very end
Args:
patch_size (tuple/list/ndarray of int): Output patch size
patch_center_dist_from_border (tuple/list/ndarray of int, or int): How far should the center pixel of the
extracted patch be from the image border? Recommended to use patch_size//2.
This only applies when random_crop=True
do_elastic_deform (bool): Whether or not to apply elastic deformation
alpha (tuple of float): magnitude of the elastic deformation; randomly sampled from interval
for name in roi_item_keys:
roi_items[name].append(np.array([]))
if get_rois_from_seg:
data_dict.pop('class_targets', None)
data_dict['bb_target'] = np.array(bb_target)
data_dict['roi_masks'] = np.array(roi_masks)
data_dict['seg'] = out_seg
for name in roi_item_keys:
data_dict[name] = np.array(roi_items[name])
return data_dict
class ConvertSegToBoundingBoxCoordinates(AbstractTransform):
""" Converts segmentation masks into bounding box coordinates.
"""
def __init__(self, dim, roi_item_keys, get_rois_from_seg=False, class_specific_seg=False):
self.dim = dim
self.roi_item_keys = roi_item_keys
self.get_rois_from_seg = get_rois_from_seg
self.class_specific_seg = class_specific_seg
def __call__(self, **data_dict):
return convert_seg_to_bounding_box_coordinates(data_dict, self.dim, self.roi_item_keys, self.get_rois_from_seg,
self.class_specific_seg)
elif dim == 2:
# cut if dimension got too long
img_up = img_up[:img.shape[0], :img.shape[1]]
# pad with 0 if dimension too small
img_padded = np.zeros((img.shape[0], img.shape[1]))
img_padded[:img_up.shape[0], :img_up.shape[1]] = img_up
data[sample_idx, channel_idx] = img_padded
else:
raise ValueError("Invalid dimension size")
return data
class ResampleTransformLegacy(AbstractTransform):
"""
This is no longer part of batchgenerators, so we have an implementation here.
CPU always 100% when using this, but batch_time on cluster not longer (1s)
Downsamples each sample (linearly) by a random factor and upsamples to original resolution again (nearest neighbor)
Info:
* Uses scipy zoom for resampling.
* Resamples all dimensions (channels, x, y, z) with same downsampling factor
(like isotropic=True from linear_downsampling_generator_nilearn)
Args:
zoom_range (tuple of float): Random downscaling factor in this range. (e.g.: 0.5 halfs the resolution)
"""
def __init__(self, zoom_range=(0.5, 1)):
self.zoom_range = zoom_range
data[id, 0] *= -1
data[id, 3] *= -1
data[id, 6] *= -1
elif axis == "y":
data[id, 1] *= -1
data[id, 4] *= -1
data[id, 7] *= -1
elif axis == "z":
data[id, 2] *= -1
data[id, 5] *= -1
data[id, 8] *= -1
return data
class FlipVectorAxisTransform(AbstractTransform):
"""
Expects as input an image with 3 3D-vectors at each voxels, encoded as a nine-channel image. Will randomly
flip sign of one dimension of all 3 vectors (x, y or z).
"""
def __init__(self, axes=(2, 3, 4), data_key="data"):
self.data_key = data_key
self.axes = axes
def __call__(self, **data_dict):
data_dict[self.data_key] = flip_vector_axis(data=data_dict[self.data_key])
return data_dict
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from warnings import warn
from batchgenerators.transforms.abstract_transforms import AbstractTransform
class ReorderSegTransform(AbstractTransform):
"""
Yields reordered seg (needed for DataAugmentation: x&y have to be last 2 dims and nr_classes must be before, for DataAugmentation to work)
-> here we move it back to (bs, x, y, nr_classes) for easy calculating of f1
"""
def __init__(self):
pass
def __call__(self, **data_dict):
seg = data_dict.get("seg")
if seg is None:
warn("You used ReorderSegTransform but there is no 'seg' key in your data_dict, returning data_dict unmodified", Warning)
else:
seg = data_dict["seg"] # (bs, nr_of_classes, x, y)
data_dict["seg"] = seg.transpose(0, 2, 3, 1) # (bs, x, y, nr_of_classes)
return data_dict