diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0ee90c642..c801b96a7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -20,7 +20,7 @@ jobs: - name: Install Torch cpu run: pip install torch --index-url https://download.pytorch.org/whl/cpu - name: Install Flair dependencies - run: pip install -e . + run: pip install -e .[word-embeddings] - name: Install unittest dependencies run: pip install -r requirements-dev.txt - name: Show installed dependencies diff --git a/flair/class_utils.py b/flair/class_utils.py index 842a53387..9aa95cd1e 100644 --- a/flair/class_utils.py +++ b/flair/class_utils.py @@ -1,5 +1,7 @@ +import importlib import inspect -from typing import Iterable, Optional, Type, TypeVar +from types import ModuleType +from typing import Any, Iterable, List, Optional, Type, TypeVar, Union, overload T = TypeVar("T") @@ -17,3 +19,27 @@ def get_state_subclass_by_name(cls: Type[T], cls_name: Optional[str]) -> Type[T] if sub_cls.__name__ == cls_name: return sub_cls raise ValueError(f"Could not find any class with name '{cls_name}'") + + +@overload +def lazy_import(group: str, module: str, first_symbol: None) -> ModuleType: ... + + +@overload +def lazy_import(group: str, module: str, first_symbol: str, *symbols: str) -> List[Any]: ... + + +def lazy_import( + group: str, module: str, first_symbol: Optional[str] = None, *symbols: str +) -> Union[List[Any], ModuleType]: + try: + imported_module = importlib.import_module(module) + except ImportError: + raise ImportError( + f"Could not import {module}. Please install the optional '{group}' dependency. Via 'pip install flair[{group}]'" + ) + if first_symbol is None: + return imported_module + symbols = (first_symbol, *symbols) + + return [getattr(imported_module, symbol) for symbol in symbols] diff --git a/flair/embeddings/__init__.py b/flair/embeddings/__init__.py index 04e1d1376..308acfceb 100644 --- a/flair/embeddings/__init__.py +++ b/flair/embeddings/__init__.py @@ -40,7 +40,6 @@ # Expose token embedding classes from .token import ( - BPEmbSerializable, BytePairEmbeddings, CharacterEmbeddings, FastTextEmbeddings, diff --git a/flair/embeddings/token.py b/flair/embeddings/token.py index 7cfbd73b9..b06830580 100644 --- a/flair/embeddings/token.py +++ b/flair/embeddings/token.py @@ -1,22 +1,19 @@ import hashlib import logging -import os import re import tempfile from collections import Counter from pathlib import Path from typing import Any, Dict, List, Optional, Union -import gensim import numpy as np import torch -from bpemb import BPEmb from deprecated.sphinx import deprecated -from gensim.models import KeyedVectors -from gensim.models.fasttext import FastTextKeyedVectors, load_facebook_vectors +from sentencepiece import SentencePieceProcessor from torch import nn import flair +from flair.class_utils import lazy_import from flair.data import Corpus, Dictionary, Sentence, _iter_dataset from flair.embeddings.base import TokenEmbeddings, load_embeddings, register_embeddings from flair.embeddings.transformer import ( @@ -165,6 +162,9 @@ def __init__( Constructor downloads required files if not there. + Note: + When loading a new embedding, you need to have `flair[gensim]` installed. + Args: embeddings: one of: 'glove', 'extvec', 'crawl' or two-letter language code or a path to a custom embedding field: if given, the word-embeddings embed the data for the specific label-type instead of the plain text. @@ -195,12 +195,13 @@ def __init__( super().__init__() if embeddings_path is not None: + (KeyedVectors,) = lazy_import("word-embeddings", "gensim.models", "KeyedVectors") if embeddings_path.suffix in [".bin", ".txt"]: - precomputed_word_embeddings = gensim.models.KeyedVectors.load_word2vec_format( + precomputed_word_embeddings = KeyedVectors.load_word2vec_format( str(embeddings_path), binary=embeddings_path.suffix == ".bin", no_header=no_header ) else: - precomputed_word_embeddings = gensim.models.KeyedVectors.load(str(embeddings_path)) + precomputed_word_embeddings = KeyedVectors.load(str(embeddings_path)) self.__embedding_length: int = precomputed_word_embeddings.vector_size @@ -218,7 +219,7 @@ def __init__( # gensim version 3 self.vocab = {k: v.index for k, v in precomputed_word_embeddings.vocab.items()} else: - # if no embedding is set, the vocab and embedding length is requried + # if no embedding is set, the vocab and embedding length is required assert vocab is not None assert embedding_length is not None self.vocab = vocab @@ -333,12 +334,6 @@ def get_cached_token_index(self, word: str) -> int: else: return len(self.vocab) # token - def get_vec(self, word: str) -> torch.Tensor: - word_embedding = self.vectors[self.get_cached_token_index(word)] - - word_embedding = torch.tensor(word_embedding.tolist(), device=flair.device, dtype=torch.float) - return word_embedding - def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: tokens = [token for sentence in sentences for token in sentence.tokens] @@ -398,7 +393,7 @@ def __setstate__(self, state: Dict[str, Any]): state.setdefault("fine_tune", False) state.setdefault("field", None) if "precomputed_word_embeddings" in state: - precomputed_word_embeddings: KeyedVectors = state.pop("precomputed_word_embeddings") + precomputed_word_embeddings = state.pop("precomputed_word_embeddings") vectors = np.vstack( ( precomputed_word_embeddings.vectors, @@ -1017,6 +1012,9 @@ def to_params(self): @register_embeddings +@deprecated( + reason="The FastTextEmbeddings are no longer supported and will be removed at version 0.16.0", version="0.14.0" +) class FastTextEmbeddings(TokenEmbeddings): """FastText Embeddings with oov functionality.""" @@ -1050,8 +1048,12 @@ def __init__( self.static_embeddings = True + FastTextKeyedVectors, load_facebook_vectors = lazy_import( + "word-embeddings", "gensim.models.fasttext", "FastTextKeyedVectors", "load_facebook_vectors" + ) + if embeddings_path.suffix == ".bin": - self.precomputed_word_embeddings: FastTextKeyedVectors = load_facebook_vectors(str(embeddings_path)) + self.precomputed_word_embeddings = load_facebook_vectors(str(embeddings_path)) else: self.precomputed_word_embeddings = FastTextKeyedVectors.load(str(embeddings_path)) @@ -1281,6 +1283,8 @@ def __init__( self.static_embeddings = True self.__embedding_length: int = 300 self.language_embeddings: Dict[str, Any] = {} + (KeyedVectors,) = lazy_import("word-embeddings", "gensim.models", "KeyedVectors") + self.kv = KeyedVectors super().__init__() self.eval() @@ -1343,7 +1347,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: embeddings_file = cached_path(f"{hu_path}/muse.{language_code}.vec.gensim", cache_dir=cache_dir) # load the model - self.language_embeddings[language_code] = gensim.models.KeyedVectors.load(str(embeddings_file)) + self.language_embeddings[language_code] = self.kv.load(str(embeddings_file)) for token, _token_idx in zip(sentence.tokens, range(len(sentence.tokens))): word_embedding = self.get_cached_vec(language_code=language_code, word=token.text) @@ -1367,47 +1371,6 @@ def to_params(self): return {} -# TODO: keep for backwards compatibility, but remove in future -@deprecated( - reason="""'BPEmbSerializable' is only used in the legacy pickle-embeddings format. - Please save your model again to save it in the serializable json format. - """, - version="0.13.0", -) -class BPEmbSerializable(BPEmb): - """Helper class to allow pickle-seralizable BPE embeddings.""" - - def __getstate__(self): - state = self.__dict__.copy() - # save the sentence piece model as binary file (not as path which may change) - with self.model_file.open(mode="rb") as fin: - state["spm_model_binary"] = fin.read() - state["spm"] = None - return state - - def __setstate__(self, state): - from bpemb.util import sentencepiece_load - - model_file = self.model_tpl.format(lang=state["lang"], vs=state["vs"]) - self.__dict__ = state - - # write out the binary sentence piece model into the expected directory - self.cache_dir: Path = flair.cache_root / "embeddings" - if "spm_model_binary" in self.__dict__: - # if the model was saved as binary and it is not found on disk, write to appropriate path - if not os.path.exists(self.cache_dir / state["lang"]): - os.makedirs(self.cache_dir / state["lang"]) - self.model_file = self.cache_dir / model_file - with open(self.model_file, "wb") as out: - out.write(self.__dict__["spm_model_binary"]) - else: - # otherwise, use normal process and potentially trigger another download - self.model_file = self._load_file(model_file) - - # once the modes if there, load it with sentence piece - state["spm"] = sentencepiece_load(self.model_file) - - @register_embeddings class BytePairEmbeddings(TokenEmbeddings): def __init__( @@ -1419,6 +1382,9 @@ def __init__( model_file_path: Optional[Path] = None, embedding_file_path: Optional[Path] = None, name: Optional[str] = None, + force_cpu: bool = True, + field: Optional[str] = None, + preprocess: bool = True, **kwargs, ) -> None: """Initializes BP embeddings. @@ -1426,54 +1392,103 @@ def __init__( Constructor downloads required files if not there. """ self.instance_parameters = self.get_instance_parameters(locals=locals()) - if not cache_dir: cache_dir = flair.cache_root / "embeddings" - if language: - self.name: str = f"bpe-{language}-{syllables}-{dim}" + + if model_file_path is not None and embedding_file_path is None: + self.spm = SentencePieceProcessor() + self.spm.Load(str(model_file_path)) + vectors = np.zeros((self.spm.vocab_size() + 1, dim)) + if name is not None: + self.name = name + else: + raise ValueError( + "When only providing a SentencePieceProcessor, you need to specify a name for the BytePairEmbeddings" + ) else: - assert ( - model_file_path is not None and embedding_file_path is not None - ), "Need to specify model_file_path and embedding_file_path if no language is given in BytePairEmbeddings(...)" - dim = None # type: ignore[assignment] - - self.embedder = BPEmb( - lang=language, - vs=syllables, - dim=dim, - cache_dir=cache_dir, - model_file=model_file_path, - emb_file=embedding_file_path, - **kwargs, - ) + if not language and model_file_path is None: + raise ValueError("Need to specify model_file_path if no language is give in BytePairEmbeddings") + (BPEmb,) = lazy_import("word-embeddings", "bpemb", "BPEmb") + + if language: + self.name: str = f"bpe-{language}-{syllables}-{dim}" + embedder = BPEmb( + lang=language, + vs=syllables, + dim=dim, + cache_dir=cache_dir, + model_file=model_file_path, + emb_file=embedding_file_path, + **kwargs, + ) + else: + if model_file_path is None: + raise ValueError("Need to specify model_file_path if no language is give in BytePairEmbeddings") + embedder = BPEmb( + lang=language, + vs=syllables, + dim=dim, + cache_dir=cache_dir, + model_file=model_file_path, + emb_file=embedding_file_path, + **kwargs, + ) + self.spm = embedder.spm + vectors = np.vstack( + ( + embedder.vectors, + np.zeros(embedder.dim, dtype=embedder.vectors.dtype), + ) + ) + dim = embedder.dim + syllables = embedder.vs - if not language: - self.name = f"bpe-custom-{self.embedder.vs}-{self.embedder.dim}" + if not language: + self.name = f"bpe-custom-{syllables}-{dim}" if name is not None: self.name = name + super().__init__() + self.embedding = nn.Embedding.from_pretrained(torch.FloatTensor(vectors), freeze=True) + self.force_cpu = force_cpu self.static_embeddings = True + self.field = field + self.do_preproc = preprocess - self.__embedding_length: int = self.embedder.emb.vector_size * 2 - super().__init__() + self.__embedding_length: int = dim * 2 self.eval() + self.to(flair.device) + + def _preprocess(self, text: str) -> str: + return re.sub(r"\d", "0", text) @property def embedding_length(self) -> int: return self.__embedding_length def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: - for _i, sentence in enumerate(sentences): - for token, _token_idx in zip(sentence.tokens, range(len(sentence.tokens))): - word = token.text + tokens = [token for sentence in sentences for token in sentence.tokens] - if word.strip() == "": - # empty words get no embedding - token.set_embedding(self.name, torch.zeros(self.embedding_length, dtype=torch.float)) - else: - # all other words get embedded - embeddings = self.embedder.embed(word.lower()) - embedding = np.concatenate((embeddings[0], embeddings[len(embeddings) - 1])) - token.set_embedding(self.name, torch.tensor(embedding, dtype=torch.float)) + word_indices: List[List[int]] = [] + for token in tokens: + word = token.text if self.field is None else token.get_label(self.field).value + + if word.strip() == "": + ids = [self.spm.vocab_size(), self.embedder.spm.vocab_size()] + else: + if self.do_preproc: + word = self._preprocess(word) + ids = self.spm.EncodeAsIds(word.lower()) + ids = [ids[0], ids[-1]] + word_indices.append(ids) + + index_tensor = torch.tensor(word_indices, dtype=torch.long, device=self.device) + embeddings = self.embedding(index_tensor) + embeddings = embeddings.reshape((-1, self.embedding_length)) + if self.force_cpu: + embeddings = embeddings.to(flair.device) + + for emb, token in zip(embeddings, tokens): + token.set_embedding(self.name, emb) return sentences @@ -1489,21 +1504,69 @@ def from_params(cls, params): temp_path = Path(temp_dir) model_file_path = temp_path / "model.spm" model_file_path.write_bytes(params["spm_model_binary"]) - embedding_file_path = temp_path / "word2vec.bin" - embedding_file_path.write_bytes(params["word2vec_binary"]) - return cls(name=params["name"], model_file_path=model_file_path, embedding_file_path=embedding_file_path) - def to_params(self): - if not self.embedder.emb_file.exists(): - self.embedder.emb_file = self.embedder.emb_file.with_suffix(".bin") - self.embedder.emb.save_word2vec_format(str(self.embedder.emb_file), binary=True) + if "word2vec_binary" in params: + embedding_file_path = temp_path / "word2vec.bin" + embedding_file_path.write_bytes(params["word2vec_binary"]) + dim = None + else: + embedding_file_path = None + dim = params["dim"] + return cls( + name=params["name"], + dim=dim, + model_file_path=model_file_path, + embedding_file_path=embedding_file_path, + field=params.get("field"), + preprocess=params.get("preprocess", True), + ) + def to_params(self): return { "name": self.name, - "spm_model_binary": self.embedder.spm.serialized_model_proto(), - "word2vec_binary": self.embedder.emb_file.read_bytes(), + "spm_model_binary": self.spm.serialized_model_proto(), + "dim": self.embedding_length // 2, + "field": self.field, + "preprocess": self.do_preproc, } + def to(self, device): + if self.force_cpu: + device = torch.device("cpu") + self.device = device + super().to(device) + + def _apply(self, fn): + if fn.__name__ == "convert" and self.force_cpu: + # this is required to force the module on the cpu, + # if a parent module is put to gpu, the _apply is called to each sub_module + # self.to(..) actually sets the device properly + if not hasattr(self, "device"): + self.to(flair.device) + return + super()._apply(fn) + + def state_dict(self, *args, **kwargs): + # when loading the old versions from pickle, the embeddings might not be added as pytorch module. + # we do this delayed, when the weights are collected (e.g. for saving), as doing this earlier might + # lead to issues while loading (trying to load weights that weren't stored as python weights and therefore + # not finding them) + if list(self.modules()) == [self]: + self.embedding = self.embedding + return super().state_dict(*args, **kwargs) + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + if not state_dict: + # old embeddings do not have a torch-embedding and therefore do not store the weights in the saved torch state_dict + # however they are already initialized rightfully, so we just set the state dict from our current state dict + for k, v in self.state_dict(prefix=prefix).items(): + state_dict[k] = v + super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) + @register_embeddings class NILCEmbeddings(WordEmbeddings): diff --git a/flair/file_utils.py b/flair/file_utils.py index dfb0049b7..f7f20a20f 100644 --- a/flair/file_utils.py +++ b/flair/file_utils.py @@ -171,13 +171,6 @@ def hf_download(model_name: str) -> str: ) except HTTPError: # output information - logger.error("-" * 80) - logger.error( - f"ERROR: The key '{model_name}' was neither found on the ModelHub nor is this a valid path to a file on your system!" - ) - logger.error(" -> Please check https://huggingface.co/models?filter=flair for all available models.") - logger.error(" -> Alternatively, point to a model file on your local drive.") - logger.error("-" * 80) Path(flair.cache_root / "models" / model_folder).rmdir() # remove folder again if not valid raise diff --git a/flair/models/sequence_tagger_model.py b/flair/models/sequence_tagger_model.py index 1f2a93c68..e2ac00902 100644 --- a/flair/models/sequence_tagger_model.py +++ b/flair/models/sequence_tagger_model.py @@ -677,8 +677,6 @@ def _fetch_model(model_name) -> str: "chunk": "flair/chunk-english", "chunk-fast": "flair/chunk-english-fast", # Language-specific NER models - "ar-ner": "megantosh/flair-arabic-multi-ner", - "ar-pos": "megantosh/flair-arabic-dialects-codeswitch-egy-lev", "da-ner": "flair/ner-danish", "de-ner": "flair/ner-german", "de-ler": "flair/ner-german-legal", @@ -691,37 +689,13 @@ def _fetch_model(model_name) -> str: } hu_path: str = "https://nlp.informatik.hu-berlin.de/resources/models" - hunflair_paper_path = hu_path + "/hunflair_smallish_models" hunflair_main_path = hu_path + "/hunflair_allcorpus_models" hu_model_map = { # English NER models - "ner": "/".join([hu_path, "ner", "en-ner-conll03-v0.4.pt"]), "ner-pooled": "/".join([hu_path, "ner-pooled", "en-ner-conll03-pooled-v0.5.pt"]), - "ner-fast": "/".join([hu_path, "ner-fast", "en-ner-fast-conll03-v0.4.pt"]), - "ner-ontonotes": "/".join([hu_path, "ner-ontonotes", "en-ner-ontonotes-v0.4.pt"]), - "ner-ontonotes-fast": "/".join([hu_path, "ner-ontonotes-fast", "en-ner-ontonotes-fast-v0.4.pt"]), - # Multilingual NER models - "ner-multi": "/".join([hu_path, "multi-ner", "quadner-large.pt"]), - "multi-ner": "/".join([hu_path, "multi-ner", "quadner-large.pt"]), - "ner-multi-fast": "/".join([hu_path, "multi-ner-fast", "ner-multi-fast.pt"]), - # English POS models - "upos": "/".join([hu_path, "upos", "en-pos-ontonotes-v0.4.pt"]), - "upos-fast": "/".join([hu_path, "upos-fast", "en-upos-ontonotes-fast-v0.4.pt"]), - "pos": "/".join([hu_path, "pos", "en-pos-ontonotes-v0.5.pt"]), - "pos-fast": "/".join([hu_path, "pos-fast", "en-pos-ontonotes-fast-v0.5.pt"]), - # Multilingual POS models - "pos-multi": "/".join([hu_path, "multi-pos", "pos-multi-v0.1.pt"]), - "multi-pos": "/".join([hu_path, "multi-pos", "pos-multi-v0.1.pt"]), - "pos-multi-fast": "/".join([hu_path, "multi-pos-fast", "pos-multi-fast.pt"]), - "multi-pos-fast": "/".join([hu_path, "multi-pos-fast", "pos-multi-fast.pt"]), # English SRL models - "frame": "/".join([hu_path, "frame", "en-frame-ontonotes-v0.4.pt"]), - "frame-fast": "/".join([hu_path, "frame-fast", "en-frame-ontonotes-fast-v0.4.pt"]), "frame-large": "/".join([hu_path, "frame-large", "frame-large.pt"]), - # English chunking models - "chunk": "/".join([hu_path, "chunk", "en-chunk-conll2000-v0.4.pt"]), - "chunk-fast": "/".join([hu_path, "chunk-fast", "en-chunk-conll2000-fast-v0.4.pt"]), # Danish models "da-pos": "/".join([hu_path, "da-pos", "da-pos-v0.1.pt"]), "da-ner": "/".join([hu_path, "NER-danish", "da-ner-v0.1.pt"]), @@ -730,13 +704,14 @@ def _fetch_model(model_name) -> str: "de-pos-tweets": "/".join([hu_path, "de-pos-tweets", "de-pos-twitter-v0.1.pt"]), "de-ner": "/".join([hu_path, "de-ner", "de-ner-conll03-v0.4.pt"]), "de-ner-germeval": "/".join([hu_path, "de-ner-germeval", "de-ner-germeval-0.4.1.pt"]), - "de-ler": "/".join([hu_path, "de-ner-legal", "de-ner-legal.pt"]), - "de-ner-legal": "/".join([hu_path, "de-ner-legal", "de-ner-legal.pt"]), + # Arabic models + "ar-ner": "/".join([hu_path, "arabic", "ar-ner.pt"]), + "ar-pos": "/".join([hu_path, "arabic", "ar-pos.pt"]), # French models "fr-ner": "/".join([hu_path, "fr-ner", "fr-ner-wikiner-0.4.pt"]), # Dutch models "nl-ner": "/".join([hu_path, "nl-ner", "nl-ner-bert-conll02-v0.8.pt"]), - "nl-ner-rnn": "/".join([hu_path, "nl-ner-rnn", "nl-ner-conll02-v0.5.pt"]), + "nl-ner-rnn": "/".join([hu_path, "nl-ner-rnn", "nl-ner-conll02-v0.14.0.pt"]), # Malayalam models "ml-pos": "https://raw.githubusercontent.com/qburst/models-repository/master/FlairMalayalamModels/malayalam-xpos-model.pt", "ml-upos": "https://raw.githubusercontent.com/qburst/models-repository/master/FlairMalayalamModels/malayalam-upos-model.pt", @@ -748,20 +723,13 @@ def _fetch_model(model_name) -> str: "pucpr-flair-clinical-pos-tagging-best-model.pt", ] ), - # Keyphase models - "keyphrase": "/".join([hu_path, "keyphrase", "keyphrase-en-scibert.pt"]), - "negation-speculation": "/".join([hu_path, "negation-speculation", "negation-speculation-model.pt"]), + "negation-speculation": "/".join([hu_path, "negation-speculation-v14", "negation-speculation-v0.14.0.pt"]), # Biomedical models - "hunflair-paper-cellline": "/".join([hunflair_paper_path, "cellline", "hunflair-celline-v1.0.pt"]), - "hunflair-paper-chemical": "/".join([hunflair_paper_path, "chemical", "hunflair-chemical-v1.0.pt"]), - "hunflair-paper-disease": "/".join([hunflair_paper_path, "disease", "hunflair-disease-v1.0.pt"]), - "hunflair-paper-gene": "/".join([hunflair_paper_path, "gene", "hunflair-gene-v1.0.pt"]), - "hunflair-paper-species": "/".join([hunflair_paper_path, "species", "hunflair-species-v1.0.pt"]), - "hunflair-cellline": "/".join([hunflair_main_path, "cellline", "hunflair-celline-v1.0.pt"]), - "hunflair-chemical": "/".join([hunflair_main_path, "huner-chemical", "hunflair-chemical-full-v1.0.pt"]), - "hunflair-disease": "/".join([hunflair_main_path, "huner-disease", "hunflair-disease-full-v1.0.pt"]), - "hunflair-gene": "/".join([hunflair_main_path, "huner-gene", "hunflair-gene-full-v1.0.pt"]), - "hunflair-species": "/".join([hunflair_main_path, "huner-species", "hunflair-species-full-v1.1.pt"]), + "hunflair-cellline": "/".join([hunflair_main_path, "huner-cellline", "hunflair-cellline.pt"]), + "hunflair-chemical": "/".join([hunflair_main_path, "huner-chemical", "hunflair-chemical.pt"]), + "hunflair-disease": "/".join([hunflair_main_path, "huner-disease", "hunflair-disease.pt"]), + "hunflair-gene": "/".join([hunflair_main_path, "huner-gene", "hunflair-gene.pt"]), + "hunflair-species": "/".join([hunflair_main_path, "huner-species", "hunflair-species.pt"]), } cache_dir = Path("models") diff --git a/flair/nn/model.py b/flair/nn/model.py index 96b2c2d92..88f51f443 100644 --- a/flair/nn/model.py +++ b/flair/nn/model.py @@ -151,7 +151,17 @@ def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "Model": continue # if the model cannot be fetched, load as a file - state = model_path if isinstance(model_path, dict) else load_torch_state(str(model_path)) + try: + state = model_path if isinstance(model_path, dict) else load_torch_state(str(model_path)) + except Exception: + log.error("-" * 80) + log.error( + f"ERROR: The key '{model_path}' was neither found on the ModelHub nor is this a valid path to a file on your system!" + ) + log.error(" -> Please check https://huggingface.co/models?filter=flair for all available models.") + log.error(" -> Alternatively, point to a model file on your local drive.") + log.error("-" * 80) + raise ValueError(f"Could not find any model with name '{model_path}'") # try to get model class from state cls_name = state.pop("__cls__", None) diff --git a/flair/trainers/trainer.py b/flair/trainers/trainer.py index 6d9c3ec54..fb8590841 100644 --- a/flair/trainers/trainer.py +++ b/flair/trainers/trainer.py @@ -473,7 +473,7 @@ def train_custom( if inspect.isclass(sampler): sampler = sampler() # set dataset to sample from - sampler.set_dataset(train_data) # type: ignore[union-attr] + sampler.set_dataset(train_data) shuffle = False # this field stores the names of all dynamic embeddings in the model (determined after first forward pass) diff --git a/requirements.txt b/requirements.txt index fdb507e44..7f159bcd1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,8 @@ boto3>=1.20.27 -bpemb>=0.3.5 -conllu>=4.0 +conllu>=4.0,<5.0.0 deprecated>=1.2.13 ftfy>=6.1.0 gdown>=4.4.0 -gensim>=4.2.0 huggingface-hub>=0.10.0 langdetect>=1.0.9 lxml>=4.8.0 @@ -23,7 +21,6 @@ torch>=1.5.0,!=1.8 tqdm>=4.63.0 transformer-smaller-training-vocab>=0.2.3 transformers[sentencepiece]>=4.18.0,<5.0.0 -urllib3<2.0.0,>=1.0.0 # pin below 2 to make dependency resolution faster. wikipedia-api>=0.5.7 semver<4.0.0,>=3.0.0 bioc<3.0.0,>=2.0.0 diff --git a/setup.py b/setup.py index 172ab7758..17f60733f 100644 --- a/setup.py +++ b/setup.py @@ -16,6 +16,9 @@ packages=find_packages(exclude=["tests", "tests.*"]), # same as name license="MIT", install_requires=required, + extras_require={ + "word-embeddings": ["gensim>=4.2.0", "bpemb>=0.3.5"], + }, include_package_data=True, python_requires=">=3.8", ) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 52fec1c5e..2d0391b26 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -418,6 +418,7 @@ def test_load_universal_dependencies_conllu_corpus(tasks_base_path): _assert_universal_dependencies_conllu_dataset(corpus.train) +@pytest.mark.integration() def test_hipe_2022_corpus(tasks_base_path): # This test covers the complete HIPE 2022 dataset. # https://github.com/hipe-eval/HIPE-2022-data @@ -681,6 +682,7 @@ def test_hipe_2022(dataset_version="v2.1", add_document_separator=True): test_hipe_2022(dataset_version="v2.1", add_document_separator=False) +@pytest.mark.integration() def test_icdar_europeana_corpus(tasks_base_path): # This test covers the complete ICDAR Europeana corpus: # https://github.com/stefan-it/historic-domain-adaptation-icdar @@ -698,6 +700,7 @@ def check_number_sentences(reference: int, actual: int, split_name: str): check_number_sentences(len(corpus.test), gold_stats[language]["test"], "test") +@pytest.mark.integration() def test_masakhane_corpus(tasks_base_path): # This test covers the complete MasakhaNER dataset, including support for v1 and v2. supported_versions = ["v1", "v2"] @@ -781,6 +784,7 @@ def check_number_sentences(reference: int, actual: int, split_name: str, languag check_number_sentences(len(corpus.test), gold_stats["test"], "test", language, version) +@pytest.mark.integration() def test_nermud_corpus(tasks_base_path): # This test covers the NERMuD dataset. Official stats can be found here: # https://github.com/dhfbk/KIND/tree/main/evalita-2023 @@ -808,6 +812,7 @@ def test_german_ler_corpus(tasks_base_path): assert len(corpus.test) == 6673, "Mismatch in number of sentences for test split" +@pytest.mark.integration() def test_masakha_pos_corpus(tasks_base_path): # This test covers the complete MasakhaPOS dataset. supported_versions = ["v1"] @@ -876,6 +881,7 @@ def check_number_sentences(reference: int, actual: int, split_name: str, languag check_number_sentences(len(corpus.test), gold_stats["test"], "test", language, version) +@pytest.mark.integration() def test_german_mobie(tasks_base_path): corpus = flair.datasets.NER_GERMAN_MOBIE() @@ -960,6 +966,7 @@ def test_jsonl_corpus_loads_metadata(tasks_base_path): assert dataset.sentences[2].get_metadata("from") == 125 +@pytest.mark.integration() def test_ontonotes_download(): from urllib.parse import urlparse