Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def initialize(self, model='net-lin', net='squeeze', colorspace='Lab', use_gpu=True, printNet=False):
'''
INPUTS
model - ['net-lin'] for linearly calibrated network
['net'] for off-the-shelf network
['L2'] for L2 distance in Lab colorspace
['SSIM'] for ssim in RGB colorspace
net - ['squeeze','alex','vgg']
colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM
use_gpu - bool - whether or not to use a GPU
printNet - bool - whether or not to print network architecture out
'''
BaseModel.initialize(self, use_gpu=use_gpu)
self.model = model
self.net = net
self.use_gpu = use_gpu
if self.use_gpu:
self.map_location = None
else:
self.map_location = lambda storage, loc: storage
self.model_name = '%s [%s]'%(model,net)
if(self.model == 'net-lin'): # pretrained net + linear layer
self.net = networks.PNetLin(use_gpu=use_gpu,pnet_type=net,use_dropout=True)
weight_path = os.path.join(os.path.dirname(__file__), 'weights', '%s.pth' % net)
self.net.load_state_dict(torch.load(weight_path,
map_location=self.map_location))
elif(self.model=='net'): # pretrained network