Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

longlora-paddle #9939

Open
wants to merge 9 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions llm/config/llama/longlora.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
{
"model_name_or_path": "meta-llama/Meta-Llama-3-8B",
"dataset_name_or_path": "./data",
"output_dir": "./checkpoints/lora_ckpts",
"per_device_train_batch_size": 4,
"gradient_accumulation_steps": 4,
"per_device_eval_batch_size": 8,
"eval_accumulation_steps":16,
"num_train_epochs": 1,
"learning_rate": 3e-04,
"warmup_steps": 30,
"logging_steps": 1,
"evaluation_strategy": "epoch",
"save_strategy": "epoch",
"src_length": 1024,
"max_length": 2048,
"bf16": true,
"fp16_opt_level": "O2",
"do_train": true,
"do_eval": true,
"disable_tqdm": true,
"load_best_model_at_end": true,
"eval_with_do_generation": false,
"metric_for_best_model": "accuracy",
"recompute": true,
"save_total_limit": 1,
"tensor_parallel_degree": 1,
"pipeline_parallel_degree": 1,
"sharding": "stage1",
"lora": true,
"zero_padding": false,
"use_flash_attention": true,
"unified_checkpoint": true,
"pissa": false,
"use_mora": false
}
10 changes: 10 additions & 0 deletions llm/run_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
Qwen2MoeForCausalLMPipe,
)
from paddlenlp.transformers.configuration_utils import LlmMetaConfig
from paddlenlp.transformers.longlora import replace_llama_attn, set_group_size
from paddlenlp.trl import DataConfig, ModelConfig, SFTConfig, SFTTrainer
from paddlenlp.trl.llm_utils import (
ZeroPaddingIterDatasetCallback,
Expand Down Expand Up @@ -168,6 +169,13 @@ def main():
quantization_config=quantization_config,
)

if training_args.use_ssa:
assert (
training_args.ssa_group_size_ratio is not None
), "ssa_group_size_ratio must be specified when use_ssa is True"
set_group_size(training_args.ssa_group_size_ratio)
replace_llama_attn()

