Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_byteps_push_pull(self):
"""Test that the byteps_push_pull correctly sums 1D, 2D, 3D tensors."""
size = bps.size()
dtypes = self.filter_supported_types(['float32'])
dims = [1]
ctx = self._current_context()
count = 100
shapes = [(), (17)]
for dtype, dim in itertools.product(dtypes, dims):
# MXNet uses gpu_id as part of the seed, so to get identical seeds
# we must set a context.
mx.random.seed(10 + 10 * bps.rank(), ctx=ctx)
tensor = mx.nd.random.uniform(-100, 100, shape=shapes[dim],
ctx=ctx)
tensor = tensor.astype(dtype)
print("tensor before push_pull:", tensor)
bps.byteps_declare_tensor("tensor_" + str(count))
bps.byteps_push_pull(tensor, name="tensor_"+str(count))
def test_byteps_push_pull_inplace(self):
"""Test that the byteps_push_pull correctly sums 1D, 2D, 3D tensors."""
size = bps.size()
dtypes = self.filter_supported_types(['int32', 'int64',
'float32', 'float64'])
dims = [1, 2, 3]
ctx = self._current_context()
count = 200
shapes = [(), (17), (17, 17), (17, 17, 17)]
for dtype, dim in itertools.product(dtypes, dims):
mx.random.seed(1234, ctx=ctx)
tensor = mx.nd.random.uniform(-100, 100, shape=shapes[dim],
ctx=ctx)
tensor = tensor.astype(dtype)
multiplied = tensor * size
bps.byteps_declare_tensor("tensor_" + str(count))
bps.byteps_push_pull(tensor, name= "tensor_" + str(count))
max_difference = mx.nd.max(mx.nd.subtract(tensor, multiplied))
count += 1
def test_byteps_broadcast(self):
"""Test that the broadcast correctly broadcasts 1D, 2D, 3D tensors."""
rank = bps.rank()
size = bps.size()
# This test does not apply if there is only one worker.
if size == 1:
return
dtypes = ['int32', 'int64',
'float32', 'float64']
dims = [1, 2, 3]
ctx = self._current_context()
count = 300
shapes = [(), (17), (17, 17), (17, 17, 17)]
root_ranks = list(range(size))
for dtype, dim, root_rank in itertools.product(dtypes, dims,
root_ranks):
tensor = mx.nd.ones(shapes[dim], ctx=ctx) * rank
root_tensor = mx.nd.ones(shapes[dim], ctx=ctx) * root_rank
def test_byteps_trainer_param_order(self):
size = bps.size()
dtypes = self.filter_supported_types(['float32'])
dims = [1]
ctx = self._current_context()
net = mx.gluon.nn.Sequential()
# layers may be added in a random order for all workers
layers = {'ones_': 1, 'zeros_': 0}
for name, init in layers.items():
net.add(mx.gluon.nn.Dense(10, in_units=10, weight_initializer=mx.init.Constant(init),
use_bias=False, prefix=name))
params = net.collect_params()
net.initialize()
trainer = bps.DistributedTrainer(params, 'sgd')
trainer._init_params()
# check the result of bps_broadcast
for name, init in layers.items():
weight = params[name + 'weight'].data()[0].asnumpy()
label = batch[1].as_in_context(context)
output = model(data.astype(args.dtype, copy=False))
metric.update([label], [output])
return metric.get()
# Load training and validation data
train_data, val_data, train_size = get_mnist_iterator()
# Initialize BytePS
bps.init()
# BytePS: pin context to local rank
context = mx.cpu(bps.local_rank()) if args.no_cuda else mx.gpu(bps.local_rank())
num_workers = bps.size()
# Build model
model = conv_nets()
model.cast(args.dtype)
# Initialize parameters
model.initialize(mx.init.MSRAPrelu(), ctx=context)
# if bps.rank() == 0:
model.summary(nd.ones((1, 1, 28, 28), ctx=mx.gpu(bps.local_rank())))
model.hybridize()
params = model.collect_params()
# BytePS: create DistributedTrainer, a subclass of gluon.Trainer
optimizer_params = {'momentum': args.momentum, 'learning_rate': args.lr * num_workers}
trainer = bps.DistributedTrainer(params, "sgd", optimizer_params)
'lr_scheduler': lr_scheduler,
'multi_precision': True}
# Only a limited number of optimizers have 'momentum' property
has_momentum = {'sgd', 'dcasgd', 'nag'}
if args.optimizer in has_momentum:
optimizer_params['momentum'] = args.mom
monitor = mx.mon.Monitor(
args.monitor, pattern=".*") if args.monitor > 0 else None
# A limited number of optimizers have a warmup period
has_warmup = {'lbsgd', 'lbnag'}
if args.optimizer in has_warmup:
if bps.size() > 1:
nworkers = bps.size()
else:
nworkers = 1
epoch_size = args.num_examples / args.batch_size / nworkers
if epoch_size < 1:
epoch_size = 1
macrobatch_size = args.macrobatch_size
if macrobatch_size < args.batch_size * nworkers:
macrobatch_size = args.batch_size * nworkers
#batch_scale = round(float(macrobatch_size) / args.batch_size / nworkers +0.4999)
batch_scale = math.ceil(
float(macrobatch_size) / args.batch_size / nworkers)
optimizer_params['updates_per_epoch'] = epoch_size
optimizer_params['begin_epoch'] = args.load_epoch if args.load_epoch else 0
optimizer_params['batch_scale'] = batch_scale
optimizer_params['warmup_strategy'] = args.warmup_strategy
optimizer_params['warmup_epochs'] = args.warmup_epochs