Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_load():
from opts import opts
opt = opts().init()
batch_size = 16
#batchify_fn = Tuple(Stack(), Stack(), Stack(), Stack()) # stack image, heatmaps, scale, offset
batchify_fn = Tuple(Stack(), Stack(), Stack(), Stack(), Stack(), Stack()) # stack image, heatmaps, scale, offset, ind, mask
num_workers = 2
train_dataset = CenterCOCODataset(opt, split = 'train')
train_loader = gluon.data.DataLoader( train_dataset,
batch_size, True, batchify_fn=batchify_fn, last_batch='rollover', num_workers=num_workers)
ctx = [mx.gpu(int(i)) for i in opt.gpus_str.split(',') if i.strip()]
ctx = ctx if ctx else [mx.cpu()]
for i, batch in enumerate(train_loader):
print("{} Batch".format(i))
print("image batch shape: ", batch[0].shape)
print("heatmap batch shape", batch[1].shape)
print("scale batch shape", batch[2].shape)
print("offset batch shape", batch[3].shape)
print("indices batch shape", batch[4].shape)
print("mask batch shape", batch[5].shape)
def get_dataloader(train_dataset, data_shape, batch_size, num_workers, ctx):
"""Get dataloader."""
width, height = data_shape, data_shape
batchify_fn = Tuple(Stack(), Stack(), Stack(), Stack(), Stack(), Stack(), Stack(), Stack(), Stack(), Stack(), Stack())
train_loader = gluon.data.DataLoader(train_dataset, batch_size, True, batchify_fn=batchify_fn, last_batch='rollover', num_workers=num_workers)
return train_loader
def batchify_val_fn():
return Tuple(Stack(), Pad(pad_val=-1))
def get_dataloader(net, train_dataset, val_dataset, data_shape, batch_size, num_workers, ctx):
"""Get dataloader."""
width, height = data_shape, data_shape
# use fake data to generate fixed anchors for target generation
with autograd.train_mode():
_, _, anchors = net(mx.nd.zeros((1, 3, height, width), ctx))
anchors = anchors.as_in_context(mx.cpu())
batchify_fn = Tuple(Stack(), Stack(), Stack()) # stack image, cls_targets, box_targets
train_loader = gluon.data.DataLoader(
train_dataset.transform(SSDDefaultTrainTransform(width, height, anchors)),
batch_size, True, batchify_fn=batchify_fn, last_batch='rollover', num_workers=num_workers)
val_batchify_fn = Tuple(Stack(), Pad(pad_val=-1))
val_loader = gluon.data.DataLoader(
val_dataset.transform(SSDDefaultValTransform(width, height)),
batch_size, False, batchify_fn=val_batchify_fn, last_batch='keep', num_workers=num_workers)
return train_loader, val_loader
x * 32,
x * 32,
net,
mixup=args.mixup) for x in range(
10,
20)]
train_loader = RandomTransformDataLoader(
transform_fns,
train_dataset,
batch_size=batch_size,
interval=10,
last_batch='rollover',
shuffle=True,
batchify_fn=batchify_fn,
num_workers=num_workers)
val_batchify_fn = Tuple(Stack(), Pad(pad_val=-1))
val_loader = gluon.data.DataLoader(
val_dataset.transform(
YOLO3DefaultValTransform(
width,
height)),
batch_size,
False,
batchify_fn=val_batchify_fn,
last_batch='keep',
num_workers=num_workers)
return train_loader, val_loader
def get_dataloader(train_dataset, data_shape, batch_size, num_workers, ctx):
"""Get dataloader."""
width, height = data_shape, data_shape
batchify_fn = Tuple(Stack(), Stack(), Stack(), Stack(), Stack(), Stack(), Stack(), Stack(), Stack(), Stack(), Stack(), Stack())
train_loader = gluon.data.DataLoader(train_dataset, batch_size, True, batchify_fn=batchify_fn, last_batch='rollover', num_workers=num_workers)
return train_loader
def get_dataloader(net, train_dataset, data_shape, batch_size, num_workers):
from gluoncv.data.batchify import Tuple, Stack, Pad
from gluoncv.data.transforms.presets.ssd import SSDDefaultTrainTransform
width, height = data_shape, data_shape
# use fake data to generate fixed anchors for target generation
with autograd.train_mode():
_, _, anchors = net(mx.nd.zeros((1, 3, height, width)))
batchify_fn = Tuple(Stack(), Stack(), Stack()) # stack image, cls_targets, box_targets
train_loader = gluon.data.DataLoader(
train_dataset.transform(SSDDefaultTrainTransform(width, height, anchors)),
batch_size, True, batchify_fn=batchify_fn, last_batch='rollover', num_workers=num_workers)
return train_loader