Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
[DialogueStateTracker.from_events("one", [])],
[deque([]), UserMessage.DEFAULT_SENDER_ID],
),
(
[
str(i)
for i in range(
interactive.MAX_NUMBER_OF_TRAINING_STORIES_FOR_VISUALIZATION + 1
)
],
[UserMessage.DEFAULT_SENDER_ID],
),
],
)
async def test_initial_plotting_call(
mock_endpoint: EndpointConfig,
monkeypatch: MonkeyPatch,
async def test_dispatcher_template_invalid_vars():
templates = {
"my_made_up_template": [
{"text": "a template referencing an invalid {variable}."}
]
}
bot = CollectingOutputChannel()
nlg = TemplatedNaturalLanguageGenerator(templates)
dispatcher = Dispatcher("my-sender", bot, nlg)
tracker = DialogueStateTracker("my-sender", slots=[])
await dispatcher.utter_template("my_made_up_template", tracker)
collected = dispatcher.output_channel.latest_output()
assert collected["text"].startswith("a template referencing an invalid {variable}.")
def test_tracker_entity_retrieval(default_domain):
tracker = DialogueStateTracker("default", default_domain.slots)
# the retrieved tracker should be empty
assert len(tracker.events) == 0
assert list(tracker.get_latest_entity_values("entity_name")) == []
intent = {"name": "greet", "confidence": 1.0}
tracker.update(
UserUttered(
"/greet",
intent,
[
{
"start": 1,
"end": 5,
"value": "greet",
"entity": "entity_name",
"extractor": "manual",
def test_revert_action_event(default_domain):
tracker = DialogueStateTracker("default", default_domain.slots)
# the retrieved tracker should be empty
assert len(tracker.events) == 0
intent = {"name": "greet", "confidence": 1.0}
tracker.update(ActionExecuted(ACTION_LISTEN_NAME))
tracker.update(UserUttered("/greet", intent, []))
tracker.update(ActionExecuted("my_action"))
tracker.update(ActionExecuted(ACTION_LISTEN_NAME))
# Expecting count of 4:
# +3 executed actions
# +1 final state
assert tracker.latest_action_name == ACTION_LISTEN_NAME
assert len(list(tracker.generate_all_prior_trackers())) == 4
tracker.update(ActionReverted())
tracker_store_kwargs: Dict,
default_domain: Domain,
):
tracker_store = tracker_store_type(default_domain, **tracker_store_kwargs)
events = [
UserUttered("Hola", {"name": "greet"}),
BotUttered("Hi"),
SessionStarted(),
UserUttered("Ciao", {"name": "greet"}),
]
sender_id = "test_sql_tracker_store_with_session_events"
tracker = DialogueStateTracker.from_events(sender_id, events)
tracker_store.save(tracker)
# Save other tracker to ensure that we don't run into problems with other senders
other_tracker = DialogueStateTracker.from_events("other-sender", [SessionStarted()])
tracker_store.save(other_tracker)
# Retrieve tracker with events since latest SessionStarted
tracker = tracker_store.retrieve(sender_id)
assert len(tracker.events) == 2
assert all((event == tracker.events[i] for i, event in enumerate(events[2:])))
def test_revert_user_utterance_event(default_domain: Domain):
tracker = DialogueStateTracker("default", default_domain.slots)
# the retrieved tracker should be empty
assert len(tracker.events) == 0
intent1 = {"name": "greet", "confidence": 1.0}
tracker.update(ActionExecuted(ACTION_LISTEN_NAME))
tracker.update(UserUttered("/greet", intent1, []))
tracker.update(ActionExecuted("my_action_1"))
tracker.update(ActionExecuted(ACTION_LISTEN_NAME))
intent2 = {"name": "goodbye", "confidence": 1.0}
tracker.update(UserUttered("/goodbye", intent2, []))
tracker.update(ActionExecuted("my_action_2"))
tracker.update(ActionExecuted(ACTION_LISTEN_NAME))
# Expecting count of 6:
# +5 executed actions
stored = self.conversations.find_one({"sender_id": sender_id})
# look for conversations which have used an `int` sender_id in the past
# and update them.
if stored is None and sender_id.isdigit():
from pymongo import ReturnDocument
stored = self.conversations.find_one_and_update(
{"sender_id": int(sender_id)},
{"$set": {"sender_id": str(sender_id)}},
return_document=ReturnDocument.AFTER,
)
if stored is not None:
if self.domain:
return DialogueStateTracker.from_dict(
sender_id, stored.get("events"), self.domain.slots
)
else:
logger.warning(
"Can't recreate tracker from mongo storage "
"because no domain is set. Returning `None` "
"instead."
)
return None
else:
return None
async def log_message(
self, message: UserMessage, should_save_tracker: bool = True
) -> Optional[DialogueStateTracker]:
"""Log `message` on tracker belonging to the message's conversation_id.
Optionally save the tracker if `should_save_tracker` is `True`. Tracker saving
can be skipped if the tracker returned by this method is used for further
processing and saved at a later stage.
"""
# preprocess message if necessary
if self.message_preprocessor is not None:
message.text = self.message_preprocessor(message.text)
# we have a Tracker instance for each user
# which maintains conversation state
tracker = await self.get_tracker_with_session_start(
message.sender_id, message.output_channel, message.metadata
)
def init_tracker(self, sender_id):
return DialogueStateTracker(
sender_id,
self.domain.slots if self.domain else None,
max_event_history=self.max_event_history,
)
def init_copy(self) -> "DialogueStateTracker":
"""Creates a new state tracker with the same initial values."""
from rasa.core.channels.channel import UserMessage
return DialogueStateTracker(
UserMessage.DEFAULT_SENDER_ID, self.slots.values(), self._max_event_history
)