From 9b7617ac6faac0fc58bd2502127496dc1a794196 Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 13 Nov 2024 11:37:36 -0500 Subject: [PATCH] formatter --- tests/test_library_transforms.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/tests/test_library_transforms.py b/tests/test_library_transforms.py index 66ffc128..f26f3a31 100644 --- a/tests/test_library_transforms.py +++ b/tests/test_library_transforms.py @@ -46,6 +46,7 @@ def test_config(): cfg3 = ArrowConfig(ab_only=True, scale=False) assert cfg3.save_name == cfg.save_name + def test_knot_merge(tmp_path, create_dummy_expert): config = ExpertConfig( **{ @@ -64,16 +65,19 @@ def test_knot_merge(tmp_path, create_dummy_expert): } ) model = ExpertModel(ExpertModelConfig(base_model="EleutherAI/gpt-neo-125m")) - + config.finetune_task_name = "cot_creak" expert1 = create_dummy_expert(config, "cot_creak") config.finetune_task_name = "cot_creak_ii" expert2 = create_dummy_expert(config, "cot_creak_ii") # only leave 1 layer to speed up things. - expert1.expert_weights = {k:v for k,v in expert1.expert_weights.items() if "8.mlp" in k} - expert2.expert_weights = {k:v for k,v in expert2.expert_weights.items() if "8.mlp" in k} - + expert1.expert_weights = { + k: v for k, v in expert1.expert_weights.items() if "8.mlp" in k + } + expert2.expert_weights = { + k: v for k, v in expert2.expert_weights.items() if "8.mlp" in k + } library = LocalExpertLibrary(tmp_path) library.add_expert(expert1) @@ -82,13 +86,13 @@ def test_knot_merge(tmp_path, create_dummy_expert): transform = KnotMerge(KnotMergeConfig(path=f"{tmp_path}/knot_ingredients.pt")) exp = transform.transform(library) state_dict = model.model.state_dict() - - # TODO: this can be implemented as a seperate modifier maybe or utils func. + + # TODO: this can be implemented as a seperate modifier maybe or utils func. merged_layers = [] for p_name, value in exp.expert_weights.items(): if p_name in state_dict: merged_layers.append(p_name) - state_dict[p_name]+=value + state_dict[p_name] += value assert len(merged_layers) == len(exp.expert_weights.keys()) == 1