How to use the byteps.mxnet function in byteps

To help you get started, we’ve selected a few byteps examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github bytedance / byteps / tests / test_mxnet.py View on Github external
print("comparison", bps.rank(), tensor == root_tensor)
                assert not same(tensor.asnumpy(), root_tensor.asnumpy()), \
                    'bps.broadcast modifies source tensor'
            if not same(broadcast_tensor.asnumpy(), root_tensor.asnumpy()):
                print("broadcast", count, dtype, dim)
                print("broadcast_tensor", bps.rank(), broadcast_tensor)
                print("root_tensor", bps.rank(), root_tensor)
                print("comparison", bps.rank(),
                      broadcast_tensor == root_tensor)
            assert same(broadcast_tensor.asnumpy(), root_tensor.asnumpy()), \
                'bps.broadcast produces incorrect broadcasted tensor'


if __name__ == '__main__':
    mxtest = MXTest()
    bps.init()
    mxtest.test_byteps_push_pull()
    mxtest.test_byteps_trainer_param_order()
    #mxtest.test_byteps_broadcast()
github bytedance / byteps / tests / test_mxnet.py View on Github external
def _current_context(self):
        if has_gpu:
            return mx.gpu(bps.local_rank())
        else:
            return mx.current_context()
github bytedance / byteps / tests / test_mxnet.py View on Github external
def test_byteps_push_pull_inplace(self):
        """Test that the byteps_push_pull correctly sums 1D, 2D, 3D tensors."""
        size = bps.size()
        dtypes = self.filter_supported_types(['int32',   'int64',
                                              'float32', 'float64'])
        dims = [1, 2, 3]
        ctx = self._current_context()
        count = 200
        shapes = [(), (17), (17, 17), (17, 17, 17)]
        for dtype, dim in itertools.product(dtypes, dims):
            mx.random.seed(1234, ctx=ctx)
            tensor = mx.nd.random.uniform(-100, 100, shape=shapes[dim],
                                          ctx=ctx)
            tensor = tensor.astype(dtype)
            multiplied = tensor * size
            bps.byteps_declare_tensor("tensor_" + str(count))
            bps.byteps_push_pull(tensor, name= "tensor_" + str(count))
            max_difference = mx.nd.max(mx.nd.subtract(tensor, multiplied))
            count += 1

            # Threshold for floating point equality depends on number of
            # ranks, since we're comparing against precise multiplication.
            if size <= 3 or dtype in ['int32', 'int64']:
                threshold = 0
            elif size < 10:
                threshold = 1e-4
            elif size < 15:
                threshold = 5e-4
            else:
                break

            if max_difference > threshold:
github bytedance / byteps / example / mxnet / common / fit_byteps.py View on Github external
def _get_lr_scheduler(args):
    if 'lr_factor' not in args or args.lr_factor >= 1:
        return (args.lr, None)
    epoch_size = args.num_examples / args.batch_size
    if bps.size() > 1:
        epoch_size /= bps.size()
    begin_epoch = args.load_epoch if args.load_epoch else 0
    if 'pow' in args.lr_step_epochs:
        lr = args.lr
        max_up = args.num_epochs * epoch_size
        pwr = float(re.sub('pow[- ]*', '', args.lr_step_epochs))
        poly_sched = mx.lr_scheduler.PolyScheduler(max_up, lr, pwr)
        return (lr, poly_sched)
    step_epochs = [int(l) for l in args.lr_step_epochs.split(',')]
    lr = args.lr
    for s in step_epochs:
        if begin_epoch >= s:
            lr *= args.lr_factor
    if lr != args.lr:
        logging.info('Adjust learning rate to %e for epoch %d',
                     lr, begin_epoch)
github bytedance / byteps / example / mxnet / train_gluon_mnist_byteps.py View on Github external
def evaluate(model, data_iter, context):
    metric = mx.metric.Accuracy()
    for _, batch in enumerate(data_iter):
        data = batch[0].as_in_context(context)
        label = batch[1].as_in_context(context)
        output = model(data.astype(args.dtype, copy=False))
        metric.update([label], [output])

    return metric.get()


# Load training and validation data
train_data, val_data, train_size = get_mnist_iterator()

