Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def __init__(self, cell_name, prev_labels, channels):
super().__init__()
self.input_choice = mutables.InputChoice(choose_from=prev_labels, n_chosen=1, return_mask=True,
key=cell_name + "_input")
self.op_choice = mutables.LayerChoice([
SepConvBN(channels, channels, 3, 1),
SepConvBN(channels, channels, 5, 2),
Pool("avg", 3, 1, 1),
Pool("max", 3, 1, 1),
nn.Identity()
], key=cell_name + "_op")
{ key_name: {"_type": "input_choice",
"_value": {"candidates": ["in1", "in2"],
"n_chosen": 1}} }
Returns
-------
dict
the generated search space
"""
search_space = {}
for mutable in self.mutables:
# for now we only generate flattened search space
if isinstance(mutable, LayerChoice):
key = mutable.key
val = [repr(choice) for choice in mutable.choices]
search_space[key] = {"_type": LAYER_CHOICE, "_value": val}
elif isinstance(mutable, InputChoice):
key = mutable.key
search_space[key] = {"_type": INPUT_CHOICE,
"_value": {"candidates": mutable.choose_from,
"n_chosen": mutable.n_chosen}}
else:
raise TypeError("Unsupported mutable type: '%s'." % type(mutable))
return search_space
def __init__(self, num_nodes, in_channels_pp, in_channels_p, out_channels, reduction):
super().__init__()
self.preproc0 = Calibration(in_channels_pp, out_channels)
self.preproc1 = Calibration(in_channels_p, out_channels)
self.num_nodes = num_nodes
name_prefix = "reduce" if reduction else "normal"
self.nodes = nn.ModuleList()
node_labels = [mutables.InputChoice.NO_KEY, mutables.InputChoice.NO_KEY]
for i in range(num_nodes):
node_labels.append("{}_node_{}".format(name_prefix, i))
self.nodes.append(Node(node_labels[-1], node_labels[:-1], out_channels))
self.final_conv_w = nn.Parameter(torch.zeros(out_channels, self.num_nodes + 2, out_channels, 1, 1), requires_grad=True)
self.bn = nn.BatchNorm2d(out_channels, affine=False)
self.reset_parameters()
if memo is None:
memo = set()
if root is None:
root = StructuredMutableTreeNode(None)
if module not in memo:
memo.add(module)
if isinstance(module, Mutable):
if nested_detection is not None:
raise RuntimeError("Cannot have nested search space. Error at {} in {}"
.format(module, nested_detection))
module.name = prefix
module.set_mutator(self)
root = root.add_child(module)
if not isinstance(module, MutableScope):
nested_detection = module
if isinstance(module, InputChoice):
for k in module.choose_from:
if k != InputChoice.NO_KEY and k not in [m.key for m in memo if isinstance(m, Mutable)]:
raise RuntimeError("'{}' required by '{}' not found in keys that appeared before, and is not NO_KEY."
.format(k, module.key))
for name, submodule in module._modules.items():
if submodule is None:
continue
submodule_prefix = prefix + ("." if prefix else "") + name
self._parse_search_space(submodule, root, submodule_prefix, memo=memo,
nested_detection=nested_detection)
return root
def sample_final(self):
result = dict()
edges_max = dict()
for mutable in self.mutables:
if isinstance(mutable, LayerChoice):
max_val, index = torch.max(F.softmax(self.choices[mutable.key], dim=-1)[:-1], 0)
edges_max[mutable.key] = max_val
result[mutable.key] = F.one_hot(index, num_classes=mutable.length).view(-1).bool()
for mutable in self.mutables:
if isinstance(mutable, InputChoice) and mutable.n_chosen is not None:
weights = []
for src_key in mutable.choose_from:
if src_key not in edges_max:
_logger.warning("InputChoice.NO_KEY in '%s' is weighted 0 when selecting inputs.", mutable.key)
weights.append(edges_max.get(src_key, 0.))
weights = torch.tensor(weights) # pylint: disable=not-callable
_, topk_edge_indices = torch.topk(weights, mutable.n_chosen)
selected_multihot = []
for i, src_key in enumerate(mutable.choose_from):
if i not in topk_edge_indices and src_key in result:
result[src_key] = torch.zeros_like(result[src_key]) # clear this choice to optimize calc graph
selected_multihot.append(i in topk_edge_indices)
result[mutable.key] = torch.tensor(selected_multihot, dtype=torch.bool, device=self.device()) # pylint: disable=not-callable
return result
def __init__(self, key, prev_labels, in_filters, out_filters):
super().__init__(key)
self.in_filters = in_filters
self.out_filters = out_filters
self.mutable = mutables.LayerChoice([
ConvBranch(in_filters, out_filters, 3, 1, 1, separable=False),
ConvBranch(in_filters, out_filters, 3, 1, 1, separable=True),
ConvBranch(in_filters, out_filters, 5, 1, 2, separable=False),
ConvBranch(in_filters, out_filters, 5, 1, 2, separable=True),
PoolBranch('avg', in_filters, out_filters, 3, 1, 1),
PoolBranch('max', in_filters, out_filters, 3, 1, 1)
])
if len(prev_labels) > 0:
self.skipconnect = mutables.InputChoice(choose_from=prev_labels, n_chosen=None)
else:
self.skipconnect = None
self.batch_norm = nn.BatchNorm2d(out_filters, affine=False)
if root is None:
root = StructuredMutableTreeNode(None)
if module not in memo:
memo.add(module)
if isinstance(module, Mutable):
if nested_detection is not None:
raise RuntimeError("Cannot have nested search space. Error at {} in {}"
.format(module, nested_detection))
module.name = prefix
module.set_mutator(self)
root = root.add_child(module)
if not isinstance(module, MutableScope):
nested_detection = module
if isinstance(module, InputChoice):
for k in module.choose_from:
if k != InputChoice.NO_KEY and k not in [m.key for m in memo if isinstance(m, Mutable)]:
raise RuntimeError("'{}' required by '{}' not found in keys that appeared before, and is not NO_KEY."
.format(k, module.key))
for name, submodule in module._modules.items():
if submodule is None:
continue
submodule_prefix = prefix + ("." if prefix else "") + name
self._parse_search_space(submodule, root, submodule_prefix, memo=memo,
nested_detection=nested_detection)
return root
stride = 2 if i < num_downsample_connect else 1
choice_keys.append("{}_p{}".format(node_id, i))
self.ops.append(
mutables.LayerChoice(
[
ops.PoolBN('max', channels, 3, stride, 1, affine=False),
ops.PoolBN('avg', channels, 3, stride, 1, affine=False),
nn.Identity() if stride == 1 else ops.FactorizedReduce(channels, channels, affine=False),
ops.SepConv(channels, channels, 3, stride, 1, affine=False),
ops.SepConv(channels, channels, 5, stride, 2, affine=False),
ops.DilConv(channels, channels, 3, stride, 2, 2, affine=False),
ops.DilConv(channels, channels, 5, stride, 4, 2, affine=False)
],
key=choice_keys[-1]))
self.drop_path = ops.DropPath()
self.input_switch = mutables.InputChoice(choose_from=choice_keys, n_chosen=2, key="{}_switch".format(node_id))