Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def pytorch_to_hls(yamlConfig):
######################
## Do translation
######################
print('Interpreting Model')
reader = PyTorchDataReader(yamlConfig)
core_layers = ['Linear']
skip_layers = ['Dropout', 'Flatten']
activation_layers = ['ReLU', 'Sigmoid', 'Tanh', 'SELU', 'LeakyReLU', 'Softmax', 'Softplus', 'Softsign']
supported_layers = core_layers + skip_layers + activation_layers
#This is a list of dictionaries to hold all the layer info we need to generate HLS
layer_list = []
#Loop through layers
print('Topology:')
modelstr = repr(reader.torch_model).split('\n')
for pytorch_layer in modelstr:
layer_match = re.match(r'\((\d)\): (\w+)\((.*)\)', pytorch_layer.strip())
if layer_match is None:
continue
layer['name'] = layer['activation'] + '_' + str(layer_idx)
layer_list.append(layer)
input_layer = {}
input_layer['name'] = 'input1'
input_layer['class_name'] = 'InputLayer'
input_layer['input_shape'] = [layer_list[0]['n_in']]
layer_list.insert(0, input_layer)
#################
## Generate HLS
#################
reader = PyTorchDataReader(yamlConfig)
print('Creating HLS model')
hls_model = HLSModel(yamlConfig, reader, layer_list)
optimizers = ['eliminate_linear_activation', 'merge_batch_norm_quantized_tanh', 'quantize_dense_output']
optimize_model(hls_model, optimizers)
return hls_model