diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0ee90c642..c801b96a7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -20,7 +20,7 @@ jobs: - name: Install Torch cpu run: pip install torch --index-url https://download.pytorch.org/whl/cpu - name: Install Flair dependencies - run: pip install -e . + run: pip install -e .[word-embeddings] - name: Install unittest dependencies run: pip install -r requirements-dev.txt - name: Show installed dependencies diff --git a/flair/class_utils.py b/flair/class_utils.py index 842a53387..9aa95cd1e 100644 --- a/flair/class_utils.py +++ b/flair/class_utils.py @@ -1,5 +1,7 @@ +import importlib import inspect -from typing import Iterable, Optional, Type, TypeVar +from types import ModuleType +from typing import Any, Iterable, List, Optional, Type, TypeVar, Union, overload T = TypeVar("T") @@ -17,3 +19,27 @@ def get_state_subclass_by_name(cls: Type[T], cls_name: Optional[str]) -> Type[T] if sub_cls.__name__ == cls_name: return sub_cls raise ValueError(f"Could not find any class with name '{cls_name}'") + + +@overload +def lazy_import(group: str, module: str, first_symbol: None) -> ModuleType: ... + + +@overload +def lazy_import(group: str, module: str, first_symbol: str, *symbols: str) -> List[Any]: ... + + +def lazy_import( + group: str, module: str, first_symbol: Optional[str] = None, *symbols: str +) -> Union[List[Any], ModuleType]: + try: + imported_module = importlib.import_module(module) + except ImportError: + raise ImportError( + f"Could not import {module}. Please install the optional '{group}' dependency. Via 'pip install flair[{group}]'" + ) + if first_symbol is None: + return imported_module + symbols = (first_symbol, *symbols) + + return [getattr(imported_module, symbol) for symbol in symbols] diff --git a/flair/embeddings/__init__.py b/flair/embeddings/__init__.py index 04e1d1376..308acfceb 100644 --- a/flair/embeddings/__init__.py +++ b/flair/embeddings/__init__.py @@ -40,7 +40,6 @@ # Expose token embedding classes from .token import ( - BPEmbSerializable, BytePairEmbeddings, CharacterEmbeddings, FastTextEmbeddings, diff --git a/flair/embeddings/token.py b/flair/embeddings/token.py index 7cfbd73b9..b06830580 100644 --- a/flair/embeddings/token.py +++ b/flair/embeddings/token.py @@ -1,22 +1,19 @@ import hashlib import logging -import os import re import tempfile from collections import Counter from pathlib import Path from typing import Any, Dict, List, Optional, Union -import gensim import numpy as np import torch -from bpemb import BPEmb from deprecated.sphinx import deprecated -from gensim.models import KeyedVectors -from gensim.models.fasttext import FastTextKeyedVectors, load_facebook_vectors +from sentencepiece import SentencePieceProcessor from torch import nn import flair +from flair.class_utils import lazy_import from flair.data import Corpus, Dictionary, Sentence, _iter_dataset from flair.embeddings.base import TokenEmbeddings, load_embeddings, register_embeddings from flair.embeddings.transformer import ( @@ -165,6 +162,9 @@ def __init__( Constructor downloads required files if not there. + Note: + When loading a new embedding, you need to have `flair[gensim]` installed. + Args: embeddings: one of: 'glove', 'extvec', 'crawl' or two-letter language code or a path to a custom embedding field: if given, the word-embeddings embed the data for the specific label-type instead of the plain text. @@ -195,12 +195,13 @@ def __init__( super().__init__() if embeddings_path is not None: + (KeyedVectors,) = lazy_import("word-embeddings", "gensim.models", "KeyedVectors") if embeddings_path.suffix in [".bin", ".txt"]: - precomputed_word_embeddings = gensim.models.KeyedVectors.load_word2vec_format( + precomputed_word_embeddings = KeyedVectors.load_word2vec_format( str(embeddings_path), binary=embeddings_path.suffix == ".bin", no_header=no_header ) else: - precomputed_word_embeddings = gensim.models.KeyedVectors.load(str(embeddings_path)) + precomputed_word_embeddings = KeyedVectors.load(str(embeddings_path)) self.__embedding_length: int = precomputed_word_embeddings.vector_size @@ -218,7 +219,7 @@ def __init__( # gensim version 3 self.vocab = {k: v.index for k, v in precomputed_word_embeddings.vocab.items()} else: - # if no embedding is set, the vocab and embedding length is requried + # if no embedding is set, the vocab and embedding length is required assert vocab is not None assert embedding_length is not None self.vocab = vocab @@ -333,12 +334,6 @@ def get_cached_token_index(self, word: str) -> int: else: return len(self.vocab) # token - def get_vec(self, word: str) -> torch.Tensor: - word_embedding = self.vectors[self.get_cached_token_index(word)] - - word_embedding = torch.tensor(word_embedding.tolist(), device=flair.device, dtype=torch.float) - return word_embedding - def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: tokens = [token for sentence in sentences for token in sentence.tokens] @@ -398,7 +393,7 @@ def __setstate__(self, state: Dict[str, Any]): state.setdefault("fine_tune", False) state.setdefault("field", None) if "precomputed_word_embeddings" in state: - precomputed_word_embeddings: KeyedVectors = state.pop("precomputed_word_embeddings") + precomputed_word_embeddings = state.pop("precomputed_word_embeddings") vectors = np.vstack( ( precomputed_word_embeddings.vectors, @@ -1017,6 +1012,9 @@ def to_params(self): @register_embeddings +@deprecated( + reason="The FastTextEmbeddings are no longer supported and will be removed at version 0.16.0", version="0.14.0" +) class FastTextEmbeddings(TokenEmbeddings): """FastText Embeddings with oov functionality.""" @@ -1050,8 +1048,12 @@ def __init__( self.static_embeddings = True + FastTextKeyedVectors, load_facebook_vectors = lazy_import( + "word-embeddings", "gensim.models.fasttext", "FastTextKeyedVectors", "load_facebook_vectors" + ) + if embeddings_path.suffix == ".bin": - self.precomputed_word_embeddings: FastTextKeyedVectors = load_facebook_vectors(str(embeddings_path)) + self.precomputed_word_embeddings = load_facebook_vectors(str(embeddings_path)) else: self.precomputed_word_embeddings = FastTextKeyedVectors.load(str(embeddings_path)) @@ -1281,6 +1283,8 @@ def __init__( self.static_embeddings = True self.__embedding_length: int = 300 self.language_embeddings: Dict[str, Any] = {} + (KeyedVectors,) = lazy_import("word-embeddings", "gensim.models", "KeyedVectors") + self.kv = KeyedVectors super().__init__() self.eval() @@ -1343,7 +1347,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: embeddings_file = cached_path(f"{hu_path}/muse.{language_code}.vec.gensim", cache_dir=cache_dir) # load the model - self.language_embeddings[language_code] = gensim.models.KeyedVectors.load(str(embeddings_file)) + self.language_embeddings[language_code] = self.kv.load(str(embeddings_file)) for token, _token_idx in zip(sentence.tokens, range(len(sentence.tokens))): word_embedding = self.get_cached_vec(language_code=language_code, word=token.text) @@ -1367,47 +1371,6 @@ def to_params(self): return {} -# TODO: keep for backwards compatibility, but remove in future -@deprecated( - reason="""'BPEmbSerializable' is only used in the legacy pickle-embeddings format. - Please save your model again to save it in the serializable json format. - """, - version="0.13.0", -) -class BPEmbSerializable(BPEmb): - """Helper class to allow pickle-seralizable BPE embeddings.""" - - def __getstate__(self): - state = self.__dict__.copy() - # save the sentence piece model as binary file (not as path which may change) - with self.model_file.open(mode="rb") as fin: - state["spm_model_binary"] = fin.read() - state["spm"] = None - return state - - def __setstate__(self, state): - from bpemb.util import sentencepiece_load - - model_file = self.model_tpl.format(lang=state["lang"], vs=state["vs"]) - self.__dict__ = state - - # write out the binary sentence piece model into the expected directory - self.cache_dir: Path = flair.cache_root / "embeddings" - if "spm_model_binary" in self.__dict__: - # if the model was saved as binary and it is not found on disk, write to appropriate path - if not os.path.exists(self.cache_dir / state["lang"]): - os.makedirs(self.cache_dir / state["lang"]) - self.model_file = self.cache_dir / model_file - with open(self.model_file, "wb") as out: - out.write(self.__dict__["spm_model_binary"]) - else: - # otherwise, use normal process and potentially trigger another download - self.model_file = self._load_file(model_file) - - # once the modes if there, load it with sentence piece - state["spm"] = sentencepiece_load(self.model_file) - - @register_embeddings class BytePairEmbeddings(TokenEmbeddings): def __init__( @@ -1419,6 +1382,9 @@ def __init__( model_file_path: Optional[Path] = None, embedding_file_path: Optional[Path] = None, name: Optional[str] = None, + force_cpu: bool = True, + field: Optional[str] = None, + preprocess: bool = True, **kwargs, ) -> None: """Initializes BP embeddings. @@ -1426,54 +1392,103 @@ def __init__( Constructor downloads required files if not there. """ self.instance_parameters = self.get_instance_parameters(locals=locals()) - if not cache_dir: cache_dir = flair.cache_root / "embeddings" - if language: - self.name: str = f"bpe-{language}-{syllables}-{dim}" + + if model_file_path is not None and embedding_file_path is None: + self.spm = SentencePieceProcessor() + self.spm.Load(str(model_file_path)) + vectors = np.zeros((self.spm.vocab_size() + 1, dim)) + if name is not None: + self.name = name + else: + raise ValueError( + "When only providing a SentencePieceProcessor, you need to specify a name for the BytePairEmbeddings" + ) else: - assert ( - model_file_path is not None and embedding_file_path is not None - ), "Need to specify model_file_path and embedding_file_path if no language is given in BytePairEmbeddings(...)" - dim = None # type: ignore[assignment] - - self.embedder = BPEmb( - lang=language, - vs=syllables, - dim=dim, - cache_dir=cache_dir, - model_file=model_file_path, - emb_file=embedding_file_path, - **kwargs, - ) + if not language and model_file_path is None: + raise ValueError("Need to specify model_file_path if no language is give in BytePairEmbeddings") + (BPEmb,) = lazy_import("word-embeddings", "bpemb", "BPEmb") + + if language: + self.name: str = f"bpe-{language}-{syllables}-{dim}" + embedder = BPEmb( + lang=language, + vs=syllables, + dim=dim, + cache_dir=cache_dir, + model_file=model_file_path, + emb_file=embedding_file_path, + **kwargs, + ) + else: + if model_file_path is None: + raise ValueError("Need to specify model_file_path if no language is give in BytePairEmbeddings") + embedder = BPEmb( + lang=language, + vs=syllables, + dim=dim, + cache_dir=cache_dir, + model_file=model_file_path, + emb_file=embedding_file_path, + **kwargs, + ) + self.spm = embedder.spm + vectors = np.vstack( + ( + embedder.vectors, + np.zeros(embedder.dim, dtype=embedder.vectors.dtype), + ) + ) + dim = embedder.dim + syllables = embedder.vs - if not language: - self.name = f"bpe-custom-{self.embedder.vs}-{self.embedder.dim}" + if not language: + self.name = f"bpe-custom-{syllables}-{dim}" if name is not None: self.name = name + super().__init__() + self.embedding = nn.Embedding.from_pretrained(torch.FloatTensor(vectors), freeze=True) + self.force_cpu = force_cpu self.static_embeddings = True + self.field = field + self.do_preproc = preprocess - self.__embedding_length: int = self.embedder.emb.vector_size * 2 - super().__init__() + self.__embedding_length: int = dim * 2 self.eval() + self.to(flair.device) + + def _preprocess(self, text: str) -> str: + return re.sub(r"\d", "0", text) @property def embedding_length(self) -> int: return self.__embedding_length def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: - for _i, sentence in enumerate(sentences): - for token, _token_idx in zip(sentence.tokens, range(len(sentence.tokens))): - word = token.text + tokens = [token for sentence in sentences for token in sentence.tokens] - if word.strip() == "": - # empty words get no embedding - token.set_embedding(self.name, torch.zeros(self.embedding_length, dtype=torch.float)) - else: - # all other words get embedded - embeddings = self.embedder.embed(word.lower()) - embedding = np.concatenate((embeddings[0], embeddings[len(embeddings) - 1])) - token.set_embedding(self.name, torch.tensor(embedding, dtype=torch.float)) + word_indices: List[List[int]] = [] + for token in tokens: + word = token.text if self.field is None else token.get_label(self.field).value + + if word.strip() == "": + ids = [self.spm.vocab_size(), self.embedder.spm.vocab_size()] + else: + if self.do_preproc: + word = self._preprocess(word) + ids = self.spm.EncodeAsIds(word.lower()) + ids = [ids[0], ids[-1]] + word_indices.append(ids) + + index_tensor = torch.tensor(word_indices, dtype=torch.long, device=self.device) + embeddings = self.embedding(index_tensor) + embeddings = embeddings.reshape((-1, self.embedding_length)) + if self.force_cpu: + embeddings = embeddings.to(flair.device) + + for emb, token in zip(embeddings, tokens): + token.set_embedding(self.name, emb) return sentences @@ -1489,21 +1504,69 @@ def from_params(cls, params): temp_path = Path(temp_dir) model_file_path = temp_path / "model.spm" model_file_path.write_bytes(params["spm_model_binary"]) - embedding_file_path = temp_path / "word2vec.bin" - embedding_file_path.write_bytes(params["word2vec_binary"]) - return cls(name=params["name"], model_file_path=model_file_path, embedding_file_path=embedding_file_path) - def to_params(self): - if not self.embedder.emb_file.exists(): - self.embedder.emb_file = self.embedder.emb_file.with_suffix(".bin") - self.embedder.emb.save_word2vec_format(str(self.embedder.emb_file), binary=True) + if "word2vec_binary" in params: + embedding_file_path = temp_path / "word2vec.bin" + embedding_file_path.write_bytes(params["word2vec_binary"]) + dim = None + else: + embedding_file_path = None + dim = params["dim"] + return cls( + name=params["name"], + dim=dim, + model_file_path=model_file_path, + embedding_file_path=embedding_file_path, + field=params.get("field"), + preprocess=params.get("preprocess", True), + ) + def to_params(self): return { "name": self.name, - "spm_model_binary": self.embedder.spm.serialized_model_proto(), - "word2vec_binary": self.embedder.emb_file.read_bytes(), + "spm_model_binary": self.spm.serialized_model_proto(), + "dim": self.embedding_length // 2, + "field": self.field, + "preprocess": self.do_preproc, } + def to(self, device): + if self.force_cpu: + device = torch.device("cpu") + self.device = device + super().to(device) + + def _apply(self, fn): + if fn.__name__ == "convert" and self.force_cpu: + # this is required to force the module on the cpu, + # if a parent module is put to gpu, the _apply is called to each sub_module + # self.to(..) actually sets the device properly + if not hasattr(self, "device"): + self.to(flair.device) + return + super()._apply(fn) + + def state_dict(self, *args, **kwargs): + # when loading the old versions from pickle, the embeddings might not be added as pytorch module. + # we do this delayed, when the weights are collected (e.g. for saving), as doing this earlier might + # lead to issues while loading (trying to load weights that weren't stored as python weights and therefore + # not finding them) + if list(self.modules()) == [self]: + self.embedding = self.embedding + return super().state_dict(*args, **kwargs) + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + if not state_dict: + # old embeddings do not have a torch-embedding and therefore do not store the weights in the saved torch state_dict + # however they are already initialized rightfully, so we just set the state dict from our current state dict + for k, v in self.state_dict(prefix=prefix).items(): + state_dict[k] = v + super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) + @register_embeddings class NILCEmbeddings(WordEmbeddings): diff --git a/requirements.txt b/requirements.txt index fdb507e44..7f159bcd1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,8 @@ boto3>=1.20.27 -bpemb>=0.3.5 -conllu>=4.0 +conllu>=4.0,<5.0.0 deprecated>=1.2.13 ftfy>=6.1.0 gdown>=4.4.0 -gensim>=4.2.0 huggingface-hub>=0.10.0 langdetect>=1.0.9 lxml>=4.8.0 @@ -23,7 +21,6 @@ torch>=1.5.0,!=1.8 tqdm>=4.63.0 transformer-smaller-training-vocab>=0.2.3 transformers[sentencepiece]>=4.18.0,<5.0.0 -urllib3<2.0.0,>=1.0.0 # pin below 2 to make dependency resolution faster. wikipedia-api>=0.5.7 semver<4.0.0,>=3.0.0 bioc<3.0.0,>=2.0.0 diff --git a/setup.py b/setup.py index 172ab7758..17f60733f 100644 --- a/setup.py +++ b/setup.py @@ -16,6 +16,9 @@ packages=find_packages(exclude=["tests", "tests.*"]), # same as name license="MIT", install_requires=required, + extras_require={ + "word-embeddings": ["gensim>=4.2.0", "bpemb>=0.3.5"], + }, include_package_data=True, python_requires=">=3.8", )