Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_graph_json_transform(self):
""" unittest for graph_json_transform function
"""
graph_init = CnnGenerator(10, (32, 32, 3)).generate()
graph_init = to_wider_graph(deepcopy(graph_init))
graph_init = to_deeper_graph(deepcopy(graph_init))
graph_init = to_skip_connection_graph(deepcopy(graph_init))
json_out = graph_to_json(graph_init, "temp.json")
graph_recover = json_to_graph(json_out)
# compare all data in graph
self.assertEqual(graph_init.input_shape, graph_recover.input_shape)
self.assertEqual(graph_init.weighted, graph_recover.weighted)
self.assertEqual(
graph_init.layer_id_to_input_node_ids,
graph_recover.layer_id_to_input_node_ids,
)
self.assertEqual(graph_init.adj_list, graph_recover.adj_list)
self.assertEqual(
graph_init.reverse_adj_list,
graph_recover.reverse_adj_list)
self.assertEqual(
len(graph_init.operation_history), len(
graph_recover.operation_history)
)
def test_to_skip_connection_graph(self):
""" unittest for to_skip_connection_graph function
"""
graph_init = CnnGenerator(10, (32, 32, 3)).generate()
json_out = graph_to_json(graph_init, "temp.json")
graph_recover = json_to_graph(json_out)
skip_connection_graph = to_wider_graph(deepcopy(graph_recover))
model = skip_connection_graph.produce_torch_model()
out = model(torch.ones(1, 3, 32, 32))
self.assertEqual(out.shape, torch.Size([1, 10]))
def test_to_wider_graph(self):
""" unittest for to_wider_graph function
"""
graph_init = CnnGenerator(10, (32, 32, 3)).generate()
json_out = graph_to_json(graph_init, "temp.json")
graph_recover = json_to_graph(json_out)
wider_graph = to_wider_graph(deepcopy(graph_recover))
model = wider_graph.produce_torch_model()
out = model(torch.ones(1, 3, 32, 32))
self.assertEqual(out.shape, torch.Size([1, 10]))
def build_graph_from_json(ir_model_json):
"""build model from json representation
"""
graph = json_to_graph(ir_model_json)
logging.debug(graph.operation_history)
model = graph.produce_torch_model()
return model
def build_graph_from_json(ir_model_json):
"""build model from json representation
"""
graph = json_to_graph(ir_model_json)
logging.debug(graph.operation_history)
model = graph.produce_keras_model()
return model
def build_graph_from_json(ir_model_json):
"""build model from json representation
"""
graph = json_to_graph(ir_model_json)
logging.debug(graph.operation_history)
model = graph.produce_keras_model()
return model
def build_graph_from_json(ir_model_json):
"""build model from json representation
"""
graph = json_to_graph(ir_model_json)
logging.debug(graph.operation_history)
model = graph.produce_torch_model()
return model
Parameters
----------
model_id : int
model index
Returns
-------
load_model : Graph
the model graph representation
"""
with open(os.path.join(self.path, str(model_id) + ".json")) as fin:
json_str = fin.read().replace("\n", "")
load_model = json_to_graph(json_str)
return load_model