forked from takezo5096/DNN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathoptimizer.h
62 lines (35 loc) · 869 Bytes
/
optimizer.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
/*
* Optimizer.h
*
*/
#ifndef OPTIMIZER_H_
#define OPTIMIZER_H_
#include <thread>
#include "model.h"
#include "cuMat.h"
#include "variable.h"
class OptimizerParams {
public:
OptimizerParams(){}
};
class Optimizer {
public:
vector<OptimizerParams *> opts;
int epoch = 1;
Model *model = NULL;
vector<thread *> ts;
vector<UpdateParams *> updateParams;
float lr;
float clip_grad_threshold = 0;
Optimizer(Model *model, float learning_rate);
Optimizer(Model *model, float learning_rate, float clip_grad_threshold);
~Optimizer();
void delOpts();
virtual OptimizerParams *createOptimizerParams(Variable *v);
void init();
virtual void update_param(Variable *w, OptimizerParams &opp);
void zero_grads();
void update();
void clip_grad(Variable *v);
};
#endif /* OPTIMIZER_H_ */