Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
pytorch_model = load_orig_imagenet_model(arch_name='resnet50')
# load the class label
label_map = load_imagenet_label_map()
elif args.dataset == 'places365':
pytorch_model = load_orig_places365_model(arch_name='resnet50')
# load the class label
label_map = load_class_label()
else:
print('Invalid datasest!!')
exit(0)
pytorch_explainer = lime_image.LimeImageExplainer(random_state=args.lime_explainer_seed)
slic_parameters = {'n_segments': args.lime_superpixel_num, 'compactness': 30, 'sigma': 3}
segmenter = SegmentationAlgorithm('slic', **slic_parameters)
pill_transf = get_pil_transform()
#########################################################
# Function to compute probabilities
# Pytorch
pytorch_preprocess_transform = get_pytorch_preprocess_transform()
def pytorch_batch_predict(images):
batch = torch.stack(tuple(pytorch_preprocess_transform(i) for i in images), dim=0)
batch = batch.to('cuda')
if args.if_pre == 1:
logits = pytorch_model(batch)
probs = F.softmax(logits, dim=1)
def explain_raw(model, img, topLabels, numSamples, numFeatures, hideRest, hideColor, positiveOnly):
img, oldImg = transform_img_fn(img)
img = img*(1./255)
prediction = model.predict(img)
explainer = lime_image.LimeImageExplainer()
img = np.squeeze(img)
explanation = explainer.explain_instance(img, model.predict, top_labels=topLabels, hide_color=hideColor, num_samples=numSamples)
temp, mask = explanation.get_image_and_mask(getTopPrediction(prediction[0]), positive_only=positiveOnly, num_features=numFeatures, hide_rest=hideRest)
imgExplained = mark_boundaries(temp, mask)
img = Image.fromarray(np.uint8(imgExplained*255))
imgByteArr = io.BytesIO()
img.save(imgByteArr, format='JPEG')
imgByteArr = imgByteArr.getvalue()
return imgByteArr