Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
:param saliency: an optional function that takes the model and the processed input and returns a 2-d array
"""
if isinstance(inputs, str):
self.input_interface = gradio.inputs.registry[inputs.lower()](
preprocessing_fns
)
elif isinstance(inputs, gradio.inputs.AbstractInput):
self.input_interface = inputs
else:
raise ValueError("Input interface must be of type `str` or `AbstractInput`")
if isinstance(outputs, str):
self.output_interface = gradio.outputs.registry[outputs.lower()](
postprocessing_fns
)
elif isinstance(outputs, gradio.outputs.AbstractOutput):
self.output_interface = outputs
else:
raise ValueError(
"Output interface must be of type `str` or `AbstractOutput`"
)
self.model_obj = model
if model_type is None:
model_type = self._infer_model_type(model)
if verbose:
print(
"Model type not explicitly identified, inferred to be: {}".format(
self.VALID_MODEL_TYPES[model_type]
)
)
elif not (model_type.lower() in self.VALID_MODEL_TYPES):
ValueError("model_type must be one of: {}".format(self.VALID_MODEL_TYPES))
"""
"""
return prediction
def rebuild_flagged(self, dir, msg):
"""
Default rebuild method to decode a base64 image
"""
im = preprocessing_utils.decode_base64_to_image(msg)
timestamp = datetime.datetime.now()
filename = f'output_{timestamp.strftime("%Y-%m-%d-%H-%M-%S")}.png'
im.save(f'{dir}/{filename}', 'PNG')
return filename
registry = {cls.__name__.lower(): cls for cls in AbstractOutput.__subclasses__()}
})
prediction[prediction.argmax()] = 0
elif isinstance(prediction, str):
response[Label.LABEL_KEY] = prediction
else:
raise ValueError("Unable to post-process model prediction.")
return json.dumps(response)
def rebuild_flagged(self, dir, msg):
"""
Default rebuild method for label
"""
return json.loads(msg)
class Textbox(AbstractOutput):
def get_name(self):
return 'textbox'
def postprocess(self, prediction):
"""
"""
return prediction
def rebuild_flagged(self, dir, msg):
"""
Default rebuild method for label
"""
return json.loads(msg)
def get_name(self):
return 'textbox'
def postprocess(self, prediction):
"""
"""
return prediction
def rebuild_flagged(self, dir, msg):
"""
Default rebuild method for label
"""
return json.loads(msg)
class Image(AbstractOutput):
def get_name(self):
return 'image'
def postprocess(self, prediction):
"""
"""
return prediction
def rebuild_flagged(self, dir, msg):
"""
Default rebuild method to decode a base64 image
"""
im = preprocessing_utils.decode_base64_to_image(msg)
timestamp = datetime.datetime.now()
filename = f'output_{timestamp.strftime("%Y-%m-%d-%H-%M-%S")}.png'
:param saliency: an optional function that takes the model and the processed input and returns a 2-d array
"""
if isinstance(inputs, str):
self.input_interface = gradio.inputs.registry[inputs.lower()](
preprocessing_fns
)
elif isinstance(inputs, gradio.inputs.AbstractInput):
self.input_interface = inputs
else:
raise ValueError("Input interface must be of type `str` or `AbstractInput`")
if isinstance(outputs, str):
self.output_interface = gradio.outputs.registry[outputs.lower()](
postprocessing_fns
)
elif isinstance(outputs, gradio.outputs.AbstractOutput):
self.output_interface = outputs
else:
raise ValueError(
"Output interface must be of type `str` or `AbstractOutput`"
)
self.model_obj = model
if model_type is None:
model_type = self._infer_model_type(model)
if verbose:
print(
"Model type not explicitly identified, inferred to be: {}".format(
self.VALID_MODEL_TYPES[model_type]
)
)
elif not (model_type.lower() in self.VALID_MODEL_TYPES):
ValueError("model_type must be one of: {}".format(self.VALID_MODEL_TYPES))
@abstractmethod
def postprocess(self, prediction):
"""
All interfaces should define a default postprocessing method
"""
pass
@abstractmethod
def rebuild_flagged(self, inp):
"""
All interfaces should define a method that rebuilds the flagged output when it's passed back (i.e. rebuilds image from base64)
"""
pass
class Label(AbstractOutput):
LABEL_KEY = 'label'
CONFIDENCES_KEY = 'confidences'
CONFIDENCE_KEY = 'confidence'
def __init__(self, postprocessing_fn=None, num_top_classes=3, show_confidences=True, label_names=None,
max_label_length=None, max_label_words=None, word_delimiter=" "):
self.num_top_classes = num_top_classes
self.show_confidences = show_confidences
self.label_names = label_names
self.max_label_length = max_label_length
self.max_label_words = max_label_words
self.word_delimiter = word_delimiter
super().__init__(postprocessing_fn=postprocessing_fn)
def get_name(self):
return 'label'