Skip to content
Merged
2 changes: 1 addition & 1 deletion src/megatron/bridge/models/gemma/gemma2_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> Gemma2ModelProvi
provider.query_pre_attn_scalar = hf_config.query_pre_attn_scalar
provider.attn_logit_softcapping = hf_config.attn_logit_softcapping
provider.final_logit_softcapping = hf_config.final_logit_softcapping
provider.window_size = (hf_config.sliding_window, 0)
provider.window_size = (hf_config.sliding_window - 1, 0)

provider.normalization = "RMSNorm"
provider.activation_func = fast_gelu
Expand Down
2 changes: 1 addition & 1 deletion src/megatron/bridge/models/gemma/gemma3_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def __init__(
config = copy.deepcopy(config)
if _is_local_attn_layer(layer_number, config.interleaved_attn_pattern):
# local attention, (q, k)
config.window_size = (config.window_size, 0)
config.window_size = (config.window_size - 1, 0)
else:
# global attention
config.window_size = None
Expand Down
2 changes: 1 addition & 1 deletion src/megatron/bridge/models/gpt_oss/gpt_oss_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> GPTModelProvider
provider.glu_linear_offset = 1.0

provider.softmax_type = "learnable"
provider.window_size = (128, 0)
provider.window_size = (hf_pretrained.config.sliding_window - 1, 0)
provider.window_attn_skip_freq = 2

# GPT-OSS uses intermediate_size for MoE FFN hidden size
Expand Down
2 changes: 1 addition & 1 deletion src/megatron/bridge/models/gpt_oss/gpt_oss_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class GPTOSSProvider(GPTModelProvider):
moe_ffn_hidden_size: int = 2880
moe_router_load_balancing_type: str = "none"
seq_length: int = 131072
window_size: Optional[Tuple[int, int]] = (128, 0)
window_size: Optional[Tuple[int, int]] = (127, 0)
softmax_type: Literal["vanilla", "off-by-one", "learnable"] = "learnable"
activation_func: Callable = quick_gelu
glu_linear_offset: float = 1.0
Expand Down
2 changes: 1 addition & 1 deletion src/megatron/bridge/models/mistral/mistral_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> MistralModelProv

window_size, cp_comm_type = (None, None)
if getattr(hf_config, "sliding_window", None) is not None:
window_size = [hf_config.sliding_window, 0]
window_size = [hf_config.sliding_window - 1, 0]
cp_comm_type = "a2a"

provider = cls(
Expand Down
6 changes: 3 additions & 3 deletions tests/unit_tests/models/gemma/test_gemma2_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def test_provider_bridge_gemma2_specific_features(self, mock_pretrained_gemma2_2
assert result.query_pre_attn_scalar == gemma2_2b_config.query_pre_attn_scalar
assert result.attn_logit_softcapping == gemma2_2b_config.attn_logit_softcapping
assert result.final_logit_softcapping == gemma2_2b_config.final_logit_softcapping
assert result.window_size == (gemma2_2b_config.sliding_window, 0)
assert result.window_size == (gemma2_2b_config.sliding_window - 1, 0)
assert result.add_bias_linear == False # Gemma2 doesn't use bias in linear layers
assert result.layernorm_zero_centered_gamma == True # Gemma2-specific RMSNorm behavior

Expand Down Expand Up @@ -406,8 +406,8 @@ def test_provider_bridge_sliding_window_config(self, mock_pretrained_gemma2_2b,
result = bridge.provider_bridge(mock_pretrained_gemma2_2b)

# Check sliding window configuration specific to Gemma2
assert result.window_size == (gemma2_2b_config.sliding_window, 0)
assert result.window_size == (4096, 0)
assert result.window_size == (gemma2_2b_config.sliding_window - 1, 0)
assert result.window_size == (4095, 0)

def test_provider_bridge_query_pre_attn_scalar_variants(self, mock_pretrained_gemma2_27b, gemma2_27b_config):
"""Test query_pre_attn_scalar for 27B model which has different value."""
Expand Down
1 change: 1 addition & 0 deletions tests/unit_tests/models/gpt_oss/test_gpt_oss_bridges.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def gpt_oss_cfg(self):
"torch_dtype": "bfloat16",
"vocab_size": 201088,
"hidden_act": "silu",
"sliding_window": 4096,
}

@pytest.fixture
Expand Down
Loading