Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_ExpectedNumInstanceSampler():
N = 6
train_length = 2
pred_length = 1
ds = make_dataset(N, train_length)
t = transform.Chain(
trans=[
transform.InstanceSplitter(
target_field=FieldName.TARGET,
is_pad_field=FieldName.IS_PAD,
start_field=FieldName.START,
forecast_start_field=FieldName.FORECAST_START,
train_sampler=transform.ExpectedNumInstanceSampler(
num_instances=4
),
past_length=train_length,
future_length=pred_length,
pick_incomplete=True,
)
]
)
assert_serializable(t)
def test_InstanceSplitter(
start, target, is_train: bool, pick_incomplete: bool
):
train_length = 100
pred_length = 13
t = transform.InstanceSplitter(
target_field=FieldName.TARGET,
is_pad_field=FieldName.IS_PAD,
start_field=FieldName.START,
forecast_start_field=FieldName.FORECAST_START,
train_sampler=transform.UniformSplitSampler(p=1.0),
past_length=train_length,
future_length=pred_length,
time_series_fields=["some_time_feature"],
pick_incomplete=pick_incomplete,
)
assert_serializable(t)
other_feat = np.arange(len(target) + 100)
data = {
"start": start,
def test_BucketInstanceSampler():
N = 6
train_length = 2
pred_length = 1
ds = make_dataset(N, train_length)
dataset_stats = calculate_dataset_statistics(ds)
t = transform.Chain(
trans=[
transform.InstanceSplitter(
target_field=FieldName.TARGET,
is_pad_field=FieldName.IS_PAD,
start_field=FieldName.START,
forecast_start_field=FieldName.FORECAST_START,
train_sampler=transform.BucketInstanceSampler(
dataset_stats.scale_histogram
),
past_length=train_length,
future_length=pred_length,
pick_incomplete=True,
)
]
)
assert_serializable(t)
target_field=FieldName.TARGET,
output_field="age",
pred_length=pred_length,
log_scale=True,
),
transform.AddObservedValuesIndicator(
target_field=FieldName.TARGET,
output_field="observed_values",
convert_nans=False,
),
transform.VstackFeatures(
output_field="dynamic_feat",
input_fields=["age", "time_feat"],
drop_inputs=True,
),
transform.InstanceSplitter(
target_field=FieldName.TARGET,
is_pad_field=FieldName.IS_PAD,
start_field=FieldName.START,
forecast_start_field=FieldName.FORECAST_START,
train_sampler=transform.ExpectedNumInstanceSampler(
num_instances=4
),
past_length=train_length,
future_length=pred_length,
time_series_fields=["dynamic_feat", "observed_values"],
output_NTC=False,
),
]
)
assert_serializable(t)
def create_transformation(self):
# Model specific input transform
# Here we use a transformation that randomly
# selects training samples from all series.
return InstanceSplitter(
target_field=FieldName.TARGET,
is_pad_field=FieldName.IS_PAD,
start_field=FieldName.START,
forecast_start_field=FieldName.FORECAST_START,
train_sampler=ExpectedNumInstanceSampler(num_instances=1),
past_length=self.context_length,
future_length=self.prediction_length,
)
time_features=self.time_features,
pred_length=self.prediction_length,
),
VstackFeatures(
output_field=FieldName.FEAT_TIME,
input_fields=[FieldName.FEAT_TIME],
),
SetFieldIfNotPresent(
field=FieldName.FEAT_STATIC_CAT, value=[0.0]
),
TargetDimIndicator(
field_name="target_dimension_indicator",
target_field=FieldName.TARGET,
),
AsNumpyArray(field=FieldName.FEAT_STATIC_CAT, expected_ndim=1),
InstanceSplitter(
target_field=FieldName.TARGET,
is_pad_field=FieldName.IS_PAD,
start_field=FieldName.START,
forecast_start_field=FieldName.FORECAST_START,
train_sampler=ExpectedNumInstanceSampler(num_instances=1),
past_length=self.history_length,
future_length=self.prediction_length,
time_series_fields=[
FieldName.FEAT_TIME,
FieldName.OBSERVED_VALUES,
],
pick_incomplete=self.pick_incomplete,
),
use_marginal_transformation(self.use_marginal_transformation),
]
def create_transformation(self) -> Transformation:
return Chain(
[
InstanceSplitter(
target_field=FieldName.TARGET,
is_pad_field=FieldName.IS_PAD,
start_field=FieldName.START,
forecast_start_field=FieldName.FORECAST_START,
train_sampler=ExpectedNumInstanceSampler(num_instances=1),
past_length=self.context_length,
future_length=self.prediction_length,
time_series_fields=[], # [FieldName.FEAT_DYNAMIC_REAL]
)
target_field=FieldName.TARGET,
output_field=FieldName.FEAT_AGE,
pred_length=self.prediction_length,
log_scale=True,
dtype=self.dtype,
),
VstackFeatures(
output_field=FieldName.FEAT_TIME,
input_fields=[FieldName.FEAT_TIME, FieldName.FEAT_AGE]
+ (
[FieldName.FEAT_DYNAMIC_REAL]
if self.use_feat_dynamic_real
else []
),
),
InstanceSplitter(
target_field=FieldName.TARGET,
is_pad_field=FieldName.IS_PAD,
start_field=FieldName.START,
forecast_start_field=FieldName.FORECAST_START,
train_sampler=ExpectedNumInstanceSampler(num_instances=1),
past_length=self.history_length,
future_length=self.prediction_length,
time_series_fields=[
FieldName.FEAT_TIME,
FieldName.OBSERVED_VALUES,
],
pred_length=self.prediction_length,
),
AddAgeFeature(
target_field=FieldName.TARGET,
output_field=FieldName.FEAT_AGE,
pred_length=self.prediction_length,
),
VstackFeatures(
output_field=FieldName.FEAT_TIME,
input_fields=[FieldName.FEAT_TIME, FieldName.FEAT_AGE],
),
SetFieldIfNotPresent(
field=FieldName.FEAT_STATIC_CAT, value=[0.0]
),
AsNumpyArray(field=FieldName.FEAT_STATIC_CAT, expected_ndim=1),
InstanceSplitter(
target_field=FieldName.TARGET,
is_pad_field=FieldName.IS_PAD,
start_field=FieldName.START,
forecast_start_field=FieldName.FORECAST_START,
train_sampler=ExpectedNumInstanceSampler(num_instances=1),
past_length=self.context_length,
future_length=pred_length,
output_NTC=False,
time_series_fields=[
FieldName.FEAT_TIME,
FieldName.OBSERVED_VALUES,
],
),
QuantizeScaled(
bin_edges=bin_edges.tolist(),
future_target="future_target",
target_field=FieldName.TARGET,
output_field=FieldName.FEAT_TIME,
time_features=time_features_from_frequency_str(self.freq),
pred_length=self.prediction_length,
),
transform.VstackFeatures(
output_field=FieldName.FEAT_DYNAMIC_REAL,
input_fields=[FieldName.FEAT_TIME],
),
transform.SetFieldIfNotPresent(
field=FieldName.FEAT_STATIC_CAT, value=[0.0]
),
transform.AsNumpyArray(
field=FieldName.FEAT_STATIC_CAT, expected_ndim=1
),
transform.InstanceSplitter(
target_field=FieldName.TARGET,
is_pad_field=FieldName.IS_PAD,
start_field=FieldName.START,
forecast_start_field=FieldName.FORECAST_START,
train_sampler=ExpectedNumInstanceSampler(num_instances=1),
past_length=self.context_length,
future_length=self.prediction_length,
time_series_fields=[FieldName.FEAT_DYNAMIC_REAL],
),