Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_get_upsampling_weight():
src = skimage.data.coffee()
x = src.transpose(2, 0, 1)
x = x[np.newaxis, :, :, :]
x = torch.from_numpy(x).float()
x = torch.autograd.Variable(x)
in_channels = 3
out_channels = 3
kernel_size = 4
m = torch.nn.ConvTranspose2d(
in_channels, out_channels, kernel_size, stride=2, bias=False)
m.weight.data = get_upsampling_weight(
in_channels, out_channels, kernel_size)
y = m(x)
y = y.data.numpy()
y = y[0]
y = y.transpose(1, 2, 0)
dst = y.astype(np.uint8)
assert abs(src.shape[0] * 2 - dst.shape[0]) <= 2
assert abs(src.shape[1] * 2 - dst.shape[1]) <= 2
return src, dst
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
m.weight.data.zero_()
if m.bias is not None:
m.bias.data.zero_()
if isinstance(m, nn.ConvTranspose2d):
assert m.kernel_size[0] == m.kernel_size[1]
initial_weight = get_upsampling_weight(
m.in_channels, m.out_channels, m.kernel_size[0])
m.weight.data.copy_(initial_weight)
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
m.weight.data.zero_()
if m.bias is not None:
m.bias.data.zero_()
if isinstance(m, nn.ConvTranspose2d):
assert m.kernel_size[0] == m.kernel_size[1]
initial_weight = get_upsampling_weight(
m.in_channels, m.out_channels, m.kernel_size[0])
m.weight.data.copy_(initial_weight)