Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def input_formats(detected_format):
formats = [{
'key': 'k',
'name': input_format_string('Keras (HDF5)', common.KERAS_MODEL,
detected_format),
'value': common.KERAS_MODEL
}, {
'key': 'e',
'name': input_format_string('Tensorflow Keras Saved Model',
common.KERAS_SAVED_MODEL,
detected_format),
'value': common.KERAS_SAVED_MODEL,
}, {
'key': 's',
'name': input_format_string('Tensorflow Saved Model',
common.TF_SAVED_MODEL,
detected_format),
'value': common.TF_SAVED_MODEL,
}, {
'key': 'h',
def input_formats(detected_format):
formats = [{
'key': 'k',
'name': input_format_string('Keras (HDF5)', common.KERAS_MODEL,
detected_format),
'value': common.KERAS_MODEL
}, {
'key': 'e',
'name': input_format_string('Tensorflow Keras Saved Model',
common.KERAS_SAVED_MODEL,
detected_format),
'value': common.KERAS_SAVED_MODEL,
}, {
'key': 's',
'name': input_format_string('Tensorflow Saved Model',
common.TF_SAVED_MODEL,
detected_format),
'value': common.TF_SAVED_MODEL,
}, {
'key': 'h',
output_format: Output format as a string.
Returns:
A `tuple` of two strings:
(standardized_input_format, standardized_output_format).
"""
# https://github.com/tensorflow/tfjs/issues/1292: Remove the logic for the
# explicit error message of the deprecated model type name 'tensorflowjs'
# at version 1.1.0.
if input_format == 'tensorflowjs':
raise ValueError(
'--input_format=tensorflowjs has been deprecated. '
'Use --input_format=tfjs_layers_model instead.')
input_format_is_keras = (
input_format in [common.KERAS_MODEL, common.KERAS_SAVED_MODEL])
input_format_is_tf = (
input_format in [common.TF_SAVED_MODEL, common.TF_HUB_MODEL])
if output_format is None:
# If no explicit output_format is provided, infer it from input format.
if input_format_is_keras:
output_format = common.TFJS_LAYERS_MODEL
elif input_format_is_tf:
output_format = common.TFJS_GRAPH_MODEL
elif input_format == common.TFJS_LAYERS_MODEL:
output_format = common.KERAS_MODEL
elif output_format == 'tensorflowjs':
# https://github.com/tensorflow/tfjs/issues/1292: Remove the logic for the
# explicit error message of the deprecated model type name 'tensorflowjs'
# at version 1.1.0.
if input_format_is_keras:
raise ValueError(
elif os.path.isdir(input_path):
if (any(fname.lower().endswith('saved_model.pb')
for fname in os.listdir(input_path))):
detected_input_format = common.TF_SAVED_MODEL
else:
for fname in os.listdir(input_path):
fname = fname.lower()
if fname.endswith('model.json'):
filename = os.path.join(input_path, fname)
if get_tfjs_model_type(filename) == common.TFJS_LAYERS_MODEL_FORMAT:
input_path = os.path.join(input_path, fname)
detected_input_format = common.TFJS_LAYERS_MODEL
break
elif os.path.isfile(input_path):
if h5py.is_hdf5(input_path):
detected_input_format = common.KERAS_MODEL
elif input_path.endswith('saved_model.pb'):
detected_input_format = common.TF_SAVED_MODEL
elif (input_path.endswith('model.json') and
get_tfjs_model_type(input_path) == common.TFJS_LAYERS_MODEL_FORMAT):
detected_input_format = common.TFJS_LAYERS_MODEL
return detected_input_format, input_path
def _standardize_input_output_formats(input_format, output_format):
"""Standardize input and output formats.
Args:
input_format: Input format as a string.
output_format: Output format as a string.
Returns:
A `tuple` of two strings:
(standardized_input_format, standardized_output_format).
"""
input_format_is_keras = (
input_format in [common.KERAS_MODEL, common.KERAS_SAVED_MODEL])
input_format_is_tf = (
input_format in [common.TF_SAVED_MODEL,
common.TF_FROZEN_MODEL, common.TF_HUB_MODEL])
if output_format is None:
# If no explicit output_format is provided, infer it from input format.
if input_format_is_keras:
output_format = common.TFJS_LAYERS_MODEL
elif input_format_is_tf:
output_format = common.TFJS_GRAPH_MODEL
elif input_format == common.TFJS_LAYERS_MODEL:
output_format = common.KERAS_MODEL
return (input_format, output_format)
input_format = answers[common.INPUT_FORMAT]
if input_format == common.KERAS_SAVED_MODEL:
return [{
'key': 'g', # shortcut key for the option
'name': 'Tensorflow.js Graph Model',
'value': common.TFJS_GRAPH_MODEL,
}, {
'key': 'l',
'name': 'TensoFlow.js Layers Model',
'value': common.TFJS_LAYERS_MODEL,
}]
if input_format == common.TFJS_LAYERS_MODEL:
return [{
'key': 'k',
'name': 'Keras Model (HDF5)',
'value': common.KERAS_MODEL,
}, {
'key': 'l',
'name': 'TensoFlow.js Layers Model',
'value': common.TFJS_LAYERS_MODEL,
}]
return []
raise ValueError(
'--input_format=tensorflowjs has been deprecated. '
'Use --input_format=tfjs_layers_model instead.')
input_format_is_keras = (
input_format in [common.KERAS_MODEL, common.KERAS_SAVED_MODEL])
input_format_is_tf = (
input_format in [common.TF_SAVED_MODEL, common.TF_HUB_MODEL])
if output_format is None:
# If no explicit output_format is provided, infer it from input format.
if input_format_is_keras:
output_format = common.TFJS_LAYERS_MODEL
elif input_format_is_tf:
output_format = common.TFJS_GRAPH_MODEL
elif input_format == common.TFJS_LAYERS_MODEL:
output_format = common.KERAS_MODEL
elif output_format == 'tensorflowjs':
# https://github.com/tensorflow/tfjs/issues/1292: Remove the logic for the
# explicit error message of the deprecated model type name 'tensorflowjs'
# at version 1.1.0.
if input_format_is_keras:
raise ValueError(
'--output_format=tensorflowjs has been deprecated under '
'--input_format=%s. Use --output_format=tfjs_layers_model '
'instead.' % input_format)
if input_format_is_tf:
raise ValueError(
'--output_format=tensorflowjs has been deprecated under '
'--input_format=%s. Use --output_format=tfjs_graph_model '
'instead.' % input_format)
return (input_format, output_format)
output_format == common.TFJS_GRAPH_MODEL):
tf_saved_model_conversion_v2.convert_tf_saved_model(
args.input_path, args.output_path,
signature_def=args.signature_name,
saved_model_tags=args.saved_model_tags,
quantization_dtype=quantization_dtype,
skip_op_check=args.skip_op_check,
strip_debug_ops=args.strip_debug_ops)
elif (input_format == common.TF_HUB_MODEL and
output_format == common.TFJS_GRAPH_MODEL):
tf_saved_model_conversion_v2.convert_tf_hub_module(
args.input_path, args.output_path, args.signature_name,
args.saved_model_tags, skip_op_check=args.skip_op_check,
strip_debug_ops=args.strip_debug_ops)
elif (input_format == common.TFJS_LAYERS_MODEL and
output_format == common.KERAS_MODEL):
dispatch_tensorflowjs_to_keras_h5_conversion(args.input_path,
args.output_path)
elif (input_format == common.TFJS_LAYERS_MODEL and
output_format == common.KERAS_SAVED_MODEL):
dispatch_tensorflowjs_to_keras_saved_model_conversion(args.input_path,
args.output_path)
elif (input_format == common.TFJS_LAYERS_MODEL and
output_format == common.TFJS_LAYERS_MODEL):
dispatch_tensorflowjs_to_tensorflowjs_conversion(
args.input_path, args.output_path,
quantization_dtype=_parse_quantization_bytes(args.quantization_bytes),
weight_shard_size_bytes=weight_shard_size_bytes)
elif (input_format == common.TFJS_LAYERS_MODEL and
output_format == common.TFJS_GRAPH_MODEL):
dispatch_tfjs_layers_model_to_tfjs_graph_conversion(
args.input_path, args.output_path,
output_format == common.TFJS_GRAPH_MODEL):
tf_saved_model_conversion_v2.convert_tf_saved_model(
args.input_path, args.output_path,
signature_def=args.signature_name,
saved_model_tags=args.saved_model_tags,
quantization_dtype=quantization_dtype,
skip_op_check=args.skip_op_check,
strip_debug_ops=args.strip_debug_ops)
elif (input_format == common.TF_HUB_MODEL and
output_format == common.TFJS_GRAPH_MODEL):
tf_saved_model_conversion_v2.convert_tf_hub_module(
args.input_path, args.output_path, args.signature_name,
args.saved_model_tags, skip_op_check=args.skip_op_check,
strip_debug_ops=args.strip_debug_ops)
elif (input_format == common.TFJS_LAYERS_MODEL and
output_format == common.KERAS_MODEL):
dispatch_tensorflowjs_to_keras_h5_conversion(args.input_path,
args.output_path)
elif (input_format == common.TFJS_LAYERS_MODEL and
output_format == common.KERAS_SAVED_MODEL):
dispatch_tensorflowjs_to_keras_saved_model_conversion(args.input_path,
args.output_path)
elif (input_format == common.TFJS_LAYERS_MODEL and
output_format == common.TFJS_LAYERS_MODEL):
dispatch_tensorflowjs_to_tensorflowjs_conversion(
args.input_path, args.output_path,
quantization_dtype=_parse_quantization_bytes(args.quantization_bytes),
weight_shard_size_bytes=weight_shard_size_bytes)
elif (input_format == common.TFJS_LAYERS_MODEL and
output_format == common.TFJS_GRAPH_MODEL):
dispatch_tfjs_layers_model_to_tfjs_graph_conversion(
args.input_path, args.output_path,
elif os.path.isdir(input_path):
if (any(fname.lower().endswith('saved_model.pb')
for fname in os.listdir(input_path))):
detected_input_format = detect_saved_model(input_path)
else:
for fname in os.listdir(input_path):
fname = fname.lower()
if fname.endswith('model.json'):
filename = os.path.join(input_path, fname)
if get_tfjs_model_type(filename) == common.TFJS_LAYERS_MODEL_FORMAT:
input_path = os.path.join(input_path, fname)
detected_input_format = common.TFJS_LAYERS_MODEL
break
elif os.path.isfile(input_path):
if h5py.is_hdf5(input_path):
detected_input_format = common.KERAS_MODEL
elif input_path.endswith('saved_model.pb'):
input_path = os.path.dirname(input_path)
detected_input_format = detect_saved_model(input_path)
elif (input_path.endswith('model.json') and
get_tfjs_model_type(input_path) == common.TFJS_LAYERS_MODEL_FORMAT):
detected_input_format = common.TFJS_LAYERS_MODEL
return detected_input_format, input_path