Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
min_crop = int(sample_seg_center[ii] - self.cf.pre_crop_size[ii] // 2)
max_crop = int(sample_seg_center[ii] + self.cf.pre_crop_size[ii] // 2)
data = np.take(data, indices=range(min_crop, max_crop), axis=ii + 1)
seg = np.take(seg, indices=range(min_crop, max_crop), axis=ii)
batch_data.append(data)
batch_segs.append(seg[np.newaxis])
data = np.array(batch_data)
seg = np.array(batch_segs).astype(np.uint8)
class_target = np.array(batch_targets)
return {'data': data, 'seg': seg, 'pid': batch_pids, 'class_target': class_target}
class PatientBatchIterator(SlimDataLoaderBase):
"""
creates a test generator that iterates over entire given dataset returning 1 patient per batch.
Can be used for monitoring if cf.val_mode = 'patient_val' for a monitoring closer to actualy evaluation (done in 3D),
if willing to accept speed-loss during training.
:return: out_batch: dictionary containing one patient with batch_size = n_3D_patches in 3D or
batch_size = n_2D_patches in 2D .
"""
def __init__(self, data, cf): #threads in augmenter
super(PatientBatchIterator, self).__init__(data, 0)
self.cf = cf
self.patient_ix = 0
self.dataset_pids = [v['pid'] for (k, v) in data.items()]
self.patch_size = cf.patch_size
if len(self.patch_size) == 2:
self.patch_size = self.patch_size + [1]
print_f('{}: {} rois seen ({:.1f}%).'.format(name, count, count / total_count * 100))
total_samples = self.cf.num_epochs*self.cf.num_train_batches*self.cf.batch_size
empties = [
'{}: {} ({:.1f}%)'.format(str(name), self.stats['empty_counts'][tix],
self.stats['empty_counts'][tix]/total_samples*100)
for tix, name in enumerate(self.unique_ts)
]
empties = ", ".join(empties)
print_f('empty samples seen: {}\n'.format(empties))
if plot:
if plot_file is None:
plot_file = os.path.join(self.plot_dir, "train_gen_stats_{}.png".format(self.cf.fold))
os.makedirs(self.plot_dir, exist_ok=True)
plg.plot_batchgen_stats(self.cf, self.stats, empties, self.balance_target, self.unique_ts, plot_file)
class PatientBatchIterator(SlimDataLoaderBase):
"""
creates a val/test generator. Step through the dataset and return dictionaries per patient.
2D is a special case of 3D patching with patch_size[2] == 1 (slices)
Creates whole Patient batch and targets, and - if necessary - patchwise batch and targets.
Appends patient targets anyway for evaluation.
For Patching, shifts all patches into batch dimension. batch_tiling_forward will take care of exceeding batch dimensions.
This iterator/these batches are not intended to go through MTaugmenter afterwards
"""
def __init__(self, cf, data):
super(PatientBatchIterator, self).__init__(data, 0)
self.cf = cf
self.dataset_length = len(self._data)
self.dataset_pids = list(self._data.keys())
angle_y=cf.da_kwargs['angle_y'], angle_z=cf.da_kwargs['angle_z'],
do_scale=cf.da_kwargs['do_scale'], scale=cf.da_kwargs['scale'],
random_crop=cf.da_kwargs['random_crop'])
my_transforms.append(spatial_transform)
else:
my_transforms.append(CenterCropTransform(crop_size=cf.patch_size[:cf.dim]))
my_transforms.append(ConvertSegToBoundingBoxCoordinates(cf.dim, get_rois_from_seg_flag=False, class_specific_seg_flag=cf.class_specific_seg_flag))
all_transforms = Compose(my_transforms)
# multithreaded_generator = SingleThreadedAugmenter(data_gen, all_transforms)
multithreaded_generator = MultiThreadedAugmenter(data_gen, all_transforms, num_processes=cf.n_workers, seeds=range(cf.n_workers))
return multithreaded_generator
class BatchGenerator(SlimDataLoaderBase):
"""
creates the training/validation batch generator. Samples n_batch_size patients (draws a slice from each patient if 2D)
from the data set while maintaining foreground-class balance. Returned patches are cropped/padded to pre_crop_size.
Actual patch_size is obtained after data augmentation.
:param data: data dictionary as provided by 'load_dataset'.
:param batch_size: number of patients to sample for the batch
:return dictionary containing the batch data (b, c, x, y, (z)) / seg (b, 1, x, y, (z)) / pids / class_target
"""
def __init__(self, data, batch_size, cf):
super(BatchGenerator, self).__init__(data, batch_size)
self.cf = cf
def generate_train_batch(self):
batch_data, batch_segs, batch_pids, batch_targets = [], [], [], []
if (cand_rarest_class != rarest_class and np.count_nonzero(class_targets[cand] == rarest_class) > 0) \
or ix < int(batch_size * random_ratio):
break
for c in range(1,num_classes+1):
class_counts[c] += np.count_nonzero(class_targets[pick] == c)
if not ix < int(batch_size * random_ratio) and class_counts[rarest_class] == 0: # means searched thru whole set without finding rarest class
print("Class {} not represented in current dataset.".format(rarest_class))
rarest_class = np.argmin(([class_counts[c] for c in range(1,num_classes+1)]))+1
batch_patients[ix] = pick
not_picked = not_picked[not_picked != pick] # removes pick
return batch_patients
class BatchGenerator(SlimDataLoaderBase):
"""
create the training/validation batch generator. Randomly sample batch_size patients
from the data set, (draw a random slice if 2D), pad-crop them to equal sizes and merge to an array.
:param data: data dictionary as provided by 'load_dataset'
:param img_modalities: list of strings ['adc', 'b1500'] from config
:param batch_size: number of patients to sample for the batch
:param pre_crop_size: equal size for merging the patients to a single array (before the final random-crop in data aug.)
:return dictionary containing the batch data / seg / pids as lists; the augmenter will later concatenate them into an array.
"""
def __init__(self, cf, data, sample_pids_w_replace=True, max_batches=None, raise_stop_iteration=False, n_threads=None, seed=0):
if n_threads is None:
n_threads = cf.n_workers
super(BatchGenerator, self).__init__(data, cf.batch_size, number_of_threads_in_multithreaded=n_threads)
self.cf = cf
self.random_count = int(cf.batch_random_ratio * cf.batch_size)
from __future__ import division
from __future__ import print_function
from os.path import join
import random
import numpy as np
from batchgenerators.dataloading.data_loader import SlimDataLoaderBase
from tractseg.libs.system_config import SystemConfig as C
"""
Info:
Dimensions order for DeepLearningBatchGenerator: (batch_size, channels, x, y, [z])
"""
class SlicesBatchGeneratorNpyImg_fusion(SlimDataLoaderBase):
'''
Returns 2D slices ordered way. Takes data in form of a npy file for each image. Npy file is already cropped to right size.
'''
def __init__(self, *args, **kwargs):
super(self.__class__, self).__init__(*args, **kwargs)
self.Config = None
self.global_idx = 0
def generate_train_batch(self):
subject = self._data[0]
data = np.load(join(C.DATA_PATH, self.Config.DATASET_FOLDER, subject, self.Config.FEATURES_FILENAME + ".npy"), mmap_mode="r")
seg = np.load(join(C.DATA_PATH, self.Config.DATASET_FOLDER, subject, self.Config.LABELS_FILENAME + ".npy"), mmap_mode="r")
if self.Config.SLICE_DIRECTION == "x":
end = data.shape[0]
y = np.nan_to_num(y)
# If we want only CA Binary
#Bundles together Order
# x = x[:, (0, 75, 150, 5, 80, 155), :, :]
# y = y[:, (0, 5), :, :]
#Mixed Order
# x = x[:, (0, 5, 75, 80, 150, 155), :, :]
# y = y[:, (0, 5), :, :]
data_dict = {"data": x, # (batch_size, channels, x, y, [z])
"seg": y} # (batch_size, channels, x, y, [z])
return data_dict
class SlicesBatchGeneratorRandomNpyImg_fusionMean(SlimDataLoaderBase):
'''
take mean of xyz channel and return slices (x,y,nrBundles)
'''
def __init__(self, *args, **kwargs):
super(self.__class__, self).__init__(*args, **kwargs)
self.Config = None
def generate_train_batch(self):
subjects = self._data[0]
subject_idx = int(random.uniform(0, len(subjects))) # len(subjects)-1 not needed because int always rounds to floor
data = np.load(join(C.DATA_PATH, self.Config.DATASET_FOLDER, subjects[subject_idx], self.Config.FEATURES_FILENAME + ".npy"), mmap_mode="r")
seg = np.load(join(C.DATA_PATH, self.Config.DATASET_FOLDER, subjects[subject_idx], self.Config.LABELS_FILENAME + ".npy"), mmap_mode="r")
# print("data 1: {}".format(data.shape))
#If we want only CA Binary
#Bundles together Order
# x = x[:, (0, 75, 150, 5, 80, 155), :, :]
# y = y[:, (0, 5), :, :]
#Mixed Order
# x = x[:, (0, 5, 75, 80, 150, 155), :, :]
# y = y[:, (0, 5), :, :]
data_dict = {"data": x, # (batch_size, channels, x, y, [z])
"seg": y} # (batch_size, channels, x, y, [z])
self.global_idx = new_global_idx
return data_dict
class SlicesBatchGeneratorRandomNpyImg_fusion(SlimDataLoaderBase):
'''
Randomly sample 2D slices from a npy file for each subject.
About 4s per 54-batch 75 bundles 1.25mm.
About 2s per 54-batch 45 bundles 1.25mm.
'''
def __init__(self, *args, **kwargs):
super(self.__class__, self).__init__(*args, **kwargs)
self.Config = None
def generate_train_batch(self):
subjects = self._data[0]
subject_idx = int(random.uniform(0, len(subjects))) # len(subjects)-1 not needed because int always rounds to floor
# data = np.load(join(C.DATA_PATH, self.Config.DATASET_FOLDER, subjects[subject_idx], self.Config.FEATURES_FILENAME + ".npy"), mmap_mode="r")
angle_y=cf.da_kwargs['angle_y'], angle_z=cf.da_kwargs['angle_z'],
do_scale=cf.da_kwargs['do_scale'], scale=cf.da_kwargs['scale'],
random_crop=cf.da_kwargs['random_crop'])
my_transforms.append(spatial_transform)
else:
my_transforms.append(CenterCropTransform(crop_size=cf.patch_size[:cf.dim]))
my_transforms.append(ConvertSegToBoundingBoxCoordinates(cf.dim, get_rois_from_seg_flag=False, class_specific_seg_flag=cf.class_specific_seg_flag))
all_transforms = Compose(my_transforms)
# multithreaded_generator = SingleThreadedAugmenter(data_gen, all_transforms)
multithreaded_generator = MultiThreadedAugmenter(data_gen, all_transforms, num_processes=cf.n_workers, seeds=range(cf.n_workers))
return multithreaded_generator
class BatchGenerator(SlimDataLoaderBase):
"""
creates the training/validation batch generator. Samples n_batch_size patients (draws a slice from each patient if 2D)
from the data set while maintaining foreground-class balance. Returned patches are cropped/padded to pre_crop_size.
Actual patch_size is obtained after data augmentation.
:param data: data dictionary as provided by 'load_dataset'.
:param batch_size: number of patients to sample for the batch
:return dictionary containing the batch data (b, c, x, y, (z)) / seg (b, 1, x, y, (z)) / pids / class_target
"""
def __init__(self, data, batch_size, cf):
super(BatchGenerator, self).__init__(data, batch_size)
self.cf = cf
self.crop_margin = np.array(self.cf.patch_size)/8. #min distance of ROI center to edge of cropped_patch.
self.p_fg = 0.5
def generate_train_batch(self):
x = pad_nd_image(x, shape_must_be_divisible_by=(16, 16), mode='constant', kwargs={'constant_values': 0})
y = pad_nd_image(y, shape_must_be_divisible_by=(16, 16), mode='constant', kwargs={'constant_values': 0})
# Does not make it slower
x = x.astype(np.float32)
y = y.astype(np.float32)
# possible optimization: sample slices from different patients and pad all to same size (size of biggest)
data_dict = {"data": x, # (batch_size, channels, x, y, [z])
"seg": y,
"slice_dir": slice_direction} # (batch_size, channels, x, y, [z])
return data_dict
class BatchGenerator2D_Npy_random(SlimDataLoaderBase):
"""
Takes image ID provided via self._data, loads the Npy (numpy array) image and randomly samples 2D slices from it.
Needed for fusion training.
Timing:
About 2s per 54-batch 45 bundles 1.25mm.
"""
def __init__(self, *args, **kwargs):
super(self.__class__, self).__init__(*args, **kwargs)
self.Config = None
def generate_train_batch(self):
subjects = self._data[0]
subject_idx = int(random.uniform(0, len(subjects)))
import random
import os
import numpy as np
import nibabel as nib
from batchgenerators.transforms.abstract_transforms import Compose
from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter
from batchgenerators.dataloading.data_loader import SlimDataLoaderBase
from tractseg.libs.system_config import SystemConfig as C
np.random.seed(1337)
class BatchGenerator2D_PrecomputedBatches(SlimDataLoaderBase):
"""
Loads precomputed batches
"""
def __init__(self, *args, **kwargs):
super(self.__class__, self).__init__(*args, **kwargs)
self.Config = None
def generate_train_batch(self):
type = self._data[0]
path = join(C.DATA_PATH, self.Config.DATASET_FOLDER, type)
nr_of_files = len([name for name in os.listdir(path) if os.path.isfile(join(path, name))]) - 1
idx = int(random.uniform(0, int(nr_of_files / 2.)))
data = nib.load(join(path, "batch_" + str(idx) + "_data.nii.gz")).get_data()
seg = nib.load(join(path, "batch_" + str(idx) + "_seg.nii.gz")).get_data()
return {"data": data, "seg": seg}