Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
assert name in ['train', 'val', 'test']
assert datadir is not None
assert isinstance(augmentors, list)
isTrain = name == 'train'
#parallel = 1
if parallel is None:
parallel = min(40, multiprocessing.cpu_count() // 2) # assuming hyperthreading
if isTrain:
ds = dataset.ILSVRC12(datadir, name, meta_dir=meta_dir, shuffle=True)
ds = AugmentImageComponent(ds, augmentors, copy=False)
if parallel < 16:
logger.warn("DataFlow may become the bottleneck when too few processes are used.")
ds = PrefetchDataZMQ(ds, parallel)
ds = BatchData(ds, batch_size, remainder=False)
else:
ds = dataset.ILSVRC12Files(datadir, name, meta_dir= meta_dir, shuffle=False)
aug = imgaug.AugmentorList(augmentors)
def mapf(dp):
fname, cls = dp
im = cv2.imread(fname, cv2.IMREAD_COLOR)
im = aug.augment(im)
return im, cls
ds = MultiThreadMapData(ds, parallel, mapf, buffer_size=2000, strict=True)
ds = BatchData(ds, batch_size, remainder=True)
ds = PrefetchDataZMQ(ds, 1)
return ds
dtype='float32')[::-1, ::-1]
)]),
imgaug.Clip(),
imgaug.Flip(horiz=True),
imgaug.ToUint8()
]
else:
augmentors = [
imgaug.ResizeShortestEdge(256),
imgaug.CenterCrop((input_size, input_size)),
imgaug.ToUint8()
]
ds = AugmentImageComponent(ds, augmentors, copy=False)
if do_multiprocess:
ds = PrefetchDataZMQ(ds, min(24, multiprocessing.cpu_count()))
ds = BatchData(ds, options.batch_size // options.nr_gpu, remainder=not isTrain)
return ds
assert name in ['train', 'val', 'test']
isTrain = name == 'train'
assert datadir is not None
if augmentors is None:
augmentors = fbresnet_augmentor(isTrain)
assert isinstance(augmentors, list)
if parallel is None:
parallel = min(40, multiprocessing.cpu_count() // 2) # assuming hyperthreading
if isTrain:
ds = dataset.ILSVRC12(datadir, name, shuffle=True)
ds = AugmentImageComponent(ds, augmentors, copy=False)
if parallel < 16:
logger.warn("DataFlow may become the bottleneck when too few processes are used.")
ds = MultiProcessRunnerZMQ(ds, parallel)
ds = BatchData(ds, batch_size, remainder=False)
else:
ds = dataset.ILSVRC12Files(datadir, name, shuffle=False)
aug = imgaug.AugmentorList(augmentors)
def mapf(dp):
fname, cls = dp
im = cv2.imread(fname, cv2.IMREAD_COLOR)
im = aug.augment(im)
return im, cls
ds = MultiThreadMapData(ds, parallel, mapf, buffer_size=2000, strict=True)
ds = BatchData(ds, batch_size, remainder=True)
ds = MultiProcessRunnerZMQ(ds, 1)
return ds
def get_augmented_speech_commands_data(subset, options,
do_multiprocess=True, shuffle=True):
isTrain = subset == 'train' and do_multiprocess
shuffle = shuffle if shuffle is not None else isTrain
ds = SpeechCommandsDataFlow(os.path.join(options.data_dir, 'speech_commands_v0.02'),
subset, shuffle, None)
if isTrain:
add_noise_func = functools.partial(_add_noise, noises=ds.noises)
ds = MapDataComponent(ds, _pad_or_clip_to_desired_sample, index=0)
ds = MapDataComponent(ds, _to_float, index=0)
if isTrain:
ds = MapDataComponent(ds, _time_shift, index=0)
ds = MapData(ds, add_noise_func)
ds = BatchData(ds, options.batch_size // options.nr_gpu, remainder=not isTrain)
if do_multiprocess:
ds = PrefetchData(ds, 4, 4)
return ds
def critic_dataflow_factory(ctrl, data, is_train):
"""
Generate a critic dataflow
"""
if ctrl.critic_type == CriticTypes.CONV:
ds = ConvCriticDataFlow(data, shuffle=is_train, max_depth=ctrl.controller_max_depth)
ds = BatchData(ds, ctrl.controller_batch_size, remainder=not is_train, use_list=False)
elif ctrl.critic_type == CriticTypes.LSTM:
ds = LSTMCriticDataFlow(data, shuffle=is_train)
ds = BatchData(ds, ctrl.controller_batch_size, remainder=not is_train, use_list=True)
return ds
def get_data():
def f(dp):
im = dp[0][:, :, None]
onehot = np.eye(10)[dp[1]]
return [im, onehot]
train = BatchData(MapData(dataset.Mnist('train'), f), 128)
test = BatchData(MapData(dataset.Mnist('test'), f), 256)
return train, test
def get_data():
def f(dp):
im = dp[0][:, :, None]
onehot = np.eye(10)[dp[1]]
return [im, onehot]
train = BatchData(MapData(dataset.Mnist('train'), f), 128)
test = BatchData(MapData(dataset.Mnist('test'), f), 256)
return train, test
ds = PrefetchDataZMQ(ds, parallel)
ds = BatchData(ds, batch_size, remainder=False)
else:
ds = dataset.ILSVRC12Files(datadir, "val", shuffle=False)
aug = imgaug.AugmentorList(augmentors)
def mapf(dp):
fname, cls = dp
im = cv2.imread(fname, cv2.IMREAD_COLOR)
im = np.flip(im, axis=2)
# print("fname={}".format(fname))
im = aug.augment(im)
return im, cls
ds = MultiThreadMapData(ds, parallel, mapf, buffer_size=2000, strict=True)
# ds = MapData(ds, mapf)
ds = BatchData(ds, batch_size, remainder=True)
ds = PrefetchDataZMQ(ds, 1)
# ds = PrefetchData(ds, 1)
return ds
ds = AugmentImageComponent(ds, augmentors, copy=False)
if parallel < 16:
logger.warn("DataFlow may become the bottleneck when too few processes are used.")
ds = PrefetchDataZMQ(ds, parallel)
ds = BatchData(ds, batch_size, remainder=False)
else:
ds = dataset.ILSVRC12Files(datadir, name, meta_dir= meta_dir, shuffle=False)
aug = imgaug.AugmentorList(augmentors)
def mapf(dp):
fname, cls = dp
im = cv2.imread(fname, cv2.IMREAD_COLOR)
im = aug.augment(im)
return im, cls
ds = MultiThreadMapData(ds, parallel, mapf, buffer_size=2000, strict=True)
ds = BatchData(ds, batch_size, remainder=True)
ds = PrefetchDataZMQ(ds, 1)
return ds