Skip to content

Hardcoded data types prevent from using some models with dtype different from float32 #1032

@christian-pinto

Description

@christian-pinto

For some models, such as Terramind, there are types statically set to torch.float which is 32bit.

An example is in the Terramind backbone while initializing the positional embeddings

pos_dim = embed_dim // 4
omega = torch.arange(pos_dim, dtype=torch.float) / pos_dim # Shape (D/4,)
omega = 1.0 / (temperature**omega)

However, it is common for models to be served using float16 or bfloat16 to reduce GPU memory usage and increase inference throughput. This is the case of vLLM that by default downcasts everything to float16. Loading terramind with vLLM fails exactly because of a data mismatch float16 vs float32.

I suggest setting the data type to torch.get_default_dtype() in place of torch.float to guarantee that all data uses the same data type.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions