Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_sharded_data_loader():
X = np.random.uniform(size=(100, 20))
Y = np.random.uniform(size=(100,))
dataset = gluon.data.ArrayDataset(X, Y)
loader = ShardedDataLoader(dataset, 2)
for i, (x, y) in enumerate(loader):
assert mx.test_utils.almost_equal(x.asnumpy(), X[i*2:(i+1)*2])
assert mx.test_utils.almost_equal(y.asnumpy(), Y[i*2:(i+1)*2])
num_shards = 4
batch_sampler = FixedBucketSampler(lengths=[X.shape[1]] * X.shape[0],
batch_size=2,
num_buckets=1,
shuffle=False,
num_shards=num_shards)
for thread_pool in [True, False]:
for num_workers in [0, 1, 2, 3, 4]:
loader = ShardedDataLoader(dataset, batch_sampler=batch_sampler, num_workers=num_workers, thread_pool=thread_pool)
for i, seqs in enumerate(loader):
assert len(seqs) == num_shards
for j in range(num_shards):
if i != len(loader) - 1:
assert mx.test_utils.almost_equal(seqs[j][0].asnumpy(),
X[(i*num_shards+j)*2:(i*num_shards+j+1)*2])
assert mx.test_utils.almost_equal(seqs[j][1].asnumpy(),
Y[(i*num_shards+j)*2:(i*num_shards+j+1)*2])
else:
else:
raise NotImplementedError
data_lengths = get_data_lengths(data_set)
if dataset_type == 'train':
train_batchify_fn = btf.Tuple(btf.Pad(pad_val=0), btf.Pad(pad_val=0),
btf.Stack(dtype='float32'), btf.Stack(dtype='float32'))
else:
data_lengths = list(map(lambda x: x[-1], data_lengths))
test_batchify_fn = btf.Tuple(btf.Pad(pad_val=0), btf.Pad(pad_val=0),
btf.Stack(dtype='float32'), btf.Stack(dtype='float32'),
btf.Stack())
batch_sampler = nlp.data.FixedBucketSampler(lengths=data_lengths,
batch_size=(args.batch_size \
if dataset_type == 'train' \
else args.test_batch_size),
num_buckets=args.num_buckets,
ratio=args.bucket_ratio,
shuffle=(dataset_type == 'train'),
use_average_length=use_average_length,
num_shards=num_shards,
bucket_scheme=bucket_scheme)
if dataset_type == 'train':
logging.info('Train Batch Sampler:\n%s', batch_sampler.stats())
data_loader = nlp.data.ShardedDataLoader(data_set,
batch_sampler=batch_sampler,
batchify_fn=train_batchify_fn,
num_workers=num_workers)
def __call__(self, dataset):
"""Create data sampler based on the dataset"""
if isinstance(dataset, nlp.data.NumpyDataset):
lengths = dataset.get_field('valid_lengths')
else:
# dataset is a BERTPretrainDataset:
lengths = dataset.transform(lambda input_ids, segment_ids, masked_lm_positions, \
masked_lm_ids, masked_lm_weights, \
next_sentence_labels, valid_lengths: \
valid_lengths, lazy=False)
# calculate total batch size for all GPUs
batch_size = self._batch_size * self._num_ctxes
sampler = nlp.data.FixedBucketSampler(lengths,
batch_size=batch_size,
num_buckets=self._num_buckets,
ratio=0,
shuffle=self._shuffle)
logging.debug('Sampler created for a new dataset:\n%s', sampler.stats())
return sampler
def prepare_data_loader(args, dataset, vocab, test=False):
"""
Read data and build data loader.
"""
# Preprocess
dataset = dataset.transform(lambda s1, s2, label: (vocab(s1), vocab(s2), label),
lazy=False)
# Batching
batchify_fn = btf.Tuple(btf.Pad(pad_val=0), btf.Pad(pad_val=0), btf.Stack(dtype='int32'))
data_lengths = [max(len(d[0]), len(d[1])) for d in dataset]
batch_sampler = nlp.data.FixedBucketSampler(lengths=data_lengths,
batch_size=args.batch_size,
shuffle=(not test))
data_loader = gluon.data.DataLoader(dataset=dataset,
batch_sampler=batch_sampler,
batchify_fn=batchify_fn)
return data_loader
valid_dataset, valid_data_lengths = preprocess_dataset(valid_dataset)
test_dataset, test_data_lengths = preprocess_dataset(test_dataset)
# Construct the DataLoader. Pad data and stack label
batchify_fn = nlp.data.batchify.Tuple(nlp.data.batchify.Pad(axis=0, pad_val=0, ret_length=True),
nlp.data.batchify.Stack(dtype='float32'))
if args.bucket_type is None:
print('Bucketing strategy is not used!')
train_dataloader = DataLoader(dataset=train_dataset,
batch_size=args.batch_size,
shuffle=True,
batchify_fn=batchify_fn)
else:
if args.bucket_type == 'fixed':
print('Use FixedBucketSampler')
batch_sampler = nlp.data.FixedBucketSampler(train_data_lengths,
batch_size=args.batch_size,
num_buckets=args.bucket_num,
ratio=args.bucket_ratio,
shuffle=True)
print(batch_sampler.stats())
elif args.bucket_type == 'sorted':
print('Use SortedBucketSampler')
batch_sampler = nlp.data.SortedBucketSampler(train_data_lengths,
batch_size=args.batch_size,
mult=args.bucket_mult,
shuffle=True)
else:
raise NotImplementedError
train_dataloader = DataLoader(dataset=train_dataset,
batch_sampler=batch_sampler,
batchify_fn=batchify_fn)
batchify_fn = Tuple(Pad(), Pad(), Pad(), Pad(), Stack(), Pad(), Stack())
if self._use_avg_len:
# sharded data loader
sampler = nlp.data.FixedBucketSampler(lengths=lengths,
# batch_size per shard
batch_size=self._batch_size,
num_buckets=self._num_buckets,
shuffle=self._shuffle,
use_average_length=True,
num_shards=self._num_ctxes)
dataloader = nlp.data.ShardedDataLoader(dataset,
batch_sampler=sampler,
batchify_fn=batchify_fn,
num_workers=self._num_ctxes)
else:
sampler = nlp.data.FixedBucketSampler(lengths,
batch_size=self._batch_size * self._num_ctxes,
num_buckets=self._num_buckets,
ratio=0,
shuffle=self._shuffle)
dataloader = DataLoader(dataset=dataset,
batch_sampler=sampler,
batchify_fn=batchify_fn,
num_workers=1)
logging.debug('Sampler created for a new dataset:\n%s', sampler.stats())
return dataloader
def transform(raw_data, params):
# 定义数据转换接口
# raw_data --> batch_data
num_buckets = params.num_buckets
batch_size = params.batch_size
responses = raw_data
batch_idxes = FixedBucketSampler([len(rs) for rs in responses], batch_size, num_buckets=num_buckets)
batch = []
def index(r):
correct = 0 if r[1] <= 0 else 1
return r[0] * 2 + correct
for batch_idx in tqdm(batch_idxes, "batchify"):
batch_rs = []
batch_pick_index = []
batch_labels = []
for idx in batch_idx:
batch_rs.append([index(r) for r in responses[idx]])
if len(responses[idx]) <= 1:
pick_index, labels = [], []
else:
pick_index, labels = zip(*[(r[0], 0 if r[1] <= 0 else 1) for r in responses[idx][1:]])
def transform(raw_data, params):
# 定义数据转换接口
# raw_data --> batch_data
num_buckets = params.num_buckets
batch_size = params.batch_size
responses = raw_data
batch_idxes = FixedBucketSampler([len(rs) for rs in responses], batch_size, num_buckets=num_buckets)
batch = []
def response_index(r):
correct = 0 if r[1] <= 0 else 1
return r[0] * 2 + correct
def question_index(r):
return r[0]
for batch_idx in tqdm(batch_idxes, "batchify"):
batch_qs = []
batch_rs = []
batch_labels = []
for idx in batch_idx:
batch_qs.append([question_index(r) for r in responses[idx]])
batch_rs.append([response_index(r) for r in responses[idx]])
def data_loader(self, docs: Sequence[Document], batch_size, shuffle=False, label=True, **kwargs) -> DataLoader:
if label is True and self.label_map is None:
raise ValueError('Please specify label_map')
batchify_fn = kwargs.get('batchify_fn', sequence_batchify_fn)
bucket = kwargs.get('bucket', False)
num_buckets = kwargs.get('num_buckets', 10)
ratio = kwargs.get('ratio', 0)
dataset = SequencesDataset(docs=docs, embs=self.embs, key=self.key, label_map=self.label_map, label=label)
if bucket is True:
dataset_lengths = list(map(lambda x: float(len(x[3])), dataset))
batch_sampler = FixedBucketSampler(dataset_lengths, batch_size=batch_size, num_buckets=num_buckets, ratio=ratio, shuffle=shuffle)
return DataLoader(dataset=dataset, batch_sampler=batch_sampler, batchify_fn=batchify_fn)
else:
return DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle, batchify_fn=batchify_fn)