Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def create_test_cases():
data_inputs = np.ones((1, 1))
expected_outputs = np.ones((1, 1))
dc = DataContainer(current_ids=range(len(data_inputs)), data_inputs=data_inputs, expected_outputs=expected_outputs)
tape = TapeCallbackFunction()
tape_fit = TapeCallbackFunction()
tape_without_checkpoint_test_arguments = ResumablePipelineTestCase(
tape,
data_inputs,
expected_outputs,
[
("a", FitTransformCallbackStep(tape.callback, tape_fit.callback, ["1"])),
("b", FitTransformCallbackStep(tape.callback, tape_fit.callback, ["2"])),
("c", FitTransformCallbackStep(tape.callback, tape_fit.callback, ["3"]))
],
["1", "2", "3"])
tape2 = TapeCallbackFunction()
tape2_fit = TapeCallbackFunction()
initial_data_inputs = [1, 2]
initial_expected_outputs = [2, 3]
create_pipeline_output_transformer = lambda: ResumablePipeline(
[
('output_transformer_1', MultiplyBy2OutputTransformer()),
('pickle_checkpoint', DefaultCheckpoint()),
('output_transformer_2', MultiplyBy2OutputTransformer()),
], cache_folder=tmpdir)
create_pipeline_output_transformer().fit_transform(
data_inputs=initial_data_inputs, expected_outputs=initial_expected_outputs
)
transformer = create_pipeline_output_transformer()
actual_data_container = transformer.handle_transform(
DataContainer(current_ids=[0, 1], data_inputs=initial_data_inputs, expected_outputs=initial_expected_outputs),
ExecutionContext(tmpdir)
)
assert np.array_equal(actual_data_container.data_inputs, [4, 8])
assert np.array_equal(actual_data_container.expected_outputs, [8, 12])
def handle_fit(self, data_container: DataContainer, context: ExecutionContext) -> (BaseStep, DataContainer):
self.wrapped = self.wrapped.handle_fit(
DataContainer(
current_ids=data_container.current_ids,
data_inputs=data_container.expected_outputs,
expected_outputs=None
),
context.push(self.wrapped)
)
current_ids = self.hash(data_container)
data_container.set_current_ids(current_ids)
return self, data_container
include_incomplete_pass=True)
conv_data_inputs = convolved_1d(stride=stride, iterable=self.data_inputs, kernel_size=kernel_size,
include_incomplete_pass=True)
conv_expected_outputs = convolved_1d(stride=stride, iterable=self.expected_outputs, kernel_size=kernel_size,
include_incomplete_pass=True)
for current_ids, data_inputs, expected_outputs in zip(conv_current_ids, conv_data_inputs,
conv_expected_outputs):
for i, (ci, di, eo) in enumerate(zip(current_ids, data_inputs, expected_outputs)):
if di is None:
current_ids = current_ids[:i]
data_inputs = data_inputs[:i]
expected_outputs = expected_outputs[:i]
break
yield DataContainer(
summary_id=self.summary_id,
current_ids=current_ids,
data_inputs=data_inputs,
expected_outputs=expected_outputs
)
def copy(self):
return DataContainer(
summary_id=self.summary_id,
current_ids=self.current_ids,
data_inputs=self.data_inputs,
expected_outputs=self.expected_outputs,
)
expected_outputs = [None] * len(self.data_inputs)
return zip(current_ids, self.data_inputs, expected_outputs)
def __repr__(self):
return str(self)
def __str__(self):
return self.__class__.__name__ + "(current_ids=" + repr(list(self.current_ids)) + ", summary_id=" + repr(
self.summary_id)
def __len__(self):
return len(self.data_inputs)
class ExpandedDataContainer(DataContainer):
"""
Sub class of DataContainer to expand data container dimension.
.. seealso::
:class:`ExpandedDataContainer`,
"""
def __init__(self, current_ids, data_inputs, expected_outputs, summary_id, old_current_ids):
DataContainer.__init__(
self,
current_ids=current_ids,
data_inputs=data_inputs,
expected_outputs=expected_outputs,
summary_id=summary_id
)
def handle_transform(self, data_container: DataContainer, context: ExecutionContext) -> DataContainer:
new_expected_outputs_data_container = self.wrapped.handle_transform(
DataContainer(
current_ids=data_container.current_ids,
data_inputs=data_container.expected_outputs,
expected_outputs=None
),
context.push(self.wrapped)
)
data_container.set_expected_outputs(new_expected_outputs_data_container.data_inputs)
current_ids = self.hash(data_container)
data_container.set_current_ids(current_ids)
return data_container