Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
mirror_transform = Mirror(axes=np.arange(2, cf.dim+2, 1))
my_transforms.append(mirror_transform)
spatial_transform = SpatialTransform(patch_size=cf.patch_size[:cf.dim],
patch_center_dist_from_border=cf.da_kwargs['rand_crop_dist'],
do_elastic_deform=cf.da_kwargs['do_elastic_deform'],
alpha=cf.da_kwargs['alpha'], sigma=cf.da_kwargs['sigma'],
do_rotation=cf.da_kwargs['do_rotation'], angle_x=cf.da_kwargs['angle_x'],
angle_y=cf.da_kwargs['angle_y'], angle_z=cf.da_kwargs['angle_z'],
do_scale=cf.da_kwargs['do_scale'], scale=cf.da_kwargs['scale'],
random_crop=cf.da_kwargs['random_crop'])
my_transforms.append(spatial_transform)
else:
my_transforms.append(CenterCropTransform(crop_size=cf.patch_size[:cf.dim]))
my_transforms.append(ConvertSegToBoundingBoxCoordinates(cf.dim, get_rois_from_seg_flag=False, class_specific_seg_flag=cf.class_specific_seg_flag))
all_transforms = Compose(my_transforms)
# multithreaded_generator = SingleThreadedAugmenter(data_gen, all_transforms)
multithreaded_generator = MultiThreadedAugmenter(data_gen, all_transforms, num_processes=cf.n_workers, seeds=range(cf.n_workers))
return multithreaded_generator
mirror_transform = Mirror(axes=np.arange(cf.dim))
my_transforms.append(mirror_transform)
spatial_transform = SpatialTransform(patch_size=cf.patch_size[:cf.dim],
patch_center_dist_from_border=cf.da_kwargs['rand_crop_dist'],
do_elastic_deform=cf.da_kwargs['do_elastic_deform'],
alpha=cf.da_kwargs['alpha'], sigma=cf.da_kwargs['sigma'],
do_rotation=cf.da_kwargs['do_rotation'], angle_x=cf.da_kwargs['angle_x'],
angle_y=cf.da_kwargs['angle_y'], angle_z=cf.da_kwargs['angle_z'],
do_scale=cf.da_kwargs['do_scale'], scale=cf.da_kwargs['scale'],
random_crop=cf.da_kwargs['random_crop'])
my_transforms.append(spatial_transform)
else:
my_transforms.append(CenterCropTransform(crop_size=cf.patch_size[:cf.dim]))
my_transforms.append(ConvertSegToBoundingBoxCoordinates(cf.dim, get_rois_from_seg_flag=False, class_specific_seg_flag=cf.class_specific_seg_flag))
all_transforms = Compose(my_transforms)
# multithreaded_generator = SingleThreadedAugmenter(data_gen, all_transforms)
multithreaded_generator = MultiThreadedAugmenter(data_gen, all_transforms, num_processes=cf.n_workers, seeds=range(cf.n_workers))
return multithreaded_generator
def generate_train_batch(self):
pid = self.dataset_pids[self.patient_ix]
patient = self._data[pid]
all_data = np.load(patient['data'], mmap_mode='r')
data = all_data[0]
seg = all_data[1].astype('uint8')
batch_class_targets = np.array([patient['class_target']])
out_data = data[None, None]
out_seg = seg[None, None]
print('check patient data loader', out_data.shape, out_seg.shape)
batch_2D = {'data': out_data, 'seg': out_seg, 'class_target': batch_class_targets, 'pid': pid}
converter = ConvertSegToBoundingBoxCoordinates(dim=2, get_rois_from_seg_flag=False, class_specific_seg_flag=self.cf.class_specific_seg_flag)
batch_2D = converter(**batch_2D)
batch_2D.update({'patient_bb_target': batch_2D['bb_target'],
'patient_roi_labels': batch_2D['roi_labels'],
'original_img_shape': out_data.shape})
self.patient_ix += 1
if self.patient_ix == len(self.dataset_pids):
self.patient_ix = 0
return batch_2D
if self.cf.dim == 2:
if self.cf.n_3D_context is not None:
data = np.transpose(data[:, 0], axes=(0, 3, 1, 2))
else:
# all patches have z dimension 1 (slices). discard dimension
data = data[..., 0]
seg = seg[..., 0]
patch_batch = {'data': data, 'seg': seg, 'class_target': batch_class_targets, 'pid': pid}
patch_batch['patch_crop_coords'] = np.array(patch_crop_coords_list)
patch_batch['patient_bb_target'] = patient_batch['patient_bb_target']
patch_batch['patient_roi_labels'] = patient_batch['patient_roi_labels']
patch_batch['original_img_shape'] = patient_batch['original_img_shape']
converter = ConvertSegToBoundingBoxCoordinates(self.cf.dim, get_rois_from_seg_flag=False, class_specific_seg_flag=self.cf.class_specific_seg_flag)
patch_batch = converter(**patch_batch)
out_batch = patch_batch
self.patient_ix += 1
if self.patient_ix == len(self.dataset_pids):
self.patient_ix = 0
return out_batch
if self.cf.dim == 2:
out_data = np.transpose(data, axes=(3, 0, 1, 2)) # (z, c, x, y )
out_seg = np.transpose(seg, axes=(2, 0, 1))[:, np.newaxis]
out_targets = np.array(np.repeat(batch_class_targets, out_data.shape[0], axis=0))
# if set to not None, add neighbouring slices to each selected slice in channel dimension.
if self.cf.n_3D_context is not None:
slice_range = range(self.cf.n_3D_context, out_data.shape[0] + self.cf.n_3D_context)
out_data = np.pad(out_data, ((self.cf.n_3D_context, self.cf.n_3D_context), (0, 0), (0, 0), (0, 0)), 'constant', constant_values=0)
out_data = np.array(
[np.concatenate([out_data[ii] for ii in range(
slice_id - self.cf.n_3D_context, slice_id + self.cf.n_3D_context + 1)], axis=0) for slice_id in
slice_range])
batch_2D = {'data': out_data, 'seg': out_seg, 'class_target': out_targets, 'pid': pid}
converter = ConvertSegToBoundingBoxCoordinates(dim=2, get_rois_from_seg_flag=False, class_specific_seg_flag=self.cf.class_specific_seg_flag)
batch_2D = converter(**batch_2D)
if self.cf.merge_2D_to_3D_preds:
batch_2D.update({'patient_bb_target': batch_3D['patient_bb_target'],
'patient_roi_labels': batch_3D['patient_roi_labels'],
'original_img_shape': out_data.shape})
else:
batch_2D.update({'patient_bb_target': batch_2D['bb_target'],
'patient_roi_labels': batch_2D['roi_labels'],
'original_img_shape': out_data.shape})
out_batch = batch_3D if self.cf.dim == 3 else batch_2D
patient_batch = out_batch
# crop patient-volume to patches of patch_size used during training. stack patches up in batch dimension.
# in this case, 2D is treated as a special case of 3D with patch_size[z] = 1.