Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
@pytest.mark.parametrize("store", stores_to_be_tested(), ids=stores_to_be_tested_ids())
@pytest.mark.parametrize("pair", zip(TEST_DIALOGUES, EXAMPLE_DOMAINS))
def test_tracker_store(store, pair):
filename, domainpath = pair
domain = Domain.load(domainpath)
tracker = tracker_from_dialogue_file(filename, domain)
store.save(tracker)
restored = store.retrieve(tracker.sender_id)
assert restored == tracker
@pytest.fixture(scope="module")
async def trained_policy(self, featurizer, priority):
default_domain = Domain.load(DEFAULT_DOMAIN_PATH_WITH_SLOTS)
policy = self.create_policy(featurizer, priority)
training_trackers = await train_trackers(default_domain, augmentation_factor=20)
policy.train(training_trackers, default_domain)
return policy
def test_transform_intents_for_file_with_mapping():
domain_path = "data/test_domains/default_with_mapping.yml"
domain = Domain.load(domain_path)
transformed = domain._transform_intents_for_file()
expected = [
{"greet": {"triggers": "utter_greet", USE_ENTITIES_KEY: True}},
{"default": {"triggers": "utter_default", USE_ENTITIES_KEY: True}},
{"goodbye": {USE_ENTITIES_KEY: True}},
]
assert transformed == expected
def test_domain_fails_on_unknown_custom_slot_type(tmpdir, domain_unkown_slot_type):
domain_path = utilities.write_text_to_file(
tmpdir, "domain.yml", domain_unkown_slot_type
)
with pytest.raises(ValueError):
Domain.load(domain_path)
def test_dialogue_from_parameters():
domain = Domain.load("examples/restaurantbot/domain.yml")
filename = "data/test_dialogues/restaurantbot.json"
tracker = tracker_from_dialogue_file(filename, domain)
serialised_dialogue = InMemoryTrackerStore.serialise_tracker(tracker)
deserialised_dialogue = Dialogue.from_parameters(json.loads(serialised_dialogue))
assert tracker.as_dialogue().as_dict() == deserialised_dialogue.as_dict()
async def test_agent_train(trained_moodbot_path: Text):
moodbot_domain = Domain.load("examples/moodbot/domain.yml")
loaded = Agent.load(trained_moodbot_path)
# test domain
assert loaded.domain.action_names == moodbot_domain.action_names
assert loaded.domain.intents == moodbot_domain.intents
assert loaded.domain.entities == moodbot_domain.entities
assert loaded.domain.templates == moodbot_domain.templates
assert [s.name for s in loaded.domain.slots] == [
s.name for s in moodbot_domain.slots
]
# test policies
assert isinstance(loaded.policy_ensemble, SimplePolicyEnsemble)
assert [type(p) for p in loaded.policy_ensemble.policies] == [
TEDPolicy,
MemoizationPolicy,
def _validate_domain(domain_path: Text):
from rasa.core.domain import Domain, InvalidDomain
try:
Domain.load(domain_path)
except InvalidDomain as e:
print_error("The provided domain file could not be loaded. Error: {}".format(e))
sys.exit(1)
@staticmethod
def _create_domain(domain: Union[Domain, Text]) -> Domain:
if isinstance(domain, str):
domain = Domain.load(domain)
domain.check_missing_templates()
return domain
elif isinstance(domain, Domain):
return domain
elif domain is None:
return Domain.empty()
else:
raise ValueError(
"Invalid param `domain`. Expected a path to a domain "
"specification or a domain instance. But got "
"type '{}' with value '{}'".format(type(domain), domain)
)
def _create_domain(domain: Union[Domain, Text]) -> Domain:
if isinstance(domain, str):
domain = Domain.load(domain)
domain.check_missing_templates()
return domain
elif isinstance(domain, Domain):
return domain
elif domain is None:
return Domain.empty()
else:
raise ValueError(
"Invalid param `domain`. Expected a path to a domain "
"specification or a domain instance. But got "
"type '{}' with value '{}'".format(type(domain), domain)
)
async def get_domain(self) -> Domain:
domain = Domain.empty()
try:
domain = Domain.load(self._domain_path)
domain.check_missing_templates()
except InvalidDomain as e:
logger.warning(
"Loading domain from '{}' failed. Using empty domain. Error: '{}'".format(
self._domain_path, e.message
)
)
return domain