From f7858612f0e59affc0156c374c1719e2c6926653 Mon Sep 17 00:00:00 2001
From: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com>
Date: Thu, 15 Feb 2024 14:11:03 +0100
Subject: [PATCH 1/5] Create starcode_kv_cache_injection

---
 starcode_kv_cache_injection | 1 +
 1 file changed, 1 insertion(+)
 create mode 100644 starcode_kv_cache_injection

diff --git a/starcode_kv_cache_injection b/starcode_kv_cache_injection
new file mode 100644
index 00000000000..8b137891791
--- /dev/null
+++ b/starcode_kv_cache_injection
@@ -0,0 +1 @@
+

From 22d10a7dfca08c43970dc94887ee1b41ea365f62 Mon Sep 17 00:00:00 2001
From: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com>
Date: Thu, 15 Feb 2024 14:12:21 +0100
Subject: [PATCH 2/5] Delete starcode_kv_cache_injection

---
 starcode_kv_cache_injection | 1 -
 1 file changed, 1 deletion(-)
 delete mode 100644 starcode_kv_cache_injection

diff --git a/starcode_kv_cache_injection b/starcode_kv_cache_injection
deleted file mode 100644
index 8b137891791..00000000000
--- a/starcode_kv_cache_injection
+++ /dev/null
@@ -1 +0,0 @@
-

From c01f768856b486f8cd2365d081d87bac9c03b24b Mon Sep 17 00:00:00 2001
From: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com>
Date: Thu, 15 Feb 2024 14:12:58 +0100
Subject: [PATCH 3/5] Add files via upload

---
 starcode_kv_cache_injection/__init__.py       |   0
 .../kv_cache_injection.py                     | 216 ++++++++++++++++++
 starcode_kv_cache_injection/validation.py     | 210 +++++++++++++++++
 3 files changed, 426 insertions(+)
 create mode 100644 starcode_kv_cache_injection/__init__.py
 create mode 100644 starcode_kv_cache_injection/kv_cache_injection.py
 create mode 100644 starcode_kv_cache_injection/validation.py

