How to use the byteps.mxnet.local_rank function in byteps

To help you get started, we’ve selected a few byteps examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github bytedance / byteps / example / mxnet / train_gluon_mnist_byteps.py View on Github external
data = batch[0].as_in_context(context)
        label = batch[1].as_in_context(context)
        output = model(data.astype(args.dtype, copy=False))
        metric.update([label], [output])

    return metric.get()


# Load training and validation data
train_data, val_data, train_size = get_mnist_iterator()

# Initialize BytePS
bps.init()

# BytePS: pin context to local rank
context = mx.cpu(bps.local_rank()) if args.no_cuda else mx.gpu(bps.local_rank())
num_workers = bps.size()

# Build model
model = conv_nets()
model.cast(args.dtype)

# Initialize parameters
model.initialize(mx.init.MSRAPrelu(), ctx=context)
# if bps.rank() == 0:
model.summary(nd.ones((1, 1, 28, 28), ctx=mx.gpu(bps.local_rank())))
model.hybridize()

params = model.collect_params()

# BytePS: create DistributedTrainer, a subclass of gluon.Trainer
optimizer_params = {'momentum': args.momentum, 'learning_rate': args.lr * num_workers}
github bytedance / byteps / example / mxnet / common / fit_byteps.py View on Github external
if 'arg_params' in kwargs and 'aux_params' in kwargs:
        arg_params = kwargs['arg_params']
        aux_params = kwargs['aux_params']
    else:
        sym, arg_params, aux_params = _load_model(args, bps.rank())
        if sym is not None:
            assert sym.tojson() == network.tojson()

    # save model
    checkpoint = _save_model(args, bps.rank())

    # devices for training
    if args.cpu_train:
        devs = [mx.cpu(bps.local_rank())]
    else:
        logging.info('Launch BytePS process on GPU-%d', bps.local_rank())
        devs = [mx.gpu(bps.local_rank())]

    # learning rate
    lr, lr_scheduler = _get_lr_scheduler(args)

    # create model
    model = mx.mod.Module(
        context=devs,
        symbol=network
    )

    lr_scheduler = lr_scheduler
    optimizer_params = {
        'learning_rate': lr,
        'wd': args.wd,
        'lr_scheduler': lr_scheduler,
github bytedance / byteps / example / mxnet / common / fit_byteps.py View on Github external
arg_params = kwargs['arg_params']
        aux_params = kwargs['aux_params']
    else:
        sym, arg_params, aux_params = _load_model(args, bps.rank())
        if sym is not None:
            assert sym.tojson() == network.tojson()

    # save model
    checkpoint = _save_model(args, bps.rank())

    # devices for training
    if args.cpu_train:
        devs = [mx.cpu(bps.local_rank())]
    else:
        logging.info('Launch BytePS process on GPU-%d', bps.local_rank())
        devs = [mx.gpu(bps.local_rank())]

    # learning rate
    lr, lr_scheduler = _get_lr_scheduler(args)

    # create model
    model = mx.mod.Module(
        context=devs,
        symbol=network
    )

    lr_scheduler = lr_scheduler
    optimizer_params = {
        'learning_rate': lr,
        'wd': args.wd,
        'lr_scheduler': lr_scheduler,
        'multi_precision': True}
github bytedance / byteps / example / mxnet / common / fit_byteps.py View on Github external
# load model
    if 'arg_params' in kwargs and 'aux_params' in kwargs:
        arg_params = kwargs['arg_params']
        aux_params = kwargs['aux_params']
    else:
        sym, arg_params, aux_params = _load_model(args, bps.rank())
        if sym is not None:
            assert sym.tojson() == network.tojson()

    # save model
    checkpoint = _save_model(args, bps.rank())

    # devices for training
    if args.cpu_train:
        devs = [mx.cpu(bps.local_rank())]
    else:
        logging.info('Launch BytePS process on GPU-%d', bps.local_rank())
        devs = [mx.gpu(bps.local_rank())]

    # learning rate
    lr, lr_scheduler = _get_lr_scheduler(args)

    # create model
    model = mx.mod.Module(
        context=devs,
        symbol=network
    )

    lr_scheduler = lr_scheduler
    optimizer_params = {
        'learning_rate': lr,