Skip to content

Commit 4107ad1

Browse files
committed
Update
[ghstack-poisoned]
1 parent ec9e7fb commit 4107ad1

File tree

3 files changed

+303
-83
lines changed

3 files changed

+303
-83
lines changed

test/test_actors.py

+113-25
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,13 @@
1010

1111
import pytest
1212
import torch
13-
from tensordict import LazyStackedTensorDict, NonTensorStack, TensorDict
13+
from tensordict import (
14+
lazy_stack,
15+
LazyStackedTensorDict,
16+
NonTensorStack,
17+
set_list_to_stack,
18+
TensorDict,
19+
)
1420
from tensordict.nn import CompositeDistribution, TensorDictModule
1521
from tensordict.nn.distributions import NormalParamExtractor
1622

@@ -937,6 +943,38 @@ def vllm_instance(self):
937943
tokenizer.pad_token = tokenizer.eos_token
938944
return llm_model
939945

946+
@pytest.fixture(scope="module")
947+
def transformers_instance(self):
948+
from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel
949+
950+
tokenizer = AutoTokenizer.from_pretrained("gpt2")
951+
model = GPT2LMHeadModel(GPT2Config()).eval()
952+
# tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
953+
# model = OPTModel(OPTConfig("facebook/opt-125m"))
954+
# tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
955+
# model = OPTForCausalLM(OPTConfig())
956+
957+
tokenizer.pad_token = tokenizer.eos_token
958+
tokenizer.padding_side = "left"
959+
960+
return model, tokenizer
961+
962+
@pytest.fixture(scope="module")
963+
def transformers_instance_pretrained(self):
964+
from transformers import AutoTokenizer, OPTForCausalLM
965+
966+
# tokenizer = AutoTokenizer.from_pretrained("gpt2")
967+
# model = GPT2LMHeadModel(GPT2Config())
968+
# tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
969+
# model = OPTModel(OPTConfig("facebook/opt-125m"))
970+
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
971+
model = OPTForCausalLM.from_pretrained("facebook/opt-125m")
972+
973+
tokenizer.pad_token = tokenizer.eos_token
974+
tokenizer.padding_side = "left"
975+
976+
return model, tokenizer
977+
940978
@pytest.mark.parametrize(
941979
"from_text, generate, return_log_probs, tokens, attention_mask",
942980
[
@@ -961,22 +999,18 @@ def vllm_instance(self):
961999
(False, True, False, torch.randint(1024, (1, 10)), None),
9621000
],
9631001
)
964-
def test_TransformersWrapper(
965-
self, from_text, generate, return_log_probs, tokens, attention_mask
1002+
def test_transformers_wrapper(
1003+
self,
1004+
from_text,
1005+
generate,
1006+
return_log_probs,
1007+
tokens,
1008+
attention_mask,
1009+
transformers_instance,
9661010
):
9671011
torch.manual_seed(0)
968-
from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel
969-
970-
# model_name = "distilbert-base-uncased" # or "minilm" or "albert-tiny"
971-
# Load the model and tokenizer
972-
# model = AutoModel.from_pretrained(model_name)
973-
# tokenizer = AutoTokenizer.from_pretrained(model_name)
9741012

975-
tokenizer = AutoTokenizer.from_pretrained("gpt2")
976-
model = GPT2LMHeadModel(GPT2Config())
977-
978-
tokenizer.pad_token = tokenizer.eos_token
979-
tokenizer.padding_side = "left"
1013+
model, tokenizer = transformers_instance
9801014

9811015
m = TransformersWrapper(
9821016
model,
@@ -1019,7 +1053,7 @@ def test_TransformersWrapper(
10191053
(False, True, False, torch.randint(1024, (1, 10)), None),
10201054
],
10211055
)
1022-
def test_from_vllm(
1056+
def test_vllm_wrapper(
10231057
self,
10241058
from_text,
10251059
generate,
@@ -1163,15 +1197,11 @@ def _run_check(
11631197
(True, None, None),
11641198
],
11651199
)
1166-
def test_from_hf_logprobs(self, from_text, tokens, attention_mask):
1200+
def test_transformers_logprobs(
1201+
self, from_text, tokens, attention_mask, transformers_instance
1202+
):
11671203
torch.manual_seed(0)
1168-
from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel
1169-
1170-
tokenizer = AutoTokenizer.from_pretrained("gpt2")
1171-
model = GPT2LMHeadModel(GPT2Config()).eval()
1172-
1173-
tokenizer.pad_token = tokenizer.eos_token
1174-
tokenizer.padding_side = "left"
1204+
model, tokenizer = transformers_instance
11751205

11761206
m_generate = TransformersWrapper(
11771207
model,
@@ -1201,7 +1231,7 @@ def test_from_hf_logprobs(self, from_text, tokens, attention_mask):
12011231
(True, False, torch.randint(1024, (1, 10)), None),
12021232
],
12031233
)
1204-
def test_from_vllm_logprobs(
1234+
def test_vllm_logprobs(
12051235
self, from_text, tokens, attention_mask, pad_output, vllm_instance
12061236
):
12071237
torch.manual_seed(0)
@@ -1254,6 +1284,7 @@ def _check_lps(
12541284
)
12551285
td_logprobs = model_logprobs(tdin_logprobs)
12561286
assert td_generate.log_probs.shape == td_generate.tokens_response.shape
1287+
assert td_logprobs.log_probs.shape == td_logprobs.tokens_response.shape
12571288
assert td_logprobs.log_probs.shape == td_generate.tokens_response.shape
12581289
torch.testing.assert_close(
12591290
td_generate.log_probs, td_logprobs.log_probs, rtol=tol, atol=tol
@@ -1374,7 +1405,7 @@ def _run_check_collector(self, policy):
13741405
assert "tokens" in data
13751406
# assert ("next", "tokens") in data
13761407

1377-
def test_generate_multiple_trajs_vllm(self, vllm_instance):
1408+
def test_vllm_generate_multiple_trajs(self, vllm_instance):
13781409
policy = vLLMWrapper(
13791410
vllm_instance,
13801411
return_log_probs=True,
@@ -1386,6 +1417,63 @@ def test_generate_multiple_trajs_vllm(self, vllm_instance):
13861417
)
13871418
data = policy(data)
13881419

1420+
@set_list_to_stack(True)
1421+
@pytest.mark.parametrize("from_text", [True, False])
1422+
@pytest.mark.parametrize("generate", [True, False])
1423+
def test_transformers_long_sequences(
1424+
self, from_text, generate, transformers_instance_pretrained
1425+
):
1426+
torch.manual_seed(42)
1427+
model, tokenizer = transformers_instance_pretrained
1428+
prompts = [
1429+
"The quick brown fox jumps over the lazy dog.", # Likely to finish soon
1430+
"Once upon a time in a land far, far away, there was a", # Likely to continue longer
1431+
"In the beginning, the universe was created. This has made a lot of people very angry and been widely regarded as a bad move.",
1432+
]
1433+
data = lazy_stack([TensorDict() for _ in range(len(prompts))])
1434+
data["text"] = prompts
1435+
eos_token_id = tokenizer.convert_tokens_to_ids(",")
1436+
if not from_text:
1437+
data["tokens"] = tokenizer(data["text"])["input_ids"]
1438+
data["attention_mask"] = (
1439+
0 * data.get("tokens", as_nested_tensor=True, layout=torch.strided) + 1
1440+
)
1441+
if not generate:
1442+
# we need responses
1443+
responses = prompts[1:] + [" et dolore magna aliqua."]
1444+
data["text_response"] = responses
1445+
if not from_text:
1446+
data["tokens_response"] = tokenizer(data["text_response"])["input_ids"]
1447+
# make sure dimensions are ragged for tokens entries
1448+
if "tokens" in data:
1449+
assert data.get_item_shape("tokens")[-1] == -1
1450+
if "tokens_response" in data:
1451+
assert data.get_item_shape("tokens_response")[-1] == -1
1452+
generate_kwargs = {}
1453+
if generate:
1454+
generate_kwargs = {
1455+
"max_new_tokens": 128, # Set a reasonable number of new tokens to generate
1456+
"min_length": 20, # Ensure a minimum length for the generated sequence
1457+
"pad_token_id": tokenizer.pad_token_id, # Use the tokenizer's pad token
1458+
"forced_eos_token_id": eos_token_id, # Use comma as an EOS token
1459+
}
1460+
policy = TransformersWrapper(
1461+
model,
1462+
tokenizer=tokenizer,
1463+
from_text=from_text,
1464+
generate=generate,
1465+
return_log_probs=True,
1466+
# TODO: use n trajs
1467+
generate_kwargs=generate_kwargs,
1468+
)
1469+
data_policy = policy(data)
1470+
if "tokens" in data_policy:
1471+
assert data_policy.get_item_shape("tokens")[-1] == -1
1472+
if "tokens_response" in data_policy:
1473+
assert (
1474+
data_policy.get_item_shape("tokens_response")[-1] == -1
1475+
) # TODO: this fails
1476+
13891477

13901478
if __name__ == "__main__":
13911479
args, unknown = argparse.ArgumentParser().parse_known_args()

torchrl/modules/llm/common.py

+19-3
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,20 @@ def get_dist(
3434
forward = TensorDictSequential.forward
3535

3636
@property
37-
def log_prob_keys(self):
38-
return ["log_probs"]
37+
def log_prob_keys(self) -> list[NestedKey]:
38+
return getattr(self, "_log_prob_keys", ["log_probs"])
3939

40-
log_prob_key = ProbabilisticTensorDictModule.log_prob_key
40+
@log_prob_keys.setter
41+
def log_prob_keys(self, value: list[NestedKey]):
42+
self._log_prob_keys = value
43+
44+
@property
45+
def log_prob_key(self) -> NestedKey:
46+
return self.log_prob_keys[0]
47+
48+
@log_prob_key.setter
49+
def log_prob_key(self, value: NestedKey) -> None:
50+
self.log_prob_keys[0] = value
4151

4252
@property
4353
def dist_params_keys(self) -> list[NestedKey]:
@@ -46,3 +56,9 @@ def dist_params_keys(self) -> list[NestedKey]:
4656
@property
4757
def dist_sample_keys(self) -> list[NestedKey]:
4858
return ["tokens_response"]
59+
60+
def log_prob(self, data: TensorDictBase, **get_kwargs) -> TensorDictBase:
61+
if not self.generate:
62+
data = self(data)
63+
return data.get(self.log_prob_key, **get_kwargs)
64+
raise RuntimeError("log_prob not callable when generate=True.")

0 commit comments

Comments
 (0)