diff --git a/curated_transformers/tests/models/util.py b/curated_transformers/tests/models/util.py index 0c15e075..f9bf0847 100644 --- a/curated_transformers/tests/models/util.py +++ b/curated_transformers/tests/models/util.py @@ -98,8 +98,7 @@ def assert_causal_lm_output_equals_hf( ) orig_model.eval() - for _, param in orig_model.state_dict().items(): - assert param.device == torch_device + check_params_buffers(orig_model, torch_device) hf_model = transformers.AutoModelForCausalLM.from_pretrained( model_name, @@ -155,8 +154,7 @@ def assert_decoder_output_equals_hf( ) orig_model.eval() - for _, param in orig_model.state_dict().items(): - assert param.device == torch_device + check_params_buffers(orig_model, torch_device) hf_model = transformers.AutoModel.from_pretrained( model_name, revision=model_revision, trust_remote_code=trust_remote_code @@ -219,8 +217,7 @@ def assert_encoder_output_equals_hf( orig_model = model_class.from_hf_hub(name=model_name, device=torch_device) orig_model.eval() - for _, param in orig_model.state_dict().items(): - assert param.device == torch_device + check_params_buffers(orig_model, torch_device) hf_model = transformers.AutoModel.from_pretrained(model_name) hf_model.to(torch_device) @@ -384,8 +381,7 @@ def assert_model_hf_serialization_roundtrip( ) orig_model.eval() - for _, param in orig_model.state_dict().items(): - assert param.device == torch_device + check_params_buffers(orig_model, torch_device) auto_cls = ( transformers.AutoModelForCausalLM @@ -424,3 +420,16 @@ def assert_model_hf_serialization_roundtrip( assert ( hf_config[k] == v ), f"Key '{k}' value '{v}' is different in the Hugging Face model config ('{hf_config[k]}')" + + +def check_params_buffers(model: Module, device: torch.device): + """ + Check that parameters/buffers are placed on the correct device and that + parameters are leaf nodes. + """ + for buffer in model.buffers(): + assert buffer.device == device + + for param in model.parameters(): + assert param.device == device + assert param.is_leaf diff --git a/curated_transformers/util/serde/load.py b/curated_transformers/util/serde/load.py index 0243dbf2..4a02e930 100644 --- a/curated_transformers/util/serde/load.py +++ b/curated_transformers/util/serde/load.py @@ -103,7 +103,7 @@ def default_tensor_to_parameter_converter( old_param = module._parameters[parameter_name] assert old_param is not None _validate_replacement(old_param, tensor, module_prefix) - return Parameter(tensor, requires_grad=old_param.requires_grad).to(device=device) # type: ignore + return Parameter(tensor.to(device=device), requires_grad=old_param.requires_grad) # type: ignore def _emplace_module_state_dict(