For some models, such as Terramind, there are types statically set to torch.float which is 32bit.
An example is in the Terramind backbone while initializing the positional embeddings
|
pos_dim = embed_dim // 4 |
|
omega = torch.arange(pos_dim, dtype=torch.float) / pos_dim # Shape (D/4,) |
|
omega = 1.0 / (temperature**omega) |
However, it is common for models to be served using float16 or bfloat16 to reduce GPU memory usage and increase inference throughput. This is the case of vLLM that by default downcasts everything to float16. Loading terramind with vLLM fails exactly because of a data mismatch float16 vs float32.
I suggest setting the data type to torch.get_default_dtype() in place of torch.float to guarantee that all data uses the same data type.
For some models, such as Terramind, there are types statically set to torch.float which is 32bit.
An example is in the Terramind backbone while initializing the positional embeddings
terratorch/terratorch/models/backbones/terramind/model/tm_utils.py
Lines 60 to 62 in 986f0a7
However, it is common for models to be served using float16 or bfloat16 to reduce GPU memory usage and increase inference throughput. This is the case of vLLM that by default downcasts everything to float16. Loading terramind with vLLM fails exactly because of a data mismatch float16 vs float32.
I suggest setting the data type to
torch.get_default_dtype()in place oftorch.floatto guarantee that all data uses the same data type.