Skip to content

Commit bc83994

Browse files
committed
examples/finetune -opt SGD (stochastic gradient descent) memory opt
add unit tested GGML_OPT_OPTIMIZER_SGD to ggml - avoids allocating m, v tensors. support finetune.cpp arg -opt SGD (or sgd). (default adamw as before) llama 3.2-1b-F32 result: observed 11gb gpu ram (41 sec/epoch) when using SGD instead of 19gb (55 sec/epoch) using adamw. (wikipedia 100 lines finetune) ( using the same GPU memory, adamw can only do before OOM 512 batch/context, reaching: train: [███████▉] data=0000140/0000140 loss=0.02575±0.00099 acc=99.52±0.03% t=00:00:47 ETA=00:00:00 val: [███████▉] data=0000008/0000008 loss=4.76565±0.28810 acc=41.46±0.77% t=00:00:00 ETA=00:00:00 SGD is superior, though it converges slower, with max before OOM 1728 batch/context (esp see the better validation perf): train: [███████▉] data=0000039/0000039 loss=0.00371±0.00010 acc=99.96±0.01% t=00:00:41 ETA=00:00:00 val: [███████▉] data=0000003/0000003 loss=5.11406±0.76034 acc=48.01±0.69% t=00:00:01 ETA=00:00:00 ) note: when finetuning long enough (or w/ enough -lr), validation accuracy *eventually* drops ('catastrophic forgetting') -lr-half (halflife) option useful for SGD to avoid oscillation or super slow underdamped learning (makes setting -lr more forgiving). terminal -lr for now is set by lr-halvings i.e. if you want at most 1/8 the inital -lr you set -lr-halvings 3. note: objective loss not directly comparable between adamw, sgd? - check perplexity or accuracy or consider relative improvements for convergence new finetune args -wd 1e-9 to enable weight decay in sgd or adamw, and max -epochs N (default 2 as before) cache (1 - wd*alpha) in 'adamw' opt struct - no noticeable perf benefit, disabled (still done for new SGD though) since opt. memory is pre-allocated, the ggml_opt_get_optimizer_params would probably be able to change between SGD and AdamW with each epoch but would need to use adamw for the first (unconfirmed - no cmdline arg to set such a policy yet) test-opt checks adamw as before and now sgd (except for a few disabled tests for sgd only; probably just needs logging values and adding alternate reference values); tolerance on the 'regression' test is broader for sgd (so we don't need many more epochs)
1 parent 55c2646 commit bc83994

23 files changed

+730
-211
lines changed

CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ if (NOT XCODE AND NOT MSVC AND NOT CMAKE_BUILD_TYPE)
1212
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo")
1313
endif()
1414

15+
message("CMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}")
16+
1517
# Add path to modules
1618
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
1719

common/arg.cpp

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1196,6 +1196,7 @@ bool common_params_parse(int argc, char ** argv, common_params & params, llama_e
11961196
common_params_print_completion(ctx_arg);
11971197
exit(0);
11981198
}
1199+
params.lr.init();
11991200
} catch (const std::invalid_argument & ex) {
12001201
fprintf(stderr, "%s\n", ex.what());
12011202
ctx_arg.params = params_org;
@@ -2609,9 +2610,9 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
26092610
{"-o", "--output", "--output-file"}, "FNAME",
26102611
string_format("output file (default: '%s')", params.out_file.c_str()),
26112612
[](common_params & params, const std::string & value) {
2612-
params.out_file = value;
2613+
params.out_file = value;
26132614
}
2614-
).set_examples({LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_CVECTOR_GENERATOR, LLAMA_EXAMPLE_EXPORT_LORA, LLAMA_EXAMPLE_TTS}));
2615+
).set_examples({LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_CVECTOR_GENERATOR, LLAMA_EXAMPLE_EXPORT_LORA, LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_FINETUNE}));
26152616
add_opt(common_arg(
26162617
{"-ofreq", "--output-frequency"}, "N",
26172618
string_format("output the imatrix every N iterations (default: %d)", params.n_out_freq),
@@ -3416,5 +3417,51 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
34163417
}
34173418
).set_examples({LLAMA_EXAMPLE_SERVER}));
34183419

3420+
add_opt(
3421+
common_arg({ "-lr", "--learning-rate-initial" }, "ALPHA",
3422+
string_format(
3423+
"adamw or sgd optimizer alpha (default: %.2g); note: sgd alpha recommended ~10x (no momentum)",
3424+
(double) params.lr.lr0),
3425+
[](common_params & params, const std::string & value) { params.lr.lr0 = std::stof(value); })
3426+
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
3427+
add_opt(
3428+
common_arg({ "-lr-min", "--learning-rate-min" }, "ALPHA",
3429+
string_format(
3430+
"(if >0) final learning rate (default=%.2g)",
3431+
(double) params.lr.lr_min),
3432+
[](common_params & params, const std::string & value) { params.lr.lr_min = std::stof(value); })
3433+
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
3434+
add_opt(
3435+
common_arg({ "-min-epochs", "--learning-rate-min-epochs" }, "ALPHA",
3436+
string_format(
3437+
"(if >0) reach -lr-min after this many epochs (instead of only at the last) (default=%.2g)",
3438+
(double) params.lr.min_epochs),
3439+
[](common_params & params, const std::string & value) { params.lr.min_epochs = std::stof(value); })
3440+
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
3441+
add_opt(common_arg(
3442+
{ "-wd", "--weight-decay" }, "WD",
3443+
string_format(
3444+
"adamw or sgd optimizer weight decay (0 is off; recommend very small e.g. 1e-9) (default: %.2g).",
3445+
(double) params.lr.wd),
3446+
[](common_params & params, const std::string & value) { params.lr.wd = std::stof(value); })
3447+
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
3448+
add_opt(common_arg({ "-val", "--val-split" }, "FRACTION",
3449+
string_format("fraction of data to use as validation set for training (default: %.2g).",
3450+
(double) params.val_split),
3451+
[](common_params & params, const std::string & value) { params.val_split = std::stof(value); })
3452+
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
3453+
add_opt(common_arg({ "-epochs", "--epochs" }, "N",
3454+
string_format("optimizer max # of epochs (default: %d)", params.lr.epochs),
3455+
[](common_params & params, int epochs) { params.lr.epochs = epochs; })
3456+
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
3457+
add_opt(common_arg({ "-opt", "--optimizer" }, "sgd|adamw", "adamw or sgd",
3458+
[](common_params & params, const std::string & name) {
3459+
params.optimizer = common_opt_get_optimizer(name.c_str());
3460+
if (params.optimizer == GGML_OPT_OPTIMIZER_TYPE_COUNT) {
3461+
throw std::invalid_argument("invalid --optimizer, valid options: adamw, sgd");
3462+
}
3463+
})
3464+
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
3465+
34193466
return ctx_arg;
34203467
}

common/common.cpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
#endif
4242
#include <locale>
4343
#include <windows.h>
44+
#include <string.h>
4445
#include <fcntl.h>
4546
#include <io.h>
4647
#else
@@ -1548,3 +1549,53 @@ ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std
15481549

15491550
return result;
15501551
}
1552+
1553+
ggml_opt_optimizer_params common_opt_lr_pars(void * userdata) {
1554+
ggml_opt_optimizer_params result = ggml_opt_get_default_optimizer_params(nullptr);
1555+
const lr_opt & d = *(lr_opt *) userdata;
1556+
result.adamw.alpha = result.sgd.alpha = d.get_lr(d.epoch);
1557+
result.sgd.wd = result.adamw.wd = d.wd;
1558+
return result;
1559+
}
1560+
1561+
static inline bool eq_case_insensitive(char const* a, char const* b) {
1562+
return !
1563+
#if defined(_MSC_VER)
1564+
_stricmp
1565+
#else
1566+
strcasecmp
1567+
#endif
1568+
(a, b);
1569+
}
1570+
1571+
enum ggml_opt_optimizer_type common_opt_get_optimizer(const char * n) {
1572+
if (eq_case_insensitive("adamw", n)) {
1573+
return GGML_OPT_OPTIMIZER_TYPE_ADAMW;
1574+
} else if (eq_case_insensitive("sgd", n)) {
1575+
return GGML_OPT_OPTIMIZER_TYPE_SGD;
1576+
} else {
1577+
return GGML_OPT_OPTIMIZER_TYPE_COUNT;
1578+
}
1579+
}
1580+
1581+
static float const k_log_2 = std::log(2.f);
1582+
1583+
void lr_opt::init() {
1584+
if (lr_min > 0 && lr_min < lr0) {
1585+
float nhalf = std::log(lr0 / lr_min) / k_log_2;
1586+
float e = epochs;
1587+
if (min_epochs > 0 && min_epochs < e)
1588+
e = min_epochs;
1589+
else
1590+
min_epochs = e;
1591+
scale_epoch = nhalf / e;
1592+
}
1593+
}
1594+
1595+
float lr_opt::get_lr(float epoch) const {
1596+
float r = lr_min <= 0 ? lr0 :
1597+
epoch >= min_epochs ? lr_min :
1598+
lr0 * std::pow(.5, epoch * scale_epoch);
1599+
LOG_INF("epoch %.2g lr=%.2g\n", epoch, r);
1600+
return r;
1601+
}

common/common.h

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,17 @@
22

33
#pragma once
44

5-
#include "llama-cpp.h"
6-
75
#include <set>
6+
#include <sstream>
87
#include <string>
98
#include <string_view>
109
#include <vector>
1110
#include <map>
1211
#include <sstream>
12+
#include <cmath>
13+
14+
#include "ggml-opt.h"
15+
#include "llama-cpp.h"
1316

1417
#ifdef _WIN32
1518
#define DIRECTORY_SEPARATOR '\\'
@@ -81,6 +84,7 @@ enum llama_example {
8184
LLAMA_EXAMPLE_LOOKUP,
8285
LLAMA_EXAMPLE_PARALLEL,
8386
LLAMA_EXAMPLE_TTS,
87+
LLAMA_EXAMPLE_FINETUNE,
8488

8589
LLAMA_EXAMPLE_COUNT,
8690
};
@@ -223,6 +227,25 @@ enum common_reasoning_format {
223227
COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content`, including in streaming deltas.
224228
};
225229

230+
231+
struct lr_opt {
232+
float lr0 = 1e-5; // learning rate at first epoch
233+
float lr_min = -1;
234+
float min_epochs = -1; // if >0, constant (lr_min) after this many epochs
235+
float scale_epoch = 0;
236+
float wd = 0;
237+
unsigned epochs = 2;
238+
239+
unsigned epoch; // set by optimizer outer (epochs) loop
240+
// learning rate decay - constant LR per epoch only for now
241+
float get_lr(float e) const;
242+
float get_lr() const { return get_lr(epoch); }
243+
// must call after arg parse, before get_lr
244+
void init();
245+
};
246+
247+
struct ggml_opt_optimizer_params common_opt_lr_pars(void * userdata);
248+
226249
struct common_params {
227250
int32_t n_predict = -1; // new tokens to predict
228251
int32_t n_ctx = 4096; // context size
@@ -354,6 +377,12 @@ struct common_params {
354377
bool no_mmproj = false; // explicitly disable multimodal model
355378
std::vector<std::string> image; // path to image file(s)
356379

380+
// finetune
381+
struct lr_opt lr;
382+
enum ggml_opt_optimizer_type optimizer = GGML_OPT_OPTIMIZER_TYPE_ADAMW;
383+
float val_split = 0.05f; // fraction of the data used for the validation set
384+
std::string opt_save_model_to = "finetuned-model.gguf";
385+
357386
// embedding
358387
bool embedding = false; // get only sentence embedding
359388
int32_t embd_normalize = 2; // normalisation for embeddings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
@@ -677,3 +706,6 @@ const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count";
677706
//
678707

679708
ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std::vector<llama_token> & tokens, int64_t stride);
709+
710+
// "adamw" or "sgd" (case insensitive)
711+
enum ggml_opt_optimizer_type common_opt_get_optimizer(const char *);

examples/training/finetune.cpp

Lines changed: 29 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,20 @@
1010
#include <vector>
1111

1212
#if defined(_MSC_VER)
13-
#pragma warning(disable: 4244 4267) // possible loss of data
13+
#pragma warning(disable: 4244 4267) // possible loss of data
1414
#endif
1515

1616
int main(int argc, char ** argv) {
1717
common_params params;
18-
1918
params.escape = false;
2019

21-
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PERPLEXITY)) {
20+
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_FINETUNE)) {
2221
return 1;
2322
}
2423

2524
if (params.use_mmap) {
26-
LOG_INF("%s: force disabling memory mapping because it would result in-read-only pointers to the weights\n", __func__);
25+
LOG_INF("%s: force disabling memory mapping because it would result in-read-only pointers to the weights\n",
26+
__func__);
2727
params.use_mmap = false;
2828
}
2929
if (params.cache_type_k != GGML_TYPE_F32) {
@@ -38,11 +38,11 @@ int main(int argc, char ** argv) {
3838
common_init();
3939
llama_backend_init();
4040
llama_numa_init(params.numa);
41-
4241
// load the model and apply lora adapter, if any
43-
common_init_result llama_init = common_init_from_params(params);
44-
llama_model_ptr & model = llama_init.model;
45-
llama_context_ptr & ctx = llama_init.context;
42+
common_init_result llama_init = common_init_from_params(params);
43+
llama_model_ptr & model = llama_init.model;
44+
llama_context_ptr & ctx = llama_init.context;
45+
auto pctx = ctx.get();
4646

4747
if (model == NULL) {
4848
LOG_ERR("%s: unable to load model\n", __func__);
@@ -55,31 +55,32 @@ int main(int argc, char ** argv) {
5555
LOG_INF("%s\n", common_params_get_system_info(params).c_str());
5656
}
5757

58-
constexpr float val_split = 0.05f;
59-
60-
std::vector<llama_token> tokens = common_tokenize(ctx.get(), params.prompt, true);
61-
ggml_opt_dataset_t dataset = common_opt_dataset_init(ctx.get(), tokens, llama_n_ctx(ctx.get())/2);
62-
63-
struct ggml_opt_optimizer_params optimizer_params = ggml_opt_get_default_optimizer_params(nullptr);
64-
optimizer_params.adamw.alpha = 1e-7f; // learning rate
65-
66-
struct llama_opt_params lopt_params {
67-
/*n_ctx_train =*/ 0,
68-
/*param_filter =*/ llama_opt_param_filter_all,
69-
/*param_filter_ud =*/ nullptr,
70-
/*get_opt_pars =*/ ggml_opt_get_constant_optimizer_params,
71-
/*get_opt_pars_ud =*/ &optimizer_params,
58+
std::vector<llama_token> tokens = common_tokenize(pctx, params.prompt, true);
59+
ggml_opt_dataset_t dataset = common_opt_dataset_init(pctx, tokens, llama_n_ctx(pctx) / 2);
60+
61+
struct lr_opt & lr = params.lr;
62+
LOG_INF("-optimizer %s -lr0 %.2g -wd %.2g -lr-min %.2g -min-epochs %.2g -epochs %d -period %.2g -val %.2g\n",
63+
ggml_opt_optimizer_name(params.optimizer), (double) lr.lr0, (double) lr.wd, (double) lr.lr_min, (double) lr.min_epochs,
64+
(unsigned) lr.epochs, (double) params.n_batch / params.n_ubatch, (double) params.val_split);
65+
66+
struct llama_opt_params lopt_params{
67+
/*n_ctx_train =*/0,
68+
/*param_filter =*/llama_opt_param_filter_all,
69+
/*param_filter_ud =*/nullptr,
70+
/*get_opt_pars =*/common_opt_lr_pars,
71+
/*get_opt_pars_ud =*/&params.lr,
72+
/*optimizer_type =*/params.optimizer,
7273
};
73-
llama_opt_init(ctx.get(), model.get(), lopt_params);
74+
llama_opt_init(pctx, model.get(), lopt_params);
7475

75-
const int64_t idata_split = ggml_opt_dataset_ndata(dataset) * (1.0f - val_split);
76+
const int64_t idata_split = ggml_opt_dataset_ndata(dataset) * (1.0f - params.val_split);
7677

7778
ggml_opt_result_t result_train = ggml_opt_result_init();
7879
ggml_opt_result_t result_eval = ggml_opt_result_init();
7980

80-
for (int epoch = 0; epoch < 2; ++epoch) {
81-
llama_opt_epoch(ctx.get(), dataset, result_train, result_eval, idata_split,
82-
ggml_opt_epoch_callback_progress_bar, ggml_opt_epoch_callback_progress_bar);
81+
for (lr.epoch = 0; lr.epoch < lr.epochs; ++lr.epoch) {
82+
llama_opt_epoch(pctx, dataset, result_train, result_eval, idata_split,
83+
ggml_opt_epoch_callback_progress_bar, ggml_opt_epoch_callback_progress_bar);
8384
fprintf(stderr, "\n");
8485

8586
ggml_opt_result_reset(result_train);
@@ -88,7 +89,7 @@ int main(int argc, char ** argv) {
8889
ggml_opt_result_free(result_train);
8990
ggml_opt_result_free(result_eval);
9091

91-
llama_model_save_to_file(model.get(), "finetuned-model.gguf");
92+
llama_model_save_to_file(model.get(), params.out_file.c_str());
9293

9394
llama_backend_free();
9495

ggml/include/ggml-opt.h

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -74,16 +74,26 @@ extern "C" {
7474
GGML_OPT_BUILD_TYPE_OPT = 30,
7575
};
7676

77+
enum ggml_opt_optimizer_type {
78+
GGML_OPT_OPTIMIZER_TYPE_ADAMW,
79+
GGML_OPT_OPTIMIZER_TYPE_SGD,
80+
81+
GGML_OPT_OPTIMIZER_TYPE_COUNT
82+
};
83+
7784
// parameters that control which optimizer is used and how said optimizer tries to find the minimal loss
7885
struct ggml_opt_optimizer_params {
79-
// AdamW optimizer parameters
8086
struct {
81-
float alpha; // learning rate
82-
float beta1;
83-
float beta2;
84-
float eps; // epsilon for numerical stability
85-
float wd; // weight decay for AdamW, use 0.0f to disable
87+
float alpha; // learning rate
88+
float beta1; // first AdamW momentum
89+
float beta2; // second AdamW momentum
90+
float eps; // epsilon for numerical stability
91+
float wd; // weight decay - 0.0f to disable
8692
} adamw;
93+
struct {
94+
float alpha; // learning rate
95+
float wd; // weight decay
96+
} sgd;
8797
};
8898

8999
// callback to calculate optimizer parameters prior to a backward pass
@@ -113,7 +123,10 @@ extern "C" {
113123
int32_t opt_period; // after how many gradient accumulation steps an optimizer step should be done
114124

115125
ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
116-
void * get_opt_pars_ud; // userdata for calculating optimizer parameters
126+
void * get_opt_pars_ud; // userdata for calculating optimizer parameters
127+
128+
// only GGML_OPT_OPTIMIZER_TYPE_ADAMW allocates m, v per parameter
129+
enum ggml_opt_optimizer_type optimizer;
117130
};
118131

119132
// get parameters for an optimization context with defaults set where possible
@@ -142,6 +155,10 @@ extern "C" {
142155
// get the gradient accumulator for a node from the forward graph
143156
GGML_API struct ggml_tensor * ggml_opt_grad_acc(ggml_opt_context_t opt_ctx, struct ggml_tensor * node);
144157

158+
GGML_API enum ggml_opt_optimizer_type ggml_opt_context_optimizer_type(ggml_opt_context_t);
159+
160+
GGML_API const char * ggml_opt_optimizer_name(enum ggml_opt_optimizer_type);
161+
145162
// ====== Optimization Result ======
146163

147164
GGML_API ggml_opt_result_t ggml_opt_result_init(void);
@@ -226,12 +243,14 @@ extern "C" {
226243
struct ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used
227244
ggml_opt_dataset_t dataset, // dataset with data and optionally also labels
228245
enum ggml_opt_loss_type loss_type, // loss to minimize
246+
enum ggml_opt_optimizer_type optimizer, // sgd or adamw
229247
ggml_opt_get_optimizer_params get_opt_pars, // callback to get optimizer params, userdata is pointer to epoch (of type int64_t)
230248
int64_t nepoch, // how many times the dataset should be iterated over
231249
int64_t nbatch_logical, // datapoints optimizer step, must be a multiple of ndata_batch in inputs/outputs
232250
float val_split, // fraction of the dataset to use for validation, must be in [0.0f, 1.0f)
233251
bool silent); // whether or not info prints to stderr should be suppressed
234252

253+
235254
#ifdef __cplusplus
236255
}
237256
#endif

0 commit comments

Comments
 (0)