Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def init(tasks_base_path) -> Tuple[TaggedCorpus, TextRegressor, ModelTrainer]:
corpus = NLPTaskDataFetcher.load_corpus(NLPTask.REGRESSION, tasks_base_path)
glove_embedding: WordEmbeddings = WordEmbeddings("glove")
document_embeddings: DocumentRNNEmbeddings = DocumentRNNEmbeddings(
[glove_embedding], 128, 1, False, 64, False, False
)
model = TextRegressor(document_embeddings)
trainer = ModelTrainer(model, corpus)
return corpus, model, trainer
def test_train_load_use_classifier(results_base_path, tasks_base_path):
corpus = flair.datasets.ClassificationCorpus(tasks_base_path / "imdb")
label_dict = corpus.make_label_dictionary()
word_embedding: WordEmbeddings = WordEmbeddings("turian")
document_embeddings: DocumentRNNEmbeddings = DocumentRNNEmbeddings(
[word_embedding], 128, 1, False, 64, False, False
)
model: TextClassifier = TextClassifier(document_embeddings, label_dict, False)
trainer = ModelTrainer(model, corpus)
trainer.train(results_base_path, max_epochs=2, shuffle=False)
sentence = Sentence("Berlin is a really nice city.")
for s in model.predict(sentence):
for l in s.labels:
assert l.value is not None
assert 0.0 <= l.score <= 1.0
assert type(l.score) is float
def test_train_load_use_classifier_with_prob(results_base_path, tasks_base_path):
corpus = flair.datasets.ClassificationCorpus(tasks_base_path / "imdb")
label_dict = corpus.make_label_dictionary()
word_embedding: WordEmbeddings = WordEmbeddings("turian")
document_embeddings: DocumentRNNEmbeddings = DocumentRNNEmbeddings(
[word_embedding], 128, 1, False, 64, False, False
)
model: TextClassifier = TextClassifier(document_embeddings, label_dict, False)
trainer = ModelTrainer(model, corpus)
trainer.train(results_base_path, max_epochs=2, shuffle=False)
sentence = Sentence("Berlin is a really nice city.")
for s in model.predict(sentence, multi_class_prob=True):
for l in s.labels:
assert l.value is not None
assert 0.0 <= l.score <= 1.0
assert type(l.score) is float
def test_train_resume_text_classification_training(results_base_path, tasks_base_path):
corpus = flair.datasets.ClassificationCorpus(tasks_base_path / "imdb")
label_dict = corpus.make_label_dictionary()
embeddings: TokenEmbeddings = FlairEmbeddings("news-forward-fast")
document_embeddings: DocumentRNNEmbeddings = DocumentRNNEmbeddings(
[embeddings], 128, 1, False
)
model = TextClassifier(document_embeddings, label_dict, False)
trainer = ModelTrainer(model, corpus)
trainer.train(results_base_path, max_epochs=2, shuffle=False, checkpoint=True)
trainer = ModelTrainer.load_checkpoint(results_base_path / "checkpoint.pt", corpus)
trainer.train(results_base_path, max_epochs=2, shuffle=False, checkpoint=True)
# clean up results directory
shutil.rmtree(results_base_path)
file_path,
train_file=train,
dev_file=dev,
test_file=test,
)
# Create label dictionary from provided labels in data
label_dict = corpus.make_label_dictionary()
# Stack Flair string-embeddings with optional embeddings
word_embeddings = list(filter(None, [
stacked_embedding,
FlairEmbeddings('news-forward'),
FlairEmbeddings('news-backward'),
]))
# Initialize document embedding by passing list of word embeddings
document_embeddings = DocumentRNNEmbeddings(
word_embeddings,
hidden_size=512,
reproject_words=True,
reproject_words_dimension=256,
)
# Define classifier
classifier = TextClassifier(
document_embeddings,
label_dictionary=label_dict,
multi_label=False
)
if not checkpoint:
trainer = ModelTrainer(classifier, corpus)
else:
# If checkpoint file is defined, resume training