Skip to content

Commit c8055f1

Browse files
fix: fix transformers warning not using GenerationConfig
1 parent d4beb52 commit c8055f1

File tree

3 files changed

+64
-29
lines changed

3 files changed

+64
-29
lines changed

graphgen/models/llm/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .api.http_client import HTTPClient
22
from .api.ollama_client import OllamaClient
33
from .api.openai_client import OpenAIClient
4+
from .local.hf_wrapper import HuggingFaceWrapper

graphgen/models/llm/local/hf_wrapper.py

Lines changed: 52 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
from typing import Any, List, Optional
22

3-
import torch
4-
from transformers import AutoModelForCausalLM, AutoTokenizer
5-
63
from graphgen.bases.base_llm_wrapper import BaseLLMWrapper
74
from graphgen.bases.datatypes import Token
85

@@ -14,24 +11,43 @@ class HuggingFaceWrapper(BaseLLMWrapper):
1411

1512
def __init__(
1613
self,
17-
model_path: str,
14+
model: str,
1815
torch_dtype="auto",
1916
device_map="auto",
2017
trust_remote_code=True,
2118
temperature=0.0,
2219
top_p=1.0,
2320
topk=5,
24-
**kwargs: Any
21+
**kwargs: Any,
2522
):
2623
super().__init__(temperature=temperature, top_p=top_p, **kwargs)
24+
25+
try:
26+
import torch
27+
from transformers import (
28+
AutoModelForCausalLM,
29+
AutoTokenizer,
30+
GenerationConfig,
31+
)
32+
except ImportError as exc:
33+
raise ImportError(
34+
"HuggingFaceWrapper requires torch and transformers. "
35+
"Install them with: pip install torch transformers"
36+
) from exc
37+
38+
self.torch = torch
39+
self.AutoTokenizer = AutoTokenizer
40+
self.AutoModelForCausalLM = AutoModelForCausalLM
41+
self.GenerationConfig = GenerationConfig
42+
2743
self.tokenizer = AutoTokenizer.from_pretrained(
28-
model_path, trust_remote_code=trust_remote_code
44+
model, trust_remote_code=trust_remote_code
2945
)
3046
if self.tokenizer.pad_token is None:
3147
self.tokenizer.pad_token = self.tokenizer.eos_token
3248

3349
self.model = AutoModelForCausalLM.from_pretrained(
34-
model_path,
50+
model,
3551
torch_dtype=torch_dtype,
3652
device_map=device_map,
3753
trust_remote_code=trust_remote_code,
@@ -42,27 +58,28 @@ def __init__(
4258
self.topk = topk
4359

4460
@staticmethod
45-
def _build_inputs(prompt: str, history: Optional[List[str]] = None):
61+
def _build_inputs(prompt: str, history: Optional[List[str]] = None) -> str:
4662
msgs = history or []
4763
msgs.append(prompt)
48-
full = "\n".join(msgs)
49-
return full
64+
return "\n".join(msgs)
5065

5166
async def generate_answer(
5267
self, text: str, history: Optional[List[str]] = None, **extra: Any
5368
) -> str:
5469
full = self._build_inputs(text, history)
5570
inputs = self.tokenizer(full, return_tensors="pt").to(self.model.device)
56-
max_new = 512
57-
with torch.no_grad():
58-
out = self.model.generate(
59-
**inputs,
60-
max_new_tokens=max_new,
61-
temperature=self.temperature if self.temperature > 0 else 0.0,
62-
top_p=self.top_p if self.temperature > 0 else 1.0,
63-
do_sample=self.temperature > 0,
64-
pad_token_id=self.tokenizer.eos_token_id,
65-
)
71+
72+
gen_config = self.GenerationConfig(
73+
max_new_tokens=extra.get("max_new_tokens", 512),
74+
temperature=self.temperature if self.temperature > 0 else 1.0,
75+
top_p=self.top_p,
76+
do_sample=self.temperature > 0, # temperature==0 => greedy
77+
pad_token_id=self.tokenizer.eos_token_id,
78+
)
79+
80+
with self.torch.no_grad():
81+
out = self.model.generate(**inputs, generation_config=gen_config)
82+
6683
gen = out[0, inputs.input_ids.shape[-1] :]
6784
return self.tokenizer.decode(gen, skip_special_tokens=True)
6885

@@ -71,17 +88,21 @@ async def generate_topk_per_token(
7188
) -> List[Token]:
7289
full = self._build_inputs(text, history)
7390
inputs = self.tokenizer(full, return_tensors="pt").to(self.model.device)
74-
with torch.no_grad():
91+
92+
with self.torch.no_grad():
7593
out = self.model.generate(
7694
**inputs,
7795
max_new_tokens=1,
78-
temperature=0,
96+
temperature=0.0,
7997
return_dict_in_generate=True,
8098
output_scores=True,
99+
pad_token_id=self.tokenizer.eos_token_id,
81100
)
82-
scores = out.scores[0][0] # vocab
83-
probs = torch.softmax(scores, dim=-1)
84-
top_probs, top_idx = torch.topk(probs, k=self.topk)
101+
102+
scores = out.scores[0][0] # (vocab,)
103+
probs = self.torch.softmax(scores, dim=-1)
104+
top_probs, top_idx = self.torch.topk(probs, k=self.topk)
105+
85106
tokens = []
86107
for p, idx in zip(top_probs.cpu().numpy(), top_idx.cpu().numpy()):
87108
tokens.append(Token(self.tokenizer.decode([idx]), float(p)))
@@ -93,12 +114,15 @@ async def generate_inputs_prob(
93114
full = self._build_inputs(text, history)
94115
ids = self.tokenizer.encode(full)
95116
logprobs = []
117+
96118
for i in range(1, len(ids) + 1):
97119
trunc = ids[: i - 1] + ids[i:] if i < len(ids) else ids[:-1]
98-
inputs = torch.tensor([trunc]).to(self.model.device)
99-
with torch.no_grad():
120+
inputs = self.torch.tensor([trunc]).to(self.model.device)
121+
122+
with self.torch.no_grad():
100123
logits = self.model(inputs).logits[0, -1, :]
101-
probs = torch.softmax(logits, dim=-1)
124+
probs = self.torch.softmax(logits, dim=-1)
125+
102126
true_id = ids[i - 1]
103127
logprobs.append(
104128
Token(

graphgen/operators/init/init_llm.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import Any, Dict
33

44
from graphgen.bases import BaseLLMWrapper
5-
from graphgen.models import HTTPClient, OllamaClient, OpenAIClient, Tokenizer
5+
from graphgen.models import Tokenizer
66

77

88
class LLMFactory:
@@ -28,11 +28,21 @@ def create_llm_wrapper(backend: str, config: Dict[str, Any]) -> BaseLLMWrapper:
2828
)
2929
config["tokenizer"] = tokenizer
3030
if backend == "http_api":
31+
from graphgen.models.llm.api.http_client import HTTPClient
32+
3133
return HTTPClient(**config)
3234
if backend == "openai_api":
35+
from graphgen.models.llm.api.openai_client import OpenAIClient
36+
3337
return OpenAIClient(**config)
3438
if backend == "ollama_api":
39+
from graphgen.models.llm.api.ollama_client import OllamaClient
40+
3541
return OllamaClient(**config)
42+
if backend == "huggingface":
43+
from graphgen.models.llm.local.hf_wrapper import HuggingFaceWrapper
44+
45+
return HuggingFaceWrapper(**config)
3646
raise NotImplementedError(f"Backend {backend} is not implemented yet.")
3747

3848

0 commit comments

Comments
 (0)