Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
parser.add_argument('--load_model', default=False, help='Load a trained model?')
parser.add_argument('--dir_save_images', default='interpolation_images', help='Dir to save the sequence of images')
args = parser.parse_args()
num_epochs = args.num_epochs
load_model = args.load_model
dir_save_images = args.dir_save_images
dir_to_save = get_cache_dir('reg_inverse_example')
transforms_to_apply = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Pixel values should be in [-1,1]
])
mnist_dir = get_dataset_dir("MNIST", create=True)
dataset = datasets.MNIST(mnist_dir, train=True, download=True, transform=transforms_to_apply)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, pin_memory=True)
fixed_dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
fixed_batch = next(iter(fixed_dataloader))
fixed_batch = fixed_batch[0].float().cuda()
scattering = Scattering(J=2, shape=(28, 28))
scattering.cuda()
scattering_fixed_batch = scattering(fixed_batch).squeeze(1)
num_input_channels = scattering_fixed_batch.shape[1]
num_hidden_channels = num_input_channels
generator = Generator(num_input_channels, num_hidden_channels)
generator.cuda()