-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel.py
More file actions
124 lines (87 loc) · 4.31 KB
/
Copy pathmodel.py
File metadata and controls
124 lines (87 loc) · 4.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from VisionModel.vision import get_models, get_nouns, run_vision_model
from LanguageModel.Language import get_language_model, generate_text, ScenePrompt, get_prompts, evaluate
from torch.cuda import is_available
import numpy as np
class AtRM():
def __init__(self, verbose = False, lm = "tuned", in_context_learning = False):
'''Full Assistant to the Regional Manager model.
lm: "tuned", "untuned", "GPTJ"
verbose: print out model loading progress
in_context_learning: Collect Training Data to provide context to the model'''
if verbose: print("Loading Vision models...")
self.clip_model, self.processor, self.face_model, self.facecascade = get_models()
self.nouns, self.words = get_nouns()
if verbose:
print("Vision models loaded.")
print("Loading Language models...")
self.lm, self.tokenizer = get_language_model(model_name = lm)
if verbose: print("Language models loaded.")
self.device = "cuda" if is_available() else "cpu"
self.lm.to(self.device)
if in_context_learning:
self.in_context_learning = True
self.context = get_prompts()
else:
self.in_context_learning = False
self.last_prompt = None
def init_overide(self, clip_model = None, processor = None, face_model = None, facecascade = None, nouns = None, words = None, lm = None, tokenizer = None, device = None):
'''For debugging purposes, allows you to overide the models with your own.'''
if clip_model: self.clip_model = clip_model
if processor: self.processor = processor
if face_model: self.face_model = face_model
if facecascade: self.facecascade = facecascade
if nouns: self.nouns = nouns
if words: self.words = words
if lm: self.lm = lm
if tokenizer: self.tokenizer = tokenizer
if device: self.device = device
def __str__(self) -> str:
return f"Your Assistant (to) the Regional Manager is here!"
def get_context(self):
self.context = get_prompts()
def __call__(self, img_file, first_character = "", first_line = "", include_nouns = True, include_prompt = True, n_context_scene = 0):
character_vector, nouns = self.promptify_img(img_file)
self.last_prompt = ScenePrompt(
characters = character_vector,
nouns = nouns if include_nouns else [],
lines = [(first_character, first_line)] if first_character != "" else []
)
return self.generate_text(self.last_prompt, include_prompt = include_prompt, n_context_scene = n_context_scene)
def promptify_img(self, img_file):
try:
character_vector, nouns = run_vision_model(
img_file,
self.clip_model,
self.processor,
self.face_model,
self.facecascade,
self.nouns,
self.words
)
except TypeError: # For some images, the ouput of run_vision_model is None, which raises a TypeError. Not sure why this happens
print('It seems the models have produced erroneous results. Please try again with a different image. Continuing with empty prompt.')
return [], []
return character_vector, nouns
def generate_text(self, prompt, include_prompt = True, n_context_scene = 0):
assert prompt is not None, "You must provide a prompt to generate text."
text = ''
if self.in_context_learning:
for _ in range(n_context_scene):
text += self.context[np.random.randint(len(self.context))].to_text()
text += "\n\n New Scene \n\n"
elif n_context_scene > 0:
print("Warning: In context learning is not enabled for this model. n_context_scene will be ignored.")
garbage = len(text)
text += prompt.to_text()
output_w_prompt = generate_text(
text,
self.lm,
self.tokenizer,
device = self.device,
)
if include_prompt:
return output_w_prompt
else:
return output_w_prompt[prompt.len_prompt + garbage:]
def evaluate_lm(self, n = 100):
return evaluate(self.lm, self.tokenizer, test_size = n, device = self.device)