architectures_to_check = {"Qwen2Moe", "DeepseekV2", "DeepseekV3"}
if (
any(architecture in str(model_config.architectures) for architecture in architectures_to_check)
Expand All @@ -192,6 +200,8 @@ def main():
model_config.fuse_attention_ffn = model_args.fuse_attention_ffn

model_config.seq_length = data_args.max_length
orig_ctx_len = getattr(model_config, "max_position_embeddings", None)
model_args.rope_scaling_factor = data_args.max_length // orig_ctx_len

# Config for model useing long sequence strategy
if model_args.use_long_sequence_strategies:
Expand Down
2 changes: 1 addition & 1 deletion llm/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def tokenize_unsupervised_example(tokenizer, example, data_args, is_test=True, z
source,
truncation=False,
padding=True,
max_length=data_args.scaled_max_length,
max_length=data_args.src_length,
add_special_tokens=True,
)

Expand Down
135 changes: 135 additions & 0 deletions paddlenlp/transformers/longlora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math

import paddle
import paddle.nn.functional as F

import paddlenlp
from paddlenlp.transformers.llama.modeling import get_triangle_upper_mask

ssa_group_size_ratio = 1 / 4


def shift(qkv, bsz, q_len, group_size, num_heads, head_dim):
assert qkv.shape == [bsz, num_heads, q_len, head_dim], "qkv shape does not match expected shape"
# Calculate the shift amount for rolling
shift_amount = -group_size // 2
# Roll the qkv tensor along the sequence length axis
qkv[:, num_heads // 2 :] = qkv[:, num_heads // 2 :].roll(shift_amount, axis=2)

# Reshape the tensor to the desired shape
qkv = qkv.reshape([bsz * (q_len // group_size), group_size, num_heads, head_dim])
return qkv


def ssa_scaled_dot_product_attention(
query_states,
config,
key_states,
value_states,
attention_mask,
output_attentions,
alibi=None,
sequence_parallel=False,
reshard_layer=None,
**kwargs
):
bsz, q_len, num_heads, head_dim = query_states.shape
if config.context_parallel_degree > 1:
raise ValueError("Context parallel requires `use_flash_attention=True`")

Check warning on line 52 in paddlenlp/transformers/longlora.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/longlora.py#L52

Added line #L52 was not covered by tests
# [ bz, seqlen, nhead, head_dim] -> [bs, nhead, seq_len, head_dim]
query_states = paddle.transpose(query_states, [0, 2, 1, 3])
# merge with the next tranpose
key_states = paddle.transpose(key_states, [0, 2, 1, 3])
value_states = paddle.transpose(value_states, [0, 2, 1, 3])
assert ssa_group_size_ratio is not None, "ssa_group_size_ratio must provide"

# Calculate the group size based on the sequence length and the group size ratio
group_size = q_len if int(q_len * ssa_group_size_ratio) == 0 else int(q_len * ssa_group_size_ratio)
assert q_len % group_size == 0, f"q_len {q_len} must be divisible by group size {group_size}."

num_group = q_len // group_size

# Apply shifting to the query, key, and value states
query_states = shift(query_states, bsz, q_len, group_size, num_heads, head_dim)
key_states = shift(key_states, bsz, q_len, group_size, num_heads, head_dim)
value_states = shift(value_states, bsz, q_len, group_size, num_heads, head_dim)
query_states = paddle.transpose(query_states, [0, 2, 1, 3])
key_states = paddle.transpose(key_states, [0, 2, 1, 3])
value_states = paddle.transpose(value_states, [0, 2, 1, 3])
# matmul and device by sqrt(head_dim)
attn_weights = paddle.matmul(query_states / math.sqrt(head_dim), key_states.transpose([0, 1, 3, 2]))

# then add alibi bias
if alibi is not None:
alibi = alibi.reshape([bsz, num_heads, 1, -1])
attn_weights = attn_weights + alibi

Check warning on line 79 in paddlenlp/transformers/longlora.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/longlora.py#L78-L79

Added lines #L78 - L79 were not covered by tests
if paddle.in_dynamic_mode() and attn_weights.shape != [bsz * num_group, num_heads, group_size, group_size]:
raise ValueError(

Check warning on line 81 in paddlenlp/transformers/longlora.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/longlora.py#L81

Added line #L81 was not covered by tests
f"Attention weights should be of shape {(bsz * num_group, num_heads, group_size, group_size)}, but is"
f" {attn_weights.shape}"
)

# In sep mode, the attenion mask should be created in the runtime.
if reshard_layer is not None:
attention_mask = None

Check warning on line 88 in paddlenlp/transformers/longlora.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/longlora.py#L88

Added line #L88 was not covered by tests

if attention_mask is None:
attention_mask = get_triangle_upper_mask(attn_weights)
attention_mask = paddle.tile(
paddle.cast(attention_mask[:, :, :group_size, :group_size], dtype="float32"), [num_group, 1, 1, 1]
)

if attention_mask.shape != [bsz * num_group, 1, group_size, group_size]:
attention_mask = attention_mask[: bsz * num_group, :, :, :]

attn_weights = attn_weights + attention_mask
if not paddle.in_dynamic_mode():
attn_weights = F.softmax(attn_weights, axis=-1, dtype="float32").astype(query_states.dtype)

Check warning on line 101 in paddlenlp/transformers/longlora.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/longlora.py#L101

Added line #L101 was not covered by tests
else:
with paddle.amp.auto_cast(False):
attn_weights = F.softmax(attn_weights, axis=-1, dtype="float32").astype(query_states.dtype)

attn_output = paddle.matmul(attn_weights, value_states)
attn_output = attn_output.transpose([0, 2, 1, 3])

# shift back
attn_output = attn_output.reshape([bsz, q_len, num_heads, head_dim])
attn_output[:, :, num_heads // 2 :] = attn_output[:, :, num_heads // 2 :].roll(group_size // 2, axis=1)

if reshard_layer is not None:
attn_output = reshard_layer(

Check warning on line 114 in paddlenlp/transformers/longlora.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/longlora.py#L114

Added line #L114 was not covered by tests
attn_output,
split_axis=1,
concat_axis=2,
)
q_len = q_len // config.sep_parallel_degree
num_heads = num_heads * config.sep_parallel_degree

Check warning on line 120 in paddlenlp/transformers/longlora.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/longlora.py#L119-L120

Added lines #L119 - L120 were not covered by tests

if sequence_parallel:
attn_output = attn_output.reshape([bsz * q_len, head_dim * num_heads])

Check warning on line 123 in paddlenlp/transformers/longlora.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/longlora.py#L123

Added line #L123 was not covered by tests
else:
attn_output = attn_output.reshape([bsz, q_len, head_dim * num_heads])
return (attn_output, attn_weights) if output_attentions else attn_output


def set_group_size(group_size_ratio):
global ssa_group_size_ratio
ssa_group_size_ratio = group_size_ratio


def replace_llama_attn():
paddlenlp.transformers.llama.modeling.scaled_dot_product_attention = ssa_scaled_dot_product_attention

Check warning on line 135 in paddlenlp/transformers/longlora.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/longlora.py#L135

Added line #L135 was not covered by tests
12 changes: 12 additions & 0 deletions paddlenlp/trl/sft_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,18 @@ class SFTConfig(TrainingArguments):
model_init_kwargs: Optional[dict[str, Any]] = None
dataset_kwargs: Optional[dict[str, Any]] = None
eval_packing: Optional[bool] = None
use_ssa: bool = field(
default=False,
metadata={
"help": "Whether to use Shifted Sparse Attention (SSA), an efficient attention mechanism introduced in the LongLoRA paper."
},
)
ssa_group_size_ratio: float = field(
default=0.25,
metadata={
"help": "The ratio parameter for grouping in SSA, controlling the number of tokens considered in each group for sparse attention calculation."
},
)

def __post_init__(self):
super().__post_init__()
Expand Down
18 changes: 1 addition & 17 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,7 @@ minversion = "6.0"
addopts = "-ra -q "
pythonpath = ["."]
testpaths = [
"tests/data",
"tests/dataaug",
"tests/datasets",
"tests/embeddings",
"tests/experimental",
"tests/generation",
"tests/layers",
"tests/metrics",
"tests/pose",
"tests/ops",
"tests/trainer",
"tests/transformers",
"tests/peft",
"tests/prompt",
"tests/mergekit",
# "tests/taskflow", TODO (paddle 2.5.1 breaks this test suite, debug later)
"tests/utils",
"tests/longlora",
]
python_files = [
"test.py",
Expand Down
2 changes: 2 additions & 0 deletions tests/fixtures/llm/autoregressive_data/dev.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
{"src": "Are you a resident of Pinnacle who owns a small business and operates from your home?\nCan you provide a service to your fellow residents of Pinnacle? If you've answered yes to both of these questions, supply your details below and we will list your business on our site.\nResidents of Pinnacle, support your local community by checking here first and seeing whether one of your neighbours can assist."}
{"src": "On October 27, 2016 GreenWorks led a tour for the College of Architecture and Landscape Architecture of Beijing University and staff of Landscape Architecture Frontiers publication. This tour group was particularly interested in technical issues related to: soil/vegetation approaches for water quality treatment; the ultra- violet finishing treatment that allows for human contact with the treated water; and soil capping issues for a former brownfield site. GreenWorks typically leads 4-6 tours per year since Tanner Springs Park opened in 2005. Tour groups have included national and international professional and environmental organizations and academic institutions. Visitors are interested in a variety of issues, including design inspiration, public involvement and outreach, and technical challenges.\nCouch Park comment forms for the playground and plaza improvements are due Thursday December 10th.\nGreenWorks was been hired by Portland Parks & Recreation to design the new playground, address accessibility issues in the plaza and install a new Portland Loo at Couch Park as part of the Parks Replacement Bond. We presented the three options below for the playground at an Open House on December 3rd . Online comments are due Thursday December 10th and can be found here: http://www.portlandoregon.gov/parks/68915. One of the top priorities for the playground is for it to be inclusive, which mean that it should be designed for children of all ages and abilities. We have been working closely with Mara Kaplan from Let Kids Play http://www.letkidsplay.com/ who is a national expert and advocate for inclusive playground design. Mara was brought on to the design team to help us design a playground that provides exceptional play opportunities for all children.\nGreenWorks met with the city of Astoria to present the Downtown Astoria Pedestrian Wayfinding Concept Plan. Those in attendance were city officials, focus group members, and Astoria community members. The presentation focused on distinct sign typologies that direct and inform pedestrians getting around downtown Astoria. Following the presentation was an interactive group discussion about sign locations, aesthetic preferences interpretive sign opportunities."}
Loading
Loading