Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
Returns:
(new_edge1, new_edge2): The new `Edge` objects of
`self.node1` and `self.node2`
"""
if self.is_dangling():
raise ValueError("Cannot break dangling edge {}.".format(self))
if not edge1_name:
edge1_name = '__disconnected_edge1_of_{}__'.format(self.name)
if not edge2_name:
edge2_name = '__disconnected_edge2_of_{}__'.format(self.name)
node1 = self.node1
node2 = self.node2
new_edge1 = Edge(node1=node1, axis1=self.axis1, name=edge1_name)
new_edge2 = Edge(node1=node2, axis1=self.axis2, name=edge2_name)
node1.add_edge(new_edge1, self.axis1, override=True)
node2.add_edge(new_edge2, self.axis2, override=True)
return new_edge1, new_edge2
edge2_name: A name for the new dangling edge at `self.node2`
Returns:
(new_edge1, new_edge2): The new `Edge` objects of
`self.node1` and `self.node2`
"""
if self.is_dangling():
raise ValueError("Cannot break dangling edge {}.".format(self))
if not edge1_name:
edge1_name = '__disconnected_edge1_of_{}__'.format(self.name)
if not edge2_name:
edge2_name = '__disconnected_edge2_of_{}__'.format(self.name)
node1 = self.node1
node2 = self.node2
new_edge1 = Edge(node1=node1, axis1=self.axis1, name=edge1_name)
new_edge2 = Edge(node1=node2, axis1=self.axis2, name=edge2_name)
node1.add_edge(new_edge1, self.axis1, override=True)
node2.add_edge(new_edge2, self.axis2, override=True)
return new_edge1, new_edge2
new_edge = Edge(node_dict[node1], axis1, edge.name)
node_dict[node1].add_edge(new_edge, axis1)
edge_dict[edge] = new_edge
continue
node2 = edge.node2
axis2 = edge.node2.get_axis_number(edge.axis2)
# copy node2 but not node1
if node1 not in node_dict:
new_edge = Edge(node_dict[node2], axis2, edge.name)
node_dict[node2].add_edge(new_edge, axis2)
edge_dict[edge] = new_edge
continue
# both nodes should be copied
new_edge = Edge(node_dict[node1], axis1, edge.name, node_dict[node2], axis2)
new_edge.set_signature(edge.signature)
node_dict[node2].add_edge(new_edge, axis2)
node_dict[node1].add_edge(new_edge, axis1)
edge_dict[edge] = new_edge
return node_dict, edge_dict
if edge1 is edge2:
raise ValueError("Cannot connect and edge '{}' to itself.".format(edge1))
if edge1.dimension != edge2.dimension:
raise ValueError("Cannot connect edges of unequal dimension. "
"Dimension of edge '{}': {}, "
"Dimension of edge '{}': {}.".format(
edge1, edge1.dimension, edge2, edge2.dimension))
#edge1 and edge2 are always dangling in this case
node1 = edge1.node1
node2 = edge2.node1
axis1_num = node1.get_axis_number(edge1.axis1)
axis2_num = node2.get_axis_number(edge2.axis1)
new_edge = Edge(
node1=node1, axis1=axis1_num, name=name, node2=node2, axis2=axis2_num)
node1.add_edge(new_edge, axis1_num, override=True)
node2.add_edge(new_edge, axis2_num, override=True)
return new_edge
backend = node.backend
# Permute until edge axes to be split are at the back and reshape.
perm_back = [min(edge.axis1, edge.axis2)]
perm_back += [max(edge.axis1, edge.axis2)]
perm_front = set(range(len(node.edges))) - set(perm_back)
perm_front = sorted(perm_front)
node.reorder_axes(perm_front + perm_back)
unaffected_shape = backend.shape(node.tensor)[:len(perm_front)]
new_shape = backend.concat([unaffected_shape, shape, shape], axis=-1)
node.tensor = backend.reshape(node.tensor, new_shape)
# Trim edges and add placeholder edges for new axes.
node.edges = node.edges[:len(perm_front)] + 2 * len(shape) * [None]
# Create new dangling edges and connect them to each other.
new_edges = []
for idx in range(len(shape)):
edge1 = Edge(node1=node, axis1=len(perm_front) + idx)
edge2 = Edge(node1=node, axis1=len(perm_front) + len(shape) + idx)
node.edges[len(perm_front) + idx] = edge1
node.edges[len(perm_front) + len(shape) + idx] = edge2
new_edges.append(
connect(edge1, edge2,
new_edge_names[idx] if new_edge_names is not None else None))
# pylint: disable=expression-not-assigned
edge.disable() # disable old edge!
return new_edges
if node is None:
continue
axis_names = node.axis_names
# Permute until edge axes to be split are at the back and reshape.
perm_back = [node.edges.index(edge)]
perm_front = set(range(len(node.edges))) - set(perm_back)
perm_front = sorted(perm_front)
node.reorder_axes(perm_front + perm_back)
unaffected_shape = backend.shape(node.tensor)[:len(perm_front)]
new_shape = backend.concat([unaffected_shape, shape], axis=-1)
node.tensor = backend.reshape(node.tensor, new_shape) # in-place update
# Trim edges.
node.edges = node.edges[:len(perm_front)]
# Create new dangling edges.
for idx in range(len(shape)):
new_dangling_edge = Edge(
node1=node,
axis1=len(perm_front) + idx,
name=new_edge_names[idx] if new_edge_names is not None else None)
node.edges += [new_dangling_edge]
new_dangling_edges.append(new_dangling_edge)
# TODO: Allow renaming of new axes (possibly distinct from new_edge_names).
if axis_names:
new_axis_names = [axis_names[n] for n in range(len(unaffected_shape))]
if new_edge_names:
new_axis_names.extend(new_edge_names)
else:
new_axis_names.extend(
[str(n) for n in range(len(unaffected_shape), len(node.edges))])
node.axis_names = new_axis_names
else:
node.axis_names = [str(n) for n in range(len(node.edges))]
name=node.name,
axis_names=node.axis_names,
backend=node.backend)
else:
node_dict[node] = Node(
node.tensor,
name=node.name,
axis_names=node.axis_names,
backend=node.backend)
edge_dict = {}
for edge in get_all_edges(nodes):
node1 = edge.node1
axis1 = edge.node1.get_axis_number(edge.axis1)
# edge dangling or node2 does not need to be copied
if edge.is_dangling() or edge.node2 not in node_dict:
new_edge = Edge(node_dict[node1], axis1, edge.name)
node_dict[node1].add_edge(new_edge, axis1)
edge_dict[edge] = new_edge
continue
node2 = edge.node2
axis2 = edge.node2.get_axis_number(edge.axis2)
# copy node2 but not node1
if node1 not in node_dict:
new_edge = Edge(node_dict[node2], axis2, edge.name)
node_dict[node2].add_edge(new_edge, axis2)
edge_dict[edge] = new_edge
continue
# both nodes should be copied
new_edge = Edge(node_dict[node1], axis1, edge.name, node_dict[node2], axis2)
new_edge.set_signature(edge.signature)
"""
node = edges[0].node1 # We are in the trace case, so this is the only node.
backend = node.backend
# Flatten all of the edge's axes into a a single list.
perm_back = [min(e.axis1, e.axis2) for e in edges]
perm_back += [max(e.axis1, e.axis2) for e in edges]
perm_front = set(range(len(node.edges))) - set(perm_back)
perm_front = sorted(perm_front)
perm = perm_front + perm_back
new_dim = backend.prod([backend.shape(node.tensor)[e.axis1] for e in edges])
node.reorder_axes(perm)
unaffected_shape = backend.shape(node.tensor)[:len(perm_front)]
new_shape = backend.concat([unaffected_shape, [new_dim, new_dim]], axis=-1)
node.tensor = backend.reshape(node.tensor, new_shape)
edge1 = Edge(node1=node, axis1=len(perm_front), name="TraceFront")
edge2 = Edge(node1=node, axis1=len(perm_front) + 1, name="TraceBack")
node.edges = node.edges[:len(perm_front)] + [edge1, edge2]
new_edge = connect(edge1, edge2, new_edge_name)
# pylint: disable=expression-not-assigned
[edge.disable() for edge in edges] #disable edges!
return new_edge
The new edge that represents the flattening of the given edges.
"""
node = edges[0].node1 # We are in the trace case, so this is the only node.
backend = node.backend
# Flatten all of the edge's axes into a a single list.
perm_back = [min(e.axis1, e.axis2) for e in edges]
perm_back += [max(e.axis1, e.axis2) for e in edges]
perm_front = set(range(len(node.edges))) - set(perm_back)
perm_front = sorted(perm_front)
perm = perm_front + perm_back
new_dim = backend.prod([backend.shape(node.tensor)[e.axis1] for e in edges])
node.reorder_axes(perm)
unaffected_shape = backend.shape(node.tensor)[:len(perm_front)]
new_shape = backend.concat([unaffected_shape, [new_dim, new_dim]], axis=-1)
node.tensor = backend.reshape(node.tensor, new_shape)
edge1 = Edge(node1=node, axis1=len(perm_front), name="TraceFront")
edge2 = Edge(node1=node, axis1=len(perm_front) + 1, name="TraceBack")
node.edges = node.edges[:len(perm_front)] + [edge1, edge2]
new_edge = connect(edge1, edge2, new_edge_name)
# pylint: disable=expression-not-assigned
[edge.disable() for edge in edges] #disable edges!
return new_edge