Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
)
# Load the latest checkpoint if one is available
os.makedirs(args.save_dir, exist_ok=True)
load_checkpoint(args, trainer, epoch_itr, teacher_model=True)
for name, para in teacher_model.named_parameters():
para.requires_grad = False
# Send a dummy batch to warm the caching allocator
# dummy_batch = task.dataset('train').get_dummy_batch(args.max_tokens, max_positions)
# trainer.dummy_train_step(dummy_batch)
# Train until the learning rate gets too small
max_epoch = args.max_epoch or math.inf
max_update = args.max_update or math.inf
lr = trainer.get_lr()
train_meter = StopwatchMeter()
train_meter.start()
valid_losses = [None]
valid_subsets = args.valid_subset.split(',')
while lr > args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates() < max_update:
# train for one epoch
train(args, trainer, task, epoch_itr)
if epoch_itr.epoch % args.validate_interval == 0:
valid_losses, valid_state = validate(args, trainer, task, epoch_itr, valid_subsets)
# only use first validation loss to update the learning rate
lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])
# save checkpoint
if epoch_itr.epoch % args.save_interval == 0:
save_checkpoint(args, trainer, epoch_itr, valid_losses[0], valid_state)
train_meter.stop()
print('| done training in {:.1f} seconds'.format(train_meter.sum))
epoch = extra_state['epoch']
print('| loaded checkpoint {} (epoch {})'.format(checkpoint_path, epoch))
trainer.lr_step(epoch)
for i in range(epoch):
_ = next(train_dataloader)
epoch += 1
# Send a dummy batch to warm the caching allocator
dummy_batch = data.get_dummy_batch(args.max_tokens, dataset.src_dict, dataset.dst_dict)
trainer.dummy_train_step(dummy_batch)
# Train until the learning rate gets too small
max_epoch = args.max_epoch or math.inf
max_update = args.max_update or math.inf
lr = trainer.get_lr()
train_meter = StopwatchMeter()
train_meter.start()
while lr > args.min_lr and epoch <= max_epoch:
# train for one epoch
train(args, trainer, next(train_dataloader), epoch)
# evaluate on validate set
first_val_loss = None
if epoch % args.validate_interval == 0:
for k, subset in enumerate(args.valid_subset.split(',')):
val_loss = validate(args, trainer, dataset, subset, epoch)
if k == 0:
first_val_loss = val_loss
# only use first validation loss to update the learning rate
lr = trainer.lr_step(epoch, first_val_loss)
for model in models:
model.make_generation_fast_()
if args.fp16:
model.half()
itr = data.EpochBatchIterator(
dataset=task.dataset(args.gen_subset),
max_tokens=args.max_tokens,
max_sentences=args.max_sentences or 4,
max_positions=model.max_positions(),
num_shards=args.num_shards,
shard_id=args.shard_id,
ignore_invalid_inputs=True,
).next_epoch_itr(shuffle=False)
gen_timer = StopwatchMeter()
scorer = SequenceScorer(models, task.target_dictionary)
if use_cuda:
scorer.cuda()
score_sum = 0.
count = 0
with progress_bar.build_progress_bar(args, itr) as t:
results = scorer.score_batched_itr(t, cuda=use_cuda, timer=gen_timer)
wps_meter = TimeMeter()
for _, src_tokens, __, hypos in results:
for hypo in hypos:
pos_scores = hypo['positional_scores']
inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf'))
if inf_scores.any():
print('| Skipping tokens with inf scores:',
task.target_dictionary.string(hypo['tokens'][inf_scores.nonzero()]))
self.meters = OrderedDict()
self.meters['train_loss'] = AverageMeter()
self.meters['train_nll_loss'] = AverageMeter()
self.meters['valid_loss'] = AverageMeter()
self.meters['valid_nll_loss'] = AverageMeter()
self.meters['wps'] = TimeMeter() # words per second
self.meters['ups'] = TimeMeter() # updates per second
self.meters['wpb'] = AverageMeter() # words per batch
self.meters['bsz'] = AverageMeter() # sentences per batch
self.meters['gnorm'] = AverageMeter() # gradient norm
self.meters['clip'] = AverageMeter() # % of updates clipped
self.meters['oom'] = AverageMeter() # out of memory
if args.fp16:
self.meters['loss_scale'] = AverageMeter() # dynamic loss scale
self.meters['wall'] = TimeMeter() # wall time in seconds
self.meters['train_wall'] = StopwatchMeter() # train wall time in seconds
# Load dataset (possibly sharded)
max_positions = min(model.max_encoder_positions() for model in models)
itr = dataset.eval_dataloader(
args.gen_subset,
max_sentences=args.max_sentences,
max_positions=max_positions,
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test,
)
if args.num_shards > 1:
if args.shard_id < 0 or args.shard_id >= args.num_shards:
raise ValueError('--shard-id must be between 0 and num_shards')
itr = data.sharded_iterator(itr, args.num_shards, args.shard_id)
# Initialize generator
gen_timer = StopwatchMeter()
if args.score_reference:
translator = SequenceScorer(models)
else:
translator = SequenceGenerator(
models, beam_size=args.beam, stop_early=(not args.no_early_stop),
normalize_scores=(not args.unnormalized), len_penalty=args.lenpen,
unk_penalty=args.unkpen, sampling=args.sampling)
if use_cuda:
translator.cuda()
# Generate and compute BLEU score
scorer = bleu.Scorer(dataset.dst_dict.pad(), dataset.dst_dict.eos(), dataset.dst_dict.unk())
num_sentences = 0
has_target = True
with progress_bar.build_progress_bar(args, itr) as t:
if args.score_reference:
self.meters = OrderedDict()
self.meters['train_loss'] = AverageMeter()
self.meters['train_nll_loss'] = AverageMeter()
self.meters['valid_loss'] = AverageMeter()
self.meters['valid_nll_loss'] = AverageMeter()
self.meters['wps'] = TimeMeter() # words per second
self.meters['ups'] = TimeMeter() # updates per second
self.meters['wpb'] = AverageMeter() # words per batch
self.meters['bsz'] = AverageMeter() # sentences per batch
self.meters['gnorm'] = AverageMeter() # gradient norm
self.meters['clip'] = AverageMeter() # % of updates clipped
self.meters['oom'] = AverageMeter() # out of memory
if args.fp16:
self.meters['loss_scale'] = AverageMeter() # dynamic loss scale
self.meters['wall'] = TimeMeter() # wall time in seconds
self.meters['train_wall'] = StopwatchMeter() # train wall time in seconds
self.meters['train_loss'] = AverageMeter()
self.meters['train_nll_loss'] = AverageMeter()
self.meters['valid_loss'] = AverageMeter()
self.meters['valid_nll_loss'] = AverageMeter()
self.meters['copy_alpha'] = AverageMeter()
self.meters['wps'] = TimeMeter() # words per second
self.meters['ups'] = TimeMeter() # updates per second
self.meters['wpb'] = AverageMeter() # words per batch
self.meters['bsz'] = AverageMeter() # sentences per batch
self.meters['gnorm'] = AverageMeter() # gradient norm
self.meters['clip'] = AverageMeter() # % of updates clipped
self.meters['oom'] = AverageMeter() # out of memory
if args.fp16:
self.meters['loss_scale'] = AverageMeter() # dynamic loss scale
self.meters['wall'] = TimeMeter() # wall time in seconds
self.meters['train_wall'] = StopwatchMeter() # train wall time in seconds
num_shards=args.num_shards,
shard_id=args.shard_id,
num_workers=args.num_workers,
).next_epoch_itr(shuffle=False)
oracle_scorer = None
if args.report_oracle_bleu:
oracle_scorer = bleu.Scorer(dst_dict.pad(), dst_dict.eos(), dst_dict.unk())
rescorer = None
num_sentences = 0
translation_samples = []
translation_info_list = []
with progress_bar.build_progress_bar(args, itr) as t:
wps_meter = TimeMeter()
gen_timer = StopwatchMeter()
translations = translator.generate_batched_itr(
t,
maxlen_a=args.max_len_a,
maxlen_b=args.max_len_b,
cuda=use_cuda,
timer=gen_timer,
prefix_size=1
if pytorch_translate_data.is_multilingual_many_to_one(args)
else 0,
)
for trans_info in _iter_translations(
args, task, dataset, translations, align_dict, rescorer, modify_target_dict
):
if hasattr(scorer, "add_string"):
scorer.add_string(trans_info.target_str, trans_info.hypo_str)