Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
jwt_method=jwt_method,
endpoints=endpoints,
)
else:
app = Sanic(__name__, configure_logging=False)
CORS(app, resources={r"/*": {"origins": cors or ""}}, automatic_options=True)
configure_file_logging(log_file)
if input_channels:
rasa.core.channels.channel.register(input_channels, app, route=route)
else:
input_channels = []
if logger.isEnabledFor(logging.DEBUG):
utils.list_routes(app)
# configure async loop logging
async def configure_async_logging():
if logger.isEnabledFor(logging.DEBUG):
rasa.utils.io.enable_async_loop_debugging(asyncio.get_event_loop())
app.add_task(configure_async_logging)
if "cmdline" in {c.name() for c in input_channels}:
async def run_cmdline_io(running_app: Sanic):
"""Small wrapper to shut down the server once cmd io is done."""
await asyncio.sleep(1) # allow server to start
await console.record_messages(
server_url=constants.DEFAULT_SERVER_FORMAT.format(port)
)
def persist_clean(self, filename: Text) -> None:
"""Write cleaned domain to a file."""
cleaned_domain_data = self.cleaned_domain()
utils.dump_obj_as_yaml_to_file(filename, cleaned_domain_data)
def write_global_config_value(name: Text, value: Any) -> None:
"""Read global Rasa configuration."""
try:
os.makedirs(os.path.dirname(GLOBAL_USER_CONFIG_PATH), exist_ok=True)
c = read_global_config()
c[name] = value
rasa.core.utils.dump_obj_as_yaml_to_file(GLOBAL_USER_CONFIG_PATH, c)
except Exception as e:
logger.warning(f"Failed to write global config. Error: {e}. Skipping.")
)
else:
graph.add_node(next_node_idx[0], label=utils.cap_length(cp.name))
graph.add_node(
nodes["STORY_START"], label="START", fillcolor="green", style="filled"
)
graph.add_node(nodes["STORY_END"], label="END", fillcolor="red", style="filled")
for step in self.story_steps:
next_node_idx[0] += 1
step_idx = next_node_idx[0]
graph.add_node(
next_node_idx[0],
label=utils.cap_length(step.block_name),
style="filled",
fillcolor="lightblue",
shape="rect",
)
for c in step.start_checkpoints:
ensure_checkpoint_is_drawn(c)
graph.add_edge(nodes[c.name], step_idx)
for c in step.end_checkpoints:
ensure_checkpoint_is_drawn(c)
graph.add_edge(step_idx, nodes[c.name])
if not step.end_checkpoints:
graph.add_edge(step_idx, nodes["STORY_END"])
if output_file:
) -> TrackerLookupDict:
"""This is where the augmentation magic happens.
We will reuse all the trackers that reached the
end checkpoint `None` (which is the end of a
story) and start processing all steps again. So instead
of starting with a fresh tracker, the second and
all following phases will reuse a couple of the trackers
that made their way to a story end.
We need to do some cleanup before processing them again.
"""
next_active_trackers = defaultdict(list)
if self.config.use_story_concatenation:
ending_trackers = utils.subsample_array(
story_end_trackers,
self.config.augmentation_factor,
rand=self.config.rand,
)
for t in ending_trackers:
# this is a nasty thing - all stories end and
# start with action listen - so after logging the first
# actions in the next phase the trackers would
# contain action listen followed by action listen.
# to fix this we are going to "undo" the last action listen
# tracker should be copied,
# otherwise original tracker is updated
aug_t = t.copy()
aug_t.is_augmented = True
aug_t.update(ActionReverted())
def resolve_by_type(type_name) -> Type["Slot"]:
"""Returns a slots class by its type name."""
for cls in utils.all_subclasses(Slot):
if cls.type_name == type_name:
return cls
try:
return class_from_module_path(type_name)
except (ImportError, AttributeError):
raise ValueError(
"Failed to find slot type, '{}' is neither a known type nor "
"user-defined. If you are creating your own slot type, make "
"sure its module path is correct.".format(type_name)
)
def as_yaml(self, clean_before_dump: bool = False) -> Text:
if clean_before_dump:
domain_data = self.cleaned_domain()
else:
domain_data = self.as_dict()
return utils.dump_obj_as_yaml_to_string(domain_data)
def _persist_metadata(self, path: Text) -> None:
"""Persists the domain specification to storage."""
# make sure the directory we persist exists
domain_spec_path = os.path.join(path, "metadata.json")
rasa.utils.io.create_directory_for_file(domain_spec_path)
policy_names = [utils.module_path_from_instance(p) for p in self.policies]
metadata = {
"action_fingerprints": self.action_fingerprints,
"python": ".".join([str(s) for s in sys.version_info[:3]]),
"max_histories": self._max_histories(),
"ensemble_name": self.__module__ + "." + self.__class__.__name__,
"policy_names": policy_names,
"trained_at": self.date_trained,
}
self._add_package_version_info(metadata)
rasa.utils.io.dump_obj_as_json_to_file(domain_spec_path, metadata)
def __hash__(self) -> int:
from rasa.utils.common import sort_list_of_dicts_by_first_key
self_as_dict = self.as_dict()
self_as_dict["intents"] = sort_list_of_dicts_by_first_key(
self_as_dict["intents"]
)
self_as_string = json.dumps(self_as_dict, sort_keys=True)
text_hash = utils.get_text_hash(self_as_string)
return int(text_hash, 16)
endpoints = AvailableEndpoints()
if not kwargs:
kwargs = {}
policies = config.load(policy_config)
agent = Agent(
domain_file,
generator=endpoints.nlg,
action_endpoint=endpoints.action,
interpreter=interpreter,
policies=policies,
)
data_load_args, kwargs = utils.extract_args(
kwargs,
{
"use_story_concatenation",
"unique_last_num_states",
"augmentation_factor",
"remove_duplicates",
"debug_plots",
},
)
training_data = await agent.load_data(
stories_file, exclusion_percentage=exclusion_percentage, **data_load_args
)
agent.train(training_data, **kwargs)
agent.persist(output_path, dump_stories)