diff --git a/starcode_kv_cache_injection/__init__.py b/starcode_kv_cache_injection/__init__.py
new file mode 100644
index 00000000000..e69de29bb2d
diff --git a/starcode_kv_cache_injection/kv_cache_injection.py b/starcode_kv_cache_injection/kv_cache_injection.py
new file mode 100644
index 00000000000..2cc52a52da6
--- /dev/null
+++ b/starcode_kv_cache_injection/kv_cache_injection.py
@@ -0,0 +1,216 @@
+from transformers import AutoTokenizer, AutoConfig
+
+import onnx
+import logging
+import os
+from typing import List, Optional
+from onnx import TensorProto, ModelProto, helper, NodeProto
+from sparseml.onnx.utils import ONNXGraph
+from sparseml.exporters.transforms.kv_cache.transforms_codegen import AdditionalTransformsCodeGen
+
+_LOGGER = logging.getLogger(__name__)
+
+class AdditionalTransformsBigCode(AdditionalTransformsCodeGen):
+    """
+    Since the entries of the causal mask are similar in their values
+    and layout to the CodeGen causal mask, I inherit from the
+    AdditionalTransformsCodeGen class
+    """
+
+    # position ids are created by a Sub node (the one that is folllowed by a Where node
+    # in the onnx graph)
+    POSITION_IDS_MATCHING_PATTERN = dict(op_type="Sub", children_ops=[["Where"]])
+    # causal mask is created by a Unsqueeze node (the one that is folllowed by a Where node
+    # in the onnx graph)
+    CAUSAL_MASK_MATCHING_PATTERN = dict(op_type="Unsqueeze", children_ops=[["Where", "Softmax"]])
+
+    def add_causal_mask_input(self, model: ModelProto) -> ModelProto:
+        """
+        reformulating this method (originally part of the AdditionalTransformsBase class)
+        so that the causal mask has shape [batch_size, input_ids_length, 1, sequence_length]
+        vs the original shape [batch_size, 1, input_ids_length, sequence_length]
+        """
+
+        input_ids = self._get_input_proto(model, "input_ids")
+        attention_mask = self._get_input_proto(model, "attention_mask")
+
+        batch_size = input_ids.type.tensor_type.shape.dim[0].dim_param
+        input_ids_length = input_ids.type.tensor_type.shape.dim[1].dim_value
+        sequence_length = attention_mask.type.tensor_type.shape.dim[1].dim_param
+
+        causal_mask_input = helper.make_tensor_value_info(
+            name=self.CAUSAL_MASK_NAME,
+            elem_type=TensorProto.INT64,
+            # this is de-facto the only change from the original method
+            shape=[batch_size, input_ids_length, 1, sequence_length],
+        )
+        model.graph.input.append(causal_mask_input)
+        _LOGGER.info(f"Inserted {self.CAUSAL_MASK_NAME} input to the ONNX model")
+        return model
+
+    def swap_nodes_for_input(
+        self,
+        model: ModelProto,
+        nodes: List[NodeProto],
+        input_name: str,
+        nodes_parent_op_type: Optional[str] = None,
+    ) -> ModelProto:
+
+        """
+        Injects the specified input to the graph, replacing the specified nodes.
+
+        :param model: the ONNX model to inject the input into
+        :param nodes: the nodes to replace with the input
+        :param input_name: the name of the input to replace the nodes with
+        :param nodes_parent_op_type: the parent op type of the nodes to replace
+
+        :return: the updated model
+        """
+
+        graph = ONNXGraph(model)
+        for node in nodes:
+            # edits so that we can have multiple children nodes
+            children_nodes = graph.get_node_children(node)
+            for child_node in children_nodes:
+                if nodes_parent_op_type:
+                    assert child_node.op_type == nodes_parent_op_type, (
+                        f"Expected to find {nodes_parent_op_type} node, "
+                        f"found {child_node.op_type}"
+                    )
+                output_to_replace = node.output[0]
+                self.log_match(node)
+                for idx, input_name_child_node in enumerate(child_node.input):
+                    if input_name_child_node == output_to_replace:
+                        graph.update_node_input(child_node, input_name, idx)
+
+        graph.delete_orphaned_node_branches()
+
+        _LOGGER.info(
+            f"Successfully swapped {len(nodes)} nodes for input '{input_name}'"
+        )
+
+        return model
+
+    def transform(self, model: ModelProto) -> ModelProto:
+        """
+        1. Adds `positions` as an input to the model
+        2. Adds `causal_mask` as an input to the model
+        2. Finds the node that initially creates the `position_ids` tensor
+        3. Updates the node to use the positions input instead of
+           computing it from the Range op
+        4. Finds the nodes that initially create the `causal_mask` tensors
+        5. Updates the nodes to use the causal_mask input instead of
+              computing it from the Slice op
+
+        :param model: model to update
+        :return: updated model
+        """
+        model = self.add_positions_input(model)
+        model = self.add_causal_mask_input(model)
+
+        position_ids_nodes = self.find_nodes_by_pattern(
+            model, pattern=self.POSITION_IDS_MATCHING_PATTERN
+        )
+        if len(position_ids_nodes) != 1:
+            raise ValueError(
+                "Expected to find exactly one node matching "
+                f"the pattern {self.POSITION_IDS_MATCHING_PATTERN}, "
+                f"found {len(position_ids_nodes)}"
+            )
+
+        model = self.inject_positions(model, position_ids_nodes, "Where")
+
+        causal_mask_nodes = self.find_nodes_by_pattern(
+            model, pattern=self.CAUSAL_MASK_MATCHING_PATTERN
+        )
+        model = self.inject_causal_mask(model, causal_mask_nodes, "Where")
+        model = self.adjust_causal_mask(model)
+        return model
+
+def inject_kv_cache_inputs_outputs(model: ModelProto, names_nodes_producing_kv_tensors, hidden_size_kv_cache, batch_size = 1):
+    graph = ONNXGraph(model)
+
+    inputs_to_add = []
+    outputs_to_add = []
+
+    attention_layer_idx = 0
+
+    for node in model.graph.node:
+        if node.name in names_nodes_producing_kv_tensors:
+
+            # inject kv cache input/output
+            cache_input_name_concat = f"past_key_values.{attention_layer_idx}"
+            cache_output_name_concat = f"present.{attention_layer_idx}"
+
+            cache_input_info = onnx.helper.make_tensor_value_info(
+                cache_input_name_concat,
+                TensorProto.FLOAT,
+                [
+                    batch_size,
+                    "past_sequence_len",
+                    hidden_size_kv_cache,
+                ]
+            )
+
+            cache_output_info = onnx.helper.make_tensor_value_info(
+                cache_output_name_concat,
+                TensorProto.FLOAT,
+                [
+                    batch_size,
+                    "past_sequence_len + 1",
+                    hidden_size_kv_cache,
+                ]
+            )
+
+            cache_parent = node
+            concat_axis = 1 # concat over length axis
+
+            concat_node = onnx.helper.make_node(
+                op_type="Concat",
+                inputs=[cache_input_name_concat, cache_parent.output[1]],
+                outputs=[cache_output_name_concat],
+                axis=concat_axis,
+                name=f"concat.{cache_input_name_concat}",
+            )
+
+            for _node in model.graph.node:
+                for input_idx, input_id in enumerate(_node.input):
+                    if input_id == cache_parent.output[1] and _node.name != concat_node.name:
+                        _node.input[input_idx] = cache_output_name_concat
+
+            graph.add_node(concat_node)
+            inputs_to_add.extend([cache_input_info])
+            outputs_to_add.extend([cache_output_info])
+
+            attention_layer_idx += 1
+            _LOGGER.info(f"Injected kv cache input/output for attention layer {attention_layer_idx}")
+
+    model.graph.input.extend(inputs_to_add)
+    model.graph.output.extend(outputs_to_add)
+    return model
+
+
+def main(deployment_folder_path, save_name_injected_model):
+    onnx_model = onnx.load(os.path.join(deployment_folder_path, "model.onnx"), load_external_data=False)
+    config = AutoConfig.from_pretrained(os.path.join(deployment_folder_path, "config.json"))
+    # KV Cache injection
+    onnx_model = inject_kv_cache_inputs_outputs(model = onnx_model,
+                                                names_nodes_producing_kv_tensors=[f"/transformer/h.{i}/attn/Split" for i in range(config.n_layer)],
+                                                hidden_size_kv_cache=2 * config.n_embd // config.n_head)
+    # Adjustment of causal masks and positions
+    transformation = AdditionalTransformsBigCode()
+    onnx_model = transformation.transform(model = onnx_model)
+    # Save the model
+    _LOGGER.info(f"Saved injected model to {os.path.join(deployment_folder_path, save_name_injected_model)}")
+    onnx.save_model(onnx_model, os.path.join(deployment_folder_path, save_name_injected_model))
+
+
+
+if __name__ == "__main__":
+    PATH_TO_DEPLOYMENT_FOLDER = "/Users/damian/Code/nm/sparseml/tiny_starcoder_py/deployment/"
+    # model created by running:
+    # sparseml.export /Users/damian/Code/nm/sparseml/tiny_starcoder_py/ --task text-generation --integration transformers  --sequence_length 256 --trust_remote_code True
+    NAME_INJECTED_MODEL = "test.onnx"
+    main(PATH_TO_DEPLOYMENT_FOLDER, NAME_INJECTED_MODEL)
+
+
diff --git a/starcode_kv_cache_injection/validation.py b/starcode_kv_cache_injection/validation.py
new file mode 100644
index 00000000000..cd3150d9ba0
--- /dev/null
+++ b/starcode_kv_cache_injection/validation.py
@@ -0,0 +1,210 @@
+import onnxruntime as ort
+import numpy as np
+import onnx
+from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
+from onnx.tools import update_model_dims
+from sparseml.onnx.utils import ONNXGraph
+import logging
+import numpy
+from typing import List, Union
+
+_LOGGER = logging.getLogger(__name__)
+
+
+def create_causal_mask(
+    input_ids: Union[numpy.ndarray, List[int]],
+    attention_mask: Union[numpy.ndarray, List[int]],
+    dtype: numpy.dtype = numpy.int64,
+) -> numpy.ndarray:
+    """
+    Compute a causal mask from a set of module inputs.
+    In transformers, a causal mask is a boolean mask that is used to
+    prevent information from future positions in a sequence from
+    being used to predict the current position. Each element of the mask
+    is set to 1 if the corresponding position in the input sequence
+    is allowed to attend to positions up to and including that position,
+    and 0 otherwise.
+
+    in case of single-token input, the causal mask is an array
+    of shape [1, 1, 1, sequence_length],
+    (essentially the reshaped attention_mask)
+
+    in case of a multi-token input, the causal mask is an array
+    of shape [batch_size, 1, input_ids_length, sequence_length]
+    it is a concatenation of a:
+     - past (cache) causal mask
+     - and a causal mask (a lower triangular matrix of 1's and 0's)
+    e.g
+    ```
+    input_ids = [[1,2,3,4]]
+    attention_mask = [[1,1,1,1,1,1]]
+
+    causal_mask = [[[[ 1 1 | 1 0 0 0 ],
+                     [ 1 1 | 1 1 0 0 ],
+                     [ 1 1 | 1 1 1 0 ],
+                     [ 1 1 | 1 1 1 1 ]]]]
+    ```
+    or
+    ```
+    input_ids = [[1,2,3,4]]
+    attention_mask = [[0,0,1,1,1,1,1]]
+
+    causal_mask = [[[[ 0 0 1 1 | 1 0 0 0 ],
+                     [ 0 0 1 1 | 1 1 0 0 ],
+                     [ 0 0 1 1 | 1 1 1 0 ],
+                     [ 0 0 1 1 | 1 1 1 1 ]]]]
+    ```
+
+    :param input_ids: input ids of the model input
+    :param attention_mask: attention mask of the model input
+    :param dtype: data type of the mask
+    :return: causal mask
+    """
+    if isinstance(input_ids, numpy.ndarray):
+        batch_size, input_ids_length = input_ids.shape
+
+    else:
+        batch_size, input_ids_length = 1, len(input_ids)
+
+    if isinstance(attention_mask, numpy.ndarray):
+        sequence_length = attention_mask.shape[1]
+    else:
+        sequence_length = len(attention_mask)
+        attention_mask = numpy.array(attention_mask)[None, ...]
+
+    if input_ids_length == 1:
+        causal_mask = numpy.reshape(attention_mask, (batch_size, 1, 1, sequence_length))
+        return causal_mask.astype(dtype)
+
+    causal_mask = numpy.tril(
+        numpy.ones((batch_size, 1, input_ids_length, input_ids_length), dtype=dtype), 0
+    )
+    past_causal_mask = numpy.ones(
+        (batch_size, 1, input_ids_length, sequence_length - input_ids_length),
+        dtype=dtype,
+    )
+    causal_mask = numpy.concatenate((past_causal_mask, causal_mask), axis=-1)
+
+    num_zeros = numpy.count_nonzero(attention_mask == 0)
+
+    # changes to the original function
+    causal_mask[:, :, num_zeros:, :] = 0
+    causal_mask = causal_mask.reshape(1, sequence_length, 1, -1)
+
+    return causal_mask
+
+def apply_input_shapes(model, onnx_model_path, sequence_length, config):
+    kv_cache_hidden_dim = config.n_embd // config.n_head
+    cache_changes_in = {n.name: [1, "dynamic_len_1", 2 * kv_cache_hidden_dim] for n in model.graph.input if n.name.startswith("past_key_values")}
+    cache_changes_out = {n.name: [1, "dynamic_len_2", 2 * kv_cache_hidden_dim] for n in model.graph.output if n.name.startswith("present")}
+    graph = ONNXGraph(model)
+
+    graph.delete_unused_initializers()
+    graph.delete_orphaned_node_branches()
+    graph.sort_nodes_topologically()
+
+    model = update_model_dims.update_inputs_outputs_dims(model,
+                                                         {"input_ids": [1, "dynamic_len_3"],
+                                                          "positions": [1, "dynamic_len_4"],
+                                                          "attention_mask": [1, sequence_length],
+                                                          "causal_mask": [1, "dynamic_len_5", 1, "dynamic_len_6"],
+                                                          **cache_changes_in},
+
+                                                          {"logits": [1, "dynamic_len_6", config.vocab_size], **cache_changes_out})
+
+    onnx.save(model, onnx_model_path)
+    return model
+
+
+def multitoken_inference_test(onnx_model_path, prompt, config, tokenizer, sequence_length, logits_gt, kv_cache_gt):
+    # feed the whole sequence to the model so that we can initially validate
+    # the correctness of the kv cache injected model
+    kv_cache_hidden_dim = config.n_embd // config.n_head
+    inputs = tokenizer(prompt, return_tensors="np", padding='max_length', max_length=sequence_length)
+    input_ids = inputs.input_ids  # (1, sequence_length)
+    attention_mask = inputs.attention_mask  # (1, sequence_length)
+    kv_cache = {f"past_key_values.{i}": np.zeros((1, 0, 2 * kv_cache_hidden_dim), dtype=np.float32) for i in
+                range(config.n_layer)}  # (1, 0, 2 * embedding [because we have k and v's concatenated])
+    causal_mask = create_causal_mask(input_ids, attention_mask)  # (1, sequence_length, 1, sequence_length)
+    positions = attention_mask.cumsum(-1) - 1  # (1, sequence_length)
+
+    session = ort.InferenceSession(onnx_model_path)
+
+    out = session.run(
+        None,
+        {
+            "input_ids": input_ids,
+            "attention_mask": attention_mask,
+            **kv_cache,
+            "causal_mask": causal_mask,
+            "positions": positions,
+        },
+    )
+    logits, *kv_cache = out
+
+    num_tokens_processed = logits_gt.shape[1] # only test the relevant, non-padded tokens
+    assert np.allclose(logits[:, :num_tokens_processed, :], logits_gt, atol=1e-3)
+    assert all(np.allclose(x[:, :num_tokens_processed, :], y, atol=1e-3) for x, y in zip(kv_cache, kv_cache_gt))
+
+def singletoken_inference_test(onnx_model_path, prompt, config, tokenizer, sequence_length, logits_gt, kv_cache_gt):
+    # feed the model one token at a time to validate the correctness of the kv cache injected model
+    model = onnx.load(onnx_model_path, load_external_data=True)
+    apply_input_shapes(model, onnx_model_path, sequence_length, config)
+
+    kv_cache_hidden_dim = config.n_embd // config.n_head
+    inputs = tokenizer(prompt, return_tensors="np")
+    attention_mask = np.zeros((1, sequence_length), dtype=np.int64)
+    kv_cache = {f"past_key_values.{i}": np.zeros((1,sequence_length-1, 2 * kv_cache_hidden_dim), dtype=np.float32) for i in range(config.n_layer)}
+    session = ort.InferenceSession(onnx_model_path)
+
+    for idx, token in enumerate(inputs.input_ids[0]):
+        if token == tokenizer.pad_token_id:
+            break
+        attention_mask[:, -(idx + 1):] = 1
+        positions = np.array([[idx]])
+        input_ids = np.array([[token]])
+        causal_mask = create_causal_mask(input_ids, attention_mask)
+
+        outputs = session .run(None, {
+            "input_ids": input_ids,
+            "attention_mask": attention_mask,
+            "positions": positions,
+            "causal_mask": causal_mask,
+            **kv_cache
+        })
+        # will not run without throwing an error, there are some missing pieces that need to be addressed
+
+def get_baseline(prompt, hf_model_name, tokenizer):
+    model = AutoModelForCausalLM.from_pretrained(hf_model_name)
+    tokens = tokenizer.encode(prompt, return_tensors="pt")
+    out = model(tokens, return_dict=True)
+    logits_gt = out.logits.detach().numpy()
+    kv_cache_gt = [t.detach().numpy() for t in out.past_key_values]
+    return logits_gt, kv_cache_gt
+
+def main(prompt, hf_model_name, onnx_model_path, sequence_length):
+    config = AutoConfig.from_pretrained(hf_model_name)
+    tokenizer = AutoTokenizer.from_pretrained(hf_model_name)
+    tokenizer.pad_token = tokenizer.eos_token
+
+    logits_gt, kv_cache_gt = get_baseline(prompt, hf_model_name, tokenizer)
+
+    multitoken_inference_test(onnx_model_path, prompt, config, tokenizer, sequence_length, logits_gt, kv_cache_gt)
+    _LOGGER.info("Successfully ran multi-token inference on the kv cache injected model")
+    singletoken_inference_test(onnx_model_path, prompt, config, tokenizer, sequence_length, logits_gt, kv_cache_gt)
+    _LOGGER.info("Successfully ran single-token inference on the kv cache injected model")
+
+
+
+if __name__ == "__main__":
+    PROMPT = "def eight_queens():\n    if True:\n        return 1\n    "
+    HF_MODEL_NAME = "bigcode/tiny_starcoder_py"
+    ONNX_MODEL_PATH = "/Users/damian/Code/nm/sparseml/tiny_starcoder_py/deployment/test.onnx"
+    SEQUENCE_LENGTH = 256
+    main(PROMPT, HF_MODEL_NAME, ONNX_MODEL_PATH, SEQUENCE_LENGTH)
+
+
+
+
+
+

From 8c7f7992073c7982f48424419df24533d8e67bb4 Mon Sep 17 00:00:00 2001
From: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com>
Date: Fri, 16 Feb 2024 19:15:00 +0100
Subject: [PATCH 4/5] Add files via upload

---
 .../kv_cache_injection.py                     | 88 +++++++++++++++++++
 starcode_kv_cache_injection/validation.py     | 14 +--
 2 files changed, 97 insertions(+), 5 deletions(-)

diff --git a/starcode_kv_cache_injection/kv_cache_injection.py b/starcode_kv_cache_injection/kv_cache_injection.py
index 2cc52a52da6..4b01ee0dd1a 100644
--- a/starcode_kv_cache_injection/kv_cache_injection.py
+++ b/starcode_kv_cache_injection/kv_cache_injection.py
@@ -7,6 +7,7 @@
 from onnx import TensorProto, ModelProto, helper, NodeProto
 from sparseml.onnx.utils import ONNXGraph
 from sparseml.exporters.transforms.kv_cache.transforms_codegen import AdditionalTransformsCodeGen
+from sparseml.onnx.utils.helpers import get_nodes_by_output_id
 
 _LOGGER = logging.getLogger(__name__)
 
@@ -24,6 +25,60 @@ class AdditionalTransformsBigCode(AdditionalTransformsCodeGen):
     # in the onnx graph)
     CAUSAL_MASK_MATCHING_PATTERN = dict(op_type="Unsqueeze", children_ops=[["Where", "Softmax"]])
 
+    def swap_nodes_for_input(
+        self,
+        model: ModelProto,
+        nodes: List[NodeProto],
+        input_name: str,
+        nodes_parent_op_type: Optional[str] = None,
+    ) -> ModelProto:
+
+        """
+        Injects the specified input to the graph, replacing the specified nodes.
+
+        :param model: the ONNX model to inject the input into
+        :param nodes: the nodes to replace with the input
+        :param input_name: the name of the input to replace the nodes with
+        :param nodes_parent_op_type: the parent op type of the nodes to replace
+
+        :return: the updated model
+        """
+
+        graph = ONNXGraph(model)
+        for node in nodes:
+            child_node = graph.get_node_children(node)[0]
+
+            if nodes_parent_op_type:
+                assert child_node.op_type == nodes_parent_op_type, (
+                    f"Expected to find {nodes_parent_op_type} node, "
+                    f"found {child_node.op_type}"
+                )
+            output_to_replace = node.output[0]
+            self.log_match(node)
+            for idx, input_name_child_node in enumerate(child_node.input):
+                if input_name_child_node == output_to_replace:
+                    graph.update_node_input(child_node, input_name, idx)
+            children_nodes = graph.get_node_children(node)
+            for child_node in children_nodes:
+                if nodes_parent_op_type:
+                    assert child_node.op_type == nodes_parent_op_type, (
+                        f"Expected to find {nodes_parent_op_type} node, "
+                        f"found {child_node.op_type}"
+                    )
+                output_to_replace = node.output[0]
+                self.log_match(node)
+                for idx, input_name_child_node in enumerate(child_node.input):
+                    if input_name_child_node == output_to_replace:
+                        graph.update_node_input(child_node, input_name, idx)
+
+        graph.delete_orphaned_node_branches()
+
+        _LOGGER.info(
+            f"Successfully swapped {len(nodes)} nodes for input '{input_name}'"
+        )
+
+        return model
+
     def add_causal_mask_input(self, model: ModelProto) -> ModelProto:
         """
         reformulating this method (originally part of the AdditionalTransformsBase class)
@@ -91,6 +146,36 @@ def swap_nodes_for_input(
 
         return model
 
+    def add_constant_reshape_node(self, model: ModelProto) -> ModelProto:
+        """
+        Adds positions as an input to the model.
+
+        Positions is a tensor of shape and dtype
+        equal to input_ids.
+
+        :param model: model to update
+        :return: updated model
+        """
+        graph = ONNXGraph(model)
+        # create a constant node that will feed value (1, 256, 768) to the reshape node
+        constant_node = onnx.helper.make_node(
+            "Constant",
+            inputs=[],
+            name="abc",
+            outputs=["reshape_input"],
+            value=onnx.helper.make_tensor(
+                name="const_tensor",
+                data_type=TensorProto.INT64,
+                dims=[3],
+                vals=[1, 256, 768],
+            ),
+        )
+        graph.add_node(constant_node)
+        reshape_node = get_nodes_by_output_id(model, "/transformer/Reshape_2_output_0")[0]
+        reshape_node.input[1] = "reshape_input"
+        _LOGGER.info(f"Inserted constant reshape node to the ONNX model")
+        return model
+
     def transform(self, model: ModelProto) -> ModelProto:
         """
         1. Adds `positions` as an input to the model
@@ -107,6 +192,8 @@ def transform(self, model: ModelProto) -> ModelProto:
         """
         model = self.add_positions_input(model)
         model = self.add_causal_mask_input(model)
+        model = self.add_constant_reshape_node(model)
+
 
         position_ids_nodes = self.find_nodes_by_pattern(
             model, pattern=self.POSITION_IDS_MATCHING_PATTERN
@@ -125,6 +212,7 @@ def transform(self, model: ModelProto) -> ModelProto:
         )
         model = self.inject_causal_mask(model, causal_mask_nodes, "Where")
         model = self.adjust_causal_mask(model)
+
         return model
 
 def inject_kv_cache_inputs_outputs(model: ModelProto, names_nodes_producing_kv_tensors, hidden_size_kv_cache, batch_size = 1):
diff --git a/starcode_kv_cache_injection/validation.py b/starcode_kv_cache_injection/validation.py
index cd3150d9ba0..c8eec7a9426 100644
--- a/starcode_kv_cache_injection/validation.py
+++ b/starcode_kv_cache_injection/validation.py
@@ -73,7 +73,7 @@ def create_causal_mask(
         attention_mask = numpy.array(attention_mask)[None, ...]
 
     if input_ids_length == 1:
-        causal_mask = numpy.reshape(attention_mask, (batch_size, 1, 1, sequence_length))
+        causal_mask = numpy.reshape(attention_mask, (batch_size, sequence_length, 1, -1))
         return causal_mask.astype(dtype)
 
     causal_mask = numpy.tril(
@@ -164,19 +164,23 @@ def singletoken_inference_test(onnx_model_path, prompt, config, tokenizer, seque
         positions = np.array([[idx]])
         input_ids = np.array([[token]])
         causal_mask = create_causal_mask(input_ids, attention_mask)
-
-        outputs = session .run(None, {
+        outputs = session.run(None, {
             "input_ids": input_ids,
             "attention_mask": attention_mask,
             "positions": positions,
             "causal_mask": causal_mask,
             **kv_cache
         })
+        logits, *kv_cache = outputs
+        for _idx, (cache_gt, cache) in enumerate(zip(kv_cache_gt, kv_cache)):
+            if np.allclose(cache_gt[:,idx,:], cache[:,-(idx + 1)],atol=1e-3):
+                print(f"Cache {_idx} matches for iteration {idx}")
         # will not run without throwing an error, there are some missing pieces that need to be addressed
 
 def get_baseline(prompt, hf_model_name, tokenizer):
     model = AutoModelForCausalLM.from_pretrained(hf_model_name)
     tokens = tokenizer.encode(prompt, return_tensors="pt")
+    model.generate(tokens[:,:1], max_length=256)
     out = model(tokens, return_dict=True)
     logits_gt = out.logits.detach().numpy()
     kv_cache_gt = [t.detach().numpy() for t in out.past_key_values]
@@ -189,8 +193,8 @@ def main(prompt, hf_model_name, onnx_model_path, sequence_length):
 
     logits_gt, kv_cache_gt = get_baseline(prompt, hf_model_name, tokenizer)
 
-    multitoken_inference_test(onnx_model_path, prompt, config, tokenizer, sequence_length, logits_gt, kv_cache_gt)
-    _LOGGER.info("Successfully ran multi-token inference on the kv cache injected model")
+    #multitoken_inference_test(onnx_model_path, prompt, config, tokenizer, sequence_length, logits_gt, kv_cache_gt)
+    #_LOGGER.info("Successfully ran multi-token inference on the kv cache injected model")
     singletoken_inference_test(onnx_model_path, prompt, config, tokenizer, sequence_length, logits_gt, kv_cache_gt)
     _LOGGER.info("Successfully ran single-token inference on the kv cache injected model")
 

From c6ad1916ce50791b23b20c7222cb6f70c81f2ed7 Mon Sep 17 00:00:00 2001
From: "bogunowicz@arrival.com" <bogunowicz@arrival.com>
Date: Wed, 6 Mar 2024 17:27:40 +0100
Subject: [PATCH 5/5] producing running (hopefully) but incorrect model

---
 .../kv_cache_injection.py                     | 150 +++++++-----------
 starcode_kv_cache_injection/run_model.py      |  10 ++
 starcode_kv_cache_injection/validation.py     |  36 +++--
 3 files changed, 92 insertions(+), 104 deletions(-)
 create mode 100644 starcode_kv_cache_injection/run_model.py

diff --git a/starcode_kv_cache_injection/kv_cache_injection.py b/starcode_kv_cache_injection/kv_cache_injection.py
index 4b01ee0dd1a..907ed0a6e7f 100644
--- a/starcode_kv_cache_injection/kv_cache_injection.py
+++ b/starcode_kv_cache_injection/kv_cache_injection.py
@@ -6,6 +6,7 @@
 from typing import List, Optional
 from onnx import TensorProto, ModelProto, helper, NodeProto
 from sparseml.onnx.utils import ONNXGraph
+from sparseml.exporters.transforms.kv_cache.cache_keys_and_values import reshape_kv_cache_inputs_outputs
 from sparseml.exporters.transforms.kv_cache.transforms_codegen import AdditionalTransformsCodeGen
 from sparseml.onnx.utils.helpers import get_nodes_by_output_id
 
@@ -25,84 +26,6 @@ class AdditionalTransformsBigCode(AdditionalTransformsCodeGen):
     # in the onnx graph)
     CAUSAL_MASK_MATCHING_PATTERN = dict(op_type="Unsqueeze", children_ops=[["Where", "Softmax"]])
 
-    def swap_nodes_for_input(
-        self,
-        model: ModelProto,
-        nodes: List[NodeProto],
-        input_name: str,
-        nodes_parent_op_type: Optional[str] = None,
-    ) -> ModelProto:
-
-        """
-        Injects the specified input to the graph, replacing the specified nodes.
-
-        :param model: the ONNX model to inject the input into
-        :param nodes: the nodes to replace with the input
-        :param input_name: the name of the input to replace the nodes with
-        :param nodes_parent_op_type: the parent op type of the nodes to replace
-
-        :return: the updated model
-        """
-
-        graph = ONNXGraph(model)
-        for node in nodes:
-            child_node = graph.get_node_children(node)[0]
-
-            if nodes_parent_op_type:
-                assert child_node.op_type == nodes_parent_op_type, (
-                    f"Expected to find {nodes_parent_op_type} node, "
-                    f"found {child_node.op_type}"
-                )
-            output_to_replace = node.output[0]
-            self.log_match(node)
-            for idx, input_name_child_node in enumerate(child_node.input):
-                if input_name_child_node == output_to_replace:
-                    graph.update_node_input(child_node, input_name, idx)
-            children_nodes = graph.get_node_children(node)
-            for child_node in children_nodes:
-                if nodes_parent_op_type:
-                    assert child_node.op_type == nodes_parent_op_type, (
-                        f"Expected to find {nodes_parent_op_type} node, "
-                        f"found {child_node.op_type}"
-                    )
-                output_to_replace = node.output[0]
-                self.log_match(node)
-                for idx, input_name_child_node in enumerate(child_node.input):
-                    if input_name_child_node == output_to_replace:
-                        graph.update_node_input(child_node, input_name, idx)
-
-        graph.delete_orphaned_node_branches()
-
-        _LOGGER.info(
-            f"Successfully swapped {len(nodes)} nodes for input '{input_name}'"
-        )
-
-        return model
-
-    def add_causal_mask_input(self, model: ModelProto) -> ModelProto:
-        """
-        reformulating this method (originally part of the AdditionalTransformsBase class)
-        so that the causal mask has shape [batch_size, input_ids_length, 1, sequence_length]
-        vs the original shape [batch_size, 1, input_ids_length, sequence_length]
-        """
-
-        input_ids = self._get_input_proto(model, "input_ids")
-        attention_mask = self._get_input_proto(model, "attention_mask")
-
-        batch_size = input_ids.type.tensor_type.shape.dim[0].dim_param
-        input_ids_length = input_ids.type.tensor_type.shape.dim[1].dim_value
-        sequence_length = attention_mask.type.tensor_type.shape.dim[1].dim_param
-
-        causal_mask_input = helper.make_tensor_value_info(
-            name=self.CAUSAL_MASK_NAME,
-            elem_type=TensorProto.INT64,
-            # this is de-facto the only change from the original method
-            shape=[batch_size, input_ids_length, 1, sequence_length],
-        )
-        model.graph.input.append(causal_mask_input)
-        _LOGGER.info(f"Inserted {self.CAUSAL_MASK_NAME} input to the ONNX model")
-        return model
-
     def swap_nodes_for_input(
         self,
         model: ModelProto,
@@ -175,6 +98,31 @@ def add_constant_reshape_node(self, model: ModelProto) -> ModelProto:
         reshape_node.input[1] = "reshape_input"
         _LOGGER.info(f"Inserted constant reshape node to the ONNX model")
         return model
+    
+    def add_causal_mask_reshape_node(self, model: ModelProto) -> ModelProto:
+        """
+        Adds positions as an input to the model.
+
+        Positions is a tensor of shape and dtype
+        equal to input_ids.
+
+        :param model: model to update
+        :return: updated model
+        """
+        graph = ONNXGraph(model)
+
+        transpose_node = onnx.helper.make_node(
+        op_type="Transpose",
+        inputs=["causal_mask"],
+        outputs=["causal_mask_transpose"],
+        name=f"causal_mask_transpose",
+        perm=(0,3,2,1),
+    )
+        graph.add_node(transpose_node)
+        reshape_node = get_nodes_by_output_id(model, "causal_mask_adjusted")[0]
+        reshape_node.input[0] = "causal_mask_transpose"
+        _LOGGER.info(f"Inserted transpose to the causal mask in  the ONNX model")
+        return model
 
     def transform(self, model: ModelProto) -> ModelProto:
         """
@@ -212,29 +160,31 @@ def transform(self, model: ModelProto) -> ModelProto:
         )
         model = self.inject_causal_mask(model, causal_mask_nodes, "Where")
         model = self.adjust_causal_mask(model)
