Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
# Typical training usage would use the `all_epochs` approach.
#
if args.all_epochs:
# Run training across all the epochs before testing for accuracy
loop_epochs = 1
reader_epochs = args.epochs
else:
# Test training accuracy after each epoch
loop_epochs = args.epochs
reader_epochs = 1
transform = TransformSpec(_transform_row, removed_fields=['idx'])
# Instantiate each petastorm Reader with a single thread, shuffle enabled, and appropriate epoch setting
for epoch in range(1, loop_epochs + 1):
with DataLoader(make_reader('{}/train'.format(args.dataset_url), num_epochs=reader_epochs,
transform_spec=transform),
batch_size=args.batch_size) as train_loader:
train(model, device, train_loader, args.log_interval, optimizer, epoch)
with DataLoader(make_reader('{}/test'.format(args.dataset_url), num_epochs=reader_epochs,
transform_spec=transform),
batch_size=args.test_batch_size) as test_loader:
test(model, device, test_loader)
loop_epochs = 1
reader_epochs = args.epochs
else:
# Test training accuracy after each epoch
loop_epochs = args.epochs
reader_epochs = 1
transform = TransformSpec(_transform_row, removed_fields=['idx'])
# Instantiate each petastorm Reader with a single thread, shuffle enabled, and appropriate epoch setting
for epoch in range(1, loop_epochs + 1):
with DataLoader(make_reader('{}/train'.format(args.dataset_url), num_epochs=reader_epochs,
transform_spec=transform),
batch_size=args.batch_size) as train_loader:
train(model, device, train_loader, args.log_interval, optimizer, epoch)
with DataLoader(make_reader('{}/test'.format(args.dataset_url), num_epochs=reader_epochs,
transform_spec=transform),
batch_size=args.test_batch_size) as test_loader:
test(model, device, test_loader)
def pytorch_hello_world(dataset_url='file:///tmp/hello_world_dataset'):
with DataLoader(make_reader(dataset_url)) as train_loader:
sample = next(iter(train_loader))
print(sample['id'])
# unequal number of samples
with make_batch_reader(remote_store.train_data_path,
num_epochs=None,
cur_shard=hvd.rank(),
shard_count=hvd.size(),
hdfs_driver=PETASTORM_HDFS_DRIVER,
schema_fields=schema_fields) as train_reader:
with make_batch_reader(remote_store.val_data_path,
num_epochs=None,
cur_shard=hvd.rank(),
shard_count=hvd.size(),
hdfs_driver=PETASTORM_HDFS_DRIVER,
schema_fields=schema_fields) \
if should_validate else empty_batch_reader() as val_reader:
train_loader = DataLoader(train_reader,
batch_size=batch_size,
shuffling_queue_capacity=shuffle_buffer_size)
train_loader_iter = iter(train_loader)
def prepare_batch(row):
inputs = [
prepare_np_data(
row[col].float(), col, metadata).reshape(shape)
for col, shape in zip(feature_columns, input_shapes)]
labels = [
prepare_np_data(
row[col].float(), col, metadata)
for col in label_columns]
sample_weights = row.get(sample_weight_col, None)
if cuda_available:
metric_cls, metric_fn_groups, label_columns, hvd)
# iterate on one epoch
for batch_idx in range(steps_per_epoch):
row = next(train_loader_iter)
inputs, labels, sample_weights = prepare_batch(row)
outputs, loss = train_minibatch(model, optimizer, transform_outputs,
loss_fn, inputs, labels, sample_weights)
update_metrics(metric_value_groups, outputs, labels)
train_loss.update(loss)
print_metrics(batch_idx, train_loss, metric_value_groups, 'train')
return aggregate_metrics('train', epoch, train_loss, metric_value_groups)
if should_validate:
val_loader = DataLoader(val_reader, batch_size=batch_size)
val_loader_iter = iter(val_loader)
if validation_steps_per_epoch is None:
validation_steps = int(math.ceil(float(val_rows) / batch_size / hvd.size()))
else:
validation_steps = validation_steps_per_epoch
def _validate(epoch):
model.eval()
val_loss = metric_cls('loss', hvd)
metric_value_groups = construct_metric_value_holders(
metric_cls, metric_fn_groups, label_columns, hvd)
# iterate on one epoch
for batch_idx in range(validation_steps):
row = next(val_loader_iter)