Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_postprocessing_1D_array_no_confidences(self):
array = np.array([0.1, 0.2, 0, 0.7, 0])
true_label = {outputs.Label.LABEL_KEY: 3}
out = outputs.Label(show_confidences=False)
label = json.loads(out.postprocess(array))
self.assertDictEqual(label, true_label)
def test_set_sample_data(self):
test_array = ["test1", "test2", "test3"]
temp_dir = tempfile.mkdtemp()
inp = inputs.Sketchpad()
out = outputs.Label()
networking.build_template(temp_dir, inp, out)
networking.set_sample_data_in_config_file(temp_dir, test_array)
# We need to come up with a better way so that the config file isn't invalid json unless
# the following parameters are set... (TODO: abidlabs)
networking.set_always_flagged_in_config_file(temp_dir, False)
networking.set_disabled_in_config_file(temp_dir, False)
config_file = os.path.join(temp_dir, 'static/config.json')
with open(config_file) as json_file:
data = json.load(json_file)
self.assertTrue(test_array == data["sample_inputs"])
def test_postprocessing_int(self):
true_label_array = np.array([[[3]]])
true_label = {outputs.Label.LABEL_KEY: 3}
out = outputs.Label()
label = json.loads(out.postprocess(true_label_array))
self.assertDictEqual(label, true_label)
def postprocess(self, prediction):
"""
"""
response = dict()
# TODO(abidlabs): check if list, if so convert to numpy array
if isinstance(prediction, np.ndarray):
prediction = prediction.squeeze()
if prediction.size == 1: # if it's single value
response[Label.LABEL_KEY] = self.get_label_name(np.asscalar(prediction))
elif len(prediction.shape) == 1: # if a 1D
response[Label.LABEL_KEY] = self.get_label_name(int(prediction.argmax()))
if self.show_confidences:
response[Label.CONFIDENCES_KEY] = []
for i in range(self.num_top_classes):
response[Label.CONFIDENCES_KEY].append({
Label.LABEL_KEY: self.get_label_name(int(prediction.argmax())),
Label.CONFIDENCE_KEY: float(prediction.max()),
})
prediction[prediction.argmax()] = 0
elif isinstance(prediction, str):
response[Label.LABEL_KEY] = prediction
else:
raise ValueError("Unable to post-process model prediction.")
return json.dumps(response)