Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
TunerFnResult = NamedTuple('TunerFnResult', [('tuner', kerastuner.Tuner),
('train_dataset', tf.data.Dataset),
('eval_dataset', tf.data.Dataset)])
# TODO(jyzhao): move to tfx/types/standard_component_specs.py.
class TunerSpec(ComponentSpec):
"""ComponentSpec for TFX Tuner Component."""
PARAMETERS = {
'module_file': ExecutionParameter(type=(str, Text), optional=True),
'tuner_fn': ExecutionParameter(type=(str, Text), optional=True),
}
INPUTS = {
'examples': ChannelParameter(type=standard_artifacts.Examples),
'schema': ChannelParameter(type=standard_artifacts.Schema),
}
OUTPUTS = {
'model_export_path':
ChannelParameter(type=standard_artifacts.Model),
'study_best_hparams_path':
ChannelParameter(type=standard_artifacts.HyperParameters),
}
# TODO(b/139281215): these input / output names will be renamed in the future.
# These compatibility aliases are provided for forwards compatibility.
_OUTPUT_COMPATIBILITY_ALIASES = {
'model': 'model_export_path',
'best_hparams': 'study_best_hparams_path',
}
class Tuner(base_component.BaseComponent):
Trainer component.
model_blessing: A Channel of 'ModelBlessingPath' type, usually produced by
Model Validator component.
data_spec: bulk_inferrer_pb2.DataSpec instance that describes data
selection.
model_spec: bulk_inferrer_pb2.ModelSpec instance that describes model
specification.
inference_result: Channel of `InferenceResult` to store the inference
results.
instance_name: Optional name assigned to this specific instance of
BulkInferrer. Required only if multiple BulkInferrer components are
declared in the same pipeline.
"""
inference_result = inference_result or types.Channel(
type=standard_artifacts.InferenceResult,
artifacts=[standard_artifacts.InferenceResult()])
spec = BulkInferrerSpec(
examples=examples,
model=model,
model_blessing=model_blessing,
data_spec=data_spec or bulk_inferrer_pb2.DataSpec(),
model_spec=model_spec or bulk_inferrer_pb2.ModelSpec(),
inference_result=inference_result)
super(BulkInferrer, self).__init__(spec=spec, instance_name=instance_name)
output_config: An example_gen_pb2.Output instance, providing output
configuration. If unset, default splits will be 'train' and 'eval' with
size 2:1.
custom_config: An optional example_gen_pb2.CustomConfig instance,
providing custom configuration for executor.
component_name: Name of the component, should be unique per component
class. Default to 'ExampleGen', can be overwritten by sub-classes.
example_artifacts: Optional channel of 'ExamplesPath' for output train and
eval examples.
name: Unique name for every component class instance.
"""
# Configure outputs.
output_config = output_config or utils.make_default_output_config(
input_config)
example_artifacts = example_artifacts or channel_utils.as_channel([
standard_artifacts.Examples(split=split_name)
for split_name in utils.generate_output_split_names(
input_config, output_config)
])
spec = QueryBasedExampleGenSpec(
input_config=input_config,
output_config=output_config,
custom_config=custom_config,
examples=example_artifacts)
super(_QueryBasedExampleGen, self).__init__(spec=spec, name=name)
class EvaluatorSpec(ComponentSpec):
"""Evaluator component spec."""
PARAMETERS = {
'feature_slicing_spec':
ExecutionParameter(type=evaluator_pb2.FeatureSlicingSpec),
# This parameter is experimental: its interface and functionality may
# change at any time.
'fairness_indicator_thresholds':
ExecutionParameter(type=List[float], optional=True),
}
INPUTS = {
'examples': ChannelParameter(type=standard_artifacts.Examples),
# TODO(b/139281215): this will be renamed to 'model' in the future.
'model_exports': ChannelParameter(type=standard_artifacts.Model),
}
OUTPUTS = {
'output': ChannelParameter(type=standard_artifacts.ModelEvaluation),
}
# TODO(b/139281215): these input / output names will be renamed in the future.
# These compatibility aliases are provided for forwards compatibility.
_INPUT_COMPATIBILITY_ALIASES = {
'model': 'model_exports',
}
_OUTPUT_COMPATIBILITY_ALIASES = {
'evaluation': 'output',
}
class ExampleValidatorSpec(ComponentSpec):
"""ExampleValidator component spec."""
FeatureSlicingSpec proto message.
fairness_indicator_thresholds: Optional list of float (or
RuntimeParameter) threshold values for use with TFMA fairness
indicators. Experimental functionality: this interface and
functionality may change at any time. TODO(b/142653905): add a link to
additional documentation for TFMA fairness indicators here.
output: Channel of `ModelEvalPath` to store the evaluation results.
model_exports: Backwards compatibility alias for the `model` argument.
instance_name: Optional name assigned to this specific instance of
Evaluator. Required only if multiple Evaluator components are declared
in the same pipeline. Either `model_exports` or `model` must be present
in the input arguments.
"""
model_exports = model_exports or model
output = output or types.Channel(
type=standard_artifacts.ModelEvaluation,
artifacts=[standard_artifacts.ModelEvaluation()])
spec = EvaluatorSpec(
examples=examples,
model_exports=model_exports,
feature_slicing_spec=(feature_slicing_spec or
evaluator_pb2.FeatureSlicingSpec()),
fairness_indicator_thresholds=fairness_indicator_thresholds,
output=output)
super(Evaluator, self).__init__(spec=spec, instance_name=instance_name)
from tfx.types.component_spec import ComponentSpec
from tfx.types.component_spec import ExecutionParameter
class BulkInferrerSpec(ComponentSpec):
"""BulkInferrer component spec."""
PARAMETERS = {
'model_spec':
ExecutionParameter(type=bulk_inferrer_pb2.ModelSpec, optional=True),
'data_spec':
ExecutionParameter(type=bulk_inferrer_pb2.DataSpec, optional=True),
}
INPUTS = {
'examples':
ChannelParameter(type=standard_artifacts.Examples),
'model':
ChannelParameter(type=standard_artifacts.Model, optional=True),
'model_blessing':
ChannelParameter(
type=standard_artifacts.ModelBlessing, optional=True),
}
OUTPUTS = {
'inference_result':
ChannelParameter(type=standard_artifacts.InferenceResult),
}
class EvaluatorSpec(ComponentSpec):
"""Evaluator component spec."""
PARAMETERS = {
"""BulkInferrer component spec."""
PARAMETERS = {
'model_spec':
ExecutionParameter(type=bulk_inferrer_pb2.ModelSpec, optional=True),
'data_spec':
ExecutionParameter(type=bulk_inferrer_pb2.DataSpec, optional=True),
}
INPUTS = {
'examples':
ChannelParameter(type=standard_artifacts.Examples),
'model':
ChannelParameter(type=standard_artifacts.Model, optional=True),
'model_blessing':
ChannelParameter(
type=standard_artifacts.ModelBlessing, optional=True),
}
OUTPUTS = {
'inference_result':
ChannelParameter(type=standard_artifacts.InferenceResult),
}
class EvaluatorSpec(ComponentSpec):
"""Evaluator component spec."""
PARAMETERS = {
'feature_slicing_spec':
ExecutionParameter(type=evaluator_pb2.FeatureSlicingSpec),
# This parameter is experimental: its interface and functionality may
# change at any time.
'fairness_indicator_thresholds':
model: A Channel of 'ModelExportPath' type, usually produced by
Trainer component.
model_blessing: A Channel of 'ModelBlessingPath' type, usually produced by
Model Validator component.
data_spec: bulk_inferrer_pb2.DataSpec instance that describes data
selection.
model_spec: bulk_inferrer_pb2.ModelSpec instance that describes model
specification.
inference_result: Channel of `InferenceResult` to store the inference
results.
instance_name: Optional name assigned to this specific instance of
BulkInferrer. Required only if multiple BulkInferrer components are
declared in the same pipeline.
"""
inference_result = inference_result or types.Channel(
type=standard_artifacts.InferenceResult,
artifacts=[standard_artifacts.InferenceResult()])
spec = BulkInferrerSpec(
examples=examples,
model=model,
model_blessing=model_blessing,
data_spec=data_spec or bulk_inferrer_pb2.DataSpec(),
model_spec=model_spec or bulk_inferrer_pb2.ModelSpec(),
inference_result=inference_result)
super(BulkInferrer, self).__init__(spec=spec, instance_name=instance_name)
class FileBasedExampleGenSpec(ComponentSpec):
"""File-based ExampleGen component spec."""
PARAMETERS = {
'input_config':
ExecutionParameter(type=example_gen_pb2.Input),
'output_config':
ExecutionParameter(type=example_gen_pb2.Output),
'custom_config':
ExecutionParameter(type=example_gen_pb2.CustomConfig, optional=True),
}
INPUTS = {
# TODO(b/139281215): this will be renamed to 'input' in the future.
'input_base': ChannelParameter(type=standard_artifacts.ExternalArtifact),
}
OUTPUTS = {
'examples': ChannelParameter(type=standard_artifacts.Examples),
}
# TODO(b/139281215): these input names will be renamed in the future.
# These compatibility aliases are provided for forwards compatibility.
_INPUT_COMPATIBILITY_ALIASES = {
'input': 'input_base',
}
class InfraValidatorSpec(ComponentSpec):
"""InfraValidator component spec."""
PARAMETERS = {
'serving_spec': ExecutionParameter(type=infra_validator_pb2.ServingSpec)
from tfx.types import standard_artifacts
class ExampleAnomaliesVisualization(visualizations.ArtifactVisualization):
ARTIFACT_TYPE = standard_artifacts.ExampleAnomalies
def display(self, artifact: types.Artifact):
anomalies_path = os.path.join(artifact.uri, 'anomalies.pbtxt')
anomalies = tfdv.load_anomalies_text(anomalies_path)
tfdv.display_anomalies(anomalies)
class ExampleStatisticsVisualization(visualizations.ArtifactVisualization):
ARTIFACT_TYPE = standard_artifacts.ExampleStatistics
def display(self, artifact: types.Artifact):
stats_path = os.path.join(artifact.uri, 'stats_tfrecord')
stats = tfdv.load_statistics(stats_path)
tfdv.visualize_statistics(stats)
class ModelEvaluationVisualization(visualizations.ArtifactVisualization):
ARTIFACT_TYPE = standard_artifacts.ModelEvaluation
def display(self, artifact: types.Artifact):
tfma_result = tfma.load_eval_result(artifact.uri)
# TODO(ccy): add comment instructing user to use the TFMA library directly
# in order to render non-default slicing metric views.
tfma.view.render_slicing_metrics(tfma_result)