Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def config():
print("Configuring tfrecord tests")
config = {}
config["patch_size"] = 200
config["patch_overlap"] = 0.05
config["annotations_xml"] = get_data("OSBS_029.xml")
config["rgb_dir"] = "tests/data"
config["annotations_file"] = "tests/data/OSBS_029.csv"
config["path_to_raster"] =get_data("OSBS_029.tif")
config["image-min-side"] = 800
config["backbone"] = "resnet50"
#Create a clean config test data
annotations = utilities.xml_to_annotations(xml_path=config["annotations_xml"])
annotations.to_csv("tests/data/testtfrecords_OSBS_029.csv",index=False)
annotations_file = preprocess.split_raster(path_to_raster=config["path_to_raster"],
annotations_file="tests/data/testtfrecords_OSBS_029.csv",
base_dir= "tests/data/",
patch_size=config["patch_size"],
patch_overlap=config["patch_overlap"])
def annotations():
annotations = utilities.xml_to_annotations(get_data("OSBS_029.xml"))
#Point at the jpg version for tfrecords
annotations.image_path = annotations.image_path.str.replace(".tif",".jpg")
annotations_file = get_data("testfile_deepforest.csv")
annotations.to_csv(annotations_file,index=False,header=False)
return annotations_file
def test_use_release():
#Download latest model from github release
release_tag, weights = utilities.use_release()
assert os.path.exists(get_data("NEON.h5"))
def config():
print("Configuring tfrecord tests")
config = {}
config["patch_size"] = 200
config["patch_overlap"] = 0.05
config["annotations_xml"] = get_data("OSBS_029.xml")
config["rgb_dir"] = "tests/data"
config["annotations_file"] = "tests/data/OSBS_029.csv"
config["path_to_raster"] =get_data("OSBS_029.tif")
config["image-min-side"] = 800
config["backbone"] = "resnet50"
#Create a clean config test data
annotations = utilities.xml_to_annotations(xml_path=config["annotations_xml"])
annotations.to_csv("tests/data/testtfrecords_OSBS_029.csv",index=False)
annotations_file = preprocess.split_raster(path_to_raster=config["path_to_raster"],
annotations_file="tests/data/testtfrecords_OSBS_029.csv",
base_dir= "tests/data/",
patch_size=config["patch_size"],
patch_overlap=config["patch_overlap"])
annotations_file.to_csv("tests/data/testfile_tfrecords.csv", index=False,header=False)
return config
def config():
config = utilities.read_config(get_data("deepforest_config.yml"))
config["patch_size"] = 200
config["patch_overlap"] = 0.25
config["annotations_xml"] = get_data("OSBS_029.xml")
config["rgb_dir"] = "tests/data"
config["annotations_file"] = "tests/data/OSBS_029.csv"
config["path_to_raster"] = get_data("OSBS_029.tif")
#Create a clean config test data
annotations = utilities.xml_to_annotations(xml_path = config["annotations_xml"])
annotations.to_csv("tests/data/OSBS_029.csv",index=False)
return config
def __init__(self, weights=None, saved_model=None):
self.weights = weights
self.saved_model = saved_model
#Read config file - if a config file exists in local dir use it, if not use installed.
if os.path.exists("deepforest_config.yml"):
config_path = "deepforest_config.yml"
else:
try:
config_path = get_data("deepforest_config.yml")
except Exception as e:
raise ValueError("No deepforest_config.yml found either in local directory or in installed package location. {}".format(e))
print("Reading config file: {}".format(config_path))
self.config = utilities.read_config(config_path)
#release version id to flag if release is being used
self.__release_version__ = None
#Load saved model if needed
if self.saved_model:
print("Loading saved model")
#Capture user warning, not relevant here
with warnings.catch_warnings():
warnings.filterwarnings("ignore",category=UserWarning)
self.model = utilities.load_model(saved_model)