Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def input_formats(detected_format):
formats = [{
'key': 'k',
'name': input_format_string('Keras (HDF5)', common.KERAS_MODEL,
detected_format),
'value': common.KERAS_MODEL
}, {
'key': 'e',
'name': input_format_string('Tensorflow Keras Saved Model',
common.KERAS_SAVED_MODEL,
detected_format),
'value': common.KERAS_SAVED_MODEL,
}, {
'key': 's',
'name': input_format_string('Tensorflow Saved Model',
common.TF_SAVED_MODEL,
detected_format),
'value': common.TF_SAVED_MODEL,
}, {
'key': 'h',
'name': input_format_string('TFHub Module',
common.TF_HUB_MODEL,
detected_format),
'value': common.TF_HUB_MODEL,
}, {
'key': 'l',
def input_formats(detected_format):
formats = [{
'key': 'k',
'name': input_format_string('Keras (HDF5)', common.KERAS_MODEL,
detected_format),
'value': common.KERAS_MODEL
}, {
'key': 'e',
'name': input_format_string('Tensorflow Keras Saved Model',
common.KERAS_SAVED_MODEL,
detected_format),
'value': common.KERAS_SAVED_MODEL,
}, {
'key': 's',
'name': input_format_string('Tensorflow Saved Model',
common.TF_SAVED_MODEL,
detected_format),
'value': common.TF_SAVED_MODEL,
}, {
'key': 'h',
'name': input_format_string('TFHub Module',
common.TF_HUB_MODEL,
detected_format),
'value': common.TF_HUB_MODEL,
}, {
'key': 'l',
def is_saved_model(input_format):
"""Check if the input path contains saved model.
Args:
input_format: input model format.
Returns:
bool: whether this is for a saved model conversion.
"""
return input_format == common.TF_SAVED_MODEL or \
input_format == common.KERAS_SAVED_MODEL
'name': 'Tensorflow.js Graph Model',
'value': common.TFJS_GRAPH_MODEL,
}, {
'key': 'l',
'name': 'TensoFlow.js Layers Model',
'value': common.TFJS_LAYERS_MODEL,
}]
if input_format == common.TFJS_LAYERS_MODEL:
return [{
'key': 'k', # shortcut key for the option
'name': 'Keras Model',
'value': common.KERAS_MODEL,
}, {
'key': 's',
'name': 'Keras Saved Model',
'value': common.KERAS_SAVED_MODEL,
}]
return []
def detect_saved_model(input_path):
if os.path.exists(os.path.join(input_path, 'assets', 'saved_model.json')):
return common.KERAS_SAVED_MODEL
saved_model = loader_impl.parse_saved_model(input_path)
graph_def = saved_model.meta_graphs[0].object_graph_def
if graph_def.nodes:
if 'tf_keras' in graph_def.nodes[0].user_object.identifier:
return common.KERAS_SAVED_MODEL
return common.TF_SAVED_MODEL
# TODO(cais, piyu): More conversion logics can be added as additional
# branches below.
if (input_format == common.KERAS_MODEL and
output_format == common.TFJS_LAYERS_MODEL):
dispatch_keras_h5_to_tfjs_layers_model_conversion(
args.input_path, output_dir=args.output_path,
quantization_dtype=quantization_dtype,
split_weights_by_layer=args.split_weights_by_layer)
elif (input_format == common.KERAS_MODEL and
output_format == common.TFJS_GRAPH_MODEL):
dispatch_keras_h5_to_tfjs_graph_model_conversion(
args.input_path, output_dir=args.output_path,
quantization_dtype=quantization_dtype,
skip_op_check=args.skip_op_check,
strip_debug_ops=args.strip_debug_ops)
elif (input_format == common.KERAS_SAVED_MODEL and
output_format == common.TFJS_LAYERS_MODEL):
dispatch_keras_saved_model_to_tensorflowjs_conversion(
args.input_path, args.output_path,
quantization_dtype=quantization_dtype,
split_weights_by_layer=args.split_weights_by_layer)
elif (input_format == common.TF_SAVED_MODEL and
output_format == common.TFJS_GRAPH_MODEL):
tf_saved_model_conversion_v2.convert_tf_saved_model(
args.input_path, args.output_path,
signature_def=args.signature_name,
saved_model_tags=args.saved_model_tags,
quantization_dtype=quantization_dtype,
skip_op_check=args.skip_op_check,
strip_debug_ops=args.strip_debug_ops)
elif (input_format == common.TF_HUB_MODEL and
output_format == common.TFJS_GRAPH_MODEL):
type=str,
help='Path to the input file or directory. For input format "keras", '
'an HDF5 (.h5) file is expected. For input format "tensorflow", '
'a SavedModel directory, frozen model file, '
'or TF-Hub module is expected.')
parser.add_argument(
common.OUTPUT_PATH,
nargs='?',
type=str,
help='Path for all output artifacts.')
parser.add_argument(
'--%s' % common.INPUT_FORMAT,
type=str,
required=False,
default=common.TF_SAVED_MODEL,
choices=set([common.KERAS_MODEL, common.KERAS_SAVED_MODEL,
common.TF_SAVED_MODEL, common.TF_HUB_MODEL,
common.TFJS_LAYERS_MODEL, common.TF_FROZEN_MODEL]),
help='Input format. '
'For "keras", the input path can be one of the two following formats:\n'
' - A topology+weights combined HDF5 (e.g., generated with'
' `keras.model.save_model()` method).\n'
' - A weights-only HDF5 (e.g., generated with Keras Model\'s '
' `save_weights()` method). \n'
'For "keras_saved_model", the input_path must point to a subfolder '
'under the saved model folder that is passed as the argument '
'to tf.contrib.save_model.save_keras_model(). '
'The subfolder is generated automatically by tensorflow when '
'saving keras model in the SavedModel format. It is usually named '
'as a Unix epoch time (e.g., 1542212752).\n'
'For "tf" formats, a SavedModel, frozen model, '
' or TF-Hub module is expected.')
def detect_saved_model(input_path):
if os.path.exists(os.path.join(input_path, 'assets', 'saved_model.json')):
return common.KERAS_SAVED_MODEL
saved_model = loader_impl.parse_saved_model(input_path)
graph_def = saved_model.meta_graphs[0].object_graph_def
if graph_def.nodes:
if 'tf_keras' in graph_def.nodes[0].user_object.identifier:
return common.KERAS_SAVED_MODEL
return common.TF_SAVED_MODEL
def input_formats(detected_format):
formats = [{
'key': 'k',
'name': input_format_string('Keras (HDF5)', common.KERAS_MODEL,
detected_format),
'value': common.KERAS_MODEL
}, {
'key': 'e',
'name': input_format_string('Tensorflow Keras Saved Model',
common.KERAS_SAVED_MODEL,
detected_format),
'value': common.KERAS_SAVED_MODEL,
}, {
'key': 's',
'name': input_format_string('Tensorflow Saved Model',
common.TF_SAVED_MODEL,
detected_format),
'value': common.TF_SAVED_MODEL,
}, {
'key': 'h',
'name': input_format_string('TFHub Module',
common.TF_HUB_MODEL,
detected_format),
'value': common.TF_HUB_MODEL,
}, {
'key': 'l',
'name': input_format_string('TensoFlow.js Layers Model',
common.TFJS_LAYERS_MODEL,