-
+        model = self.add_causal_mask_reshape_node(model)
         return model
 
-def inject_kv_cache_inputs_outputs(model: ModelProto, names_nodes_producing_kv_tensors, hidden_size_kv_cache, batch_size = 1):
+def inject_kv_cache_inputs_outputs(model: ModelProto, names_nodes: List[str], hidden_size_kv_cache, batch_size = 1, key: bool = True, output_num:int=0):
     graph = ONNXGraph(model)
 
     inputs_to_add = []
     outputs_to_add = []
-
+    num_attention_heads = 1
     attention_layer_idx = 0
 
     for node in model.graph.node:
-        if node.name in names_nodes_producing_kv_tensors:
+        if node.name in names_nodes:
 
             # inject kv cache input/output
-            cache_input_name_concat = f"past_key_values.{attention_layer_idx}"
-            cache_output_name_concat = f"present.{attention_layer_idx}"
+            cache_name = "key" if key else "value"
+            cache_input_name_concat = f"past_key_values.{attention_layer_idx}.{cache_name}"
+            cache_output_name_concat = f"present.{attention_layer_idx}.{cache_name}"
 
             cache_input_info = onnx.helper.make_tensor_value_info(
                 cache_input_name_concat,
                 TensorProto.FLOAT,
                 [
                     batch_size,
+                    num_attention_heads,
                     "past_sequence_len",
                     hidden_size_kv_cache,
                 ]
@@ -245,17 +195,30 @@ def inject_kv_cache_inputs_outputs(model: ModelProto, names_nodes_producing_kv_t
                 TensorProto.FLOAT,
                 [
                     batch_size,
+                    num_attention_heads, 
                     "past_sequence_len + 1",
                     hidden_size_kv_cache,
                 ]
             )
 
+            model, cache_input_dims_concat, cache_input_name_concat, cache_output_name_concat = reshape_kv_cache_inputs_outputs(
+                model=model,
+                cache_input_name=cache_input_name_concat,
+                cache_output_name=cache_output_name_concat,
+                cache_input_dims= [
+                    batch_size,
+                    num_attention_heads,
+                    "past_sequence_len",
+                    hidden_size_kv_cache,
+                ],
+                batch_size=batch_size,
+                num_attention_heads=1,
+            )
             cache_parent = node
             concat_axis = 1 # concat over length axis
-
             concat_node = onnx.helper.make_node(
                 op_type="Concat",
-                inputs=[cache_input_name_concat, cache_parent.output[1]],
+                inputs=[cache_input_name_concat, cache_parent.output[output_num]],
                 outputs=[cache_output_name_concat],
                 axis=concat_axis,
                 name=f"concat.{cache_input_name_concat}",
@@ -263,7 +226,7 @@ def inject_kv_cache_inputs_outputs(model: ModelProto, names_nodes_producing_kv_t
 
             for _node in model.graph.node:
                 for input_idx, input_id in enumerate(_node.input):
-                    if input_id == cache_parent.output[1] and _node.name != concat_node.name:
+                    if input_id == cache_parent.output[output_num] and _node.name != concat_node.name:
                         _node.input[input_idx] = cache_output_name_concat
 
             graph.add_node(concat_node)
@@ -271,7 +234,7 @@ def inject_kv_cache_inputs_outputs(model: ModelProto, names_nodes_producing_kv_t
             outputs_to_add.extend([cache_output_info])
 
             attention_layer_idx += 1
-            _LOGGER.info(f"Injected kv cache input/output for attention layer {attention_layer_idx}")
+            print(f"Injected kv cache input/output for {attention_layer_idx}:{cache_name}")
 
     model.graph.input.extend(inputs_to_add)
     model.graph.output.extend(outputs_to_add)
@@ -283,8 +246,15 @@ def main(deployment_folder_path, save_name_injected_model):
     config = AutoConfig.from_pretrained(os.path.join(deployment_folder_path, "config.json"))
     # KV Cache injection
     onnx_model = inject_kv_cache_inputs_outputs(model = onnx_model,
-                                                names_nodes_producing_kv_tensors=[f"/transformer/h.{i}/attn/Split" for i in range(config.n_layer)],
-                                                hidden_size_kv_cache=2 * config.n_embd // config.n_head)
+                                                names_nodes=[f"/transformer/h.{i}/attn/Split_1" for i in range(config.n_layer)],
+                                                hidden_size_kv_cache= config.n_embd // config.n_head,
+                                                key=True,
+                                                output_num=0)
+    onnx_model = inject_kv_cache_inputs_outputs(model = onnx_model,
+                                                names_nodes=[f"/transformer/h.{i}/attn/Split_1" for i in range(config.n_layer)],
+                                                hidden_size_kv_cache= config.n_embd // config.n_head,
+                                                key=False,
+                                                output_num=1)
     # Adjustment of causal masks and positions
     transformation = AdditionalTransformsBigCode()
     onnx_model = transformation.transform(model = onnx_model)
diff --git a/starcode_kv_cache_injection/run_model.py b/starcode_kv_cache_injection/run_model.py
new file mode 100644
index 00000000000..170fef0632a
--- /dev/null
+++ b/starcode_kv_cache_injection/run_model.py
@@ -0,0 +1,10 @@
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+checkpoint = "bigcode/tiny_starcoder_py"
+device="cpu"
+tokenizer = AutoTokenizer.from_pretrained(checkpoint)
+model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device)
+
+inputs = tokenizer.encode("def print_hello_world():", return_tensors="pt").to(device)
+outputs = model.generate(inputs, max_new_tokens=10)
+print(tokenizer.decode(outputs[0]))
\ No newline at end of file
diff --git a/starcode_kv_cache_injection/validation.py b/starcode_kv_cache_injection/validation.py
index c8eec7a9426..89e734cabd8 100644
--- a/starcode_kv_cache_injection/validation.py
+++ b/starcode_kv_cache_injection/validation.py
@@ -73,7 +73,7 @@ def create_causal_mask(
         attention_mask = numpy.array(attention_mask)[None, ...]
 
     if input_ids_length == 1:
-        causal_mask = numpy.reshape(attention_mask, (batch_size, sequence_length, 1, -1))
+        causal_mask = numpy.reshape(attention_mask, (batch_size, 1, 1, sequence_length))
         return causal_mask.astype(dtype)
 
     causal_mask = numpy.tril(
@@ -89,14 +89,13 @@ def create_causal_mask(
 
     # changes to the original function
     causal_mask[:, :, num_zeros:, :] = 0
-    causal_mask = causal_mask.reshape(1, sequence_length, 1, -1)
 
     return causal_mask
 
 def apply_input_shapes(model, onnx_model_path, sequence_length, config):
     kv_cache_hidden_dim = config.n_embd // config.n_head
-    cache_changes_in = {n.name: [1, "dynamic_len_1", 2 * kv_cache_hidden_dim] for n in model.graph.input if n.name.startswith("past_key_values")}
-    cache_changes_out = {n.name: [1, "dynamic_len_2", 2 * kv_cache_hidden_dim] for n in model.graph.output if n.name.startswith("present")}
+    cache_changes_in = {n.name: [1, 1,"dynamic_len_1", kv_cache_hidden_dim] for n in model.graph.input if n.name.startswith("past_key_values")}
+    cache_changes_out = {n.name: [1, 1,"dynamic_len_2", kv_cache_hidden_dim] for n in model.graph.output if n.name.startswith("present")}
     graph = ONNXGraph(model)
 
     graph.delete_unused_initializers()
@@ -107,9 +106,8 @@ def apply_input_shapes(model, onnx_model_path, sequence_length, config):
                                                          {"input_ids": [1, "dynamic_len_3"],
                                                           "positions": [1, "dynamic_len_4"],
                                                           "attention_mask": [1, sequence_length],
-                                                          "causal_mask": [1, "dynamic_len_5", 1, "dynamic_len_6"],
+                                                          "causal_mask": [1, 1, "dynamic_len_5", "dynamic_len_6"],
                                                           **cache_changes_in},
-
                                                           {"logits": [1, "dynamic_len_6", config.vocab_size], **cache_changes_out})
 
     onnx.save(model, onnx_model_path)
@@ -123,8 +121,11 @@ def multitoken_inference_test(onnx_model_path, prompt, config, tokenizer, sequen
     inputs = tokenizer(prompt, return_tensors="np", padding='max_length', max_length=sequence_length)
     input_ids = inputs.input_ids  # (1, sequence_length)
     attention_mask = inputs.attention_mask  # (1, sequence_length)
-    kv_cache = {f"past_key_values.{i}": np.zeros((1, 0, 2 * kv_cache_hidden_dim), dtype=np.float32) for i in
-                range(config.n_layer)}  # (1, 0, 2 * embedding [because we have k and v's concatenated])
+    kv_cache_value = {f"past_key_values.{i}.value": np.zeros((1, 1, 0, kv_cache_hidden_dim), dtype=np.float32) for i in
+                range(config.n_layer)}  # (1, 0, embedding)
+    kv_cache_keys = {f"past_key_values.{i}.key": np.zeros((1, 1, 0, kv_cache_hidden_dim), dtype=np.float32) for i in
+                range(config.n_layer)}  # (1, 0, embedding)
+    kv_cache = {**kv_cache_keys, **kv_cache_value}
     causal_mask = create_causal_mask(input_ids, attention_mask)  # (1, sequence_length, 1, sequence_length)
     positions = attention_mask.cumsum(-1) - 1  # (1, sequence_length)
 
@@ -154,7 +155,9 @@ def singletoken_inference_test(onnx_model_path, prompt, config, tokenizer, seque
     kv_cache_hidden_dim = config.n_embd // config.n_head
     inputs = tokenizer(prompt, return_tensors="np")
     attention_mask = np.zeros((1, sequence_length), dtype=np.int64)
-    kv_cache = {f"past_key_values.{i}": np.zeros((1,sequence_length-1, 2 * kv_cache_hidden_dim), dtype=np.float32) for i in range(config.n_layer)}
+    kv_cache_keys = {f"past_key_values.{i}.key": np.zeros((1,1,sequence_length-1, kv_cache_hidden_dim), dtype=np.float32) for i in range(config.n_layer)}
+    kv_cache_values = {f"past_key_values.{i}.value": np.zeros((1,1,sequence_length-1, kv_cache_hidden_dim), dtype=np.float32) for i in range(config.n_layer)}
+    kv_cache = {**kv_cache_keys, **kv_cache_values}
     session = ort.InferenceSession(onnx_model_path)
 
     for idx, token in enumerate(inputs.input_ids[0]):
@@ -164,6 +167,11 @@ def singletoken_inference_test(onnx_model_path, prompt, config, tokenizer, seque
         positions = np.array([[idx]])
         input_ids = np.array([[token]])
         causal_mask = create_causal_mask(input_ids, attention_mask)
+        print(causal_mask.shape)
+        print(input_ids.shape)
+        print(attention_mask.shape)
+        print(positions)
+        print(kv_cache["past_key_values.0.key"].shape)
         outputs = session.run(None, {
             "input_ids": input_ids,
             "attention_mask": attention_mask,
@@ -171,10 +179,10 @@ def singletoken_inference_test(onnx_model_path, prompt, config, tokenizer, seque
             "causal_mask": causal_mask,
             **kv_cache
         })
-        logits, *kv_cache = outputs
-        for _idx, (cache_gt, cache) in enumerate(zip(kv_cache_gt, kv_cache)):
-            if np.allclose(cache_gt[:,idx,:], cache[:,-(idx + 1)],atol=1e-3):
-                print(f"Cache {_idx} matches for iteration {idx}")
+        #logits, *kv_cache = outputs
+        #for _idx, (cache_gt, cache) in enumerate(zip(kv_cache_gt, kv_cache)):
+        #    if np.allclose(cache_gt[:,idx,:], cache[:,-(idx + 1)],atol=1e-3):
+        #        print(f"Cache {_idx} matches for iteration {idx}")
         # will not run without throwing an error, there are some missing pieces that need to be addressed
 
 def get_baseline(prompt, hf_model_name, tokenizer):
@@ -194,7 +202,7 @@ def main(prompt, hf_model_name, onnx_model_path, sequence_length):
     logits_gt, kv_cache_gt = get_baseline(prompt, hf_model_name, tokenizer)
 
     #multitoken_inference_test(onnx_model_path, prompt, config, tokenizer, sequence_length, logits_gt, kv_cache_gt)
-    #_LOGGER.info("Successfully ran multi-token inference on the kv cache injected model")
+   # _LOGGER.info("Successfully ran multi-token inference on the kv cache injected model")
     singletoken_inference_test(onnx_model_path, prompt, config, tokenizer, sequence_length, logits_gt, kv_cache_gt)
     _LOGGER.info("Successfully ran single-token inference on the kv cache injected model")