Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def _get_input_buffer(self, incremental_state):
return utils.get_incremental_state(self, incremental_state, 'input_buffer')
def _get_input_buffer(self, incremental_state):
return utils.get_incremental_state(
self,
incremental_state,
'attn_state',
) or {}
encoder_outs,
final_hidden,
final_cell,
src_lengths,
src_tokens,
_,
) = encoder_out
# embed tokens
x = self.embed_tokens(input_tokens)
x = F.dropout(x, p=self.dropout_in, training=self.training)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
# initialize previous states (or get from cache during incremental generation)
cached_state = utils.get_incremental_state(
self, incremental_state, "cached_state"
)
input_feed = None
if cached_state is not None:
prev_hiddens, prev_cells, input_feed = cached_state
else:
# first time step, initialize previous states
init_prev_states = self._init_prev_states(encoder_out)
prev_hiddens = []
prev_cells = []
# init_prev_states may or may not include initial attention context
for (h, c) in zip(init_prev_states[0::2], init_prev_states[1::2]):
prev_hiddens.append(h)
prev_cells.append(c)
if self.attention.context_dim:
def _get_input_buffer(self, incremental_state, name):
return utils.get_incremental_state(self, incremental_state, name)
def _get_input_buffer(self, incremental_state, name):
return utils.get_incremental_state(self, incremental_state, name)
def _get_input_buffer(self, incremental_state):
return utils.get_incremental_state(self, incremental_state, 'input_buffer')
encoder_outs = encoder_out[0]
srclen = encoder_outs.size(0)
if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:]
bsz, seqlen = prev_output_tokens.size()
# embed tokens
x = self.embed_tokens(prev_output_tokens)
x = F.dropout(x, p=self.dropout_in, training=self.training)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
# initialize previous states (or get from cache during incremental generation)
cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state')
if cached_state is not None:
prev_hiddens, prev_cells, input_feed = cached_state
else:
num_layers = len(self.layers)
prev_hiddens = [x.new_zeros(bsz, self.hidden_size) for i in range(num_layers)]
prev_cells = [x.new_zeros(bsz, self.hidden_size) for i in range(num_layers)]
input_feed = x.new_zeros(bsz, self.encoder_output_units) \
if self.attention is not None else None
if self.attention is not None:
attn_scores = x.new_zeros(srclen, seqlen, bsz)
outs = []
for j in range(seqlen):
# input feeding: concatenate context vector from previous time step
input = torch.cat((x[j, :, :], input_feed), dim=1) \
if input_feed is not None else x[j, :, :]
def _get_input_buffer(self, incremental_state):
return utils.get_incremental_state(self, incremental_state, 'input_buffer')
def masked_copy_incremental_state(self, incremental_state, another_state, mask):
state = utils.get_incremental_state(self, incremental_state, 'encoder_out')
if state is None:
assert another_state is None
return
def mask_copy_state(state, another_state):
if isinstance(state, list):
assert isinstance(another_state, list) and len(state) == len(another_state)
return [
mask_copy_state(state_i, another_state_i)
for state_i, another_state_i in zip(state, another_state)
]
if state is not None:
assert state.size(0) == mask.size(0) and another_state is not None and \
state.size() == another_state.size()
for _ in range(1, len(state.size())):
mask_unsqueezed = mask.unsqueeze(-1)
def reorder_incremental_state(self, incremental_state, new_order):
super().reorder_incremental_state(incremental_state, new_order)
cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state')
if cached_state is None:
return
def reorder_state(state):
if isinstance(state, list):
return [reorder_state(state_i) for state_i in state]
return state.index_select(0, new_order)
new_state = tuple(map(reorder_state, cached_state))
utils.set_incremental_state(self, incremental_state, 'cached_state', new_state)