Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def create_e2e_components(pipeline_root: Text, csv_input_location: Text,
taxi_module_file: Text) -> List[BaseComponent]:
"""Creates components for a simple Chicago Taxi TFX pipeline for testing.
Args:
pipeline_root: The root of the pipeline output.
csv_input_location: The location of the input data directory.
taxi_module_file: The location of the module file for Transform/Trainer.
Returns:
A list of TFX components that constitutes an end-to-end test pipeline.
"""
examples = dsl_utils.external_input(csv_input_location)
example_gen = CsvExampleGen(input_base=examples)
statistics_gen = StatisticsGen(input_data=example_gen.outputs['examples'])
infer_schema = SchemaGen(
stats=statistics_gen.outputs['output'], infer_feature_shape=False)
validate_stats = ExampleValidator(
stats=statistics_gen.outputs['output'],
schema=infer_schema.outputs['output'])
transform = Transform(
input_data=example_gen.outputs['examples'],
schema=infer_schema.outputs['output'],
module_file=taxi_module_file)
trainer = Trainer(
module_file=taxi_module_file,
transformed_examples=transform.outputs['transformed_examples'],
schema=infer_schema.outputs['output'],
def _create_pipeline(pipeline_name: Text, pipeline_root: Text, data_root: Text,
metadata_path: Text) -> pipeline.Pipeline:
"""Implements the chicago taxi pipeline with TFX."""
examples = external_input(data_root)
# Brings data into the pipeline or otherwise joins/converts training data.
example_gen = CsvExampleGen(input=examples)
return pipeline.Pipeline(
pipeline_name=pipeline_name,
pipeline_root=pipeline_root,
components=[example_gen],
enable_cache=True,
metadata_connection_config=metadata.sqlite_metadata_connection_config(
metadata_path),
additional_pipeline_args={},
)
def _create_pipeline(pipeline_name: Text, pipeline_root: Text, data_root: Text,
module_file: Text, serving_model_dir: Text,
metadata_path: Text) -> pipeline.Pipeline:
"""Implements the cifar10 pipeline with TFX."""
examples = external_input(data_root)
input_split = example_gen_pb2.Input(splits=[
example_gen_pb2.Input.Split(name='train', pattern='train.tfrecord'),
example_gen_pb2.Input.Split(name='eval', pattern='test.tfrecord')
])
example_gen = ImportExampleGen(input=examples, input_config=input_split)
# Computes statistics over data for visualization and example validation.
statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])
# Generates schema based on statistics files.
infer_schema = SchemaGen(
statistics=statistics_gen.outputs['statistics'], infer_feature_shape=True)
# Performs anomaly detection based on statistics and data schema.
validate_stats = ExampleValidator(
statistics=statistics_gen.outputs['statistics'],
schema=infer_schema.outputs['schema'])
# Uses TFMA to compute a evaluation statistics over features of a model.
model_analyzer = Evaluator(
examples=training_example_gen.outputs['examples'],
model_exports=trainer.outputs['output'],
feature_slicing_spec=evaluator_pb2.FeatureSlicingSpec(specs=[
evaluator_pb2.SingleSlicingSpec(
column_for_slicing=['trip_start_hour'])
]))
# Performs quality validation of a candidate model (compared to a baseline).
model_validator = ModelValidator(
examples=training_example_gen.outputs['examples'],
model=trainer.outputs['output'])
inference_examples = external_input(inference_data_root)
# Brings inference data into the pipeline.
inference_example_gen = CsvExampleGen(
input_base=inference_examples,
output_config=example_gen_pb2.Output(
split_config=example_gen_pb2.SplitConfig(
splits=[example_gen_pb2.SplitConfig.Split(
name='unlabelled', hash_buckets=100)])),
instance_name='inference_example_gen')
# Performs offline batch inference over inference examples.
bulk_inferrer = BulkInferrer(
examples=inference_example_gen.outputs['examples'],
model_export=trainer.outputs['output'],
model_blessing=model_validator.outputs['blessing'],
# Empty data_spec.example_splits will result in using all splits.
def _create_pipeline(pipeline_name: Text, pipeline_root: Text, data_root: Text,
module_file: Text, serving_model_dir: Text,
metadata_path: Text) -> pipeline.Pipeline:
"""Implements the chicago taxi pipeline with TFX."""
examples = external_input(data_root)
# Brings data into the pipeline or otherwise joins/converts training data.
example_gen = CsvExampleGen(input=examples)
# Computes statistics over data for visualization and example validation.
statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])
# Generates schema based on statistics files.
infer_schema = SchemaGen(
statistics=statistics_gen.outputs['statistics'],
infer_feature_shape=False)
# Performs anomaly detection based on statistics and data schema.
validate_stats = ExampleValidator(
statistics=statistics_gen.outputs['statistics'],
schema=infer_schema.outputs['schema'])
def _create_pipeline(pipeline_name: Text, pipeline_root: Text, data_root: Text,
module_file: Text, serving_model_dir: Text,
metadata_path: Text,
direct_num_workers: int) -> pipeline.Pipeline:
"""Implements the chicago taxi pipeline with TFX."""
examples = external_input(data_root)
# Brings data into the pipeline or otherwise joins/converts training data.
example_gen = CsvExampleGen(input=examples)
# Computes statistics over data for visualization and example validation.
statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])
# Generates schema based on statistics files.
infer_schema = SchemaGen(
statistics=statistics_gen.outputs['statistics'],
infer_feature_shape=False)
# Performs anomaly detection based on statistics and data schema.
validate_stats = ExampleValidator(
statistics=statistics_gen.outputs['statistics'],
schema=infer_schema.outputs['schema'])
def _create_pipeline(
pipeline_name: Text, pipeline_root: Text, data_root: Text,
module_file: Text, serving_model_dir: Text,
metadata_connection_config: metadata_store_pb2.ConnectionConfig
) -> pipeline.Pipeline:
"""Implements the chicago taxi pipeline with TFX."""
examples = external_input(data_root)
# Brings data into the pipeline or otherwise joins/converts training data.
example_gen = CsvExampleGen(input_base=examples)
# Computes statistics over data for visualization and example validation.
statistics_gen = StatisticsGen(input_data=example_gen.outputs['examples'])
# Generates schema based on statistics files.
infer_schema = SchemaGen(stats=statistics_gen.outputs['output'])
# Performs anomaly detection based on statistics and data schema.
validate_stats = ExampleValidator(
stats=statistics_gen.outputs['output'],
schema=infer_schema.outputs['output'])
# Performs transformations and feature engineering in training and serving.
def _create_test_pipeline(
pipeline_root: Text, csv_input_location: Text, taxi_module_file: Text,
enable_cache: bool
):
"""Creates a simple Kubeflow-based Chicago Taxi TFX pipeline.
Args:
pipeline_root: The root of the pipeline output.
csv_input_location: The location of the input data directory.
taxi_module_file: The location of the module file for Transform/Trainer.
enable_cache: Whether to enable cache or not.
Returns:
A logical TFX pipeline.Pipeline object.
"""
examples = external_input(csv_input_location)
example_gen = CsvExampleGen(input=examples)
statistics_gen = StatisticsGen(input_data=example_gen.outputs['examples'])
infer_schema = SchemaGen(
stats=statistics_gen.outputs['statistics'],
infer_feature_shape=False,
)
validate_stats = ExampleValidator(
stats=statistics_gen.outputs['statistics'],
schema=infer_schema.outputs['schema'],
)
transform = Transform(
input_data=example_gen.outputs['examples'],
schema=infer_schema.outputs['schema'],
module_file=taxi_module_file,
)
def _create_pipeline(pipeline_name: Text, pipeline_root: Text, data_root: Text,
module_file: Text, serving_model_dir: Text,
direct_num_workers: int) -> pipeline.Pipeline:
"""Implements the chicago taxi pipeline with TFX and Kubeflow Pipelines."""
examples = external_input(data_root)
# Brings data into the pipeline or otherwise joins/converts training data.
example_gen = CsvExampleGen(input=examples)
# Computes statistics over data for visualization and example validation.
statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])
# Generates schema based on statistics files.
infer_schema = SchemaGen(
statistics=statistics_gen.outputs['statistics'],
infer_feature_shape=False)
# Performs anomaly detection based on statistics and data schema.
validate_stats = ExampleValidator(
statistics=statistics_gen.outputs['statistics'],
schema=infer_schema.outputs['schema'])