Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
t_total = num_train_steps
if self.multi_gpu == False:
t_total = t_total // torch.distributed.get_world_size()
global_step = 0
pbar = master_bar(range(epochs))
for epoch in pbar:
self.model.train()
tr_loss = 0
nb_tr_examples, nb_tr_steps = 0, 0
for step, batch in enumerate(progress_bar(self.data.train_dl, parent=pbar)):
batch = tuple(t.to(self.device) for t in batch)
input_ids, input_mask, segment_ids, label_ids = batch
if self.is_fp16 and self.multi_label:
label_ids = label_ids.half()
loss = self.model(input_ids, segment_ids, input_mask, label_ids)
if self.multi_gpu:
loss = loss.mean() # mean() to average on multi-gpu.
if self.grad_accumulation_steps > 1:
loss = loss / self.grad_accumulation_steps
if self.is_fp16:
self.optimizer.backward(loss)
else:
loss.backward()
def parallel(func, arr, max_workers=None):
with ProcessPoolExecutor(max_workers=max_workers) as ex:
futures = [ex.submit(func, o, i) for i, o in enumerate(arr)]
results = []
for f in progress_bar(as_completed(futures), total=len(arr)):
results.append(f.result())
return results
def validate(self, db: Optional[DataLoader] = None,
mb=None) -> List[torch.tensor]:
"Validation loop, done after every epoch"
self.mdl.eval()
if db is None:
db = self.data.valid_dl
predicted_box_dict_list = []
with torch.no_grad():
val_losses = {k: [] for k in self.loss_keys}
eval_metrics = {k: [] for k in self.met_keys}
nums = []
for batch in progress_bar(db, parent=mb):
for b in batch.keys():
batch[b] = batch[b].to(self.device)
out = self.mdl(batch)
out_loss = self.loss_fn(out, batch)
metric = self.eval_fn(out, batch)
for k in self.loss_keys:
val_losses[k].append(out_loss[k].detach())
for k in self.met_keys:
eval_metrics[k].append(metric[k].detach())
nums.append(batch[next(iter(batch))].shape[0])
prediction_dict = {
'id': metric['idxs'].tolist(),
'pred_boxes': metric['pred_boxes'].tolist(),
'pred_scores': metric['pred_scores'].tolist()
}
def create_corpus(text_list, target_path, logger=None):
nlp = spacy.load("en_core_web_sm", disable=["tagger", "ner", "textcat"])
with open(target_path, "w") as f:
# Split sentences for each document
logger.info("Formatting corpus for {}".format(target_path))
for text in progress_bar(text_list):
text = fix_html(text)
text = replace_multi_newline(text)
text = spec_add_spaces(text)
text = rm_useless_spaces(text)
text = text.strip()
f.write(text)
[
(
self.posterior[draw],
self.tempered_logp[draw],
self.priors[draw],
self.likelihoods[draw],
draw,
*parameters,
)
for draw in range(self.draws)
],
)
else:
iterator = range(self.draws)
if self.progressbar:
iterator = progress_bar(iterator, display=self.progressbar)
results = [
metrop_kernel(
self.posterior[draw],
self.tempered_logp[draw],
self.priors[draw],
self.likelihoods[draw],
draw,
*parameters
)
for draw in iterator
]
posterior, acc_list, priors, likelihoods = zip(*results)
self.posterior = np.array(posterior)
self.priors = np.array(priors)
self.likelihoods = np.array(likelihoods)
self.acc_per_chain = np.array(acc_list)
def refine(self, n, progressbar=True):
"""Refine the solution using the last compiled step function
"""
if self.state is None:
raise TypeError("Need to call `.fit` first")
i, step, callbacks, score = self.state
if progressbar:
progress = progress_bar(n, display=progressbar)
else:
progress = range(n) # This is a guess at what progress_bar(n) does.
if score:
state = self._iterate_with_loss(i, n, step, progress, callbacks)
else:
state = self._iterate_without_loss(i, n, step, progress, callbacks)
self.state = state
def run_profiling(self, n=1000, score=None, **kwargs):
score = self._maybe_score(score)
fn_kwargs = kwargs.pop("fn_kwargs", dict())
fn_kwargs["profile"] = True
step_func = self.objective.step_function(
score=score, fn_kwargs=fn_kwargs, **kwargs
)
progress = progress_bar(range(n))
try:
for _ in progress:
step_func()
except KeyboardInterrupt:
pass
return step_func.profile
self.logger.info("Num examples = %d", len(self.data.val_dl.dataset))
self.logger.info("Validation Batch size = %d", self.data.val_batch_size)
all_logits = None
all_labels = None
eval_loss, eval_accuracy = 0, 0
nb_eval_steps = 0
preds = None
out_label_ids = None
validation_scores = {metric['name']: 0. for metric in self.metrics}
for step, batch in enumerate(progress_bar(self.data.val_dl)):
self.model.eval()
batch = batch.to(self.device)
with torch.no_grad():
outputs = self.model(batch, masked_lm_labels=batch)
tmp_eval_loss = outputs[0]
eval_loss += tmp_eval_loss.mean().item()
cpu_device = torch.device('cpu')
batch.to(cpu_device)
torch.cuda.empty_cache()
nb_eval_steps += 1
eval_loss = eval_loss / nb_eval_steps
random_seed,
start,
draws=None,
step=None,
trace=None,
tune=None,
model=None,
**kwargs
):
skip_first = kwargs.get("skip_first", 0)
sampling = _iter_sample(draws, step, start, trace, chain, tune, model, random_seed)
_pbar_data = None
_pbar_data = {"chain": chain, "divergences": 0}
_desc = "Sampling chain {chain:d}, {divergences:,d} divergences"
sampling = progress_bar(sampling, total=draws, display=progressbar)
sampling.comment = _desc.format(**_pbar_data)
try:
strace = None
for it, (strace, diverging) in enumerate(sampling):
if it >= skip_first:
trace = MultiTrace([strace])
if diverging and _pbar_data is not None:
_pbar_data["divergences"] += 1
sampling.comment = _desc.format(**_pbar_data)
except KeyboardInterrupt:
pass
return strace