Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
Model with a convolution head. More powerful classification, but more difficult to train on top of a hyperlayer.
"""
hyperlayer = BoxAttentionLayer(
glimpses=arg.num_glimpses,
in_size=shape, k=arg.k,
gadditional=arg.gadditional, radditional=arg.radditional, region=(arg.region, arg.region),
min_sigma=arg.min_sigma
)
ch1, ch2, ch3 = 16, 32, 64
h = (arg.k // 8) ** 2 * 64
model = nn.Sequential(
hyperlayer,
util.Reshape((arg.num_glimpses * shape[0], arg.k, arg.k)), # Fold glimpses into channels
nn.Conv2d(arg.num_glimpses * shape[0], ch1, kernel_size=5, padding=2),
activation,
nn.MaxPool2d(kernel_size=2),
nn.Conv2d(ch1, ch2, kernel_size=5, padding=2),
activation,
nn.Conv2d(ch2, ch2, kernel_size=5, padding=2),
activation,
nn.MaxPool2d(kernel_size=2),
nn.Conv2d(ch2, ch3, kernel_size=5, padding=2),
activation,
nn.Conv2d(ch3, ch3, kernel_size=5, padding=2),
activation,
nn.MaxPool2d(kernel_size=2),
util.Flatten(),
nn.Linear(h, 128),
activation,
smeans = means[current, :, :, :, :].view(h*w, k, 2)
ssigmas = sigmas[current, :, :, :].view(h*w, k, 2)
color = (torch.arange(numpixels, dtype=torch.float)[:, None].expand(numpixels, k)/numpixels) * 2.0 - 1.0
smeans = smeans[choices, :, :]
ssigmas = ssigmas[choices, :]
ax = plt.subplot(rows, perrow, current+1)
im = np.transpose(ims[current, :, :, :].cpu().numpy(), (1, 2, 0))
im = np.squeeze(im)
ax.imshow(im, interpolation='nearest', extent=(-0.5, wims-0.5, -0.5, hims-0.5), cmap='gray_r')
util.plot(smeans.reshape(1, -1, 2), ssigmas.reshape(1, -1, 2), color.reshape(1, -1), axes=ax, flip_y=hims, tanh=False)
# for i, ch in enumerate(choices):
# chh, chw = ch//h, ch % w
# chh, chw = chh * scale[0] + scale[0]/2, chw * scale[1] + scale[1]/2
#
# for ik in range(k):
# x, y = smeans[i, ik, :]
#
# l, ml = math.sqrt((y - chw) ** 2 + (x - hims + chh) ** 2), math.sqrt(hims ** 2 + wims ** 2)
#
# ax.add_line(mpl.lines.Line2D([y, chw], [hims-x, hims-chh], linestyle='-', alpha=max(0.0, (1.0 - (l/ml)) - 0.5), color='white', lw=0.5))
plt.gcf()
ulab = f'uninformed, var={unump.var():.4}'
ilab = f'informed, var={inump.var():.4}'
glab = f'gumbel STE (t={arg.gumbel}) var={gnump.var():.4}'
# clab = f'Classical STE var={cnump.var():.4}'
plt.hist([unump, inump, gnump], color=['r', 'g', 'b'], label=[ulab, ilab, glab], bins=arg.bins)
plt.axvline(x=unump.mean(), color='r', ls='--')
plt.axvline(x=inump.mean(), color='g', ls='-.')
plt.axvline(x=gnump.mean(), color='b', ls=':')
# plt.axvline(x=cnump.mean(), color='c')
plt.title(f'Absolute error between true gradient and estimate \n over {ind.sum()} parameters with nonzero gradient.')
plt.legend()
util.basic()
if arg.range is not None:
plt.xlim(*arg.range)
plt.savefig(f'./bias/histogram.all.pdf')
def hyper(self, x):
b, c, h, w = x.size()
k = self.k
coords = util.coordinates((h, w), cuda=x.is_cuda)
# the coordinates of the current pixels in parameters space
# - the index tuples are described relative to these
hw = torch.tensor((h, w), device=d(x), dtype=torch.float)
mids = coords[None, :, :, :].expand(b, 2, h, w) * (hw - 1)[None, :, None, None]
mids = mids.permute(0, 2, 3, 1)
if self.edges == 'sigmoid':
mids = util.inv(mids, mx=hw[None, None, None, :])
mids = mids[:, :, :, None, :].expand(b, h, w, k, 2)
# add coords to channels
if self.admode == 'none':
params = self.params[None, None, None, :].expand(b, h, w, self.nparms)
else:
if self.admode == 'full':
crds = coords[None, :, :, :].expand(b, 2, h, w)
x = torch.cat([x, crds], dim=1)
elif self.admode == 'coords':
x = coords[None, :, :, :].expand(b, 2, h, w)
elif self.admode == 'inputs':
pass
else:
raise Exception(f'adaptivity mode {self.admode} not recognized')
if arg.final:
train = torchvision.datasets.ImageFolder(root=arg.data + '/train/', transform=tr)
test = torchvision.datasets.ImageFolder(root=arg.data + '/test/', transform=tr)
trainloader = DataLoader(train, batch_size=arg.batch, shuffle=True)
testloader = DataLoader(train, batch_size=arg.batch, shuffle=True)
else:
NUM_TRAIN = 45000
NUM_VAL = 5000
total = NUM_TRAIN + NUM_VAL
train = torchvision.datasets.ImageFolder(root=arg.data + '/train/', transform=tr)
trainloader = DataLoader(train, batch_size=arg.batch, sampler=util.ChunkSampler(0, NUM_TRAIN, total))
testloader = DataLoader(train, batch_size=arg.batch, sampler=util.ChunkSampler(NUM_TRAIN, NUM_VAL, total))
for im, labels in trainloader:
shape = im[0].size()
break
num_classes = 10
else:
raise Exception('Task name {} not recognized'.format(arg.task))
activation = nn.ReLU()
hyperlayer = None