Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
return False
if op.type == "TensorArrayReadV3":
# TensorArrayRead reads an element from the TensorArray into output value.
# The TensorArray's shape can be got from TensorArrayScatter.
# So the process is: first find TensorArrayScatter's shape and then TensorArray's
# and finally take its last n-1 dim.
flow_in_op = op.inputs[2].op
if flow_in_op.type != "Enter":
return False
scatter_op = flow_in_op.inputs[0].op
if scatter_op.type != "TensorArrayScatterV3":
return False
value_shape_before_scatter = utils.get_tf_tensor_shape(scatter_op.inputs[2])
if value_shape_before_scatter is None:
return False
new_shape = value_shape_before_scatter[1:]
if new_shape is not None:
op.outputs[0].set_shape(new_shape)
logger.debug("set [%s] with new shape %s", op.outputs[0].name, new_shape)
return True
return False
return False
logger.warning("Shapes of Merge %s have different ranks: %s, %s", op.name, len(s1), len(s2))
return False
logger.debug("Inputs of Merge %s have different shapes: %s, %s, but the same rank", op.name, s1, s2)
new_shape = _merge_shapes_for_tf(s1, s2)
op.outputs[0].set_shape(new_shape)
logger.debug("set [%s] with new shape %s", op.outputs[0].name, new_shape)
else:
new_shape = s1
op.outputs[0].set_shape(new_shape)
logger.debug("set [%s] with new shape %s", op.outputs[0].name, new_shape)
return True
if op.type == "Switch":
new_shape = utils.get_tf_tensor_shape(op.inputs[0])
if new_shape is not None:
op.outputs[0].set_shape(new_shape)
op.outputs[1].set_shape(new_shape)
logger.debug("set [%s] with new shape %s", op.outputs[0].name, new_shape)
logger.debug("set [%s] with new shape %s", op.outputs[1].name, new_shape)
return True
return False
if op.type == "Enter":
new_shape = utils.get_tf_tensor_shape(op.inputs[0])
if new_shape is not None:
op.outputs[0].set_shape(new_shape)
logger.debug("set [%s] with new shape %s", op.outputs[0].name, new_shape)
return True
return False
return False
if op.type == "TensorArrayReadV3":
# TensorArrayRead reads an element from the TensorArray into output value.
# The TensorArray's shape can be got from TensorArrayScatter.
# So the process is: first find TensorArrayScatter's shape and then TensorArray's
# and finally take its last n-1 dim.
flow_in_op = op.inputs[2].op
if flow_in_op.type != "Enter":
return False
scatter_op = flow_in_op.inputs[0].op
if scatter_op.type != "TensorArrayScatterV3":
return False
value_shape_before_scatter = utils.get_tf_tensor_shape(scatter_op.inputs[2])
if value_shape_before_scatter is None:
return False
new_shape = value_shape_before_scatter[1:]
if new_shape is not None:
op.outputs[0].set_shape(new_shape)
logger.debug("set [%s] with new shape %s", op.outputs[0].name, new_shape)
return True
return False
return False
return False
if op.type == "Placeholder":
# if placeholder shape is not found, try to get it from "shape" attribute.
attr_shape = utils.get_tf_shape_attr(op)
if attr_shape is not None:
new_shape = list(attr_shape)
op.outputs[0].set_shape(new_shape)
logger.debug("set placeholder op [%s] with new shape %s", op.outputs[0].name, new_shape)
return True
logger.warning("Shape of placeholder %s is unknown, treated it as a scalar", op.name)
op.outputs[0].set_shape([])
return True
if op.type == "Merge":
s1 = utils.get_tf_tensor_shape(op.inputs[0])
s2 = utils.get_tf_tensor_shape(op.inputs[1])
new_shape = None
if s1 is None and s2 is None:
return False
if s1 is None and s2 is not None:
new_shape = s2
if s1 is not None and s2 is None:
new_shape = s1
if new_shape is not None:
op.inputs[0].set_shape(new_shape)
op.inputs[1].set_shape(new_shape)
op.outputs[0].set_shape(new_shape)
logger.debug("set [%s] with new shape %s", op.outputs[0].name, new_shape)
return True
else:
axis = 0
shape = shape_params[:axis] + shape_indices + shape_params[axis + 1:]
op.outputs[0].set_shape(shape)
return True
if op.type in ["All", "Any", "Max", "Min"]:
axis_op = op.inputs[1].op
if not utils.is_tf_const_op(axis_op):
return False
axis = utils.get_tf_const_value(axis_op)
if not isinstance(axis, list):
axis = [axis]
keep_dims = op.get_attr("keep_dims")
shape = utils.get_tf_tensor_shape(op.inputs[0])
for i, _ in enumerate(axis):
if axis[i] < 0:
axis[i] += len(shape)
new_shape = []
for i, _ in enumerate(shape):
if i in axis:
if keep_dims:
new_shape.append(1)
else:
new_shape.append(shape[i])
op.outputs[0].set_shape(new_shape)
logger.debug("set %s op [%s] with new shape %s", op.type, op.outputs[0].name, new_shape)
return True
def set_shape_from_inputs_broadcast(input_tensors, output_tensor):
s1 = utils.get_tf_tensor_shape(input_tensors[0])
s2 = utils.get_tf_tensor_shape(input_tensors[1])
new_shape = broadcast_shape_inference(s1, s2)
if new_shape is not None:
output_tensor.set_shape(new_shape)
logger.debug("set [%s] with new shape %s", output_tensor.name, new_shape)
return True
return False
return False
axis = op.get_attr("axis")
axis = axis if axis >= 0 else axis + len(input_shape)
# the link below says that the rank of output is "rank(input) -1",
# from this statement "num" must equal to input_shape[axis], and if not tf will throw a runtime error
# https://www.tensorflow.org/api_docs/python/tf/unstack
new_shape = input_shape[:axis] + input_shape[axis + 1:]
for output in op.outputs:
output.set_shape(new_shape)
logger.debug("set %s op [%s] with new shape %s", op.type, output.name, new_shape)
return True
if op.type in ["Minimum", "Maximum"]:
# ops that are elementwise and support broadcasting
input_shapes = [utils.get_tf_tensor_shape(op) for op in op.inputs]
new_shape = broadcast_shape_inference(*input_shapes)
op.outputs[0].set_shape(new_shape)
return True
return False
if infer_input_shapes(op):
return True
if not has_unknown_output_shape:
return False
# for those ops, we don't expect all input shapes available to infer output shapes.
ret = infer_output_shapes_with_partial_inputs(op)
if ret is not None:
return ret
# for ops, we need all input shapes ready to infer output shapes.
are_all_input_shape_ready = True
no_shape = []
for i in op.inputs:
if utils.get_tf_tensor_shape(i) is None:
are_all_input_shape_ready = False
no_shape.append(i.name)
if not are_all_input_shape_ready:
logger.debug("op %s has inputs don't have shape specified, they are: %s", op.name, no_shape)
return False
if op.type in direct_ops:
return set_shape_from_input(op.inputs[0], op.outputs[0])
if op.type in broadcast_ops:
return set_shape_from_inputs_broadcast(op.inputs, op.outputs[0])
if op.type == "RandomUniform":
shape_op = op.inputs[0].op
if not shape_op or shape_op.type != "Shape":
return set_shape_from_input(op.inputs[0], op.outputs[0])
if op.type in broadcast_ops:
return set_shape_from_inputs_broadcast(op.inputs, op.outputs[0])
if op.type == "RandomUniform":
shape_op = op.inputs[0].op
if not shape_op or shape_op.type != "Shape":
return False
return set_shape_from_input(shape_op.inputs[0], op.outputs[0])
if op.type == "Gather":
# uses the follwing link to know how to infer shape of output
# https://www.tensorflow.org/api_docs/python/tf/gather
shape_params = utils.get_tf_tensor_shape(op.inputs[0])
shape_indices = utils.get_tf_tensor_shape(op.inputs[1])
# gather can only have 2 inputs
# https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/gather.html
if len(op.inputs) == 3:
axis_op = op.inputs[2].op
if not utils.is_tf_const_op(axis_op):
return False
axis = utils.get_tf_const_value(axis_op)
else:
axis = 0
shape = shape_params[:axis] + shape_indices + shape_params[axis + 1:]
op.outputs[0].set_shape(shape)
return True
if op.type in ["All", "Any", "Max", "Min"]:
axis_op = op.inputs[1].op