Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def get_dataloader(net, val_dataset, val_transform, batch_size, num_workers):
"""Get dataloader."""
val_bfn = batchify.Tuple(*[batchify.Append() for _ in range(4)])
short = net.short[-1] if isinstance(net.short, (tuple, list)) else net.short
val_loader = mx.gluon.data.DataLoader(
val_dataset.transform(val_transform(short, net.max_size, net.base_stride)),
batch_size, False, batchify_fn=val_bfn, last_batch='keep', num_workers=num_workers)
return val_loader
def get_dataloader(net, val_dataset, batch_size, num_workers):
"""Get dataloader."""
val_bfn = batchify.Tuple(*[batchify.Append() for _ in range(2)])
val_loader = mx.gluon.data.DataLoader(val_dataset, batch_size, False,
batchify_fn=val_bfn, last_batch='keep', num_workers=num_workers)
return val_loader
def get_dataloader(net, train_dataset, val_dataset, train_transform, val_transform, batch_size,
num_shards_per_process, args):
"""Get dataloader."""
train_bfn = batchify.MaskRCNNTrainBatchify(net, num_shards_per_process)
train_sampler = \
gcv.nn.sampler.SplitSortedBucketSampler(train_dataset.get_im_aspect_ratio(),
batch_size,
num_parts=hvd.size() if args.horovod else 1,
part_index=hvd.rank() if args.horovod else 0,
shuffle=True)
train_loader = mx.gluon.data.DataLoader(train_dataset.transform(
train_transform(net.short, net.max_size, net, ashape=net.ashape, multi_stage=args.use_fpn)),
batch_sampler=train_sampler, batchify_fn=train_bfn, num_workers=args.num_workers)
val_bfn = batchify.Tuple(*[batchify.Append() for _ in range(2)])
short = net.short[-1] if isinstance(net.short, (tuple, list)) else net.short
# validation use 1 sample per device
val_loader = mx.gluon.data.DataLoader(
val_dataset.transform(val_transform(short, net.max_size)), num_shards_per_process, False,
batchify_fn=val_bfn, last_batch='keep', num_workers=args.num_workers)
return train_loader, val_loader
def get_dataloader(net, train_dataset, val_dataset, train_transform, val_transform, batch_size,
num_workers, multi_stage):
"""Get dataloader."""
train_bfn = batchify.Tuple(*[batchify.Append() for _ in range(6)])
train_loader = mx.gluon.data.DataLoader(
train_dataset.transform(train_transform(net.short, net.max_size, net, ashape=net.ashape,
multi_stage=multi_stage)),
batch_size, True, batchify_fn=train_bfn, last_batch='rollover', num_workers=num_workers)
val_bfn = batchify.Tuple(*[batchify.Append() for _ in range(2)])
val_loader = mx.gluon.data.DataLoader(
val_dataset.transform(val_transform(net.short, net.max_size)),
batch_size, False, batchify_fn=val_bfn, last_batch='keep', num_workers=num_workers)
return train_loader, val_loader
def get_dataloader(net, train_dataset, val_dataset, batch_size, num_workers):
"""Get dataloader."""
train_bfn = batchify.Tuple(*[batchify.Append() for _ in range(5)])
train_loader = mx.gluon.data.DataLoader(
train_dataset.transform(FasterRCNNDefaultTrainTransform(net.short, net.max_size, net)),
batch_size, True, batchify_fn=train_bfn, last_batch='rollover', num_workers=num_workers)
val_bfn = batchify.Tuple(*[batchify.Append() for _ in range(3)])
val_loader = mx.gluon.data.DataLoader(
val_dataset.transform(FasterRCNNDefaultValTransform(net.short, net.max_size)),
batch_size, False, batchify_fn=val_bfn, last_batch='keep', num_workers=num_workers)
return train_loader, val_loader
def get_dataloader(net, val_dataset, batch_size, num_workers):
"""Get dataloader."""
val_bfn = batchify.Tuple(*[batchify.Append() for _ in range(3)])
val_loader = mx.gluon.data.DataLoader(
val_dataset.transform(FasterRCNNDefaultValTransform(net.short, net.max_size)),
batch_size, False, batchify_fn=val_bfn, last_batch='keep', num_workers=num_workers)
return val_loader
def get_dataloader(net, val_dataset, batch_size, num_workers):
"""Get dataloader."""
val_bfn = batchify.Tuple(*[batchify.Append() for _ in range(3)])
val_loader = mx.gluon.data.DataLoader(
val_dataset.transform(FasterRCNNDefaultValTransform(net.short, net.max_size)),
batch_size, False, batchify_fn=val_bfn, last_batch='keep', num_workers=num_workers)
return val_loader
def get_dataloader(net, train_dataset, val_dataset, batch_size, num_workers):
"""Get dataloader."""
train_bfn = batchify.Tuple(*[batchify.Append() for _ in range(5)])
train_loader = mx.gluon.data.DataLoader(
train_dataset.transform(FasterRCNNDefaultTrainTransform(net.short, net.max_size, net)),
batch_size, True, batchify_fn=train_bfn, last_batch='rollover', num_workers=num_workers)
val_bfn = batchify.Tuple(*[batchify.Append() for _ in range(3)])
val_loader = mx.gluon.data.DataLoader(
val_dataset.transform(FasterRCNNDefaultValTransform(net.short, net.max_size)),
batch_size, False, batchify_fn=val_bfn, last_batch='keep', num_workers=num_workers)
return train_loader, val_loader
def get_dataloader(net, train_dataset, val_dataset, train_transform, val_transform, batch_size,
num_workers):
"""Get dataloader."""
train_bfn = batchify.Tuple(*[batchify.Append() for _ in range(5)])
train_loader = mx.gluon.data.DataLoader(
train_dataset.transform(train_transform(
net.short, net.max_size, net.base_stride, net.valid_range)),
batch_size, True, batchify_fn=train_bfn, last_batch='rollover',
num_workers=num_workers)
val_bfn = batchify.Tuple(*[batchify.Append() for _ in range(4)])
short = net.short[-1] if isinstance(net.short, (tuple, list)) else net.short
val_loader = mx.gluon.data.DataLoader(
val_dataset.transform(val_transform(short, net.max_size, net.base_stride)),
batch_size, False, batchify_fn=val_bfn, last_batch='keep', num_workers=num_workers)
return train_loader, val_loader