Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
keras.datasets.mnist.load_data('MNIST-data-%d' % bps.rank())
# The shape of downloaded data is (-1, 28, 28), hence we need to reshape it
# into (-1, 784) to feed into our network. Also, need to normalize the
# features between 0 and 1.
x_train = np.reshape(x_train, (-1, 784)) / 255.0
x_test = np.reshape(x_test, (-1, 784)) / 255.0
# Build model...
with tf.name_scope('input'):
image = tf.placeholder(tf.float32, [None, 784], name='image')
label = tf.placeholder(tf.float32, [None], name='label')
predict, loss = conv_model(image, label, tf.estimator.ModeKeys.TRAIN)
# BytePS: adjust learning rate based on number of GPUs.
opt = tf.train.RMSPropOptimizer(0.001 * bps.size())
# BytePS: add BytePS Distributed Optimizer.
opt = bps.DistributedOptimizer(opt)
global_step = tf.train.get_or_create_global_step()
train_op = opt.minimize(loss, global_step=global_step)
hooks = [
# BytePS: BroadcastGlobalVariablesHook broadcasts initial variable states
# from rank 0 to all other processes. This is necessary to ensure consistent
# initialization of all workers when training is started with random weights
# or restored from a checkpoint.
bps.BroadcastGlobalVariablesHook(0),
# BytePS: adjust number of steps based on number of GPUs.
tf.train.StopAtStepHook(last_step=200000 // bps.size()),
processor = processors[task_name]()
num_labels = num_labels_task[task_name]
label_list = processor.get_labels()
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
train_examples = None
num_train_optimization_steps = None
if args.do_train:
train_examples = processor.get_train_examples(args.data_dir)
num_train_optimization_steps = int(
len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs
if args.local_rank != -1:
if use_horovod == 1:
num_train_optimization_steps = num_train_optimization_steps // hvd.size()
else:
num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
# Prepare model
cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(args.local_rank))
model = BertForSequenceClassification.from_pretrained(args.bert_model,
cache_dir=cache_dir,
num_labels = num_labels)
if args.fp16:
model.half()
model.to(device)
if args.local_rank == 0:
fo = open("bert_model.txt", "w")
for name, p in model.named_parameters():
if p.requires_grad:
size = 1
# BytePS: add BytePS Distributed Optimizer.
opt = bps.DistributedOptimizer(opt)
global_step = tf.train.get_or_create_global_step()
train_op = opt.minimize(loss, global_step=global_step)
hooks = [
# BytePS: BroadcastGlobalVariablesHook broadcasts initial variable states
# from rank 0 to all other processes. This is necessary to ensure consistent
# initialization of all workers when training is started with random weights
# or restored from a checkpoint.
bps.BroadcastGlobalVariablesHook(0),
# BytePS: adjust number of steps based on number of GPUs.
tf.train.StopAtStepHook(last_step=200000 // bps.size()),
tf.train.LoggingTensorHook(tensors={'step': global_step, 'loss': loss},
every_n_iter=10),
]
# BytePS: pin GPU to be used to process local rank (one GPU per process)
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.gpu_options.visible_device_list = str(bps.local_rank())
# BytePS: save checkpoints only on worker 0 to prevent other workers from
# corrupting them.
checkpoint_dir = './checkpoints' if bps.rank() == 0 else None
training_batch_generator = train_input_generator(x_train,
y_train, batch_size=100)
# The MonitoredTrainingSession takes care of session initialization,
print("comparison", bps.rank(), tensor == root_tensor)
assert not same(tensor.asnumpy(), root_tensor.asnumpy()), \
'bps.broadcast modifies source tensor'
if not same(broadcast_tensor.asnumpy(), root_tensor.asnumpy()):
print("broadcast", count, dtype, dim)
print("broadcast_tensor", bps.rank(), broadcast_tensor)
print("root_tensor", bps.rank(), root_tensor)
print("comparison", bps.rank(),
broadcast_tensor == root_tensor)
assert same(broadcast_tensor.asnumpy(), root_tensor.asnumpy()), \
'bps.broadcast produces incorrect broadcasted tensor'
if __name__ == '__main__':
mxtest = MXTest()
bps.init()
mxtest.test_byteps_push_pull()
mxtest.test_byteps_trainer_param_order()
#mxtest.test_byteps_broadcast()
def _current_context(self):
if has_gpu:
return mx.gpu(bps.local_rank())
else:
return mx.current_context()
def test_byteps_push_pull(self):
"""Test that the byteps_push_pull correctly sums 1D, 2D, 3D tensors."""
size = bps.size()
dtypes = self.filter_supported_types(['float32'])
dims = [1]
ctx = self._current_context()
count = 100
shapes = [(), (17)]
for dtype, dim in itertools.product(dtypes, dims):
# MXNet uses gpu_id as part of the seed, so to get identical seeds
# we must set a context.
mx.random.seed(10 + 10 * bps.rank(), ctx=ctx)
tensor = mx.nd.random.uniform(-100, 100, shape=shapes[dim],
ctx=ctx)
tensor = tensor.astype(dtype)
print("tensor before push_pull:", tensor)
bps.byteps_declare_tensor("tensor_" + str(count))
bps.byteps_push_pull(tensor, name="tensor_"+str(count))
def test_byteps_push_pull_inplace(self):
"""Test that the byteps_push_pull correctly sums 1D, 2D, 3D tensors."""
size = bps.size()
dtypes = self.filter_supported_types(['int32', 'int64',
'float32', 'float64'])
dims = [1, 2, 3]
ctx = self._current_context()
count = 200
shapes = [(), (17), (17, 17), (17, 17, 17)]
for dtype, dim in itertools.product(dtypes, dims):
mx.random.seed(1234, ctx=ctx)
tensor = mx.nd.random.uniform(-100, 100, shape=shapes[dim],
ctx=ctx)
tensor = tensor.astype(dtype)
multiplied = tensor * size
bps.byteps_declare_tensor("tensor_" + str(count))
bps.byteps_push_pull(tensor, name= "tensor_" + str(count))
max_difference = mx.nd.max(mx.nd.subtract(tensor, multiplied))
count += 1
def test_byteps_broadcast(self):
"""Test that the broadcast correctly broadcasts 1D, 2D, 3D tensors."""
rank = bps.rank()
size = bps.size()
# This test does not apply if there is only one worker.
if size == 1:
return
dtypes = ['int32', 'int64',
'float32', 'float64']
dims = [1, 2, 3]
ctx = self._current_context()
count = 300
shapes = [(), (17), (17, 17), (17, 17, 17)]
root_ranks = list(range(size))
for dtype, dim, root_rank in itertools.product(dtypes, dims,
root_ranks):
tensor = mx.nd.ones(shapes[dim], ctx=ctx) * rank
root_tensor = mx.nd.ones(shapes[dim], ctx=ctx) * root_rank
def test_byteps_trainer_param_order(self):
size = bps.size()
dtypes = self.filter_supported_types(['float32'])
dims = [1]
ctx = self._current_context()
net = mx.gluon.nn.Sequential()
# layers may be added in a random order for all workers
layers = {'ones_': 1, 'zeros_': 0}
for name, init in layers.items():
net.add(mx.gluon.nn.Dense(10, in_units=10, weight_initializer=mx.init.Constant(init),
use_bias=False, prefix=name))
params = net.collect_params()
net.initialize()
trainer = bps.DistributedTrainer(params, 'sgd')
trainer._init_params()
# check the result of bps_broadcast
for name, init in layers.items():
weight = params[name + 'weight'].data()[0].asnumpy()
def test_byteps_push_pull_inplace(self):
"""Test that the byteps_push_pull correctly sums 1D, 2D, 3D tensors."""
size = bps.size()
dtypes = self.filter_supported_types(['int32', 'int64',
'float32', 'float64'])
dims = [1, 2, 3]
ctx = self._current_context()
count = 200
shapes = [(), (17), (17, 17), (17, 17, 17)]
for dtype, dim in itertools.product(dtypes, dims):
mx.random.seed(1234, ctx=ctx)
tensor = mx.nd.random.uniform(-100, 100, shape=shapes[dim],
ctx=ctx)
tensor = tensor.astype(dtype)
multiplied = tensor * size
bps.byteps_declare_tensor("tensor_" + str(count))
bps.byteps_push_pull(tensor, name= "tensor_" + str(count))
max_difference = mx.nd.max(mx.nd.subtract(tensor, multiplied))
count += 1
# Threshold for floating point equality depends on number of
# ranks, since we're comparing against precise multiplication.
if size <= 3 or dtype in ['int32', 'int64']:
threshold = 0
elif size < 10:
threshold = 1e-4
elif size < 15:
threshold = 5e-4
else:
break
if max_difference > threshold: