-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathloraplus.py
More file actions
145 lines (118 loc) · 4.5 KB
/
loraplus.py
File metadata and controls
145 lines (118 loc) · 4.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
from functools import reduce
# MIT License
#
# Copyright (c) 2024 nikhil-ghosh-berkeley
# https://github.com/nikhil-ghosh-berkeley/loraplus
import torch.nn as nn
from peft.tuners import lora
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.trainer_pt_utils import get_parameter_names
from transformers.utils import logging
logger = logging.get_logger(__name__)
def get_module(name, opt_model):
"""
Retrieve a module from a model using its parameter name.
Args:
name (str): Full name of the parameter, typically including module path.
opt_model (torch.nn.Module): The model from which to retrieve the module.
Returns:
Module corresponding to the given name.
"""
parent_idx = 2 if "lora" in name else 1
module_names = name.split(sep=".")[:-parent_idx]
module = reduce(getattr, module_names, opt_model)
return module
def create_loraplus_params(
opt_model,
# optimizer_cls,
optimizer_kwargs,
lr_ratio=2**4, #according to authors
lr_embedding=None,
lr_magnitude=None,
):
"""
Creates an params_group for the given model, applying LoRA-specific learning rate adjustments to different parameter groups.
Args:
opt_model (torch.nn.Module): The model for which the optimizer is being created.
optimizer_kwargs (dict): A dictionary of keyword arguments for the optimizer's initialization.
lr_ratio (float): The learning rate ratio to be applied to LoRA parameters.
lr_embedding (float, optional): A specific learning rate for embedding parameters, with a default value if not provided.
lr_magnitude (float, optional): If using DoRA , learning rate for magnitude parameters.
Returns:
List of Params that needs to be updated
example use:-
update_params=create_loraplus_params(
model,
optimizer_kwargs={'lr':1e-5},
lr_ratio=2**4,
lr_embedding=2e-6,
)
optimizer=AdamW(update_params,**kwargs)
"""
assert lr_ratio is not None
if lr_embedding is None:
lr_embedding = 1e-6
if lr_magnitude is None:
lr_magnitude = optimizer_kwargs['lr']
decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
decay_parameters = [name for name in decay_parameters if "bias" not in name]
param_groups = {
"groupA": {},
"groupB": {},
"groupB_no_decay": {},
"embedding": {},
"magnitude": {},
}
for name, param in opt_model.named_parameters():
if not param.requires_grad:
continue
module = get_module(name, opt_model)
if isinstance(module, lora.Embedding):
param_groups["embedding"][name] = param
elif "lora_magnitude_vector" in name:
param_groups["magnitude"][name] = param
elif "lora_B" in name or param.ndim == 1:
if name in decay_parameters:
param_groups["groupB"][name] = param
else:
param_groups["groupB_no_decay"][name] = param
else:
param_groups["groupA"][name] = param
assigned_param_groups = ""
for group in param_groups:
assigned_param_groups += f"{group}\n {list(param_groups[group].keys())}\n\n"
logger.debug(assigned_param_groups)
lr = optimizer_kwargs["lr"]
weight_decay = optimizer_kwargs.get("weight_decay", 0.0)
optimizer_grouped_parameters = [
{
"params": list(param_groups["magnitude"].values()),
"weight_decay": weight_decay,
"lr": lr_magnitude,
},
{
"params": list(param_groups["groupA"].values()),
"weight_decay": weight_decay,
"lr": lr,
},
{
"params": list(param_groups["embedding"].values()),
"weight_decay": weight_decay,
"lr": lr_embedding,
},
{
"params": list(param_groups["groupB"].values()),
"weight_decay": weight_decay,
"lr": lr * lr_ratio,
},
{
"params": list(param_groups["groupB_no_decay"].values()),
"weight_decay": 0.0,
"lr": lr * lr_ratio,
},
]
for i,v in enumerate(optimizer_grouped_parameters):
if len(v["params"]) == 0:
optimizer_grouped_parameters.pop(i)
# optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
return optimizer_grouped_parameters ## return updated params with specific weight decay