Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_retrain_release(annotations, release_model):
release_model.config["epochs"] = 1
release_model.config["save-snapshot"] = False
release_model.config["steps"] = 1
assert release_model.config["weights"] == release_model.weights
#test that it gets passed to retinanet
arg_list = utilities.format_args(annotations, release_model.config, images_per_epoch=1)
strs = ["--weights" == x for x in arg_list]
index = np.where(strs)[0][0] + 1
arg_list[index] == release_model.weights
def test_format_args(annotations, config):
arg_list = utilities.format_args(annotations, config)
assert isinstance(arg_list, list)
def test_random_transform(annotations):
test_model = deepforest.deepforest()
test_model.config["random_transform"] = True
arg_list = utilities.format_args(annotations, test_model.config)
assert "--random-transform" in arg_list
def test_format_args_steps(annotations, config):
arg_list = utilities.format_args(annotations, config, images_per_epoch=2)
assert isinstance(arg_list, list)
#A bit ugly, but since its a list, what is the argument after --steps to assert
steps_position = np.where(["--steps" in x for x in arg_list])[0][0] + 1
assert arg_list[steps_position] == '2'
def predict_generator(self, annotations, comet_experiment = None, iou_threshold=0.5, score_threshold=0.05, max_detections=200):
"""Predict bounding boxes for a model using a csv fit_generator
Args:
annotations (str): Path to csv label file, labels are in the format -> path/to/image.jpg,x1,y1,x2,y2,class_name
iou_threshold(float): IoU Threshold to count for a positive detection (defaults to 0.5)
score_threshold (float): Eliminate bounding boxes under this threshold
max_detections (int): Maximum number of bounding box predictions
comet_experiment(object): A comet experiment class objects to track
Return:
boxes_output: a pandas dataframe of bounding boxes for each image in the annotations file
"""
#Format args for CSV generator
arg_list = utilities.format_args(annotations, self.config)
args = parse_args(arg_list)
#create generator
generator = CSVGenerator(
args.annotations,
args.classes,
image_min_side=args.image_min_side,
image_max_side=args.image_max_side,
config=args.config,
shuffle_groups=False,
)
if self.prediction_model:
boxes_output = [ ]
#For each image, gather predictions
for i in range(generator.size()):
def evaluate_generator(self, annotations, comet_experiment = None, iou_threshold=0.5, score_threshold=0.05, max_detections=200):
""" Evaluate prediction model using a csv fit_generator
Args:
annotations (str): Path to csv label file, labels are in the format -> path/to/image.jpg,x1,y1,x2,y2,class_name
iou_threshold(float): IoU Threshold to count for a positive detection (defaults to 0.5)
score_threshold (float): Eliminate bounding boxes under this threshold
max_detections (int): Maximum number of bounding box predictions
comet_experiment(object): A comet experiment class objects to track
Return:
mAP: Mean average precision of the evaluated data
"""
#Format args for CSV generator
arg_list = utilities.format_args(annotations, self.config)
args = parse_args(arg_list)
#create generator
validation_generator = CSVGenerator(
args.annotations,
args.classes,
image_min_side=args.image_min_side,
image_max_side=args.image_max_side,
config=args.config,
shuffle_groups=False,
)
average_precisions = evaluate(
validation_generator,
self.prediction_model,
iou_threshold=iou_threshold,
'''Train a deep learning tree detection model using keras-retinanet.
This is the main entry point for training a new model based on either existing weights or scratch
Args:
annotations (str): Path to csv label file, labels are in the format -> path/to/image.jpg,x1,y1,x2,y2,class_name
comet_experiment: A comet ml object to log images. Optional.
list_of_tfrecords: Ignored if input_type != "tfrecord", list of tf records to process
input_type: "fit_generator" or "tfrecord"
images_per_epoch: number of images to override default config of # images in annotations file / batch size. Useful for debug
Returns:
model (object): A trained keras model
prediction model: with bbox nms
trained model: without nms
'''
arg_list = utilities.format_args(annotations, self.config, images_per_epoch)
print("Training retinanet with the following args {}".format(arg_list))
#Train model
self.model, self.prediction_model, self.training_model = retinanet_train(args=arg_list, input_type = input_type, list_of_tfrecords = list_of_tfrecords, comet_experiment = comet_experiment)