Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_fetch_events_within_time_range_with_session_events():
conversation_id = "test_fetch_events_within_time_range_with_sessions"
tracker_store = SQLTrackerStore(
dialect="sqlite", db=f"{uuid.uuid4().hex}.db", domain=Domain.empty()
)
events = [
random_user_uttered_event(1),
SessionStarted(2),
ActionExecuted(timestamp=3, action_name=ACTION_SESSION_START_NAME),
random_user_uttered_event(4),
]
tracker = DialogueStateTracker.from_events(conversation_id, evts=events)
tracker_store.save(tracker)
exporter = MockExporter(tracker_store=tracker_store)
# noinspection PyProtectedMember
fetched_events = exporter._fetch_events_within_time_range()
from rasa.core.actions.action import ACTION_SESSION_START_NAME
from rasa.core.domain import Domain
from rasa.core.events import (
SessionStarted,
SlotSet,
UserUttered,
ActionExecuted,
)
from rasa.core.trackers import DialogueStateTracker
from rasa.core.training.structures import Story
domain = Domain.load("examples/moodbot/domain.yml")
def test_session_start_is_not_serialised(default_domain: Domain):
tracker = DialogueStateTracker("default", default_domain.slots)
# the retrieved tracker should be empty
assert len(tracker.events) == 0
# add SlotSet event
tracker.update(SlotSet("slot", "value"))
# add the two SessionStarted events and a user event
tracker.update(ActionExecuted(ACTION_SESSION_START_NAME))
tracker.update(SessionStarted())
tracker.update(UserUttered("say something"))
# make sure session start is not serialised
def default_domain(self):
return Domain.load(DEFAULT_DOMAIN_PATH_WITH_SLOTS)
async def test_rasa_file_importer_with_invalid_domain(tmp_path: Path):
config_file = tmp_path / "config.yml"
config_file.write_text("")
importer = TrainingDataImporter.load_from_dict({}, str(config_file), None, [])
actual = await importer.get_domain()
assert actual.as_dict() == Domain.empty().as_dict()
def test_domain_action_instantiation():
domain = Domain(
intents={},
entities=[],
slots=[],
templates={},
action_names=["my_module.ActionTest", "utter_test", "respond_test"],
form_names=[],
)
instantiated_actions = domain.actions(None)
assert len(instantiated_actions) == 12
assert instantiated_actions[0].name() == ACTION_LISTEN_NAME
assert instantiated_actions[1].name() == ACTION_RESTART_NAME
assert instantiated_actions[2].name() == ACTION_SESSION_START_NAME
assert instantiated_actions[3].name() == ACTION_DEFAULT_FALLBACK_NAME
assert instantiated_actions[4].name() == ACTION_DEACTIVATE_FORM_NAME
async def test_dispatcher_utter_buttons_from_domain_templ(default_tracker):
domain_file = "examples/moodbot/domain.yml"
domain = Domain.load(domain_file)
bot = CollectingOutputChannel()
nlg = TemplatedNaturalLanguageGenerator(domain.templates)
dispatcher = Dispatcher("my-sender", bot, nlg)
await dispatcher.utter_template("utter_greet", default_tracker)
assert len(bot.messages) == 1
assert bot.messages[0]["text"] == "Hey! How are you?"
assert bot.messages[0]["buttons"] == [
{"payload": "/mood_great", "title": "great"},
{"payload": "/mood_unhappy", "title": "super sad"},
]
def test_domain_action_instantiation():
domain = Domain(
intents={},
entities=[],
slots=[],
templates={},
action_names=["my_module.ActionTest", "utter_test", "respond_test"],
form_names=[],
)
instantiated_actions = domain.actions(None)
assert len(instantiated_actions) == 12
assert instantiated_actions[0].name() == ACTION_LISTEN_NAME
assert instantiated_actions[1].name() == ACTION_RESTART_NAME
assert instantiated_actions[2].name() == ACTION_SESSION_START_NAME
assert instantiated_actions[3].name() == ACTION_DEFAULT_FALLBACK_NAME
assert instantiated_actions[4].name() == ACTION_DEACTIVATE_FORM_NAME
def _create_from_endpoint_config(
endpoint_config: Optional[EndpointConfig] = None,
domain: Optional[Domain] = None,
event_broker: Optional[EventBroker] = None,
) -> "TrackerStore":
"""Given an endpoint configuration, create a proper tracker store object."""
domain = domain or Domain.empty()
if endpoint_config is None or endpoint_config.type is None:
# default tracker store if no type is set
tracker_store = InMemoryTrackerStore(domain, event_broker)
elif endpoint_config.type.lower() == "redis":
tracker_store = RedisTrackerStore(
domain=domain,
host=endpoint_config.url,
event_broker=event_broker,
**endpoint_config.kwargs,
)
elif endpoint_config.type.lower() == "mongod":
tracker_store = MongoTrackerStore(
domain=domain,
host=endpoint_config.url,
event_broker=event_broker,
def is_empty(self) -> bool:
"""Checks whether the domain is empty."""
return self.as_dict() == Domain.empty().as_dict()
domain_path: Text, events: List[Dict[Text, Any]], old_domain: Domain
) -> None:
"""Write an updated domain file to the file path."""
io_utils.create_path(domain_path)
messages = _collect_messages(events)
actions = _collect_actions(events)
templates = NEW_TEMPLATES # type: Dict[Text, List[Dict[Text, Any]]]
# TODO for now there is no way to distinguish between action and form
collected_actions = list(
{e["name"] for e in actions if e["name"] not in default_action_names()}
)
new_domain = Domain(
intents=_intents_from_messages(messages),
entities=_entities_from_messages(messages),
slots=[],
templates=templates,
action_names=collected_actions,
form_names=[],
)
old_domain.merge(new_domain).persist_clean(domain_path)