Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
Fit using a custom set of features, including a custom feature extractor.
This is only for advanced users.
>>> clf.fit(features={
'in-gaz': {}, // gazetteer features
'contrived': lambda exa, res: {'contrived': len(exa.text) == 26}
})
"""
# create model with given params
model_config = self._get_model_config(**kwargs)
model = create_model(model_config)
if not label_set:
label_set = model_config.train_label_set
label_set = label_set if label_set else DEFAULT_TRAIN_SET_REGEX
new_hash = self._get_model_hash(model_config, queries, label_set)
cached_model = self._resource_loader.hash_to_model_path.get(new_hash)
if incremental_timestamp and cached_model:
logger.info("No need to fit. Loading previous model.")
self.load(cached_model)
return
queries, classes = self._get_queries_and_labels(queries, label_set)
if not queries:
logger.warning(
"Could not fit model since no relevant examples were found. "
'Make sure the labeled queries for training are placed in "%s" '
"files in your MindMeld project.",
def _get_model_hash(
self, model_config, queries=None, label_set=DEFAULT_TRAIN_SET_REGEX
):
"""Returns a hash representing the inputs into the model
Args:
model_config (ModelConfig): The model configuration
queries (list, optional): A list of ProcessedQuery objects, to
train. If not specified, a label set will be loaded.
label_set (list, optional): A label set to load. If not specified,
the default training set will be loaded.
Returns:
str: The hash
"""
# Hash queries
queries_hash = self._get_queries_and_labels_hash(
def _get_queries_and_labels_hash(
self, queries=None, label_set=DEFAULT_TRAIN_SET_REGEX
):
query_tree = self._get_query_tree(queries, label_set=label_set, raw=True)
queries = self._resource_loader.flatten_query_tree(query_tree)
hashable_queries = [
self.domain + "###" + self.intent + "###" + self.entity_type + "###"
] + sorted(queries)
return self._resource_loader.hash_list(hashable_queries)
def _get_query_tree(
self, queries=None, label_set=DEFAULT_TRAIN_SET_REGEX, raw=False
):
"""Returns the set of queries to train on
Args:
queries (list, optional): A list of ProcessedQuery objects, to
train. If not specified, a label set will be loaded.
label_set (list, optional): A label set to load. If not specified,
the default training set will be loaded.
raw (bool, optional): When True, raw query strings will be returned
Returns:
(list): list of queries
"""
if queries:
return self._build_query_tree(queries, domain=self.domain, raw=raw)
def _get_queries_and_labels_hash(
self, queries=None, label_set=DEFAULT_TRAIN_SET_REGEX
):
query_tree = self._get_query_tree(queries, label_set=label_set, raw=True)
queries = []
for intent in query_tree.get(self.domain, []):
for query_text in query_tree[self.domain][intent]:
queries.append(
self.domain + "###" + intent + "###" + mark_down(query_text)
)
queries.sort()
return self._resource_loader.hash_list(queries)
def _get_queries_and_labels(self, queries=None, label_set=DEFAULT_TRAIN_SET_REGEX):
"""Returns a set of queries and their labels based on the label set
Args:
queries (list, optional): A list of ProcessedQuery objects, to
train. If not specified, a label set will be loaded.
label_set (list, optional): A label set to load. If not specified,
the default training set will be loaded.
"""
query_tree = self._get_query_tree(queries, label_set=label_set)
queries = self._resource_loader.flatten_query_tree(query_tree)
raw_queries = [q.query for q in queries]
labels = [q.entities for q in queries]
return raw_queries, labels
def get_labeled_queries(
self, domain=None, intent=None, label_set=None, force_reload=False, raw=False
):
"""Gets labeled queries from the cache, or loads them from disk.
Args:
domain (str): The domain of queries to load
intent (str): The intent of queries to load
force_reload (bool): Will not load queries from the cache when True
raw (bool): Will return raw query strings instead of ProcessedQuery objects when true
Returns:
dict: ProcessedQuery objects (or strings) loaded from labeled query files, organized by
domain and intent.
"""
label_set = label_set or DEFAULT_TRAIN_SET_REGEX
query_tree = {}
loaded_key = "loaded_raw" if raw else "loaded"
file_iter = self._traverse_labeled_queries_files(domain, intent, label_set)
for a_domain, an_intent, filename in file_iter:
file_info = self.file_to_query_info[filename]
if force_reload or (
not file_info[loaded_key]
or file_info[loaded_key] < file_info["modified"]
):
# file is out of date, load it
self.load_query_file(a_domain, an_intent, filename, raw=raw)
if a_domain not in query_tree:
query_tree[a_domain] = {}
if an_intent not in query_tree[a_domain]:
def _get_queries_and_labels_hash(
self, queries=None, label_set=DEFAULT_TRAIN_SET_REGEX
):
query_tree = self._get_query_tree(queries, label_set=label_set, raw=True)
queries = []
for domain in query_tree:
for intent in query_tree[domain]:
for query_text in query_tree[domain][intent]:
queries.append(domain + "###" + mark_down(query_text))
queries.sort()
return self._resource_loader.hash_list(queries)
def _get_queries_and_labels_hash(
self, queries=None, label_set=DEFAULT_TRAIN_SET_REGEX
):
"""Returns a hashed string representing the labeled queries
Args:
queries (list, optional): A list of ProcessedQuery objects, to
train. If not specified, a label set will be loaded.
label_set (list, optional): A label set to load. If not specified,
the default training set will be loaded.
"""
raise NotImplementedError("Subclasses must implement this method")
incremental_timestamp (str, optional): The timestamp folder to cache models in
"""
logger.info(
"Fitting role classifier: domain=%r, intent=%r, entity_type=%r",
self.domain,
self.intent,
self.entity_type,
)
# create model with given params
model_config = self._get_model_config(**kwargs)
model = create_model(model_config)
if not label_set:
label_set = model_config.train_label_set
label_set = label_set if label_set else DEFAULT_TRAIN_SET_REGEX
new_hash = self._get_model_hash(model_config, queries, label_set)
cached_model = self._resource_loader.hash_to_model_path.get(new_hash)
if incremental_timestamp and cached_model:
logger.info("No need to fit. Loading previous model.")
self.load(cached_model)
return
# Load labeled data
examples, labels = self._get_queries_and_labels(queries, label_set=label_set)
if examples:
# Build roles set
self.roles = set()
for label in labels: