Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
# (batch_size, sub_seq_len, input_dim)
inputs = F.concat(
input_lags, repeated_index_embeddings, time_feat, dim=-1
)
# unroll encoder
outputs, state = self.rnn.unroll(
inputs=inputs,
length=unroll_length,
layout="NTC",
merge_outputs=True,
begin_state=begin_state,
)
assert_shape(outputs, (-1, unroll_length, self.num_cells))
for s in state:
assert_shape(s, (-1, self.num_cells))
assert_shape(
lags_scaled,
(-1, unroll_length, self.target_dim, len(self.lags_seq)),
)
return outputs, state, lags_scaled, inputs
inputs = F.concat(
input_lags, repeated_index_embeddings, time_feat, dim=-1
)
# unroll encoder
outputs, state = self.rnn.unroll(
inputs=inputs,
length=unroll_length,
layout="NTC",
merge_outputs=True,
begin_state=begin_state,
)
assert_shape(outputs, (-1, unroll_length, self.num_cells))
for s in state:
assert_shape(s, (-1, self.num_cells))
assert_shape(
lags_scaled,
(-1, unroll_length, self.target_dim, len(self.lags_seq)),
)
return outputs, state, lags_scaled, inputs
# (batch_size, sub_seq_len, target_dim, num_lags)
lags_scaled = F.broadcast_div(lags, scale.expand_dims(axis=-1))
assert_shape(
lags_scaled,
(-1, unroll_length, self.target_dim, len(self.lags_seq)),
)
input_lags = F.reshape(
data=lags_scaled,
shape=(-1, unroll_length, len(self.lags_seq) * self.target_dim),
)
# (batch_size, target_dim, embed_dim)
index_embeddings = self.embed(target_dimension_indicator)
assert_shape(index_embeddings, (-1, self.target_dim, self.embed_dim))
# (batch_size, seq_len, target_dim * embed_dim)
repeated_index_embeddings = (
index_embeddings.expand_dims(axis=1)
.repeat(axis=1, repeats=unroll_length)
.reshape((-1, unroll_length, self.target_dim * self.embed_dim))
)
# (batch_size, sub_seq_len, input_dim)
inputs = F.concat(
input_lags, repeated_index_embeddings, time_feat, dim=-1
)
# unroll encoder
outputs, state = self.rnn.unroll(
inputs=inputs,
Returns
-------
outputs
RNN outputs (batch_size, seq_len, num_cells)
states
RNN states. Nested list with (batch_size, num_cells) tensors with
dimensions target_dim x num_layers x (batch_size, num_cells)
lags_scaled
Scaled lags(batch_size, sub_seq_len, target_dim, num_lags)
inputs
inputs to the RNN
"""
# (batch_size, sub_seq_len, target_dim, num_lags)
lags_scaled = F.broadcast_div(lags, scale.expand_dims(axis=-1))
assert_shape(
lags_scaled,
(-1, unroll_length, self.target_dim, len(self.lags_seq)),
)
input_lags = F.reshape(
data=lags_scaled,
shape=(-1, unroll_length, len(self.lags_seq) * self.target_dim),
)
# (batch_size, target_dim, embed_dim)
index_embeddings = self.embed(target_dimension_indicator)
assert_shape(index_embeddings, (-1, self.target_dim, self.embed_dim))
# (batch_size, seq_len, target_dim * embed_dim)
repeated_index_embeddings = (
index_embeddings.expand_dims(axis=1)
),
future_observed_values,
dim=1,
)
# mask the loss at one time step if one or more observations is missing
# in the target dimensions (batch_size, subseq_length, 1)
loss_weights = observed_values.min(axis=-1, keepdims=True)
assert_shape(loss_weights, (-1, seq_len, 1))
loss = weighted_average(
F=F, x=likelihoods, weights=loss_weights, axis=1
)
assert_shape(loss, (-1, -1, 1))
self.distribution = distr
return (loss, likelihoods) + distr_args
# assert_shape(target, (-1, seq_len, self.target_dim))
distr, distr_args = self.distr(
time_features=inputs,
rnn_outputs=rnn_outputs,
scale=scale,
lags_scaled=lags_scaled,
target_dimension_indicator=target_dimension_indicator,
seq_len=self.context_length + self.prediction_length,
)
# we sum the last axis to have the same shape for all likelihoods
# (batch_size, subseq_length, 1)
likelihoods = -distr.log_prob(target).expand_dims(axis=-1)
assert_shape(likelihoods, (-1, seq_len, 1))
past_observed_values = F.broadcast_minimum(
past_observed_values, 1 - past_is_pad.expand_dims(axis=-1)
)
# (batch_size, subseq_length, target_dim)
observed_values = F.concat(
past_observed_values.slice_axis(
axis=1, begin=-self.context_length, end=None
),
future_observed_values,
dim=1,
)
# mask the loss at one time step if one or more observations is missing
# in the target dimensions (batch_size, subseq_length, 1)