# Initialize BytePS
bps.init()

# BytePS: pin context to local rank
context = mx.cpu(bps.local_rank()) if args.no_cuda else mx.gpu(bps.local_rank())
num_workers = bps.size()

# Build model
model = conv_nets()
model.cast(args.dtype)

# Initialize parameters
model.initialize(mx.init.MSRAPrelu(), ctx=context)
# if bps.rank() == 0:
model.summary(nd.ones((1, 1, 28, 28), ctx=mx.gpu(bps.local_rank())))
model.hybridize()

params = model.collect_params()
github bytedance / byteps / example / mxnet / train_imagenet_byteps.py View on Github external
# under the License.

import os
import argparse
import logging
logging.basicConfig(level=logging.DEBUG)
from common import find_mxnet
from common import data_byteps as data
from common import fit_byteps as fit
from common.util import download_file
import byteps.mxnet as bps
import mxnet as mx

if __name__ == '__main__':
    # init byteps
    bps.init()

    # parse args
    parser = argparse.ArgumentParser(description="train imagenet-1k",
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    fit.add_fit_args(parser)
    data.add_data_args(parser)
    data.add_data_aug_args(parser)
    # use a large aug level
    data.set_data_aug_level(parser, 3)
    parser.set_defaults(
        # network
        network          = 'resnet',
        num_layers       = 50,
        # data
        num_classes      = 1000,
        num_examples     = 1281167,
github bytedance / byteps / example / mxnet / common / fit_byteps.py View on Github external
'wd': args.wd,
        'lr_scheduler': lr_scheduler,
        'multi_precision': True}

    # Only a limited number of optimizers have 'momentum' property
    has_momentum = {'sgd', 'dcasgd', 'nag'}
    if args.optimizer in has_momentum:
        optimizer_params['momentum'] = args.mom

    monitor = mx.mon.Monitor(
        args.monitor, pattern=".*") if args.monitor > 0 else None

    # A limited number of optimizers have a warmup period
    has_warmup = {'lbsgd', 'lbnag'}
    if args.optimizer in has_warmup:
        if bps.size() > 1:
            nworkers = bps.size()
        else:
            nworkers = 1
        epoch_size = args.num_examples / args.batch_size / nworkers
        if epoch_size < 1:
            epoch_size = 1
        macrobatch_size = args.macrobatch_size
        if macrobatch_size < args.batch_size * nworkers:
            macrobatch_size = args.batch_size * nworkers
        #batch_scale = round(float(macrobatch_size) / args.batch_size / nworkers +0.4999)
        batch_scale = math.ceil(
            float(macrobatch_size) / args.batch_size / nworkers)
        optimizer_params['updates_per_epoch'] = epoch_size
        optimizer_params['begin_epoch'] = args.load_epoch if args.load_epoch else 0
        optimizer_params['batch_scale'] = batch_scale
        optimizer_params['warmup_strategy'] = args.warmup_strategy
github bytedance / byteps / example / mxnet / train_gluon_mnist_byteps.py View on Github external
label = batch[1].as_in_context(context)

        with autograd.record():
            output = model(data)
            loss = loss_fn(output, label)

        loss.backward()
        trainer.step(args.batch_size)
        metric.update([label], [output])

        if i % 100 == 0:
            name, acc = metric.get()
            logging.info('[Epoch %d Batch %d] Training: %s=%f' %
                         (epoch, i, name, acc))

    if bps.rank() == 0:
        elapsed = time.time() - tic
        speed = train_size * num_workers / elapsed
        logging.info('Epoch[%d]\tSpeed=%.2f samples/s\tTime cost=%f',
                     epoch, speed, elapsed)

    # Evaluate model accuracy
    _, train_acc = metric.get()
    name, val_acc = evaluate(model, val_data, context)
    if bps.rank() == 0:
        logging.info('Epoch[%d]\tTrain: %s=%f\tValidation: %s=%f', epoch, name,
                     train_acc, name, val_acc)

    if bps.rank() == 0 and epoch == args.epochs - 1:
        assert val_acc > 0.96, "Achieved accuracy (%f) is lower than expected\
                                (0.96)" % val_acc