@@ -84,7 +84,9 @@ def capture_pre_run_state(self):
84
84
export_args .save_dir = self .save_dir .name
85
85
else :
86
86
export_args .checkpoint_path = self .expected_checkpoint_path
87
- export_args .save_dir = train_args .save_dir
87
+ export_args .save_dir = train_args .save_dir + "_exported"
88
+ export_args .model_tag = train_args .model_tag
89
+ export_args .arch_key = train_args .arch_key
88
90
89
91
if "deploy" in self .configs :
90
92
deploy_args = self .configs ["deploy" ].run_args
@@ -96,8 +98,8 @@ def capture_pre_run_state(self):
96
98
97
99
def add_abridged_configs (self ):
98
100
if "train" in self .command_types :
99
- self .configs ["train" ].max_train_steps = 10
100
- self .configs ["train" ].max_eval_steps = 10
101
+ self .configs ["train" ].run_args . max_train_steps = 2
102
+ self .configs ["train" ].run_args . max_eval_steps = 2
101
103
102
104
def teardown (self ):
103
105
"""
@@ -181,7 +183,7 @@ def test_train_metrics(self, integration_manager):
181
183
def test_export_onnx_graph (self , integration_manager ):
182
184
export_args = integration_manager .configs ["export" ]
183
185
expected_onnx_path = os .path .join (
184
- integration_manager . save_dir . name ,
186
+ export_args . run_args . save_dir ,
185
187
export_args .run_args .model_tag ,
186
188
"model.onnx" ,
187
189
)
@@ -201,7 +203,7 @@ def test_export_target_model(self, integration_manager):
201
203
zoo_model = Zoo .load_model_from_stub (target_model_path )
202
204
target_model_path = zoo_model .onnx_file .downloaded_path ()
203
205
export_model_path = os .path .join (
204
- integration_manager . save_dir . name ,
206
+ export_args . run_args . save_dir ,
205
207
export_args .run_args .model_tag ,
206
208
"model.onnx" ,
207
209
)
0 commit comments