Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
opt.src_vocab_size = 46930
opt.tgt_vocab_size = 23094
#========= Preparing Model =========#
if opt.embs_share_weight:
assert training_data.dataset.src_word2idx == training_data.dataset.tgt_word2idx, \
'The src/tgt word2idx table are different but asked to share word embedding.'
print(opt)
use_horovod = int(os.environ.get("USE_HOROVOD","0"))
if use_horovod > 0:
import byteps.torch as hvd
hvd.init()
#device = torch.cuda.set_device(hvd.local_rank())
device = torch.device('cuda',hvd.local_rank())
else:
device = torch.device('cuda' if opt.cuda else 'cpu')
#device = torch.cuda.set_device(hvd.local_rank())
transformer = Transformer(
opt.src_vocab_size,
opt.tgt_vocab_size,
opt.max_token_seq_len,
tgt_emb_prj_weight_sharing=opt.proj_share_weight,
emb_src_tgt_weight_sharing=opt.embs_share_weight,
d_k=opt.d_k,
d_v=opt.d_v,
d_model=opt.d_model,
d_word_vec=opt.d_word_vec,
d_inner=opt.d_inner_hid,
n_layers=opt.n_layers,
n_head=opt.n_head,
help='disables CUDA')
parser.add_argument('--no-wait', type=bool, default=True,
help='wait for other worker request first')
parser.add_argument('--gpu', type=int, default=-1,
help='use a specified gpu')
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
bps.init()
# BytePS: pin GPU to local rank.
if args.gpu >= 0:
torch.cuda.set_device(args.gpu)
else:
torch.cuda.set_device(bps.local_rank())
cudnn.benchmark = True
def log(s, nl=True):
if bps.rank() != 0:
return
print(s, end='\n' if nl else '')
def benchmark(tensor, average, name):
if not args.no_wait and bps.rank() == 0:
time.sleep(0.01)
start = time.time()
handle = push_pull_async_inplace(tensor, average, name)
while True:
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--seed', type=int, default=42,
help='random seed')
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
pushpull_batch_size = args.batch_size * args.batches_per_pushpull
bps.init()
torch.manual_seed(args.seed)
if args.cuda:
# BytePS: pin GPU to local rank.
torch.cuda.set_device(bps.local_rank())
torch.cuda.manual_seed(args.seed)
cudnn.benchmark = True
# If set > 0, will resume training from a given checkpoint.
resume_from_epoch = 0
for try_epoch in range(args.epochs, 0, -1):
if os.path.exists(args.checkpoint_format.format(epoch=try_epoch)):
resume_from_epoch = try_epoch
break
# BytePS: broadcast resume_from_epoch from rank 0 (which will have
# checkpoints) to other ranks.
#resume_from_epoch = bps.broadcast(torch.tensor(resume_from_epoch), root_rank=0,
# name='resume_from_epoch').item()
parser.add_argument('--fp16',
action='store_true',
help="Whether to use 16-bit float precision instead of 32-bit")
parser.add_argument('--loss_scale',
type=float, default=0,
help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
"0 (default value): dynamic loss scaling.\n"
"Positive power of 2: static loss scaling value.\n")
parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
args = parser.parse_args()
use_horovod = int(os.getenv("USE_HOROVOD"))
print("env variable USE_HOROVOD:", use_horovod)
if use_horovod == 1:
hvd.init()
args.local_rank = hvd.local_rank()
print("use horovod, local rank:", args.local_rank)
args.output_dir = args.output_dir + "_" + str(args.local_rank)
if args.server_ip and args.server_port:
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
import ptvsd
print("Waiting for debugger attach")
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
ptvsd.wait_for_attach()
processors = {
"cola": ColaProcessor,
"mnli": MnliProcessor,
"mrpc": MrpcProcessor,
"sst-2": Sst2Processor,
}
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--profiler', action='store_true', default=False,
help='disables profiler')
parser.add_argument('--partition', type=int, default=None,
help='partition size')
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
bps.init()
if args.cuda:
# BytePS: pin GPU to local rank.
torch.cuda.set_device(bps.local_rank())
cudnn.benchmark = True
# Set up standard model.
model = getattr(models, args.model)(num_classes=args.num_classes)
if args.cuda:
# Move model to GPU.
model.cuda()
optimizer = optim.SGD(model.parameters(), lr=0.01)
# BytePS: (optional) compression algorithm.
compression = bps.Compression.fp16 if args.fp16_pushpull else bps.Compression.none
# BytePS: wrap optimizer with DistributedOptimizer.
transformer = Transformer(
opt.src_vocab_size,
opt.tgt_vocab_size,
opt.max_token_seq_len,
tgt_emb_prj_weight_sharing=opt.proj_share_weight,
emb_src_tgt_weight_sharing=opt.embs_share_weight,
d_k=opt.d_k,
d_v=opt.d_v,
d_model=opt.d_model,
d_word_vec=opt.d_word_vec,
d_inner=opt.d_inner_hid,
n_layers=opt.n_layers,
n_head=opt.n_head,
dropout=opt.dropout).to(device)
print("src_vocab_size:",opt.src_vocab_size,",tgt_vocab_size:",opt.tgt_vocab_size,",share_weight:",(opt.proj_share_weight,opt.embs_share_weight))
if hvd.local_rank() == 0:
fo = open("transformer_model.csv", "w")
for name, p in transformer.named_parameters():
if p.requires_grad:
size = 1
for s in list(p.size()):
size = size * s
print("name:",name,", size:",size)
fo.write(name+", "+str(size)+"\n")
fo.close()
torch_optimizer = optim.Adam(filter(lambda x: x.requires_grad, transformer.parameters()), betas=(0.9, 0.98), eps=1e-09)
if use_horovod > 0:
torch_optimizer = hvd.DistributedOptimizer(torch_optimizer, named_parameters=transformer.named_parameters())
hvd.broadcast_parameters(transformer.state_dict(), root_rank=0)
hvd.broadcast_optimizer_state(torch_optimizer, root_rank=0)
#print("finish hvd preparation")
def log(s, nl=True):
if bps.local_rank() != 0:
return
print(s, end='\n' if nl else '')
sys.stdout.flush()