-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathkmeans.h
More file actions
40 lines (38 loc) · 1.38 KB
/
kmeans.h
File metadata and controls
40 lines (38 loc) · 1.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
#ifndef KMEANS_H
#define KMEANS_H
// Kmeans类
class Kmeans {
public:
Kmeans(int numClusters, int numFeatures, float* clusters, int nsamples);
Kmeans(int numClusters,
int numFeatures,
float* clusters,
int nsamples,
int maxIters,
float epsilon);
virtual ~Kmeans();
virtual void getDistance(const float* v_data);
virtual void updateClusters(const float* v_data);
virtual void fit(const float* v_data);
virtual void saveLabels();
virtual float accuracy(const int* label);
virtual const char* getLabelFileName() const {
return "cluster_labels_cpp.csv";
}
// 样本配置
int m_nsamples; // 样本数量
int m_numFeatures; // 特征数量
// 训练参数
int m_numClusters; // 类别数量,即k
float m_optTarget; // 优化目标值,即loss
int m_maxIters; // 最大迭代次数
float m_epsilon; // 目标阈值,两次loss相差超过该值停止迭代
// 中间变量
float* m_clusters; // [numClusters, numFeatures],存储当前各个类的中心点坐标
float* m_distances; // [nsamples, numClusters],用于存储每个样本到每个类的中心两两之间的距离
int* m_sampleClasses; // [nsamples, ],记录每个样本的类比编号
private:
Kmeans(const Kmeans& model);
Kmeans& operator=(const Kmeans& model);
};
#endif // KMEANS_H