From 5d7f655bd00aa4cd28795ab9076d3b42ac1b896e Mon Sep 17 00:00:00 2001
From: Joshua Lochner <admin@xenova.com>
Date: Sat, 26 Oct 2024 05:09:52 +0000
Subject: [PATCH] Update llama-3.2-webgpu demo

---
 llama-3.2-webgpu/package-lock.json | 16 ++++++++--------
 llama-3.2-webgpu/package.json      |  2 +-
 llama-3.2-webgpu/src/App.jsx       |  4 ++--
 llama-3.2-webgpu/src/worker.js     | 12 +++++-------
 4 files changed, 16 insertions(+), 18 deletions(-)

diff --git a/llama-3.2-webgpu/package-lock.json b/llama-3.2-webgpu/package-lock.json
index 6d223b01..aec5be77 100644
--- a/llama-3.2-webgpu/package-lock.json
+++ b/llama-3.2-webgpu/package-lock.json
@@ -8,7 +8,7 @@
       "name": "llama-3.2-webgpu",
       "version": "0.0.0",
       "dependencies": {
-        "@huggingface/transformers": "3.0.0",
+        "@huggingface/transformers": "3.0.1",
         "dompurify": "^3.1.2",
         "marked": "^12.0.2",
         "react": "^18.3.1",
@@ -835,13 +835,13 @@
       }
     },
     "node_modules/@huggingface/transformers": {
-      "version": "3.0.0",
-      "resolved": "https://registry.npmjs.org/@huggingface/transformers/-/transformers-3.0.0.tgz",
-      "integrity": "sha512-OWIPnTijAw4DQ+IFHBOrej2SDdYyykYlTtpTLCEt5MZq/e9Cb65RS2YVhdGcgbaW/6JAL3i8ZA5UhDeWGm4iRQ==",
+      "version": "3.0.1",
+      "resolved": "https://registry.npmjs.org/@huggingface/transformers/-/transformers-3.0.1.tgz",
+      "integrity": "sha512-lXmF0/p+ZdQX0NKTybLUCzIKr0sKD6BfqtjL7olaLx2JHAM3HKVrvFjWeFe2lQRkhL6cEcFw2WXs7o8nZU/WGg==",
       "dependencies": {
         "@huggingface/jinja": "^0.3.0",
         "onnxruntime-node": "1.19.2",
-        "onnxruntime-web": "1.20.0-dev.20241016-2b8fc5529b",
+        "onnxruntime-web": "1.21.0-dev.20241024-d9ca84ef96",
         "sharp": "^0.33.5"
       }
     },
@@ -4260,9 +4260,9 @@
       }
     },
     "node_modules/onnxruntime-web": {
-      "version": "1.20.0-dev.20241016-2b8fc5529b",
-      "resolved": "https://registry.npmjs.org/onnxruntime-web/-/onnxruntime-web-1.20.0-dev.20241016-2b8fc5529b.tgz",
-      "integrity": "sha512-1XovqtgqeEFtupuyzdDQo7Tqj4GRyNHzOoXjapCEo4rfH3JrXok5VtqucWfRXHPsOI5qoNxMQ9VE+drDIp6woQ==",
+      "version": "1.21.0-dev.20241024-d9ca84ef96",
+      "resolved": "https://registry.npmjs.org/onnxruntime-web/-/onnxruntime-web-1.21.0-dev.20241024-d9ca84ef96.tgz",
+      "integrity": "sha512-ANSQfMALvCviN3Y4tvTViKofKToV1WUb2r2VjZVCi3uUBPaK15oNJyIxhsNyEckBr/Num3JmSXlkHOD8HfVzSQ==",
       "dependencies": {
         "flatbuffers": "^1.12.0",
         "guid-typescript": "^1.0.9",
diff --git a/llama-3.2-webgpu/package.json b/llama-3.2-webgpu/package.json
index d786987d..b610fe5e 100644
--- a/llama-3.2-webgpu/package.json
+++ b/llama-3.2-webgpu/package.json
@@ -10,7 +10,7 @@
     "preview": "vite preview"
   },
   "dependencies": {
-    "@huggingface/transformers": "3.0.0",
+    "@huggingface/transformers": "3.0.1",
     "dompurify": "^3.1.2",
     "marked": "^12.0.2",
     "react": "^18.3.1",
diff --git a/llama-3.2-webgpu/src/App.jsx b/llama-3.2-webgpu/src/App.jsx
index e818f807..a4ba733d 100644
--- a/llama-3.2-webgpu/src/App.jsx
+++ b/llama-3.2-webgpu/src/App.jsx
@@ -10,7 +10,7 @@ const STICKY_SCROLL_THRESHOLD = 120;
 const EXAMPLES = [
   "Give me some tips to improve my time management skills.",
   "What is the difference between AI and ML?",
-  "Write python code to compute the nth fibonacci number.",
+  "Write Python code to perform merge sort.",
 ];
 
 function App() {
@@ -209,7 +209,7 @@ function App() {
               <br />
               You are about to load{" "}
               <a
-                href="https://huggingface.co/onnx-community/Llama-3.2-1B-Instruct-q4f16"
+                href="https://huggingface.co/onnx-community/Llama-3.2-1B-Instruct-onnx-web-gqa"
                 target="_blank"
                 rel="noreferrer"
                 className="font-medium underline"
diff --git a/llama-3.2-webgpu/src/worker.js b/llama-3.2-webgpu/src/worker.js
index 8bfedbea..da6f3331 100644
--- a/llama-3.2-webgpu/src/worker.js
+++ b/llama-3.2-webgpu/src/worker.js
@@ -9,7 +9,7 @@ import {
  * This class uses the Singleton pattern to enable lazy-loading of the pipeline
  */
 class TextGenerationPipeline {
-  static model_id = "onnx-community/Llama-3.2-1B-Instruct-q4f16";
+  static model_id = "onnx-community/Llama-3.2-1B-Instruct-onnx-web-gqa";
 
   static async getInstance(progress_callback = null) {
     this.tokenizer ??= AutoTokenizer.from_pretrained(this.model_id, {
@@ -17,7 +17,6 @@ class TextGenerationPipeline {
     });
 
     this.model ??= AutoModelForCausalLM.from_pretrained(this.model_id, {
-      dtype: "q4f16",
       device: "webgpu",
       progress_callback,
     });
@@ -69,18 +68,17 @@ async function generate(messages) {
 
   const { past_key_values, sequences } = await model.generate({
     ...inputs,
-    // TODO: Add when model is fixed
-    // past_key_values: past_key_values_cache,
+    past_key_values: past_key_values_cache,
 
     // Sampling
-    do_sample: false,
+    // do_sample: true,
 
     max_new_tokens: 1024,
     streamer,
     stopping_criteria,
     return_dict_in_generate: true,
   });
-  // past_key_values_cache = past_key_values;
+  past_key_values_cache = past_key_values;
 
   const decoded = tokenizer.batch_decode(sequences, {
     skip_special_tokens: true,
@@ -153,7 +151,7 @@ self.addEventListener("message", async (e) => {
       break;
 
     case "reset":
-      // past_key_values_cache = null;
+      past_key_values_cache = null;
       stopping_criteria.reset();
       break;
   }