-
Notifications
You must be signed in to change notification settings - Fork 0
/
embedder.py
44 lines (34 loc) · 1.33 KB
/
embedder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from __future__ import annotations
from typing import List, Dict, Any, Optional, Union, Callable, NamedTuple, overload
import zipfile
from io import TextIOWrapper
from pathlib import Path
from tqdm import tqdm
import numpy as np
import os
class Glove:
def __init__(self, word_vectors: Dict[str, np.ndarray[np.float64]], dims: int):
self.word_vectors = word_vectors
self.dims = dims
self.oov_vector = np.zeros((self.dims,), dtype=np.float64)
@classmethod
def tokenize(cls, text: str) -> List[str]:
return [x.lower().strip() for x in text.split()]
@classmethod
def from_txt(cls, text_file: Path) -> "Glove":
# Infer the dims and vocab size
_, _, vocab_size, dims_str = text_file.stem.split(".")
dims = int(dims_str.strip("d"))
with text_file.open() as f_in:
word_vectors = {}
for line in tqdm(f_in):
split_line = line.split()
word = split_line[0]
word_vectors[word] = np.array([float(val) for val in split_line[1:]])
return Glove(word_vectors, dims)
def encode(self, tokenized_txt: List[str]) -> np.ndarray:
tok_embs = [
self.word_vectors.get(token, self.oov_vector) for token in tokenized_txt
]
res = np.stack(tok_embs).mean(axis=0)
return res