Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
is consumed by this class. For the full set of parameters supported by
Google Cloud AI Platform, refer to
https://cloud.google.com/ml-engine/docs/tensorflow/deploying-models#creating_a_model_version.
Returns:
None
Raises:
ValueError: if ai_platform_serving_args is not in
exec_properties.custom_config.
RuntimeError: if the Google Cloud AI Platform training job failed.
"""
self._log_startup(input_dict, output_dict, exec_properties)
if not self.CheckBlessing(input_dict, output_dict):
return
model_export = artifact_utils.get_single_instance(
input_dict['model_export'])
model_export_uri = model_export.uri
model_blessing_uri = artifact_utils.get_single_uri(
input_dict['model_blessing'])
model_push = artifact_utils.get_single_instance(output_dict['model_push'])
# TODO(jyzhao): should this be in driver or executor.
if not tf.gfile.Exists(os.path.join(model_blessing_uri, 'BLESSED')):
model_push.set_int_custom_property('pushed', 0)
tf.logging.info('Model on %s was not blessed',)
return
exec_properties_copy = exec_properties.copy()
custom_config = exec_properties_copy.pop('custom_config', {})
ai_platform_serving_args = custom_config['ai_platform_serving_args']
# Deploy the model.
def _import_artifacts(self, source_uri: List[Text], reimport: bool,
destination_channel: types.Channel,
split_names: List[Text]) -> List[types.Artifact]:
"""Imports external resource in MLMD."""
results = []
for uri, s in zip(source_uri, split_names):
absl.logging.info('Processing source uri: %s, split: %s' %
(uri, s or 'NO_SPLIT'))
# TODO(ccy): refactor importer to treat split name just like any other
# property.
unfiltered_previous_artifacts = self._metadata_handler.get_artifacts_by_uri(
uri)
# Filter by split name.
desired_split_names = artifact_utils.encode_split_names([s or ''])
previous_artifacts = []
for previous_artifact in unfiltered_previous_artifacts:
split_names = previous_artifact.properties.get('split_names', None)
if split_names and split_names.string_value == desired_split_names:
previous_artifacts.append(previous_artifact)
result = types.Artifact(type_name=destination_channel.type_name)
result.split_names = desired_split_names
result.uri = uri
# If any registered artifact with the same uri also has the same
# fingerprint and user does not ask for re-import, just reuse the latest.
# Otherwise, register the external resource into MLMD using the type info
# in the destination channel.
if bool(previous_artifacts) and not reimport:
absl.logging.info('Reusing existing artifact')
input_dict: Input dict from input key to a list of artifacts, including:
- 'stats': A list of 'ExampleStatistics' type which must contain
split 'train'. Stats on other splits are ignored.
- 'statistics': Synonym for 'stats'.
output_dict: Output dict from key to a list of artifacts, including:
- output: A list of 'Schema' artifact of size one.
exec_properties: A dict of execution properties, includes:
- infer_feature_shape: Whether or not to infer the shape of the feature.
Returns:
None
"""
# TODO(zhitaoli): Move constants between this file and component.py to a
# constants.py.
train_stats_uri = io_utils.get_only_uri_in_dir(
artifact_utils.get_split_uri(input_dict['stats'], 'train'))
output_uri = os.path.join(
artifact_utils.get_single_uri(output_dict['output']),
_DEFAULT_FILE_NAME)
infer_feature_shape = exec_properties['infer_feature_shape']
absl.logging.info('Infering schema from statistics.')
schema = tfdv.infer_schema(
tfdv.load_statistics(train_stats_uri), infer_feature_shape)
io_utils.write_pbtxt_file(output_uri, schema)
absl.logging.info('Schema written to %s.' % output_uri)
both 'train' and 'eval' split.
exec_properties: A dict of execution properties. Not used yet.
Returns:
None
"""
self._log_startup(input_dict, output_dict, exec_properties)
split_to_instance = {x.split: x for x in input_dict['input_data']}
with beam.Pipeline(argv=self._get_beam_pipeline_args()) as p:
# TODO(b/126263006): Support more stats_options through config.
stats_options = options.StatsOptions()
for split, instance in split_to_instance.items():
tf.logging.info('Generating statistics for split {}'.format(split))
input_uri = io_utils.all_files_pattern(instance.uri)
output_uri = artifact_utils.get_split_uri(output_dict['output'], split)
output_path = os.path.join(output_uri, _DEFAULT_FILE_NAME)
_ = (
p
| 'ReadData.' + split >>
beam.io.ReadFromTFRecord(file_pattern=input_uri)
| 'DecodeData.' + split >> tf_example_decoder.DecodeTFExample()
| 'GenerateStatistics.' + split >>
stats_api.GenerateStatistics(stats_options)
| 'WriteStatsOutput.' + split >> beam.io.WriteToTFRecord(
output_path,
shard_name_template='',
coder=beam.coders.ProtoCoder(
statistics_pb2.DatasetFeatureStatisticsList)))
tf.logging.info('Statistics for split {} written to {}.'.format(
split, output_uri))
def CheckBlessing(self, input_dict: Dict[Text, List[types.Artifact]],
output_dict: Dict[Text, List[types.Artifact]]) -> bool:
"""Check that model is blessed by upstream ModelValidator, or update output.
Args:
input_dict: Input dict from input key to a list of artifacts:
- model_blessing: model blessing path from model_validator. Pusher looks
for a file named 'BLESSED' to consider the model blessed and safe to
push.
output_dict: Output dict from key to a list of artifacts, including:
- model_push: A list of 'ModelPushPath' artifact of size one.
Returns:
True if the model is blessed by validator.
"""
model_blessing = artifact_utils.get_single_instance(
input_dict['model_blessing'])
model_push = artifact_utils.get_single_instance(output_dict['model_push'])
# TODO(jyzhao): should this be in driver or executor.
if not model_utils.is_model_blessed(model_blessing):
model_push.set_int_custom_property('pushed', 0)
absl.logging.info('Model on %s was not blessed', model_blessing.uri)
return False
return True