Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
components = tf.io.gfile.listdir(pipeline_root)
if 'SchemaGen' not in components:
sys.exit(
'Either SchemaGen component does not exist or pipeline is still running. If pipeline is running, then wait for it to successfully finish.'
)
# Get the latest SchemaGen output.
schemagen_outputs = tf.io.gfile.listdir(
os.path.join(pipeline_root, 'SchemaGen', 'output', ''))
latest_schema_folder = max(schemagen_outputs, key=int)
# Copy schema to current dir.
latest_schema_path = os.path.join(pipeline_root, 'SchemaGen', 'output',
latest_schema_folder, 'schema.pbtxt')
curr_dir_path = os.path.join(os.getcwd(), 'schema.pbtxt')
io_utils.copy_file(latest_schema_path, curr_dir_path, overwrite=True)
# Print schema and path to schema
click.echo('Path to schema: {}'.format(curr_dir_path))
click.echo('*********SCHEMA FOR {}**********'.format(pipeline_name.upper()))
with open(curr_dir_path, 'r') as f:
click.echo(f.read())
output_dict: Output dict from key to a list of artifacts, including:
- output: A list of 'ExampleValidationPath' artifact of size one. It
will include a single pbtxt file which contains all anomalies found.
exec_properties: A dict of execution properties. Not used yet.
Returns:
None
"""
self._log_startup(input_dict, output_dict, exec_properties)
tf.logging.info('Validating schema against the computed statistics.')
schema = io_utils.SchemaReader().read(
io_utils.get_only_uri_in_dir(
artifact_utils.get_single_uri(input_dict['schema'])))
stats = tfdv.load_statistics(
io_utils.get_only_uri_in_dir(
artifact_utils.get_split_uri(input_dict['stats'], 'eval')))
output_uri = artifact_utils.get_single_uri(output_dict['output'])
anomalies = tfdv.validate_statistics(stats, schema)
io_utils.write_pbtxt_file(
os.path.join(output_uri, DEFAULT_FILE_NAME), anomalies)
tf.logging.info(
'Validation complete. Anomalies written to {}.'.format(output_uri))
Returns:
None
"""
# TODO(zhitaoli): Move constants between this file and component.py to a
# constants.py.
train_stats_uri = io_utils.get_only_uri_in_dir(
artifact_utils.get_split_uri(input_dict['stats'], 'train'))
output_uri = os.path.join(
artifact_utils.get_single_uri(output_dict['output']),
_DEFAULT_FILE_NAME)
infer_feature_shape = exec_properties['infer_feature_shape']
absl.logging.info('Infering schema from statistics.')
schema = tfdv.infer_schema(
tfdv.load_statistics(train_stats_uri), infer_feature_shape)
io_utils.write_pbtxt_file(output_uri, schema)
absl.logging.info('Schema written to %s.' % output_uri)
def _save_pipeline(self, pipeline_args: Dict[Text, Any]) -> None:
"""Creates/updates pipeline folder in the handler directory."""
# Add pipeline dsl path to pipeline args.
pipeline_args[labels.PIPELINE_DSL_PATH] = self.flags_dict[
labels.PIPELINE_DSL_PATH]
# Path to pipeline folder in beam.
handler_pipeline_path = os.path.join(self._handler_home_dir,
pipeline_args[labels.PIPELINE_NAME],
'')
# If updating pipeline, first delete pipeline directory.
if tf.io.gfile.exists(handler_pipeline_path):
io_utils.delete_dir(handler_pipeline_path)
# Dump pipeline_args to handler pipeline folder as json.
tf.io.gfile.makedirs(handler_pipeline_path)
with open(os.path.join(handler_pipeline_path, 'pipeline_args.json'),
'w') as f:
json.dump(pipeline_args, f)
def _CopyCache(src, dst):
# TODO(b/37788560): Make this more efficient.
io_utils.copy_dir(src, dst)
Path to ephemeral sdist package.
Raises:
RuntimeError: if dist directory has zero or multiple files.
"""
tmp_dir = os.path.join(tempfile.mkdtemp(), 'build', 'tfx')
tfx_root_dir = os.path.dirname(os.path.dirname(__file__))
absl.logging.info('Copying all content from install dir %s to temp dir %s',
tfx_root_dir, tmp_dir)
shutil.copytree(tfx_root_dir, tmp_dir)
# Source directory default permission is 0555 but we need to be able to create
# new setup.py file.
os.chmod(tmp_dir, 0o720)
setup_file = os.path.join(tmp_dir, 'setup.py')
absl.logging.info('Generating a temp setup file at %s', setup_file)
install_requires = dependencies.make_required_install_packages()
io_utils.write_string_file(
setup_file,
_ephemeral_setup_file.format(
version=version.__version__, install_requires=install_requires))
# Create the package
curdir = os.getcwd()
os.chdir(tmp_dir)
cmd = [sys.executable, setup_file, 'sdist']
subprocess.call(cmd)
os.chdir(curdir)
# Return the package dir+filename
dist_dir = os.path.join(tmp_dir, 'dist')
files = tf.io.gfile.listdir(dist_dir)
if not files:
raise RuntimeError('Found no package files in %s' % dist_dir)
# A single uri for the output directory of the serving model.
serving_model_dir=serving_model_dir,
# A list of uris for eval files.
eval_files=eval_files,
# A single uri for schema file.
schema_file=schema_file,
# Number of train steps.
train_steps=train_steps,
# Number of eval steps.
eval_steps=eval_steps,
# Base model that will be used for this training job.
base_model=base_model,
# Additional parameters to pass to trainer function.
**custom_config)
schema = io_utils.parse_pbtxt_file(schema_file, schema_pb2.Schema())
training_spec = trainer_fn(train_fn_args, schema)
# Train the model
absl.logging.info('Training model.')
tf.estimator.train_and_evaluate(training_spec['estimator'],
training_spec['train_spec'],
training_spec['eval_spec'])
absl.logging.info('Training complete. Model written to %s',
serving_model_dir)
# Export an eval savedmodel for TFMA
absl.logging.info('Exporting eval_savedmodel for TFMA.')
tfma.export.export_eval_savedmodel(
estimator=training_spec['estimator'],
export_dir_base=eval_model_dir,
- stats: A list of 'ExampleStatisticsPath' type which should contain
split 'eval'. Stats on other splits are ignored.
- schema: A list of 'SchemaPath' type which should contain a single
schema artifact.
output_dict: Output dict from key to a list of artifacts, including:
- output: A list of 'ExampleValidationPath' artifact of size one. It
will include a single pbtxt file which contains all anomalies found.
exec_properties: A dict of execution properties. Not used yet.
Returns:
None
"""
self._log_startup(input_dict, output_dict, exec_properties)
tf.logging.info('Validating schema against the computed statistics.')
schema = io_utils.SchemaReader().read(
io_utils.get_only_uri_in_dir(
artifact_utils.get_single_uri(input_dict['schema'])))
stats = tfdv.load_statistics(
io_utils.get_only_uri_in_dir(
artifact_utils.get_split_uri(input_dict['stats'], 'eval')))
output_uri = artifact_utils.get_single_uri(output_dict['output'])
anomalies = tfdv.validate_statistics(stats, schema)
io_utils.write_pbtxt_file(
os.path.join(output_uri, DEFAULT_FILE_NAME), anomalies)
tf.logging.info(
'Validation complete. Anomalies written to {}.'.format(output_uri))
- output: A list of 'ExampleStatisticsPath' type. This should contain
both 'train' and 'eval' split.
exec_properties: A dict of execution properties. Not used yet.
Returns:
None
"""
self._log_startup(input_dict, output_dict, exec_properties)
split_to_instance = {x.split: x for x in input_dict['input_data']}
with beam.Pipeline(argv=self._get_beam_pipeline_args()) as p:
# TODO(b/126263006): Support more stats_options through config.
stats_options = options.StatsOptions()
for split, instance in split_to_instance.items():
tf.logging.info('Generating statistics for split {}'.format(split))
input_uri = io_utils.all_files_pattern(instance.uri)
output_uri = artifact_utils.get_split_uri(output_dict['output'], split)
output_path = os.path.join(output_uri, _DEFAULT_FILE_NAME)
_ = (
p
| 'ReadData.' + split >>
beam.io.ReadFromTFRecord(file_pattern=input_uri)
| 'DecodeData.' + split >> tf_example_decoder.DecodeTFExample()
| 'GenerateStatistics.' + split >>
stats_api.GenerateStatistics(stats_options)
| 'WriteStatsOutput.' + split >> beam.io.WriteToTFRecord(
output_path,
shard_name_template='',
coder=beam.coders.ProtoCoder(
statistics_pb2.DatasetFeatureStatisticsList)))
tf.logging.info('Statistics for split {} written to {}.'.format(
split, output_uri))
tuner.search_space_summary()
# TODO(jyzhao): assert v2 behavior as KerasTuner doesn't work in v1.
# TODO(jyzhao): make epochs configurable.
tuner.search(
tuner_spec.train_dataset,
epochs=5,
validation_data=tuner_spec.eval_dataset)
tuner.results_summary()
best_hparams = tuner.oracle.get_best_trials(
1)[0].hyperparameters.get_config()
best_hparams_path = os.path.join(
artifact_utils.get_single_uri(output_dict['study_best_hparams_path']),
_DEFAULT_FILE_NAME)
io_utils.write_string_file(best_hparams_path, json.dumps(best_hparams))
absl.logging.info('Best HParams is written to %s.' % best_hparams_path)