@@ -47,9 +47,12 @@ async def test_early_stopping_on_threshold_met(self, tmp_path):
4747
4848 mock_client = MagicMock ()
4949 mock_training_client = MagicMock ()
50+ del mock_training_client .forward_backward_async
5051 mock_client .create_lora_training_client .return_value = mock_training_client
5152 mock_training_client .get_tokenizer .return_value = MagicMock ()
52- mock_training_client .save_state .return_value = "tinker://checkpoint-1"
53+ mock_training_client .save_weights_for_sampler .return_value = MagicMock ()
54+ mock_training_client .forward_backward .return_value = MagicMock ()
55+ mock_training_client .optim_step .return_value = MagicMock ()
5356
5457 with patch ("trainer_with_eval.tinker.ServiceClient" , return_value = mock_client ):
5558 with patch ("trainer_with_eval.prepare_training_data" , return_value = [MagicMock ()]):
@@ -76,9 +79,12 @@ async def test_full_rounds_below_threshold(self, tmp_path):
7679
7780 mock_client = MagicMock ()
7881 mock_training_client = MagicMock ()
82+ del mock_training_client .forward_backward_async
7983 mock_client .create_lora_training_client .return_value = mock_training_client
8084 mock_training_client .get_tokenizer .return_value = MagicMock ()
81- mock_training_client .save_state .return_value = "tinker://checkpoint"
85+ mock_training_client .save_weights_for_sampler .return_value = MagicMock ()
86+ mock_training_client .forward_backward .return_value = MagicMock ()
87+ mock_training_client .optim_step .return_value = MagicMock ()
8288
8389 with patch ("trainer_with_eval.tinker.ServiceClient" , return_value = mock_client ):
8490 with patch ("trainer_with_eval.prepare_training_data" , return_value = [MagicMock ()]):
@@ -109,9 +115,12 @@ async def test_evalops_integration_called(self, tmp_path):
109115
110116 mock_tinker_client = MagicMock ()
111117 mock_training_client = MagicMock ()
118+ del mock_training_client .forward_backward_async
112119 mock_tinker_client .create_lora_training_client .return_value = mock_training_client
113120 mock_training_client .get_tokenizer .return_value = MagicMock ()
114- mock_training_client .save_state .return_value = "tinker://checkpoint"
121+ mock_training_client .save_weights_for_sampler .return_value = MagicMock ()
122+ mock_training_client .forward_backward .return_value = MagicMock ()
123+ mock_training_client .optim_step .return_value = MagicMock ()
115124
116125 async def mock_run_evals (* args , ** kwargs ):
117126 evalops_client = kwargs .get ('evalops_client' )
@@ -155,7 +164,7 @@ async def test_lr_decay_across_rounds(self, tmp_path):
155164
156165 mock_client = MagicMock ()
157166 mock_training_client = MagicMock ()
158- mock_training_client .forward_backward_async = None
167+ del mock_training_client .forward_backward_async
159168 mock_client .create_lora_training_client .return_value = mock_training_client
160169 mock_training_client .get_tokenizer .return_value = MagicMock ()
161170 mock_training_client .save_weights_for_sampler .return_value = MagicMock ()
@@ -167,4 +176,4 @@ async def test_lr_decay_across_rounds(self, tmp_path):
167176 with patch ("trainer_with_eval.run_evaluations" , new = AsyncMock (return_value = 0.7 )):
168177 await async_main (str (config_file ))
169178
170- assert mock_training_client .forward_backward .call_count == 3
179+ assert mock_training_client .save_weights_for_sampler .call_count == 3
0 commit comments