-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbuild.rs
More file actions
36 lines (36 loc) · 1.4 KB
/
Copy pathbuild.rs
File metadata and controls
36 lines (36 loc) · 1.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
use burn_import::onnx::ModelGen;
static EMBED_WEIGHT : bool = true;
fn main() {
// Generate Rust code from the ONNX model file
let supported_task = ["DC_CO", "DC_ST","DC_SST", "DC_ID", "DS_ST","DS_SST", "DS_PR"];
for task in supported_task {
let mut path = String::from("model_onnx/linear_");
//let mut path = String::from("../af_research/IAFGNN/model_ln/linear_");
path.push_str(task);
path.push_str("_9f6_d0.2.onnx");
ModelGen::new()
.input(&path)
.out_dir("model/")
.record_type(burn_import::onnx::RecordType::Bincode)
.embed_states(EMBED_WEIGHT)
.run_from_script();
}
/*ModelGen::new()
.input("../af_research/IAFGNN/model_ln/linear_DC_ST_9f2_d0.2.onnx")
.out_dir("model/")
.record_type(burn_import::onnx::RecordType::Bincode)
.embed_states(EMBED_WEIGHT)
.run_from_script();
ModelGen::new()
.input("../af_research/IAFGNN/model_ln/linear_DS_PR_9f2_d0.2.onnx")
.out_dir("model/")
.record_type(burn_import::onnx::RecordType::Bincode)
.embed_states(EMBED_WEIGHT)
.run_from_script();
ModelGen::new()
.input("../af_research/IAFGNN/model_ln/linear_DS_ST_9f2_d0.2.onnx")
.out_dir("model/")
.record_type(burn_import::onnx::RecordType::Bincode)
.embed_states(EMBED_WEIGHT)
.run_from_script();*/
}