Skip to content

Commit a9b0d95

Browse files
authored
Merge branch 'main' into fix-multiple-tokenizers-saved
2 parents 6108c54 + 090a894 commit a9b0d95

File tree

12 files changed

+83
-35
lines changed

12 files changed

+83
-35
lines changed

src/transformers/models/qwen2_audio/modeling_qwen2_audio.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from ...activations import ACT2FN
2626
from ...cache_utils import Cache
2727
from ...generation import GenerationMixin
28+
from ...masking_utils import create_bidirectional_mask
2829
from ...modeling_layers import GradientCheckpointingLayer
2930
from ...modeling_outputs import BaseModelOutput, ModelOutput
3031
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
@@ -774,14 +775,19 @@ def forward(
774775
lengths_expand = audio_feat_lengths.unsqueeze(1).expand(batch_size, max_seq_len)
775776
# Create mask
776777
padding_mask = seq_range >= lengths_expand
778+
audio_attention_mask_2d = (~padding_mask).to(dtype=torch.long, device=audio_feat_lengths.device)
777779

778-
audio_attention_mask_ = padding_mask.view(batch_size, 1, 1, max_seq_len).expand(
779-
batch_size, 1, max_seq_len, max_seq_len
780+
dummy_embeds = torch.zeros(
781+
(batch_size, max_seq_len, 1),
782+
dtype=inputs_embeds.dtype,
783+
device=inputs_embeds.device,
780784
)
781-
audio_attention_mask = audio_attention_mask_.to(
782-
dtype=self.audio_tower.conv1.weight.dtype, device=self.audio_tower.conv1.weight.device
785+
786+
audio_attention_mask = create_bidirectional_mask(
787+
config=self.audio_tower.config,
788+
input_embeds=dummy_embeds,
789+
attention_mask=audio_attention_mask_2d,
783790
)
784-
audio_attention_mask[audio_attention_mask_] = float("-inf")
785791

786792
audio_outputs = self.audio_tower(input_features, attention_mask=audio_attention_mask)
787793
selected_audio_feature = audio_outputs.last_hidden_state

src/transformers/models/qwen3_vl/modular_qwen3_vl.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1316,7 +1316,6 @@ def __call__(
13161316
video_metadata = videos_inputs.pop("video_metadata")
13171317
else:
13181318
video_metadata = videos_inputs["video_metadata"]
1319-
video_grid_thw = videos_inputs["video_grid_thw"]
13201319
else:
13211320
videos_inputs = {}
13221321
video_grid_thw = None

src/transformers/models/qwen3_vl/processing_qwen3_vl.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,6 @@ def __call__(
157157
video_metadata = videos_inputs.pop("video_metadata")
158158
else:
159159
video_metadata = videos_inputs["video_metadata"]
160-
video_grid_thw = videos_inputs["video_grid_thw"]
161160
else:
162161
videos_inputs = {}
163162
video_grid_thw = None

src/transformers/quantizers/quantizer_mxfp4.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,10 @@ def get_state_dict_and_metadata(self, model, safe_serialization: bool = False):
383383

384384
state_dict = model.state_dict()
385385

386+
# Get num_local_experts from model config
387+
num_local_experts = getattr(model.config, "num_local_experts", 32)
388+
hidden_size = getattr(model.config, "hidden_size", 2880)
389+
386390
for name, module in model.named_modules():
387391
if (
388392
isinstance(module, Mxfp4GptOssExperts)
@@ -392,7 +396,7 @@ def get_state_dict_and_metadata(self, model, safe_serialization: bool = False):
392396
state_dict[f"{name}.gate_up_proj_blocks"] = (
393397
module.gate_up_proj.storage.layout.unswizzle_data(module.gate_up_proj.storage.data)
394398
.transpose(-1, -2)
395-
.reshape(32, -1, 90, 16)
399+
.reshape(num_local_experts, -1, 90, 16)
396400
)
397401
state_dict[f"{name}.gate_up_proj_scales"] = (
398402
module.gate_up_proj_precision_config.weight_scale.storage.layout.unswizzle_data(
@@ -402,7 +406,7 @@ def get_state_dict_and_metadata(self, model, safe_serialization: bool = False):
402406
state_dict[f"{name}.down_proj_blocks"] = (
403407
module.down_proj.storage.layout.unswizzle_data(module.down_proj.storage.data)
404408
.transpose(-1, -2)
405-
.reshape(32, 2880, 90, -1)
409+
.reshape(num_local_experts, hidden_size, 90, -1)
406410
)
407411
state_dict[f"{name}.down_proj_scales"] = (
408412
module.down_proj_precision_config.weight_scale.storage.layout.unswizzle_data(

src/transformers/utils/chat_parsing_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -173,12 +173,12 @@ def recursive_parse(
173173
return parsed_schema
174174
elif isinstance(node_content, dict):
175175
for key, child_node in node_schema.get("properties", {}).items():
176-
if key in node_content:
176+
if "const" in child_node:
177+
parsed_schema[key] = child_node["const"]
178+
elif key in node_content:
177179
parsed_schema[key] = recursive_parse(node_content[key], child_node)
178180
elif "default" in child_node:
179181
parsed_schema[key] = child_node["default"]
180-
else:
181-
pass
182182
if "additionalProperties" in node_schema:
183183
for key, value in node_content.items():
184184
if key not in node_schema.get("properties", {}):

tests/models/blip_2/test_modeling_blip_2.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
require_torch,
2929
require_torch_accelerator,
3030
require_torch_fp16,
31-
require_torch_gpu,
3231
require_torch_multi_accelerator,
3332
require_vision,
3433
slow,
@@ -1734,7 +1733,7 @@ def test_inference_t5_multi_accelerator(self):
17341733
self.assertEqual(predictions[0].tolist(), expected_ids_and_text[0])
17351734
self.assertEqual(generated_text, expected_ids_and_text[1])
17361735

1737-
@require_torch_gpu
1736+
@require_torch_accelerator
17381737
def test_inference_itm(self):
17391738
model_name = "Salesforce/blip2-itm-vit-g"
17401739
processor = Blip2Processor.from_pretrained(model_name)

tests/models/falcon_h1/test_modeling_falcon_h1.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
Expectations,
2424
get_device_properties,
2525
require_torch,
26-
require_torch_gpu,
26+
require_torch_accelerator,
2727
slow,
2828
torch_device,
2929
)
@@ -400,7 +400,7 @@ def test_left_padding_compatibility(self):
400400

401401
@slow
402402
@require_torch
403-
@require_torch_gpu
403+
@require_torch_accelerator
404404
class FalconH1ModelIntegrationTest(unittest.TestCase):
405405
@slow
406406
def test_falcon_h1_hard(self):
@@ -448,10 +448,36 @@ def test_falcon_h1_hard(self):
448448
6.
449449
"""
450450

451+
EXPECTED_TEXT_XPU = """
452+
user
453+
Tell me about the french revolution.
454+
assistant
455+
The French Revolution (1789–1799) was a period of radical social and political upheaval in France that fundamentally transformed the nation and had profound effects on the rest of Europe and the world. Here are the key aspects of the revolution:
456+
457+
### **Causes**
458+
1. **Economic Crisis**: France was in severe financial trouble due to costly wars (particularly the American Revolution), extravagant spending by the monarchy, and inefficient taxation.
459+
2. **Social Inequality**: The rigid class system (the Ancien Régime) favored the nobility and clergy while the majority of the population (the Third Estate) bore the brunt of taxation and had limited rights.
460+
3. **Enlightenment Ideas**: Philosophers like Rousseau, Voltaire, and Montesquieu inspired ideas of liberty, equality, and popular sovereignty.
461+
4. **Settlement of 1789**: The Estates-General convened to address the financial crisis, leading to debates that exposed the weaknesses of the monarchy and the grievances of the common people.
462+
463+
### **Key Events**
464+
1. **Opening of the Revolution (1789)**:
465+
- **Storming of the Bastille**: A symbol of royal tyranny, marking the start of the revolution.
466+
- **Declaration of the Rights of Man and of the Citizen**: A foundational document proclaiming liberty, equality, and fraternity.
467+
468+
2. **Stages of the Revolution**:
469+
- **Staffords' Reforms (1789–1791)**: Attempts to address grievances, including the abolition of feudal privileges and the introduction of the Civil Constitution of the Church.
470+
- **Reign of Terror (1793–1794)**: Led by Maximilien Robespierre, characterized by mass executions of perceived enemies of the revolution, including King Louis XVI and Queen Marie Antoinette.
471+
- **Thermidorian Reaction (1794)**: The fall of Robespierre and the end of the Reign of Terror.
472+
473+
3. **
474+
"""
475+
451476
expected_texts = Expectations(
452477
{
453478
(None, None): EXPECTED_TEXT_DEFAULT,
454479
("cuda", 8): EXPECTED_TEXT_A10,
480+
("xpu", None): EXPECTED_TEXT_XPU,
455481
}
456482
)
457483
EXPECTED_TEXT = expected_texts.get_expectation()
@@ -466,10 +492,9 @@ def test_falcon_h1_hard(self):
466492
model_id = "tiiuae/Falcon-H1-1.5B-Deep-Instruct"
467493
tokenizer = AutoTokenizer.from_pretrained(model_id)
468494
model = FalconH1ForCausalLM.from_pretrained(model_id, dtype=torch.bfloat16, device_map="auto")
469-
device = "cuda"
470495
messages = [{"role": "user", "content": "Tell me about the french revolution."}]
471496
input_text = tokenizer.apply_chat_template(messages, tokenize=False)
472-
inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
497+
inputs = tokenizer.encode(input_text, return_tensors="pt").to(torch_device)
473498

474499
with torch.no_grad():
475500
outputs = model.generate(inputs, max_new_tokens=512, do_sample=False)

tests/models/helium/test_modeling_helium.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ class HeliumModelTest(CausalLMModelTest, unittest.TestCase):
4848

4949

5050
@slow
51-
# @require_torch_gpu
5251
class HeliumIntegrationTest(unittest.TestCase):
5352
input_text = ["Hello, today is a great day to"]
5453

tests/quantization/bitnet_integration/test_bitnet.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from transformers.testing_utils import (
2626
backend_empty_cache,
2727
require_accelerate,
28-
require_torch_gpu,
28+
require_torch_accelerator,
2929
slow,
3030
torch_device,
3131
)
@@ -39,7 +39,7 @@
3939
from accelerate import init_empty_weights
4040

4141

42-
@require_torch_gpu
42+
@require_torch_accelerator
4343
class BitNetQuantConfigTest(unittest.TestCase):
4444
def test_to_dict(self):
4545
"""
@@ -53,7 +53,7 @@ def test_to_dict(self):
5353

5454

5555
@slow
56-
@require_torch_gpu
56+
@require_torch_accelerator
5757
@require_accelerate
5858
class BitNetTest(unittest.TestCase):
5959
model_name = "HF1BitLLM/Llama3-8B-1.58-100B-tokens"
@@ -197,7 +197,7 @@ def forward(self, x):
197197

198198

199199
@slow
200-
@require_torch_gpu
200+
@require_torch_accelerator
201201
@require_accelerate
202202
class BitNetSerializationTest(unittest.TestCase):
203203
def test_model_serialization(self):

tests/trainer/test_trainer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4142,7 +4142,7 @@ def test_fp16_full_eval(self):
41424142
# perfect world: fp32_init/2 == fp16_eval
41434143
self.assertAlmostEqual(fp16_eval, fp32_init / 2, delta=5_000)
41444144

4145-
@require_torch_gpu
4145+
@require_torch_accelerator
41464146
@pytest.mark.torch_compile_test
41474147
def test_torch_compile_train(self):
41484148
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -4154,7 +4154,7 @@ def test_torch_compile_train(self):
41544154
metrics = trainer.train()
41554155
self.assertAlmostEqual(metrics.training_loss, original_train_loss)
41564156

4157-
@require_torch_gpu
4157+
@require_torch_accelerator
41584158
@pytest.mark.torch_compile_test
41594159
def test_torch_compile_eval(self):
41604160
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -4165,7 +4165,7 @@ def test_torch_compile_eval(self):
41654165
trainer = get_regression_trainer(torch_compile=True, output_dir=tmp_dir)
41664166
metrics = trainer.evaluate()
41674167

4168-
self.assertAlmostEqual(metrics["eval_loss"], original_eval_loss)
4168+
self.assertAlmostEqual(metrics["eval_loss"], original_eval_loss, delta=1e-6)
41694169

41704170
@require_torch_accelerator
41714171
@require_torch_bf16

0 commit comments

Comments
 (0)