Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_registry_get_set():
test_registry = catalogue.create("test")
with pytest.raises(catalogue.RegistryError):
test_registry.get("foo")
test_registry.register("foo", func=lambda x: x)
assert "foo" in test_registry
def test_get_set():
catalogue._set(("a", "b", "c"), "test")
assert len(catalogue.REGISTRY) == 1
assert ("a", "b", "c") in catalogue.REGISTRY
assert catalogue.check_exists("a", "b", "c")
assert catalogue.REGISTRY[("a", "b", "c")] == "test"
assert catalogue._get(("a", "b", "c")) == "test"
with pytest.raises(catalogue.RegistryError):
catalogue._get(("a", "b", "d"))
with pytest.raises(catalogue.RegistryError):
catalogue._get(("a", "b", "c", "d"))
catalogue._set(("x", "y", "z1"), "test1")
catalogue._set(("x", "y", "z2"), "test2")
assert catalogue._remove(("a", "b", "c")) == "test"
catalogue._set(("x", "y2"), "test3")
with pytest.raises(catalogue.RegistryError):
catalogue._remove(("x", "y"))
assert catalogue._remove(("x", "y", "z2")) == "test2"
def _get(namespace: Sequence[str]) -> Any:
"""Get the value for a given namespace.
namespace (Sequence[str]): The namespace.
RETURNS (Any): The value for the namespace.
"""
global REGISTRY
if not all(isinstance(name, str) for name in namespace):
err = "Invalid namespace. Expected tuple of strings, but got: {}"
raise ValueError(err.format(namespace))
namespace = tuple(namespace)
if namespace not in REGISTRY:
err = "Can't get namespace {} (not in registry)".format(namespace)
raise RegistryError(err)
return REGISTRY[namespace]
def create(*namespace: str, entry_points: bool = False) -> "Registry":
"""Create a new registry.
*namespace (str): The namespace, e.g. "spacy" or "spacy", "architectures".
entry_points (bool): Accept registered functions from entry points.
RETURNS (Registry): The Registry object.
"""
if check_exists(*namespace):
raise RegistryError("Namespace already exists: {}".format(namespace))
return Registry(namespace, entry_points=entry_points)
def get(self, name: str) -> Any:
"""Get the registered function for a given name.
name (str): The name.
RETURNS (Any): The registered function.
"""
if self.entry_points:
from_entry_point = self.get_entry_point(name)
if from_entry_point:
return from_entry_point
namespace = list(self.namespace) + [name]
if not check_exists(*namespace):
err = "Cant't find '{}' in registry {}. Available names: {}"
current_namespace = " -> ".join(self.namespace)
available = ", ".join(sorted(self.get_all().keys())) or "none"
raise RegistryError(err.format(name, current_namespace, available))
return _get(namespace)
def _remove(namespace: Sequence[str]) -> Any:
"""Remove a value for a given namespace.
namespace (Sequence[str]): The namespace.
RETURNS (Any): The removed value.
"""
global REGISTRY
namespace = tuple(namespace)
if namespace not in REGISTRY:
err = "Can't get namespace {} (not in registry)".format(namespace)
raise RegistryError(err)
removed = REGISTRY[namespace]
del REGISTRY[namespace]
return removed