Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
"""
:param inputs: a string or `AbstractInput` representing the input interface.
:param outputs: a string or `AbstractOutput` representing the output interface.
:param model: the model object, such as a sklearn classifier or keras model.
:param model_type: what kind of trained model, can be 'keras' or 'sklearn' or 'function'. Inferred if not
provided.
:param preprocessing_fns: an optional function that overrides the preprocessing function of the input interface.
:param postprocessing_fns: an optional function that overrides the postprocessing fn of the output interface.
:param saliency: an optional function that takes the model and the processed input and returns a 2-d array
"""
if isinstance(inputs, str):
self.input_interface = gradio.inputs.registry[inputs.lower()](
preprocessing_fns
)
elif isinstance(inputs, gradio.inputs.AbstractInput):
self.input_interface = inputs
else:
raise ValueError("Input interface must be of type `str` or `AbstractInput`")
if isinstance(outputs, str):
self.output_interface = gradio.outputs.registry[outputs.lower()](
postprocessing_fns
)
elif isinstance(outputs, gradio.outputs.AbstractOutput):
self.output_interface = outputs
else:
raise ValueError(
"Output interface must be of type `str` or `AbstractOutput`"
)
self.model_obj = model
if model_type is None:
model_type = self._infer_model_type(model)
im.save(f'{dir}/{filename}', 'PNG')
return filename
def get_sample_inputs(self):
encoded_images = []
if self.sample_inputs is not None:
for input in self.sample_inputs:
if self.flatten:
input = input.reshape((self.image_width, self.image_height))
if self.invert_colors:
input = 1 - input
encoded_images.append(preprocessing_utils.encode_array_to_base64(input))
return encoded_images
class Webcam(AbstractInput):
def __init__(self, preprocessing_fn=None, image_width=224, image_height=224, num_channels=3):
self.image_width = image_width
self.image_height = image_height
self.num_channels = num_channels
super().__init__(preprocessing_fn=preprocessing_fn)
def get_validation_inputs(self):
return validation_data.BASE64_COLOR_IMAGES
def get_name(self):
return 'webcam'
def preprocess(self, inp):
"""
Default preprocessing method for is to convert the picture to black and white and resize to be 48x48
"""
"""
By default, no pre-processing is applied to a microphone input file
"""
file_obj = preprocessing_utils.decode_base64_to_wav_file(inp)
mfcc_array = preprocessing_utils.generate_mfcc_features_from_audio_file(file_obj.name)
return mfcc_array
def rebuild_flagged(self, dir, msg):
"""
Default rebuild method for csv
"""
return json.loads(msg)
# Automatically adds all subclasses of AbstractInput into a dictionary (keyed by class name) for easy referencing.
registry = {cls.__name__.lower(): cls for cls in AbstractInput.__subclasses__()}
@abstractmethod
def preprocess(self, inp):
"""
All interfaces should define a default preprocessing method
"""
pass
@abstractmethod
def rebuild_flagged(self, dir, msg):
"""
All interfaces should define a method that rebuilds the flagged input when it's passed back (i.e. rebuilds image from base64)
"""
pass
class Sketchpad(AbstractInput):
def __init__(self, preprocessing_fn=None, shape=(28, 28), invert_colors=True, flatten=False, scale=1/255, shift=0,
dtype='float64', sample_inputs=None):
self.image_width = shape[0]
self.image_height = shape[1]
self.invert_colors = invert_colors
self.flatten = flatten
self.scale = scale
self.shift = shift
self.dtype = dtype
self.sample_inputs = sample_inputs
super().__init__(preprocessing_fn=preprocessing_fn)
def get_name(self):
return 'sketchpad'
array = np.array(im).flatten().reshape(1, self.image_width, self.image_height, self.num_channels)
return array
def rebuild_flagged(self, dir, msg):
"""
Default rebuild method to decode a base64 image
"""
inp = msg['data']['input']
im = preprocessing_utils.decode_base64_to_image(inp)
timestamp = time.time()*1000
filename = f'input_{timestamp}.png'
im.save(f'{dir}/{filename}', 'PNG')
return filename
class Textbox(AbstractInput):
def __init__(self, sample_inputs=None, preprocessing_fn=None):
self.sample_inputs = sample_inputs
super().__init__(preprocessing_fn=preprocessing_fn)
def get_validation_inputs(self):
return validation_data.ENGLISH_TEXTS
def get_name(self):
return 'textbox'
def preprocess(self, inp):
"""
By default, no pre-processing is applied to text.
"""
return inp
timestamp = time.time()*1000
filename = f'input_{timestamp}.png'
im.save(f'{dir}/{filename}', 'PNG')
return filename
# TODO(abidlabs): clean this up
def save_to_file(self, dir, img):
"""
"""
timestamp = time.time()*1000
filename = f'input_{timestamp}.png'
img.save(f'{dir}/{filename}', 'PNG')
return filename
class CSV(AbstractInput):
def get_name(self):
return 'csv'
def preprocess(self, inp):
"""
By default, no pre-processing is applied to a CSV file (TODO:aliabid94 fix this)
"""
return inp
def rebuild_flagged(self, dir, msg):
"""
Default rebuild method for csv
"""
return json.loads(msg)
return 'csv'
def preprocess(self, inp):
"""
By default, no pre-processing is applied to a CSV file (TODO:aliabid94 fix this)
"""
return inp
def rebuild_flagged(self, dir, msg):
"""
Default rebuild method for csv
"""
return json.loads(msg)
class Microphone(AbstractInput):
def get_name(self):
return 'microphone'
def preprocess(self, inp):
"""
By default, no pre-processing is applied to a microphone input file
"""
file_obj = preprocessing_utils.decode_base64_to_wav_file(inp)
mfcc_array = preprocessing_utils.generate_mfcc_features_from_audio_file(file_obj.name)
return mfcc_array
def rebuild_flagged(self, dir, msg):
"""
Default rebuild method for csv
"""
"""
:param inputs: a string or `AbstractInput` representing the input interface.
:param outputs: a string or `AbstractOutput` representing the output interface.
:param model: the model object, such as a sklearn classifier or keras model.
:param model_type: what kind of trained model, can be 'keras' or 'sklearn' or 'function'. Inferred if not
provided.
:param preprocessing_fns: an optional function that overrides the preprocessing function of the input interface.
:param postprocessing_fns: an optional function that overrides the postprocessing fn of the output interface.
:param saliency: an optional function that takes the model and the processed input and returns a 2-d array
"""
if isinstance(inputs, str):
self.input_interface = gradio.inputs.registry[inputs.lower()](
preprocessing_fns
)
elif isinstance(inputs, gradio.inputs.AbstractInput):
self.input_interface = inputs
else:
raise ValueError("Input interface must be of type `str` or `AbstractInput`")
if isinstance(outputs, str):
self.output_interface = gradio.outputs.registry[outputs.lower()](
postprocessing_fns
)
elif isinstance(outputs, gradio.outputs.AbstractOutput):
self.output_interface = outputs
else:
raise ValueError(
"Output interface must be of type `str` or `AbstractOutput`"
)
self.model_obj = model
if model_type is None:
model_type = self._infer_model_type(model)
def rebuild_flagged(self, dir, msg):
"""
Default rebuild method for text saves it .txt file
"""
timestamp = time.time()*1000
filename = f'input_{timestamp}.png'
with open(f'{dir}/{filename}.txt','w') as f:
f.write(msg)
return filename
def get_sample_inputs(self):
return self.sample_inputs
class ImageUpload(AbstractInput):
def __init__(self, preprocessing_fn=None, shape=(224, 224, 3), image_mode='RGB',
scale=1/127.5, shift=-1, cropper_aspect_ratio=None):
self.image_width = shape[0]
self.image_height = shape[1]
self.num_channels = shape[2]
self.image_mode = image_mode
self.scale = scale
self.shift = shift
self.cropper_aspect_ratio = "false" if cropper_aspect_ratio is None else cropper_aspect_ratio
super().__init__(preprocessing_fn=preprocessing_fn)
def get_validation_inputs(self):
return validation_data.BASE64_COLOR_IMAGES
def get_name(self):
return 'image_upload'