diff --git a/sharktank/sharktank/examples/paged_llm_v1.py b/sharktank/sharktank/examples/paged_llm_v1.py index 83efc8d9d..780f7bc13 100644 --- a/sharktank/sharktank/examples/paged_llm_v1.py +++ b/sharktank/sharktank/examples/paged_llm_v1.py @@ -6,15 +6,11 @@ """Inference support for the PagedLLMV1 protocol of models.""" -from typing import Optional - -from safetensors import safe_open import math -import sys - +from ..models.llama.tools.data_utils import write_ndarray_to_bin import torch - +import numpy as np from ..layers import * from ..types import * @@ -158,21 +154,21 @@ def prefill(self): seq_block_ids_tensor = replicate(seq_block_ids_tensor, tp) if self.dump_bins: - torch.save( - token_ids, - f"prefill_token_ids_{'_'.join([str(x) for x in token_ids.shape])}.bin", + write_ndarray_to_bin( + token_ids.numpy(), + f"prefill_token_ids_{'x'.join([str(x) for x in token_ids.shape])}xi64.bin", ) - torch.save( - torch.tensor(token_ids.shape[0]).to(torch.int64), - f"prefill_seq_lens_1.bin", + write_ndarray_to_bin( + np.array(token_ids.shape[0], dtype=np.int64), + f"prefill_seq_lens_1xi64.bin", ) - torch.save( - seq_block_ids_tensor, - f"prefill_seq_block_ids_{'_'.join([str(x) for x in seq_block_ids_tensor.shape])}.bin", + write_ndarray_to_bin( + seq_block_ids_tensor.numpy(), + f"prefill_seq_block_ids_{'x'.join([str(x) for x in seq_block_ids_tensor.shape])}xi64.bin", ) - torch.save( - self.cache_state[0].to(torch.float8_e4m3fnuz), - f"prefill_cache_state_{'_'.join([str(x) for x in self.cache_state[0].shape])}.bin", + write_ndarray_to_bin( + self.cache_state[0].to(torch.float8_e4m3fnuz).to(torch.uint8).numpy(), + f"prefill_cache_state_{'x'.join([str(x) for x in self.cache_state[0].shape])}xf8E4M3FNUZ.bin", ) logits = model.prefill( token_ids, @@ -219,25 +215,25 @@ def decode(self): decode_attention_mask = replicate(decode_attention_mask, tp) if self.dump_bins: - torch.save( - self.next_tokens, - f"decode_next_tokens_{'_'.join([str(x)for x in self.next_tokens.shape])}.bin", + write_ndarray_to_bin( + self.next_tokens.numpy(), + f"decode_next_tokens_{'x'.join([str(x)for x in self.next_tokens.shape])}xi64.bin", ) - torch.save( - start_positions, - f"decode_start_positions_{'_'.join([str(x)for x in start_positions.shape])}.bin", + write_ndarray_to_bin( + start_positions.numpy(), + f"decode_start_positions_{'x'.join([str(x)for x in start_positions.shape])}xi64.bin", ) - torch.save( - seq_block_ids_tensor, - f"decode_seq_block_ids_tensor_{'_'.join([str(x)for x in seq_block_ids_tensor.shape])}.bin", + write_ndarray_to_bin( + seq_block_ids_tensor.numpy(), + f"decode_seq_block_ids_tensor_{'x'.join([str(x)for x in seq_block_ids_tensor.shape])}xi64.bin", ) - torch.save( - torch.tensor(self.next_tokens.shape[0]).to(torch.int64), - f"decode_seq_lens_1.bin", + write_ndarray_to_bin( + torch.tensor(self.next_tokens.shape[0]).to(torch.int64).numpy(), + f"decode_seq_lens_1xi64.bin", ) - torch.save( - self.cache_state[0].to(torch.float8_e4m3fnuz), - f"decode_cache_state_{'_'.join([str(x) for x in self.cache_state[0].shape])}.bin", + write_ndarray_to_bin( + self.cache_state[0].to(torch.float8_e4m3fnuz).to(torch.uint8).numpy(), + f"decode_cache_state_{'x'.join([str(x) for x in self.cache_state[0].shape])}xf8E4M3FNUZ.bin", ) logits = model.decode( self.next_tokens,