10
10
11
11
import pytest
12
12
import torch
13
- from tensordict import LazyStackedTensorDict , NonTensorStack , TensorDict
13
+ from tensordict import (
14
+ lazy_stack ,
15
+ LazyStackedTensorDict ,
16
+ NonTensorStack ,
17
+ set_list_to_stack ,
18
+ TensorDict ,
19
+ )
14
20
from tensordict .nn import CompositeDistribution , TensorDictModule
15
21
from tensordict .nn .distributions import NormalParamExtractor
16
22
@@ -937,6 +943,38 @@ def vllm_instance(self):
937
943
tokenizer .pad_token = tokenizer .eos_token
938
944
return llm_model
939
945
946
+ @pytest .fixture (scope = "module" )
947
+ def transformers_instance (self ):
948
+ from transformers import AutoTokenizer , GPT2Config , GPT2LMHeadModel
949
+
950
+ tokenizer = AutoTokenizer .from_pretrained ("gpt2" )
951
+ model = GPT2LMHeadModel (GPT2Config ()).eval ()
952
+ # tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
953
+ # model = OPTModel(OPTConfig("facebook/opt-125m"))
954
+ # tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
955
+ # model = OPTForCausalLM(OPTConfig())
956
+
957
+ tokenizer .pad_token = tokenizer .eos_token
958
+ tokenizer .padding_side = "left"
959
+
960
+ return model , tokenizer
961
+
962
+ @pytest .fixture (scope = "module" )
963
+ def transformers_instance_pretrained (self ):
964
+ from transformers import AutoTokenizer , OPTForCausalLM
965
+
966
+ # tokenizer = AutoTokenizer.from_pretrained("gpt2")
967
+ # model = GPT2LMHeadModel(GPT2Config())
968
+ # tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
969
+ # model = OPTModel(OPTConfig("facebook/opt-125m"))
970
+ tokenizer = AutoTokenizer .from_pretrained ("facebook/opt-125m" )
971
+ model = OPTForCausalLM .from_pretrained ("facebook/opt-125m" )
972
+
973
+ tokenizer .pad_token = tokenizer .eos_token
974
+ tokenizer .padding_side = "left"
975
+
976
+ return model , tokenizer
977
+
940
978
@pytest .mark .parametrize (
941
979
"from_text, generate, return_log_probs, tokens, attention_mask" ,
942
980
[
@@ -961,22 +999,18 @@ def vllm_instance(self):
961
999
(False , True , False , torch .randint (1024 , (1 , 10 )), None ),
962
1000
],
963
1001
)
964
- def test_TransformersWrapper (
965
- self , from_text , generate , return_log_probs , tokens , attention_mask
1002
+ def test_transformers_wrapper (
1003
+ self ,
1004
+ from_text ,
1005
+ generate ,
1006
+ return_log_probs ,
1007
+ tokens ,
1008
+ attention_mask ,
1009
+ transformers_instance ,
966
1010
):
967
1011
torch .manual_seed (0 )
968
- from transformers import AutoTokenizer , GPT2Config , GPT2LMHeadModel
969
-
970
- # model_name = "distilbert-base-uncased" # or "minilm" or "albert-tiny"
971
- # Load the model and tokenizer
972
- # model = AutoModel.from_pretrained(model_name)
973
- # tokenizer = AutoTokenizer.from_pretrained(model_name)
974
1012
975
- tokenizer = AutoTokenizer .from_pretrained ("gpt2" )
976
- model = GPT2LMHeadModel (GPT2Config ())
977
-
978
- tokenizer .pad_token = tokenizer .eos_token
979
- tokenizer .padding_side = "left"
1013
+ model , tokenizer = transformers_instance
980
1014
981
1015
m = TransformersWrapper (
982
1016
model ,
@@ -1019,7 +1053,7 @@ def test_TransformersWrapper(
1019
1053
(False , True , False , torch .randint (1024 , (1 , 10 )), None ),
1020
1054
],
1021
1055
)
1022
- def test_from_vllm (
1056
+ def test_vllm_wrapper (
1023
1057
self ,
1024
1058
from_text ,
1025
1059
generate ,
@@ -1163,15 +1197,11 @@ def _run_check(
1163
1197
(True , None , None ),
1164
1198
],
1165
1199
)
1166
- def test_from_hf_logprobs (self , from_text , tokens , attention_mask ):
1200
+ def test_transformers_logprobs (
1201
+ self , from_text , tokens , attention_mask , transformers_instance
1202
+ ):
1167
1203
torch .manual_seed (0 )
1168
- from transformers import AutoTokenizer , GPT2Config , GPT2LMHeadModel
1169
-
1170
- tokenizer = AutoTokenizer .from_pretrained ("gpt2" )
1171
- model = GPT2LMHeadModel (GPT2Config ()).eval ()
1172
-
1173
- tokenizer .pad_token = tokenizer .eos_token
1174
- tokenizer .padding_side = "left"
1204
+ model , tokenizer = transformers_instance
1175
1205
1176
1206
m_generate = TransformersWrapper (
1177
1207
model ,
@@ -1201,7 +1231,7 @@ def test_from_hf_logprobs(self, from_text, tokens, attention_mask):
1201
1231
(True , False , torch .randint (1024 , (1 , 10 )), None ),
1202
1232
],
1203
1233
)
1204
- def test_from_vllm_logprobs (
1234
+ def test_vllm_logprobs (
1205
1235
self , from_text , tokens , attention_mask , pad_output , vllm_instance
1206
1236
):
1207
1237
torch .manual_seed (0 )
@@ -1254,6 +1284,7 @@ def _check_lps(
1254
1284
)
1255
1285
td_logprobs = model_logprobs (tdin_logprobs )
1256
1286
assert td_generate .log_probs .shape == td_generate .tokens_response .shape
1287
+ assert td_logprobs .log_probs .shape == td_logprobs .tokens_response .shape
1257
1288
assert td_logprobs .log_probs .shape == td_generate .tokens_response .shape
1258
1289
torch .testing .assert_close (
1259
1290
td_generate .log_probs , td_logprobs .log_probs , rtol = tol , atol = tol
@@ -1374,7 +1405,7 @@ def _run_check_collector(self, policy):
1374
1405
assert "tokens" in data
1375
1406
# assert ("next", "tokens") in data
1376
1407
1377
- def test_generate_multiple_trajs_vllm (self , vllm_instance ):
1408
+ def test_vllm_generate_multiple_trajs (self , vllm_instance ):
1378
1409
policy = vLLMWrapper (
1379
1410
vllm_instance ,
1380
1411
return_log_probs = True ,
@@ -1386,6 +1417,63 @@ def test_generate_multiple_trajs_vllm(self, vllm_instance):
1386
1417
)
1387
1418
data = policy (data )
1388
1419
1420
+ @set_list_to_stack (True )
1421
+ @pytest .mark .parametrize ("from_text" , [True , False ])
1422
+ @pytest .mark .parametrize ("generate" , [True , False ])
1423
+ def test_transformers_long_sequences (
1424
+ self , from_text , generate , transformers_instance_pretrained
1425
+ ):
1426
+ torch .manual_seed (42 )
1427
+ model , tokenizer = transformers_instance_pretrained
1428
+ prompts = [
1429
+ "The quick brown fox jumps over the lazy dog." , # Likely to finish soon
1430
+ "Once upon a time in a land far, far away, there was a" , # Likely to continue longer
1431
+ "In the beginning, the universe was created. This has made a lot of people very angry and been widely regarded as a bad move." ,
1432
+ ]
1433
+ data = lazy_stack ([TensorDict () for _ in range (len (prompts ))])
1434
+ data ["text" ] = prompts
1435
+ eos_token_id = tokenizer .convert_tokens_to_ids ("," )
1436
+ if not from_text :
1437
+ data ["tokens" ] = tokenizer (data ["text" ])["input_ids" ]
1438
+ data ["attention_mask" ] = (
1439
+ 0 * data .get ("tokens" , as_nested_tensor = True , layout = torch .strided ) + 1
1440
+ )
1441
+ if not generate :
1442
+ # we need responses
1443
+ responses = prompts [1 :] + [" et dolore magna aliqua." ]
1444
+ data ["text_response" ] = responses
1445
+ if not from_text :
1446
+ data ["tokens_response" ] = tokenizer (data ["text_response" ])["input_ids" ]
1447
+ # make sure dimensions are ragged for tokens entries
1448
+ if "tokens" in data :
1449
+ assert data .get_item_shape ("tokens" )[- 1 ] == - 1
1450
+ if "tokens_response" in data :
1451
+ assert data .get_item_shape ("tokens_response" )[- 1 ] == - 1
1452
+ generate_kwargs = {}
1453
+ if generate :
1454
+ generate_kwargs = {
1455
+ "max_new_tokens" : 128 , # Set a reasonable number of new tokens to generate
1456
+ "min_length" : 20 , # Ensure a minimum length for the generated sequence
1457
+ "pad_token_id" : tokenizer .pad_token_id , # Use the tokenizer's pad token
1458
+ "forced_eos_token_id" : eos_token_id , # Use comma as an EOS token
1459
+ }
1460
+ policy = TransformersWrapper (
1461
+ model ,
1462
+ tokenizer = tokenizer ,
1463
+ from_text = from_text ,
1464
+ generate = generate ,
1465
+ return_log_probs = True ,
1466
+ # TODO: use n trajs
1467
+ generate_kwargs = generate_kwargs ,
1468
+ )
1469
+ data_policy = policy (data )
1470
+ if "tokens" in data_policy :
1471
+ assert data_policy .get_item_shape ("tokens" )[- 1 ] == - 1
1472
+ if "tokens_response" in data_policy :
1473
+ assert (
1474
+ data_policy .get_item_shape ("tokens_response" )[- 1 ] == - 1
1475
+ ) # TODO: this fails
1476
+
1389
1477
1390
1478
if __name__ == "__main__" :
1391
1479
args , unknown = argparse .ArgumentParser ().parse_known_args ()
0 commit comments