Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1)
if self.zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
constant_init(m.norm3, 0)
elif isinstance(m, BasicBlock):
constant_init(m.norm2, 0)
else:
raise TypeError('pretrained must be a str or None')
def init_weights(self):
kaiming_init(self.conv_logits)
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
logger = logging.getLogger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
for m in self.modules():
if isinstance(m, nn.Conv2d) and hasattr(m, 'zero_init') and m.zero_init:
constant_init(m, 0)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
if hasattr(m, 'zero_init') and m.zero_init:
constant_init(m, 0)
else:
kaiming_init(m)
elif isinstance(m, nn.BatchNorm2d):
constant_init(m, 1)
else:
raise TypeError('pretrained must be a str or None')
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
# TODO: compare mode = "fan_in" or "fan_out"
kaiming_init(m)
for m in self.modules():
if isinstance(m, nn.ConvTranspose2d):
normal_init(m, std=0.001)
nn.init.constant_(self.deconv2.bias, -np.log(0.99 / 0.01))
elif isinstance(module, nn.Conv3d) and rhasattr(resnet2d, name):
new_weight = rgetattr(resnet2d, name).weight.data.unsqueeze(2).expand_as(module.weight) / module.weight.data.shape[2]
module.weight.data.copy_(new_weight)
logging.info("{}.weight loaded from weights file into {}".format(name, new_weight.shape))
if hasattr(module, 'bias') and module.bias is not None:
new_bias = rgetattr(resnet2d, name).bias.data
module.bias.data.copy_(new_bias)
logging.info("{}.bias loaded from weights file into {}".format(name, new_bias.shape))
elif isinstance(module, nn.BatchNorm3d) and rhasattr(resnet2d, name):
for attr in ['weight', 'bias', 'running_mean', 'running_var']:
logging.info("{}.{} loaded from weights file into {}".format(name, attr, getattr(rgetattr(resnet2d, name), attr).shape))
setattr(module, attr, getattr(rgetattr(resnet2d, name), attr))
elif self.pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv3d):
kaiming_init(m)
elif isinstance(m, nn.BatchNorm3d):
constant_init(m, 1)
else:
raise TypeError('pretrained must be a str or None')
def init_weights(self):
kaiming_init(self.conv_logits)
def init_weights(self):
for conv in self.convs:
kaiming_init(conv)
for fc in self.fcs:
kaiming_init(
fc,
a=1,
mode='fan_in',
nonlinearity='leaky_relu',
distribution='uniform')
normal_init(self.fc_mask_iou, std=0.01)
def init_weights(self):
if isinstance(self.pretrained, str):
logger = logging.getLogger()
load_checkpoint(self, self.pretrained, strict=False, logger=logger)
elif self.pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, nn.BatchNorm2d):
constant_init(m, 1)
else:
raise TypeError('pretrained must be a str or None')
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
logger = logging.getLogger()
load_checkpoint(self, pretrained, strict=False, logger=logger, map_location=torch.device('cpu'))
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
constant_init(m, 1)
if self.dcn is not None:
for m in self.modules():
if isinstance(m, Bottleneck) and hasattr(
m, 'conv2_offset'):
constant_init(m.conv2_offset, 0)
if self.zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
constant_init(m.norm3, 0)
elif isinstance(m, BasicBlock):
constant_init(m.norm2, 0)
else:
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
logger = logging.getLogger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
constant_init(m, 1)