Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def reparameterize(self, mu, logvar):
if self.training:
std = torch.exp(0.5 * logvar)
eps = to_gpu(V(torch.randn(self.latent_dim)))
return mu + eps * std
else:
return mu
def reshape_parent_indices(indices, bs, num_beams):
parent_indices = V((torch.arange(end=bs) * num_beams).unsqueeze_(1).repeat(1, num_beams).view(-1).long())
return indices + parent_indices
max_iterations = min(dec_inputs.size(0), self.MAX_STEPS_ALLOWED) if self.training else self.max_iterations
inputs = V(inputs[:1].data) # inputs should be only first token initially [1,bs]
sl, bs = inputs.size()
finished = to_gpu(torch.zeros(bs).byte())
iteration = 0
self.beam_outputs = inputs.clone()
final_outputs = []
while not finished.all() and iteration < max_iterations:
# output should be List[[sl, bs, layer_dim], ...] sl should be one
if 0 < iteration and self.training and 0. < self.random() < self.pr_force:
inputs = dec_inputs[iteration].unsqueeze(0)
output = self.forward(inputs, hidden=hidden, num_beams=0, constraints=constraints)
hidden = self.decoder_layer.hidden
final_outputs.append(output) # dim should be [sl=1, bs, nt]
# inputs are the indices dims [1,bs] # repackage the var to avoid grad backwards
inputs = assert_dims(V(output.data.max(dim=-1)[1]), [1, bs])
iteration += 1
self.beam_outputs = assert_dims(torch.cat([self.beam_outputs, inputs], dim=0), [iteration + 1, bs])
new_finished = inputs.data == self.eos_token
finished = finished | new_finished
# stop if the output is to big to fit in memory
self.beam_outputs = self.beam_outputs.view(-1, bs, 1)
# outputs should be [sl, bs, nt]
outputs = torch.cat(final_outputs, dim=0)
return outputs
def _greedy_forward(self, inputs, hidden=None, constraints=None):
dec_inputs = inputs
max_iterations = min(dec_inputs.size(0), self.MAX_STEPS_ALLOWED) if self.training else self.max_iterations
inputs = V(inputs[:1].data) # inputs should be only first token initially [1,bs]
sl, bs = inputs.size()
finished = to_gpu(torch.zeros(bs).byte())
iteration = 0
self.beam_outputs = inputs.clone()
final_outputs = []
while not finished.all() and iteration < max_iterations:
# output should be List[[sl, bs, layer_dim], ...] sl should be one
if 0 < iteration and self.training and 0. < self.random() < self.pr_force:
inputs = dec_inputs[iteration].unsqueeze(0)
output = self.forward(inputs, hidden=hidden, num_beams=0, constraints=constraints)
hidden = self.decoder_layer.hidden
final_outputs.append(output) # dim should be [sl=1, bs, nt]
# inputs are the indices dims [1,bs] # repackage the var to avoid grad backwards
inputs = assert_dims(V(output.data.max(dim=-1)[1]), [1, bs])
iteration += 1
self.beam_outputs = assert_dims(torch.cat([self.beam_outputs, inputs], dim=0), [iteration + 1, bs])
def mask_logprobs(self, bs, finished, iteration, logprobs, new_logprobs, num_beams, num_tokens):
if iteration == 0:
# only the first beam is considered in the first step, otherwise we would get the same result for every beam
new_logprobs = new_logprobs[..., 0, :]
else:
# we have to cater for finished beams as well
# create a mask [1, bs x nb, nt] with - inf everywhere
mask = torch.zeros_like(new_logprobs).fill_(-1e32).view(1, bs * num_beams, num_tokens)
f = V(finished.unsqueeze(0))
# set the pad_token position to the last logprob for the finished ones
mask[..., self.pad_token] = logprobs.view(1, bs * num_beams)
# mask shape = [1, bs * nb (that are finished), nt]
mask = mask.masked_select(f.unsqueeze(-1)).view(1, -1, num_tokens)
# replace the rows of the finished ones with the mask
new_logprobs.masked_scatter_(f.view(1, bs, num_beams, 1), mask)
# flatten all beams with the tokens
new_logprobs = new_logprobs.view(1, bs, -1)
return new_logprobs