Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_train_resume_text_classification_training(results_base_path, tasks_base_path):
corpus = flair.datasets.ClassificationCorpus(tasks_base_path / "imdb")
label_dict = corpus.make_label_dictionary()
embeddings: TokenEmbeddings = FlairEmbeddings("news-forward-fast")
document_embeddings: DocumentRNNEmbeddings = DocumentRNNEmbeddings(
[embeddings], 128, 1, False
)
model = TextClassifier(document_embeddings, label_dict, False)
trainer = ModelTrainer(model, corpus)
trainer.train(results_base_path, max_epochs=2, shuffle=False, checkpoint=True)
trainer = ModelTrainer.load_checkpoint(results_base_path / "checkpoint.pt", corpus)
trainer.train(results_base_path, max_epochs=2, shuffle=False, checkpoint=True)
# clean up results directory
def test_train_load_use_classifier_with_prob(results_base_path, tasks_base_path):
corpus = flair.datasets.ClassificationCorpus(tasks_base_path / "imdb")
label_dict = corpus.make_label_dictionary()
word_embedding: WordEmbeddings = WordEmbeddings("turian")
document_embeddings: DocumentRNNEmbeddings = DocumentRNNEmbeddings(
[word_embedding], 128, 1, False, 64, False, False
)
model: TextClassifier = TextClassifier(document_embeddings, label_dict, False)
trainer = ModelTrainer(model, corpus)
trainer.train(results_base_path, max_epochs=2, shuffle=False)
sentence = Sentence("Berlin is a really nice city.")
for s in model.predict(sentence, multi_class_prob=True):
for l in s.labels:
def test_load_imdb_data_max_tokens(tasks_base_path):
# get training, test and dev data
corpus = flair.datasets.ClassificationCorpus(
tasks_base_path / "imdb", in_memory=True, max_tokens_per_doc=3
)
assert len(corpus.train[0]) <= 3
assert len(corpus.dev[0]) <= 3
assert len(corpus.test[0]) <= 3
def test_load_imdb_data_streaming(tasks_base_path):
# get training, test and dev data
corpus = flair.datasets.ClassificationCorpus(
tasks_base_path / "imdb", in_memory=False
)
assert len(corpus.train) == 5
assert len(corpus.dev) == 5
assert len(corpus.test) == 5
if split == "test":
url = f"http://saifmohammad.com/WebDocs/EmoInt%20Test%20Gold%20Data/{emotion}-ratings-0to1.test.gold.txt"
path = cached_path(url, Path("datasets") / dataset_name)
with open(path, "r") as f:
with open(data_file, "w") as out:
next(f)
for line in f:
fields = line.split("\t")
out.write(f"__label__{fields[3].rstrip()} {fields[1]}\n")
os.remove(path)
class WASSA_ANGER(ClassificationCorpus):
def __init__(self, base_path: Union[str, Path] = None, in_memory: bool = False):
if type(base_path) == str:
base_path: Path = Path(base_path)
# this dataset name
dataset_name = self.__class__.__name__.lower()
# default dataset folder is the cache root
if not base_path:
base_path = Path(flair.cache_root) / "datasets"
data_folder = base_path / dataset_name
# download data if necessary
_download_wassa_if_not_there("anger", data_folder, dataset_name)
dev: Dataset = ClassificationDataset(
dev_file,
tokenizer=tokenizer,
max_tokens_per_doc=max_tokens_per_doc,
max_chars_per_doc=max_chars_per_doc,
in_memory=in_memory,
)
# otherwise, sample dev data from dev data
else:
train_length = len(train)
dev_size: int = round(train_length / 10)
splits = random_split(train, [train_length - dev_size, dev_size])
train = splits[0]
dev = splits[1]
super(ClassificationCorpus, self).__init__(
train, dev, test, name=data_folder.name
)
dataset_name = self.__class__.__name__.lower()
# default dataset folder is the cache root
if not base_path:
base_path = Path(flair.cache_root) / "datasets"
data_folder = base_path / dataset_name
# download data if necessary
_download_wassa_if_not_there("joy", data_folder, dataset_name)
super(WASSA_JOY, self).__init__(
data_folder, tokenizer=space_tokenizer, in_memory=in_memory
)
class WASSA_SADNESS(ClassificationCorpus):
def __init__(self, base_path: Union[str, Path] = None, in_memory: bool = False):
if type(base_path) == str:
base_path: Path = Path(base_path)
# this dataset name
dataset_name = self.__class__.__name__.lower()
# default dataset folder is the cache root
if not base_path:
base_path = Path(flair.cache_root) / "datasets"
data_folder = base_path / dataset_name
# download data if necessary
_download_wassa_if_not_there("sadness", data_folder, dataset_name)
dataset_name = self.__class__.__name__.lower()
# default dataset folder is the cache root
if not base_path:
base_path = Path(flair.cache_root) / "datasets"
data_folder = base_path / dataset_name
# download data if necessary
_download_wassa_if_not_there("fear", data_folder, dataset_name)
super(WASSA_FEAR, self).__init__(
data_folder, tokenizer=space_tokenizer, in_memory=in_memory
)
class WASSA_JOY(ClassificationCorpus):
def __init__(self, base_path: Union[str, Path] = None, in_memory: bool = False):
if type(base_path) == str:
base_path: Path = Path(base_path)
# this dataset name
dataset_name = self.__class__.__name__.lower()
# default dataset folder is the cache root
if not base_path:
base_path = Path(flair.cache_root) / "datasets"
data_folder = base_path / dataset_name
# download data if necessary
_download_wassa_if_not_there("joy", data_folder, dataset_name)
question = " ".join(fields[1:])
# Create flair compatible labels
# TREC-6 : NUM:dist -> __label__NUM
# TREC-50: NUM:dist -> __label__NUM:dist
new_label = "__label__"
new_label += old_label
write_fp.write(f"{new_label} {question}\n")
super(TREC_50, self).__init__(
data_folder, tokenizer=space_tokenizer, in_memory=in_memory
)
class TREC_6(ClassificationCorpus):
def __init__(self, base_path: Union[str, Path] = None, in_memory: bool = True):
if type(base_path) == str:
base_path: Path = Path(base_path)
# this dataset name
dataset_name = self.__class__.__name__.lower()
# default dataset folder is the cache root
if not base_path:
base_path = Path(flair.cache_root) / "datasets"
data_folder = base_path / dataset_name
# download data if necessary
trec_path = "https://cogcomp.seas.upenn.edu/Data/QA/QC/"