11from typing import Any , List , Optional
22
3- import torch
4- from transformers import AutoModelForCausalLM , AutoTokenizer
5-
63from graphgen .bases .base_llm_wrapper import BaseLLMWrapper
74from graphgen .bases .datatypes import Token
85
@@ -14,24 +11,43 @@ class HuggingFaceWrapper(BaseLLMWrapper):
1411
1512 def __init__ (
1613 self ,
17- model_path : str ,
14+ model : str ,
1815 torch_dtype = "auto" ,
1916 device_map = "auto" ,
2017 trust_remote_code = True ,
2118 temperature = 0.0 ,
2219 top_p = 1.0 ,
2320 topk = 5 ,
24- ** kwargs : Any
21+ ** kwargs : Any ,
2522 ):
2623 super ().__init__ (temperature = temperature , top_p = top_p , ** kwargs )
24+
25+ try :
26+ import torch
27+ from transformers import (
28+ AutoModelForCausalLM ,
29+ AutoTokenizer ,
30+ GenerationConfig ,
31+ )
32+ except ImportError as exc :
33+ raise ImportError (
34+ "HuggingFaceWrapper requires torch and transformers. "
35+ "Install them with: pip install torch transformers"
36+ ) from exc
37+
38+ self .torch = torch
39+ self .AutoTokenizer = AutoTokenizer
40+ self .AutoModelForCausalLM = AutoModelForCausalLM
41+ self .GenerationConfig = GenerationConfig
42+
2743 self .tokenizer = AutoTokenizer .from_pretrained (
28- model_path , trust_remote_code = trust_remote_code
44+ model , trust_remote_code = trust_remote_code
2945 )
3046 if self .tokenizer .pad_token is None :
3147 self .tokenizer .pad_token = self .tokenizer .eos_token
3248
3349 self .model = AutoModelForCausalLM .from_pretrained (
34- model_path ,
50+ model ,
3551 torch_dtype = torch_dtype ,
3652 device_map = device_map ,
3753 trust_remote_code = trust_remote_code ,
@@ -42,27 +58,28 @@ def __init__(
4258 self .topk = topk
4359
4460 @staticmethod
45- def _build_inputs (prompt : str , history : Optional [List [str ]] = None ):
61+ def _build_inputs (prompt : str , history : Optional [List [str ]] = None ) -> str :
4662 msgs = history or []
4763 msgs .append (prompt )
48- full = "\n " .join (msgs )
49- return full
64+ return "\n " .join (msgs )
5065
5166 async def generate_answer (
5267 self , text : str , history : Optional [List [str ]] = None , ** extra : Any
5368 ) -> str :
5469 full = self ._build_inputs (text , history )
5570 inputs = self .tokenizer (full , return_tensors = "pt" ).to (self .model .device )
56- max_new = 512
57- with torch .no_grad ():
58- out = self .model .generate (
59- ** inputs ,
60- max_new_tokens = max_new ,
61- temperature = self .temperature if self .temperature > 0 else 0.0 ,
62- top_p = self .top_p if self .temperature > 0 else 1.0 ,
63- do_sample = self .temperature > 0 ,
64- pad_token_id = self .tokenizer .eos_token_id ,
65- )
71+
72+ gen_config = self .GenerationConfig (
73+ max_new_tokens = extra .get ("max_new_tokens" , 512 ),
74+ temperature = self .temperature if self .temperature > 0 else 1.0 ,
75+ top_p = self .top_p ,
76+ do_sample = self .temperature > 0 , # temperature==0 => greedy
77+ pad_token_id = self .tokenizer .eos_token_id ,
78+ )
79+
80+ with self .torch .no_grad ():
81+ out = self .model .generate (** inputs , generation_config = gen_config )
82+
6683 gen = out [0 , inputs .input_ids .shape [- 1 ] :]
6784 return self .tokenizer .decode (gen , skip_special_tokens = True )
6885
@@ -71,17 +88,21 @@ async def generate_topk_per_token(
7188 ) -> List [Token ]:
7289 full = self ._build_inputs (text , history )
7390 inputs = self .tokenizer (full , return_tensors = "pt" ).to (self .model .device )
74- with torch .no_grad ():
91+
92+ with self .torch .no_grad ():
7593 out = self .model .generate (
7694 ** inputs ,
7795 max_new_tokens = 1 ,
78- temperature = 0 ,
96+ temperature = 0.0 ,
7997 return_dict_in_generate = True ,
8098 output_scores = True ,
99+ pad_token_id = self .tokenizer .eos_token_id ,
81100 )
82- scores = out .scores [0 ][0 ] # vocab
83- probs = torch .softmax (scores , dim = - 1 )
84- top_probs , top_idx = torch .topk (probs , k = self .topk )
101+
102+ scores = out .scores [0 ][0 ] # (vocab,)
103+ probs = self .torch .softmax (scores , dim = - 1 )
104+ top_probs , top_idx = self .torch .topk (probs , k = self .topk )
105+
85106 tokens = []
86107 for p , idx in zip (top_probs .cpu ().numpy (), top_idx .cpu ().numpy ()):
87108 tokens .append (Token (self .tokenizer .decode ([idx ]), float (p )))
@@ -93,12 +114,15 @@ async def generate_inputs_prob(
93114 full = self ._build_inputs (text , history )
94115 ids = self .tokenizer .encode (full )
95116 logprobs = []
117+
96118 for i in range (1 , len (ids ) + 1 ):
97119 trunc = ids [: i - 1 ] + ids [i :] if i < len (ids ) else ids [:- 1 ]
98- inputs = torch .tensor ([trunc ]).to (self .model .device )
99- with torch .no_grad ():
120+ inputs = self .torch .tensor ([trunc ]).to (self .model .device )
121+
122+ with self .torch .no_grad ():
100123 logits = self .model (inputs ).logits [0 , - 1 , :]
101- probs = torch .softmax (logits , dim = - 1 )
124+ probs = self .torch .softmax (logits , dim = - 1 )
125+
102126 true_id = ids [i - 1 ]
103127 logprobs .append (
104128 Token (
0 commit comments