-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
83 lines (71 loc) · 2.66 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import re
# Charger le tokenizer et le modèle
tokenizer = AutoTokenizer.from_pretrained("trained_llm")
model = AutoModelForCausalLM.from_pretrained(
"trained_llm",
trust_remote_code=True,
local_files_only=True
)
model.eval()
print("✅ Modèle et tokenizer chargés")
def generate_response(prompt_text):
# Reformater le prompt pour de meilleures réponses
formatted_prompt = f"Question: {prompt_text}\nRéponse:"
# Encoder l'entrée
inputs = tokenizer(
f"{tokenizer.bos_token}{formatted_prompt}",
return_tensors="pt",
truncation=True,
max_length=512,
add_special_tokens=True,
padding=True
)
# Générer la réponse
with torch.no_grad():
outputs = model.generate(
**inputs,
max_length=50,
min_length=10,
num_return_sequences=1,
do_sample=True,
temperature=0.3,
top_k=20,
top_p=0.9,
repetition_penalty=1.4,
no_repeat_ngram_size=3,
pad_token_id=tokenizer.pad_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
use_cache=True,
num_beams=2,
early_stopping=True
)
return tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
def clean_response(text, original_prompt):
# Enlever les parties du prompt et le contexte
text = text.replace("Question:", "").replace("Réponse:", "").strip()
text = text.replace(original_prompt, "").strip()
# Nettoyer la ponctuation et les caractères spéciaux
text = re.sub(r'\s+', ' ', text)
text = re.sub(r'[^a-zA-ZÀ-ÿ\s\.,?!]', '', text)
# Enlever les segments non pertinents
text = re.sub(r'Notes? et références.*$', '', text, flags=re.IGNORECASE)
text = re.sub(r'Liens? externes?.*$', '', text, flags=re.IGNORECASE)
# Nettoyer les espaces multiples et les sauts de ligne
text = re.sub(r'\s+', ' ', text).strip()
return text if text else "Désolé, je n'ai pas de réponse claire à cette question."
# Boucle principale d'interaction
print("\nPosez votre question (ou 'q' pour quitter):")
while True:
user_input = input("> ")
if user_input.lower() == 'q':
print("Au revoir!")
break
if user_input.strip():
generated = generate_response(user_input)
cleaned = clean_response(generated, user_input)
print("\nRéponse:")
print(cleaned)
print("\nPosez votre question (ou 'q' pour quitter):")