Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def __init__(self, key, prev_labels, in_filters, out_filters):
super().__init__(key)
self.in_filters = in_filters
self.out_filters = out_filters
self.mutable = mutables.LayerChoice([
ConvBranch(in_filters, out_filters, 3, 1, 1, separable=False),
ConvBranch(in_filters, out_filters, 3, 1, 1, separable=True),
ConvBranch(in_filters, out_filters, 5, 1, 2, separable=False),
ConvBranch(in_filters, out_filters, 5, 1, 2, separable=True),
PoolBranch('avg', in_filters, out_filters, 3, 1, 1),
PoolBranch('max', in_filters, out_filters, 3, 1, 1)
])
if len(prev_labels) > 0:
self.skipconnect = mutables.InputChoice(choose_from=prev_labels, n_chosen=None)
else:
self.skipconnect = None
self.batch_norm = nn.BatchNorm2d(out_filters, affine=False)
def sample_final(self):
assert set(self._chosen_arch.keys()) == set(self._search_space.keys()), \
"Unmatched keys, expected keys '{}' from search space, found '{}'.".format(self._search_space.keys(),
self._chosen_arch.keys())
result = dict()
for mutable in self.mutables:
assert mutable.key in self._chosen_arch, "Expected '{}' in chosen arch, but not found.".format(mutable.key)
data = self._chosen_arch[mutable.key]
assert isinstance(data, dict) and "_value" in data and "_idx" in data, \
"'{}' is not a valid choice.".format(data)
value = data["_value"]
idx = data["_idx"]
search_space_item = self._search_space[mutable.key]["_value"]
if isinstance(mutable, LayerChoice):
result[mutable.key] = self._sample_layer_choice(mutable, idx, value, search_space_item)
elif isinstance(mutable, InputChoice):
result[mutable.key] = self._sample_input_choice(mutable, idx, value, search_space_item)
else:
raise TypeError("Unsupported mutable type: '%s'." % type(mutable))
return result
def __init__(self, node_id, num_prev_nodes, channels, num_downsample_connect):
super().__init__()
self.ops = nn.ModuleList()
choice_keys = []
for i in range(num_prev_nodes):
stride = 2 if i < num_downsample_connect else 1
choice_keys.append("{}_p{}".format(node_id, i))
self.ops.append(
mutables.LayerChoice(
[
ops.PoolBN('max', channels, 3, stride, 1, affine=False),
ops.PoolBN('avg', channels, 3, stride, 1, affine=False),
nn.Identity() if stride == 1 else ops.FactorizedReduce(channels, channels, affine=False),
ops.SepConv(channels, channels, 3, stride, 1, affine=False),
ops.SepConv(channels, channels, 5, stride, 2, affine=False),
ops.DilConv(channels, channels, 3, stride, 2, 2, affine=False),
ops.DilConv(channels, channels, 5, stride, 4, 2, affine=False)
],
key=choice_keys[-1]))
self.drop_path = ops.DropPath()
self.input_switch = mutables.InputChoice(choose_from=choice_keys, n_chosen=2, key="{}_switch".format(node_id))
def __init__(self, cell_name, prev_labels, channels):
super().__init__()
self.input_choice = mutables.InputChoice(choose_from=prev_labels, n_chosen=1, return_mask=True,
key=cell_name + "_input")
self.op_choice = mutables.LayerChoice([
SepConvBN(channels, channels, 3, 1),
SepConvBN(channels, channels, 5, 2),
Pool("avg", 3, 1, 1),
Pool("max", 3, 1, 1),
nn.Identity()
], key=cell_name + "_op")
def _sample(self, tree):
mutable = tree.mutable
if isinstance(mutable, LayerChoice) and mutable.key not in self._choices:
self._choices[mutable.key] = self._sample_layer_choice(mutable)
elif isinstance(mutable, InputChoice) and mutable.key not in self._choices:
self._choices[mutable.key] = self._sample_input_choice(mutable)
for child in tree.children:
self._sample(child)
if isinstance(mutable, MutableScope) and mutable.key not in self._anchors_hid:
if self.cell_exit_extra_step:
self._lstm_next_step()
self._mark_anchor(mutable.key)
Here is the search space format:
::
{ key_name: {"_type": "layer_choice",
"_value": ["conv1", "conv2"]} }
{ key_name: {"_type": "input_choice",
"_value": {"candidates": ["in1", "in2"],
"n_chosen": 1}} }
Returns
-------
dict
the generated search space
"""
search_space = {}
for mutable in self.mutables:
# for now we only generate flattened search space
if isinstance(mutable, LayerChoice):
key = mutable.key
val = [repr(choice) for choice in mutable.choices]
search_space[key] = {"_type": LAYER_CHOICE, "_value": val}
elif isinstance(mutable, InputChoice):
key = mutable.key
search_space[key] = {"_type": INPUT_CHOICE,
"_value": {"candidates": mutable.choose_from,
"n_chosen": mutable.n_chosen}}
else:
raise TypeError("Unsupported mutable type: '%s'." % type(mutable))
return search_space
if isinstance(mutable, LayerChoice):
switches = self.switches.get(mutable.key, [True for j in range(mutable.length)])
choices = self.choices[mutable.key]
operations_count = np.sum(switches)
# +1 and -1 are caused by zero operation in darts network
# the zero operation is not in choices list in network, but its weight are in,
# so it needs one more weights and switch for zero.
self.choices[mutable.key] = nn.Parameter(1.0E-3 * torch.randn(operations_count + 1))
self.switches[mutable.key] = switches
# update LayerChoice instances in model,
# it's physically remove dropped choices operations.
for module in self.model.modules():
if isinstance(module, LayerChoice):
switches = self.switches.get(module.key)
choices = self.choices[module.key]
if len(module.choices) > len(choices):
# from last to first, so that it won't effect previous indexes after removed one.
for index in range(len(switches)-1, -1, -1):
if switches[index] == False:
del(module.choices[index])
module.length -= 1
def __init__(self, model, pdarts_epoch_index, pdarts_num_to_drop, switches={}):
self.pdarts_epoch_index = pdarts_epoch_index
self.pdarts_num_to_drop = pdarts_num_to_drop
if switches is None:
self.switches = {}
else:
self.switches = switches
super(PdartsMutator, self).__init__(model)
# this loop go through mutables with different keys,
# it's mainly to update length of choices.
for mutable in self.mutables:
if isinstance(mutable, LayerChoice):
switches = self.switches.get(mutable.key, [True for j in range(mutable.length)])
choices = self.choices[mutable.key]
operations_count = np.sum(switches)
# +1 and -1 are caused by zero operation in darts network
# the zero operation is not in choices list in network, but its weight are in,
# so it needs one more weights and switch for zero.
self.choices[mutable.key] = nn.Parameter(1.0E-3 * torch.randn(operations_count + 1))
self.switches[mutable.key] = switches
# update LayerChoice instances in model,
# it's physically remove dropped choices operations.
for module in self.model.modules():
if isinstance(module, LayerChoice):
switches = self.switches.get(module.key)