Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def input_formats(detected_format):
formats = [{
'key': 'k',
'name': input_format_string('Keras (HDF5)', common.KERAS_MODEL,
detected_format),
'value': common.KERAS_MODEL
}, {
'key': 'e',
'name': input_format_string('Tensorflow Keras Saved Model',
common.KERAS_SAVED_MODEL,
detected_format),
'value': common.KERAS_SAVED_MODEL,
}, {
'key': 's',
'name': input_format_string('Tensorflow Saved Model',
common.TF_SAVED_MODEL,
detected_format),
'value': common.TF_SAVED_MODEL,
}, {
'key': 'h',
'name': input_format_string('TFHub Module',
common.TF_HUB_MODEL,
detected_format),
'value': common.TF_HUB_MODEL,
}, {
'key': 'l',
'name': input_format_string('TensoFlow.js Layers Model',
common.TFJS_LAYERS_MODEL,
detected_format),
'value': common.TFJS_LAYERS_MODEL,
}]
formats.sort(key=lambda x: x['value'] != detected_format)
def detect_saved_model(input_path):
if os.path.exists(os.path.join(input_path, 'assets', 'saved_model.json')):
return common.KERAS_SAVED_MODEL
saved_model = loader_impl.parse_saved_model(input_path)
graph_def = saved_model.meta_graphs[0].object_graph_def
if graph_def.nodes:
if 'tf_keras' in graph_def.nodes[0].user_object.identifier:
return common.KERAS_SAVED_MODEL
return common.TF_SAVED_MODEL
nargs='?',
type=str,
help='Path to the input file or directory. For input format "keras", '
'an HDF5 (.h5) file is expected. For input format "tensorflow", '
'a SavedModel directory, frozen model file, '
'or TF-Hub module is expected.')
parser.add_argument(
common.OUTPUT_PATH,
nargs='?',
type=str,
help='Path for all output artifacts.')
parser.add_argument(
'--%s' % common.INPUT_FORMAT,
type=str,
required=False,
default=common.TF_SAVED_MODEL,
choices=set([common.KERAS_MODEL, common.KERAS_SAVED_MODEL,
common.TF_SAVED_MODEL, common.TF_HUB_MODEL,
common.TFJS_LAYERS_MODEL, common.TF_FROZEN_MODEL]),
help='Input format. '
'For "keras", the input path can be one of the two following formats:\n'
' - A topology+weights combined HDF5 (e.g., generated with'
' `keras.model.save_model()` method).\n'
' - A weights-only HDF5 (e.g., generated with Keras Model\'s '
' `save_weights()` method). \n'
'For "keras_saved_model", the input_path must point to a subfolder '
'under the saved model folder that is passed as the argument '
'to tf.contrib.save_model.save_keras_model(). '
'The subfolder is generated automatically by tensorflow when '
'saving keras model in the SavedModel format. It is usually named '
'as a Unix epoch time (e.g., 1542212752).\n'
'For "tf" formats, a SavedModel, frozen model, '
weight_shard_size_bytes = args.weight_shard_size_bytes
if args.input_path is None:
raise ValueError(
'Error: The input_path argument must be set. '
'Run with --help flag for usage information.')
input_format, output_format = _standardize_input_output_formats(
args.input_format, args.output_format)
quantization_dtype = (
quantization.QUANTIZATION_BYTES_TO_DTYPES[args.quantization_bytes]
if args.quantization_bytes else None)
if (args.signature_name and input_format not in
(common.TF_SAVED_MODEL, common.TF_HUB_MODEL)):
raise ValueError(
'The --signature_name flag is applicable only to "tf_saved_model" and '
'"tf_hub" input format, but the current input format is '
'"%s".' % input_format)
# TODO(cais, piyu): More conversion logics can be added as additional
# branches below.
if (input_format == common.KERAS_MODEL and
output_format == common.TFJS_LAYERS_MODEL):
dispatch_keras_h5_to_tfjs_layers_model_conversion(
args.input_path, output_dir=args.output_path,
quantization_dtype=quantization_dtype,
split_weights_by_layer=args.split_weights_by_layer)
elif (input_format == common.KERAS_MODEL and
output_format == common.TFJS_GRAPH_MODEL):
dispatch_keras_h5_to_tfjs_graph_model_conversion(
def detect_input_format(input_path):
"""Determine the input format from model's input path or file.
Args:
input_path: string of the input model path
returns:
string: detected input format
string: normalized input path
"""
input_path = input_path.strip()
detected_input_format = None
if re.match(TFHUB_VALID_URL_REGEX, input_path):
detected_input_format = common.TF_HUB_MODEL
elif os.path.isdir(input_path):
if (any(fname.lower().endswith('saved_model.pb')
for fname in os.listdir(input_path))):
detected_input_format = common.TF_SAVED_MODEL
else:
for fname in os.listdir(input_path):
fname = fname.lower()
if fname.endswith('model.json'):
filename = os.path.join(input_path, fname)
if get_tfjs_model_type(filename) == common.TFJS_LAYERS_MODEL_FORMAT:
input_path = os.path.join(input_path, fname)
detected_input_format = common.TFJS_LAYERS_MODEL
break
elif os.path.isfile(input_path):
if h5py.is_hdf5(input_path):
detected_input_format = common.KERAS_MODEL
elif input_path.endswith('saved_model.pb'):
detected_input_format = common.TF_SAVED_MODEL
elif (input_path.endswith('model.json') and
get_tfjs_model_type(input_path) == common.TFJS_LAYERS_MODEL_FORMAT):
'when': lambda answers: value_in_list(answers, common.INPUT_FORMAT,
(common.TF_SAVED_MODEL,
common.TF_HUB_MODEL))
}
'Error: The input_path argument must be set. '
'Run with --help flag for usage information.')
input_format, output_format = _standardize_input_output_formats(
args.input_format, args.output_format)
quantization_dtype = (
quantization.QUANTIZATION_BYTES_TO_DTYPES[args.quantization_bytes]
if args.quantization_bytes else None)
if (not args.output_node_names and input_format == common.TF_FROZEN_MODEL):
raise ValueError(
'The --output_node_names flag is required for "tf_frozen_model"')
if (args.signature_name and input_format not in
(common.TF_SAVED_MODEL, common.TF_HUB_MODEL)):
raise ValueError(
'The --signature_name flag is applicable only to "tf_saved_model" and '
'"tf_hub" input format, but the current input format is '
'"%s".' % input_format)
# TODO(cais, piyu): More conversion logics can be added as additional
# branches below.
if (input_format == common.KERAS_MODEL and
output_format == common.TFJS_LAYERS_MODEL):
dispatch_keras_h5_to_tfjs_layers_model_conversion(
args.input_path, output_dir=args.output_path,
quantization_dtype=quantization_dtype,
split_weights_by_layer=args.split_weights_by_layer)
elif (input_format == common.KERAS_MODEL and
output_format == common.TFJS_GRAPH_MODEL):
dispatch_keras_h5_to_tfjs_graph_model_conversion(
'when': lambda answers: value_in_list(answers, common.INPUT_FORMAT,
(common.TF_SAVED_MODEL,
common.TF_HUB_MODEL))
},
def is_saved_model(input_format):
"""Check if the input path contains saved model.
Args:
input_format: input model format.
Returns:
bool: whether this is for a saved model conversion.
"""
return input_format == common.TF_SAVED_MODEL
def input_formats(detected_format):
formats = [{
'key': 'k',
'name': input_format_string('Keras (HDF5)', common.KERAS_MODEL,
detected_format),
'value': common.KERAS_MODEL
}, {
'key': 'e',
'name': input_format_string('Tensorflow Keras Saved Model',
common.KERAS_SAVED_MODEL,
detected_format),
'value': common.KERAS_SAVED_MODEL,
}, {
'key': 's',
'name': input_format_string('Tensorflow Saved Model',
common.TF_SAVED_MODEL,
detected_format),
'value': common.TF_SAVED_MODEL,
}, {
'key': 'h',
'name': input_format_string('TFHub Module',
common.TF_HUB_MODEL,
detected_format),
'value': common.TF_HUB_MODEL,
}, {
'key': 'l',
'name': input_format_string('TensoFlow.js Layers Model',
common.TFJS_LAYERS_MODEL,
detected_format),
'value': common.TFJS_LAYERS_MODEL,
}]
formats.sort(key=lambda x: x['value'] != detected_format)