-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathClassificationNode.hpp
More file actions
54 lines (39 loc) · 1.41 KB
/
ClassificationNode.hpp
File metadata and controls
54 lines (39 loc) · 1.41 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
#ifndef CLASSIFICATIONNODE_H
#define CLASSIFICATIONNODE_H
#include "Node.hpp"
#include <mlpack/core.hpp>
#include <map>
#include <algorithm>
#include <string.h>
//Abstract class that every classification will inherit
class ClassificationNode : public Node{
public:
//Constructor
ClassificationNode(std::string name);
//Checking if target is selected
bool IsVariableSelected();
//Transforming target column into arma::row of numbers instead of the class values
arma::Row<size_t> TransformToArma();
//Setters
void setTarget(std::string targetName);
void SetNumClasses(const size_t& size);
void SetClassPredictions(const arma::Row<size_t> predictions);
//Methods that calculate precision and confusion matrix
void Precision(arma::Row<size_t> values, arma::Row<size_t> predictions);
void ConfusionMatrix(arma::Row<size_t> values, arma::Row<size_t> predictions);
//Getters
std::string TargetColumnName() const;
arma::Row<size_t> TargetColumn() const;
size_t NumClasses() const;
std::string GetPrecision() const;
std::string GetConfusionMatrix() const;
arma::Row<size_t> ClassPredictions() const;
protected:
std::string targetColumnName;
arma::Row<size_t> targetColumn;
size_t numClasses;
std::string precision;
std::string confusionMatrix;
arma::Row<size_t> classPredictions;
};
#endif // CLASSIFICATIONNODE_H