Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
_nodes = nodes
if n <= 0:
raise ValueError("Cannot contract empty tensor network.")
if n == 1:
if not ignore_edge_order:
if output_edge_order is None:
output_edge_order = list(
(get_all_edges(_nodes) - get_all_nondangling(_nodes)))
if len(output_edge_order) > 1:
raise ValueError("The final node after contraction has more than "
"one dangling edge. In this case `output_edge_order` "
"has to be provided.")
edges = get_all_nondangling(_nodes)
if edges:
final_node = contract_parallel(edges.pop())
else:
final_node = list(_nodes)[0]
final_node.reorder_edges(output_edge_order)
if not ignore_edge_order:
final_node.reorder_edges(output_edge_order)
return final_node
if n < 5:
return optimal(nodes, output_edge_order, memory_limit, ignore_edge_order)
if n < 7:
return branch(nodes, output_edge_order, memory_limit, ignore_edge_order)
if n < 9:
return branch(nodes, output_edge_order, memory_limit, nbranch=2, ignore_edge_order=ignore_edge_order)
if n < 15:
return branch(nodes, output_edge_order, nbranch=1, ignore_edge_order=ignore_edge_order)
return greedy(nodes, output_edge_order, memory_limit, ignore_edge_order)
if len(output_edge_order) > 1:
raise ValueError("The final node after contraction has more than "
"one remaining edge. In this case `output_edge_order` "
"has to be provided.")
if set(output_edge_order) != get_subgraph_dangling(nodes):
raise ValueError(
"output edges are not equal to the remaining "
"non-contracted edges of the final node."
)
for edge in edges:
if not edge.is_disabled: #if its disabled we already contracted it
if edge.is_trace():
nodes_set.remove(edge.node1)
nodes_set.add(contract_parallel(edge))
if len(nodes_set) == 1:
# There's nothing to contract.
if ignore_edge_order:
return list(nodes_set)[0]
return list(nodes_set)[0].reorder_edges(output_edge_order)
# Then apply `opt_einsum`'s algorithm
path, nodes = utils.get_path(nodes_set, algorithm)
for a, b in path:
new_node = contract_between(nodes[a], nodes[b], allow_outer_product=True)
nodes.append(new_node)
nodes = utils.multi_remove(nodes, [a, b])
# if the final node has more than one edge,
# output_edge_order has to be specified
def contract_trace_edges(node: BaseNode) -> BaseNode:
"""
contract all trace edges of `node`.
Args:
node: A `BaseNode` object
Returns:
A new `BaseNode` obtained from contracting all
trace edges.
Raises:
ValueError: If `node` has no trace edges
"""
for edge in node.edges:
if edge.is_trace():
return contract_parallel(edge)
raise ValueError('`node` has no trace edges')