From 080a53a58ee65aad262016b8460459fc77b5ad84 Mon Sep 17 00:00:00 2001 From: john chen Date: Fri, 13 Jun 2025 17:19:00 +0800 Subject: [PATCH 01/79] fix Rotation matri form of RoPE (#25) --- book/src/week1-02-positional-encodings.md | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/book/src/week1-02-positional-encodings.md b/book/src/week1-02-positional-encodings.md index aaf9c97..c105411 100644 --- a/book/src/week1-02-positional-encodings.md +++ b/book/src/week1-02-positional-encodings.md @@ -41,10 +41,10 @@ Note that, practically, D can be even or odd. In the case of D being odd, the la and is typically left untouched in most implementations. For simplicity, we just assume that D is always even. ``` -output[0] = x[0] * cos_freqs[0] + x[1] * sin_freqs[0] -output[1] = x[0] * -sin_freqs[0] + x[1] * cos_freqs[0] -output[2] = x[2] * cos_freqs[1] + x[3] * sin_freqs[1] -output[3] = x[2] * -sin_freqs[1] + x[3] * cos_freqs[1] +output[0] = x[0] * cos_freqs[0] + x[1] * -sin_freqs[0] +output[1] = x[0] * sin_freqs[0] + x[1] * cos_freqs[0] +output[2] = x[2] * cos_freqs[1] + x[3] * -sin_freqs[1] +output[3] = x[2] * sin_freqs[1] + x[3] * cos_freqs[1] ...and so on ``` @@ -67,10 +67,10 @@ The Qwen2 model uses a non-traditional form of RoPE. In this form, the head embe and the two halves are applied with different frequencies. Let's say `x1 = x[.., :HALF_DIM]` and `x2 = x[.., HALF_DIM:]`. ``` -output[0] = x1[0] * cos_freqs[0] + x2[0] * sin_freqs[0] -output[HALF_DIM] = x1[0] * -sin_freqs[0] + x2[0] * cos_freqs[0] -output[1] = x1[1] * cos_freqs[1] + x2[1] * sin_freqs[1] -output[HALF_DIM + 1] = x1[1] * -sin_freqs[1] + x2[1] * cos_freqs[1] +output[0] = x1[0] * cos_freqs[0] + x2[0] * -sin_freqs[0] +output[HALF_DIM] = x1[0] * sin_freqs[0] + x2[0] * cos_freqs[0] +output[1] = x1[1] * cos_freqs[1] + x2[1] * -sin_freqs[1] +output[HALF_DIM + 1] = x1[1] * sin_freqs[1] + x2[1] * cos_freqs[1] ...and so on ``` From cd5445923a66a4a141d06c4f97aba68affc8ceee Mon Sep 17 00:00:00 2001 From: Alex Chi Date: Sat, 14 Jun 2025 14:42:12 +0800 Subject: [PATCH 02/79] add back installation check script Signed-off-by: Alex Chi --- book/src/setup.md | 2 +- pyproject.toml | 1 + scripts/check-installation.py | 20 ++++++++++++++++++++ 3 files changed, 22 insertions(+), 1 deletion(-) create mode 100644 scripts/check-installation.py diff --git a/book/src/setup.md b/book/src/setup.md index 5197bab..dedf1c3 100644 --- a/book/src/setup.md +++ b/book/src/setup.md @@ -34,7 +34,7 @@ pdm install -v # this will automatically create a virtual environment and instal ## Check the Installation ```bash -pdm run python check.py +pdm run check-installation # The reference solution should pass all the *week 1* tests pdm run test-refsol -- -- -k week_1 ``` diff --git a/pyproject.toml b/pyproject.toml index 9dd6ce2..afd3f1a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ main-week1.cmd = "python main.py --loader week1" main-week2.cmd = "python main.py --loader week2" batch-main.cmd = "python batch-main.py" test.cmd = "python scripts/dev-tools.py test" +check-installation.cmd = "python scripts/check-installation.py" test-refsol.cmd = "python scripts/dev-tools.py test-refsol" bench.cmd = "pytest benches" format = "ruff format" diff --git a/scripts/check-installation.py b/scripts/check-installation.py new file mode 100644 index 0000000..783b79b --- /dev/null +++ b/scripts/check-installation.py @@ -0,0 +1,20 @@ +import mlx.core as mx +import torch + +with mx.stream(mx.cpu): + a = mx.array([1, 2, 3]) + b = mx.array([4, 5, 6]) + c = mx.add(a, b) + print(c) + +with mx.stream(mx.gpu): + a = mx.array([1, 2, 3]) + b = mx.array([4, 5, 6]) + c = mx.add(a, b) + print(c) + +print( + torch.add( + torch.tensor([1, 2, 3], device="cpu"), torch.tensor([4, 5, 6], device="cpu") + ) +) From 55066c343aafae0de2400e69e83da0dedf1bd85d Mon Sep 17 00:00:00 2001 From: "Alex Chi Z." <4198311+skyzh@users.noreply.github.com> Date: Tue, 22 Jul 2025 21:24:41 -0400 Subject: [PATCH 03/79] bump mlx to latest version (#33) Signed-off-by: Alex Chi Z --- pdm.lock | 183 ++++++++++---------- pyproject.toml | 4 + src/extensions/src/axpby.cpp | 3 +- src/extensions/src/axpby.h | 2 +- src/extensions/src/utils.cpp | 2 +- src/extensions_ref/src/flash_attention.cpp | 3 +- src/extensions_ref/src/quantized_matmul.cpp | 3 +- src/extensions_ref/src/tiny_llm_ext.h | 4 +- src/extensions_ref/src/utils.cpp | 2 +- 9 files changed, 110 insertions(+), 96 deletions(-) diff --git a/pdm.lock b/pdm.lock index 4868855..d649d46 100644 --- a/pdm.lock +++ b/pdm.lock @@ -783,28 +783,31 @@ files = [ [[package]] name = "mlx" -version = "0.25.2" +version = "0.26.5" requires_python = ">=3.9" summary = "A framework for machine learning on Apple silicon." groups = ["default"] +dependencies = [ + "mlx-metal==0.26.5", +] files = [ - {file = "mlx-0.25.2-cp310-cp310-macosx_13_0_arm64.whl", hash = "sha256:112a02030b14c0e08eeb4e6a1c6e0d259d2c6ae28cfb684d02a32e720e8656da"}, - {file = "mlx-0.25.2-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:5b1173bafca02c2e4b7902a2440bf635d2588f5ab40f68fa83bad43ceca818f2"}, - {file = "mlx-0.25.2-cp310-cp310-macosx_15_0_arm64.whl", hash = "sha256:22f5e80bf2424dc73052e48eeb769056514b78ea28d66a33a5daf1cb83e1acaa"}, - {file = "mlx-0.25.2-cp310-cp310-manylinux_2_31_x86_64.whl", hash = "sha256:e44385a2f5d37d586b245bd7f6ed79998429129880cff0e89f308da73d8f4b56"}, - {file = "mlx-0.25.2-cp311-cp311-macosx_13_0_arm64.whl", hash = "sha256:ab9d9eda941004d805d7365006d65789c88a37b6df896bf8c7e9ea1ca85630d3"}, - {file = "mlx-0.25.2-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:c89ec8726e8cb48196fc07435579482a49d3d2d484b5e2c4d15dfd3a390eec91"}, - {file = "mlx-0.25.2-cp311-cp311-macosx_15_0_arm64.whl", hash = "sha256:f3ef4e8e9a6adc43371f242445423147cce338d1b1368366b84f9dd74483a980"}, - {file = "mlx-0.25.2-cp311-cp311-manylinux_2_31_x86_64.whl", hash = "sha256:3bbce1ad8f6066376ed82002af05fe2451776fd345600e7d4743777898aaaf49"}, - {file = "mlx-0.25.2-cp312-cp312-macosx_13_0_arm64.whl", hash = "sha256:c9d35eac8d3224c9e005e7d469dd927dbdecd3d01b2e23390fd0d496e0010b7f"}, - {file = "mlx-0.25.2-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:c965289d64814005f85c4b121d5df00889a0ea5160883a62edb396b360bf5736"}, - {file = "mlx-0.25.2-cp312-cp312-macosx_15_0_arm64.whl", hash = "sha256:24055a8fba9a6261d94dbfd31c37ede75106c7a832cb0609cb5b65e7e903e1fb"}, - {file = "mlx-0.25.2-cp312-cp312-manylinux_2_31_x86_64.whl", hash = "sha256:d7cc749e66e759ccad70d736896e0bc71d15f99fdbccfaa292da65fd6988bab8"}, + {file = "mlx-0.26.5-cp310-cp310-macosx_13_0_arm64.whl", hash = "sha256:027cf842643ee27176e24c604b453f3200d9c03a96aa0a24c7f9de7027e87485"}, + {file = "mlx-0.26.5-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:f7f26bb955b7b33564ff93f06050d086639dc7e55d4addda19f3658960822cf4"}, + {file = "mlx-0.26.5-cp310-cp310-macosx_15_0_arm64.whl", hash = "sha256:907c1fadbd3a13db40e0f455d19cc1750ec9961432b603605b501945126927a7"}, + {file = "mlx-0.26.5-cp310-cp310-manylinux_2_34_x86_64.manylinux_2_35_x86_64.whl", hash = "sha256:c9dfd46c9cf60e12b3440d8224068826f3c58b0ff2afadd6185f907ae587df17"}, + {file = "mlx-0.26.5-cp311-cp311-macosx_13_0_arm64.whl", hash = "sha256:04367b20970a081c4359111c667ba62788592fd66184943c7eff10d660a13fce"}, + {file = "mlx-0.26.5-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:aa14defb1c5b330c94e9873d09c90f27d72681970c20c1490457c8a2a51b93e5"}, + {file = "mlx-0.26.5-cp311-cp311-macosx_15_0_arm64.whl", hash = "sha256:1e2d90d391c83932dee9a21ce5c95de30a60ce7fdac61143e3a15a4d8e55271d"}, + {file = "mlx-0.26.5-cp311-cp311-manylinux_2_34_x86_64.manylinux_2_35_x86_64.whl", hash = "sha256:b6ab9ee55d0c2e2a67d66b35bdd52adf5e21adf9fdc4c2cc2e4c6f465ec30145"}, + {file = "mlx-0.26.5-cp312-cp312-macosx_13_0_arm64.whl", hash = "sha256:a29a060d089906a8ce753d0cb605c9c20005ce1865ccde8d06ad91c8c4325dc7"}, + {file = "mlx-0.26.5-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:d754e964c7c320c67f4a1a00d6d2d60c1eec8213b83ebe754702bf6dc51a36a4"}, + {file = "mlx-0.26.5-cp312-cp312-macosx_15_0_arm64.whl", hash = "sha256:77c5b750b24a18ed6e433dc46d787193b124a869b1b65d1532bfb7d37ec7172f"}, + {file = "mlx-0.26.5-cp312-cp312-manylinux_2_34_x86_64.manylinux_2_35_x86_64.whl", hash = "sha256:9c2b5dc36c57c8e2fe2245ca8cb2db167d053395afd53979f18728e2311cd633"}, ] [[package]] name = "mlx-lm" -version = "0.24.1" +version = "0.26.0" requires_python = ">=3.8" summary = "LLMs on Apple silicon with MLX and the Hugging Face Hub" groups = ["default"] @@ -814,11 +817,23 @@ dependencies = [ "numpy", "protobuf", "pyyaml", - "transformers[sentencepiece]>=4.39.3", + "transformers>=4.39.3", ] files = [ - {file = "mlx_lm-0.24.1-py3-none-any.whl", hash = "sha256:8a47c2ac73b1ea5f1f4569d48420b4cd2fdb38972eb14fd45340c3341bf0689f"}, - {file = "mlx_lm-0.24.1.tar.gz", hash = "sha256:5df8aa87504df28bd77c91be09f68c551bde6f141f204d3b37157f3c3ede26c1"}, + {file = "mlx_lm-0.26.0-py3-none-any.whl", hash = "sha256:b00294c26242cd50db4b6e3ec3a2baf1cfdf8ca49a5e6057dce14642fabe0d21"}, + {file = "mlx_lm-0.26.0.tar.gz", hash = "sha256:78980ad994baf976779cc1c34c0d55c1c6b63dffef4899d67fec240d0c443b52"}, +] + +[[package]] +name = "mlx-metal" +version = "0.26.5" +requires_python = ">=3.9" +summary = "A framework for machine learning on Apple silicon." +groups = ["default"] +files = [ + {file = "mlx_metal-0.26.5-py3-none-macosx_13_0_arm64.whl", hash = "sha256:6e58b4ec234bd23d04b817642245832b5a5ac48d06ed26b97b4dc6dba5b40aa3"}, + {file = "mlx_metal-0.26.5-py3-none-macosx_14_0_arm64.whl", hash = "sha256:d6f6a7d1110562544978d6f87a5bbb47be293d3b85c742043f8f2048e9f387cf"}, + {file = "mlx_metal-0.26.5-py3-none-macosx_15_0_arm64.whl", hash = "sha256:f5bd394c7ff6eebaaf8db6d7bd4f8dec96f2d232f08ba641673ab66f22e9727b"}, ] [[package]] @@ -1661,22 +1676,34 @@ files = [ ] [[package]] -name = "pytest" -version = "8.3.5" +name = "pygments" +version = "2.19.2" requires_python = ">=3.8" +summary = "Pygments is a syntax highlighting package written in Python." +groups = ["default"] +files = [ + {file = "pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b"}, + {file = "pygments-2.19.2.tar.gz", hash = "sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887"}, +] + +[[package]] +name = "pytest" +version = "8.4.1" +requires_python = ">=3.9" summary = "pytest: simple powerful testing with Python" groups = ["default"] dependencies = [ - "colorama; sys_platform == \"win32\"", - "exceptiongroup>=1.0.0rc8; python_version < \"3.11\"", - "iniconfig", - "packaging", + "colorama>=0.4; sys_platform == \"win32\"", + "exceptiongroup>=1; python_version < \"3.11\"", + "iniconfig>=1", + "packaging>=20", "pluggy<2,>=1.5", + "pygments>=2.7.2", "tomli>=1; python_version < \"3.11\"", ] files = [ - {file = "pytest-8.3.5-py3-none-any.whl", hash = "sha256:c69214aa47deac29fad6c2a4f590b9c4a9fdb16a403176fe154b79c0b4d4d820"}, - {file = "pytest-8.3.5.tar.gz", hash = "sha256:f4efe70cc14e511565ac476b57c279e12a855b11f48f212af1080ef2263d3845"}, + {file = "pytest-8.4.1-py3-none-any.whl", hash = "sha256:539c70ba6fcead8e78eebbf1115e8b589e7565830d7d006a8723f19ac8a0afb7"}, + {file = "pytest-8.4.1.tar.gz", hash = "sha256:7c67fd69174877359ed9371ec3af8a3d2b04741818c51e5e99cc1742251fa93c"}, ] [[package]] @@ -1881,29 +1908,29 @@ files = [ [[package]] name = "ruff" -version = "0.11.10" +version = "0.12.4" requires_python = ">=3.7" summary = "An extremely fast Python linter and code formatter, written in Rust." groups = ["default"] files = [ - {file = "ruff-0.11.10-py3-none-linux_armv6l.whl", hash = "sha256:859a7bfa7bc8888abbea31ef8a2b411714e6a80f0d173c2a82f9041ed6b50f58"}, - {file = "ruff-0.11.10-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:968220a57e09ea5e4fd48ed1c646419961a0570727c7e069842edd018ee8afed"}, - {file = "ruff-0.11.10-py3-none-macosx_11_0_arm64.whl", hash = "sha256:1067245bad978e7aa7b22f67113ecc6eb241dca0d9b696144256c3a879663bca"}, - {file = "ruff-0.11.10-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f4854fd09c7aed5b1590e996a81aeff0c9ff51378b084eb5a0b9cd9518e6cff2"}, - {file = "ruff-0.11.10-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8b4564e9f99168c0f9195a0fd5fa5928004b33b377137f978055e40008a082c5"}, - {file = "ruff-0.11.10-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5b6a9cc5b62c03cc1fea0044ed8576379dbaf751d5503d718c973d5418483641"}, - {file = "ruff-0.11.10-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:607ecbb6f03e44c9e0a93aedacb17b4eb4f3563d00e8b474298a201622677947"}, - {file = "ruff-0.11.10-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7b3a522fa389402cd2137df9ddefe848f727250535c70dafa840badffb56b7a4"}, - {file = "ruff-0.11.10-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2f071b0deed7e9245d5820dac235cbdd4ef99d7b12ff04c330a241ad3534319f"}, - {file = "ruff-0.11.10-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4a60e3a0a617eafba1f2e4186d827759d65348fa53708ca547e384db28406a0b"}, - {file = "ruff-0.11.10-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:da8ec977eaa4b7bf75470fb575bea2cb41a0e07c7ea9d5a0a97d13dbca697bf2"}, - {file = "ruff-0.11.10-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:ddf8967e08227d1bd95cc0851ef80d2ad9c7c0c5aab1eba31db49cf0a7b99523"}, - {file = "ruff-0.11.10-py3-none-musllinux_1_2_i686.whl", hash = "sha256:5a94acf798a82db188f6f36575d80609072b032105d114b0f98661e1679c9125"}, - {file = "ruff-0.11.10-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:3afead355f1d16d95630df28d4ba17fb2cb9c8dfac8d21ced14984121f639bad"}, - {file = "ruff-0.11.10-py3-none-win32.whl", hash = "sha256:dc061a98d32a97211af7e7f3fa1d4ca2fcf919fb96c28f39551f35fc55bdbc19"}, - {file = "ruff-0.11.10-py3-none-win_amd64.whl", hash = "sha256:5cc725fbb4d25b0f185cb42df07ab6b76c4489b4bfb740a175f3a59c70e8a224"}, - {file = "ruff-0.11.10-py3-none-win_arm64.whl", hash = "sha256:ef69637b35fb8b210743926778d0e45e1bffa850a7c61e428c6b971549b5f5d1"}, - {file = "ruff-0.11.10.tar.gz", hash = "sha256:d522fb204b4959909ecac47da02830daec102eeb100fb50ea9554818d47a5fa6"}, + {file = "ruff-0.12.4-py3-none-linux_armv6l.whl", hash = "sha256:cb0d261dac457ab939aeb247e804125a5d521b21adf27e721895b0d3f83a0d0a"}, + {file = "ruff-0.12.4-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:55c0f4ca9769408d9b9bac530c30d3e66490bd2beb2d3dae3e4128a1f05c7442"}, + {file = "ruff-0.12.4-py3-none-macosx_11_0_arm64.whl", hash = "sha256:a8224cc3722c9ad9044da7f89c4c1ec452aef2cfe3904365025dd2f51daeae0e"}, + {file = "ruff-0.12.4-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e9949d01d64fa3672449a51ddb5d7548b33e130240ad418884ee6efa7a229586"}, + {file = "ruff-0.12.4-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:be0593c69df9ad1465e8a2d10e3defd111fdb62dcd5be23ae2c06da77e8fcffb"}, + {file = "ruff-0.12.4-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a7dea966bcb55d4ecc4cc3270bccb6f87a337326c9dcd3c07d5b97000dbff41c"}, + {file = "ruff-0.12.4-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:afcfa3ab5ab5dd0e1c39bf286d829e042a15e966b3726eea79528e2e24d8371a"}, + {file = "ruff-0.12.4-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c057ce464b1413c926cdb203a0f858cd52f3e73dcb3270a3318d1630f6395bb3"}, + {file = "ruff-0.12.4-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e64b90d1122dc2713330350626b10d60818930819623abbb56535c6466cce045"}, + {file = "ruff-0.12.4-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2abc48f3d9667fdc74022380b5c745873499ff827393a636f7a59da1515e7c57"}, + {file = "ruff-0.12.4-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:2b2449dc0c138d877d629bea151bee8c0ae3b8e9c43f5fcaafcd0c0d0726b184"}, + {file = "ruff-0.12.4-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:56e45bb11f625db55f9b70477062e6a1a04d53628eda7784dce6e0f55fd549eb"}, + {file = "ruff-0.12.4-py3-none-musllinux_1_2_i686.whl", hash = "sha256:478fccdb82ca148a98a9ff43658944f7ab5ec41c3c49d77cd99d44da019371a1"}, + {file = "ruff-0.12.4-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:0fc426bec2e4e5f4c4f182b9d2ce6a75c85ba9bcdbe5c6f2a74fcb8df437df4b"}, + {file = "ruff-0.12.4-py3-none-win32.whl", hash = "sha256:4de27977827893cdfb1211d42d84bc180fceb7b72471104671c59be37041cf93"}, + {file = "ruff-0.12.4-py3-none-win_amd64.whl", hash = "sha256:fe0b9e9eb23736b453143d72d2ceca5db323963330d5b7859d60d101147d461a"}, + {file = "ruff-0.12.4-py3-none-win_arm64.whl", hash = "sha256:0618ec4442a83ab545e5b71202a5c0ed7791e8471435b94e655b570a5031a98e"}, + {file = "ruff-0.12.4.tar.gz", hash = "sha256:13efa16df6c6eeb7d0f091abae50f58e9522f3843edb40d56ad52a5a4a4b6873"}, ] [[package]] @@ -1973,13 +2000,13 @@ files = [ [[package]] name = "setuptools" -version = "80.7.1" +version = "80.9.0" requires_python = ">=3.9" summary = "Easily download, build, install, upgrade, and uninstall Python packages" groups = ["default"] files = [ - {file = "setuptools-80.7.1-py3-none-any.whl", hash = "sha256:ca5cc1069b85dc23070a6628e6bcecb3292acac802399c7f8edc0100619f9009"}, - {file = "setuptools-80.7.1.tar.gz", hash = "sha256:f6ffc5f0142b1bd8d0ca94ee91b30c0ca862ffd50826da1ea85258a06fd94552"}, + {file = "setuptools-80.9.0-py3-none-any.whl", hash = "sha256:062d34222ad13e0cc312a4c02d73f059e86a4acbfbdea8f8f76b28c99f306922"}, + {file = "setuptools-80.9.0.tar.gz", hash = "sha256:f36b47402ecde768dbfafc46e8e4207b4360c654f1f3bb84475f0a28628fb19c"}, ] [[package]] @@ -2122,7 +2149,7 @@ files = [ [[package]] name = "torch" -version = "2.7.0" +version = "2.7.1" requires_python = ">=3.9.0" summary = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" groups = ["default"] @@ -2147,32 +2174,32 @@ dependencies = [ "nvidia-nvtx-cu12==12.6.77; platform_system == \"Linux\" and platform_machine == \"x86_64\"", "setuptools; python_version >= \"3.12\"", "sympy>=1.13.3", - "triton==3.3.0; platform_system == \"Linux\" and platform_machine == \"x86_64\"", + "triton==3.3.1; platform_system == \"Linux\" and platform_machine == \"x86_64\"", "typing-extensions>=4.10.0", ] files = [ - {file = "torch-2.7.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:c9afea41b11e1a1ab1b258a5c31afbd646d6319042bfe4f231b408034b51128b"}, - {file = "torch-2.7.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:0b9960183b6e5b71239a3e6c883d8852c304e691c0b2955f7045e8a6d05b9183"}, - {file = "torch-2.7.0-cp310-cp310-win_amd64.whl", hash = "sha256:2ad79d0d8c2a20a37c5df6052ec67c2078a2c4e9a96dd3a8b55daaff6d28ea29"}, - {file = "torch-2.7.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:34e0168ed6de99121612d72224e59b2a58a83dae64999990eada7260c5dd582d"}, - {file = "torch-2.7.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:2b7813e904757b125faf1a9a3154e1d50381d539ced34da1992f52440567c156"}, - {file = "torch-2.7.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:fd5cfbb4c3bbadd57ad1b27d56a28008f8d8753733411a140fcfb84d7f933a25"}, - {file = "torch-2.7.0-cp311-cp311-win_amd64.whl", hash = "sha256:58df8d5c2eeb81305760282b5069ea4442791a6bbf0c74d9069b7b3304ff8a37"}, - {file = "torch-2.7.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:0a8d43caa342b9986101ec5feb5bbf1d86570b5caa01e9cb426378311258fdde"}, - {file = "torch-2.7.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:36a6368c7ace41ad1c0f69f18056020b6a5ca47bedaca9a2f3b578f5a104c26c"}, - {file = "torch-2.7.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:15aab3e31c16feb12ae0a88dba3434a458874636f360c567caa6a91f6bfba481"}, - {file = "torch-2.7.0-cp312-cp312-win_amd64.whl", hash = "sha256:f56d4b2510934e072bab3ab8987e00e60e1262fb238176168f5e0c43a1320c6d"}, - {file = "torch-2.7.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:30b7688a87239a7de83f269333651d8e582afffce6f591fff08c046f7787296e"}, + {file = "torch-2.7.1-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:a103b5d782af5bd119b81dbcc7ffc6fa09904c423ff8db397a1e6ea8fd71508f"}, + {file = "torch-2.7.1-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:fe955951bdf32d182ee8ead6c3186ad54781492bf03d547d31771a01b3d6fb7d"}, + {file = "torch-2.7.1-cp310-cp310-win_amd64.whl", hash = "sha256:885453d6fba67d9991132143bf7fa06b79b24352f4506fd4d10b309f53454162"}, + {file = "torch-2.7.1-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:d72acfdb86cee2a32c0ce0101606f3758f0d8bb5f8f31e7920dc2809e963aa7c"}, + {file = "torch-2.7.1-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:236f501f2e383f1cb861337bdf057712182f910f10aeaf509065d54d339e49b2"}, + {file = "torch-2.7.1-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:06eea61f859436622e78dd0cdd51dbc8f8c6d76917a9cf0555a333f9eac31ec1"}, + {file = "torch-2.7.1-cp311-cp311-win_amd64.whl", hash = "sha256:8273145a2e0a3c6f9fd2ac36762d6ee89c26d430e612b95a99885df083b04e52"}, + {file = "torch-2.7.1-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:aea4fc1bf433d12843eb2c6b2204861f43d8364597697074c8d38ae2507f8730"}, + {file = "torch-2.7.1-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:27ea1e518df4c9de73af7e8a720770f3628e7f667280bce2be7a16292697e3fa"}, + {file = "torch-2.7.1-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:c33360cfc2edd976c2633b3b66c769bdcbbf0e0b6550606d188431c81e7dd1fc"}, + {file = "torch-2.7.1-cp312-cp312-win_amd64.whl", hash = "sha256:d8bf6e1856ddd1807e79dc57e54d3335f2b62e6f316ed13ed3ecfe1fc1df3d8b"}, + {file = "torch-2.7.1-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:787687087412c4bd68d315e39bc1223f08aae1d16a9e9771d95eabbb04ae98fb"}, ] [[package]] name = "torchao" -version = "0.11.0" +version = "0.12.0" summary = "Package for applying ao techniques to GPU models" groups = ["default"] files = [ - {file = "torchao-0.11.0-cp39-abi3-manylinux_2_28_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:5d9804833a7ec4d98a759d4750cd0f07dbd90503e3f3f4c59c77a9b9172bf043"}, - {file = "torchao-0.11.0-py3-none-any.whl", hash = "sha256:22be86f06e95b3c1f700eb8eb9fc73a5c2c9d88845ef1f37b3448bb1848a2ccd"}, + {file = "torchao-0.12.0-cp39-abi3-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:326ea2800cc7d9d50f0d17742ad923e5c6d4c4dd5942558f4ed13db00bdebc7c"}, + {file = "torchao-0.12.0-py3-none-any.whl", hash = "sha256:103f2a9164d2e4f705332af1aafbb8473eadd14d9164e45857ca187cde1f13d2"}, ] [[package]] @@ -2253,26 +2280,9 @@ files = [ {file = "transformers-4.51.3.tar.gz", hash = "sha256:e292fcab3990c6defe6328f0f7d2004283ca81a7a07b2de9a46d67fd81ea1409"}, ] -[[package]] -name = "transformers" -version = "4.51.3" -extras = ["sentencepiece"] -requires_python = ">=3.9.0" -summary = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow" -groups = ["default"] -dependencies = [ - "protobuf", - "sentencepiece!=0.1.92,>=0.1.91", - "transformers==4.51.3", -] -files = [ - {file = "transformers-4.51.3-py3-none-any.whl", hash = "sha256:fd3279633ceb2b777013234bbf0b4f5c2d23c4626b05497691f00cfda55e8a83"}, - {file = "transformers-4.51.3.tar.gz", hash = "sha256:e292fcab3990c6defe6328f0f7d2004283ca81a7a07b2de9a46d67fd81ea1409"}, -] - [[package]] name = "triton" -version = "3.3.0" +version = "3.3.1" summary = "A language and compiler for custom Deep Learning operations" groups = ["default"] marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" @@ -2280,12 +2290,9 @@ dependencies = [ "setuptools>=40.8.0", ] files = [ - {file = "triton-3.3.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fad99beafc860501d7fcc1fb7045d9496cbe2c882b1674640304949165a916e7"}, - {file = "triton-3.3.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3161a2bf073d6b22c4e2f33f951f3e5e3001462b2570e6df9cd57565bdec2984"}, - {file = "triton-3.3.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b68c778f6c4218403a6bd01be7484f6dc9e20fe2083d22dd8aef33e3b87a10a3"}, - {file = "triton-3.3.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:47bc87ad66fa4ef17968299acacecaab71ce40a238890acc6ad197c3abe2b8f1"}, - {file = "triton-3.3.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ce4700fc14032af1e049005ae94ba908e71cd6c2df682239aed08e49bc71b742"}, - {file = "triton-3.3.0-cp39-cp39-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1f41403bfa0cbb3e24fd958ca7fee04e9681e55e539296db9aca30c42acae693"}, + {file = "triton-3.3.1-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b74db445b1c562844d3cfad6e9679c72e93fdfb1a90a24052b03bb5c49d1242e"}, + {file = "triton-3.3.1-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b31e3aa26f8cb3cc5bf4e187bf737cbacf17311e1112b781d4a059353dfd731b"}, + {file = "triton-3.3.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9999e83aba21e1a78c1f36f21bce621b77bcaa530277a50484a7cb4a822f6e43"}, ] [[package]] diff --git a/pyproject.toml b/pyproject.toml index afd3f1a..ba3825e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,10 @@ build-ext-test.cmd = "python test.py" build-ext-test.working_dir = "src/extensions" build-ext-ref.cmd = "python build.py" build-ext-ref.working_dir = "src/extensions_ref" +clean-ext.cmd = "rm -rf build" +clean-ext.working_dir = "src/extensions" +clean-ext-ref.cmd = "rm -rf build" +clean-ext-ref.working_dir = "src/extensions_ref" main.cmd = "python main.py" main-week1.cmd = "python main.py --loader week1" main-week2.cmd = "python main.py --loader week2" diff --git a/src/extensions/src/axpby.cpp b/src/extensions/src/axpby.cpp index 21bcbd0..e70027e 100644 --- a/src/extensions/src/axpby.cpp +++ b/src/extensions/src/axpby.cpp @@ -148,7 +148,8 @@ void Axpby::eval_gpu(const std::vector &inputs, std::vector &axes) override; /** Print the primitive. */ - void print(std::ostream &os) override { os << "Axpby"; } + const char* name() const override { return "Axpby"; } /** Equivalence check **/ bool is_equivalent(const mx::Primitive &other) const override; diff --git a/src/extensions/src/utils.cpp b/src/extensions/src/utils.cpp index 97a5b80..3b52303 100644 --- a/src/extensions/src/utils.cpp +++ b/src/extensions/src/utils.cpp @@ -9,7 +9,7 @@ namespace tiny_llm_ext { void load_library(mx::Device d, const char *path) { #ifdef _METAL_ auto &md = mx::metal::device(d); - md.register_library("tiny_llm_ext", path); + md.get_library("tiny_llm_ext", path); #endif } diff --git a/src/extensions_ref/src/flash_attention.cpp b/src/extensions_ref/src/flash_attention.cpp index 10362b8..3ac0277 100644 --- a/src/extensions_ref/src/flash_attention.cpp +++ b/src/extensions_ref/src/flash_attention.cpp @@ -215,7 +215,8 @@ void FlashAttention::eval_gpu(const std::vector &inputs, std::vector< out.set_data(mx::allocator::malloc(out.nbytes())); // Make a kernel from this metal library - auto kernel = d.get_kernel("flash_attention_f32_e128", "tiny_llm_ext_ref"); + auto library = d.get_library("tiny_llm_ext_ref"); + auto kernel = d.get_kernel("flash_attention_f32_e128", library); // Prepare to encode kernel auto &compute_encoder = d.get_command_encoder(s.index); diff --git a/src/extensions_ref/src/quantized_matmul.cpp b/src/extensions_ref/src/quantized_matmul.cpp index b5874e4..d5eed7d 100644 --- a/src/extensions_ref/src/quantized_matmul.cpp +++ b/src/extensions_ref/src/quantized_matmul.cpp @@ -163,7 +163,8 @@ void QuantizedMatmul::eval_gpu(const std::vector &inputs, std::vector out.set_data(mx::allocator::malloc(out.nbytes())); // Make a kernel from this metal library - auto kernel = d.get_kernel("quantized_matmul_w4a16_g64", "tiny_llm_ext_ref"); + auto library = d.get_library("tiny_llm_ext_ref"); + auto kernel = d.get_kernel("quantized_matmul_w4a16_g64", library); // Prepare to encode kernel auto &compute_encoder = d.get_command_encoder(s.index); diff --git a/src/extensions_ref/src/tiny_llm_ext.h b/src/extensions_ref/src/tiny_llm_ext.h index 4c2bf66..c7afbd3 100644 --- a/src/extensions_ref/src/tiny_llm_ext.h +++ b/src/extensions_ref/src/tiny_llm_ext.h @@ -32,7 +32,7 @@ class QuantizedMatmul : public mx::Primitive { throw std::runtime_error("QuantizedMatmul has no vmap implementation."); } - void print(std::ostream &os) override { os << "QuantizedMatmul"; } + const char* name() const override { return "QuantizedMatmul"; } bool is_equivalent(const mx::Primitive &other) const override; @@ -57,7 +57,7 @@ class FlashAttention : public mx::Primitive { throw std::runtime_error("FlashAttention has no vmap implementation."); } - void print(std::ostream &os) override { os << "FlashAttention"; } + const char* name() const override { return "FlashAttention"; } bool is_equivalent(const mx::Primitive &other) const override { const FlashAttention &r_other = static_cast(other); diff --git a/src/extensions_ref/src/utils.cpp b/src/extensions_ref/src/utils.cpp index ac87b1c..9ce0003 100644 --- a/src/extensions_ref/src/utils.cpp +++ b/src/extensions_ref/src/utils.cpp @@ -9,7 +9,7 @@ namespace tiny_llm_ext_ref { void load_library(mx::Device d, const char *path) { #ifdef _METAL_ auto &md = mx::metal::device(d); - md.register_library("tiny_llm_ext_ref", path); + md.get_library("tiny_llm_ext_ref", path); #endif } From 7fc05dc0f5580cc12fe6b4f5c92b3cd68d4aa073 Mon Sep 17 00:00:00 2001 From: Phoenix Date: Wed, 23 Jul 2025 09:25:58 +0800 Subject: [PATCH 04/79] test:add a test case to cover week_1_day_3_task3 (#31) * test:add a test case to cover week_1_day_3_task3 Closes: #23 Signed-off-by: Jiawei Zhao * fmt Signed-off-by: Alex Chi Z --------- Signed-off-by: Jiawei Zhao Signed-off-by: Alex Chi Z Co-authored-by: Alex Chi Z --- tests_refsol/test_week_1_day_3.py | 65 ++++++++++++++++++++++++++++++- 1 file changed, 63 insertions(+), 2 deletions(-) diff --git a/tests_refsol/test_week_1_day_3.py b/tests_refsol/test_week_1_day_3.py index 5158150..32c27ae 100644 --- a/tests_refsol/test_week_1_day_3.py +++ b/tests_refsol/test_week_1_day_3.py @@ -131,5 +131,66 @@ def test_task_2_grouped_attention_causal_mask( grouped_attention_helper(stream, precision, batch_dimension, scale, True) -def test_task_3_qwen2_grouped_query_attention(): - pass +@pytest.mark.parametrize("stream", AVAILABLE_STREAMS, ids=AVAILABLE_STREAMS_IDS) +@pytest.mark.parametrize("precision", PRECISIONS, ids=PRECISION_IDS) +@pytest.mark.parametrize("mask", [None, "causal"], ids=["no_mask", "causal_mask"]) +def test_task_3_qwen2_grouped_query_attention( + stream: mx.Stream, precision: mx.Dtype, mask: str | None +): + with mx.stream(stream): + batch_size = 1 + seq_len = 4 + hidden_size = 32 + num_heads = 4 + num_kv_heads = 2 + max_seq_len = 64 + theta = 10000 + + from mlx_lm.models import qwen2 + + args = qwen2.ModelArgs( + model_type="qwen2", + hidden_size=hidden_size, + num_hidden_layers=2, + intermediate_size=hidden_size * 4, + num_attention_heads=num_heads, + num_key_value_heads=num_kv_heads, + rms_norm_eps=1e-6, + vocab_size=1000, + rope_theta=theta, + rope_traditional=False, + max_position_embeddings=max_seq_len, + ) + + mlx_attention = qwen2.Attention(args) + wq = mlx_attention.q_proj.weight + wk = mlx_attention.k_proj.weight + wv = mlx_attention.v_proj.weight + wo = mlx_attention.o_proj.weight + bq = mlx_attention.q_proj.bias + bk = mlx_attention.k_proj.bias + bv = mlx_attention.v_proj.bias + mx.random.seed(42) + x = mx.random.uniform( + -1.0, 1.0, shape=(batch_size, seq_len, hidden_size), dtype=precision + ) + + user_attention = qwen2_week1.Qwen2MultiHeadAttention( + hidden_size=hidden_size, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + wq=wq, + wk=wk, + wv=wv, + wo=wo, + bq=bq, + bk=bk, + bv=bv, + max_seq_len=max_seq_len, + theta=theta, + ) + + user_output = user_attention(x, offset=0, mask=mask) + mlx_output = mlx_attention(x, mask=mask, cache=None) + + assert_allclose(user_output, mlx_output, precision=precision) From 13295fc64c396a799f8283d35813db3e9e904fac Mon Sep 17 00:00:00 2001 From: Alex Chi Z Date: Sun, 27 Jul 2025 14:12:42 -0400 Subject: [PATCH 05/79] fix tokp implementation Signed-off-by: Alex Chi Z --- src/tiny_llm_ref/sampler.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/tiny_llm_ref/sampler.py b/src/tiny_llm_ref/sampler.py index c4234b7..deacfb8 100644 --- a/src/tiny_llm_ref/sampler.py +++ b/src/tiny_llm_ref/sampler.py @@ -15,8 +15,10 @@ def sample(logprobs: mx.array): if top_p is not None and top_p > 0: sorted_idx = mx.argsort(-logprobs, axis=-1) sorted_logprobs = logprobs[:, sorted_idx] - cumsum = mx.cumsum(sorted_logprobs, axis=-1) - logprobs[:, sorted_idx] = mx.where(cumsum < top_p, sorted_logprobs, -mx.inf) + cumsum = mx.cumsum(mx.exp(sorted_logprobs), axis=-1) + mask_elements = cumsum < top_p + mask_elements[..., 0] = True + logprobs[:, sorted_idx] = mx.where(mask_elements, sorted_logprobs, -mx.inf) logprobs = logprobs / temp return mx.random.categorical(logprobs, axis=-1) From 4e9101c32ac39fd2885140c988165ea602674b0c Mon Sep 17 00:00:00 2001 From: Alex Chi Z Date: Sun, 27 Jul 2025 14:26:08 -0400 Subject: [PATCH 06/79] more precision tweaks Signed-off-by: Alex Chi Z --- tests/utils.py | 9 ++++++++- tests_refsol/test_week_2_day_3.py | 2 +- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/tests/utils.py b/tests/utils.py index 32906b5..4b9a332 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -23,14 +23,21 @@ def assert_allclose( elif precision == mx.float16: rtol = rtol or 3.0e-2 atol = atol or 1.0e-5 + else: + raise ValueError(f"Unsupported precision: {precision}") assert a.shape == b.shape, f"shape mismatch: {a.shape} vs {b.shape}" if not np.allclose(a, b, rtol=rtol, atol=atol): + diff = np.invert(np.isclose(a, b, rtol=rtol, atol=atol)) + if diff.size > 10000 and np.sum(diff) <= 1: + # if only one element is different in a large array, probably fine + return with np.printoptions(precision=3, suppress=True): print("a=", a) print("b=", b) - diff = np.invert(np.isclose(a, b, rtol=rtol, atol=atol)) print("diff_a=", a * diff) print("diff_b=", b * diff) + print("diff_a_val=", a[diff]) + print("diff_b_val=", b[diff]) assert False, f"result mismatch" diff --git a/tests_refsol/test_week_2_day_3.py b/tests_refsol/test_week_2_day_3.py index f787834..ac8f2a7 100644 --- a/tests_refsol/test_week_2_day_3.py +++ b/tests_refsol/test_week_2_day_3.py @@ -28,7 +28,7 @@ def attention_helper(stream: mx.Stream, H_q, H, L, E, S, BATCH): scale=scale, ) mx.eval(user_output) # so that any error will be caught here - assert_allclose(user_output, reference_output, precision=precision) + assert_allclose(user_output, reference_output, precision=mx.float16) def test_flash_attention_cpu_small(): From ae06a7f9f995b6e3de52bae986b594ce1586f607 Mon Sep 17 00:00:00 2001 From: Alex Chi Z Date: Sun, 27 Jul 2025 15:37:56 -0400 Subject: [PATCH 07/79] fix bugs in continuous batching Signed-off-by: Alex Chi Z --- batch-main.py | 3 +- src/tiny_llm_ref/attention.py | 9 +- src/tiny_llm_ref/generate.py | 19 +--- src/tiny_llm_ref/kv_cache.py | 189 +++++++++----------------------- src/tiny_llm_ref/qwen2_week2.py | 11 +- 5 files changed, 72 insertions(+), 159 deletions(-) diff --git a/batch-main.py b/batch-main.py index d7af798..9d7ed3f 100644 --- a/batch-main.py +++ b/batch-main.py @@ -4,7 +4,7 @@ import random parser = argparse.ArgumentParser() -parser.add_argument("--model", type=str, default="Qwen/Qwen2-7B-Instruct-MLX") +parser.add_argument("--model", type=str, default="Qwen/Qwen2-0.5B-Instruct-MLX") shanghai_wikipedia = """ Shanghai[a] is a direct-administered municipality and the most populous urban area in China. The city is located on the Chinese shoreline on the southern estuary of the Yangtze River, with the Huangpu River flowing through it. The population of the city proper is the second largest in the world after Chongqing, with around 24.87 million inhabitants in 2023, while the urban area is the most populous in China, with 29.87 million residents. As of 2022, the Greater Shanghai metropolitan area was estimated to produce a gross metropolitan product (nominal) of nearly 13 trillion RMB ($1.9 trillion).[13] Shanghai is one of the world's major centers for finance, business and economics, research, science and technology, manufacturing, transportation, tourism, and culture. The Port of Shanghai is the world's busiest container port. @@ -76,5 +76,6 @@ prefill_step=args.prefill_step, ) for prompt_idx, text in result: + print(f"--- {prompt_idx} ---") print(f"Q: {prompts[prompt_idx]}") print(f"A: {text}") diff --git a/src/tiny_llm_ref/attention.py b/src/tiny_llm_ref/attention.py index d628446..9905056 100644 --- a/src/tiny_llm_ref/attention.py +++ b/src/tiny_llm_ref/attention.py @@ -40,12 +40,13 @@ def scaled_dot_product_attention_grouped( H_q, L, D = query.shape[-3:] H, S, _ = key.shape[-3:] + B = query.shape[:-3] assert H_q % H == 0 n_repeats = H_q // H - query = query.reshape(-1, H, n_repeats, L, D) - key = key.reshape(-1, H, 1, S, D) - value = value.reshape(-1, H, 1, S, D) + query = query.reshape(*B, -1, H, n_repeats, L, D) + key = key.reshape(*B, -1, H, 1, S, D) + value = value.reshape(*B, -1, H, 1, S, D) scores = mx.matmul(query, key.swapaxes(-2, -1)) * factor if mask is not None: @@ -53,7 +54,7 @@ def scaled_dot_product_attention_grouped( mask = causal_mask(L, S, scores.dtype) scores = scores + mask else: - mask = mask.reshape(-1, H, n_repeats, mask.shape[-2], mask.shape[-1]) + mask = mask.reshape(*B, 1, 1, 1, mask.shape[-2], mask.shape[-1]) scores = scores + mask result = mx.matmul(softmax(scores, axis=-1), value) return result.reshape(expected_shape) diff --git a/src/tiny_llm_ref/generate.py b/src/tiny_llm_ref/generate.py index fda7929..eb2b87a 100644 --- a/src/tiny_llm_ref/generate.py +++ b/src/tiny_llm_ref/generate.py @@ -213,23 +213,14 @@ def batch_generate( for i in range(batch_size): if not is_idle[i]: detokenizers[i].add_token(next_tokens[i].item()) - if ( - next_tokens[i].item() == tokenizer.eos_token_id - or offsets[i] >= max_seq_len - ): - print( - f"(Finished) {prompt_idx[i]}: " + detokenizers[i].text, - flush=True, - ) + remove_due_to_eos = next_tokens[i].item() == tokenizer.eos_token_id + remove_due_to_max_seq_len = offsets[i] >= max_seq_len + if remove_due_to_eos or remove_due_to_max_seq_len: + reason = "EOS" if remove_due_to_eos else "Max Seq Len" result.append((prompt_idx[i], detokenizers[i].text)) - print(f"Removing request {i}", flush=True) + print(f"Removing request {i} due to {reason}", flush=True) batch_cache.remove_request(i) is_idle[i] = True continue - else: - print( - f"(In Progress) {prompt_idx[i]}: " + detokenizers[i].text, - flush=True, - ) _print_progress(detokenizers, prompt_idx, is_idle, pending_prefill_requests) return result diff --git a/src/tiny_llm_ref/kv_cache.py b/src/tiny_llm_ref/kv_cache.py index e344ebb..c23e32c 100644 --- a/src/tiny_llm_ref/kv_cache.py +++ b/src/tiny_llm_ref/kv_cache.py @@ -1,12 +1,13 @@ from typing import Optional +from .attention import causal_mask import mlx.core as mx class TinyKvCache: def update_and_fetch( - self, key: mx.array, value: mx.array - ) -> tuple[mx.array, mx.array, int]: + self, key: mx.array, value: mx.array, q_L: int | None = None + ) -> tuple[mx.array, mx.array, int, Optional[mx.array]]: pass @@ -14,40 +15,49 @@ class BatchingKvCache(TinyKvCache): def __init__(self, max_active_requests: int, max_seq_len: int): self.max_active_requests = max_active_requests self.max_seq_len = max_seq_len - self.key_values = None - self.head_offsets = mx.array([0] * max_active_requests) - self.head = 0 + self.key_values = [None] * max_active_requests + self.real_seq_len = [0] * max_active_requests + self.HD = None def update_and_fetch( - self, key: mx.array, value: mx.array - ) -> tuple[mx.array, mx.array, int]: - B, H, L, D = key.shape + self, key: mx.array, value: mx.array, q_L: int | None = None + ) -> tuple[mx.array, mx.array, int, Optional[mx.array]]: + B, H, S, D = key.shape assert key.shape == value.shape - assert L <= self.max_seq_len - keys, values = self.key_values - if self.head + L <= self.max_seq_len: - keys[:, :, self.head : self.head + L, :] = key - values[:, :, self.head : self.head + L, :] = value - self.head += L - self.head_offsets += L - else: - fill_size = self.max_seq_len - self.head - keys[:, :, self.head : self.max_seq_len, :] = key[:, :, :fill_size, :] - values[:, :, self.head : self.max_seq_len, :] = value[:, :, :fill_size, :] - remaining_size = L - fill_size - keys[:, :, :remaining_size, :] = key[:, :, fill_size:, :] - values[:, :, :remaining_size, :] = value[:, :, fill_size:, :] - self.head = remaining_size - self.head_offsets += L - self.key_values = (keys, values) - - before_keys = keys[:, :, self.head :, :] - before_values = values[:, :, self.head :, :] - after_keys = keys[:, :, : self.head, :] - after_values = values[:, :, : self.head, :] - keys = mx.concat([after_keys, before_keys], axis=2) - values = mx.concat([after_values, before_values], axis=2) - return keys, values, self.head_offsets + assert S <= self.max_seq_len + assert self.HD == (H, D), f"expect {self.HD} but got {H, D}" + assert B == self.max_active_requests + # Step 1: append the result to the cache + for b in range(B): + if self.key_values[b] is None: + continue + cached_keys, cached_values = self.key_values[b] + keys, values = key[b], value[b] + keys = mx.concat([cached_keys, keys], axis=1) + values = mx.concat([cached_values, values], axis=1) + self.key_values[b] = (keys, values) + self.real_seq_len[b] += S + # Step 2: compute seq_len of this batch + seq_len = max(self.real_seq_len) + # Step 3: generate masks and a single array of keys and values + masks = [] + keys = mx.zeros((self.max_active_requests, H, seq_len, D), dtype=key.dtype) + values = mx.zeros((self.max_active_requests, H, seq_len, D), dtype=value.dtype) + masks = mx.full( + (self.max_active_requests, q_L, seq_len), -mx.inf, dtype=key.dtype + ) + for b in range(B): + if self.key_values[b] is None: + # for some reasons we need to do this, otherwise it will cause wrong output? + # maybe precision issues? + masks[b, :, :] = causal_mask(q_L, seq_len, dtype=key.dtype) + continue + cached_keys, cached_values = self.key_values[b] + S = self.real_seq_len[b] + keys[b, :, seq_len - S : seq_len, :] = cached_keys + values[b, :, seq_len - S : seq_len, :] = cached_values + masks[b, :, seq_len - S : seq_len] = causal_mask(q_L, S, dtype=key.dtype) + return keys, values, None, masks def add_request(self, prefilled: TinyKvCache, id: int): if id >= self.max_active_requests: @@ -55,50 +65,18 @@ def add_request(self, prefilled: TinyKvCache, id: int): keys, values = prefilled.key_values B, H, L, D = keys.shape assert B == 1 - if self.key_values is None: - self.key_values = ( - mx.zeros((self.max_active_requests, H, self.max_seq_len, D)), - mx.zeros((self.max_active_requests, H, self.max_seq_len, D)), - ) - if L > self.max_seq_len: - keys = keys[:, :, -self.max_seq_len :, :] - values = values[:, :, -self.max_seq_len :, :] - take_size = self.max_seq_len + if self.HD is None: + self.HD = (H, D) else: - take_size = L - cached_keys, cached_values = self.key_values - # Firstly, fill the cache with zeros - cached_keys[id, :, :, :] = 0 - cached_values[id, :, :, :] = 0 - # Then, fill the cache with the prefilled values up to self.head (may wrap) - start_pos = (self.head - take_size + self.max_seq_len) % self.max_seq_len - if start_pos + take_size <= self.max_seq_len: - cached_keys[id, :, start_pos : start_pos + take_size, :] = keys[0, :, :, :] - cached_values[id, :, start_pos : start_pos + take_size, :] = values[ - 0, :, :, : - ] - else: - cached_keys[id, :, start_pos : self.max_seq_len, :] = keys[ - 0, :, : self.max_seq_len - start_pos, : - ] - cached_values[id, :, start_pos : self.max_seq_len, :] = values[ - 0, :, : self.max_seq_len - start_pos, : - ] - cached_keys[id, :, : take_size - (self.max_seq_len - start_pos), :] = keys[ - 0, :, self.max_seq_len - start_pos :, : - ] - cached_values[id, :, : take_size - (self.max_seq_len - start_pos), :] = ( - values[0, :, self.max_seq_len - start_pos :, :] - ) - self.head_offsets[id] = L - self.key_values = (cached_keys, cached_values) + assert self.HD == (H, D) + self.real_seq_len[id] = L + self.key_values[id] = (keys[0], values[0]) def remove_request(self, id: int): if self.key_values is None: raise ValueError(f"Request id {id} is not in the cache") - cached_keys, cached_values = self.key_values - cached_keys[id, :, :, :] = 0 - cached_values[id, :, :, :] = 0 + self.key_values[id] = None + self.real_seq_len[id] = 0 class TinyKvFullCache(TinyKvCache): @@ -107,14 +85,14 @@ def __init__(self): self.offset = 0 def update_and_fetch( - self, key: mx.array, value: mx.array - ) -> tuple[mx.array, mx.array, int]: + self, key: mx.array, value: mx.array, q_L: int | None = None + ) -> tuple[mx.array, mx.array, int, Optional[mx.array]]: if self.key_values is None: assert self.offset == 0 self.key_values = (key, value) B, H, S, D = key.shape self.offset = S - return key, value, 0 + return key, value, 0, None else: B, H, S, D = key.shape assert key.shape == value.shape @@ -126,63 +104,4 @@ def update_and_fetch( self.key_values = (new_keys, new_values) start_offset = self.offset self.offset += S - return new_keys, new_values, start_offset - - -class TinyKvRotatingCache(TinyKvCache): - def __init__(self, max_seq_len: int): - self.max_seq_len = max_seq_len - self.key_values = None - self.head = 0 - self.head_offset = 0 - - def update_and_fetch( - self, key: mx.array, value: mx.array, offset: int - ) -> tuple[mx.array, mx.array]: - if self.key_values is None: - assert offset == 0 - B, H, L, D = key.shape - assert L <= self.max_seq_len - keys = mx.zeros((B, H, self.max_seq_len, D)) - values = mx.zeros((B, H, self.max_seq_len, D)) - keys[:, :, :L, :] = key - values[:, :, :L, :] = value - self.key_values = (keys, values) - self.head = L - self.head_offset = L - return keys[:, :, :L, :], values[:, :, :L, :] - else: - B, H, L, D = key.shape - assert key.shape == value.shape - assert offset == self.head_offset - assert L <= self.max_seq_len - keys, values = self.key_values - if self.head + L <= self.max_seq_len: - keys[:, :, self.head : self.head + L, :] = key - values[:, :, self.head : self.head + L, :] = value - self.head += L - self.head_offset += L - else: - fill_size = self.max_seq_len - self.head - keys[:, :, self.head : self.max_seq_len, :] = key[:, :, :fill_size, :] - values[:, :, self.head : self.max_seq_len, :] = value[ - :, :, :fill_size, : - ] - remaining_size = L - fill_size - keys[:, :, :remaining_size, :] = key[:, :, fill_size:, :] - values[:, :, :remaining_size, :] = value[:, :, fill_size:, :] - self.head = remaining_size - self.head_offset += L - self.key_values = (keys, values) - if self.head_offset < self.max_seq_len: - return keys[:, :, : self.head_offset, :], values[ - :, :, : self.head_offset, : - ] - else: - before_keys = keys[:, :, self.head_offset :, :] - before_values = values[:, :, self.head_offset :, :] - after_keys = keys[:, :, : self.head_offset, :] - after_values = values[:, :, : self.head_offset, :] - keys = mx.concat([after_keys, before_keys], axis=2) - values = mx.concat([after_values, before_values], axis=2) - return keys, values + return new_keys, new_values, start_offset, None diff --git a/src/tiny_llm_ref/qwen2_week2.py b/src/tiny_llm_ref/qwen2_week2.py index 63bd47a..49ff6cc 100644 --- a/src/tiny_llm_ref/qwen2_week2.py +++ b/src/tiny_llm_ref/qwen2_week2.py @@ -62,6 +62,7 @@ def __call__( projection_v = quantized_linear(x, self.wv, bias=self.bv).reshape( B, L, self.num_kv_heads, self.head_dim ) + # todo: move offsets to kv cache if isinstance(offsets, int): offset_slice = [slice(int(offsets), int(offsets + L))] else: @@ -71,9 +72,11 @@ def __call__( projection_q = projection_q.transpose(0, 2, 1, 3) projection_k = projection_k.transpose(0, 2, 1, 3) projection_v = projection_v.transpose(0, 2, 1, 3) - projection_k, projection_v, _ = cache.update_and_fetch( - projection_k, projection_v + projection_k, projection_v, _, kv_cache_mask = cache.update_and_fetch( + projection_k, projection_v, q_L=L ) + if kv_cache_mask is not None: + mask = kv_cache_mask x = scaled_dot_product_attention_grouped( projection_q.astype(mx.float32), projection_k.astype(mx.float32), @@ -252,9 +255,7 @@ def __call__( ) -> mx.array: h = self.embedding(inputs) for layer in range(self.num_hidden_layers): - h = self.layers_inner[layer]( - h, offset, cache[layer], mask="causal" if h.shape[1] > 1 else None - ) + h = self.layers_inner[layer](h, offset, cache[layer], mask="causal") h = self.norm(h) if self.w_lm_head is not None: return quantized_linear(h, self.w_lm_head) From cfbc43e5bd377e5d8ac40585b3ee1bdac74575e6 Mon Sep 17 00:00:00 2001 From: Alex Chi Z Date: Sun, 27 Jul 2025 15:42:38 -0400 Subject: [PATCH 08/79] fix mask tests Signed-off-by: Alex Chi Z --- src/tiny_llm_ref/attention.py | 3 ++- tests/utils.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/tiny_llm_ref/attention.py b/src/tiny_llm_ref/attention.py index 9905056..b53778b 100644 --- a/src/tiny_llm_ref/attention.py +++ b/src/tiny_llm_ref/attention.py @@ -54,7 +54,8 @@ def scaled_dot_product_attention_grouped( mask = causal_mask(L, S, scores.dtype) scores = scores + mask else: - mask = mask.reshape(*B, 1, 1, 1, mask.shape[-2], mask.shape[-1]) + mask = mx.broadcast_to(mask, (*B, H_q, L, S)) + mask = mask.reshape(*B, 1, H, n_repeats, L, S) scores = scores + mask result = mx.matmul(softmax(scores, axis=-1), value) return result.reshape(expected_shape) diff --git a/tests/utils.py b/tests/utils.py index 4b9a332..4aa7cd4 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -28,8 +28,8 @@ def assert_allclose( assert a.shape == b.shape, f"shape mismatch: {a.shape} vs {b.shape}" if not np.allclose(a, b, rtol=rtol, atol=atol): diff = np.invert(np.isclose(a, b, rtol=rtol, atol=atol)) - if diff.size > 10000 and np.sum(diff) <= 1: - # if only one element is different in a large array, probably fine + if diff.size > 10000 and np.sum(diff) <= 3: + # if only a small number of elements are different in a large array, probably fine return with np.printoptions(precision=3, suppress=True): print("a=", a) From 5863d964736e970e5627a9b2899ff30f29feb5ae Mon Sep 17 00:00:00 2001 From: Alex Chi Z Date: Sun, 27 Jul 2025 15:47:12 -0400 Subject: [PATCH 09/79] reshape in kvcache Signed-off-by: Alex Chi Z --- book/src/week2-overview.md | 2 ++ src/tiny_llm_ref/kv_cache.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/book/src/week2-overview.md b/book/src/week2-overview.md index 2b9e2b4..63d4e49 100644 --- a/book/src/week2-overview.md +++ b/book/src/week2-overview.md @@ -27,3 +27,5 @@ https://huggingface.co/docs/transformers/pad_truncation https://siboehm.com/articles/22/CUDA-MMM https://github.com/ggml-org/llama.cpp/blob/master/ggml/src/ggml-metal/ggml-metal.metal + +pdm run batch-main --solution ref --model Qwen/Qwen2-7B-Instruct-MLX --prefill-step 16 diff --git a/src/tiny_llm_ref/kv_cache.py b/src/tiny_llm_ref/kv_cache.py index c23e32c..8345f99 100644 --- a/src/tiny_llm_ref/kv_cache.py +++ b/src/tiny_llm_ref/kv_cache.py @@ -57,7 +57,7 @@ def update_and_fetch( keys[b, :, seq_len - S : seq_len, :] = cached_keys values[b, :, seq_len - S : seq_len, :] = cached_values masks[b, :, seq_len - S : seq_len] = causal_mask(q_L, S, dtype=key.dtype) - return keys, values, None, masks + return keys, values, None, masks.reshape(B, 1, q_L, seq_len) def add_request(self, prefilled: TinyKvCache, id: int): if id >= self.max_active_requests: From 55e7b0cb8921cea938a7c27ffc4da4868ee75b07 Mon Sep 17 00:00:00 2001 From: Alex Chi Z Date: Sat, 2 Aug 2025 16:49:14 -0400 Subject: [PATCH 10/79] fix flash attention Signed-off-by: Alex Chi Z --- src/extensions_ref/bindings.cpp | 3 +- src/extensions_ref/src/flash_attention.cpp | 96 +++++++++++------- src/extensions_ref/src/flash_attention.metal | 57 +++++++---- src/extensions_ref/src/quantized_matmul.metal | 2 +- src/extensions_ref/src/tiny_llm_ext.h | 8 +- src/tiny_llm_ref/attention.py | 47 +++++---- src/tiny_llm_ref/qwen2_week2.py | 33 +++++-- tests_refsol/test_week_2_day_6.py | 99 +++++++++++++++++++ 8 files changed, 258 insertions(+), 87 deletions(-) create mode 100644 tests_refsol/test_week_2_day_6.py diff --git a/src/extensions_ref/bindings.cpp b/src/extensions_ref/bindings.cpp index c988509..afb4aec 100644 --- a/src/extensions_ref/bindings.cpp +++ b/src/extensions_ref/bindings.cpp @@ -31,7 +31,7 @@ NB_MODULE(_ext, m) { array: ``a * b`` )"); - m.def("flash_attention", &tiny_llm_ext_ref::flash_attention, "query"_a, "key"_a, "value"_a, "scale"_a = 1.0, + m.def("flash_attention", &tiny_llm_ext_ref::flash_attention, "query"_a, "key"_a, "value"_a, "mask"_a, "scale"_a = 1.0, "num_kv_heads"_a, "num_heads"_a, "stream"_a = nb::none(), R"( Flash attention layer @@ -39,6 +39,7 @@ NB_MODULE(_ext, m) { query (array): Query array. key (array): Key array. value (array): Value array. + mask (array): Mask array. scale (float): Scaling factor. Returns: diff --git a/src/extensions_ref/src/flash_attention.cpp b/src/extensions_ref/src/flash_attention.cpp index 3ac0277..768b384 100644 --- a/src/extensions_ref/src/flash_attention.cpp +++ b/src/extensions_ref/src/flash_attention.cpp @@ -13,9 +13,9 @@ #endif namespace tiny_llm_ext_ref { -mx::array flash_attention(const mx::array &q, const mx::array &k, const mx::array &v, const float scale, - const int num_kv_heads, const int num_heads, mx::StreamOrDevice s) { - if (q.dtype() != mx::float32 || k.dtype() != mx::float32 || v.dtype() != mx::float32) { +mx::array flash_attention(const mx::array &q, const mx::array &k, const mx::array &v, const mx::array &mask, + const float scale, const int num_kv_heads, const int num_heads, mx::StreamOrDevice s) { + if (q.dtype() != mx::float32 || k.dtype() != mx::float32 || v.dtype() != mx::float32 || mask.dtype() != mx::float32) { throw std::runtime_error("flash_attention: all input arrays must be float32"); } if (q.shape().size() != 3 || k.shape().size() != 3 || v.shape().size() != 3) { @@ -24,10 +24,15 @@ mx::array flash_attention(const mx::array &q, const mx::array &k, const mx::arra if (num_heads % num_kv_heads != 0) { throw std::runtime_error("flash_attention: num_heads must be divisible by num_kv_heads"); } - // Q: [N, S, E] - // K: [N_KV, L, E] - // V: [N_KV, L, E] - // O: [N, S, E] + if (mask.shape().size() != 3) { + throw std::runtime_error("flash_attention: mask must be 3D"); + } + + // Q: [N, L, E] + // K: [N_KV, S, E] + // V: [N_KV, S, E] + // O: [N, L, E] + // M: [N, L, S] (optional, needs broadcasting) if (q.shape()[0] % num_heads != 0) { throw std::runtime_error("flash_attention: q.shape[0] must be divisible by num_heads"); @@ -44,15 +49,19 @@ mx::array flash_attention(const mx::array &q, const mx::array &k, const mx::arra if (k.shape()[1] != v.shape()[1]) { throw std::runtime_error("flash_attention: k.shape[1] must be equal to v.shape[1]"); } + if (mask.shape()[0] != q.shape()[0] || mask.shape()[1] != q.shape()[1] || mask.shape()[2] != k.shape()[1]) { + throw std::runtime_error("flash_attention: mask must be broadcastable to q, k, v"); + } return mx::array(q.shape(), mx::float32, - std::make_shared(to_stream(s), scale, num_kv_heads, num_heads), {q, k, v}); + std::make_shared(to_stream(s), scale, num_kv_heads, num_heads), {q, k, v, mask}); } void FlashAttention::eval_cpu(const std::vector &inputs, std::vector &outputs) { auto &q = inputs[0]; auto &k = inputs[1]; auto &v = inputs[2]; + auto &mask = inputs[3]; auto &out = outputs[0]; out.set_data(mx::allocator::malloc(out.nbytes())); @@ -61,6 +70,7 @@ void FlashAttention::eval_cpu(const std::vector &inputs, std::vector< encoder.set_input_array(q); encoder.set_input_array(k); encoder.set_input_array(v); + encoder.set_input_array(mask); encoder.set_output_array(out); if (!q.flags().row_contiguous) { @@ -75,23 +85,25 @@ void FlashAttention::eval_cpu(const std::vector &inputs, std::vector< // Launch the CPU kernel encoder.dispatch([out_ptr = out.data(), out_shape = out.shape(), q = mx::array::unsafe_weak_copy(q), - k = mx::array::unsafe_weak_copy(k), v = mx::array::unsafe_weak_copy(v), num_heads = num_heads_, - num_kv_heads = num_kv_heads_, scale = scale_]() { + k = mx::array::unsafe_weak_copy(k), v = mx::array::unsafe_weak_copy(v), + mask = mx::array::unsafe_weak_copy(mask), num_heads = num_heads_, num_kv_heads = num_kv_heads_, + scale = scale_]() { const int64_t N = q.shape()[0]; - const int64_t S = q.shape()[1]; - const int64_t L = k.shape()[1]; + const int64_t L = q.shape()[1]; + const int64_t S = k.shape()[1]; const int64_t E = q.shape()[2]; - const int64_t N_Q_HEAD = S * E; - const int64_t N_K_HEAD = L * E; + const int64_t N_Q_HEAD = L * E; + const int64_t N_K_HEAD = S * E; const int64_t Br = 32; const int64_t Bc = 32; - const int64_t Tr = (S + Br - 1) / Br; - const int64_t Tc = (L + Bc - 1) / Bc; + const int64_t Tr = (L + Br - 1) / Br; + const int64_t Tc = (S + Bc - 1) / Bc; const int64_t q_kv_heads_ratio = num_heads / num_kv_heads; const float *q_ptr = q.data(); const float *k_ptr = k.data(); const float *v_ptr = v.data(); + const float *m_ptr = mask.data(); for (int64_t n = 0; n < N; n++) { const float *q_batch = q_ptr + n * N_Q_HEAD; @@ -99,7 +111,7 @@ void FlashAttention::eval_cpu(const std::vector &inputs, std::vector< const float *v_batch = v_ptr + (n / q_kv_heads_ratio) * N_K_HEAD; for (int64_t i = 0; i < Tr; i++) { std::vector q_i(Br * E, 0.0); - int br_upper_bound = std::min(S - i * Br, Br); + int br_upper_bound = std::min(L - i * Br, Br); // Load Qi for (int64_t a = 0; a < br_upper_bound; a++) { for (int64_t b = 0; b < E; b++) { @@ -111,7 +123,7 @@ void FlashAttention::eval_cpu(const std::vector &inputs, std::vector< std::vector l_i(Br, 0.0); std::vector m_i(Br, -std::numeric_limits::infinity()); for (int64_t j = 0; j < Tc; j++) { - int bc_upper_bound = std::min(L - j * Bc, Bc); + int bc_upper_bound = std::min(S - j * Bc, Bc); // Each kernel processes a block of Br x Bc // Load Kj and Vj std::vector k_j(Bc * E, 0.0); @@ -120,7 +132,7 @@ void FlashAttention::eval_cpu(const std::vector &inputs, std::vector< int64_t kv_idx_base = j * Bc + a; for (int64_t b = 0; b < E; b++) { int kv_idx = kv_idx_base * E + b; - if (kv_idx_base < L) { + if (kv_idx_base < S) { k_j[a * E + b] = k_batch[kv_idx]; v_j[a * E + b] = v_batch[kv_idx]; } @@ -137,6 +149,18 @@ void FlashAttention::eval_cpu(const std::vector &inputs, std::vector< } } + // Add mask and scale + for (int64_t a = 0; a < br_upper_bound; a++) { + for (int64_t b = 0; b < bc_upper_bound; b++) { + int m_idx_1 = n; + int m_idx_2 = i * Br + a; + int m_idx_3 = j * Bc + b; + int m_idx_converted = mx::elem_to_loc(m_idx_1 * L * S + m_idx_2 * S + m_idx_3, mask); + s_i[a * Bc + b] *= scale; + s_i[a * Bc + b] += m_ptr[m_idx_converted]; + } + } + // m_i from iteration j = max(m_i from iteration j-1, rowmax(s_i)) std::vector m_i_diff(Br, 0.0); for (int64_t a = 0; a < br_upper_bound; a++) { @@ -193,7 +217,7 @@ void FlashAttention::eval_cpu(const std::vector &inputs, std::vector< for (int64_t a = 0; a < br_upper_bound; a++) { for (int64_t b = 0; b < E; b++) { int out_idx = i * Br + a; - if (out_idx < S) { + if (out_idx < L) { out_ptr[n * N_Q_HEAD + out_idx * E + b] = o_i[a * E + b]; } } @@ -208,6 +232,7 @@ void FlashAttention::eval_gpu(const std::vector &inputs, std::vector< const auto &q = inputs[0]; const auto &k = inputs[1]; const auto &v = inputs[2]; + const auto &mask = inputs[3]; auto &out = outputs[0]; auto &s = stream(); @@ -230,9 +255,12 @@ void FlashAttention::eval_gpu(const std::vector &inputs, std::vector< compute_encoder.set_input_array(q, 0); compute_encoder.set_input_array(k, 1); compute_encoder.set_input_array(v, 2); + compute_encoder.set_input_array(mask, 3); + compute_encoder.set_vector_bytes(mask.shape(), 4); + compute_encoder.set_vector_bytes(mask.strides(), 5); // Encode output arrays to kernel - compute_encoder.set_output_array(out, 3); + compute_encoder.set_output_array(out, 6); if (!q.flags().row_contiguous) { throw std::runtime_error("flash_attention: q must be contiguous"); @@ -245,18 +273,18 @@ void FlashAttention::eval_gpu(const std::vector &inputs, std::vector< } const int64_t N = q.shape()[0]; - const int64_t S = q.shape()[1]; - const int64_t L = k.shape()[1]; + const int64_t L = q.shape()[1]; + const int64_t S = k.shape()[1]; const int64_t E = q.shape()[2]; - compute_encoder.set_bytes(N, 4); - compute_encoder.set_bytes(S, 5); - compute_encoder.set_bytes(L, 6); - compute_encoder.set_bytes(E, 7); + compute_encoder.set_bytes(N, 7); + compute_encoder.set_bytes(L, 8); + compute_encoder.set_bytes(S, 9); + compute_encoder.set_bytes(E, 10); - compute_encoder.set_bytes(num_kv_heads_, 8); - compute_encoder.set_bytes(num_heads_, 9); - compute_encoder.set_bytes(scale_, 10); + compute_encoder.set_bytes(num_kv_heads_, 11); + compute_encoder.set_bytes(num_heads_, 12); + compute_encoder.set_bytes(scale_, 13); size_t tgp_size = kernel->maxTotalThreadsPerThreadgroup(); size_t simd_width = kernel->threadExecutionWidth(); @@ -281,10 +309,10 @@ void FlashAttention::eval_gpu(const std::vector &inputs, std::vector< const int64_t Tr = (S + Br - 1) / Br; const int64_t Tc = (L + Bc - 1) / Bc; - compute_encoder.set_bytes(Br, 11); - compute_encoder.set_bytes(Bc, 12); - compute_encoder.set_bytes(Tr, 13); - compute_encoder.set_bytes(Tc, 14); + compute_encoder.set_bytes(Br, 14); + compute_encoder.set_bytes(Bc, 15); + compute_encoder.set_bytes(Tr, 16); + compute_encoder.set_bytes(Tc, 17); MTL::Size num_threadgroups = MTL::Size(N, Tr, 1); MTL::Size num_threads_per_group = MTL::Size(Br, simd_width, 1); diff --git a/src/extensions_ref/src/flash_attention.metal b/src/extensions_ref/src/flash_attention.metal index ab15eaf..bb895c1 100644 --- a/src/extensions_ref/src/flash_attention.metal +++ b/src/extensions_ref/src/flash_attention.metal @@ -1,22 +1,27 @@ #include +#include "mlx/backend/metal/kernels/utils.h" + using namespace metal; [[kernel]] void flash_attention_f32_e128( device const float* q [[buffer(0)]], device const float* k [[buffer(1)]], device const float* v [[buffer(2)]], - device float* out [[buffer(3)]], - [[maybe_unused]] device const int64_t &N [[buffer(4)]], - device const int64_t &S [[buffer(5)]], - device const int64_t &L [[buffer(6)]], - device const int64_t &E [[buffer(7)]], - device const int64_t &num_kv_heads [[buffer(8)]], - device const int64_t &num_heads [[buffer(9)]], - device const float &scale [[buffer(10)]], - device const int64_t &Br [[buffer(11)]], - device const int64_t &Bc [[buffer(12)]], - [[maybe_unused]] device const int64_t &Tr [[buffer(13)]], - device const int64_t &Tc [[buffer(14)]], + device const float* mask [[buffer(3)]], + constant const int* mask_shape [[buffer(4)]], + constant const int64_t* mask_strides [[buffer(5)]], + device float* out [[buffer(6)]], + [[maybe_unused]] device const int64_t &N [[buffer(7)]], + device const int64_t &L [[buffer(8)]], + device const int64_t &S [[buffer(9)]], + device const int64_t &E [[buffer(10)]], + device const int64_t &num_kv_heads [[buffer(11)]], + device const int64_t &num_heads [[buffer(12)]], + device const float &scale [[buffer(13)]], + device const int64_t &Br [[buffer(14)]], + device const int64_t &Bc [[buffer(15)]], + [[maybe_unused]] device const int64_t &Tr [[buffer(16)]], + device const int64_t &Tc [[buffer(17)]], uint2 group_id [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { @@ -31,12 +36,12 @@ using namespace metal; // 128*32*sizeof(float) bytes * number of arrays and use them as the threadgroup // shared memory. - bool is_i_in_range = i * Br + a < S && a < Br; + bool is_i_in_range = i * Br + a < L && a < Br; const int q_kv_ratio = num_heads / num_kv_heads; - device const float *q_ptr = q + n * S * E + i * Br * E; - device const float *k_ptr_base = k + (n / q_kv_ratio) * L * E; - device const float *v_ptr_base = v + (n / q_kv_ratio) * L * E; + device const float *q_ptr = q + n * L * E + i * Br * E; + device const float *k_ptr_base = k + (n / q_kv_ratio) * S * E; + device const float *v_ptr_base = v + (n / q_kv_ratio) * S * E; threadgroup float o_i[32][128]; // Br x E, each simd group shares an o_i, only lane 0 writes to it if (simd_lid == 0) { @@ -45,9 +50,9 @@ using namespace metal; } } - // q_ptr: S * E - // k_ptr: L * E - // v_ptr: L * E + // q_ptr: L * E + // k_ptr: S * E + // v_ptr: S * E // To access q[a, c]: use a * E + c // To access k/v[b, c]: use b * E + c @@ -55,7 +60,7 @@ using namespace metal; float l_i = 0.0; // per thread; sync to threadgroup memory later for (int j = 0; j < Tc; j++) { - bool is_j_in_range = j * Bc + b < L && b < Bc; + bool is_j_in_range = j * Bc + b < S && b < Bc; device const float *k_ptr = k_ptr_base + j * Bc * E; device const float *v_ptr = v_ptr_base + j * Bc * E; @@ -67,6 +72,16 @@ using namespace metal; s_a_b += q_ptr[a * E + c] * k_ptr[b * E + c]; } } + + s_a_b *= scale; + if (is_i_in_range && is_j_in_range) { + int64_t m_idx_1 = n; + int64_t m_idx_2 = i * Br + a; + int64_t m_idx_3 = j * Bc + b; + int64_t m_idx_converted = elem_to_loc(m_idx_1 * L * S + m_idx_2 * S + m_idx_3, mask_shape, mask_strides, 3); + s_a_b += mask[m_idx_converted]; + } + // for each cell, get the rowmax of the corresponding row, and compute m_i in each // of the cells float rowmax = simd_max(s_a_b); @@ -112,7 +127,7 @@ using namespace metal; } for (int c = 0; c < E; c++) { if (is_i_in_range) { - out[n * S * E + (i * Br + a) * E + c] = o_i[a][c]; + out[n * L * E + (i * Br + a) * E + c] = o_i[a][c]; } } } diff --git a/src/extensions_ref/src/quantized_matmul.metal b/src/extensions_ref/src/quantized_matmul.metal index 1f997c3..5a58eb6 100644 --- a/src/extensions_ref/src/quantized_matmul.metal +++ b/src/extensions_ref/src/quantized_matmul.metal @@ -10,7 +10,7 @@ uint3 group_id [[threadgroup_position_in_grid]], uint3 thread_id [[thread_position_in_threadgroup]], uint3 threads_per_threadgroup [[threads_per_threadgroup]], - threadgroup char * shmem [[threadgroup(0)]]) { + [[maybe_unused]] threadgroup char * shmem [[threadgroup(0)]]) { const int group_size = 64; const int bits = 4; const int packs_per_item = 32 / bits; diff --git a/src/extensions_ref/src/tiny_llm_ext.h b/src/extensions_ref/src/tiny_llm_ext.h index c7afbd3..824c2b4 100644 --- a/src/extensions_ref/src/tiny_llm_ext.h +++ b/src/extensions_ref/src/tiny_llm_ext.h @@ -32,7 +32,7 @@ class QuantizedMatmul : public mx::Primitive { throw std::runtime_error("QuantizedMatmul has no vmap implementation."); } - const char* name() const override { return "QuantizedMatmul"; } + const char *name() const override { return "QuantizedMatmul"; } bool is_equivalent(const mx::Primitive &other) const override; @@ -41,8 +41,8 @@ class QuantizedMatmul : public mx::Primitive { int bits_; }; -mx::array flash_attention(const mx::array &q, const mx::array &k, const mx::array &v, const float scale, - const int num_kv_heads, const int num_heads, mx::StreamOrDevice s = {}); +mx::array flash_attention(const mx::array &q, const mx::array &k, const mx::array &v, const mx::array &mask, + const float scale, const int num_kv_heads, const int num_heads, mx::StreamOrDevice s = {}); class FlashAttention : public mx::Primitive { public: @@ -57,7 +57,7 @@ class FlashAttention : public mx::Primitive { throw std::runtime_error("FlashAttention has no vmap implementation."); } - const char* name() const override { return "FlashAttention"; } + const char *name() const override { return "FlashAttention"; } bool is_equivalent(const mx::Primitive &other) const override { const FlashAttention &r_other = static_cast(other); diff --git a/src/tiny_llm_ref/attention.py b/src/tiny_llm_ref/attention.py index b53778b..2970e76 100644 --- a/src/tiny_llm_ref/attention.py +++ b/src/tiny_llm_ref/attention.py @@ -11,8 +11,8 @@ def scaled_dot_product_attention_simple( mask: mx.array | None = None, ) -> mx.array: """ - A simple implementation of scaled dot product attention. Assuming Q,K,V are of the same shape. - Assuming mask is always a float array. + A simple implementation of scaled dot product attention. Assuming Q, K, V are of the same shape. + Assuming mask is always a float array that you can add to the scores. """ factor = mx.rsqrt(query.shape[-1]) if scale is None else scale scores = mx.matmul(query, key.swapaxes(-2, -1)) * factor @@ -26,14 +26,18 @@ def causal_mask(L: int, S: int, dtype: mx.Dtype) -> mx.array: mask = mx.where(mask, mx.array(0), mx.array(-mx.inf)).astype(dtype) return mask - def scaled_dot_product_attention_grouped( query: mx.array, key: mx.array, value: mx.array, scale: float | None = None, - mask: mx.array | str | None = None, + mask: mx.array | None = None, ) -> mx.array: + """ + Potential input of the mask: + - mx.array that can broadcast to B * H_q * L * S, which needs to be reshaped to match multi-head dimensions + - None which will be ignored + """ factor = mx.rsqrt(query.shape[-1]) if scale is None else mx.array(scale) factor = factor.astype(query.dtype) expected_shape = query.shape @@ -50,13 +54,9 @@ def scaled_dot_product_attention_grouped( scores = mx.matmul(query, key.swapaxes(-2, -1)) * factor if mask is not None: - if mask == "causal": - mask = causal_mask(L, S, scores.dtype) - scores = scores + mask - else: - mask = mx.broadcast_to(mask, (*B, H_q, L, S)) - mask = mask.reshape(*B, 1, H, n_repeats, L, S) - scores = scores + mask + mask = mx.broadcast_to(mask, (*B, H_q, L, S)) + mask = mask.reshape(*B, 1, H, n_repeats, L, S) + scores = scores + mask result = mx.matmul(softmax(scores, axis=-1), value) return result.reshape(expected_shape) @@ -66,20 +66,31 @@ def flash_attention( key: mx.array, value: mx.array, scale: float | None = None, + mask: mx.array | None = None, ) -> mx.array: - *B, H_q, S, E = query.shape - _, H, L, _ = key.shape + factor = mx.rsqrt(query.shape[-1]) if scale is None else mx.array(scale) + factor = factor.astype(query.dtype) + + *B, H_q, L, E = query.shape + _, H, S, _ = key.shape assert H_q % H == 0 - query = query.reshape(-1, S, E) - key = key.reshape(-1, L, E) - value = value.reshape(-1, L, E) + query = query.reshape(-1, L, E) + key = key.reshape(-1, S, E) + value = value.reshape(-1, S, E) query = mx.contiguous(query) key = mx.contiguous(key) value = mx.contiguous(value) + N = query.shape[0] + if mask is None: + mask = mx.zeros((N, L, S)) + else: + mask = mx.reshape(mx.broadcast_to(mask, (*B, H_q, L, S)), (N, L, S)) + # Seems like an MLX bug: need to make contiguous before passing to the kernel + mask = mx.contiguous(mask) result = tiny_llm_ext_ref.flash_attention( - query, key, value, scale, num_heads=H_q, num_kv_heads=H + query, key, value, mask, factor, num_heads=H_q, num_kv_heads=H, ) - return result.reshape(*B, H_q, S, E) + return result.reshape(*B, H_q, L, E) class SimpleMultiHeadAttention: diff --git a/src/tiny_llm_ref/qwen2_week2.py b/src/tiny_llm_ref/qwen2_week2.py index 49ff6cc..d61d254 100644 --- a/src/tiny_llm_ref/qwen2_week2.py +++ b/src/tiny_llm_ref/qwen2_week2.py @@ -1,6 +1,6 @@ import mlx.core as mx from .basics import silu -from .attention import scaled_dot_product_attention_grouped +from .attention import scaled_dot_product_attention_grouped, flash_attention, causal_mask from .layer_norm import RMSNorm from .positional_encoding import RoPE from typing import Any @@ -24,6 +24,7 @@ def __init__( bv: mx.array, max_seq_len: int = 32768, theta: int = 1000000, + use_flash_attention: bool = False, ): self.hidden_size = hidden_size self.num_heads = num_heads @@ -44,6 +45,7 @@ def __init__( self.bk = bk self.bv = bv self.rope = RoPE(self.head_dim, max_seq_len, theta) + self.use_flash_attention = use_flash_attention def __call__( self, @@ -77,13 +79,25 @@ def __call__( ) if kv_cache_mask is not None: mask = kv_cache_mask - x = scaled_dot_product_attention_grouped( - projection_q.astype(mx.float32), - projection_k.astype(mx.float32), - projection_v.astype(mx.float32), - scale=self.scale, - mask=mask, - ).astype(x.dtype) + S = projection_k.shape[-2] + if mask == "causal": + mask = causal_mask(L, S, mx.float32) + if self.use_flash_attention: + x = flash_attention( + projection_q.astype(mx.float32), + projection_k.astype(mx.float32), + projection_v.astype(mx.float32), + scale=self.scale, + mask=mask, + ).astype(x.dtype) + else: + x = scaled_dot_product_attention_grouped( + projection_q.astype(mx.float32), + projection_k.astype(mx.float32), + projection_v.astype(mx.float32), + scale=self.scale, + mask=mask, + ).astype(x.dtype) x = x.transpose(0, 2, 1, 3).reshape(B, L, self.hidden_size) return quantized_linear(x, self.wo) @@ -132,6 +146,7 @@ def __init__( w_post_attention_layernorm: mx.array, max_seq_len: int = 32768, theta: int = 1000000, + use_flash_attention: bool = False, ): self.num_attention_heads = num_attention_heads self.hidden_size = hidden_size @@ -153,6 +168,7 @@ def __init__( bv=bv, max_seq_len=max_seq_len, theta=theta, + use_flash_attention=use_flash_attention, ) def __call__( @@ -234,6 +250,7 @@ def __init__( ].post_attention_layernorm.weight.astype(precision), max_seq_len=mlx_model.args.max_position_embeddings, theta=mlx_model.args.rope_theta, + use_flash_attention=False, ) self.layers_inner.append(layer) self.norm = RMSNorm( diff --git a/tests_refsol/test_week_2_day_6.py b/tests_refsol/test_week_2_day_6.py new file mode 100644 index 0000000..492fbe6 --- /dev/null +++ b/tests_refsol/test_week_2_day_6.py @@ -0,0 +1,99 @@ +import pytest +import mlx.core as mx +from .tiny_llm_base import * +from .utils import * + + +def attention_helper(stream: mx.Stream, H_q, H, L, E, S, BATCH, use_flash_attention: bool = False): + precision = mx.float32 + with mx.stream(stream): + q_shape = (BATCH, H_q, L, E) + kv_shape = (BATCH, H, S, E) + scale = 0.8 + for _ in range(100): + query = mx.random.uniform(shape=q_shape, dtype=precision) + key = mx.random.uniform(shape=kv_shape, dtype=precision) + value = mx.random.uniform(shape=kv_shape, dtype=precision) + mask = mx.random.uniform(shape=(BATCH, 1, L, S), dtype=precision) + + reference_output = mx.fast.scaled_dot_product_attention( + q=query, + k=key, + v=value, + scale=scale, + mask=mask, + ) + if use_flash_attention: + user_output = flash_attention( + query, + key, + value, + scale=scale, + mask=mask, + ) + else: + user_output = scaled_dot_product_attention_grouped( + query, + key, + value, + scale=scale, + mask=mask, + ) + mx.eval(user_output) # so that any error will be caught here + assert_allclose(user_output, reference_output, precision=mx.float16) + + +def test_flash_attention_with_mask_cpu_small(): + attention_helper(mx.cpu, 6, 3, 2, 5, 3, 1, use_flash_attention=True) + + +def test_flash_attention_with_mask_cpu(): + attention_helper(mx.cpu, 18, 6, 7, 5, 3, 10, use_flash_attention=True) + + +def test_flash_attention_with_mask_cpu_large(): + attention_helper(mx.cpu, 28, 4, 16, 128, 16, 3, use_flash_attention=True) + + +def test_flash_attention_with_mask_gpu_extra_small(): + attention_helper(mx.gpu, 1, 1, 5, 7, 4, 1, use_flash_attention=True) + + +def test_flash_attention_with_mask_gpu_small(): + attention_helper(mx.gpu, 6, 3, 2, 5, 3, 1, use_flash_attention=True) + + +def test_flash_attention_with_mask_gpu(): + attention_helper(mx.gpu, 18, 6, 7, 5, 3, 10, use_flash_attention=True) + + +def test_flash_attention_with_mask_gpu_large(): + attention_helper(mx.gpu, 28, 4, 16, 128, 16, 3, use_flash_attention=True) + + +def test_attention_with_mask_cpu_small(): + attention_helper(mx.cpu, 6, 3, 2, 5, 3, 1, use_flash_attention=False) + + +def test_attention_with_mask_cpu(): + attention_helper(mx.cpu, 18, 6, 7, 5, 3, 10, use_flash_attention=False) + + +def test_attention_with_mask_cpu_large(): + attention_helper(mx.cpu, 28, 4, 16, 128, 16, 3, use_flash_attention=False) + + +def test_attention_with_mask_gpu_extra_small(): + attention_helper(mx.gpu, 1, 1, 5, 7, 4, 1, use_flash_attention=False) + + +def test_attention_with_mask_gpu_small(): + attention_helper(mx.gpu, 6, 3, 2, 5, 3, 1, use_flash_attention=False) + + +def test_attention_with_mask_gpu(): + attention_helper(mx.gpu, 18, 6, 7, 5, 3, 10, use_flash_attention=False) + + +def test_attention_with_mask_gpu_large(): + attention_helper(mx.gpu, 28, 4, 16, 128, 16, 3, use_flash_attention=False) From d954fb525a86595752a725f3b72c7ca8a0dfff31 Mon Sep 17 00:00:00 2001 From: Alex Chi Z Date: Sat, 2 Aug 2025 16:51:02 -0400 Subject: [PATCH 11/79] add back causal mask to gqa Signed-off-by: Alex Chi Z --- src/tiny_llm_ref/attention.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/tiny_llm_ref/attention.py b/src/tiny_llm_ref/attention.py index 2970e76..b23bae9 100644 --- a/src/tiny_llm_ref/attention.py +++ b/src/tiny_llm_ref/attention.py @@ -31,7 +31,7 @@ def scaled_dot_product_attention_grouped( key: mx.array, value: mx.array, scale: float | None = None, - mask: mx.array | None = None, + mask: mx.array | str | None = None, ) -> mx.array: """ Potential input of the mask: @@ -54,9 +54,13 @@ def scaled_dot_product_attention_grouped( scores = mx.matmul(query, key.swapaxes(-2, -1)) * factor if mask is not None: - mask = mx.broadcast_to(mask, (*B, H_q, L, S)) - mask = mask.reshape(*B, 1, H, n_repeats, L, S) - scores = scores + mask + if mask == "causal": + mask = causal_mask(L, S, scores.dtype) + scores = scores + mask + else: + mask = mx.broadcast_to(mask, (*B, H_q, L, S)) + mask = mask.reshape(*B, 1, H, n_repeats, L, S) + scores = scores + mask result = mx.matmul(softmax(scores, axis=-1), value) return result.reshape(expected_shape) From e21a583e7fcc408edb2cddb0e833be805b8344ab Mon Sep 17 00:00:00 2001 From: Alex Chi Z Date: Sat, 2 Aug 2025 16:52:59 -0400 Subject: [PATCH 12/79] flash attention works for the first token, maybe some mem init issue Signed-off-by: Alex Chi Z --- src/extensions_ref/src/flash_attention.metal | 2 ++ src/tiny_llm_ref/qwen2_week2.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/extensions_ref/src/flash_attention.metal b/src/extensions_ref/src/flash_attention.metal index bb895c1..ac7e001 100644 --- a/src/extensions_ref/src/flash_attention.metal +++ b/src/extensions_ref/src/flash_attention.metal @@ -80,6 +80,8 @@ using namespace metal; int64_t m_idx_3 = j * Bc + b; int64_t m_idx_converted = elem_to_loc(m_idx_1 * L * S + m_idx_2 * S + m_idx_3, mask_shape, mask_strides, 3); s_a_b += mask[m_idx_converted]; + } else { + s_a_b = -1e9; } // for each cell, get the rowmax of the corresponding row, and compute m_i in each diff --git a/src/tiny_llm_ref/qwen2_week2.py b/src/tiny_llm_ref/qwen2_week2.py index d61d254..ae76d81 100644 --- a/src/tiny_llm_ref/qwen2_week2.py +++ b/src/tiny_llm_ref/qwen2_week2.py @@ -250,7 +250,7 @@ def __init__( ].post_attention_layernorm.weight.astype(precision), max_seq_len=mlx_model.args.max_position_embeddings, theta=mlx_model.args.rope_theta, - use_flash_attention=False, + use_flash_attention=True, ) self.layers_inner.append(layer) self.norm = RMSNorm( From 9eacc3b57197a4a7ccd8dc258c9add36266c98cd Mon Sep 17 00:00:00 2001 From: Alex Chi Z Date: Sun, 3 Aug 2025 13:52:37 -0400 Subject: [PATCH 13/79] try debug flashattention on multi test run Signed-off-by: Alex Chi Z --- src/extensions_ref/src/flash_attention.cpp | 33 ++++++++++++------ src/extensions_ref/src/flash_attention.metal | 36 ++++++++++---------- src/tiny_llm_ref/attention.py | 17 ++++++--- src/tiny_llm_ref/qwen2_week2.py | 6 +++- tests_refsol/test_week_2_day_6.py | 4 ++- 5 files changed, 61 insertions(+), 35 deletions(-) diff --git a/src/extensions_ref/src/flash_attention.cpp b/src/extensions_ref/src/flash_attention.cpp index 768b384..17106a2 100644 --- a/src/extensions_ref/src/flash_attention.cpp +++ b/src/extensions_ref/src/flash_attention.cpp @@ -64,6 +64,10 @@ void FlashAttention::eval_cpu(const std::vector &inputs, std::vector< auto &mask = inputs[3]; auto &out = outputs[0]; + if (out.dtype() != mx::float32) { + throw std::runtime_error("flash_attention: output dtype must be float32"); + } + out.set_data(mx::allocator::malloc(out.nbytes())); auto &encoder = mx::cpu::get_command_encoder(stream()); @@ -235,9 +239,14 @@ void FlashAttention::eval_gpu(const std::vector &inputs, std::vector< const auto &mask = inputs[3]; auto &out = outputs[0]; + if (out.dtype() != mx::float32) { + throw std::runtime_error("flash_attention: output dtype must be float32"); + } + + out.set_data(mx::allocator::malloc(out.nbytes())); + auto &s = stream(); auto &d = mx::metal::device(s.device); - out.set_data(mx::allocator::malloc(out.nbytes())); // Make a kernel from this metal library auto library = d.get_library("tiny_llm_ext_ref"); @@ -247,20 +256,17 @@ void FlashAttention::eval_gpu(const std::vector &inputs, std::vector< auto &compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); - // Kernel parameters are registered with buffer indices corresponding to - // those in the kernel declaration at axpby.metal - int ndim = out.ndim(); - // Encode input arrays to kernel compute_encoder.set_input_array(q, 0); compute_encoder.set_input_array(k, 1); compute_encoder.set_input_array(v, 2); compute_encoder.set_input_array(mask, 3); - compute_encoder.set_vector_bytes(mask.shape(), 4); - compute_encoder.set_vector_bytes(mask.strides(), 5); // Encode output arrays to kernel - compute_encoder.set_output_array(out, 6); + compute_encoder.set_output_array(out, 4); + + compute_encoder.set_vector_bytes(mask.shape(), 5); + compute_encoder.set_vector_bytes(mask.strides(), 6); if (!q.flags().row_contiguous) { throw std::runtime_error("flash_attention: q must be contiguous"); @@ -271,6 +277,9 @@ void FlashAttention::eval_gpu(const std::vector &inputs, std::vector< if (!v.flags().row_contiguous) { throw std::runtime_error("flash_attention: v must be contiguous"); } + if (!out.flags().row_contiguous) { + throw std::runtime_error("flash_attention: out must be contiguous"); + } const int64_t N = q.shape()[0]; const int64_t L = q.shape()[1]; @@ -306,8 +315,12 @@ void FlashAttention::eval_gpu(const std::vector &inputs, std::vector< throw std::runtime_error("flash_attention: Br must be less than 32"); } - const int64_t Tr = (S + Br - 1) / Br; - const int64_t Tc = (L + Bc - 1) / Bc; + if (Bc > 32) { + throw std::runtime_error("flash_attention: Bc must be less than 32"); + } + + const int64_t Tr = (L + Br - 1) / Br; + const int64_t Tc = (S + Bc - 1) / Bc; compute_encoder.set_bytes(Br, 14); compute_encoder.set_bytes(Bc, 15); diff --git a/src/extensions_ref/src/flash_attention.metal b/src/extensions_ref/src/flash_attention.metal index ac7e001..dcee63d 100644 --- a/src/extensions_ref/src/flash_attention.metal +++ b/src/extensions_ref/src/flash_attention.metal @@ -8,10 +8,10 @@ using namespace metal; device const float* k [[buffer(1)]], device const float* v [[buffer(2)]], device const float* mask [[buffer(3)]], - constant const int* mask_shape [[buffer(4)]], - constant const int64_t* mask_strides [[buffer(5)]], - device float* out [[buffer(6)]], - [[maybe_unused]] device const int64_t &N [[buffer(7)]], + device float* out [[buffer(4)]], + constant const int* mask_shape [[buffer(5)]], + constant const int64_t* mask_strides [[buffer(6)]], + device const int64_t &N [[buffer(7)]], device const int64_t &L [[buffer(8)]], device const int64_t &S [[buffer(9)]], device const int64_t &E [[buffer(10)]], @@ -31,18 +31,12 @@ using namespace metal; int a = simd_gid; // max=Br int b = simd_lid; // max=Bc - // We do not use the shared memory for the threadgroup in this course -- - // this is left as an exercise for the students. For example, you can allocate - // 128*32*sizeof(float) bytes * number of arrays and use them as the threadgroup - // shared memory. - bool is_i_in_range = i * Br + a < L && a < Br; - const int q_kv_ratio = num_heads / num_kv_heads; device const float *q_ptr = q + n * L * E + i * Br * E; device const float *k_ptr_base = k + (n / q_kv_ratio) * S * E; device const float *v_ptr_base = v + (n / q_kv_ratio) * S * E; - threadgroup float o_i[32][128]; // Br x E, each simd group shares an o_i, only lane 0 writes to it + threadgroup float o_i[32][128]; // assume max(E) = 128, max(Br) = 32, only lane 0 writes to it if (simd_lid == 0) { for (int c = 0; c < E; c++) { @@ -50,6 +44,7 @@ using namespace metal; } } + threadgroup float q_local[32][128]; // assume max(E) = 128, max(Br) = 32, access by a, c // q_ptr: L * E // k_ptr: S * E // v_ptr: S * E @@ -59,6 +54,13 @@ using namespace metal; float m_i = -1e9; // per thread; sync to threadgroup memory later float l_i = 0.0; // per thread; sync to threadgroup memory later + // load q_local + if (simd_lid == 0) { + for (int c = 0; c < E; c++) { + q_local[a][c] = q_ptr[a * E + c]; + } + } + for (int j = 0; j < Tc; j++) { bool is_j_in_range = j * Bc + b < S && b < Bc; @@ -69,10 +71,9 @@ using namespace metal; float s_a_b = 0.0; for (int c = 0; c < E; c++) { if (is_i_in_range && is_j_in_range) { - s_a_b += q_ptr[a * E + c] * k_ptr[b * E + c]; + s_a_b += q_local[a][c] * k_ptr[b * E + c]; } } - s_a_b *= scale; if (is_i_in_range && is_j_in_range) { int64_t m_idx_1 = n; @@ -125,11 +126,10 @@ using namespace metal; // write to output if (simd_lid == 0) { for (int c = 0; c < E; c++) { - o_i[a][c] /= l_i; - } - for (int c = 0; c < E; c++) { - if (is_i_in_range) { - out[n * L * E + (i * Br + a) * E + c] = o_i[a][c]; + if (is_i_in_range && n < N) { + float o_i_c = o_i[a][c]; + o_i_c /= l_i; + out[n * L * E + (i * Br + a) * E + c] = o_i_c; } } } diff --git a/src/tiny_llm_ref/attention.py b/src/tiny_llm_ref/attention.py index b23bae9..82032f4 100644 --- a/src/tiny_llm_ref/attention.py +++ b/src/tiny_llm_ref/attention.py @@ -26,6 +26,7 @@ def causal_mask(L: int, S: int, dtype: mx.Dtype) -> mx.array: mask = mx.where(mask, mx.array(0), mx.array(-mx.inf)).astype(dtype) return mask + def scaled_dot_product_attention_grouped( query: mx.array, key: mx.array, @@ -88,13 +89,19 @@ def flash_attention( if mask is None: mask = mx.zeros((N, L, S)) else: - mask = mx.reshape(mx.broadcast_to(mask, (*B, H_q, L, S)), (N, L, S)) - # Seems like an MLX bug: need to make contiguous before passing to the kernel - mask = mx.contiguous(mask) + mask = mx.contiguous( + mx.reshape(mx.contiguous(mx.broadcast_to(mask, (*B, H_q, L, S))), (N, L, S)) + ) result = tiny_llm_ext_ref.flash_attention( - query, key, value, mask, factor, num_heads=H_q, num_kv_heads=H, + query, + key, + value, + mask, + factor, + num_heads=H_q, + num_kv_heads=H, ) - return result.reshape(*B, H_q, L, E) + return mx.contiguous(result.reshape(*B, H_q, L, E)) class SimpleMultiHeadAttention: diff --git a/src/tiny_llm_ref/qwen2_week2.py b/src/tiny_llm_ref/qwen2_week2.py index ae76d81..540f8fb 100644 --- a/src/tiny_llm_ref/qwen2_week2.py +++ b/src/tiny_llm_ref/qwen2_week2.py @@ -1,6 +1,10 @@ import mlx.core as mx from .basics import silu -from .attention import scaled_dot_product_attention_grouped, flash_attention, causal_mask +from .attention import ( + scaled_dot_product_attention_grouped, + flash_attention, + causal_mask, +) from .layer_norm import RMSNorm from .positional_encoding import RoPE from typing import Any diff --git a/tests_refsol/test_week_2_day_6.py b/tests_refsol/test_week_2_day_6.py index 492fbe6..e5f6d8d 100644 --- a/tests_refsol/test_week_2_day_6.py +++ b/tests_refsol/test_week_2_day_6.py @@ -4,7 +4,9 @@ from .utils import * -def attention_helper(stream: mx.Stream, H_q, H, L, E, S, BATCH, use_flash_attention: bool = False): +def attention_helper( + stream: mx.Stream, H_q, H, L, E, S, BATCH, use_flash_attention: bool = False +): precision = mx.float32 with mx.stream(stream): q_shape = (BATCH, H_q, L, E) From 657c0b3de3fa571a7d5d28cf575dad48778f57eb Mon Sep 17 00:00:00 2001 From: Alex Chi Z Date: Sun, 3 Aug 2025 14:11:18 -0400 Subject: [PATCH 14/79] small fixes of flash attention Signed-off-by: Alex Chi Z --- src/extensions_ref/src/flash_attention.metal | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/src/extensions_ref/src/flash_attention.metal b/src/extensions_ref/src/flash_attention.metal index dcee63d..4ac4ccd 100644 --- a/src/extensions_ref/src/flash_attention.metal +++ b/src/extensions_ref/src/flash_attention.metal @@ -36,11 +36,11 @@ using namespace metal; device const float *q_ptr = q + n * L * E + i * Br * E; device const float *k_ptr_base = k + (n / q_kv_ratio) * S * E; device const float *v_ptr_base = v + (n / q_kv_ratio) * S * E; - threadgroup float o_i[32][128]; // assume max(E) = 128, max(Br) = 32, only lane 0 writes to it + threadgroup float o_i[128 * 32]; // assume max(E) = 128, max(Br) = 32, only lane 0 writes to it if (simd_lid == 0) { for (int c = 0; c < E; c++) { - o_i[a][c] = 0.0; + o_i[a * E + c] = 0.0; } } @@ -61,6 +61,14 @@ using namespace metal; } } + if (simd_lid == 0) { + for (int c = 0; c < E; c++) { + if (is_i_in_range && n < N) { + out[n * L * E + (i * Br + a) * E + c] = -233.0; + } + } + } + for (int j = 0; j < Tc; j++) { bool is_j_in_range = j * Bc + b < S && b < Bc; @@ -108,6 +116,7 @@ using namespace metal; // compute o_i, where O is Br x E; note that this does not align // with the threadgroup we dispatch, so we have to do threadgroup sync + threadgroup_barrier(mem_flags::mem_threadgroup); for (int c = 0; c < E; c++) { float v; if (is_i_in_range && is_j_in_range) { @@ -117,17 +126,18 @@ using namespace metal; } float res = simd_sum(v); // res = sum(p_a_b * v_j) on each cell // only lane 0 will write to threadgroup memory - if (simd_lid == 0) { - o_i[a][c] = m_i_diff_exp * o_i[a][c] + res; + if (simd_lid == 0 && is_i_in_range && is_j_in_range) { + o_i[a * E + c] = m_i_diff_exp * o_i[a * E + c] + res; } } } // write to output if (simd_lid == 0) { + threadgroup_barrier(mem_flags::mem_threadgroup); for (int c = 0; c < E; c++) { if (is_i_in_range && n < N) { - float o_i_c = o_i[a][c]; + float o_i_c = o_i[a * E + c]; o_i_c /= l_i; out[n * L * E + (i * Br + a) * E + c] = o_i_c; } From 8dfe61c1d5a5d4451151e54f47e0b925e25d0d52 Mon Sep 17 00:00:00 2001 From: Alex Chi Z Date: Sun, 3 Aug 2025 16:40:00 -0400 Subject: [PATCH 15/79] finally fully fix flash attention Signed-off-by: Alex Chi Z --- pdm.lock | 85 ++++++++++---------- src/extensions_ref/src/flash_attention.cpp | 21 +++-- src/extensions_ref/src/flash_attention.metal | 20 ++--- src/extensions_ref/src/quantized_matmul.cpp | 5 -- src/extensions_ref/src/tiny_llm_ext.h | 10 +-- src/tiny_llm_ref/attention.py | 6 +- tests/utils.py | 3 +- tests_refsol/test_week_2_day_3.py | 2 +- tests_refsol/test_week_2_day_6.py | 30 +++++-- 9 files changed, 95 insertions(+), 87 deletions(-) diff --git a/pdm.lock b/pdm.lock index d649d46..ce49d33 100644 --- a/pdm.lock +++ b/pdm.lock @@ -783,57 +783,58 @@ files = [ [[package]] name = "mlx" -version = "0.26.5" +version = "0.27.1" requires_python = ">=3.9" summary = "A framework for machine learning on Apple silicon." groups = ["default"] dependencies = [ - "mlx-metal==0.26.5", + "mlx-metal==0.27.1; platform_system == \"Darwin\"", ] files = [ - {file = "mlx-0.26.5-cp310-cp310-macosx_13_0_arm64.whl", hash = "sha256:027cf842643ee27176e24c604b453f3200d9c03a96aa0a24c7f9de7027e87485"}, - {file = "mlx-0.26.5-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:f7f26bb955b7b33564ff93f06050d086639dc7e55d4addda19f3658960822cf4"}, - {file = "mlx-0.26.5-cp310-cp310-macosx_15_0_arm64.whl", hash = "sha256:907c1fadbd3a13db40e0f455d19cc1750ec9961432b603605b501945126927a7"}, - {file = "mlx-0.26.5-cp310-cp310-manylinux_2_34_x86_64.manylinux_2_35_x86_64.whl", hash = "sha256:c9dfd46c9cf60e12b3440d8224068826f3c58b0ff2afadd6185f907ae587df17"}, - {file = "mlx-0.26.5-cp311-cp311-macosx_13_0_arm64.whl", hash = "sha256:04367b20970a081c4359111c667ba62788592fd66184943c7eff10d660a13fce"}, - {file = "mlx-0.26.5-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:aa14defb1c5b330c94e9873d09c90f27d72681970c20c1490457c8a2a51b93e5"}, - {file = "mlx-0.26.5-cp311-cp311-macosx_15_0_arm64.whl", hash = "sha256:1e2d90d391c83932dee9a21ce5c95de30a60ce7fdac61143e3a15a4d8e55271d"}, - {file = "mlx-0.26.5-cp311-cp311-manylinux_2_34_x86_64.manylinux_2_35_x86_64.whl", hash = "sha256:b6ab9ee55d0c2e2a67d66b35bdd52adf5e21adf9fdc4c2cc2e4c6f465ec30145"}, - {file = "mlx-0.26.5-cp312-cp312-macosx_13_0_arm64.whl", hash = "sha256:a29a060d089906a8ce753d0cb605c9c20005ce1865ccde8d06ad91c8c4325dc7"}, - {file = "mlx-0.26.5-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:d754e964c7c320c67f4a1a00d6d2d60c1eec8213b83ebe754702bf6dc51a36a4"}, - {file = "mlx-0.26.5-cp312-cp312-macosx_15_0_arm64.whl", hash = "sha256:77c5b750b24a18ed6e433dc46d787193b124a869b1b65d1532bfb7d37ec7172f"}, - {file = "mlx-0.26.5-cp312-cp312-manylinux_2_34_x86_64.manylinux_2_35_x86_64.whl", hash = "sha256:9c2b5dc36c57c8e2fe2245ca8cb2db167d053395afd53979f18728e2311cd633"}, + {file = "mlx-0.27.1-cp310-cp310-macosx_13_0_arm64.whl", hash = "sha256:a033b65fe46425ad5032867d5a71a556a5168108d89aa7092b457556c70d84fc"}, + {file = "mlx-0.27.1-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:2836c6fd9803dc0c6cd06f204e31b3e0191e6c5b6bc8570b28661d926908cba3"}, + {file = "mlx-0.27.1-cp310-cp310-macosx_15_0_arm64.whl", hash = "sha256:8b0566054c46d84c470cc99cda2afc3914ad6c7808fbb724dc1ec235e5b2a98c"}, + {file = "mlx-0.27.1-cp310-cp310-manylinux_2_35_x86_64.whl", hash = "sha256:91ef93ce09900c9a8ca662cf34e3c39ab5af2762822ecd6b12fecae518be167f"}, + {file = "mlx-0.27.1-cp311-cp311-macosx_13_0_arm64.whl", hash = "sha256:d2e5dedbfbcbe558e51a5c476ca6a18e307676f9e49854eb27e53778bc474699"}, + {file = "mlx-0.27.1-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:9f04b9778897a879c9ca22e5413dfa1efc192d86d7211b184e079efec49dfb8b"}, + {file = "mlx-0.27.1-cp311-cp311-macosx_15_0_arm64.whl", hash = "sha256:01d794f9e390438ab4f942a18d9a8ca65bef10c2c2007ef38ca988d039d6d9d3"}, + {file = "mlx-0.27.1-cp311-cp311-manylinux_2_35_x86_64.whl", hash = "sha256:fae11432d0639789f1e172b19b35ac8987c8ab9716e55a23fc7a170d6545fc33"}, + {file = "mlx-0.27.1-cp312-cp312-macosx_13_0_arm64.whl", hash = "sha256:0c570c9afb57c697bd864504115be8a7c4de97f0b80557a597d496ee426a6812"}, + {file = "mlx-0.27.1-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:ccff7bbbd9df302b26e79013ef6d0c3531c9ba5963ead521e2d85856811b86a0"}, + {file = "mlx-0.27.1-cp312-cp312-macosx_15_0_arm64.whl", hash = "sha256:9ccadaed449c07dfeae484620992b904c17dfea7564f8df63095c60eed3af02b"}, + {file = "mlx-0.27.1-cp312-cp312-manylinux_2_35_x86_64.whl", hash = "sha256:742c413e75605b71db69379a176da63e32ba19b9e9ad03b8763adbd1fcfcd394"}, ] [[package]] name = "mlx-lm" -version = "0.26.0" +version = "0.26.2" requires_python = ">=3.8" -summary = "LLMs on Apple silicon with MLX and the Hugging Face Hub" +summary = "LLMs with MLX and the Hugging Face Hub" groups = ["default"] dependencies = [ "jinja2", - "mlx>=0.25.0", + "mlx>=0.26.0", "numpy", "protobuf", "pyyaml", "transformers>=4.39.3", ] files = [ - {file = "mlx_lm-0.26.0-py3-none-any.whl", hash = "sha256:b00294c26242cd50db4b6e3ec3a2baf1cfdf8ca49a5e6057dce14642fabe0d21"}, - {file = "mlx_lm-0.26.0.tar.gz", hash = "sha256:78980ad994baf976779cc1c34c0d55c1c6b63dffef4899d67fec240d0c443b52"}, + {file = "mlx_lm-0.26.2-py3-none-any.whl", hash = "sha256:632624c4753a290dfe68f368d21f24883105cd8bba4c6ba5cf0905fecd626c1e"}, + {file = "mlx_lm-0.26.2.tar.gz", hash = "sha256:77e6f875bdea90a71174357363622b90a071d9c279bc958021e17e38d69fbc2a"}, ] [[package]] name = "mlx-metal" -version = "0.26.5" +version = "0.27.1" requires_python = ">=3.9" summary = "A framework for machine learning on Apple silicon." groups = ["default"] +marker = "platform_system == \"Darwin\"" files = [ - {file = "mlx_metal-0.26.5-py3-none-macosx_13_0_arm64.whl", hash = "sha256:6e58b4ec234bd23d04b817642245832b5a5ac48d06ed26b97b4dc6dba5b40aa3"}, - {file = "mlx_metal-0.26.5-py3-none-macosx_14_0_arm64.whl", hash = "sha256:d6f6a7d1110562544978d6f87a5bbb47be293d3b85c742043f8f2048e9f387cf"}, - {file = "mlx_metal-0.26.5-py3-none-macosx_15_0_arm64.whl", hash = "sha256:f5bd394c7ff6eebaaf8db6d7bd4f8dec96f2d232f08ba641673ab66f22e9727b"}, + {file = "mlx_metal-0.27.1-py3-none-macosx_13_0_arm64.whl", hash = "sha256:c66d9b1adb3c0ea19492fba6493f672bc7542e65dd65f7e2995918815fbeb907"}, + {file = "mlx_metal-0.27.1-py3-none-macosx_14_0_arm64.whl", hash = "sha256:fe4415ddd242974d91c7ca0699cd01507d17da8a5ba304122ef137cdb5e7fff4"}, + {file = "mlx_metal-0.27.1-py3-none-macosx_15_0_arm64.whl", hash = "sha256:d025dea30bda8baa32c928cfa333eac64a5adc8d07656f8fc55072d99403ebc9"}, ] [[package]] @@ -1908,29 +1909,29 @@ files = [ [[package]] name = "ruff" -version = "0.12.4" +version = "0.12.7" requires_python = ">=3.7" summary = "An extremely fast Python linter and code formatter, written in Rust." groups = ["default"] files = [ - {file = "ruff-0.12.4-py3-none-linux_armv6l.whl", hash = "sha256:cb0d261dac457ab939aeb247e804125a5d521b21adf27e721895b0d3f83a0d0a"}, - {file = "ruff-0.12.4-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:55c0f4ca9769408d9b9bac530c30d3e66490bd2beb2d3dae3e4128a1f05c7442"}, - {file = "ruff-0.12.4-py3-none-macosx_11_0_arm64.whl", hash = "sha256:a8224cc3722c9ad9044da7f89c4c1ec452aef2cfe3904365025dd2f51daeae0e"}, - {file = "ruff-0.12.4-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e9949d01d64fa3672449a51ddb5d7548b33e130240ad418884ee6efa7a229586"}, - {file = "ruff-0.12.4-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:be0593c69df9ad1465e8a2d10e3defd111fdb62dcd5be23ae2c06da77e8fcffb"}, - {file = "ruff-0.12.4-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a7dea966bcb55d4ecc4cc3270bccb6f87a337326c9dcd3c07d5b97000dbff41c"}, - {file = "ruff-0.12.4-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:afcfa3ab5ab5dd0e1c39bf286d829e042a15e966b3726eea79528e2e24d8371a"}, - {file = "ruff-0.12.4-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c057ce464b1413c926cdb203a0f858cd52f3e73dcb3270a3318d1630f6395bb3"}, - {file = "ruff-0.12.4-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e64b90d1122dc2713330350626b10d60818930819623abbb56535c6466cce045"}, - {file = "ruff-0.12.4-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2abc48f3d9667fdc74022380b5c745873499ff827393a636f7a59da1515e7c57"}, - {file = "ruff-0.12.4-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:2b2449dc0c138d877d629bea151bee8c0ae3b8e9c43f5fcaafcd0c0d0726b184"}, - {file = "ruff-0.12.4-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:56e45bb11f625db55f9b70477062e6a1a04d53628eda7784dce6e0f55fd549eb"}, - {file = "ruff-0.12.4-py3-none-musllinux_1_2_i686.whl", hash = "sha256:478fccdb82ca148a98a9ff43658944f7ab5ec41c3c49d77cd99d44da019371a1"}, - {file = "ruff-0.12.4-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:0fc426bec2e4e5f4c4f182b9d2ce6a75c85ba9bcdbe5c6f2a74fcb8df437df4b"}, - {file = "ruff-0.12.4-py3-none-win32.whl", hash = "sha256:4de27977827893cdfb1211d42d84bc180fceb7b72471104671c59be37041cf93"}, - {file = "ruff-0.12.4-py3-none-win_amd64.whl", hash = "sha256:fe0b9e9eb23736b453143d72d2ceca5db323963330d5b7859d60d101147d461a"}, - {file = "ruff-0.12.4-py3-none-win_arm64.whl", hash = "sha256:0618ec4442a83ab545e5b71202a5c0ed7791e8471435b94e655b570a5031a98e"}, - {file = "ruff-0.12.4.tar.gz", hash = "sha256:13efa16df6c6eeb7d0f091abae50f58e9522f3843edb40d56ad52a5a4a4b6873"}, + {file = "ruff-0.12.7-py3-none-linux_armv6l.whl", hash = "sha256:76e4f31529899b8c434c3c1dede98c4483b89590e15fb49f2d46183801565303"}, + {file = "ruff-0.12.7-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:789b7a03e72507c54fb3ba6209e4bb36517b90f1a3569ea17084e3fd295500fb"}, + {file = "ruff-0.12.7-py3-none-macosx_11_0_arm64.whl", hash = "sha256:2e1c2a3b8626339bb6369116e7030a4cf194ea48f49b64bb505732a7fce4f4e3"}, + {file = "ruff-0.12.7-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:32dec41817623d388e645612ec70d5757a6d9c035f3744a52c7b195a57e03860"}, + {file = "ruff-0.12.7-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:47ef751f722053a5df5fa48d412dbb54d41ab9b17875c6840a58ec63ff0c247c"}, + {file = "ruff-0.12.7-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a828a5fc25a3efd3e1ff7b241fd392686c9386f20e5ac90aa9234a5faa12c423"}, + {file = "ruff-0.12.7-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:5726f59b171111fa6a69d82aef48f00b56598b03a22f0f4170664ff4d8298efb"}, + {file = "ruff-0.12.7-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:74e6f5c04c4dd4aba223f4fe6e7104f79e0eebf7d307e4f9b18c18362124bccd"}, + {file = "ruff-0.12.7-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5d0bfe4e77fba61bf2ccadf8cf005d6133e3ce08793bbe870dd1c734f2699a3e"}, + {file = "ruff-0.12.7-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:06bfb01e1623bf7f59ea749a841da56f8f653d641bfd046edee32ede7ff6c606"}, + {file = "ruff-0.12.7-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:e41df94a957d50083fd09b916d6e89e497246698c3f3d5c681c8b3e7b9bb4ac8"}, + {file = "ruff-0.12.7-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:4000623300563c709458d0ce170c3d0d788c23a058912f28bbadc6f905d67afa"}, + {file = "ruff-0.12.7-py3-none-musllinux_1_2_i686.whl", hash = "sha256:69ffe0e5f9b2cf2b8e289a3f8945b402a1b19eff24ec389f45f23c42a3dd6fb5"}, + {file = "ruff-0.12.7-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:a07a5c8ffa2611a52732bdc67bf88e243abd84fe2d7f6daef3826b59abbfeda4"}, + {file = "ruff-0.12.7-py3-none-win32.whl", hash = "sha256:c928f1b2ec59fb77dfdf70e0419408898b63998789cc98197e15f560b9e77f77"}, + {file = "ruff-0.12.7-py3-none-win_amd64.whl", hash = "sha256:9c18f3d707ee9edf89da76131956aba1270c6348bfee8f6c647de841eac7194f"}, + {file = "ruff-0.12.7-py3-none-win_arm64.whl", hash = "sha256:dfce05101dbd11833a0776716d5d1578641b7fddb537fe7fa956ab85d1769b69"}, + {file = "ruff-0.12.7.tar.gz", hash = "sha256:1fc3193f238bc2d7968772c82831a4ff69252f673be371fb49663f0068b7ec71"}, ] [[package]] diff --git a/src/extensions_ref/src/flash_attention.cpp b/src/extensions_ref/src/flash_attention.cpp index 17106a2..dc31175 100644 --- a/src/extensions_ref/src/flash_attention.cpp +++ b/src/extensions_ref/src/flash_attention.cpp @@ -261,10 +261,7 @@ void FlashAttention::eval_gpu(const std::vector &inputs, std::vector< compute_encoder.set_input_array(k, 1); compute_encoder.set_input_array(v, 2); compute_encoder.set_input_array(mask, 3); - - // Encode output arrays to kernel - compute_encoder.set_output_array(out, 4); - + compute_encoder.set_output_array(out, 4); compute_encoder.set_vector_bytes(mask.shape(), 5); compute_encoder.set_vector_bytes(mask.strides(), 6); @@ -281,10 +278,10 @@ void FlashAttention::eval_gpu(const std::vector &inputs, std::vector< throw std::runtime_error("flash_attention: out must be contiguous"); } - const int64_t N = q.shape()[0]; - const int64_t L = q.shape()[1]; - const int64_t S = k.shape()[1]; - const int64_t E = q.shape()[2]; + const int N = q.shape()[0]; + const int L = q.shape()[1]; + const int S = k.shape()[1]; + const int E = q.shape()[2]; compute_encoder.set_bytes(N, 7); compute_encoder.set_bytes(L, 8); @@ -298,8 +295,8 @@ void FlashAttention::eval_gpu(const std::vector &inputs, std::vector< size_t tgp_size = kernel->maxTotalThreadsPerThreadgroup(); size_t simd_width = kernel->threadExecutionWidth(); - const int64_t Br = 32; - const int64_t Bc = 32; + const int Br = 32; + const int Bc = 32; if (simd_width * Br > tgp_size) { throw std::runtime_error("flash_attention: simd_width * Br must be equal to tgp_size"); } @@ -319,8 +316,8 @@ void FlashAttention::eval_gpu(const std::vector &inputs, std::vector< throw std::runtime_error("flash_attention: Bc must be less than 32"); } - const int64_t Tr = (L + Br - 1) / Br; - const int64_t Tc = (S + Bc - 1) / Bc; + const int Tr = (L + Br - 1) / Br; + const int Tc = (S + Bc - 1) / Bc; compute_encoder.set_bytes(Br, 14); compute_encoder.set_bytes(Bc, 15); diff --git a/src/extensions_ref/src/flash_attention.metal b/src/extensions_ref/src/flash_attention.metal index 4ac4ccd..98e5821 100644 --- a/src/extensions_ref/src/flash_attention.metal +++ b/src/extensions_ref/src/flash_attention.metal @@ -11,17 +11,17 @@ using namespace metal; device float* out [[buffer(4)]], constant const int* mask_shape [[buffer(5)]], constant const int64_t* mask_strides [[buffer(6)]], - device const int64_t &N [[buffer(7)]], - device const int64_t &L [[buffer(8)]], - device const int64_t &S [[buffer(9)]], - device const int64_t &E [[buffer(10)]], - device const int64_t &num_kv_heads [[buffer(11)]], - device const int64_t &num_heads [[buffer(12)]], + device const int &N [[buffer(7)]], + device const int &L [[buffer(8)]], + device const int &S [[buffer(9)]], + device const int &E [[buffer(10)]], + device const int &num_kv_heads [[buffer(11)]], + device const int &num_heads [[buffer(12)]], device const float &scale [[buffer(13)]], - device const int64_t &Br [[buffer(14)]], - device const int64_t &Bc [[buffer(15)]], - [[maybe_unused]] device const int64_t &Tr [[buffer(16)]], - device const int64_t &Tc [[buffer(17)]], + device const int &Br [[buffer(14)]], + device const int &Bc [[buffer(15)]], + [[maybe_unused]] device const int &Tr [[buffer(16)]], + device const int &Tc [[buffer(17)]], uint2 group_id [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { diff --git a/src/extensions_ref/src/quantized_matmul.cpp b/src/extensions_ref/src/quantized_matmul.cpp index d5eed7d..b599b58 100644 --- a/src/extensions_ref/src/quantized_matmul.cpp +++ b/src/extensions_ref/src/quantized_matmul.cpp @@ -215,9 +215,4 @@ void QuantizedMatmul::eval_gpu(const std::vector &inputs, std::vector compute_encoder.dispatch_threadgroups(num_threadgroups, num_threads_per_group); } -bool QuantizedMatmul::is_equivalent(const Primitive &other) const { - const QuantizedMatmul &r_other = static_cast(other); - return group_size_ == r_other.group_size_ && bits_ == r_other.bits_; -} - } // namespace tiny_llm_ext_ref diff --git a/src/extensions_ref/src/tiny_llm_ext.h b/src/extensions_ref/src/tiny_llm_ext.h index 824c2b4..0fe663e 100644 --- a/src/extensions_ref/src/tiny_llm_ext.h +++ b/src/extensions_ref/src/tiny_llm_ext.h @@ -34,8 +34,6 @@ class QuantizedMatmul : public mx::Primitive { const char *name() const override { return "QuantizedMatmul"; } - bool is_equivalent(const mx::Primitive &other) const override; - private: int group_size_; int bits_; @@ -44,6 +42,9 @@ class QuantizedMatmul : public mx::Primitive { mx::array flash_attention(const mx::array &q, const mx::array &k, const mx::array &v, const mx::array &mask, const float scale, const int num_kv_heads, const int num_heads, mx::StreamOrDevice s = {}); +mx::array flash_attention_no_mask(const mx::array &q, const mx::array &k, const mx::array &v, + const float scale, const int num_kv_heads, const int num_heads, mx::StreamOrDevice s = {}); + class FlashAttention : public mx::Primitive { public: explicit FlashAttention(mx::Stream stream, const float scale, const int num_kv_heads, const int num_heads) @@ -59,11 +60,6 @@ class FlashAttention : public mx::Primitive { const char *name() const override { return "FlashAttention"; } - bool is_equivalent(const mx::Primitive &other) const override { - const FlashAttention &r_other = static_cast(other); - return scale_ == r_other.scale_ && num_kv_heads_ == r_other.num_kv_heads_ && num_heads_ == r_other.num_heads_; - } - private: float scale_; int num_kv_heads_; diff --git a/src/tiny_llm_ref/attention.py b/src/tiny_llm_ref/attention.py index 82032f4..7f69615 100644 --- a/src/tiny_llm_ref/attention.py +++ b/src/tiny_llm_ref/attention.py @@ -87,11 +87,9 @@ def flash_attention( value = mx.contiguous(value) N = query.shape[0] if mask is None: - mask = mx.zeros((N, L, S)) + mask = mx.reshape(mx.broadcast_to(mx.zeros((L, S)), (*B, H_q, L, S)), (N, L, S)).astype(mx.float32) else: - mask = mx.contiguous( - mx.reshape(mx.contiguous(mx.broadcast_to(mask, (*B, H_q, L, S))), (N, L, S)) - ) + mask = mx.reshape(mx.broadcast_to(mask, (*B, H_q, L, S)), (N, L, S)).astype(mx.float32) result = tiny_llm_ext_ref.flash_attention( query, key, diff --git a/tests/utils.py b/tests/utils.py index 4aa7cd4..c34584f 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -14,6 +14,7 @@ def assert_allclose( precision: mx.Dtype, rtol: float | None = None, atol: float | None = None, + message: str | None = None, ): a = np.array(a) b = np.array(b) @@ -38,7 +39,7 @@ def assert_allclose( print("diff_b=", b * diff) print("diff_a_val=", a[diff]) print("diff_b_val=", b[diff]) - assert False, f"result mismatch" + assert False, f"result mismatch: {message}" def np_type_to_mx_type(np_type: np.dtype) -> mx.Dtype: diff --git a/tests_refsol/test_week_2_day_3.py b/tests_refsol/test_week_2_day_3.py index ac8f2a7..8e8aba4 100644 --- a/tests_refsol/test_week_2_day_3.py +++ b/tests_refsol/test_week_2_day_3.py @@ -9,7 +9,7 @@ def attention_helper(stream: mx.Stream, H_q, H, L, E, S, BATCH): with mx.stream(stream): q_shape = (BATCH, H_q, L, E) kv_shape = (BATCH, H, S, E) - scale = 1.0 + scale = 0.9 for _ in range(100): query = mx.random.uniform(shape=q_shape, dtype=precision) key = mx.random.uniform(shape=kv_shape, dtype=precision) diff --git a/tests_refsol/test_week_2_day_6.py b/tests_refsol/test_week_2_day_6.py index e5f6d8d..1214756 100644 --- a/tests_refsol/test_week_2_day_6.py +++ b/tests_refsol/test_week_2_day_6.py @@ -18,31 +18,51 @@ def attention_helper( value = mx.random.uniform(shape=kv_shape, dtype=precision) mask = mx.random.uniform(shape=(BATCH, 1, L, S), dtype=precision) - reference_output = mx.fast.scaled_dot_product_attention( + reference_output_1 = mx.fast.scaled_dot_product_attention( q=query, k=key, v=value, scale=scale, mask=mask, ) + reference_output_2 = mx.fast.scaled_dot_product_attention( + q=query, + k=key, + v=value, + scale=scale, + ) if use_flash_attention: - user_output = flash_attention( + user_output_1 = flash_attention( query, key, value, scale=scale, mask=mask, ) + user_output_2 = flash_attention( + query, + key, + value, + scale=scale, + ) else: - user_output = scaled_dot_product_attention_grouped( + user_output_1 = scaled_dot_product_attention_grouped( query, key, value, scale=scale, mask=mask, ) - mx.eval(user_output) # so that any error will be caught here - assert_allclose(user_output, reference_output, precision=mx.float16) + user_output_2 = scaled_dot_product_attention_grouped( + query, + key, + value, + scale=scale, + ) + mx.eval(user_output_1) + mx.eval(user_output_2) + assert_allclose(user_output_2, reference_output_2, precision=mx.float16, message="no mask") + assert_allclose(user_output_1, reference_output_1, precision=mx.float16, message="with mask") def test_flash_attention_with_mask_cpu_small(): From 0ca2bf17e37663ddb6ab9b1f71d2931be959c085 Mon Sep 17 00:00:00 2001 From: Magic Mai Date: Sun, 3 Aug 2025 14:14:33 -0700 Subject: [PATCH 16/79] feat(kv-cache): add KV cache imports and week 2 day 1 tests (#35) Add KV cache module imports to both tiny_llm and tiny_llm_ref packages to enable KV cache functionality. Include comprehensive test suite for week 2 day 1 covering embedding operations, model inference with KV cache, and sequential token generation with offset support. - Add KV cache imports to __init__.py files - Create test_week_2_day_1.py with task 2-4 test coverage - Support multiple Qwen2 model variants (0.5B, 1.5B, 7B) - Include embedding call and as_linear functionality tests - Add sequential generation tests with proper cache management --- src/tiny_llm/__init__.py | 1 + src/tiny_llm_ref/__init__.py | 1 + tests_refsol/test_week_2_day_1.py | 129 ++++++++++++++++++++++++++++++ 3 files changed, 131 insertions(+) create mode 100644 tests_refsol/test_week_2_day_1.py diff --git a/src/tiny_llm/__init__.py b/src/tiny_llm/__init__.py index c05fdd8..b476493 100644 --- a/src/tiny_llm/__init__.py +++ b/src/tiny_llm/__init__.py @@ -8,3 +8,4 @@ from .qwen2_week1 import Qwen2ModelWeek1 from .qwen2_week2 import Qwen2ModelWeek2 from .sampler import * +from .kv_cache import * \ No newline at end of file diff --git a/src/tiny_llm_ref/__init__.py b/src/tiny_llm_ref/__init__.py index f3dbc0e..917e931 100644 --- a/src/tiny_llm_ref/__init__.py +++ b/src/tiny_llm_ref/__init__.py @@ -9,3 +9,4 @@ from .qwen2_week1 import Qwen2ModelWeek1 from .qwen2_week2 import Qwen2ModelWeek2 from .sampler import * +from .kv_cache import * \ No newline at end of file diff --git a/tests_refsol/test_week_2_day_1.py b/tests_refsol/test_week_2_day_1.py new file mode 100644 index 0000000..bc412e5 --- /dev/null +++ b/tests_refsol/test_week_2_day_1.py @@ -0,0 +1,129 @@ +import pytest +from .utils import * +from .tiny_llm_base import Qwen2ModelWeek2, Embedding, dequantize_linear, qwen2_week2, TinyKvFullCache +from mlx_lm import load + +# TODO: task 1 tests + + +@pytest.mark.skipif( + not qwen_2_05b_model_exists(), reason="Qwen2-0.5B-Instruct-MLX model not found" +) +def test_utils_qwen_2_05b(): + pass + + +@pytest.mark.skipif( + not qwen_2_7b_model_exists(), reason="Qwen2-7B-Instruct-MLX model not found" +) +def test_utils_qwen_2_7b(): + pass + + +@pytest.mark.skipif( + not qwen_2_15b_model_exists(), reason="Qwen2-1.5B-Instruct-MLX model not found" +) +def test_utils_qwen_2_15b(): + pass + + +def helper_test_task_3(model_name: str, iters: int = 10): + mlx_model, tokenizer = load(model_name) + model = Qwen2ModelWeek2(mlx_model) + for _ in range(iters): + cache = [TinyKvFullCache() for _ in range(model.num_hidden_layers)] + input = mx.random.randint(low=0, high=tokenizer.vocab_size, shape=(1, 10)) + user_output = model(input, 0, cache) + user_output = user_output - mx.logsumexp(user_output, keepdims=True) + ref_output = mlx_model(input) + ref_output = ref_output - mx.logsumexp(ref_output, keepdims=True) + assert_allclose(user_output, ref_output, precision=mx.float16, rtol=1e-1) + + +@pytest.mark.skipif( + not qwen_2_05b_model_exists(), reason="Qwen2-0.5B-Instruct-MLX model not found" +) +def test_task_2_embedding_call(): + mlx_model, _ = load("Qwen/Qwen2-0.5B-Instruct-MLX") + embedding = Embedding( + mlx_model.args.vocab_size, + mlx_model.args.hidden_size, + dequantize_linear(mlx_model.model.embed_tokens).astype(mx.float16), + ) + for _ in range(50): + input = mx.random.randint(low=0, high=mlx_model.args.vocab_size, shape=(1, 10)) + user_output = embedding(input) + ref_output = mlx_model.model.embed_tokens(input) + assert_allclose(user_output, ref_output, precision=mx.float16) + + +@pytest.mark.skipif( + not qwen_2_05b_model_exists(), reason="Qwen2-0.5B-Instruct-MLX model not found" +) +def test_task_2_embedding_as_linear(): + mlx_model, _ = load("Qwen/Qwen2-0.5B-Instruct-MLX") + embedding = Embedding( + mlx_model.args.vocab_size, + mlx_model.args.hidden_size, + dequantize_linear(mlx_model.model.embed_tokens).astype(mx.float16), + ) + for _ in range(50): + input = mx.random.uniform(shape=(1, 10, mlx_model.args.hidden_size)) + user_output = embedding.as_linear(input) + ref_output = mlx_model.model.embed_tokens.as_linear(input) + assert_allclose(user_output, ref_output, precision=mx.float16, atol=1e-1) + + +@pytest.mark.skipif( + not qwen_2_05b_model_exists(), reason="Qwen2-0.5B-Instruct-MLX model not found" +) +def test_task_3_qwen_2_05b(): + helper_test_task_3("Qwen/Qwen2-0.5B-Instruct-MLX", 5) + + +@pytest.mark.skipif( + not qwen_2_7b_model_exists(), reason="Qwen2-7B-Instruct-MLX model not found" +) +def test_task_3_qwen_2_7b(): + helper_test_task_3("Qwen/Qwen2-7B-Instruct-MLX", 1) + + +@pytest.mark.skipif( + not qwen_2_15b_model_exists(), reason="Qwen2-1.5B-Instruct-MLX model not found" +) +def test_task_3_qwen_2_15b(): + helper_test_task_3("Qwen/Qwen2-1.5B-Instruct-MLX", 3) + +def helper_test_task_4(model_name: str, seq_len: int, iters: int = 1): + mlx_model, tokenizer = load(model_name) + model = Qwen2ModelWeek2(mlx_model) + for _ in range(iters): + cache = [TinyKvFullCache() for _ in range(model.num_hidden_layers)] + inputs = mx.random.randint(0, tokenizer.vocab_size, (1, seq_len)) + ref_outputs = mlx_model(inputs) + for offset in range(seq_len): + user_out = model(inputs=inputs[:, offset:offset+1], offset=offset, cache=cache) + ref_out = ref_outputs[:, offset:offset+1, :] + user_out = user_out - mx.logsumexp(user_out, keepdims=True) + ref_out = ref_out - mx.logsumexp(ref_out, keepdims=True) + assert_allclose(user_out, ref_out, precision=mx.float16, rtol=1e-1) + +@pytest.mark.skipif( + not qwen_2_05b_model_exists(), reason="Qwen2-0.5B-Instruct-MLX model not found" +) +def test_task_4_qwen_2_05b(): + helper_test_task_4("Qwen/Qwen2-0.5B-Instruct-MLX", seq_len=3) + + +@pytest.mark.skipif( + not qwen_2_7b_model_exists(), reason="Qwen2-7B-Instruct-MLX model not found" +) +def test_task_4_qwen_2_7b(): + helper_test_task_4("Qwen/Qwen2-7B-Instruct-MLX", seq_len=3) + + +@pytest.mark.skipif( + not qwen_2_15b_model_exists(), reason="Qwen2-1.5B-Instruct-MLX model not found" +) +def test_task_4_qwen_2_15b(): + helper_test_task_4("Qwen/Qwen2-1.5B-Instruct-MLX", seq_len=3) \ No newline at end of file From 3cd7d840c5e65e53ab5f4697a7c2c39ffd777fb5 Mon Sep 17 00:00:00 2001 From: Alex Chi Z Date: Sun, 3 Aug 2025 18:31:13 -0400 Subject: [PATCH 17/79] refactor continuous batching Signed-off-by: Alex Chi Z --- src/extensions_ref/src/flash_attention.cpp | 1 + src/tiny_llm/__init__.py | 2 +- src/tiny_llm_ref/__init__.py | 3 +- src/tiny_llm_ref/attention.py | 8 +- src/tiny_llm_ref/batch.py | 219 +++++++++++++++++++++ src/tiny_llm_ref/generate.py | 147 -------------- src/tiny_llm_ref/kv_cache.py | 96 +++++---- src/tiny_llm_ref/qwen2_week2.py | 6 +- tests_refsol/test_week_2_day_1.py | 18 +- tests_refsol/test_week_2_day_6.py | 14 +- 10 files changed, 317 insertions(+), 197 deletions(-) create mode 100644 src/tiny_llm_ref/batch.py diff --git a/src/extensions_ref/src/flash_attention.cpp b/src/extensions_ref/src/flash_attention.cpp index dc31175..006ae51 100644 --- a/src/extensions_ref/src/flash_attention.cpp +++ b/src/extensions_ref/src/flash_attention.cpp @@ -288,6 +288,7 @@ void FlashAttention::eval_gpu(const std::vector &inputs, std::vector< compute_encoder.set_bytes(S, 9); compute_encoder.set_bytes(E, 10); + // Make sure the data type matches with the metal kernel: otherwise you'll get flaky issues and stuck :( compute_encoder.set_bytes(num_kv_heads_, 11); compute_encoder.set_bytes(num_heads_, 12); compute_encoder.set_bytes(scale_, 13); diff --git a/src/tiny_llm/__init__.py b/src/tiny_llm/__init__.py index b476493..99801a2 100644 --- a/src/tiny_llm/__init__.py +++ b/src/tiny_llm/__init__.py @@ -8,4 +8,4 @@ from .qwen2_week1 import Qwen2ModelWeek1 from .qwen2_week2 import Qwen2ModelWeek2 from .sampler import * -from .kv_cache import * \ No newline at end of file +from .kv_cache import * diff --git a/src/tiny_llm_ref/__init__.py b/src/tiny_llm_ref/__init__.py index 917e931..402f98f 100644 --- a/src/tiny_llm_ref/__init__.py +++ b/src/tiny_llm_ref/__init__.py @@ -9,4 +9,5 @@ from .qwen2_week1 import Qwen2ModelWeek1 from .qwen2_week2 import Qwen2ModelWeek2 from .sampler import * -from .kv_cache import * \ No newline at end of file +from .kv_cache import * +from .batch import * diff --git a/src/tiny_llm_ref/attention.py b/src/tiny_llm_ref/attention.py index 7f69615..bc2d30c 100644 --- a/src/tiny_llm_ref/attention.py +++ b/src/tiny_llm_ref/attention.py @@ -87,9 +87,13 @@ def flash_attention( value = mx.contiguous(value) N = query.shape[0] if mask is None: - mask = mx.reshape(mx.broadcast_to(mx.zeros((L, S)), (*B, H_q, L, S)), (N, L, S)).astype(mx.float32) + mask = mx.reshape( + mx.broadcast_to(mx.zeros((L, S)), (*B, H_q, L, S)), (N, L, S) + ).astype(mx.float32) else: - mask = mx.reshape(mx.broadcast_to(mask, (*B, H_q, L, S)), (N, L, S)).astype(mx.float32) + mask = mx.reshape(mx.broadcast_to(mask, (*B, H_q, L, S)), (N, L, S)).astype( + mx.float32 + ) result = tiny_llm_ext_ref.flash_attention( query, key, diff --git a/src/tiny_llm_ref/batch.py b/src/tiny_llm_ref/batch.py new file mode 100644 index 0000000..a990902 --- /dev/null +++ b/src/tiny_llm_ref/batch.py @@ -0,0 +1,219 @@ +import mlx.core as mx +from mlx_lm.tokenizer_utils import TokenizerWrapper +from .kv_cache import * +from .qwen2_week2 import Qwen2ModelWeek2 +from typing import Callable + + +def _step(model, y, offsets, kv_cache): + logits = model(y, offsets, kv_cache) + logits = logits[:, -1, :] + logprobs = logits - mx.logsumexp(logits, keepdims=True) + sampler = lambda x: mx.argmax(x, axis=-1) + y = sampler(logprobs) + return y + + +class Request: + def __init__( + self, + model: any, + tokenizer: TokenizerWrapper, + prompt: str, + prefill_max_step: int = 128, + prompt_idx: int = 0, + ): + self.prompt = prompt + self.kv_cache = [TinyKvFullCache() for _ in range(model.num_hidden_layers)] + self.model = model + self.detokenizer = tokenizer.detokenizer.__class__(tokenizer._tokenizer) + self.prefill_tokens = mx.array( + tokenizer.encode(prompt, add_special_tokens=False) + ) + self.prefill_max_step = prefill_max_step + self.is_done = False + self.is_prefill_done = False + self.eos_token_id = tokenizer.eos_token_id + self.next_token = None + self.offset = 0 + self.prompt_idx = prompt_idx + + def try_prefill(self): + """ + Prefill this request up to max_step size, returns None if prefill is not done + """ + if self.is_prefill_done: + raise ValueError("prefill called after done") + tokens_to_prefill = min( + self.prefill_max_step, self.prefill_tokens.size - self.offset + ) + token = _step( + self.model, + self.prefill_tokens[self.offset : self.offset + tokens_to_prefill][None], + [self.offset], + self.kv_cache, + ) + self.offset += tokens_to_prefill + for i in self.kv_cache: + mx.eval(i.key_values[0]) + mx.eval(i.key_values[1]) + if self.offset == self.prefill_tokens.size: + self.is_prefill_done = True + mx.eval(token) + self.decode_done(token.item(), False) + + def decode_done(self, token, update_offset=True): + if self.is_done: + raise ValueError("decode called after done") + if token == self.eos_token_id: + self.is_done = True + return + self.detokenizer.add_token(token) + self.next_token = token + if update_offset: + self.offset += 1 + + def text(self): + return self.detokenizer.text + + +def _print_progress( + requests: list[Request | None], + is_idle: list[bool], + pending_prefill_request: Request | None, + queue_size: int, + progress_cnt: int, +): + print(" ---") + animation_frames = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"] + animation_frame = animation_frames[progress_cnt % len(animation_frames)] + for i in range(len(requests)): + if is_idle[i]: + print(f" Decode #{i}: idle", flush=True) + else: + print( + f"{animation_frame} Decode [req {requests[i].prompt_idx}, {requests[i].offset}]: {requests[i].text()[-80:].replace('\n', ' ')}", + flush=True, + ) + if pending_prefill_request is not None: + if pending_prefill_request.is_prefill_done: + print( + f" Prefill [req {pending_prefill_request.prompt_idx}]: done, waiting for slot, {queue_size} requests in queue", + flush=True, + ) + return + precentage = ( + pending_prefill_request.offset / pending_prefill_request.prefill_tokens.size + ) * 100 + print( + f"{animation_frame} Prefill [req {pending_prefill_request.prompt_idx}]: {precentage:.2f}% ({pending_prefill_request.prefill_tokens.size - pending_prefill_request.offset} remaining tokens)", + flush=True, + ) + else: + print(f" Prefill: idle, {queue_size} requests in queue", flush=True) + + +def batch_generate( + model: any, + tokenizer: TokenizerWrapper, + prompts: list[str], + max_seq_len=512, + batch_size=5, + prefill_step=128, +): + decode_requests: list[Request] = [None] * batch_size + is_idle = [True] * batch_size + kv_cache = [ + BatchingKvCache(max_active_requests=batch_size, max_seq_len=max_seq_len) + for _ in range(model.num_hidden_layers) + ] + result = [] + pending_prefill_request = None + next_request_idx = 0 + progress_cnt = 0 + + while True: + if len(prompts) == 0 and all(is_idle): + break + # prefill until no idle slots + if len(prompts) > 0 and pending_prefill_request is None: + prompt = prompts.pop(0) + pending_prefill_request = Request( + model, tokenizer, prompt, prefill_step, next_request_idx + ) + next_request_idx += 1 + + # In every iteration, we do a prefill first + if pending_prefill_request is not None: + made_progress = False + if not pending_prefill_request.is_prefill_done: + pending_prefill_request.try_prefill() + made_progress = True + if pending_prefill_request.is_prefill_done: + prefill_kv_cache = pending_prefill_request.kv_cache + found_slot = False + for i in range(batch_size): + if is_idle[i]: + # Add this request to the decode requests + is_idle[i] = False + for prefill_cache, batch_cache in zip( + prefill_kv_cache, kv_cache + ): + batch_cache.add_request(prefill_cache, i) + decode_requests[i] = pending_prefill_request + found_slot = True + made_progress = True + break + if found_slot: + pending_prefill_request = None + if made_progress: + _print_progress( + decode_requests, + is_idle, + pending_prefill_request, + len(prompts), + progress_cnt, + ) + progress_cnt += 1 + + # After the prefill request moves forward one step, we do the decode + if not all(is_idle): + next_tokens = [] + offsets = [] + for req in decode_requests: + if req is not None: + next_tokens.append(req.next_token) + offsets.append(req.offset) + else: + next_tokens.append(0) + offsets.append(0) + next_tokens = mx.array(next_tokens) + # decode + next_tokens = _step(model, next_tokens.reshape(-1, 1), offsets, kv_cache) + for i in range(batch_size): + if not is_idle[i]: + req = decode_requests[i] + remove_reason = None + if req.is_done: + remove_reason = "EOS" + elif req.offset >= max_seq_len: + remove_reason = "max seq len" + if remove_reason is not None: + print( + f"Removing request {i} due to {remove_reason}", flush=True + ) + batch_cache.remove_request(i) + is_idle[i] = True + result.append((req.prompt_idx, req.text())) + decode_requests[i] = None + continue + req.decode_done(next_tokens[i].item()) + _print_progress( + decode_requests, + is_idle, + pending_prefill_request, + len(prompts), + progress_cnt, + ) + progress_cnt += 1 + return result diff --git a/src/tiny_llm_ref/generate.py b/src/tiny_llm_ref/generate.py index eb2b87a..d32896b 100644 --- a/src/tiny_llm_ref/generate.py +++ b/src/tiny_llm_ref/generate.py @@ -77,150 +77,3 @@ def _step(model, y, offset, kv_cache): break offset += tokens.size tokens = token - - -def _step(model, y, offsets, kv_cache): - logits = model(y, offsets, kv_cache) - logits = logits[:, -1, :] - logprobs = logits - mx.logsumexp(logits, keepdims=True) - sampler = lambda x: mx.argmax(x, axis=-1) - y = sampler(logprobs) - return y - - -class _PrefillRequest: - def __init__( - self, model: any, tokenizer: TokenizerWrapper, prompt: str, max_step: int = 128 - ): - self.prompt = prompt - self.kv_cache = [TinyKvFullCache() for _ in range(model.num_hidden_layers)] - self.model = model - self.prefill_tokens = mx.array( - tokenizer.encode(prompt, add_special_tokens=False) - ) - self.offset = 0 - self.max_step = max_step - - def prefill(self): - # returns None if prefill is not done - tokens_to_prefill = min(self.max_step, self.prefill_tokens.size - self.offset) - token = _step( - self.model, - self.prefill_tokens[self.offset : self.offset + tokens_to_prefill][None], - [self.offset], - self.kv_cache, - ) - self.offset += tokens_to_prefill - for i in self.kv_cache: - mx.eval(i.key_values[0]) - mx.eval(i.key_values[1]) - if self.offset == self.prefill_tokens.size: - mx.eval(token) - return token, self.kv_cache, self.offset - else: - return None - - -def _print_progress( - detokenizers: list[TokenizerWrapper], - prompt_idx: list[int], - is_idle: list[bool], - pending_prefill_requests: _PrefillRequest | None, -): - for i in range(len(detokenizers)): - if is_idle[i]: - print(f"Decode {i}: idle", flush=True) - else: - print(f"Decode {i}[{prompt_idx[i]}]: {detokenizers[i].text}", flush=True) - if pending_prefill_requests is not None: - print( - f"Prefill {pending_prefill_requests.offset}/{pending_prefill_requests.prefill_tokens.size}", - flush=True, - ) - else: - print("Prefill: idle", flush=True) - - -def batch_generate( - model: any, - tokenizer: TokenizerWrapper, - prompts: list[str], - max_seq_len=512, - batch_size=5, - prefill_step=128, -): - is_idle = [True] * batch_size - prompt_idx = [0] * batch_size - next_tokens = mx.array([0] * batch_size) - offsets = mx.array([0] * batch_size) - detokenizers = [None] * batch_size - kv_cache = [ - BatchingKvCache(max_active_requests=batch_size, max_seq_len=max_seq_len) - for _ in range(model.num_hidden_layers) - ] - result = [] - pending_prefill_requests = None - - print(f"Processing {len(prompts)} prompts") - prompts = enumerate(prompts) - more_prompts = True - while True: - if not more_prompts and all(is_idle): - break - # prefill until no idle slots - while any(is_idle) and more_prompts and pending_prefill_requests is None: - try: - idx, prompt = next(prompts) - except StopIteration: - more_prompts = False - break - pending_prefill_requests = _PrefillRequest( - model, tokenizer, prompt, prefill_step - ) - break - - if pending_prefill_requests is not None: - res = pending_prefill_requests.prefill() - if res is not None: - pending_prefill_requests = None - token, prefill_kv_cache, offset = res - - if token.item() == tokenizer.eos_token_id: - # if the first token is eos, we skip this prompt - continue - - for i in range(batch_size): - if is_idle[i]: - detokenizers[i] = tokenizer.detokenizer.__class__( - tokenizer._tokenizer - ) - detokenizers[i].add_token(token.item()) - prompt_idx[i] = idx - is_idle[i] = False - for prefill_cache, batch_cache in zip( - prefill_kv_cache, kv_cache - ): - batch_cache.add_request(prefill_cache, i) - next_tokens[i] = token - offsets[i] = offset - break - - if not all(is_idle): - next_tokens = mx.array(next_tokens) - # decode - next_tokens = _step(model, next_tokens.reshape(-1, 1), offsets, kv_cache) - offsets += 1 - for i in range(batch_size): - if not is_idle[i]: - detokenizers[i].add_token(next_tokens[i].item()) - remove_due_to_eos = next_tokens[i].item() == tokenizer.eos_token_id - remove_due_to_max_seq_len = offsets[i] >= max_seq_len - if remove_due_to_eos or remove_due_to_max_seq_len: - reason = "EOS" if remove_due_to_eos else "Max Seq Len" - result.append((prompt_idx[i], detokenizers[i].text)) - print(f"Removing request {i} due to {reason}", flush=True) - batch_cache.remove_request(i) - is_idle[i] = True - continue - _print_progress(detokenizers, prompt_idx, is_idle, pending_prefill_requests) - return result diff --git a/src/tiny_llm_ref/kv_cache.py b/src/tiny_llm_ref/kv_cache.py index 8345f99..418f996 100644 --- a/src/tiny_llm_ref/kv_cache.py +++ b/src/tiny_llm_ref/kv_cache.py @@ -6,8 +6,12 @@ class TinyKvCache: def update_and_fetch( - self, key: mx.array, value: mx.array, q_L: int | None = None - ) -> tuple[mx.array, mx.array, int, Optional[mx.array]]: + self, + key: mx.array, + value: mx.array, + mask_length: int | None = None, + mask: mx.array | str | None = None, + ) -> tuple[mx.array, mx.array, int, mx.array]: pass @@ -15,68 +19,82 @@ class BatchingKvCache(TinyKvCache): def __init__(self, max_active_requests: int, max_seq_len: int): self.max_active_requests = max_active_requests self.max_seq_len = max_seq_len - self.key_values = [None] * max_active_requests - self.real_seq_len = [0] * max_active_requests + self.kv_caches: list[TinyKvCache] = [None] * max_active_requests self.HD = None def update_and_fetch( - self, key: mx.array, value: mx.array, q_L: int | None = None + self, + keys: mx.array, + values: mx.array, + mask_length: int | None = None, + mask: mx.array | str | None = None, ) -> tuple[mx.array, mx.array, int, Optional[mx.array]]: - B, H, S, D = key.shape - assert key.shape == value.shape + B, H, S, D = keys.shape + assert keys.shape == values.shape assert S <= self.max_seq_len assert self.HD == (H, D), f"expect {self.HD} but got {H, D}" assert B == self.max_active_requests # Step 1: append the result to the cache + data = [] for b in range(B): - if self.key_values[b] is None: + if self.kv_caches[b] is None: + data.append(None) continue - cached_keys, cached_values = self.key_values[b] - keys, values = key[b], value[b] - keys = mx.concat([cached_keys, keys], axis=1) - values = mx.concat([cached_values, values], axis=1) - self.key_values[b] = (keys, values) - self.real_seq_len[b] += S + key, value = keys[b : b + 1], values[b : b + 1] + new_key, new_value, seq_len, mask = self.kv_caches[b].update_and_fetch( + key, value + ) + data.append((new_key[0], new_value[0], seq_len, mask)) + # Step 2: compute seq_len of this batch - seq_len = max(self.real_seq_len) + def get_seq_len(data): + if data is None: + return 0 + _, _, seq_len, _ = data + return seq_len + + seq_len = max(map(get_seq_len, data)) # Step 3: generate masks and a single array of keys and values - masks = [] keys = mx.zeros((self.max_active_requests, H, seq_len, D), dtype=key.dtype) values = mx.zeros((self.max_active_requests, H, seq_len, D), dtype=value.dtype) masks = mx.full( - (self.max_active_requests, q_L, seq_len), -mx.inf, dtype=key.dtype + (self.max_active_requests, mask_length, seq_len), -mx.inf, dtype=key.dtype ) for b in range(B): - if self.key_values[b] is None: + if data[b] is None: # for some reasons we need to do this, otherwise it will cause wrong output? # maybe precision issues? - masks[b, :, :] = causal_mask(q_L, seq_len, dtype=key.dtype) + masks[b, :, :] = causal_mask(mask_length, seq_len, dtype=key.dtype) continue - cached_keys, cached_values = self.key_values[b] - S = self.real_seq_len[b] - keys[b, :, seq_len - S : seq_len, :] = cached_keys - values[b, :, seq_len - S : seq_len, :] = cached_values - masks[b, :, seq_len - S : seq_len] = causal_mask(q_L, S, dtype=key.dtype) - return keys, values, None, masks.reshape(B, 1, q_L, seq_len) + key, value, S, mask = data[b] + keys[b, :, seq_len - S : seq_len, :] = key + values[b, :, seq_len - S : seq_len, :] = value + if mask is None or mask == "causal": + masks[b, :, seq_len - S : seq_len] = causal_mask( + mask_length, S, dtype=key.dtype + ) + elif isinstance(mask, mx.array): + masks[b, :, seq_len - S : seq_len] = mask + else: + raise NotImplemented + return keys, values, None, masks.reshape(B, 1, mask_length, seq_len) def add_request(self, prefilled: TinyKvCache, id: int): if id >= self.max_active_requests: raise ValueError(f"Request id {id} is out of range") - keys, values = prefilled.key_values - B, H, L, D = keys.shape + keys, _ = prefilled.key_values + B, H, _, D = keys.shape assert B == 1 if self.HD is None: self.HD = (H, D) else: assert self.HD == (H, D) - self.real_seq_len[id] = L - self.key_values[id] = (keys[0], values[0]) + self.kv_caches[id] = prefilled def remove_request(self, id: int): - if self.key_values is None: + if self.kv_caches is None: raise ValueError(f"Request id {id} is not in the cache") - self.key_values[id] = None - self.real_seq_len[id] = 0 + self.kv_caches[id] = None class TinyKvFullCache(TinyKvCache): @@ -85,14 +103,18 @@ def __init__(self): self.offset = 0 def update_and_fetch( - self, key: mx.array, value: mx.array, q_L: int | None = None + self, + key: mx.array, + value: mx.array, + mask_length: int | None = None, + mask: mx.array | str | None = None, ) -> tuple[mx.array, mx.array, int, Optional[mx.array]]: if self.key_values is None: assert self.offset == 0 self.key_values = (key, value) B, H, S, D = key.shape self.offset = S - return key, value, 0, None + return key, value, 0, mask else: B, H, S, D = key.shape assert key.shape == value.shape @@ -102,6 +124,8 @@ def update_and_fetch( new_keys = mx.concat([prev_keys, key], axis=2) new_values = mx.concat([prev_values, value], axis=2) self.key_values = (new_keys, new_values) - start_offset = self.offset self.offset += S - return new_keys, new_values, start_offset, None + return new_keys, new_values, self.offset, mask + + def get_offset(self): + return self.offset diff --git a/src/tiny_llm_ref/qwen2_week2.py b/src/tiny_llm_ref/qwen2_week2.py index 540f8fb..636dc4b 100644 --- a/src/tiny_llm_ref/qwen2_week2.py +++ b/src/tiny_llm_ref/qwen2_week2.py @@ -78,11 +78,9 @@ def __call__( projection_q = projection_q.transpose(0, 2, 1, 3) projection_k = projection_k.transpose(0, 2, 1, 3) projection_v = projection_v.transpose(0, 2, 1, 3) - projection_k, projection_v, _, kv_cache_mask = cache.update_and_fetch( - projection_k, projection_v, q_L=L + projection_k, projection_v, _, mask = cache.update_and_fetch( + projection_k, projection_v, mask_length=L, mask=mask ) - if kv_cache_mask is not None: - mask = kv_cache_mask S = projection_k.shape[-2] if mask == "causal": mask = causal_mask(L, S, mx.float32) diff --git a/tests_refsol/test_week_2_day_1.py b/tests_refsol/test_week_2_day_1.py index bc412e5..3372a0b 100644 --- a/tests_refsol/test_week_2_day_1.py +++ b/tests_refsol/test_week_2_day_1.py @@ -1,6 +1,12 @@ import pytest from .utils import * -from .tiny_llm_base import Qwen2ModelWeek2, Embedding, dequantize_linear, qwen2_week2, TinyKvFullCache +from .tiny_llm_base import ( + Qwen2ModelWeek2, + Embedding, + dequantize_linear, + qwen2_week2, + TinyKvFullCache, +) from mlx_lm import load # TODO: task 1 tests @@ -94,6 +100,7 @@ def test_task_3_qwen_2_7b(): def test_task_3_qwen_2_15b(): helper_test_task_3("Qwen/Qwen2-1.5B-Instruct-MLX", 3) + def helper_test_task_4(model_name: str, seq_len: int, iters: int = 1): mlx_model, tokenizer = load(model_name) model = Qwen2ModelWeek2(mlx_model) @@ -102,12 +109,15 @@ def helper_test_task_4(model_name: str, seq_len: int, iters: int = 1): inputs = mx.random.randint(0, tokenizer.vocab_size, (1, seq_len)) ref_outputs = mlx_model(inputs) for offset in range(seq_len): - user_out = model(inputs=inputs[:, offset:offset+1], offset=offset, cache=cache) - ref_out = ref_outputs[:, offset:offset+1, :] + user_out = model( + inputs=inputs[:, offset : offset + 1], offset=offset, cache=cache + ) + ref_out = ref_outputs[:, offset : offset + 1, :] user_out = user_out - mx.logsumexp(user_out, keepdims=True) ref_out = ref_out - mx.logsumexp(ref_out, keepdims=True) assert_allclose(user_out, ref_out, precision=mx.float16, rtol=1e-1) + @pytest.mark.skipif( not qwen_2_05b_model_exists(), reason="Qwen2-0.5B-Instruct-MLX model not found" ) @@ -126,4 +136,4 @@ def test_task_4_qwen_2_7b(): not qwen_2_15b_model_exists(), reason="Qwen2-1.5B-Instruct-MLX model not found" ) def test_task_4_qwen_2_15b(): - helper_test_task_4("Qwen/Qwen2-1.5B-Instruct-MLX", seq_len=3) \ No newline at end of file + helper_test_task_4("Qwen/Qwen2-1.5B-Instruct-MLX", seq_len=3) diff --git a/tests_refsol/test_week_2_day_6.py b/tests_refsol/test_week_2_day_6.py index 1214756..c6079d7 100644 --- a/tests_refsol/test_week_2_day_6.py +++ b/tests_refsol/test_week_2_day_6.py @@ -61,8 +61,18 @@ def attention_helper( ) mx.eval(user_output_1) mx.eval(user_output_2) - assert_allclose(user_output_2, reference_output_2, precision=mx.float16, message="no mask") - assert_allclose(user_output_1, reference_output_1, precision=mx.float16, message="with mask") + assert_allclose( + user_output_2, + reference_output_2, + precision=mx.float16, + message="no mask", + ) + assert_allclose( + user_output_1, + reference_output_1, + precision=mx.float16, + message="with mask", + ) def test_flash_attention_with_mask_cpu_small(): From 5930135a0e119a899a7da46044045098c07321be Mon Sep 17 00:00:00 2001 From: Alex Chi Z Date: Sun, 3 Aug 2025 18:32:52 -0400 Subject: [PATCH 18/79] chunked prefill only in continuous batching Signed-off-by: Alex Chi Z --- src/tiny_llm_ref/generate.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/src/tiny_llm_ref/generate.py b/src/tiny_llm_ref/generate.py index d32896b..82e59a8 100644 --- a/src/tiny_llm_ref/generate.py +++ b/src/tiny_llm_ref/generate.py @@ -54,19 +54,9 @@ def _step(model, y, offset, kv_cache): # prefill with the prompt tokens = mx.array(tokenizer.encode(prompt, add_special_tokens=False)) - offset = 0 - prefill_max = 64 - total_tokens = tokens.size - while tokens.size > prefill_max: - token, _ = _step(model, tokens[:prefill_max], offset, kv_cache) - for i in kv_cache: - mx.eval(i.key_values[0]) - mx.eval(i.key_values[1]) - offset += prefill_max - tokens = tokens[prefill_max:] - print(f"Prefill progress: {offset}/{total_tokens}", flush=True) detokenizer = tokenizer.detokenizer detokenizer.reset() + offset = tokens.size # generate/decode while True: token, _ = _step(model, tokens, offset, kv_cache) From 8b4d9a78a94effbc6060fcb0e626b9be2151d335 Mon Sep 17 00:00:00 2001 From: Alex Chi Z Date: Sun, 3 Aug 2025 18:42:24 -0400 Subject: [PATCH 19/79] update readme and roadmap Signed-off-by: Alex Chi Z --- README.md | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 0b32d88..5258f52 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,8 @@ can build the model serving infrastructure from scratch and dig into the optimiz The goal is to learn the techniques behind efficiently serving a large language model (e.g., Qwen2 models). +In week 1, you will implement the necessary components in Python (only Python!) to use the Qwen2 model to generate responses (e.g., attention, RoPE, etc). In week 2, you will implement the inference system which is similar to but a much simpler version of vLLM (e.g., KV cache, continuous batching, flash attention, etc). In week 3, we will cover more advanced topics and how the model interacts with the outside world. + Why MLX: nowadays it's easier to get a macOS-based local development environment than setting up an NVIDIA GPU. Why Qwen2: this was the first LLM I've interacted with -- it's the go-to example in the vllm documentation. I spent some time looking at the vllm source code and built some knowledge around it. @@ -28,6 +30,7 @@ Week 1 is complete. Week 2 is in progress. | Week + Chapter | Topic | Code | Test | Doc | | -------------- | ----------------------------------------------------------- | ---- | ---- | --- | +| | Goal: wire up Qwen and make it generate text | | | | | 1.1 | Attention | ✅ | ✅ | ✅ | | 1.2 | RoPE | ✅ | ✅ | ✅ | | 1.3 | Grouped Query Attention | ✅ | ✅ | ✅ | @@ -36,18 +39,18 @@ Week 1 is complete. Week 2 is in progress. | 1.6 | Generate Responses (aka Decoding) | ✅ | ✅ | ✅ | | 1.7 | Sampling | ✅ | ✅ | ✅ | | 2.1 | Key-Value Cache | ✅ | 🚧 | 🚧 | -| 2.2 | Quantized Matmul and Linear - CPU | ✅ | 🚧 | 🚧 | -| 2.3 | Quantized Matmul and Linear - GPU | ✅ | 🚧 | 🚧 | -| 2.4 | Flash Attention 2 - CPU | ✅ | 🚧 | 🚧 | -| 2.5 | Flash Attention 2 - GPU | ✅ | 🚧 | 🚧 | +| 2.2 | Quantized Matmul and Linear - CPU | ✅ | ✅ | 🚧 | +| 2.3 | Quantized Matmul and Linear - GPU | ✅ | ✅ | 🚧 | +| 2.4 | Flash Attention 2 - CPU | ✅ | ✅ | 🚧 | +| 2.5 | Flash Attention 2 - GPU | ✅ | ✅ | 🚧 | | 2.6 | Continuous Batching | ✅ | 🚧 | 🚧 | | 2.7 | Chunked Prefill | ✅ | 🚧 | 🚧 | | 3.1 | Paged Attention - Part 1 | 🚧 | 🚧 | 🚧 | | 3.2 | Paged Attention - Part 2 | 🚧 | 🚧 | 🚧 | | 3.3 | MoE (Mixture of Experts) | 🚧 | 🚧 | 🚧 | | 3.4 | Speculative Decoding | 🚧 | 🚧 | 🚧 | -| 3.5 | Prefill-Decode Separation (requires two Macintosh devices) | 🚧 | 🚧 | 🚧 | -| 3.6 | Parallelism | 🚧 | 🚧 | 🚧 | -| 3.7 | AI Agent / Tool Calling | 🚧 | 🚧 | 🚧 | +| 3.5 | RAG Pipeline | 🚧 | 🚧 | 🚧 | +| 3.6 | AI Agent / Tool Calling | 🚧 | 🚧 | 🚧 | +| 3.7 | Long Context | 🚧 | 🚧 | 🚧 | Other topics not covered: quantized/compressed kv cache, prefix/prompt cache; sampling, fine tuning; smaller kernels (softmax, silu, etc) From 024d528820df75f66f82e09fb137f59ff9fd3051 Mon Sep 17 00:00:00 2001 From: Zhen Tong <88566085+58191554@users.noreply.github.com> Date: Thu, 7 Aug 2025 18:25:40 -0700 Subject: [PATCH 20/79] update the vllm-RoPE code link in the reading (#39) --- book/src/week1-02-positional-encodings.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/book/src/week1-02-positional-encodings.md b/book/src/week1-02-positional-encodings.md index c105411..7e14dd3 100644 --- a/book/src/week1-02-positional-encodings.md +++ b/book/src/week1-02-positional-encodings.md @@ -79,7 +79,7 @@ frequencies to each half separately. **📚 Readings** -- [vLLM implementation of RoPE](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/rotary_embedding.py) +- [vLLM implementation of RoPE](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/rotary_embedding) You can test your implementation by running the following command: From 00ea99003ef08901ada1bf79722cc2c0c2979a14 Mon Sep 17 00:00:00 2001 From: Alex Chi Z Date: Sat, 9 Aug 2025 16:01:59 -0400 Subject: [PATCH 21/79] bfloat16 support for matmul Signed-off-by: Alex Chi Z --- src/extensions_ref/src/quantized_matmul.cpp | 25 +++++++++++++------ src/extensions_ref/src/quantized_matmul.metal | 17 +++++++++---- 2 files changed, 30 insertions(+), 12 deletions(-) diff --git a/src/extensions_ref/src/quantized_matmul.cpp b/src/extensions_ref/src/quantized_matmul.cpp index b599b58..0420414 100644 --- a/src/extensions_ref/src/quantized_matmul.cpp +++ b/src/extensions_ref/src/quantized_matmul.cpp @@ -23,14 +23,17 @@ mx::array quantized_matmul(const mx::array &scales, // Input array scale const bool transpose_b, // Whether to transpose b mx::StreamOrDevice s /* = {} */ // Stream on which to schedule the operation ) { - if (scales.dtype() != mx::float16 || biases.dtype() != mx::float16) { - throw std::runtime_error("quantized_matmul: scales and biases must be float16"); + if (scales.dtype() != mx::float16 && scales.dtype() != mx::bfloat16) { + throw std::runtime_error("quantized_matmul: scales must be float16 or bfloat16"); + } + if (scales.dtype() != biases.dtype()) { + throw std::runtime_error("quantized_matmul: scales and biases must be the same dtype"); } if (b.dtype() != mx::uint32) { throw std::runtime_error("quantized_matmul: b must be uint32"); } - if (a.dtype() != mx::float16) { - throw std::runtime_error("quantized_matmul: a must be float16"); + if (a.dtype() != scales.dtype()) { + throw std::runtime_error("quantized_matmul: a must be the same dtype as scales"); } if (a.shape().size() != 2) { throw std::runtime_error("quantized_matmul: a must be a 2D array"); @@ -68,7 +71,7 @@ mx::array quantized_matmul(const mx::array &scales, // Input array scale return mx::array( /* const mx::Shape& shape = */ out_shape, - /* mx::Dtype dtype = */ mx::float16, + /* mx::Dtype dtype = */ a.dtype(), /* std::shared_ptr primitive = */ std::make_shared(to_stream(s), group_size, bits), /* const std::vector& inputs = */ {scales, biases, a, b}); @@ -92,7 +95,7 @@ void quantized_matmul_impl(const mx::array &scales, const mx::array &biases, con throw std::runtime_error("quantized_matmul: b must be contiguous"); } - // Launch the CPU kernel + // Launch the CPU kernel, TODO: support bfloat16 encoder.dispatch([out_ptr = out.data(), out_shape = out.shape(), out_strides = out.strides(), a = mx::array::unsafe_weak_copy(a), b = mx::array::unsafe_weak_copy(b), scales = mx::array::unsafe_weak_copy(scales), biases = mx::array::unsafe_weak_copy(biases)]() { @@ -164,7 +167,15 @@ void QuantizedMatmul::eval_gpu(const std::vector &inputs, std::vector // Make a kernel from this metal library auto library = d.get_library("tiny_llm_ext_ref"); - auto kernel = d.get_kernel("quantized_matmul_w4a16_g64", library); + const char* kernel_name; + if (a.dtype() == mx::float16) { + kernel_name = "quantized_matmul_w4a16_g64_f16"; + } else if (a.dtype() == mx::bfloat16) { + kernel_name = "quantized_matmul_w4a16_g64_bf16"; + } else { + throw std::runtime_error("quantized_matmul: a must be float16 or bfloat16"); + } + auto kernel = d.get_kernel(kernel_name, library); // Prepare to encode kernel auto &compute_encoder = d.get_command_encoder(s.index); diff --git a/src/extensions_ref/src/quantized_matmul.metal b/src/extensions_ref/src/quantized_matmul.metal index 5a58eb6..3b354bb 100644 --- a/src/extensions_ref/src/quantized_matmul.metal +++ b/src/extensions_ref/src/quantized_matmul.metal @@ -1,9 +1,13 @@ +#include +#include "mlx/backend/metal/kernels/utils.h" + +template [[kernel]] void quantized_matmul_w4a16_g64( - device const half* scales [[buffer(0)]], - device const half* biases [[buffer(1)]], - device const half* a [[buffer(2)]], + device const T* scales [[buffer(0)]], + device const T* biases [[buffer(1)]], + device const T* a [[buffer(2)]], device const uint32_t* b [[buffer(3)]], - device half* out [[buffer(4)]], + device T* out [[buffer(4)]], device const int &M [[buffer(5)]], device const int &N [[buffer(6)]], device const int &K [[buffer(7)]], @@ -43,6 +47,9 @@ } scales_biases_loc += 1; } - out[i * K + k] = sum; + out[i * K + k] = static_cast(sum); } } + +instantiate_kernel("quantized_matmul_w4a16_g64_f16", quantized_matmul_w4a16_g64, float16_t); +instantiate_kernel("quantized_matmul_w4a16_g64_bf16", quantized_matmul_w4a16_g64, bfloat16_t); \ No newline at end of file From 850dd6caacbaefd133301f292105251f45e37c5f Mon Sep 17 00:00:00 2001 From: Alex Chi Z Date: Sat, 9 Aug 2025 16:02:53 -0400 Subject: [PATCH 22/79] model shortcut and dispatcher Signed-off-by: Alex Chi Z --- batch-main.py | 17 ++++++++++--- book/src/week1-06-generate-response.md | 4 +-- book/src/week1-07-sampling-prepare.md | 6 ++--- book/src/week2-overview.md | 2 +- main.py | 29 +++++++++++++-------- src/tiny_llm/__init__.py | 4 +++ src/tiny_llm/models.py | 35 ++++++++++++++++++++++++++ src/tiny_llm_ref/__init__.py | 2 ++ src/tiny_llm_ref/models.py | 1 + 9 files changed, 80 insertions(+), 20 deletions(-) create mode 100644 src/tiny_llm/models.py create mode 120000 src/tiny_llm_ref/models.py diff --git a/batch-main.py b/batch-main.py index 9d7ed3f..6334b5f 100644 --- a/batch-main.py +++ b/batch-main.py @@ -4,7 +4,7 @@ import random parser = argparse.ArgumentParser() -parser.add_argument("--model", type=str, default="Qwen/Qwen2-0.5B-Instruct-MLX") +parser.add_argument("--model", type=str, default="qwen2-0.5b") shanghai_wikipedia = """ Shanghai[a] is a direct-administered municipality and the most populous urban area in China. The city is located on the Chinese shoreline on the southern estuary of the Yangtze River, with the Huangpu River flowing through it. The population of the city proper is the second largest in the world after Chongqing, with around 24.87 million inhabitants in 2023, while the urban area is the most populous in China, with 29.87 million residents. As of 2022, the Greater Shanghai metropolitan area was estimated to produce a gross metropolitan product (nominal) of nearly 13 trillion RMB ($1.9 trillion).[13] Shanghai is one of the world's major centers for finance, business and economics, research, science and technology, manufacturing, transportation, tourism, and culture. The Port of Shanghai is the world's busiest container port. @@ -38,23 +38,31 @@ parser.add_argument("--device", type=str, default="gpu") parser.add_argument("--batch-size", type=int, default=5) parser.add_argument("--prefill-step", type=int, default=128) +parser.add_argument("--enable-flash-attn", action="store_true") +parser.add_argument("--enable-thinking", action="store_true") args = parser.parse_args() if args.solution == "tiny_llm": print("Using your tiny_llm solution") - from tiny_llm import Qwen2ModelWeek2, batch_generate + from tiny_llm import models, batch_generate elif args.solution == "tiny_llm_ref" or args.solution == "ref": print("Using tiny_llm_ref solution") - from tiny_llm_ref import Qwen2ModelWeek2, batch_generate + from tiny_llm_ref import models, batch_generate else: raise ValueError(f"Solution {args.solution} not supported") +args.model = models.shortcut_name_to_full_name(args.model) mlx_model, tokenizer = load(args.model) with mx.stream(mx.gpu if args.device == "gpu" else mx.cpu): - tiny_llm_model = Qwen2ModelWeek2(mlx_model) + print( + f"Using week2 loader with flash_attn={args.enable_flash_attn} thinking={args.enable_thinking} for {args.model}" + ) + tiny_llm_model = models.dispatch_model( + args.model, mlx_model, week=2, enable_flash_attn=args.enable_flash_attn + ) encoded_prompts = [] for idx, prompt in enumerate(prompts): print(f"Prompt {idx}: {prompt}") @@ -66,6 +74,7 @@ messages, tokenize=False, add_generation_prompt=True, + enable_thinking=args.enable_thinking, ) encoded_prompts.append(prompt) result = batch_generate( diff --git a/book/src/week1-06-generate-response.md b/book/src/week1-06-generate-response.md index dd24ccb..8c67875 100644 --- a/book/src/week1-06-generate-response.md +++ b/book/src/week1-06-generate-response.md @@ -64,9 +64,9 @@ We will optimize the `decode` process to use key-value cache to speed up the gen You can test your implementation by running the following command: ```bash -pdm run main --solution tiny_llm --loader week1 --model Qwen/Qwen2-0.5B-Instruct-MLX \ +pdm run main --solution tiny_llm --loader week1 --model qwen2-0.5b \ --prompt "Give me a short introduction to large language model" -pdm run main --solution tiny_llm --loader week1 --model Qwen/Qwen2-7B-Instruct-MLX \ +pdm run main --solution tiny_llm --loader week1 --model qwen2-7b \ --prompt "Give me a short introduction to large language model" ``` diff --git a/book/src/week1-07-sampling-prepare.md b/book/src/week1-07-sampling-prepare.md index 2e266a2..e4d76d2 100644 --- a/book/src/week1-07-sampling-prepare.md +++ b/book/src/week1-07-sampling-prepare.md @@ -24,7 +24,7 @@ To implement temperature sampling, simply divide the logprobs by the temperature randomly select the next token. ```bash -pdm run main --solution tiny_llm --loader week1 --model Qwen/Qwen2-0.5B-Instruct-MLX --sampler-temp 0.5 +pdm run main --solution tiny_llm --loader week1 --model qwen2-0.5b --sampler-temp 0.5 ``` **Top-k Sampling** @@ -36,7 +36,7 @@ You can use `mx.argpartition` to partition the output so that you can know the i mask those logprobs outside the top-k with `-mx.inf`. After that, do temperature sampling. ```bash -pdm run main --solution tiny_llm --loader week1 --model Qwen/Qwen2-0.5B-Instruct-MLX --sampler-temp 0.5 --sampler-top-k 10 +pdm run main --solution tiny_llm --loader week1 --model qwen2-0.5b --sampler-temp 0.5 --sampler-top-k 10 ``` **Top-p (Nucleus) Sampling** @@ -49,7 +49,7 @@ probability to lowest), and then, do a `cumsum` over the sorted logprobs to get those logprobs outside the top-p with `-mx.inf`. After that, do temperature sampling. ```bash -pdm run main --solution tiny_llm --loader week1 --model Qwen/Qwen2-0.5B-Instruct-MLX --sampler-temp 0.5 --sampler-top-p 0.9 +pdm run main --solution tiny_llm --loader week1 --model qwen2-0.5b --sampler-temp 0.5 --sampler-top-p 0.9 ``` ## Task 2: Prepare for Week 2 diff --git a/book/src/week2-overview.md b/book/src/week2-overview.md index 63d4e49..e008678 100644 --- a/book/src/week2-overview.md +++ b/book/src/week2-overview.md @@ -28,4 +28,4 @@ https://huggingface.co/docs/transformers/pad_truncation https://siboehm.com/articles/22/CUDA-MMM https://github.com/ggml-org/llama.cpp/blob/master/ggml/src/ggml-metal/ggml-metal.metal -pdm run batch-main --solution ref --model Qwen/Qwen2-7B-Instruct-MLX --prefill-step 16 +pdm run batch-main --solution ref --model qwen2-7b --prefill-step 16 diff --git a/main.py b/main.py index cb31d5b..05fdd98 100644 --- a/main.py +++ b/main.py @@ -6,7 +6,7 @@ import mlx_lm.sample_utils parser = argparse.ArgumentParser() -parser.add_argument("--model", type=str, default="Qwen/Qwen2-7B-Instruct-MLX") +parser.add_argument("--model", type=str, default="qwen2-7b") parser.add_argument( "--prompt", type=str, @@ -18,14 +18,16 @@ parser.add_argument("--sampler-temp", type=float, default=0) parser.add_argument("--sampler-top-p", type=float, default=None) parser.add_argument("--sampler-top-k", type=int, default=None) +parser.add_argument("--enable-thinking", action="store_true") +parser.add_argument("--enable-flash-attn", action="store_true") + args = parser.parse_args() use_mlx = False if args.solution == "tiny_llm": print("Using your tiny_llm solution") from tiny_llm import ( - Qwen2ModelWeek1, - Qwen2ModelWeek2, + models, simple_generate, simple_generate_with_kv_cache, sampler, @@ -34,8 +36,7 @@ elif args.solution == "tiny_llm_ref" or args.solution == "ref": print("Using tiny_llm_ref solution") from tiny_llm_ref import ( - Qwen2ModelWeek1, - Qwen2ModelWeek2, + models, simple_generate, simple_generate_with_kv_cache, sampler, @@ -49,6 +50,7 @@ else: raise ValueError(f"Solution {args.solution} not supported") +args.model = models.shortcut_name_to_full_name(args.model) mlx_model, tokenizer = load(args.model) with mx.stream(mx.gpu if args.device == "gpu" else mx.cpu): @@ -56,11 +58,15 @@ tiny_llm_model = mlx_model else: if args.loader == "week1": - print("Using Qwen2ModelWeek1 loader") - tiny_llm_model = Qwen2ModelWeek1(mlx_model) + print(f"Using week1 loader for {args.model}") + tiny_llm_model = models.dispatch_model(args.model, mlx_model, week=1) elif args.loader == "week2": - print("Using Qwen2ModelWeek2 loader") - tiny_llm_model = Qwen2ModelWeek2(mlx_model) + print( + f"Using week2 loader with flash_attn={args.enable_flash_attn} thinking={args.enable_thinking} for {args.model}" + ) + tiny_llm_model = models.dispatch_model( + args.model, mlx_model, week=2, enable_flash_attn=args.enable_flash_attn + ) else: raise ValueError(f"Loader {args.loader} not supported") messages = [ @@ -68,7 +74,10 @@ {"role": "user", "content": args.prompt}, ] prompt = tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=args.enable_thinking, ) if not use_mlx: sampler = sampler.make_sampler( diff --git a/src/tiny_llm/__init__.py b/src/tiny_llm/__init__.py index 99801a2..bb237dd 100644 --- a/src/tiny_llm/__init__.py +++ b/src/tiny_llm/__init__.py @@ -5,7 +5,11 @@ from .positional_encoding import * from .quantize import * from .generate import * +from .kv_cache import * from .qwen2_week1 import Qwen2ModelWeek1 from .qwen2_week2 import Qwen2ModelWeek2 +from .qwen3 import Qwen3Model from .sampler import * from .kv_cache import * +from .batch import * +from .models import * diff --git a/src/tiny_llm/models.py b/src/tiny_llm/models.py new file mode 100644 index 0000000..2a58f74 --- /dev/null +++ b/src/tiny_llm/models.py @@ -0,0 +1,35 @@ +from .qwen2_week1 import Qwen2ModelWeek1 +from .qwen2_week2 import Qwen2ModelWeek2 +from .qwen3 import Qwen3Model + + +def shortcut_name_to_full_name(shortcut_name: str): + lower_shortcut_name = shortcut_name.lower() + if lower_shortcut_name == "qwen2-7b": + return "Qwen/Qwen2-7B-Instruct-MLX" + elif lower_shortcut_name == "qwen2-0.5b": + return "Qwen/Qwen2-0.5B-Instruct-MLX" + elif lower_shortcut_name == "qwen2-1.5b": + return "Qwen/Qwen2-1.5B-Instruct-MLX" + elif lower_shortcut_name == "qwen3-8b": + return "mlx-community/Qwen3-8B-4bit" + elif lower_shortcut_name == "qwen3-0.6b": + return "mlx-community/Qwen3-0.6B-4bit" + elif lower_shortcut_name == "qwen3-1.7b": + return "mlx-community/Qwen3-1.7B-4bit" + elif lower_shortcut_name == "qwen3-4b": + return "mlx-community/Qwen3-4B-4bit" + else: + return shortcut_name + + +def dispatch_model(model_name: str, mlx_model, week: int, **kwargs): + model_name = shortcut_name_to_full_name(model_name) + if week == 1 and model_name.startswith("Qwen/Qwen2"): + return Qwen2ModelWeek1(mlx_model, **kwargs) + elif week == 2 and model_name.startswith("Qwen/Qwen2"): + return Qwen2ModelWeek2(mlx_model, **kwargs) + elif week == 2 and model_name.startswith("mlx-community/Qwen3"): + return Qwen3Model(mlx_model, **kwargs) + else: + raise ValueError(f"{model_name} for week {week} not supported") diff --git a/src/tiny_llm_ref/__init__.py b/src/tiny_llm_ref/__init__.py index 402f98f..bb237dd 100644 --- a/src/tiny_llm_ref/__init__.py +++ b/src/tiny_llm_ref/__init__.py @@ -8,6 +8,8 @@ from .kv_cache import * from .qwen2_week1 import Qwen2ModelWeek1 from .qwen2_week2 import Qwen2ModelWeek2 +from .qwen3 import Qwen3Model from .sampler import * from .kv_cache import * from .batch import * +from .models import * diff --git a/src/tiny_llm_ref/models.py b/src/tiny_llm_ref/models.py new file mode 120000 index 0000000..acf3886 --- /dev/null +++ b/src/tiny_llm_ref/models.py @@ -0,0 +1 @@ +../tiny_llm/models.py \ No newline at end of file From ffbd15d25c6913168bdddd60d91f0bba8c826bdd Mon Sep 17 00:00:00 2001 From: Alex Chi Z Date: Sat, 9 Aug 2025 16:03:14 -0400 Subject: [PATCH 23/79] qwen3 support Signed-off-by: Alex Chi Z --- src/tiny_llm/batch.py | 0 src/tiny_llm/qwen3.py | 131 +++++++++++++ src/tiny_llm_ref/batch.py | 7 +- src/tiny_llm_ref/qwen2_week2.py | 3 +- src/tiny_llm_ref/qwen3.py | 332 ++++++++++++++++++++++++++++++++ 5 files changed, 471 insertions(+), 2 deletions(-) create mode 100644 src/tiny_llm/batch.py create mode 100644 src/tiny_llm/qwen3.py create mode 100644 src/tiny_llm_ref/qwen3.py diff --git a/src/tiny_llm/batch.py b/src/tiny_llm/batch.py new file mode 100644 index 0000000..e69de29 diff --git a/src/tiny_llm/qwen3.py b/src/tiny_llm/qwen3.py new file mode 100644 index 0000000..bb758d2 --- /dev/null +++ b/src/tiny_llm/qwen3.py @@ -0,0 +1,131 @@ +import mlx.core as mx +from .basics import silu +from .attention import ( + scaled_dot_product_attention_grouped, + flash_attention, + causal_mask, +) +from .layer_norm import RMSNorm +from .positional_encoding import RoPE +from typing import Any +from .embedding import Embedding +from .quantize import dequantize_linear, QuantizedWeights, quantized_linear +from .kv_cache import TinyKvCache + + +class Qwen3MultiHeadAttention: + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + wq: QuantizedWeights, + wk: QuantizedWeights, + wv: QuantizedWeights, + wo: QuantizedWeights, + q_norm: mx.array, + k_norm: mx.array, + max_seq_len: int = 32768, + theta: int = 1000000, + rms_norm_eps: float = 1e-5, + use_flash_attention: bool = False, + ): + pass + + def __call__( + self, + x: mx.array, + offsets: list[int], + cache: TinyKvCache, + mask: mx.array | str | None = None, + ) -> mx.array: + pass + + +class Qwen3MLP: + def __init__( + self, + dim: int, + hidden_dim: int, + w_gate: QuantizedWeights, + w_up: QuantizedWeights, + w_down: QuantizedWeights, + ): + self.dim = dim + self.hidden_dim = hidden_dim + self.w_gate = w_gate + self.w_up = w_up + self.w_down = w_down + + def __call__(self, x: mx.array) -> mx.array: + pass + + +class Qwen3TransformerBlock: + def __init__( + self, + num_attention_heads: int, + num_kv_heads: int, + hidden_size: int, + head_dim: int, + intermediate_size: int, + rms_norm_eps: float, + wq: QuantizedWeights, + wk: QuantizedWeights, + wv: QuantizedWeights, + wo: QuantizedWeights, + q_norm: mx.array, + k_norm: mx.array, + w_gate: QuantizedWeights, + w_up: QuantizedWeights, + w_down: QuantizedWeights, + w_input_layernorm: mx.array, + w_post_attention_layernorm: mx.array, + max_seq_len: int = 32768, + theta: int = 1000000, + use_flash_attention: bool = False, + ): + pass + + def __call__( + self, + x: mx.array, + offset: int, + cache: TinyKvCache, + mask: mx.array | str | None = None, + ) -> mx.array: + pass + + +def assert_dtype(weights: mx.array, dtype: mx.Dtype): + if weights.dtype != dtype: + raise ValueError(f"{weights.dtype} != {dtype}") + else: + return weights + + +def assert_quantized_weights_dtype(weights: QuantizedWeights, dtype: mx.Dtype): + if weights.scales.dtype != dtype: + raise ValueError(f"{weights.scales.dtype} != {dtype}") + if weights.biases.dtype != dtype: + raise ValueError(f"{weights.biases.dtype} != {dtype}") + else: + return weights + + +class Qwen3Model: + def __init__( + self, + mlx_model: Any, + enable_flash_attn: bool = False, + ): + pass + + def __call__( + self, + inputs: mx.array, + offset: int, + cache: list[TinyKvCache], + ) -> mx.array: + pass diff --git a/src/tiny_llm_ref/batch.py b/src/tiny_llm_ref/batch.py index a990902..659a330 100644 --- a/src/tiny_llm_ref/batch.py +++ b/src/tiny_llm_ref/batch.py @@ -3,6 +3,7 @@ from .kv_cache import * from .qwen2_week2 import Qwen2ModelWeek2 from typing import Callable +from datetime import datetime def _step(model, y, offsets, kv_cache): @@ -83,8 +84,9 @@ def _print_progress( pending_prefill_request: Request | None, queue_size: int, progress_cnt: int, + start_time: datetime, ): - print(" ---") + print(f" --- {datetime.now() - start_time}") animation_frames = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"] animation_frame = animation_frames[progress_cnt % len(animation_frames)] for i in range(len(requests)): @@ -131,6 +133,7 @@ def batch_generate( pending_prefill_request = None next_request_idx = 0 progress_cnt = 0 + start_time = datetime.now() while True: if len(prompts) == 0 and all(is_idle): @@ -173,6 +176,7 @@ def batch_generate( pending_prefill_request, len(prompts), progress_cnt, + start_time, ) progress_cnt += 1 @@ -214,6 +218,7 @@ def batch_generate( pending_prefill_request, len(prompts), progress_cnt, + start_time, ) progress_cnt += 1 return result diff --git a/src/tiny_llm_ref/qwen2_week2.py b/src/tiny_llm_ref/qwen2_week2.py index 636dc4b..540087d 100644 --- a/src/tiny_llm_ref/qwen2_week2.py +++ b/src/tiny_llm_ref/qwen2_week2.py @@ -191,6 +191,7 @@ class Qwen2ModelWeek2: def __init__( self, mlx_model: Any, + enable_flash_attn: bool = False, ): self.num_hidden_layers = mlx_model.args.num_hidden_layers self.hidden_size = mlx_model.args.hidden_size @@ -252,7 +253,7 @@ def __init__( ].post_attention_layernorm.weight.astype(precision), max_seq_len=mlx_model.args.max_position_embeddings, theta=mlx_model.args.rope_theta, - use_flash_attention=True, + use_flash_attention=enable_flash_attn, ) self.layers_inner.append(layer) self.norm = RMSNorm( diff --git a/src/tiny_llm_ref/qwen3.py b/src/tiny_llm_ref/qwen3.py new file mode 100644 index 0000000..4d28ea8 --- /dev/null +++ b/src/tiny_llm_ref/qwen3.py @@ -0,0 +1,332 @@ +import mlx.core as mx +from .basics import silu +from .attention import ( + scaled_dot_product_attention_grouped, + flash_attention, + causal_mask, +) +from .layer_norm import RMSNorm +from .positional_encoding import RoPE +from typing import Any +from .embedding import Embedding +from .quantize import dequantize_linear, QuantizedWeights, quantized_linear +from .kv_cache import TinyKvCache + + +class Qwen3MultiHeadAttention: + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + wq: QuantizedWeights, + wk: QuantizedWeights, + wv: QuantizedWeights, + wo: QuantizedWeights, + q_norm: mx.array, + k_norm: mx.array, + max_seq_len: int = 32768, + theta: int = 1000000, + rms_norm_eps: float = 1e-5, + use_flash_attention: bool = False, + ): + self.hidden_size = hidden_size + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + assert hidden_size % num_heads == 0, ( + f"hidden_size {hidden_size} must be divisible by num_heads {num_heads}" + ) + assert num_heads % num_kv_heads == 0, ( + f"num_heads {num_heads} must be divisible by num_kv_heads {num_kv_heads}" + ) + self.head_dim = head_dim + self.scale = mx.rsqrt(self.head_dim) + self.wq = wq + self.wk = wk + self.wv = wv + self.wo = wo + self.q_norm = q_norm + self.k_norm = k_norm + self.rope = RoPE(self.head_dim, max_seq_len, theta) + self.q_norm = RMSNorm(self.head_dim, q_norm, eps=rms_norm_eps) + self.k_norm = RMSNorm(self.head_dim, k_norm, eps=rms_norm_eps) + self.use_flash_attention = use_flash_attention + + def __call__( + self, + x: mx.array, + offsets: list[int], + cache: TinyKvCache, + mask: mx.array | str | None = None, + ) -> mx.array: + B, L, _ = x.shape + projection_q = quantized_linear(x, self.wq).reshape( + B, L, self.num_heads, self.head_dim + ) + projection_k = quantized_linear(x, self.wk).reshape( + B, L, self.num_kv_heads, self.head_dim + ) + projection_q = self.q_norm(projection_q) + projection_k = self.k_norm(projection_k) + projection_v = quantized_linear(x, self.wv).reshape( + B, L, self.num_kv_heads, self.head_dim + ) + # todo: move offsets to kv cache + if isinstance(offsets, int): + offset_slice = [slice(int(offsets), int(offsets + L))] + else: + offset_slice = [slice(int(i), int(i + L)) for i in offsets] + projection_q = self.rope(projection_q, offset=offset_slice) + projection_k = self.rope(projection_k, offset=offset_slice) + projection_q = projection_q.transpose(0, 2, 1, 3) + projection_k = projection_k.transpose(0, 2, 1, 3) + projection_v = projection_v.transpose(0, 2, 1, 3) + projection_k, projection_v, _, mask = cache.update_and_fetch( + projection_k, projection_v, mask_length=L, mask=mask + ) + S = projection_k.shape[-2] + if mask == "causal": + mask = causal_mask(L, S, mx.float32) + if self.use_flash_attention: + x = flash_attention( + projection_q.astype(mx.float32), + projection_k.astype(mx.float32), + projection_v.astype(mx.float32), + scale=self.scale, + mask=mask, + ).astype(x.dtype) + else: + x = scaled_dot_product_attention_grouped( + projection_q.astype(mx.float32), + projection_k.astype(mx.float32), + projection_v.astype(mx.float32), + scale=self.scale, + mask=mask, + ).astype(x.dtype) + x = x.transpose(0, 2, 1, 3).reshape(B, L, self.num_heads * self.head_dim) + return quantized_linear(x, self.wo) + + +class Qwen3MLP: + def __init__( + self, + dim: int, + hidden_dim: int, + w_gate: QuantizedWeights, + w_up: QuantizedWeights, + w_down: QuantizedWeights, + ): + self.dim = dim + self.hidden_dim = hidden_dim + self.w_gate = w_gate + self.w_up = w_up + self.w_down = w_down + + def __call__(self, x: mx.array) -> mx.array: + return quantized_linear( + silu(quantized_linear(x, self.w_gate)) * quantized_linear(x, self.w_up), + self.w_down, + ) + + +class Qwen3TransformerBlock: + def __init__( + self, + num_attention_heads: int, + num_kv_heads: int, + hidden_size: int, + head_dim: int, + intermediate_size: int, + rms_norm_eps: float, + wq: QuantizedWeights, + wk: QuantizedWeights, + wv: QuantizedWeights, + wo: QuantizedWeights, + q_norm: mx.array, + k_norm: mx.array, + w_gate: QuantizedWeights, + w_up: QuantizedWeights, + w_down: QuantizedWeights, + w_input_layernorm: mx.array, + w_post_attention_layernorm: mx.array, + max_seq_len: int = 32768, + theta: int = 1000000, + use_flash_attention: bool = False, + ): + self.num_attention_heads = num_attention_heads + self.hidden_size = hidden_size + self.mlp = Qwen3MLP(hidden_size, intermediate_size, w_gate, w_up, w_down) + self.input_layernorm = RMSNorm(hidden_size, w_input_layernorm, eps=rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + hidden_size, w_post_attention_layernorm, eps=rms_norm_eps + ) + self.self_attn = Qwen3MultiHeadAttention( + num_heads=num_attention_heads, + hidden_size=hidden_size, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + wq=wq, + wk=wk, + wv=wv, + wo=wo, + q_norm=q_norm, + k_norm=k_norm, + max_seq_len=max_seq_len, + theta=theta, + rms_norm_eps=rms_norm_eps, + use_flash_attention=use_flash_attention, + ) + + def __call__( + self, + x: mx.array, + offset: int, + cache: TinyKvCache, + mask: mx.array | str | None = None, + ) -> mx.array: + r = self.self_attn(self.input_layernorm(x), offset, cache, mask) + h = x + r + r = self.mlp(self.post_attention_layernorm(h)) + out = h + r + return out + + +def assert_dtype(weights: mx.array, dtype: mx.Dtype): + if weights.dtype != dtype: + raise ValueError(f"{weights.dtype} != {dtype}") + else: + return weights + + +def assert_quantized_weights_dtype(weights: QuantizedWeights, dtype: mx.Dtype): + if weights.scales.dtype != dtype: + raise ValueError(f"{weights.scales.dtype} != {dtype}") + if weights.biases.dtype != dtype: + raise ValueError(f"{weights.biases.dtype} != {dtype}") + else: + return weights + + +class Qwen3Model: + def __init__( + self, + mlx_model: Any, + enable_flash_attn: bool = False, + ): + self.num_hidden_layers = mlx_model.args.num_hidden_layers + self.hidden_size = mlx_model.args.hidden_size + self.vocab_size = mlx_model.args.vocab_size + precision = mx.bfloat16 + self.precision = precision + + self.embedding = Embedding( + vocab_size=self.vocab_size, + embedding_dim=self.hidden_size, + weight=assert_dtype( + dequantize_linear(mlx_model.model.embed_tokens), dtype=precision + ), + ) + self.layers_inner = [] + + for i in range(mlx_model.args.num_hidden_layers): + wq = assert_quantized_weights_dtype( + QuantizedWeights.from_mlx_layer( + mlx_model.model.layers[i].self_attn.q_proj + ), + dtype=precision, + ) + wk = assert_quantized_weights_dtype( + QuantizedWeights.from_mlx_layer( + mlx_model.model.layers[i].self_attn.k_proj + ), + dtype=precision, + ) + wv = assert_quantized_weights_dtype( + QuantizedWeights.from_mlx_layer( + mlx_model.model.layers[i].self_attn.v_proj + ), + dtype=precision, + ) + wo = assert_quantized_weights_dtype( + QuantizedWeights.from_mlx_layer( + mlx_model.model.layers[i].self_attn.o_proj + ), + dtype=precision, + ) + w_gate = assert_quantized_weights_dtype( + QuantizedWeights.from_mlx_layer( + mlx_model.model.layers[i].mlp.gate_proj + ), + dtype=precision, + ) + w_up = assert_quantized_weights_dtype( + QuantizedWeights.from_mlx_layer(mlx_model.model.layers[i].mlp.up_proj), + dtype=precision, + ) + w_down = assert_quantized_weights_dtype( + QuantizedWeights.from_mlx_layer( + mlx_model.model.layers[i].mlp.down_proj + ), + dtype=precision, + ) + + layer = Qwen3TransformerBlock( + num_attention_heads=mlx_model.args.num_attention_heads, + num_kv_heads=mlx_model.args.num_key_value_heads, + hidden_size=mlx_model.args.hidden_size, + head_dim=mlx_model.args.head_dim, + intermediate_size=mlx_model.args.intermediate_size, + rms_norm_eps=mlx_model.args.rms_norm_eps, + wq=wq, + wk=wk, + wv=wv, + wo=wo, + q_norm=assert_dtype( + mlx_model.model.layers[i].self_attn.q_norm.weight, dtype=precision + ), + k_norm=assert_dtype( + mlx_model.model.layers[i].self_attn.k_norm.weight, dtype=precision + ), + w_gate=w_gate, + w_up=w_up, + w_down=w_down, + w_input_layernorm=assert_dtype( + mlx_model.model.layers[i].input_layernorm.weight, dtype=precision + ), + w_post_attention_layernorm=assert_dtype( + mlx_model.model.layers[i].post_attention_layernorm.weight, + dtype=precision, + ), + max_seq_len=mlx_model.args.max_position_embeddings, + theta=mlx_model.args.rope_theta, + use_flash_attention=enable_flash_attn, + ) + self.layers_inner.append(layer) + self.norm = RMSNorm( + mlx_model.args.hidden_size, + weight=assert_dtype(mlx_model.model.norm.weight, dtype=precision), + eps=mlx_model.args.rms_norm_eps, + ) + if not mlx_model.args.tie_word_embeddings: + self.w_lm_head = assert_quantized_weights_dtype( + QuantizedWeights.from_mlx_layer(mlx_model.lm_head), dtype=precision + ) + else: + self.w_lm_head = None + self.mlx_model = mlx_model + + def __call__( + self, + inputs: mx.array, + offset: int, + cache: list[TinyKvCache], + ) -> mx.array: + h = self.embedding(inputs) + for layer in range(self.num_hidden_layers): + h = self.layers_inner[layer](h, offset, cache[layer], mask="causal") + h = self.norm(h) + if self.w_lm_head is not None: + return quantized_linear(h, self.w_lm_head) + else: + return self.embedding.as_linear(h) From cd87116770c804f1682deabec89c13274a3a7203 Mon Sep 17 00:00:00 2001 From: Alex Chi Z Date: Sat, 9 Aug 2025 16:04:37 -0400 Subject: [PATCH 24/79] update readme Signed-off-by: Alex Chi Z --- README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/README.md b/README.md index 5258f52..7819533 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,6 @@ Week 1 is complete. Week 2 is in progress. | Week + Chapter | Topic | Code | Test | Doc | | -------------- | ----------------------------------------------------------- | ---- | ---- | --- | -| | Goal: wire up Qwen and make it generate text | | | | | 1.1 | Attention | ✅ | ✅ | ✅ | | 1.2 | RoPE | ✅ | ✅ | ✅ | | 1.3 | Grouped Query Attention | ✅ | ✅ | ✅ | From 4e1cced107af7bbee7d8a7e340657ae8bd26f1e0 Mon Sep 17 00:00:00 2001 From: Eric Yue Date: Tue, 12 Aug 2025 08:01:30 +0800 Subject: [PATCH 25/79] fix: resolve f-string syntax error in batch.py (#44) Extract string replacement operation outside f-string expression to avoid backslash in f-string expression part, which is not allowed in Python syntax. - Move .replace('\n', ' ') operation to separate variable - Improves code readability and fixes SyntaxError --- src/tiny_llm_ref/batch.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/tiny_llm_ref/batch.py b/src/tiny_llm_ref/batch.py index 659a330..a0ff373 100644 --- a/src/tiny_llm_ref/batch.py +++ b/src/tiny_llm_ref/batch.py @@ -93,8 +93,9 @@ def _print_progress( if is_idle[i]: print(f" Decode #{i}: idle", flush=True) else: + text_preview = requests[i].text()[-80:].replace('\n', ' ') print( - f"{animation_frame} Decode [req {requests[i].prompt_idx}, {requests[i].offset}]: {requests[i].text()[-80:].replace('\n', ' ')}", + f"{animation_frame} Decode [req {requests[i].prompt_idx}, {requests[i].offset}]: {text_preview}", flush=True, ) if pending_prefill_request is not None: From 1d7572f0363584bf528d109880891891d29cc33e Mon Sep 17 00:00:00 2001 From: Alex Chi Z Date: Sun, 17 Aug 2025 14:40:18 -0400 Subject: [PATCH 26/79] remove offset in week 1, not used Signed-off-by: Alex Chi Z --- book/src/week1-03-gqa.md | 4 ++-- book/src/week1-06-generate-response.md | 14 ++++---------- src/tiny_llm_ref/qwen2_week1.py | 14 +++++--------- tests_refsol/test_week_1_day_3.py | 2 +- tests_refsol/test_week_1_day_5.py | 2 +- 5 files changed, 13 insertions(+), 23 deletions(-) diff --git a/book/src/week1-03-gqa.md b/book/src/week1-03-gqa.md index 2290a96..dcd9c6c 100644 --- a/book/src/week1-03-gqa.md +++ b/book/src/week1-03-gqa.md @@ -100,8 +100,8 @@ x: B, L, E q = linear(x, wq, bq) -> B, L, H_q, D k = linear(x, wk, bk) -> B, L, H, D v = linear(x, wv, bv) -> B, L, H, D -q = rope(q, offset=slice(offset, offset + L)) -k = rope(k, offset=slice(offset, offset + L)) +q = rope(q, offset=slice(0, L)) +k = rope(k, offset=slice(0, L)) (transpose as needed) x = scaled_dot_product_attention_grouped(q, k, v, scale, mask) -> B, L, H_q, D ; Do this at float32 precision (transpose as needed) diff --git a/book/src/week1-06-generate-response.md b/book/src/week1-06-generate-response.md index 8c67875..8aae630 100644 --- a/book/src/week1-06-generate-response.md +++ b/book/src/week1-06-generate-response.md @@ -13,12 +13,10 @@ src/tiny_llm/generate.py The `simple_generate` function takes a model, a tokenizer, and a prompt, and generates the response. The generation process is done in two parts: first prefill, and then decode. -First thing is to implement the `_step` sub-function. It takes a list of tokens `y`, and the offset of the first token -provided to the model. The model will return the logits: the probability distribution of the next token for each position. +First thing is to implement the `_step` sub-function. It takes a list of tokens `y`. The model will return the logits: the probability distribution of the next token for each position. ``` y: N.. x S, where in week 1 we don't implement batch, so N.. = 1 -offset: int output_logits: N.. x S x vocab_size ``` @@ -42,10 +40,6 @@ With the `_step` function implemented, you can now implement the full `simple_ge first prefill the model with the prompt. As the prompt is a string, you need to first convert it to a list of tokens by using the tokenizer `tokenizer.encode`. -* The prefill step is done by calling the `_step` function with all the tokens in the prompt with `offset=0`. It gives back -the first token in the response. -* The decode step is done by calling the `_step` function with all the previous tokens and the offset of the last token. - You will need to implement a while loop to keep generating the response until the model outputs the EOS `tokenizer.eos_token_id` token. In the loop, you will need to store all previous tokens in a list, and use the detokenizer `tokenizer.detokenizer` to print the response. @@ -53,9 +47,9 @@ An example of the sequences provided to the `_step` function is as below: ``` tokenized_prompt: [1, 2, 3, 4, 5, 6] -prefill: _step(model, [1, 2, 3, 4, 5, 6], 0) # returns 7 -decode: _step(model, [1, 2, 3, 4, 5, 6, 7], 7) # returns 8 -decode: _step(model, [1, 2, 3, 4, 5, 6, 7, 8], 8) # returns 9 +prefill: _step(model, [1, 2, 3, 4, 5, 6]) # returns 7 +decode: _step(model, [1, 2, 3, 4, 5, 6, 7]) # returns 8 +decode: _step(model, [1, 2, 3, 4, 5, 6, 7, 8]) # returns 9 ... ``` diff --git a/src/tiny_llm_ref/qwen2_week1.py b/src/tiny_llm_ref/qwen2_week1.py index d439653..21a3848 100644 --- a/src/tiny_llm_ref/qwen2_week1.py +++ b/src/tiny_llm_ref/qwen2_week1.py @@ -47,7 +47,6 @@ def __init__( def __call__( self, x: mx.array, - offset: int, mask: mx.array | str | None = None, ) -> mx.array: B, L, _ = x.shape @@ -60,8 +59,8 @@ def __call__( projection_v = linear(x, self.wv, bias=self.bv).reshape( B, L, self.num_kv_heads, self.head_dim ) - projection_q = self.rope(projection_q, offset=slice(offset, offset + L)) - projection_k = self.rope(projection_k, offset=slice(offset, offset + L)) + projection_q = self.rope(projection_q, offset=slice(0, L)) + projection_k = self.rope(projection_k, offset=slice(0, L)) projection_q = projection_q.transpose(0, 2, 1, 3) projection_k = projection_k.transpose(0, 2, 1, 3) projection_v = projection_v.transpose(0, 2, 1, 3) @@ -143,10 +142,9 @@ def __init__( def __call__( self, x: mx.array, - offset: int, mask: mx.array | str | None = None, ) -> mx.array: - r = self.self_attn(self.input_layernorm(x), offset, mask) + r = self.self_attn(self.input_layernorm(x), mask) h = x + r r = self.mlp(self.post_attention_layernorm(h)) out = h + r @@ -217,16 +215,14 @@ def __init__( self.w_lm_head = None self.mlx_model = mlx_model + def __call__( self, inputs: mx.array, - offset: int, ) -> mx.array: h = self.embedding(inputs) for layer in range(self.num_hidden_layers): - h = self.layers_inner[layer]( - h, offset, mask="causal" if h.shape[1] > 1 else None - ) + h = self.layers_inner[layer](h, mask="causal") h = self.norm(h) if self.w_lm_head is not None: return linear(h, self.w_lm_head) diff --git a/tests_refsol/test_week_1_day_3.py b/tests_refsol/test_week_1_day_3.py index 32c27ae..0b17c81 100644 --- a/tests_refsol/test_week_1_day_3.py +++ b/tests_refsol/test_week_1_day_3.py @@ -190,7 +190,7 @@ def test_task_3_qwen2_grouped_query_attention( theta=theta, ) - user_output = user_attention(x, offset=0, mask=mask) + user_output = user_attention(x, mask=mask) mlx_output = mlx_attention(x, mask=mask, cache=None) assert_allclose(user_output, mlx_output, precision=precision) diff --git a/tests_refsol/test_week_1_day_5.py b/tests_refsol/test_week_1_day_5.py index fca9183..7ac1d1e 100644 --- a/tests_refsol/test_week_1_day_5.py +++ b/tests_refsol/test_week_1_day_5.py @@ -32,7 +32,7 @@ def helper_test_task_3(model_name: str, iters: int = 10): model = Qwen2ModelWeek1(mlx_model) for _ in range(iters): input = mx.random.randint(low=0, high=tokenizer.vocab_size, shape=(1, 10)) - user_output = model(input, 0) + user_output = model(input) user_output = user_output - mx.logsumexp(user_output, keepdims=True) ref_output = mlx_model(input) ref_output = ref_output - mx.logsumexp(ref_output, keepdims=True) From 0b82b7f64a5ab74f573dea68ed2c8b4289932cb6 Mon Sep 17 00:00:00 2001 From: Alex Chi Z Date: Sun, 17 Aug 2025 15:00:38 -0400 Subject: [PATCH 27/79] add week2day1 kv cache contents Signed-off-by: Alex Chi Z --- README.md | 2 +- book/src/SUMMARY.md | 18 +-- book/src/week2-01-kv-cache.md | 183 ++++++++++++++++++++++++++++++ book/src/week2-overview.md | 21 ++++ src/tiny_llm/generate.py | 6 +- src/tiny_llm/kv_cache.py | 35 +++--- src/tiny_llm_ref/generate.py | 6 +- src/tiny_llm_ref/kv_cache.py | 14 ++- tests_refsol/test_week_2_day_1.py | 38 ------- 9 files changed, 252 insertions(+), 71 deletions(-) create mode 100644 book/src/week2-01-kv-cache.md diff --git a/README.md b/README.md index 7819533..ea3d2a1 100644 --- a/README.md +++ b/README.md @@ -37,7 +37,7 @@ Week 1 is complete. Week 2 is in progress. | 1.5 | Load the Model | ✅ | ✅ | ✅ | | 1.6 | Generate Responses (aka Decoding) | ✅ | ✅ | ✅ | | 1.7 | Sampling | ✅ | ✅ | ✅ | -| 2.1 | Key-Value Cache | ✅ | 🚧 | 🚧 | +| 2.1 | Key-Value Cache | ✅ | ✅ | ✅ | | 2.2 | Quantized Matmul and Linear - CPU | ✅ | ✅ | 🚧 | | 2.3 | Quantized Matmul and Linear - GPU | ✅ | ✅ | 🚧 | | 2.4 | Flash Attention 2 - CPU | ✅ | ✅ | 🚧 | diff --git a/book/src/SUMMARY.md b/book/src/SUMMARY.md index 23acefa..502a072 100644 --- a/book/src/SUMMARY.md +++ b/book/src/SUMMARY.md @@ -13,18 +13,12 @@ - [The Qwen2 Model](./week1-05-qwen2-model.md) - [Generating the Response](./week1-06-generate-response.md) - [Sampling and Preparing for Week 2](./week1-07-sampling-prepare.md) - - -- [Week 2: Optimizing]() - +- [Week 2: Tiny vLLM](./week2-overview.md) + - [Key-Value Cache](./week2-01-kv-cache.md) + - [Quantized Matmul (2 Days)]() + - [Flash Attention (2 Days)]() + - [Chunked Prefill]() + - [Continuous Batching]() - [Week 3: Serving]() --- diff --git a/book/src/week2-01-kv-cache.md b/book/src/week2-01-kv-cache.md new file mode 100644 index 0000000..861d89c --- /dev/null +++ b/book/src/week2-01-kv-cache.md @@ -0,0 +1,183 @@ +# Week 2 Day 1: Key-Value Cache + +In this chapter, we will implement the **key-value cache** for the Qwen2 model. The key-value cache is an essential component of the attention mechanism, as it allows the model to reuse previously computed results instead of recomputing them for every new token. + +**📚 Readings** + +- [KV Caching Explained: Optimizing Transformer Inference Efficiency](https://huggingface.co/blog/not-lain/kv-caching) + +Recall from last week how we supplied data to the model: + +```plain +tokenized_prompt: [1, 2, 3, 4, 5, 6] +prefill: _step(model, [1, 2, 3, 4, 5, 6]) # returns 7 +decode: _step(model, [1, 2, 3, 4, 5, 6, 7]) # returns 8 +decode: _step(model, [1, 2, 3, 4, 5, 6, 7, 8]) # returns 9 +... +``` + +```plain +x: B, L, E +q = linear(x, wq, bq) -> B, L, H_q, D +k = linear(x, wk, bk) -> B, L, H, D +v = linear(x, wv, bv) -> B, L, H, D +q = rope(q, offset=slice(offset, offset + L)) +k = rope(k, offset=slice(offset, offset + L)) +(transpose as needed) +x = scaled_dot_product_attention_grouped(q, k, v, scale, mask) -> B, L, H_q, D # at float32 precision +(transpose as needed) +x = linear(x, wo) -> B, L, E +``` + +The attention mechanism is computed as: + +$$ + \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}} + M\right)V +$$ + + +Consider two consecutive decoding steps with `L = 3` and `L = 4`, where each head in each layer has an embedding dim of `D = 4`: + +``` +L = 3 +Q x K^T = +1 1 1 1 1 1 1 1 1x1 1x2 1x3 1x4 +2 2 2 2 2 2 2 2 2x1 2x2 2x3 2x4 +3 3 3 3 3 3 3 3 3x1 3x2 3x3 3x4 + +L = 4 +Q x K^T = +1 1 1 1 1 1 1 1 1x1 1x2 1x3 1x4 +2 2 2 2 2 2 2 2 2x1 2x2 2x3 2x4 +3 3 3 3 3 3 3 3 3x1 3x2 3x3 3x4 +4 4 4 4 4 4 4 4 4x1 4x2 4x3 4x4 +``` + +Notice that the first three rows of `Q × K^T` are identical in both steps. The same redundancy applies to the softmax and `V` multiplication. This means we are unnecessarily recomputing results for tokens we’ve already processed. + +The solution is to cache the K and V matrices and only compute new values for incoming tokens: + +``` +K in cache: +1 1 1 1 +2 2 2 2 + +[a b c d] represent cached values + +L = 3 +Q x K^T = + [1 1 1 1] + [2 2 2 2] +3 3 3 3 3 3 3 3 3x1 3x2 3x3 3x4 + +L = 4 +Q x K^T = + [1 1 1 1] + [2 2 2 2] + [3 3 3 3] +4 4 4 4 4 4 4 4 4x1 4x2 4x3 4x4 +``` + +## Task 1: Implement the Key-Value Cache + +``` +src/tiny_llm/kv_cache.py +``` + +Each layer in the model maintains its own key-value cache. The cache has a single API, `update_and_fetch`, which: + +1. Takes the newly computed `K` and `V` for incoming tokens. +2. Concatenates them with the existing cached matrices. +3. Returns the full cached `K` and `V`. + +For week 2 day 1, you only need to handle `key` and `value`. The `mask` and `mask_length` parameters will remain unused. + +You may implement this in `kv_cache.py` as `TinyKvFullCache`: + +```plain +L' = new tokens length +L = total tokens length + +update_and_fetch(key, value) -> key, value + +key: B, L', H, D +value: B, L', H, D + +self.key = concat_or_initialize(self.key, key, on the L' dimension) +self.value = concat_or_initialize(self.value, value, on the L' dimension) + +self.key: B, L, H, D +self.value: B, L, H, D + +return self.key, self.value +``` + +## Task 2: Use the Key-Value Cache + +``` +src/tiny_llm/qwen2_week2.py +``` + +With the cache in place, update your week 1 Qwen2 implementation to support it. Implement the `Qwen2MultiHeadAttention` class in `qwen2_week2.py`. + +* Each layer should use its own cache. +* The model must now accept an `offset` argument, which represents the position of the last token processed. +* This value should match the current sequence length in the cache (you can add assertions to check consistency). +* Both the argument and the cache maintain the offset for debugging purposes. + +Example computation flow: + +```plain +x: B, L', E +q = linear(x, wq, bq) -> B, L', H_q, D +k = linear(x, wk, bk) -> B, L', H, D +v = linear(x, wv, bv) -> B, L', H, D +q = rope(q, offset=slice(offset, offset + L')) +k = rope(k, offset=slice(offset, offset + L')) +(transpose as needed) +k, v = cache.update_and_fetch(k, v) +x = scaled_dot_product_attention_grouped(q, k, v, scale, mask) -> B, L, H_q, D # at float32 precision +(transpose as needed) +x = linear(x, wo) -> B, L, E +``` + +## Task 3: Implement the Model + +``` +src/tiny_llm/qwen2_week2.py +``` + +Complete the rest of the model using your week 1 implementation as a base, but modify all relevant components to use the key-value cache. + +To verify correctness, run the following test (almost identical to week 1’s test): + +```bash +pdm run test --week 2 --day 1 +``` + +## Task 4: Implement Decoding + +``` +src/tiny_llm/generate.py +``` + +Next, implement the decoding logic in `generate.py` by completing the `simple_generate_with_kv_cache` function. This function should call your Week 2 Qwen2 model with both the `offset` and the newly decoded token. + +For example: + +```plain +tokenized_prompt: [1, 2, 3, 4, 5, 6] +prefill: _step(model, [1, 2, 3, 4, 5, 6], 0) # returns 7 +decode: _step(model, [7], 7) # returns 8 +decode: _step(model, [8], 8) # returns 9 +... +``` + +You can test your implementation with: + +```bash +pdm run main --solution tiny_llm --loader week2 --model qwen2-0.5b +pdm run main --solution tiny_llm --loader week2 --model qwen2-7b +``` + +{{#include copyright.md}} diff --git a/book/src/week2-overview.md b/book/src/week2-overview.md index e008678..f4df873 100644 --- a/book/src/week2-overview.md +++ b/book/src/week2-overview.md @@ -1,3 +1,23 @@ +# Week 2: Tiny vLLM + +In Week 2 of the course, we will focus on building serving infrastructure for the Qwen2 model. Essentially, this means creating a minimal version of the vLLM project from scratch. By the end of the week, you’ll be able to serve the Qwen2 model efficiently on your Apple Silicon device using the infrastructure we’ve built together. + +## What We’ll Cover + +* Key-value cache implementation +* C++/Metal kernels + * Implementing a quantized matmul kernel + * Implementing a flash attention kernel + * Note: This week, we won’t focus on performance optimization. The kernels you build will likely be around 10x slower than MLX implementations. Optimizing them will be left as an exercise. +* Model serving infrastructure + * Implementing chunked prefill + * Implementing continuous batching + +Additionally, the repo includes skeleton code for the Qwen3 model. If your device supports the bfloat16 data type (note: M1 chips do not), you’re encouraged to try implementing it and experiment with the Qwen3-series models as well. + +{{#include copyright.md}} + + diff --git a/src/tiny_llm/generate.py b/src/tiny_llm/generate.py index 3bdfdba..c201d9f 100644 --- a/src/tiny_llm/generate.py +++ b/src/tiny_llm/generate.py @@ -11,13 +11,15 @@ def simple_generate( prompt: str, sampler: Callable[[mx.array], mx.array] | None, ) -> str: - pass + def _step(model, y): + pass def simple_generate_with_kv_cache( model: Qwen2ModelWeek2, tokenizer: TokenizerWrapper, prompt: str ) -> str: - pass + def _step(model, y, offset, kv_cache): + pass def batch_generate( diff --git a/src/tiny_llm/kv_cache.py b/src/tiny_llm/kv_cache.py index 3ee328b..5ea3434 100644 --- a/src/tiny_llm/kv_cache.py +++ b/src/tiny_llm/kv_cache.py @@ -2,14 +2,31 @@ import mlx.core as mx - class TinyKvCache: def update_and_fetch( - self, key: mx.array, value: mx.array - ) -> tuple[mx.array, mx.array, int]: + self, + key: mx.array, + value: mx.array, + mask_length: int | None = None, + mask: mx.array | str | None = None, + ) -> tuple[mx.array, mx.array, int, Optional[mx.array]]: + """ + Update the key-value cache and fetch the updated key-value cache. + + Args: + key: The key to update the cache with. + value: The value to update the cache with. + mask_length: The length of the mask (only used in batching mode) + mask: The mask to use (only used in batching mode) + + Returns: + A tuple of the updated key-value cache, the updated value, the sequence length, and the mask. + In week 2 day 1, we only need to return the updated key-value cache, the updated value. + In week 2 day 6/7, we need to return the updated key-value cache, the updated value, the sequence length, and the mask. + so that the batching kv cache can use this information to generate the mask. + """ pass - class BatchingKvCache(TinyKvCache): def __init__(self, max_active_requests: int, max_seq_len: int): pass @@ -34,13 +51,3 @@ def update_and_fetch( self, key: mx.array, value: mx.array ) -> tuple[mx.array, mx.array, int]: pass - - -class TinyKvRotatingCache(TinyKvCache): - def __init__(self, max_seq_len: int): - pass - - def update_and_fetch( - self, key: mx.array, value: mx.array, offset: int - ) -> tuple[mx.array, mx.array]: - pass diff --git a/src/tiny_llm_ref/generate.py b/src/tiny_llm_ref/generate.py index 82e59a8..bfb8bc1 100644 --- a/src/tiny_llm_ref/generate.py +++ b/src/tiny_llm_ref/generate.py @@ -12,8 +12,8 @@ def simple_generate( prompt: str, sampler: Callable[[mx.array], mx.array] | None, ) -> str: - def _step(model, y, offset): - logits = model(y[None], offset) + def _step(model, y): + logits = model(y[None]) logits = logits[:, -1, :] logprobs = logits - mx.logsumexp( logits, keepdims=True @@ -30,7 +30,7 @@ def _step(model, y, offset): detokenizer.reset() # generate/decode while True: - token = _step(model, tokens, tokens.size) + token = _step(model, tokens) mx.eval(token) tokens = mx.concat([tokens, token]) if token.item() == tokenizer.eos_token_id: diff --git a/src/tiny_llm_ref/kv_cache.py b/src/tiny_llm_ref/kv_cache.py index 418f996..605de88 100644 --- a/src/tiny_llm_ref/kv_cache.py +++ b/src/tiny_llm_ref/kv_cache.py @@ -11,7 +11,19 @@ def update_and_fetch( value: mx.array, mask_length: int | None = None, mask: mx.array | str | None = None, - ) -> tuple[mx.array, mx.array, int, mx.array]: + ) -> tuple[mx.array, mx.array, int, Optional[mx.array]]: + """ + Update the key-value cache and fetch the updated key-value cache. + + Args: + key: The key to update the cache with. + value: The value to update the cache with. + mask_length: The length of the mask (only used in batching mode) + mask: The mask to use (only used in batching mode) + + Returns: + A tuple of the updated key-value cache, the updated value, the sequence length, and the mask. + """ pass diff --git a/tests_refsol/test_week_2_day_1.py b/tests_refsol/test_week_2_day_1.py index 3372a0b..4f75028 100644 --- a/tests_refsol/test_week_2_day_1.py +++ b/tests_refsol/test_week_2_day_1.py @@ -2,9 +2,6 @@ from .utils import * from .tiny_llm_base import ( Qwen2ModelWeek2, - Embedding, - dequantize_linear, - qwen2_week2, TinyKvFullCache, ) from mlx_lm import load @@ -45,41 +42,6 @@ def helper_test_task_3(model_name: str, iters: int = 10): ref_output = ref_output - mx.logsumexp(ref_output, keepdims=True) assert_allclose(user_output, ref_output, precision=mx.float16, rtol=1e-1) - -@pytest.mark.skipif( - not qwen_2_05b_model_exists(), reason="Qwen2-0.5B-Instruct-MLX model not found" -) -def test_task_2_embedding_call(): - mlx_model, _ = load("Qwen/Qwen2-0.5B-Instruct-MLX") - embedding = Embedding( - mlx_model.args.vocab_size, - mlx_model.args.hidden_size, - dequantize_linear(mlx_model.model.embed_tokens).astype(mx.float16), - ) - for _ in range(50): - input = mx.random.randint(low=0, high=mlx_model.args.vocab_size, shape=(1, 10)) - user_output = embedding(input) - ref_output = mlx_model.model.embed_tokens(input) - assert_allclose(user_output, ref_output, precision=mx.float16) - - -@pytest.mark.skipif( - not qwen_2_05b_model_exists(), reason="Qwen2-0.5B-Instruct-MLX model not found" -) -def test_task_2_embedding_as_linear(): - mlx_model, _ = load("Qwen/Qwen2-0.5B-Instruct-MLX") - embedding = Embedding( - mlx_model.args.vocab_size, - mlx_model.args.hidden_size, - dequantize_linear(mlx_model.model.embed_tokens).astype(mx.float16), - ) - for _ in range(50): - input = mx.random.uniform(shape=(1, 10, mlx_model.args.hidden_size)) - user_output = embedding.as_linear(input) - ref_output = mlx_model.model.embed_tokens.as_linear(input) - assert_allclose(user_output, ref_output, precision=mx.float16, atol=1e-1) - - @pytest.mark.skipif( not qwen_2_05b_model_exists(), reason="Qwen2-0.5B-Instruct-MLX model not found" ) From 45cff24be9054ba6d58208ef7529425982d71780 Mon Sep 17 00:00:00 2001 From: Alex Chi Z Date: Sun, 17 Aug 2025 15:00:48 -0400 Subject: [PATCH 28/79] update benches Signed-off-by: Alex Chi Z --- benches/test_attention.py | 37 ++++++++++++++++++++++++++++++++ benches/test_quantized_matmul.py | 4 ++-- 2 files changed, 39 insertions(+), 2 deletions(-) create mode 100644 benches/test_attention.py diff --git a/benches/test_attention.py b/benches/test_attention.py new file mode 100644 index 0000000..ced71c5 --- /dev/null +++ b/benches/test_attention.py @@ -0,0 +1,37 @@ +import mlx.core as mx +import mlx.nn as nn +import tiny_llm_ref +from .utils import assert_allclose +import pytest + +def get_test_attention_data(): + # Qwen2 7B matrix size + init = nn.init.he_uniform(mx.float32) + q = init(mx.zeros((10, 28, 1024, 128))) + k = init(mx.zeros((10, 4, 1024, 128))) + v = init(mx.zeros((10, 4, 1024, 128))) + res = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0) + return q, k, v, res + +def test_mlx_attention(benchmark): + with mx.stream(mx.gpu): + q, k, v, res = get_test_attention_data() + result = benchmark(lambda: mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0)) + assert_allclose(result, res, precision=mx.float32, rtol=1e-2) + + +def test_refsol_attention(benchmark): + with mx.stream(mx.gpu): + q, k, v, res = get_test_attention_data() + result = benchmark( + lambda: tiny_llm_ref.scaled_dot_product_attention_grouped(q, k, v, scale=1.0) + ) + assert_allclose(result, res, precision=mx.float32, rtol=1e-2) + +def test_refsol_flash_attention(benchmark): + with mx.stream(mx.gpu): + q, k, v, res = get_test_attention_data() + result = benchmark( + lambda: tiny_llm_ref.flash_attention(q, k, v, scale=1.0) + ) + assert_allclose(result, res, precision=mx.float32, rtol=1e-2) diff --git a/benches/test_quantized_matmul.py b/benches/test_quantized_matmul.py index eb0007a..ceb60c9 100644 --- a/benches/test_quantized_matmul.py +++ b/benches/test_quantized_matmul.py @@ -21,7 +21,7 @@ def test_mlx_quantized_matmul(benchmark): result = benchmark( lambda: mx.quantized_matmul(x, w_q, scales=scales, biases=biases) ) - assert_allclose(result, res, precision=np.float16, rtol=1e-2) + assert_allclose(result, res, precision=mx.float16, rtol=1e-2) def test_refsol_quantized_matmul(benchmark): @@ -32,4 +32,4 @@ def test_refsol_quantized_matmul(benchmark): scales, biases, 64, 4, x, w_q, transpose_b=True ) ) - assert_allclose(result, res, precision=np.float16, rtol=1e-2) + assert_allclose(result, res, precision=mx.float16, rtol=1e-2) From 30b68a92eb291c057261e781dde2ee015fffb5ff Mon Sep 17 00:00:00 2001 From: Alex Chi Z Date: Mon, 18 Aug 2025 02:00:30 -0400 Subject: [PATCH 29/79] small fix about dim Signed-off-by: Alex Chi Z --- book/src/week2-01-kv-cache.md | 31 +++++++++++++++++-------------- tests_refsol/test_week_2_day_1.py | 2 +- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/book/src/week2-01-kv-cache.md b/book/src/week2-01-kv-cache.md index 861d89c..b2ac240 100644 --- a/book/src/week2-01-kv-cache.md +++ b/book/src/week2-01-kv-cache.md @@ -36,24 +36,25 @@ $$ $$ -Consider two consecutive decoding steps with `L = 3` and `L = 4`, where each head in each layer has an embedding dim of `D = 4`: +Consider two consecutive decoding steps with `L = S = 3` and `L = S = 4`, where each head in each layer has an embedding dim of `D = 4`: ``` L = 3 -Q x K^T = -1 1 1 1 1 1 1 1 1x1 1x2 1x3 1x4 -2 2 2 2 2 2 2 2 2x1 2x2 2x3 2x4 -3 3 3 3 3 3 3 3 3x1 3x2 3x3 3x4 +Q x K^T = +1 1 1 1 1 2 3 1x1 -inf -inf +2 2 2 2 1 2 3 2x1 2x2 -inf +3 3 3 3 1 2 3 3x1 3x2 -inf + 1 2 3 L = 4 Q x K^T = -1 1 1 1 1 1 1 1 1x1 1x2 1x3 1x4 -2 2 2 2 2 2 2 2 2x1 2x2 2x3 2x4 -3 3 3 3 3 3 3 3 3x1 3x2 3x3 3x4 -4 4 4 4 4 4 4 4 4x1 4x2 4x3 4x4 +1 1 1 1 1 2 3 4 1x1 -inf -inf -inf +2 2 2 2 1 2 3 4 2x1 2x2 -inf -inf +3 3 3 3 1 2 3 4 3x1 3x2 3x3 -inf +4 4 4 4 1 2 3 4 4x1 4x2 4x3 4x4 ``` -Notice that the first three rows of `Q × K^T` are identical in both steps. The same redundancy applies to the softmax and `V` multiplication. This means we are unnecessarily recomputing results for tokens we’ve already processed. +Notice that the first three rows/cols of `Q × K^T` are identical in both steps. Also given that we are using the causal masks, we do not need to care about the upper triangle of the matrix. The same applies to the softmax function and the multiplication with the V matrix. This means we are unnecessarily recomputing results for tokens we’ve already processed, and the new information only comes from the last row of `Q * K^T`. The solution is to cache the K and V matrices and only compute new values for incoming tokens: @@ -64,14 +65,16 @@ K in cache: [a b c d] represent cached values -L = 3 +L = 1, S = 3 Q x K^T = + (⬇️ is K not transposed) [1 1 1 1] [2 2 2 2] -3 3 3 3 3 3 3 3 3x1 3x2 3x3 3x4 +3 3 3 3 3 3 3 3 3x1 3x2 3x3 -L = 4 -Q x K^T = +L = 1, S = 4 +Q x K^T = + (⬇️ is K not transposed) [1 1 1 1] [2 2 2 2] [3 3 3 3] diff --git a/tests_refsol/test_week_2_day_1.py b/tests_refsol/test_week_2_day_1.py index 4f75028..3032b86 100644 --- a/tests_refsol/test_week_2_day_1.py +++ b/tests_refsol/test_week_2_day_1.py @@ -40,7 +40,7 @@ def helper_test_task_3(model_name: str, iters: int = 10): user_output = user_output - mx.logsumexp(user_output, keepdims=True) ref_output = mlx_model(input) ref_output = ref_output - mx.logsumexp(ref_output, keepdims=True) - assert_allclose(user_output, ref_output, precision=mx.float16, rtol=1e-1) + assert_allclose(user_output, ref_output, precision=mx.float16, rtol=0.1, atol=0.5) @pytest.mark.skipif( not qwen_2_05b_model_exists(), reason="Qwen2-0.5B-Instruct-MLX model not found" From 042acf5516b5c3a125e1ad3718ffa6ea27fea6a9 Mon Sep 17 00:00:00 2001 From: Alex Chi Z Date: Mon, 18 Aug 2025 02:08:32 -0400 Subject: [PATCH 30/79] clearify variables Signed-off-by: Alex Chi Z --- book/src/week2-01-kv-cache.md | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/book/src/week2-01-kv-cache.md b/book/src/week2-01-kv-cache.md index b2ac240..4f34753 100644 --- a/book/src/week2-01-kv-cache.md +++ b/book/src/week2-01-kv-cache.md @@ -138,12 +138,17 @@ v = linear(x, wv, bv) -> B, L', H, D q = rope(q, offset=slice(offset, offset + L')) k = rope(k, offset=slice(offset, offset + L')) (transpose as needed) -k, v = cache.update_and_fetch(k, v) -x = scaled_dot_product_attention_grouped(q, k, v, scale, mask) -> B, L, H_q, D # at float32 precision +k, v = cache.update_and_fetch(k, v) ; k/v: B, L, H, D, q: B, L', H, D +x = scaled_dot_product_attention_grouped(q, k, v, scale, mask) -> B, L', H_q, D # at float32 precision (transpose as needed) -x = linear(x, wo) -> B, L, E +x = linear(x, wo) -> B, L', E ``` +We use two different variables for the `L'` because they have different meanings in the context of this chapter +and the context of week 1 day 3: in the GQA implementation, k/v's sequence length is `S` (source length), while +q's sequence length is `L`. In the Qwen2 multihead attention implementation, `L'` is the "new token" and `L` is +the total sequence length, which corresponds to `L` and `S` in week 1 respectively. + ## Task 3: Implement the Model ``` From 4cddec27312a914c74513bc501b19b3ed01d8533 Mon Sep 17 00:00:00 2001 From: touale <68891321+touale@users.noreply.github.com> Date: Tue, 19 Aug 2025 10:45:18 +0800 Subject: [PATCH 31/79] fix: Add flash attention option and fix token offset (#46) --- src/tiny_llm/qwen2_week2.py | 6 +++++- src/tiny_llm_ref/generate.py | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/tiny_llm/qwen2_week2.py b/src/tiny_llm/qwen2_week2.py index a70b7ba..012da01 100644 --- a/src/tiny_llm/qwen2_week2.py +++ b/src/tiny_llm/qwen2_week2.py @@ -88,7 +88,11 @@ def __call__( class Qwen2ModelWeek2: - def __init__(self, mlx_model: Any): + def __init__( + self, + mlx_model: Any, + enable_flash_attn: bool = False, + ): pass def __call__( diff --git a/src/tiny_llm_ref/generate.py b/src/tiny_llm_ref/generate.py index bfb8bc1..6338579 100644 --- a/src/tiny_llm_ref/generate.py +++ b/src/tiny_llm_ref/generate.py @@ -65,5 +65,5 @@ def _step(model, y, offset, kv_cache): print(detokenizer.last_segment, end="", flush=True) if token.item() == tokenizer.eos_token_id: break - offset += tokens.size + offset += token.size tokens = token From 25068d47af231de65eebbfefe06d953e70aa7122 Mon Sep 17 00:00:00 2001 From: Zhen Tong <88566085+58191554@users.noreply.github.com> Date: Tue, 19 Aug 2025 18:31:54 -0400 Subject: [PATCH 32/79] Fix: remove offset parameter to Qwen2MultiHeadAttention.__call__ method in tiny-llm/ (#49) --- src/tiny_llm/qwen2_week1.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/tiny_llm/qwen2_week1.py b/src/tiny_llm/qwen2_week1.py index fcb9bfe..e196fb1 100644 --- a/src/tiny_llm/qwen2_week1.py +++ b/src/tiny_llm/qwen2_week1.py @@ -29,7 +29,6 @@ def __init__( def __call__( self, x: mx.array, - offset: int, mask: mx.array | str | None = None, ) -> mx.array: pass @@ -78,7 +77,6 @@ def __init__( def __call__( self, x: mx.array, - offset: int, mask: mx.array | str | None = None, ) -> mx.array: pass @@ -91,6 +89,5 @@ def __init__(self, mlx_model: Any): def __call__( self, inputs: mx.array, - offset: int, ) -> mx.array: pass From 4a4c752d382b544025d867af0717e40497524fbb Mon Sep 17 00:00:00 2001 From: Gunther Xing Date: Wed, 20 Aug 2025 06:32:12 +0800 Subject: [PATCH 33/79] Fix broken url links of MultiHeadAttention in week1-01-attention.md (#47) --- book/src/week1-01-attention.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/book/src/week1-01-attention.md b/book/src/week1-01-attention.md index 30ba57c..5f73efd 100644 --- a/book/src/week1-01-attention.md +++ b/book/src/week1-01-attention.md @@ -88,8 +88,8 @@ src/tiny_llm/attention.py **📚 Readings** * [Annotated Transformer](https://nlp.seas.harvard.edu/annotated-transformer/) -* [PyTorch SimpleMultiHeadAttention API](https://pytorch.org/docs/stable/generated/torch.nn.SimpleMultiHeadAttention.html) (assume dim_k=dim_v=dim_q and H_k=H_v=H_q) -* [MLX SimpleMultiHeadAttention API](https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary/mlx.nn.SimpleMultiHeadAttention.html) (assume dim_k=dim_v=dim_q and H_k=H_v=H_q) +* [PyTorch MultiHeadAttention API](https://docs.pytorch.org/docs/2.8/generated/torch.nn.MultiheadAttention.html) (assume dim_k=dim_v=dim_q and H_k=H_v=H_q) +* [MLX MultiHeadAttention API](https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary/mlx.nn.MultiHeadAttention.html) (assume dim_k=dim_v=dim_q and H_k=H_v=H_q) * [The Illustrated GPT-2 (Visualizing Transformer Language Models)](https://jalammar.github.io/illustrated-gpt2) helps you better understand what key, value, and query are. Implement `SimpleMultiHeadAttention`. The layer takes a batch of vectors, maps it through the K, V, Q weight matrixes, and use the attention function we implemented in task 1 to compute the result. The output needs to be mapped using the O From 4ae2ad1fc530986d60c5468df184d2ce880f393d Mon Sep 17 00:00:00 2001 From: Zhen Tong <88566085+58191554@users.noreply.github.com> Date: Thu, 21 Aug 2025 00:53:21 -0400 Subject: [PATCH 34/79] Fix MLX Metal API usage and Primitive interface for Axpby; restore successful build (#51) --- src/extensions/src/axpby.cpp | 10 +++++++--- src/extensions/src/axpby.h | 5 ++++- src/extensions/src/utils.cpp | 2 +- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/src/extensions/src/axpby.cpp b/src/extensions/src/axpby.cpp index e70027e..ee67a23 100644 --- a/src/extensions/src/axpby.cpp +++ b/src/extensions/src/axpby.cpp @@ -147,9 +147,8 @@ void Axpby::eval_gpu(const std::vector &inputs, std::vector &inputs, std::vector Axpby::jvp(const std::vector &primals, const std::vector &tangents, const std::vector &argnums) { diff --git a/src/extensions/src/axpby.h b/src/extensions/src/axpby.h index 7a56dd2..5ad87b3 100644 --- a/src/extensions/src/axpby.h +++ b/src/extensions/src/axpby.h @@ -63,7 +63,10 @@ class Axpby : public mx::Primitive { const std::vector &axes) override; /** Print the primitive. */ - const char* name() const override { return "Axpby"; } + void print(std::ostream &os) override; + + /** Name of the primitive (not virtual in some MLX versions). */ + const char* name() const { return "Axpby"; } /** Equivalence check **/ bool is_equivalent(const mx::Primitive &other) const override; diff --git a/src/extensions/src/utils.cpp b/src/extensions/src/utils.cpp index 3b52303..97a5b80 100644 --- a/src/extensions/src/utils.cpp +++ b/src/extensions/src/utils.cpp @@ -9,7 +9,7 @@ namespace tiny_llm_ext { void load_library(mx::Device d, const char *path) { #ifdef _METAL_ auto &md = mx::metal::device(d); - md.get_library("tiny_llm_ext", path); + md.register_library("tiny_llm_ext", path); #endif } From efe008a315548a96d08453a455a59189eaed56db Mon Sep 17 00:00:00 2001 From: Weihua Cheung Date: Thu, 21 Aug 2025 13:32:30 -0700 Subject: [PATCH 35/79] s/consequtive/consecutive (#52) --- book/src/week1-02-positional-encodings.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/book/src/week1-02-positional-encodings.md b/book/src/week1-02-positional-encodings.md index 7e14dd3..a5f4ef0 100644 --- a/book/src/week1-02-positional-encodings.md +++ b/book/src/week1-02-positional-encodings.md @@ -33,7 +33,7 @@ x: (N, L, H, D) cos/sin_freqs: (MAX_SEQ_LEN, D // 2) ``` -In the traditional form of RoPE, each head on the dimension of `D` is viewed as consequtive complex pairs. That is to +In the traditional form of RoPE, each head on the dimension of `D` is viewed as consecutive complex pairs. That is to say, if D = 8, then, x[0] and x[1] are a pair, x[2] and x[3] are another pair, and so on. A pair gets the same frequency from `cos/sin_freqs`. From a55a92f3345359ba0eda81b1f1fabb7d22d6956b Mon Sep 17 00:00:00 2001 From: Liu Jinyi <1206668472@qq.com> Date: Fri, 22 Aug 2025 04:32:44 +0800 Subject: [PATCH 36/79] docs: fix some typos (#53) --- book/src/preface.md | 2 +- book/src/setup.md | 2 +- book/src/week1-02-positional-encodings.md | 5 ++--- book/src/week1-04-rmsnorm-and-mlp.md | 4 ++-- book/src/week1-05-qwen2-model.md | 4 +--- 5 files changed, 7 insertions(+), 10 deletions(-) diff --git a/book/src/preface.md b/book/src/preface.md index f9e2450..b19731e 100644 --- a/book/src/preface.md +++ b/book/src/preface.md @@ -20,7 +20,7 @@ resources are: ## Environment Setup -This course uses [MLX](https://github.com/ml-explore/mlx), an array/machine learning library for Apple Silicon. Nowaways +This course uses [MLX](https://github.com/ml-explore/mlx), an array/machine learning library for Apple Silicon. Nowadays it's much easier to get an Apple Silicon device than NVIDIA GPUs. In theory you can also do this course with PyTorch or numpy, but we just don't have the test infra to support them. We test your implementation against PyTorch's CPU implementation and MLX's implementation to ensure correctness. diff --git a/book/src/setup.md b/book/src/setup.md index dedf1c3..7139808 100644 --- a/book/src/setup.md +++ b/book/src/setup.md @@ -4,7 +4,7 @@ To follow along this course, you will need a Macintosh device with Apple Silicon ## Install pdm -Please follow the [offcial guide](https://pdm-project.org/en/latest/) to install pdm. +Please follow the [official guide](https://pdm-project.org/en/latest/) to install pdm. ## Clone the Repository diff --git a/book/src/week1-02-positional-encodings.md b/book/src/week1-02-positional-encodings.md index a5f4ef0..176e4fe 100644 --- a/book/src/week1-02-positional-encodings.md +++ b/book/src/week1-02-positional-encodings.md @@ -1,6 +1,6 @@ # Week 1 Day 2: Positional Encodings and RoPE -In day 2, we will implement the positional embedding used in the Qwen2 model: Rotary Postional Encoding. In a transformer +In day 2, we will implement the positional embedding used in the Qwen2 model: Rotary Positional Encoding. In a transformer model, we need a way to embed the information of the position of a token into the input of the attention layers. In Qwen2, positional embedding is applied within the multi head attention layer on the query and key vectors. @@ -9,7 +9,7 @@ positional embedding is applied within the multi head attention layer on the que - [You could have designed state of the art positional encoding](https://huggingface.co/blog/designing-positional-encoding) - [Roformer: Enhanced Transformer with Rotary Positional Encoding](https://arxiv.org/pdf/2104.09864) -## Task 1: Implement Rotary Postional Encoding "RoPE" +## Task 1: Implement Rotary Positional Encoding "RoPE" You will need to modify the following file: @@ -81,7 +81,6 @@ frequencies to each half separately. - [vLLM implementation of RoPE](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/rotary_embedding) - You can test your implementation by running the following command: ``` diff --git a/book/src/week1-04-rmsnorm-and-mlp.md b/book/src/week1-04-rmsnorm-and-mlp.md index 2c0c7d4..2ad38aa 100644 --- a/book/src/week1-04-rmsnorm-and-mlp.md +++ b/book/src/week1-04-rmsnorm-and-mlp.md @@ -66,7 +66,7 @@ Modern Transformer architectures, including Qwen2, often employ more advanced FF * [PyTorch SiLU documentation](https://pytorch.org/docs/stable/generated/torch.nn.SiLU.html) * [Qwen2 layers implementation in mlx-lm (includes MLP)](https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/models/qwen2.py) -Essientially, SwiGLU is a combination of GLU and the SiLU (Sigmoid Linear Unit) activation function: +Essentially, SwiGLU is a combination of GLU and the SiLU (Sigmoid Linear Unit) activation function: - GLU is a gating mechanism that allows the model to learn which parts of the input to focus on. It typically involves an element-wise product of two linear projections of the input, one of which might be passed through an activation function. Compared to ReLU used in the original FFN, GLU can help the model learn more complex relationships in the data, deciding which features to keep and which to discard. - SiLU (Sigmoid Linear Unit) is a smooth, non-monotonic activation function that has been shown to perform well in various deep learning tasks. Compared to ReLU and sigmoid used in GLU, it is fully differentiable without the zero-gradient “dead zones”, retains non-zero output even for negative inputs. @@ -116,4 +116,4 @@ pdm run test --week 1 --day 4 ``` -{{#include copyright.md}} +{{#include copyright.md}} \ No newline at end of file diff --git a/book/src/week1-05-qwen2-model.md b/book/src/week1-05-qwen2-model.md index df59346..979b347 100644 --- a/book/src/week1-05-qwen2-model.md +++ b/book/src/week1-05-qwen2-model.md @@ -62,7 +62,7 @@ src/tiny_llm/embedding.py - [LLM Embeddings Explained: A Visual and Intuitive Guide](https://huggingface.co/spaces/hesamation/primer-llm-embedding) -The embedding layer maps one or more tokens (represented as an interger) to one or more vector of dimension `embedding_dim`. +The embedding layer maps one or more tokens (represented as an integer) to one or more vector of dimension `embedding_dim`. In this task, you will implement the embedding layer. ``` @@ -148,7 +148,6 @@ in the next day. You should pass all tests for this task by running: - ```bash # Download the models if you haven't done so huggingface-cli download Qwen/Qwen2-0.5B-Instruct-MLX @@ -163,5 +162,4 @@ At the end of the day, you should be able to pass all tests of this day: pdm run test --week 1 --day 5 ``` - {{#include copyright.md}} From e05179090b01aaac679b44b4802aef45b0ea3268 Mon Sep 17 00:00:00 2001 From: Zhen Tong <88566085+58191554@users.noreply.github.com> Date: Thu, 21 Aug 2025 16:33:04 -0400 Subject: [PATCH 37/79] fix typo in week2-01 (#54) --- book/src/week2-01-kv-cache.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/book/src/week2-01-kv-cache.md b/book/src/week2-01-kv-cache.md index 4f34753..ab02158 100644 --- a/book/src/week2-01-kv-cache.md +++ b/book/src/week2-01-kv-cache.md @@ -43,7 +43,7 @@ L = 3 Q x K^T = 1 1 1 1 1 2 3 1x1 -inf -inf 2 2 2 2 1 2 3 2x1 2x2 -inf -3 3 3 3 1 2 3 3x1 3x2 -inf +3 3 3 3 1 2 3 3x1 3x2 3x3 1 2 3 L = 4 From b4c14ed22528f55af25b067751bff868a09d6aeb Mon Sep 17 00:00:00 2001 From: Liu Jinyi <1206668472@qq.com> Date: Mon, 25 Aug 2025 02:29:20 +0800 Subject: [PATCH 38/79] ci: add spell check workflow (#55) --- .cspell.json | 31 +++++++++++++++++++++++++++++++ .github/workflows/spell-check.yml | 28 ++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+) create mode 100644 .cspell.json create mode 100644 .github/workflows/spell-check.yml diff --git a/.cspell.json b/.cspell.json new file mode 100644 index 0000000..8f4136b --- /dev/null +++ b/.cspell.json @@ -0,0 +1,31 @@ +{ + "version": "0.2", + "language": "en", + "words": [ + "skyzh", + "numpy", + "Connor", + "CUDA", + "matmul", + "qwen", + "huggingface", + "dequantize", + "freqs", + "torchtune", + "Jinyi", + "logits", + "argmax", + "logprobs", + "softmax", + "feedforward", + "Convolutional", + "Roformer", + "bfloat", + "multihead", + "vllm", + "silu" + ], + "ignoreRegExpList": [ + "`[^`]*`", + ] +} \ No newline at end of file diff --git a/.github/workflows/spell-check.yml b/.github/workflows/spell-check.yml new file mode 100644 index 0000000..8032fe7 --- /dev/null +++ b/.github/workflows/spell-check.yml @@ -0,0 +1,28 @@ +name: Spell Check + +on: + push: + branches: ["main"] + pull_request: + branches: ["main"] + workflow_dispatch: + +jobs: + spell-check: + name: Run cspell + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Node.js + uses: actions/setup-node@v4 + with: + node-version: "20" + + - name: Install cspell globally + run: npm install -g cspell + + - name: Run spell check on Markdown files + run: cspell "book/**/*.md" From bf3383db5a0fe8fcbccfc807f93490a289ef3d75 Mon Sep 17 00:00:00 2001 From: Gunther Xing Date: Mon, 8 Sep 2025 00:43:47 +0800 Subject: [PATCH 39/79] fix: Use non-traditional RoPE in Qwen2 test case. (#56) --- book/src/week1-03-gqa.md | 2 ++ tests_refsol/test_week_1_day_3.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/book/src/week1-03-gqa.md b/book/src/week1-03-gqa.md index dcd9c6c..70a05ff 100644 --- a/book/src/week1-03-gqa.md +++ b/book/src/week1-03-gqa.md @@ -108,6 +108,8 @@ x = scaled_dot_product_attention_grouped(q, k, v, scale, mask) -> B, L, H_q, D ; x = linear(x, wo) -> B, L, E ``` +Keep in mind that you should use non-traditional RoPE. + You can test your implementation by running the following command: ```bash diff --git a/tests_refsol/test_week_1_day_3.py b/tests_refsol/test_week_1_day_3.py index 0b17c81..f13eb6c 100644 --- a/tests_refsol/test_week_1_day_3.py +++ b/tests_refsol/test_week_1_day_3.py @@ -158,7 +158,7 @@ def test_task_3_qwen2_grouped_query_attention( rms_norm_eps=1e-6, vocab_size=1000, rope_theta=theta, - rope_traditional=False, + rope_traditional=True, max_position_embeddings=max_seq_len, ) From 1c9369ac2b2c6e9cd14fa24fde3e3efd7bd2c62b Mon Sep 17 00:00:00 2001 From: Gunther Xing Date: Mon, 8 Sep 2025 00:44:38 +0800 Subject: [PATCH 40/79] fix: mlx-llm Qwen2 RMSNorm url link (#57) Refer to another commit cause you can't find RMSNorm impl in the current mlx-llm repo (it's replaced by mlx fast impl). --- book/src/week1-04-rmsnorm-and-mlp.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/book/src/week1-04-rmsnorm-and-mlp.md b/book/src/week1-04-rmsnorm-and-mlp.md index 2ad38aa..a689f4d 100644 --- a/book/src/week1-04-rmsnorm-and-mlp.md +++ b/book/src/week1-04-rmsnorm-and-mlp.md @@ -14,7 +14,7 @@ src/tiny_llm/layer_norm.py **📚 Readings** * [Root Mean Square Layer Normalization](https://arxiv.org/abs/1910.07467) -* [Qwen2 layers implementation in mlx-lm (includes RMSNorm)](https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/models/qwen2.py) - See `Qwen2RMSNorm`. +* [Qwen2 layers implementation in mlx-lm (includes RMSNorm)]([https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/models/qwen2.py](https://github.com/ml-explore/mlx-lm/blob/bcb96db87f218453774f8808159012f15fc0dc7b/mlx_lm/models/qwen2.py)) - See `RMSNorm`. RMSNorm is defined as: @@ -116,4 +116,4 @@ pdm run test --week 1 --day 4 ``` -{{#include copyright.md}} \ No newline at end of file +{{#include copyright.md}} From 04149a33173a2aad2c59d89abbe28e790a6ca361 Mon Sep 17 00:00:00 2001 From: Gunther Xing Date: Mon, 8 Sep 2025 00:45:01 +0800 Subject: [PATCH 41/79] add test for week 1 day 5 test 1: Qwen2TransformerBlock (#59) --- tests_refsol/test_week_1_day_5.py | 77 ++++++++++++++++++++++++++++++- 1 file changed, 75 insertions(+), 2 deletions(-) diff --git a/tests_refsol/test_week_1_day_5.py b/tests_refsol/test_week_1_day_5.py index 7ac1d1e..eb5a07b 100644 --- a/tests_refsol/test_week_1_day_5.py +++ b/tests_refsol/test_week_1_day_5.py @@ -3,8 +3,81 @@ from .tiny_llm_base import Qwen2ModelWeek1, Embedding, dequantize_linear, qwen2_week1 from mlx_lm import load -# TODO: task 1 tests - +@pytest.mark.parametrize("stream", AVAILABLE_STREAMS, ids=AVAILABLE_STREAMS_IDS) +@pytest.mark.parametrize("precision", PRECISIONS, ids=PRECISION_IDS) +@pytest.mark.parametrize("mask", [None, "causal"], ids=["no_mask", "causal_mask"]) +def test_task_1_transformer_block( + stream: mx.Stream, precision: mx.Dtype, mask: str | None +): + with mx.stream(stream): + from mlx_lm.models import qwen2 + + BATCH_SIZE = 1 + SEQ_LEN = 10 + NUM_ATTENTION_HEAD = 4 + NUM_KV_HEADS = 2 + HIDDEN_SIZE = 32 + INTERMEDIATE_SIZE = HIDDEN_SIZE * 4 + + args = qwen2.ModelArgs( + model_type="qwen2", + hidden_size=HIDDEN_SIZE, + num_hidden_layers=1, + intermediate_size=INTERMEDIATE_SIZE, + num_attention_heads=NUM_ATTENTION_HEAD, + num_key_value_heads=NUM_KV_HEADS, + rms_norm_eps=1e-6, + vocab_size=1000, + ) + + mlx_transformer_block = qwen2.TransformerBlock(args) + + mlx_attention = mlx_transformer_block.self_attn + wq = mlx_attention.q_proj.weight + wk = mlx_attention.k_proj.weight + wv = mlx_attention.v_proj.weight + wo = mlx_attention.o_proj.weight + bq = mlx_attention.q_proj.bias + bk = mlx_attention.k_proj.bias + bv = mlx_attention.v_proj.bias + + mlx_mlp = mlx_transformer_block.mlp + w_gate = mlx_mlp.gate_proj.weight + w_up = mlx_mlp.up_proj.weight + w_down = mlx_mlp.down_proj.weight + + w_input_layernorm = mlx_transformer_block.input_layernorm.weight + w_post_attention_layernorm = mlx_transformer_block.post_attention_layernorm.weight + + user_transformer_block = qwen2_week1.Qwen2TransformerBlock( + num_attention_heads=NUM_ATTENTION_HEAD, + num_kv_heads=NUM_KV_HEADS, + hidden_size=HIDDEN_SIZE, + intermediate_size=INTERMEDIATE_SIZE, + rms_norm_eps=1e-6, + wq=wq, + wk=wk, + wv=wv, + wo=wo, + bq=bq, + bk=bk, + bv=bv, + w_gate=w_gate, + w_up=w_up, + w_down=w_down, + w_input_layernorm=w_input_layernorm, + w_post_attention_layernorm=w_post_attention_layernorm + ) + + mx.random.seed(42) + x = mx.random.uniform( + shape=(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE), dtype=precision + ) + + user_output = user_transformer_block(x, mask=mask) + mlx_output = mlx_transformer_block(x, mask=mask, cache=None) + + assert_allclose(user_output, mlx_output, precision=precision, rtol=1e-1) @pytest.mark.skipif( not qwen_2_05b_model_exists(), reason="Qwen2-0.5B-Instruct-MLX model not found" From 81b917d5d60f54d1930e720dc4b059a67d1eb007 Mon Sep 17 00:00:00 2001 From: Eric Zhang Date: Sun, 7 Sep 2025 22:01:00 -0400 Subject: [PATCH 42/79] Possible typo in week1-01-attention (#60) * Possible typo in week1-01-attention Hello, was going through the book! I'm not 100% sure of this, but after going through the tests for day1-task2, it looks like the w_qkv matrices and w_o matrix have their shape reversed. I confirmed by checking the mlx.nn.layers.linear.Linear weight, which is of shape `[Output, Input]`. Since w_qkv's output is HxD and input is E, the shape should be `[H x D, E]`. * Oops fix another typo --- book/src/week1-01-attention.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/book/src/week1-01-attention.md b/book/src/week1-01-attention.md index 5f73efd..7ebf300 100644 --- a/book/src/week1-01-attention.md +++ b/book/src/week1-01-attention.md @@ -118,9 +118,9 @@ H is num_heads D is head_dim L is seq_len, in PyTorch API it's S (source len) -w_q/w_k/w_v: E x (H x D) +w_q/w_k/w_v: (H x D) x E output/input: N x L x E -w_o: (H x D) x E +w_o: E x (H x D) ``` At the end of the task, you should be able to pass the following tests: From fa8b08ee765cb76153f18b4ea6c270efdda68a65 Mon Sep 17 00:00:00 2001 From: Gunther Xing Date: Wed, 10 Sep 2025 10:27:24 +0800 Subject: [PATCH 43/79] Revert "fix: Use non-traditional RoPE in Qwen2 test case. (#56)" (#62) * Revert "fix: Use non-traditional RoPE in Qwen2 test case. (#56)" This reverts commit bf3383db5a0fe8fcbccfc807f93490a289ef3d75. * Update week1-03-gqa.md with RoPE note and test command Added note about using non-traditional RoPE and testing command. --------- Co-authored-by: Alex Chi Z. <4198311+skyzh@users.noreply.github.com> --- tests_refsol/test_week_1_day_3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests_refsol/test_week_1_day_3.py b/tests_refsol/test_week_1_day_3.py index f13eb6c..0b17c81 100644 --- a/tests_refsol/test_week_1_day_3.py +++ b/tests_refsol/test_week_1_day_3.py @@ -158,7 +158,7 @@ def test_task_3_qwen2_grouped_query_attention( rms_norm_eps=1e-6, vocab_size=1000, rope_theta=theta, - rope_traditional=True, + rope_traditional=False, max_position_embeddings=max_seq_len, ) From 919a3e556d44284237205baa9a87ba658759cc66 Mon Sep 17 00:00:00 2001 From: Alex Chi Z Date: Sat, 13 Sep 2025 14:51:49 -0400 Subject: [PATCH 44/79] format and warn on different test files Signed-off-by: Alex Chi Z --- benches/test_attention.py | 15 ++++++++++----- scripts/dev-tools.py | 16 +++++++++++++--- src/tiny_llm/kv_cache.py | 2 ++ src/tiny_llm_ref/batch.py | 2 +- src/tiny_llm_ref/qwen2_week1.py | 2 +- tests_refsol/test_week_1_day_5.py | 18 ++++++++++-------- tests_refsol/test_week_2_day_1.py | 5 ++++- 7 files changed, 41 insertions(+), 19 deletions(-) diff --git a/benches/test_attention.py b/benches/test_attention.py index ced71c5..1b0c912 100644 --- a/benches/test_attention.py +++ b/benches/test_attention.py @@ -4,6 +4,7 @@ from .utils import assert_allclose import pytest + def get_test_attention_data(): # Qwen2 7B matrix size init = nn.init.he_uniform(mx.float32) @@ -13,10 +14,13 @@ def get_test_attention_data(): res = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0) return q, k, v, res + def test_mlx_attention(benchmark): with mx.stream(mx.gpu): q, k, v, res = get_test_attention_data() - result = benchmark(lambda: mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0)) + result = benchmark( + lambda: mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0) + ) assert_allclose(result, res, precision=mx.float32, rtol=1e-2) @@ -24,14 +28,15 @@ def test_refsol_attention(benchmark): with mx.stream(mx.gpu): q, k, v, res = get_test_attention_data() result = benchmark( - lambda: tiny_llm_ref.scaled_dot_product_attention_grouped(q, k, v, scale=1.0) + lambda: tiny_llm_ref.scaled_dot_product_attention_grouped( + q, k, v, scale=1.0 + ) ) assert_allclose(result, res, precision=mx.float32, rtol=1e-2) + def test_refsol_flash_attention(benchmark): with mx.stream(mx.gpu): q, k, v, res = get_test_attention_data() - result = benchmark( - lambda: tiny_llm_ref.flash_attention(q, k, v, scale=1.0) - ) + result = benchmark(lambda: tiny_llm_ref.flash_attention(q, k, v, scale=1.0)) assert_allclose(result, res, precision=mx.float32, rtol=1e-2) diff --git a/scripts/dev-tools.py b/scripts/dev-tools.py index fe3d65c..177ab67 100644 --- a/scripts/dev-tools.py +++ b/scripts/dev-tools.py @@ -2,12 +2,21 @@ import shutil import os import pytest +from pathlib import Path -def copy_test(args, skip_if_exists=False): +def copy_test(args, skip_if_exists=False, force=False): source_file = f"tests_refsol/test_week_{args.week}_day_{args.day}.py" target_file = f"tests/test_week_{args.week}_day_{args.day}.py" - if skip_if_exists and os.path.exists(target_file): + if skip_if_exists and os.path.exists(target_file) and not force: + # diff the two files and warn if they are different + if Path(source_file).read_text() != Path(target_file).read_text(): + print( + f"[WARNING] {target_file} already exists and is different from {source_file}" + ) + print( + f"You can run `pdm run copy-test --week {args.week} --day {args.day} --force` to update it" + ) return print(f"copying {source_file} to {target_file}") shutil.copyfile(source_file, target_file) @@ -45,6 +54,7 @@ def main(): copy_test_parser = subparsers.add_parser("copy-test") copy_test_parser.add_argument("--week", type=int, required=True) copy_test_parser.add_argument("--day", type=int, required=True) + copy_test_parser.add_argument("--force", action="store_true") copy_test_parser.set_defaults(copy_test_parser=True) test_parser = subparsers.add_parser("test") test_parser.add_argument("--week", type=int, required=False) @@ -58,7 +68,7 @@ def main(): test_refsol_parser.set_defaults(test_refsol_parser=True) args = parser.parse_args() if hasattr(args, "copy_test_parser"): - copy_test(args) + copy_test(args, args.force) if hasattr(args, "test_parser"): test(args) if hasattr(args, "test_refsol_parser"): diff --git a/src/tiny_llm/kv_cache.py b/src/tiny_llm/kv_cache.py index 5ea3434..ae290d8 100644 --- a/src/tiny_llm/kv_cache.py +++ b/src/tiny_llm/kv_cache.py @@ -2,6 +2,7 @@ import mlx.core as mx + class TinyKvCache: def update_and_fetch( self, @@ -27,6 +28,7 @@ def update_and_fetch( """ pass + class BatchingKvCache(TinyKvCache): def __init__(self, max_active_requests: int, max_seq_len: int): pass diff --git a/src/tiny_llm_ref/batch.py b/src/tiny_llm_ref/batch.py index a0ff373..9e623e7 100644 --- a/src/tiny_llm_ref/batch.py +++ b/src/tiny_llm_ref/batch.py @@ -93,7 +93,7 @@ def _print_progress( if is_idle[i]: print(f" Decode #{i}: idle", flush=True) else: - text_preview = requests[i].text()[-80:].replace('\n', ' ') + text_preview = requests[i].text()[-80:].replace("\n", " ") print( f"{animation_frame} Decode [req {requests[i].prompt_idx}, {requests[i].offset}]: {text_preview}", flush=True, diff --git a/src/tiny_llm_ref/qwen2_week1.py b/src/tiny_llm_ref/qwen2_week1.py index 21a3848..e978427 100644 --- a/src/tiny_llm_ref/qwen2_week1.py +++ b/src/tiny_llm_ref/qwen2_week1.py @@ -59,6 +59,7 @@ def __call__( projection_v = linear(x, self.wv, bias=self.bv).reshape( B, L, self.num_kv_heads, self.head_dim ) + print(x.shape, self.wk.shape) projection_q = self.rope(projection_q, offset=slice(0, L)) projection_k = self.rope(projection_k, offset=slice(0, L)) projection_q = projection_q.transpose(0, 2, 1, 3) @@ -215,7 +216,6 @@ def __init__( self.w_lm_head = None self.mlx_model = mlx_model - def __call__( self, inputs: mx.array, diff --git a/tests_refsol/test_week_1_day_5.py b/tests_refsol/test_week_1_day_5.py index eb5a07b..dfd45fa 100644 --- a/tests_refsol/test_week_1_day_5.py +++ b/tests_refsol/test_week_1_day_5.py @@ -3,6 +3,7 @@ from .tiny_llm_base import Qwen2ModelWeek1, Embedding, dequantize_linear, qwen2_week1 from mlx_lm import load + @pytest.mark.parametrize("stream", AVAILABLE_STREAMS, ids=AVAILABLE_STREAMS_IDS) @pytest.mark.parametrize("precision", PRECISIONS, ids=PRECISION_IDS) @pytest.mark.parametrize("mask", [None, "causal"], ids=["no_mask", "causal_mask"]) @@ -18,7 +19,7 @@ def test_task_1_transformer_block( NUM_KV_HEADS = 2 HIDDEN_SIZE = 32 INTERMEDIATE_SIZE = HIDDEN_SIZE * 4 - + args = qwen2.ModelArgs( model_type="qwen2", hidden_size=HIDDEN_SIZE, @@ -47,8 +48,10 @@ def test_task_1_transformer_block( w_down = mlx_mlp.down_proj.weight w_input_layernorm = mlx_transformer_block.input_layernorm.weight - w_post_attention_layernorm = mlx_transformer_block.post_attention_layernorm.weight - + w_post_attention_layernorm = ( + mlx_transformer_block.post_attention_layernorm.weight + ) + user_transformer_block = qwen2_week1.Qwen2TransformerBlock( num_attention_heads=NUM_ATTENTION_HEAD, num_kv_heads=NUM_KV_HEADS, @@ -66,19 +69,18 @@ def test_task_1_transformer_block( w_up=w_up, w_down=w_down, w_input_layernorm=w_input_layernorm, - w_post_attention_layernorm=w_post_attention_layernorm + w_post_attention_layernorm=w_post_attention_layernorm, ) mx.random.seed(42) - x = mx.random.uniform( - shape=(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE), dtype=precision - ) + x = mx.random.uniform(shape=(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE), dtype=precision) user_output = user_transformer_block(x, mask=mask) mlx_output = mlx_transformer_block(x, mask=mask, cache=None) - + assert_allclose(user_output, mlx_output, precision=precision, rtol=1e-1) + @pytest.mark.skipif( not qwen_2_05b_model_exists(), reason="Qwen2-0.5B-Instruct-MLX model not found" ) diff --git a/tests_refsol/test_week_2_day_1.py b/tests_refsol/test_week_2_day_1.py index 3032b86..ab18ddb 100644 --- a/tests_refsol/test_week_2_day_1.py +++ b/tests_refsol/test_week_2_day_1.py @@ -40,7 +40,10 @@ def helper_test_task_3(model_name: str, iters: int = 10): user_output = user_output - mx.logsumexp(user_output, keepdims=True) ref_output = mlx_model(input) ref_output = ref_output - mx.logsumexp(ref_output, keepdims=True) - assert_allclose(user_output, ref_output, precision=mx.float16, rtol=0.1, atol=0.5) + assert_allclose( + user_output, ref_output, precision=mx.float16, rtol=0.1, atol=0.5 + ) + @pytest.mark.skipif( not qwen_2_05b_model_exists(), reason="Qwen2-0.5B-Instruct-MLX model not found" From 34fb3fe26cc997af0483d3a00acbbd98acdcacb6 Mon Sep 17 00:00:00 2001 From: Alex Chi Z Date: Sat, 13 Sep 2025 15:00:20 -0400 Subject: [PATCH 45/79] mention that we have quantized weight now Signed-off-by: Alex Chi Z --- book/src/week2-01-kv-cache.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/book/src/week2-01-kv-cache.md b/book/src/week2-01-kv-cache.md index ab02158..d3a4b6a 100644 --- a/book/src/week2-01-kv-cache.md +++ b/book/src/week2-01-kv-cache.md @@ -149,6 +149,10 @@ and the context of week 1 day 3: in the GQA implementation, k/v's sequence lengt q's sequence length is `L`. In the Qwen2 multihead attention implementation, `L'` is the "new token" and `L` is the total sequence length, which corresponds to `L` and `S` in week 1 respectively. +Note that another refactor of this week's code is that all modules now take `QuantizedWeights` instead of `mx.array` +for some weights. You will need to move the dequantize code from loading the model to each module first, and we +will replace it with our own quantized matmul implementation for the rest of the week. + ## Task 3: Implement the Model ``` From 14498161ce672d625e8958ccbc741059ac11a70e Mon Sep 17 00:00:00 2001 From: Alex Chi Z <4198311+skyzh@users.noreply.github.com> Date: Sat, 13 Sep 2025 15:55:36 -0400 Subject: [PATCH 46/79] add chunked prefill and continuous batching writeup (#64) Signed-off-by: Alex Chi Z --- book/src/SUMMARY.md | 3 +- book/src/week2-06-prefill-and-batch.md | 113 ++++++++++++++++ src/tiny_llm/attention.py | 1 + src/tiny_llm/batch.py | 171 +++++++++++++++++++++++++ src/tiny_llm/generate.py | 11 -- src/tiny_llm/kv_cache.py | 59 ++++++++- src/tiny_llm/qwen2_week2.py | 3 + src/tiny_llm_ref/kv_cache.py | 3 - src/tiny_llm_ref/qwen2_week1.py | 1 - 9 files changed, 341 insertions(+), 24 deletions(-) create mode 100644 book/src/week2-06-prefill-and-batch.md diff --git a/book/src/SUMMARY.md b/book/src/SUMMARY.md index 502a072..ca5b41a 100644 --- a/book/src/SUMMARY.md +++ b/book/src/SUMMARY.md @@ -17,8 +17,7 @@ - [Key-Value Cache](./week2-01-kv-cache.md) - [Quantized Matmul (2 Days)]() - [Flash Attention (2 Days)]() - - [Chunked Prefill]() - - [Continuous Batching]() + - [Continuous Batching (2 Days)](./week2-06-prefill-and-batch.md) - [Week 3: Serving]() --- diff --git a/book/src/week2-06-prefill-and-batch.md b/book/src/week2-06-prefill-and-batch.md new file mode 100644 index 0000000..ca53b1c --- /dev/null +++ b/book/src/week2-06-prefill-and-batch.md @@ -0,0 +1,113 @@ +# Week 2 Day 6 and 7: Chunked Prefill and Continuous Batching + +In this chapter, we will implement **continuous batching**. The idea is to batch multiple requests together so we can make full use of the compute resources. + +So far, we have assumed that the model only processes a single batch each time it is called. However, a single batch is usually not enough to saturate the compute resources. To address this, we can process multiple requests at the same time. + +The first question is how to batch requests. A naive approach would be to select a fixed number of prompts (for example, 5) from the request queue and perform decoding as before. The problem is that different prompts produce sequences of different lengths. It is possible that 4 out of 5 requests finish decoding quickly, while the remaining one takes much longer. This leads to wasted compute resources and stalls all other requests. + +A smarter approach is **continuous batching**. That is, we set the maximum number of requests we can process at once. When one request finishes, we replace its slot (i.e., its KV cache) with another request. In this way, the pipeline remains fully utilized. + +Another challenge is how to handle decoding and prefilling at the same time. In this chapter, we adopt a simplified approach: we prefill one request, then decode one token for each request in progress. The general idea can be described with the following pseudocode: + +```python +while requests_in_queue_or_in_progress: + if prefill_request exists: + prefill_request.try_prefill() # perform a chunk of chunked prefill + if prefill_request.ready: + if kv_cache.try_add(prefill_request): + prefill_request = next(requests) + tokens = decode(model, kv_cache) + requests.append(tokens) +``` + +We will also implement **chunked prefill** in this chapter. Prefilling a long prompt can take a significant amount of time. Since we are interleaving prefills and decodes, we want to reduce the latency of producing the next token. Ideally, the time slots for prefill and decode should be roughly equal. To achieve this, we can prefill a portion of the request at a time, using multiple slots to finish the entire prefill. + +For prefilling, this essentially means providing a chunk of tokens to the model to populate the KV cache. For example: + +```python +# assume prompt_tokens is a list of 400 tokens and prefill chunk size is 128 +_step(model, prompt_tokens[0:128], offset=0, kv_cache) +_step(model, prompt_tokens[128:256], offset=128, kv_cache) +_step(model, prompt_tokens[256:384], offset=256, kv_cache) +_step(model, prompt_tokens[384:400], offset=384, kv_cache) +``` + +Note that the causal mask generated during prefilling has the shape `LxS`. For example, assume we already have 5 tokens in the KV cache and want to prefill 3 tokens. The mask should look like this: + +``` +0 0 0 -inf -inf +0 0 0 0 -inf +0 0 0 0 0 +``` + +This is the same masking logic you implemented in Week 1. + +## Task 1: Batch RoPE and Causal Mask for Prefill + +``` +src/tiny_llm/positional_encoding.py +src/tiny_llm/attention.py::causal_mask +``` + +Ensure your RoPE implementation accepts a list of offsets. Also, make sure your mask implementation correctly handles the case where `L != S`. + +## Task 2: Batch KV Cache + +``` +src/tiny_llm/kv_cache.py::BatchingKvCache +``` + +The batch KV cache is a collection of KV caches, one for each request. A challenge here is generating a `BxHxLxS` mask for the batch, since requests can have different lengths. + +``` +S = max(S_i of the batch) +L = mask_length (input parameter) +keys: 1, H, S_i, D +values: 1, H, S_i, D +batched_keys: B, H, S, D +batched_values: B, H, S, D +mask: B, 1, L, S +``` + +You should fill the `batched_keys` and `batched_values` arrays so that each request’s data is aligned at the end: + +```python +batched_keys[i, :, (S-S_i):S, :] = keys[i, :, :, :] +batched_values[i, :, (S-S_i):S, :] = values[i, :, :, :] +mask[i, :, 0:L, (S-S_i):S] = causal_mask(L, S_i) +``` + +## Task 3: Handle Batches in the Model + +``` +src/tiny_llm/qwen2_week2.py +``` + +Ensure your model can handle multiple requests simultaneously. You should also use the masks returned by the batch KV cache. + +## Task 4: Batch Generate + +``` +src/tiny_llm/batch.py +``` + +Implement `try_prefill` so that it prefills an entire request at once. Then implement the rest of the code as described in the starter code. + +## Task 5: Chunked Prefill + +``` +src/tiny_llm/batch.py +``` + +Modify `try_prefill` so that it performs prefilling in chunks, rather than all at once. + +You can test your implementation by running: + +```bash +pdm run batch-main +``` + +This will use the `qwen2-0.5b` model with a batch size of 5 to process a fixed set of prompts. + +{{#include copyright.md}} diff --git a/src/tiny_llm/attention.py b/src/tiny_llm/attention.py index be1d0d4..ccb1179 100644 --- a/src/tiny_llm/attention.py +++ b/src/tiny_llm/attention.py @@ -53,5 +53,6 @@ def flash_attention( key: mx.array, value: mx.array, scale: float | None = None, + mask: mx.array | None = None, ) -> mx.array: pass diff --git a/src/tiny_llm/batch.py b/src/tiny_llm/batch.py index e69de29..f8b5a08 100644 --- a/src/tiny_llm/batch.py +++ b/src/tiny_llm/batch.py @@ -0,0 +1,171 @@ +import mlx.core as mx +from mlx_lm.tokenizer_utils import TokenizerWrapper +from .kv_cache import * +from .qwen2_week2 import Qwen2ModelWeek2 +from typing import Callable +from datetime import datetime + + +def _step(model, y, offsets, kv_cache): + logits = model(y, offsets, kv_cache) + logits = logits[:, -1, :] + logprobs = logits - mx.logsumexp(logits, keepdims=True) + sampler = lambda x: mx.argmax(x, axis=-1) + y = sampler(logprobs) + return y + + +class Request: + def __init__( + self, + model: any, + tokenizer: TokenizerWrapper, + prompt: str, + prefill_max_step: int = 128, + prompt_idx: int = 0, + ): + self.prompt = prompt + self.kv_cache = [TinyKvFullCache() for _ in range(model.num_hidden_layers)] + self.model = model + self.detokenizer = tokenizer.detokenizer.__class__(tokenizer._tokenizer) + self.prefill_tokens = mx.array( + tokenizer.encode(prompt, add_special_tokens=False) + ) + self.prefill_max_step = prefill_max_step + self.is_done = False + self.is_prefill_done = False + self.eos_token_id = tokenizer.eos_token_id + self.next_token = None + self.offset = 0 + self.prompt_idx = prompt_idx + + def try_prefill(self): + """ + Prefill this request up to max_step size, returns None if prefill is not done + """ + if self.is_prefill_done: + raise ValueError("prefill called after done") + # TODO: in task 4, prefill the full request at once; in task 5, prefill a chunk at a time + + def decode_done(self, token, update_offset=True): + if self.is_done: + raise ValueError("decode called after done") + if token == self.eos_token_id: + self.is_done = True + return + # TODO: update the offset and add the token to the detokenizer + + def text(self): + return self.detokenizer.text + + +def _print_progress( + requests: list[Request | None], + is_idle: list[bool], + pending_prefill_request: Request | None, + queue_size: int, + progress_cnt: int, + start_time: datetime, +): + print(f" --- {datetime.now() - start_time}") + animation_frames = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"] + animation_frame = animation_frames[progress_cnt % len(animation_frames)] + for i in range(len(requests)): + if is_idle[i]: + print(f" Decode #{i}: idle", flush=True) + else: + text_preview = requests[i].text()[-80:].replace("\n", " ") + print( + f"{animation_frame} Decode [req {requests[i].prompt_idx}, {requests[i].offset}]: {text_preview}", + flush=True, + ) + if pending_prefill_request is not None: + if pending_prefill_request.is_prefill_done: + print( + f" Prefill [req {pending_prefill_request.prompt_idx}]: done, waiting for slot, {queue_size} requests in queue", + flush=True, + ) + return + precentage = ( + pending_prefill_request.offset / pending_prefill_request.prefill_tokens.size + ) * 100 + print( + f"{animation_frame} Prefill [req {pending_prefill_request.prompt_idx}]: {precentage:.2f}% ({pending_prefill_request.prefill_tokens.size - pending_prefill_request.offset} remaining tokens)", + flush=True, + ) + else: + print(f" Prefill: idle, {queue_size} requests in queue", flush=True) + + +def batch_generate( + model: any, + tokenizer: TokenizerWrapper, + prompts: list[str], + max_seq_len=512, + batch_size=5, + prefill_step=128, +): + decode_requests: list[Request] = [None] * batch_size + is_idle = [True] * batch_size + kv_cache = [ + BatchingKvCache(max_active_requests=batch_size, max_seq_len=max_seq_len) + for _ in range(model.num_hidden_layers) + ] + result = [] + pending_prefill_request = None + next_request_idx = 0 + progress_cnt = 0 + start_time = datetime.now() + + while True: + if len(prompts) == 0 and all(is_idle): + break + # prefill until no idle slots + if len(prompts) > 0 and pending_prefill_request is None: + prompt = prompts.pop(0) + pending_prefill_request = Request( + model, tokenizer, prompt, prefill_step, next_request_idx + ) + next_request_idx += 1 + + # In every iteration, we do a prefill first + if pending_prefill_request is not None: + made_progress = False + if not pending_prefill_request.is_prefill_done: + pending_prefill_request.try_prefill() + made_progress = True + if pending_prefill_request.is_prefill_done: + # Implement this: find an idle slot and add the request to the decode requests + pass + if made_progress: + _print_progress( + decode_requests, + is_idle, + pending_prefill_request, + len(prompts), + progress_cnt, + start_time, + ) + progress_cnt += 1 + + # After the prefill request moves forward one step, we do the decode + if not all(is_idle): + next_tokens = [] + offsets = [] + # TODO: collect the next tokens and offsets from the decode requests + next_tokens = _step(model, next_tokens.reshape(-1, 1), offsets, kv_cache) + for i in range(batch_size): + # TODO: check if the decode has finished by comparing EOS or the seqlength. If so, + # remove the request from the decode requests and add the result to the result list; + # otherwise, call `decode_done` to update the offset and add the token to the detokenizer + pass + _print_progress( + decode_requests, + is_idle, + pending_prefill_request, + len(prompts), + progress_cnt, + start_time, + ) + progress_cnt += 1 + return result diff --git a/src/tiny_llm/generate.py b/src/tiny_llm/generate.py index c201d9f..67df12c 100644 --- a/src/tiny_llm/generate.py +++ b/src/tiny_llm/generate.py @@ -20,14 +20,3 @@ def simple_generate_with_kv_cache( ) -> str: def _step(model, y, offset, kv_cache): pass - - -def batch_generate( - model: any, - tokenizer: TokenizerWrapper, - prompts: list[str], - max_seq_len=512, - batch_size=5, - prefill_step=128, -): - pass diff --git a/src/tiny_llm/kv_cache.py b/src/tiny_llm/kv_cache.py index ae290d8..fb6e666 100644 --- a/src/tiny_llm/kv_cache.py +++ b/src/tiny_llm/kv_cache.py @@ -31,12 +31,52 @@ def update_and_fetch( class BatchingKvCache(TinyKvCache): def __init__(self, max_active_requests: int, max_seq_len: int): - pass + self.max_active_requests = max_active_requests + self.max_seq_len = max_seq_len + self.kv_caches: list[TinyKvCache] = [None] * max_active_requests + self.HD = None def update_and_fetch( - self, key: mx.array, value: mx.array - ) -> tuple[mx.array, mx.array, int]: - pass + self, + keys: mx.array, + values: mx.array, + mask_length: int | None = None, + mask: mx.array | str | None = None, + ) -> tuple[mx.array, mx.array, int, Optional[mx.array]]: + B, H, S, D = keys.shape + assert keys.shape == values.shape + assert S <= self.max_seq_len + assert self.HD == (H, D), f"expect {self.HD} but got {H, D}" + assert B == self.max_active_requests + # Step 1: append the result to the cache + data = [] + for b in range(B): + if self.kv_caches[b] is None: + data.append(None) + continue + key, value = keys[b : b + 1], values[b : b + 1] + new_key, new_value, seq_len, mask = self.kv_caches[b].update_and_fetch( + key, value + ) + data.append((new_key[0], new_value[0], seq_len, mask)) + + # Step 2: compute seq_len of this batch + def get_seq_len(data): + if data is None: + return 0 + _, _, seq_len, _ = data + return seq_len + + seq_len = max(map(get_seq_len, data)) + + # Step 3: generate masks and a single array of keys and values + keys = mx.zeros((self.max_active_requests, H, seq_len, D), dtype=key.dtype) + values = mx.zeros((self.max_active_requests, H, seq_len, D), dtype=value.dtype) + masks = mx.full( + (self.max_active_requests, mask_length, seq_len), -mx.inf, dtype=key.dtype + ) + # TODO: generate masks and a single array of keys and values + return keys, values, None, masks.reshape(B, 1, mask_length, seq_len) def add_request(self, prefilled: TinyKvCache, id: int): pass @@ -47,9 +87,14 @@ def remove_request(self, id: int): class TinyKvFullCache(TinyKvCache): def __init__(self): - pass + self.key_values = None + self.offset = 0 def update_and_fetch( - self, key: mx.array, value: mx.array - ) -> tuple[mx.array, mx.array, int]: + self, + key: mx.array, + value: mx.array, + mask_length: int | None = None, + mask: mx.array | str | None = None, + ) -> tuple[mx.array, mx.array, int, Optional[mx.array]]: pass diff --git a/src/tiny_llm/qwen2_week2.py b/src/tiny_llm/qwen2_week2.py index 012da01..1604cba 100644 --- a/src/tiny_llm/qwen2_week2.py +++ b/src/tiny_llm/qwen2_week2.py @@ -24,6 +24,7 @@ def __init__( bv: mx.array, max_seq_len: int = 32768, theta: int = 1000000, + use_flash_attention: bool = False, ): pass @@ -74,6 +75,7 @@ def __init__( w_post_attention_layernorm: mx.array, max_seq_len: int = 32768, theta: int = 1000000, + use_flash_attention: bool = False, ): pass @@ -93,6 +95,7 @@ def __init__( mlx_model: Any, enable_flash_attn: bool = False, ): + self.num_hidden_layers = mlx_model.args.num_hidden_layers pass def __call__( diff --git a/src/tiny_llm_ref/kv_cache.py b/src/tiny_llm_ref/kv_cache.py index 605de88..9a90bf8 100644 --- a/src/tiny_llm_ref/kv_cache.py +++ b/src/tiny_llm_ref/kv_cache.py @@ -138,6 +138,3 @@ def update_and_fetch( self.key_values = (new_keys, new_values) self.offset += S return new_keys, new_values, self.offset, mask - - def get_offset(self): - return self.offset diff --git a/src/tiny_llm_ref/qwen2_week1.py b/src/tiny_llm_ref/qwen2_week1.py index e978427..94f7d92 100644 --- a/src/tiny_llm_ref/qwen2_week1.py +++ b/src/tiny_llm_ref/qwen2_week1.py @@ -59,7 +59,6 @@ def __call__( projection_v = linear(x, self.wv, bias=self.bv).reshape( B, L, self.num_kv_heads, self.head_dim ) - print(x.shape, self.wk.shape) projection_q = self.rope(projection_q, offset=slice(0, L)) projection_k = self.rope(projection_k, offset=slice(0, L)) projection_q = projection_q.transpose(0, 2, 1, 3) From 1fc0752ac87b15d4d26d36900297fdc65786726b Mon Sep 17 00:00:00 2001 From: Alex Chi Z <4198311+skyzh@users.noreply.github.com> Date: Sat, 13 Sep 2025 16:09:50 -0400 Subject: [PATCH 47/79] fix simple kv cache decoding (#65) Signed-off-by: Alex Chi Z --- src/tiny_llm_ref/generate.py | 10 ++++++---- src/tiny_llm_ref/qwen2_week2.py | 2 +- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/tiny_llm_ref/generate.py b/src/tiny_llm_ref/generate.py index 6338579..6012ba8 100644 --- a/src/tiny_llm_ref/generate.py +++ b/src/tiny_llm_ref/generate.py @@ -56,14 +56,16 @@ def _step(model, y, offset, kv_cache): tokens = mx.array(tokenizer.encode(prompt, add_special_tokens=False)) detokenizer = tokenizer.detokenizer detokenizer.reset() - offset = tokens.size + offset = 0 # generate/decode while True: token, _ = _step(model, tokens, offset, kv_cache) mx.eval(token) - detokenizer.add_token(token.item()) - print(detokenizer.last_segment, end="", flush=True) if token.item() == tokenizer.eos_token_id: break - offset += token.size + detokenizer.add_token(token.item()) + print(detokenizer.last_segment, end="", flush=True) + # The first iteration of this loop is prefill. We want to add the offset to the prefilled token size. + # Otherwise, we add the decoded token size (which is always 1). + offset += tokens.size tokens = token diff --git a/src/tiny_llm_ref/qwen2_week2.py b/src/tiny_llm_ref/qwen2_week2.py index 540087d..7db2e27 100644 --- a/src/tiny_llm_ref/qwen2_week2.py +++ b/src/tiny_llm_ref/qwen2_week2.py @@ -48,7 +48,7 @@ def __init__( self.bq = bq self.bk = bk self.bv = bv - self.rope = RoPE(self.head_dim, max_seq_len, theta) + self.rope = RoPE(self.head_dim, max_seq_len, theta, traditional=False) self.use_flash_attention = use_flash_attention def __call__( From 26aa2ff68a5d953fa250c60beedd142bfbf58a82 Mon Sep 17 00:00:00 2001 From: Alex Chi Z Date: Sat, 13 Sep 2025 16:11:42 -0400 Subject: [PATCH 48/79] update writeup progress Signed-off-by: Alex Chi Z --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index ea3d2a1..cda99fd 100644 --- a/README.md +++ b/README.md @@ -42,8 +42,8 @@ Week 1 is complete. Week 2 is in progress. | 2.3 | Quantized Matmul and Linear - GPU | ✅ | ✅ | 🚧 | | 2.4 | Flash Attention 2 - CPU | ✅ | ✅ | 🚧 | | 2.5 | Flash Attention 2 - GPU | ✅ | ✅ | 🚧 | -| 2.6 | Continuous Batching | ✅ | 🚧 | 🚧 | -| 2.7 | Chunked Prefill | ✅ | 🚧 | 🚧 | +| 2.6 | Continuous Batching | ✅ | 🚧 | ✅ | +| 2.7 | Chunked Prefill | ✅ | 🚧 | ✅ | | 3.1 | Paged Attention - Part 1 | 🚧 | 🚧 | 🚧 | | 3.2 | Paged Attention - Part 2 | 🚧 | 🚧 | 🚧 | | 3.3 | MoE (Mixture of Experts) | 🚧 | 🚧 | 🚧 | From 1f2ab122080b2a54d46ef90957a73f2f3ff7ec2a Mon Sep 17 00:00:00 2001 From: Eric Zhang Date: Sun, 14 Sep 2025 14:48:36 -0400 Subject: [PATCH 49/79] Bump mlx to >=0.27 and fix build-ext from week 1, day 7 (#66) Resolves #50, applies the patch from there and updates pyproject / lockfile to specify newer version of mlx. --- pdm.lock | 2 +- pyproject.toml | 10 ++++------ src/extensions/src/axpby.h | 4 ++-- src/extensions/src/utils.cpp | 2 +- 4 files changed, 8 insertions(+), 10 deletions(-) diff --git a/pdm.lock b/pdm.lock index ce49d33..6588ddd 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default"] strategy = ["inherit_metadata"] lock_version = "4.5.0" -content_hash = "sha256:5daed9b471ae8d795d175b6732b822213acc091f3b969e745cb5a63a69ef87ac" +content_hash = "sha256:57bf554af33b4cc63ec547d0b25307f18a1642bec8ff0c628d71866a926180cd" [[metadata.targets]] requires_python = ">=3.10,<3.13" diff --git a/pyproject.toml b/pyproject.toml index ba3825e..8c29d25 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,18 +8,18 @@ version = "0.1.0" requires-python = ">=3.10, <3.13" readme = "README.md" dependencies = [ - "mlx>=0.25.0", + "mlx>=0.27.0", "torch>=2.6.0", "torchtune>=0.6.1", "torchao>=0.10.0", - "mlx-lm>=0.23.0", + "mlx-lm>=0.26.0", "numpy>=2.2.4", "pytest>=8.3.5", "ruff>=0.11.6", # this should not usually appear in a project dependency list but we add it to simplify the setup process "setuptools>=62", "nanobind==2.4.0", - "pytest-benchmark>=5.1.0" + "pytest-benchmark>=5.1.0", ] [tool.pdm.scripts] @@ -48,7 +48,5 @@ copy-test.cmd = "python scripts/dev-tools.py copy-test" book.cmd = "mdbook serve book/" [tool.pytest.ini_options] -addopts = [ - "--import-mode=importlib", -] +addopts = ["--import-mode=importlib"] pythonpath = "src" diff --git a/src/extensions/src/axpby.h b/src/extensions/src/axpby.h index 5ad87b3..cfa9ed3 100644 --- a/src/extensions/src/axpby.h +++ b/src/extensions/src/axpby.h @@ -63,10 +63,10 @@ class Axpby : public mx::Primitive { const std::vector &axes) override; /** Print the primitive. */ - void print(std::ostream &os) override; + void print(std::ostream &os); /** Name of the primitive (not virtual in some MLX versions). */ - const char* name() const { return "Axpby"; } + const char *name() const override { return "Axpby"; } /** Equivalence check **/ bool is_equivalent(const mx::Primitive &other) const override; diff --git a/src/extensions/src/utils.cpp b/src/extensions/src/utils.cpp index 97a5b80..3b52303 100644 --- a/src/extensions/src/utils.cpp +++ b/src/extensions/src/utils.cpp @@ -9,7 +9,7 @@ namespace tiny_llm_ext { void load_library(mx::Device d, const char *path) { #ifdef _METAL_ auto &md = mx::metal::device(d); - md.register_library("tiny_llm_ext", path); + md.get_library("tiny_llm_ext", path); #endif } From 308388eb243f05ce3b7ad5106ecce68790825601 Mon Sep 17 00:00:00 2001 From: Eric Zhang Date: Sun, 14 Sep 2025 17:40:41 -0400 Subject: [PATCH 50/79] CI workflow for pdm setup, build and testing refsol (#67) * Add CI for reference solution / building extensions * Adjust tests to run build-ext-ref before testing * Add sshx for debugging * Fix nanobind in CMake * Change when the workflow runs --- .github/workflows/macos.yml | 40 +++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) create mode 100644 .github/workflows/macos.yml diff --git a/.github/workflows/macos.yml b/.github/workflows/macos.yml new file mode 100644 index 0000000..719da95 --- /dev/null +++ b/.github/workflows/macos.yml @@ -0,0 +1,40 @@ +# Build and test the reference solution automatically on M1 runners. +# This helps prevent breakage of the dev setup. +name: macOS + +on: + push: + branches: + - main + pull_request: + +jobs: + test-refsol: + name: Test reference solution + runs-on: macos-15 # ARM64 + steps: + - uses: actions/checkout@v5 + + - uses: pdm-project/setup-pdm@v4 + with: + python-version: 3.12 + cache: true + + - run: pdm install + + - run: pdm run check-installation + + # Without this, future build steps fail in CMake. + - name: Add nanobind to CMake + run: | + nanobind_dir=$(pdm run python -c 'import nanobind, os; print(os.path.join(nanobind.__path__[0], "cmake"))') + echo "nanobind_DIR=${nanobind_dir}" >> $GITHUB_ENV + + - name: Try building extensions + run: | + pdm run build-ext + pdm run build-ext-test + + - run: pdm run build-ext-ref + + - run: pdm run test-refsol From 136ad7f31f6dabf6b001af71bbc52008d54edfd3 Mon Sep 17 00:00:00 2001 From: Eric Zhang Date: Tue, 16 Sep 2025 22:46:09 -0400 Subject: [PATCH 51/79] Day 6, task 1 tests - RoPE with multiple offsets (#68) This test requires the latest version of mlx 0.29.1, since they just merged support for this in mlx a week ago: https://github.com/ml-explore/mlx/pull/2564 I verified that the other tests still pass with the version upgrade. --- book/src/week2-06-prefill-and-batch.md | 8 +- pdm.lock | 46 ++++++------ pyproject.toml | 4 +- ...t_week_2_day_3.py => test_week_2_day_4.py} | 0 tests_refsol/test_week_2_day_6.py | 75 +++++++++++++++---- 5 files changed, 92 insertions(+), 41 deletions(-) rename tests_refsol/{test_week_2_day_3.py => test_week_2_day_4.py} (100%) diff --git a/book/src/week2-06-prefill-and-batch.md b/book/src/week2-06-prefill-and-batch.md index ca53b1c..b5cfa5d 100644 --- a/book/src/week2-06-prefill-and-batch.md +++ b/book/src/week2-06-prefill-and-batch.md @@ -50,7 +50,13 @@ src/tiny_llm/positional_encoding.py src/tiny_llm/attention.py::causal_mask ``` -Ensure your RoPE implementation accepts a list of offsets. Also, make sure your mask implementation correctly handles the case where `L != S`. +Ensure your RoPE implementation accepts a `list[slice]` of offsets (one slice for sequence in the batch). Also, make sure your mask implementation correctly handles the case where `L != S`. + +You can verify multi-offset RoPE, and that masking works for attention and flash attention with: + +```bash +pdm run test --week 2 --day 6 -- -k task_1 +``` ## Task 2: Batch KV Cache diff --git a/pdm.lock b/pdm.lock index 6588ddd..6dd7b75 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default"] strategy = ["inherit_metadata"] lock_version = "4.5.0" -content_hash = "sha256:57bf554af33b4cc63ec547d0b25307f18a1642bec8ff0c628d71866a926180cd" +content_hash = "sha256:1c01c53bb8f2b7383a86ffdc398d66ebc30dce2d762a5e39761f35d152c78222" [[metadata.targets]] requires_python = ">=3.10,<3.13" @@ -783,58 +783,58 @@ files = [ [[package]] name = "mlx" -version = "0.27.1" +version = "0.29.1" requires_python = ">=3.9" summary = "A framework for machine learning on Apple silicon." groups = ["default"] dependencies = [ - "mlx-metal==0.27.1; platform_system == \"Darwin\"", + "mlx-metal==0.29.1; platform_system == \"Darwin\"", ] files = [ - {file = "mlx-0.27.1-cp310-cp310-macosx_13_0_arm64.whl", hash = "sha256:a033b65fe46425ad5032867d5a71a556a5168108d89aa7092b457556c70d84fc"}, - {file = "mlx-0.27.1-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:2836c6fd9803dc0c6cd06f204e31b3e0191e6c5b6bc8570b28661d926908cba3"}, - {file = "mlx-0.27.1-cp310-cp310-macosx_15_0_arm64.whl", hash = "sha256:8b0566054c46d84c470cc99cda2afc3914ad6c7808fbb724dc1ec235e5b2a98c"}, - {file = "mlx-0.27.1-cp310-cp310-manylinux_2_35_x86_64.whl", hash = "sha256:91ef93ce09900c9a8ca662cf34e3c39ab5af2762822ecd6b12fecae518be167f"}, - {file = "mlx-0.27.1-cp311-cp311-macosx_13_0_arm64.whl", hash = "sha256:d2e5dedbfbcbe558e51a5c476ca6a18e307676f9e49854eb27e53778bc474699"}, - {file = "mlx-0.27.1-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:9f04b9778897a879c9ca22e5413dfa1efc192d86d7211b184e079efec49dfb8b"}, - {file = "mlx-0.27.1-cp311-cp311-macosx_15_0_arm64.whl", hash = "sha256:01d794f9e390438ab4f942a18d9a8ca65bef10c2c2007ef38ca988d039d6d9d3"}, - {file = "mlx-0.27.1-cp311-cp311-manylinux_2_35_x86_64.whl", hash = "sha256:fae11432d0639789f1e172b19b35ac8987c8ab9716e55a23fc7a170d6545fc33"}, - {file = "mlx-0.27.1-cp312-cp312-macosx_13_0_arm64.whl", hash = "sha256:0c570c9afb57c697bd864504115be8a7c4de97f0b80557a597d496ee426a6812"}, - {file = "mlx-0.27.1-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:ccff7bbbd9df302b26e79013ef6d0c3531c9ba5963ead521e2d85856811b86a0"}, - {file = "mlx-0.27.1-cp312-cp312-macosx_15_0_arm64.whl", hash = "sha256:9ccadaed449c07dfeae484620992b904c17dfea7564f8df63095c60eed3af02b"}, - {file = "mlx-0.27.1-cp312-cp312-manylinux_2_35_x86_64.whl", hash = "sha256:742c413e75605b71db69379a176da63e32ba19b9e9ad03b8763adbd1fcfcd394"}, + {file = "mlx-0.29.1-cp310-cp310-macosx_13_0_arm64.whl", hash = "sha256:6a10d589a439346be1b7d8d0bbfe6cba45f1ef592f406b247b2da7131a9242c2"}, + {file = "mlx-0.29.1-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:673505b631f05041d7634366c8f52bf38d80439af44592f6e4291e59bb17a243"}, + {file = "mlx-0.29.1-cp310-cp310-macosx_15_0_arm64.whl", hash = "sha256:4fd9e75f778cbd4a1f0060d090824863da8935a0e071f4747a4bc8f4b8d83838"}, + {file = "mlx-0.29.1-cp310-cp310-manylinux_2_35_x86_64.whl", hash = "sha256:b68563e760e70355507f789f62a515a5809f71813540896282a7ce0958b2862a"}, + {file = "mlx-0.29.1-cp311-cp311-macosx_13_0_arm64.whl", hash = "sha256:4da94700d8d19966f56962d16fe12510e41c3a85d3c90c6fb532016a8ab3c6d5"}, + {file = "mlx-0.29.1-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:0edf6c2c34bdd073741f583ea72ad538bbec26ebf66dc907478bbd68f942b5b8"}, + {file = "mlx-0.29.1-cp311-cp311-macosx_15_0_arm64.whl", hash = "sha256:03dd298a3818ef32a764edbf602b2233739ab8eb3fd095b8a2ff4c994d373cdf"}, + {file = "mlx-0.29.1-cp311-cp311-manylinux_2_35_x86_64.whl", hash = "sha256:380f005a5c6889496999358ee7c5e8144a3b5c7c1f54efe1f74cfb073b37f1be"}, + {file = "mlx-0.29.1-cp312-cp312-macosx_13_0_arm64.whl", hash = "sha256:1c437c32931ab69e514dfec25d13e319cbe43b2bd524047cbbe103e8be32d903"}, + {file = "mlx-0.29.1-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:e0a7b9ac66a8000eef75ec2cb4394a2133a2b7b2290606ade47d7c1a1d05d16e"}, + {file = "mlx-0.29.1-cp312-cp312-macosx_15_0_arm64.whl", hash = "sha256:6b5795655fe1a313bbfa52f16a4b23226bef32e47bfc81d3885e9487ca34e7f7"}, + {file = "mlx-0.29.1-cp312-cp312-manylinux_2_35_x86_64.whl", hash = "sha256:7d7e15086bf71be3d7c6480816d7988ad91e4c8c3989a251fd12154f2f17f6c4"}, ] [[package]] name = "mlx-lm" -version = "0.26.2" +version = "0.27.1" requires_python = ">=3.8" summary = "LLMs with MLX and the Hugging Face Hub" groups = ["default"] dependencies = [ "jinja2", - "mlx>=0.26.0", + "mlx>=0.29.0", "numpy", "protobuf", "pyyaml", "transformers>=4.39.3", ] files = [ - {file = "mlx_lm-0.26.2-py3-none-any.whl", hash = "sha256:632624c4753a290dfe68f368d21f24883105cd8bba4c6ba5cf0905fecd626c1e"}, - {file = "mlx_lm-0.26.2.tar.gz", hash = "sha256:77e6f875bdea90a71174357363622b90a071d9c279bc958021e17e38d69fbc2a"}, + {file = "mlx_lm-0.27.1-py3-none-any.whl", hash = "sha256:300da6f63d8d392483b62b2abda794730fa04343dcb28a1f6a712f4c3ab60f3c"}, + {file = "mlx_lm-0.27.1.tar.gz", hash = "sha256:36640fb64c909cfd9baddf37b16e7d3b94a1a141033e6b7ea7a0ef5a965fb4ae"}, ] [[package]] name = "mlx-metal" -version = "0.27.1" +version = "0.29.1" requires_python = ">=3.9" summary = "A framework for machine learning on Apple silicon." groups = ["default"] marker = "platform_system == \"Darwin\"" files = [ - {file = "mlx_metal-0.27.1-py3-none-macosx_13_0_arm64.whl", hash = "sha256:c66d9b1adb3c0ea19492fba6493f672bc7542e65dd65f7e2995918815fbeb907"}, - {file = "mlx_metal-0.27.1-py3-none-macosx_14_0_arm64.whl", hash = "sha256:fe4415ddd242974d91c7ca0699cd01507d17da8a5ba304122ef137cdb5e7fff4"}, - {file = "mlx_metal-0.27.1-py3-none-macosx_15_0_arm64.whl", hash = "sha256:d025dea30bda8baa32c928cfa333eac64a5adc8d07656f8fc55072d99403ebc9"}, + {file = "mlx_metal-0.29.1-py3-none-macosx_13_0_arm64.whl", hash = "sha256:b9dadd432948eab196ed110db0dc745795fd516b7124c0d3c4d176fee678a07a"}, + {file = "mlx_metal-0.29.1-py3-none-macosx_14_0_arm64.whl", hash = "sha256:824b939b721a964a455aeea4d0e956e4cc945f3333522c1e72a077ae774bca49"}, + {file = "mlx_metal-0.29.1-py3-none-macosx_15_0_arm64.whl", hash = "sha256:ebd9ba8e83213f929663b92b8065b451a4276c7002ed83eae0fc8dde721c50c5"}, ] [[package]] diff --git a/pyproject.toml b/pyproject.toml index 8c29d25..5e57cab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,11 +8,11 @@ version = "0.1.0" requires-python = ">=3.10, <3.13" readme = "README.md" dependencies = [ - "mlx>=0.27.0", + "mlx>=0.29.1", "torch>=2.6.0", "torchtune>=0.6.1", "torchao>=0.10.0", - "mlx-lm>=0.26.0", + "mlx-lm>=0.27.1", "numpy>=2.2.4", "pytest>=8.3.5", "ruff>=0.11.6", diff --git a/tests_refsol/test_week_2_day_3.py b/tests_refsol/test_week_2_day_4.py similarity index 100% rename from tests_refsol/test_week_2_day_3.py rename to tests_refsol/test_week_2_day_4.py diff --git a/tests_refsol/test_week_2_day_6.py b/tests_refsol/test_week_2_day_6.py index c6079d7..12745bf 100644 --- a/tests_refsol/test_week_2_day_6.py +++ b/tests_refsol/test_week_2_day_6.py @@ -1,9 +1,54 @@ -import pytest import mlx.core as mx +import numpy as np +import pytest from .tiny_llm_base import * from .utils import * +def rope_helper(stream: mx.Stream, traditional: bool, precision: mx.Dtype): + BATCH_SIZE = 16 + NUM_HEADS = 8 + HEAD_DIM = 4 + MAX_SEQ_LEN = 14 + SEQ_LEN = 9 + BASE = 10000 + with mx.stream(stream): + for _ in range(100): + user_layer = RoPE(HEAD_DIM, MAX_SEQ_LEN, BASE, traditional=traditional) + x = mx.random.uniform( + shape=(BATCH_SIZE, SEQ_LEN, NUM_HEADS, HEAD_DIM), dtype=precision + ) + + input_pos = np.random.randint(0, MAX_SEQ_LEN - SEQ_LEN, size=BATCH_SIZE) + input_pos_mx = mx.array(input_pos, dtype=mx.int32) + input_pos_user = [slice(i, i + SEQ_LEN) for i in input_pos] + + reference_output = mx.fast.rope( + x.transpose(0, 2, 1, 3), + dims=HEAD_DIM, + traditional=traditional, + base=BASE, + scale=1.0, + offset=input_pos_mx, + ).transpose(0, 2, 1, 3) + user_output = user_layer(x, input_pos_user) + assert_allclose( + user_output, + reference_output, + precision, + atol=5e-6 if precision == mx.float32 else 1e-3, + ) + + +@pytest.mark.parametrize("stream", AVAILABLE_STREAMS, ids=AVAILABLE_STREAMS_IDS) +@pytest.mark.parametrize("traditional", [False, True], ids=["default", "traditional"]) +@pytest.mark.parametrize("precision", PRECISIONS, ids=PRECISION_IDS) +def test_task_1_rope_multiple_offsets( + stream: mx.Stream, traditional: bool, precision: mx.Dtype +): + rope_helper(stream, traditional, precision) + + def attention_helper( stream: mx.Stream, H_q, H, L, E, S, BATCH, use_flash_attention: bool = False ): @@ -75,57 +120,57 @@ def attention_helper( ) -def test_flash_attention_with_mask_cpu_small(): +def test_task_1_flash_attention_with_mask_cpu_small(): attention_helper(mx.cpu, 6, 3, 2, 5, 3, 1, use_flash_attention=True) -def test_flash_attention_with_mask_cpu(): +def test_task_1_flash_attention_with_mask_cpu(): attention_helper(mx.cpu, 18, 6, 7, 5, 3, 10, use_flash_attention=True) -def test_flash_attention_with_mask_cpu_large(): +def test_task_1_flash_attention_with_mask_cpu_large(): attention_helper(mx.cpu, 28, 4, 16, 128, 16, 3, use_flash_attention=True) -def test_flash_attention_with_mask_gpu_extra_small(): +def test_task_1_flash_attention_with_mask_gpu_extra_small(): attention_helper(mx.gpu, 1, 1, 5, 7, 4, 1, use_flash_attention=True) -def test_flash_attention_with_mask_gpu_small(): +def test_task_1_flash_attention_with_mask_gpu_small(): attention_helper(mx.gpu, 6, 3, 2, 5, 3, 1, use_flash_attention=True) -def test_flash_attention_with_mask_gpu(): +def test_task_1_flash_attention_with_mask_gpu(): attention_helper(mx.gpu, 18, 6, 7, 5, 3, 10, use_flash_attention=True) -def test_flash_attention_with_mask_gpu_large(): +def test_task_1_flash_attention_with_mask_gpu_large(): attention_helper(mx.gpu, 28, 4, 16, 128, 16, 3, use_flash_attention=True) -def test_attention_with_mask_cpu_small(): +def test_task_1_attention_with_mask_cpu_small(): attention_helper(mx.cpu, 6, 3, 2, 5, 3, 1, use_flash_attention=False) -def test_attention_with_mask_cpu(): +def test_task_1_attention_with_mask_cpu(): attention_helper(mx.cpu, 18, 6, 7, 5, 3, 10, use_flash_attention=False) -def test_attention_with_mask_cpu_large(): +def test_task_1_attention_with_mask_cpu_large(): attention_helper(mx.cpu, 28, 4, 16, 128, 16, 3, use_flash_attention=False) -def test_attention_with_mask_gpu_extra_small(): +def test_task_1_attention_with_mask_gpu_extra_small(): attention_helper(mx.gpu, 1, 1, 5, 7, 4, 1, use_flash_attention=False) -def test_attention_with_mask_gpu_small(): +def test_task_1_attention_with_mask_gpu_small(): attention_helper(mx.gpu, 6, 3, 2, 5, 3, 1, use_flash_attention=False) -def test_attention_with_mask_gpu(): +def test_task_1_attention_with_mask_gpu(): attention_helper(mx.gpu, 18, 6, 7, 5, 3, 10, use_flash_attention=False) -def test_attention_with_mask_gpu_large(): +def test_task_1_attention_with_mask_gpu_large(): attention_helper(mx.gpu, 28, 4, 16, 128, 16, 3, use_flash_attention=False) From 6635e4a1fe8a0fd065973cc6e5cdf79838b0a584 Mon Sep 17 00:00:00 2001 From: Eric Zhang Date: Thu, 18 Sep 2025 20:18:34 -0400 Subject: [PATCH 52/79] Add tests for week 2, day 6 - continuous batching (#69) * Add tests for week 2, day 6 - continuous batching * Download model weights in GitHub Actions --- .github/workflows/macos.yml | 5 ++ README.md | 6 +- book/src/setup.md | 1 + book/src/week2-06-prefill-and-batch.md | 6 ++ src/tiny_llm/kv_cache.py | 5 +- src/tiny_llm_ref/kv_cache.py | 29 ++++++---- tests_refsol/test_week_2_day_6.py | 78 ++++++++++++++++++++++++++ 7 files changed, 113 insertions(+), 17 deletions(-) diff --git a/.github/workflows/macos.yml b/.github/workflows/macos.yml index 719da95..f991763 100644 --- a/.github/workflows/macos.yml +++ b/.github/workflows/macos.yml @@ -15,6 +15,11 @@ jobs: steps: - uses: actions/checkout@v5 + - name: Install HuggingFace weights + run: | + brew install huggingface-cli + hf download Qwen/Qwen2-0.5B-Instruct-MLX + - uses: pdm-project/setup-pdm@v4 with: python-version: 3.12 diff --git a/README.md b/README.md index cda99fd..38ba9de 100644 --- a/README.md +++ b/README.md @@ -42,14 +42,14 @@ Week 1 is complete. Week 2 is in progress. | 2.3 | Quantized Matmul and Linear - GPU | ✅ | ✅ | 🚧 | | 2.4 | Flash Attention 2 - CPU | ✅ | ✅ | 🚧 | | 2.5 | Flash Attention 2 - GPU | ✅ | ✅ | 🚧 | -| 2.6 | Continuous Batching | ✅ | 🚧 | ✅ | -| 2.7 | Chunked Prefill | ✅ | 🚧 | ✅ | +| 2.6 | Continuous Batching | ✅ | ✅ | ✅ | +| 2.7 | Chunked Prefill | ✅ | ✅ | ✅ | | 3.1 | Paged Attention - Part 1 | 🚧 | 🚧 | 🚧 | | 3.2 | Paged Attention - Part 2 | 🚧 | 🚧 | 🚧 | | 3.3 | MoE (Mixture of Experts) | 🚧 | 🚧 | 🚧 | | 3.4 | Speculative Decoding | 🚧 | 🚧 | 🚧 | | 3.5 | RAG Pipeline | 🚧 | 🚧 | 🚧 | | 3.6 | AI Agent / Tool Calling | 🚧 | 🚧 | 🚧 | -| 3.7 | Long Context | 🚧 | 🚧 | 🚧 | +| 3.7 | Long Context | 🚧 | 🚧 | 🚧 | Other topics not covered: quantized/compressed kv cache, prefix/prompt cache; sampling, fine tuning; smaller kernels (softmax, silu, etc) diff --git a/book/src/setup.md b/book/src/setup.md index 7139808..9028de9 100644 --- a/book/src/setup.md +++ b/book/src/setup.md @@ -60,6 +60,7 @@ them with: ```bash huggingface-cli login +huggingface-cli download Qwen/Qwen2-0.5B-Instruct-MLX huggingface-cli download Qwen/Qwen2-7B-Instruct-MLX ``` diff --git a/book/src/week2-06-prefill-and-batch.md b/book/src/week2-06-prefill-and-batch.md index b5cfa5d..172bb34 100644 --- a/book/src/week2-06-prefill-and-batch.md +++ b/book/src/week2-06-prefill-and-batch.md @@ -92,6 +92,12 @@ src/tiny_llm/qwen2_week2.py Ensure your model can handle multiple requests simultaneously. You should also use the masks returned by the batch KV cache. +You should pass all of the tests by running: + +```bash +pdm run test --week 2 --day 6 -- -k task_3 +``` + ## Task 4: Batch Generate ``` diff --git a/src/tiny_llm/kv_cache.py b/src/tiny_llm/kv_cache.py index fb6e666..0b98a5a 100644 --- a/src/tiny_llm/kv_cache.py +++ b/src/tiny_llm/kv_cache.py @@ -1,9 +1,11 @@ +from abc import ABC, abstractmethod from typing import Optional import mlx.core as mx -class TinyKvCache: +class TinyKvCache(ABC): + @abstractmethod def update_and_fetch( self, key: mx.array, @@ -26,7 +28,6 @@ def update_and_fetch( In week 2 day 6/7, we need to return the updated key-value cache, the updated value, the sequence length, and the mask. so that the batching kv cache can use this information to generate the mask. """ - pass class BatchingKvCache(TinyKvCache): diff --git a/src/tiny_llm_ref/kv_cache.py b/src/tiny_llm_ref/kv_cache.py index 9a90bf8..f9b2499 100644 --- a/src/tiny_llm_ref/kv_cache.py +++ b/src/tiny_llm_ref/kv_cache.py @@ -1,10 +1,12 @@ +from abc import ABC, abstractmethod from typing import Optional from .attention import causal_mask import mlx.core as mx -class TinyKvCache: +class TinyKvCache(ABC): + @abstractmethod def update_and_fetch( self, key: mx.array, @@ -24,7 +26,6 @@ def update_and_fetch( Returns: A tuple of the updated key-value cache, the updated value, the sequence length, and the mask. """ - pass class BatchingKvCache(TinyKvCache): @@ -44,7 +45,10 @@ def update_and_fetch( B, H, S, D = keys.shape assert keys.shape == values.shape assert S <= self.max_seq_len - assert self.HD == (H, D), f"expect {self.HD} but got {H, D}" + if self.HD is None: + self.HD = (H, D) + else: + assert self.HD == (H, D), f"expect {self.HD} but got {H, D}" assert B == self.max_active_requests # Step 1: append the result to the cache data = [] @@ -88,19 +92,20 @@ def get_seq_len(data): elif isinstance(mask, mx.array): masks[b, :, seq_len - S : seq_len] = mask else: - raise NotImplemented + raise NotImplementedError return keys, values, None, masks.reshape(B, 1, mask_length, seq_len) def add_request(self, prefilled: TinyKvCache, id: int): if id >= self.max_active_requests: raise ValueError(f"Request id {id} is out of range") - keys, _ = prefilled.key_values - B, H, _, D = keys.shape - assert B == 1 - if self.HD is None: - self.HD = (H, D) - else: - assert self.HD == (H, D) + if getattr(prefilled, "key_values", None) is not None: + keys, _ = prefilled.key_values + B, H, _, D = keys.shape + assert B == 1 + if self.HD is None: + self.HD = (H, D) + else: + assert self.HD == (H, D) self.kv_caches[id] = prefilled def remove_request(self, id: int): @@ -126,7 +131,7 @@ def update_and_fetch( self.key_values = (key, value) B, H, S, D = key.shape self.offset = S - return key, value, 0, mask + return key, value, self.offset, mask else: B, H, S, D = key.shape assert key.shape == value.shape diff --git a/tests_refsol/test_week_2_day_6.py b/tests_refsol/test_week_2_day_6.py index 12745bf..6ad1265 100644 --- a/tests_refsol/test_week_2_day_6.py +++ b/tests_refsol/test_week_2_day_6.py @@ -1,6 +1,8 @@ import mlx.core as mx import numpy as np import pytest +from mlx_lm import load + from .tiny_llm_base import * from .utils import * @@ -174,3 +176,79 @@ def test_task_1_attention_with_mask_gpu(): def test_task_1_attention_with_mask_gpu_large(): attention_helper(mx.gpu, 28, 4, 16, 128, 16, 3, use_flash_attention=False) + + +def helper_test_task_3(model_name: str, seq_len: int, iters: int = 1): + """Tests for continuous batching of decode requests.""" + requests = 4 + max_seq_len = seq_len + + mlx_model, tokenizer = load(model_name) + model = Qwen2ModelWeek2(mlx_model) + for _ in range(iters): + cache = [ + BatchingKvCache(requests, max_seq_len) + for _ in range(model.num_hidden_layers) + ] + # Start each request at a staggered token index. + staggered_start = [seq_len * i // requests for i in range(requests)] + inputs = mx.random.randint(0, tokenizer.vocab_size, (requests, seq_len)) + ref_outputs = mlx_model(inputs) + for offset in range(seq_len + staggered_start[-1]): + seq_idx = [offset - start for start in staggered_start] + + # Requests join at the staggered start, and leave when they reach seq_len. + for request_id, sidx in enumerate(seq_idx): + if sidx == 0: + for c in cache: + c.add_request(TinyKvFullCache(), request_id) + elif sidx == seq_len: + for c in cache: + c.remove_request(request_id) + + next_tokens = [] + next_offsets = [] + for request_id, sidx in enumerate(seq_idx): + if 0 <= sidx < seq_len: + next_tokens.append(inputs[request_id, sidx].item()) + next_offsets.append(sidx) + else: + next_tokens.append(0) + next_offsets.append(0) + + user_out = model( + inputs=mx.array(next_tokens, dtype=mx.int32).reshape(-1, 1), + offset=mx.array(next_offsets, dtype=mx.int32), + cache=cache, + ) + + for request_id, sidx in enumerate(seq_idx): + if 0 <= sidx < seq_len: + user_out_r = user_out[request_id, 0, :] + ref_out_r = ref_outputs[request_id, sidx, :] + user_out_r = user_out_r - mx.logsumexp(user_out_r, keepdims=True) + ref_out_r = ref_out_r - mx.logsumexp(ref_out_r, keepdims=True) + assert_allclose( + user_out_r, ref_out_r, precision=mx.float16, rtol=1e-1 + ) + + +@pytest.mark.skipif( + not qwen_2_05b_model_exists(), reason="Qwen2-0.5B-Instruct-MLX model not found" +) +def test_task_3_qwen_2_05b(): + helper_test_task_3("Qwen/Qwen2-0.5B-Instruct-MLX", seq_len=3) + + +@pytest.mark.skipif( + not qwen_2_7b_model_exists(), reason="Qwen2-7B-Instruct-MLX model not found" +) +def test_task_3_qwen_2_7b(): + helper_test_task_3("Qwen/Qwen2-7B-Instruct-MLX", seq_len=3) + + +@pytest.mark.skipif( + not qwen_2_15b_model_exists(), reason="Qwen2-1.5B-Instruct-MLX model not found" +) +def test_task_3_qwen_2_15b(): + helper_test_task_3("Qwen/Qwen2-1.5B-Instruct-MLX", seq_len=3) From b6a3b0086f10af143df46b43bb6c8a7ec87d79e7 Mon Sep 17 00:00:00 2001 From: yangpeng <594424877@qq.com> Date: Sun, 21 Sep 2025 23:55:08 +0800 Subject: [PATCH 53/79] update dev-tools.py to fix --force in copy-test (#70) --- scripts/dev-tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/dev-tools.py b/scripts/dev-tools.py index 177ab67..20b64ff 100644 --- a/scripts/dev-tools.py +++ b/scripts/dev-tools.py @@ -68,7 +68,7 @@ def main(): test_refsol_parser.set_defaults(test_refsol_parser=True) args = parser.parse_args() if hasattr(args, "copy_test_parser"): - copy_test(args, args.force) + copy_test(args, force=args.force) if hasattr(args, "test_parser"): test(args) if hasattr(args, "test_refsol_parser"): From ad6d97677f99f04bc39724dd9b2cfb6587a9ab2f Mon Sep 17 00:00:00 2001 From: Alex Chi Z <4198311+skyzh@users.noreply.github.com> Date: Fri, 26 Sep 2025 01:30:00 -0400 Subject: [PATCH 54/79] add speculative decoding (#71) * add speculative decoding Signed-off-by: Alex Chi Z * update readme Signed-off-by: Alex Chi Z --------- Signed-off-by: Alex Chi Z --- README.md | 2 +- main.py | 23 ++++++++- src/tiny_llm_ref/generate.py | 95 ++++++++++++++++++++++++++++++++++++ src/tiny_llm_ref/kv_cache.py | 4 ++ 4 files changed, 122 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 38ba9de..55636b6 100644 --- a/README.md +++ b/README.md @@ -47,7 +47,7 @@ Week 1 is complete. Week 2 is in progress. | 3.1 | Paged Attention - Part 1 | 🚧 | 🚧 | 🚧 | | 3.2 | Paged Attention - Part 2 | 🚧 | 🚧 | 🚧 | | 3.3 | MoE (Mixture of Experts) | 🚧 | 🚧 | 🚧 | -| 3.4 | Speculative Decoding | 🚧 | 🚧 | 🚧 | +| 3.4 | Speculative Decoding | 🚧 | ✅ | 🚧 | | 3.5 | RAG Pipeline | 🚧 | 🚧 | 🚧 | | 3.6 | AI Agent / Tool Calling | 🚧 | 🚧 | 🚧 | | 3.7 | Long Context | 🚧 | 🚧 | 🚧 | diff --git a/main.py b/main.py index 05fdd98..f4ba9c3 100644 --- a/main.py +++ b/main.py @@ -7,6 +7,7 @@ parser = argparse.ArgumentParser() parser.add_argument("--model", type=str, default="qwen2-7b") +parser.add_argument("--draft-model", type=str, default=None) parser.add_argument( "--prompt", type=str, @@ -39,6 +40,7 @@ models, simple_generate, simple_generate_with_kv_cache, + speculative_generate, sampler, ) @@ -53,6 +55,15 @@ args.model = models.shortcut_name_to_full_name(args.model) mlx_model, tokenizer = load(args.model) +if args.draft_model: + args.draft_model = models.shortcut_name_to_full_name(args.draft_model) + draft_mlx_model, draft_tokenizer = load(args.draft_model) + if args.loader == "week1": + raise ValueError("Draft model not supported for week1") +else: + draft_mlx_model = None + draft_tokenizer = None + with mx.stream(mx.gpu if args.device == "gpu" else mx.cpu): if use_mlx: tiny_llm_model = mlx_model @@ -67,6 +78,13 @@ tiny_llm_model = models.dispatch_model( args.model, mlx_model, week=2, enable_flash_attn=args.enable_flash_attn ) + if draft_mlx_model is not None: + print(f"Using draft model {args.draft_model}") + draft_tiny_llm_model = models.dispatch_model( + args.draft_model, draft_mlx_model, week=2, enable_flash_attn=args.enable_flash_attn + ) + else: + draft_tiny_llm_model = None else: raise ValueError(f"Loader {args.loader} not supported") messages = [ @@ -86,7 +104,10 @@ if args.loader == "week1": simple_generate(tiny_llm_model, tokenizer, prompt, sampler=sampler) elif args.loader == "week2": - simple_generate_with_kv_cache(tiny_llm_model, tokenizer, prompt) + if draft_tiny_llm_model is not None: + speculative_generate(draft_tiny_llm_model, tiny_llm_model, draft_tokenizer, tokenizer, prompt) + else: + simple_generate_with_kv_cache(tiny_llm_model, tokenizer, prompt) else: sampler = mlx_lm.sample_utils.make_sampler( args.sampler_temp, top_p=args.sampler_top_p, top_k=args.sampler_top_k diff --git a/src/tiny_llm_ref/generate.py b/src/tiny_llm_ref/generate.py index 6012ba8..117cb3c 100644 --- a/src/tiny_llm_ref/generate.py +++ b/src/tiny_llm_ref/generate.py @@ -69,3 +69,98 @@ def _step(model, y, offset, kv_cache): # Otherwise, we add the decoded token size (which is always 1). offset += tokens.size tokens = token + +def speculative_generate( + draft_model: Qwen2ModelWeek2, model: Qwen2ModelWeek2, draft_tokenizer: TokenizerWrapper, tokenizer: TokenizerWrapper, prompt: str +) -> str: + draft_kv_cache = [TinyKvFullCache() for _ in range(draft_model.num_hidden_layers)] + kv_cache = [TinyKvFullCache() for _ in range(model.num_hidden_layers)] + + def _step(model, y, offset, kv_cache, n_tokens=1): + logits = model(y[None], offset, kv_cache) + if n_tokens > 1: + logits = logits[:, -n_tokens:, :] + else: + logits = logits[:, -1, :] + logprobs = logits - mx.logsumexp(logits, keepdims=True) + sampler = lambda x: mx.argmax(x, axis=-1) + y = sampler(logprobs) + return y, logprobs.squeeze(0) + + # prefill with the prompt, using the large model + def _prefill(model, tokenizer, prompt, kv_cache): + prefill_tokens = mx.array(tokenizer.encode(prompt, add_special_tokens=False)) + offset = 0 + token, _ = _step(model, prefill_tokens, offset, kv_cache) + mx.eval(token) + if token.item() == tokenizer.eos_token_id: + return + offset = prefill_tokens.size + return token, offset + + draft_token, draft_offset = _prefill(draft_model, draft_tokenizer, prompt, draft_kv_cache) + token, offset = _prefill(model, tokenizer, prompt, kv_cache) + + def _decode_one(token, tokenizer): + if token.item() == tokenizer.eos_token_id: + return False + detokenizer = tokenizer.detokenizer + detokenizer.add_token(token.item()) + return True + + + def draft_generate(model, last_token, offset, kv_cache, num_drafts): + tokens = [] + for _ in range(num_drafts): + token, _ = _step(model, last_token, offset, kv_cache) + mx.eval(token) + tokens.append(token.item()) + last_token = token + return tokens + + num_drafts = 4 + + def _rewind_cache(kv_cache, revert_len): + for layer in kv_cache: + layer.rewind(revert_len) + + def _print_text(text, progress): + print(f"+{progress} {text.replace('\n', ' ')[-80:]}") + + # speculative decode + while True: + draft_tokens = draft_generate(draft_model, token, draft_offset, draft_kv_cache, num_drafts) + draft_offset += num_drafts + # assume both models use the same tokenizer + draft_tokens = mx.concat([token, mx.array(draft_tokens)]) + new_tokens, _ = _step(model, draft_tokens, offset, kv_cache, num_drafts + 1) + new_tokens = new_tokens.tolist()[0] + offset += num_drafts + 1 + last_new_token = new_tokens[-1] + new_tokens = mx.array([token.item()] + new_tokens[:-1]) + assert len(new_tokens) == len(draft_tokens) + accept_all = True + for i in range(len(new_tokens)): + if new_tokens[i] != draft_tokens[i]: + # revert the full draft generation; re-generate next time + # or we matched full, then no rewind and use the last token + assert i >= 1 # first token is always the same + revert_len = len(draft_tokens) - i + _rewind_cache(draft_kv_cache, revert_len - 1) + draft_offset -= revert_len - 1 + _rewind_cache(kv_cache, revert_len) + token = mx.array([new_tokens[i]]) + offset -= revert_len + assert offset == draft_offset + assert offset == kv_cache[0].offset + _print_text(tokenizer._detokenizer.text, i) + accept_all = False + break + if not _decode_one(new_tokens[i], tokenizer): + print(tokenizer._detokenizer.text) + return + if accept_all: + _print_text(tokenizer._detokenizer.text, len(new_tokens)) + draft_generate(draft_model, mx.array(draft_tokens[-1:]), draft_offset, draft_kv_cache, 1) + token = mx.array([last_new_token]) + draft_offset += 1 diff --git a/src/tiny_llm_ref/kv_cache.py b/src/tiny_llm_ref/kv_cache.py index f9b2499..594ceba 100644 --- a/src/tiny_llm_ref/kv_cache.py +++ b/src/tiny_llm_ref/kv_cache.py @@ -143,3 +143,7 @@ def update_and_fetch( self.key_values = (new_keys, new_values) self.offset += S return new_keys, new_values, self.offset, mask + + def rewind(self, n: int): + self.offset -= n + self.key_values = (self.key_values[0][:, :, :self.offset], self.key_values[1][:, :, :self.offset]) \ No newline at end of file From ff5d7d0ecb2565ebc6e8ac9da669a222a5405aea Mon Sep 17 00:00:00 2001 From: Alex Chi Z Date: Fri, 26 Sep 2025 01:37:04 -0400 Subject: [PATCH 55/79] ensure user solution can run Signed-off-by: Alex Chi Z --- main.py | 14 ++++++++++++-- src/tiny_llm/generate.py | 10 ++++++++++ src/tiny_llm_ref/generate.py | 28 +++++++++++++++++++++------- src/tiny_llm_ref/kv_cache.py | 5 ++++- 4 files changed, 47 insertions(+), 10 deletions(-) diff --git a/main.py b/main.py index f4ba9c3..2fb26a4 100644 --- a/main.py +++ b/main.py @@ -31,6 +31,7 @@ models, simple_generate, simple_generate_with_kv_cache, + speculative_generate, sampler, ) @@ -81,7 +82,10 @@ if draft_mlx_model is not None: print(f"Using draft model {args.draft_model}") draft_tiny_llm_model = models.dispatch_model( - args.draft_model, draft_mlx_model, week=2, enable_flash_attn=args.enable_flash_attn + args.draft_model, + draft_mlx_model, + week=2, + enable_flash_attn=args.enable_flash_attn, ) else: draft_tiny_llm_model = None @@ -105,7 +109,13 @@ simple_generate(tiny_llm_model, tokenizer, prompt, sampler=sampler) elif args.loader == "week2": if draft_tiny_llm_model is not None: - speculative_generate(draft_tiny_llm_model, tiny_llm_model, draft_tokenizer, tokenizer, prompt) + speculative_generate( + draft_tiny_llm_model, + tiny_llm_model, + draft_tokenizer, + tokenizer, + prompt, + ) else: simple_generate_with_kv_cache(tiny_llm_model, tokenizer, prompt) else: diff --git a/src/tiny_llm/generate.py b/src/tiny_llm/generate.py index 67df12c..d2bc3ed 100644 --- a/src/tiny_llm/generate.py +++ b/src/tiny_llm/generate.py @@ -20,3 +20,13 @@ def simple_generate_with_kv_cache( ) -> str: def _step(model, y, offset, kv_cache): pass + + +def speculative_generate( + draft_model: Qwen2ModelWeek2, + model: Qwen2ModelWeek2, + draft_tokenizer: TokenizerWrapper, + tokenizer: TokenizerWrapper, + prompt: str, +) -> str: + pass diff --git a/src/tiny_llm_ref/generate.py b/src/tiny_llm_ref/generate.py index 117cb3c..2adc5ae 100644 --- a/src/tiny_llm_ref/generate.py +++ b/src/tiny_llm_ref/generate.py @@ -70,8 +70,13 @@ def _step(model, y, offset, kv_cache): offset += tokens.size tokens = token + def speculative_generate( - draft_model: Qwen2ModelWeek2, model: Qwen2ModelWeek2, draft_tokenizer: TokenizerWrapper, tokenizer: TokenizerWrapper, prompt: str + draft_model: Qwen2ModelWeek2, + model: Qwen2ModelWeek2, + draft_tokenizer: TokenizerWrapper, + tokenizer: TokenizerWrapper, + prompt: str, ) -> str: draft_kv_cache = [TinyKvFullCache() for _ in range(draft_model.num_hidden_layers)] kv_cache = [TinyKvFullCache() for _ in range(model.num_hidden_layers)] @@ -98,7 +103,9 @@ def _prefill(model, tokenizer, prompt, kv_cache): offset = prefill_tokens.size return token, offset - draft_token, draft_offset = _prefill(draft_model, draft_tokenizer, prompt, draft_kv_cache) + draft_token, draft_offset = _prefill( + draft_model, draft_tokenizer, prompt, draft_kv_cache + ) token, offset = _prefill(model, tokenizer, prompt, kv_cache) def _decode_one(token, tokenizer): @@ -108,7 +115,6 @@ def _decode_one(token, tokenizer): detokenizer.add_token(token.item()) return True - def draft_generate(model, last_token, offset, kv_cache, num_drafts): tokens = [] for _ in range(num_drafts): @@ -129,7 +135,9 @@ def _print_text(text, progress): # speculative decode while True: - draft_tokens = draft_generate(draft_model, token, draft_offset, draft_kv_cache, num_drafts) + draft_tokens = draft_generate( + draft_model, token, draft_offset, draft_kv_cache, num_drafts + ) draft_offset += num_drafts # assume both models use the same tokenizer draft_tokens = mx.concat([token, mx.array(draft_tokens)]) @@ -144,7 +152,7 @@ def _print_text(text, progress): if new_tokens[i] != draft_tokens[i]: # revert the full draft generation; re-generate next time # or we matched full, then no rewind and use the last token - assert i >= 1 # first token is always the same + assert i >= 1 # first token is always the same revert_len = len(draft_tokens) - i _rewind_cache(draft_kv_cache, revert_len - 1) draft_offset -= revert_len - 1 @@ -158,9 +166,15 @@ def _print_text(text, progress): break if not _decode_one(new_tokens[i], tokenizer): print(tokenizer._detokenizer.text) - return + return tokenizer._detokenizer.text if accept_all: _print_text(tokenizer._detokenizer.text, len(new_tokens)) - draft_generate(draft_model, mx.array(draft_tokens[-1:]), draft_offset, draft_kv_cache, 1) + draft_generate( + draft_model, + mx.array(draft_tokens[-1:]), + draft_offset, + draft_kv_cache, + 1, + ) token = mx.array([last_new_token]) draft_offset += 1 diff --git a/src/tiny_llm_ref/kv_cache.py b/src/tiny_llm_ref/kv_cache.py index 594ceba..02be133 100644 --- a/src/tiny_llm_ref/kv_cache.py +++ b/src/tiny_llm_ref/kv_cache.py @@ -146,4 +146,7 @@ def update_and_fetch( def rewind(self, n: int): self.offset -= n - self.key_values = (self.key_values[0][:, :, :self.offset], self.key_values[1][:, :, :self.offset]) \ No newline at end of file + self.key_values = ( + self.key_values[0][:, :, : self.offset], + self.key_values[1][:, :, : self.offset], + ) From cf6910ae6005330f064c6f8bc231d3551201955c Mon Sep 17 00:00:00 2001 From: Connor1996 Date: Sun, 5 Oct 2025 21:53:01 -0700 Subject: [PATCH 56/79] add definition hint for model args Signed-off-by: Connor1996 --- book/src/week1-05-qwen2-model.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/book/src/week1-05-qwen2-model.md b/book/src/week1-05-qwen2-model.md index 979b347..b057893 100644 --- a/book/src/week1-05-qwen2-model.md +++ b/book/src/week1-05-qwen2-model.md @@ -129,7 +129,7 @@ Embedding::as_linear OR Linear (lm_head) output ``` -You can access the number of layers, hidden size, and other model parameters from `mlx_model.args`. Note that different +You can access the number of layers, hidden size, and other model parameters from `mlx_model.args` which is defined in [ModelArgs](https://github.com/ml-explore/mlx-lm/blob/f318741784496dc2025dd7a4dea1ae698d21c610/mlx_lm/models/qwen2.py#L14). Note that different size of the Qwen2 models use different strategies to map the embeddings back to the token space. For the 0.5b model, it directly uses the `Embedding::as_linear` layer. For the 7b model, it has a separate `lm_head` linear layer. You can decide which strategy to use based on the `mlx_model.args.tie_word_embeddings` argument. If it is true, then you should From a30f9c2bf8622f91cfc1c470616c039b22e7faff Mon Sep 17 00:00:00 2001 From: Connor1996 Date: Thu, 9 Oct 2025 22:07:10 -0700 Subject: [PATCH 57/79] add more info Signed-off-by: Connor1996 --- book/src/week1-05-qwen2-model.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/book/src/week1-05-qwen2-model.md b/book/src/week1-05-qwen2-model.md index b057893..ad97344 100644 --- a/book/src/week1-05-qwen2-model.md +++ b/book/src/week1-05-qwen2-model.md @@ -103,7 +103,6 @@ src/tiny_llm/qwen2_week1.py **📚 Readings** -- [Qwen2.5-7B-Instruct model parameters](https://huggingface.co/Qwen/Qwen2.5-7B-Instruct?show_file_info=model.safetensors.index.json) In this course, you will not implement the process of loading the model parameters from the tensor files. Instead, we will load the model using the `mlx-lm` library, and then we will place the loaded parameters into our model. Therefore, @@ -129,7 +128,9 @@ Embedding::as_linear OR Linear (lm_head) output ``` -You can access the number of layers, hidden size, and other model parameters from `mlx_model.args` which is defined in [ModelArgs](https://github.com/ml-explore/mlx-lm/blob/f318741784496dc2025dd7a4dea1ae698d21c610/mlx_lm/models/qwen2.py#L14). Note that different +You can access the number of layers, hidden size, and other model parameters from `mlx_model.args` which is defined in [ModelArgs](https://github.com/ml-explore/mlx-lm/blob/f318741784496dc2025dd7a4dea1ae698d21c610/mlx_lm/models/qwen2.py#L14). You can reach the loaded weights from `mlx_model.model` which is defined in [Qwen2Model](https://github.com/ml-explore/mlx-lm/blob/f318741784496dc2025dd7a4dea1ae698d21c610/mlx_lm/models/qwen2.py#L125-L133), the layers structure can be found in [Qwen2.5-7B-Instruct model parameters](https://huggingface.co/Qwen/Qwen2.5-7B-Instruct?show_file_info=model.safetensors.index.json) and [Qwen2-0.5B-Instruct model parameters](https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct?show_file_info=model.safetensors). + +Note that different size of the Qwen2 models use different strategies to map the embeddings back to the token space. For the 0.5b model, it directly uses the `Embedding::as_linear` layer. For the 7b model, it has a separate `lm_head` linear layer. You can decide which strategy to use based on the `mlx_model.args.tie_word_embeddings` argument. If it is true, then you should From 83762c8179b9ac23de68b22d82f2d46862715923 Mon Sep 17 00:00:00 2001 From: Connor1996 Date: Thu, 9 Oct 2025 22:09:05 -0700 Subject: [PATCH 58/79] rename Signed-off-by: Connor1996 --- book/src/week1-05-qwen2-model.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/book/src/week1-05-qwen2-model.md b/book/src/week1-05-qwen2-model.md index ad97344..b95a7e7 100644 --- a/book/src/week1-05-qwen2-model.md +++ b/book/src/week1-05-qwen2-model.md @@ -128,7 +128,7 @@ Embedding::as_linear OR Linear (lm_head) output ``` -You can access the number of layers, hidden size, and other model parameters from `mlx_model.args` which is defined in [ModelArgs](https://github.com/ml-explore/mlx-lm/blob/f318741784496dc2025dd7a4dea1ae698d21c610/mlx_lm/models/qwen2.py#L14). You can reach the loaded weights from `mlx_model.model` which is defined in [Qwen2Model](https://github.com/ml-explore/mlx-lm/blob/f318741784496dc2025dd7a4dea1ae698d21c610/mlx_lm/models/qwen2.py#L125-L133), the layers structure can be found in [Qwen2.5-7B-Instruct model parameters](https://huggingface.co/Qwen/Qwen2.5-7B-Instruct?show_file_info=model.safetensors.index.json) and [Qwen2-0.5B-Instruct model parameters](https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct?show_file_info=model.safetensors). +You can access the number of layers, hidden size, and other model parameters from `mlx_model.args` which is defined in [ModelArgs](https://github.com/ml-explore/mlx-lm/blob/f318741784496dc2025dd7a4dea1ae698d21c610/mlx_lm/models/qwen2.py#L14). You can reach the loaded weights from `mlx_model.model` which is defined in [Qwen2Model](https://github.com/ml-explore/mlx-lm/blob/f318741784496dc2025dd7a4dea1ae698d21c610/mlx_lm/models/qwen2.py#L125-L133), the layers structure can be found in [Qwen2.5-7B-Instruct model structure](https://huggingface.co/Qwen/Qwen2.5-7B-Instruct?show_file_info=model.safetensors.index.json) and [Qwen2-0.5B-Instruct model structure](https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct?show_file_info=model.safetensors). Note that different size of the Qwen2 models use different strategies to map the embeddings back to the token space. For the 0.5b model, it From f1f4f98a737afd48b9a2714bf6a838e465d29854 Mon Sep 17 00:00:00 2001 From: Eikasia30 Date: Sat, 11 Oct 2025 14:53:41 -0400 Subject: [PATCH 59/79] fix: fix link to Qwen2.5 blog in week1 (#72) Co-authored-by: Yangchen Ye --- book/src/week1-overview.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/book/src/week1-overview.md b/book/src/week1-overview.md index 128172d..83ed382 100644 --- a/book/src/week1-overview.md +++ b/book/src/week1-overview.md @@ -47,7 +47,7 @@ utilize these resources to better understand the internals of the model and what **📚 Readings** -- [Qwen2.5: A Party of Foundation Models!](https://qwenlm.github.io/blog/qwen2.5/) +- [Qwen2.5: A Party of Foundation Models!](https://qwen.ai/blog?id=6da44b4d3b48c53f5719bab9cc18b732a7065647&from=research.research-list) - [Key Concepts of the Qwen2 Model](https://qwen.readthedocs.io/en/latest/getting_started/concepts.html) - [Huggingface Transformers - Qwen2](https://github.com/huggingface/transformers/tree/main/src/transformers/models/qwen2) - [vLLM Qwen2](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/qwen2.py) From cea892662e7a15c98b7dbbafc37ce8b0a86794d6 Mon Sep 17 00:00:00 2001 From: Jasmine <87114666+jinhuix@users.noreply.github.com> Date: Mon, 13 Oct 2025 02:25:36 +0800 Subject: [PATCH 60/79] docs: add instruction to download Qwen2-1.5B model (#75) * docs: add instruction to download Qwen2-1.5B model --- book/src/week1-05-qwen2-model.md | 4 ++++ book/src/week1-06-generate-response.md | 7 +++++++ 2 files changed, 11 insertions(+) diff --git a/book/src/week1-05-qwen2-model.md b/book/src/week1-05-qwen2-model.md index b95a7e7..77cb19e 100644 --- a/book/src/week1-05-qwen2-model.md +++ b/book/src/week1-05-qwen2-model.md @@ -6,6 +6,7 @@ Before we start, please make sure you have downloaded the models: ```bash huggingface-cli download Qwen/Qwen2-0.5B-Instruct-MLX +huggingface-cli download Qwen/Qwen2-1.5B-Instruct-MLX huggingface-cli download Qwen/Qwen2-7B-Instruct-MLX ``` @@ -47,6 +48,7 @@ You should pass all tests for this task by running: ```bash # Download the models if you haven't done so huggingface-cli download Qwen/Qwen2-0.5B-Instruct-MLX +huggingface-cli download Qwen/Qwen2-1.5B-Instruct-MLX huggingface-cli download Qwen/Qwen2-7B-Instruct-MLX # Run the tests pdm run test --week 1 --day 5 -- -k task_1 @@ -88,6 +90,7 @@ You should pass all tests for this task by running: ```bash # Download the models if you haven't done so; we need to tokenizers huggingface-cli download Qwen/Qwen2-0.5B-Instruct-MLX +huggingface-cli download Qwen/Qwen2-1.5B-Instruct-MLX huggingface-cli download Qwen/Qwen2-7B-Instruct-MLX # Run the tests pdm run test --week 1 --day 5 -- -k task_2 @@ -152,6 +155,7 @@ You should pass all tests for this task by running: ```bash # Download the models if you haven't done so huggingface-cli download Qwen/Qwen2-0.5B-Instruct-MLX +huggingface-cli download Qwen/Qwen2-1.5B-Instruct-MLX huggingface-cli download Qwen/Qwen2-7B-Instruct-MLX # Run the tests pdm run test --week 1 --day 5 -- -k task_3 diff --git a/book/src/week1-06-generate-response.md b/book/src/week1-06-generate-response.md index 8aae630..2c27459 100644 --- a/book/src/week1-06-generate-response.md +++ b/book/src/week1-06-generate-response.md @@ -58,8 +58,15 @@ We will optimize the `decode` process to use key-value cache to speed up the gen You can test your implementation by running the following command: ```bash +# Download the models if you haven't done so +huggingface-cli download Qwen/Qwen2-0.5B-Instruct-MLX +huggingface-cli download Qwen/Qwen2-1.5B-Instruct-MLX +huggingface-cli download Qwen/Qwen2-7B-Instruct-MLX +# Run the tests pdm run main --solution tiny_llm --loader week1 --model qwen2-0.5b \ --prompt "Give me a short introduction to large language model" +pdm run main --solution tiny_llm --loader week1 --model qwen2-1.5b \ + --prompt "Give me a short introduction to large language model" pdm run main --solution tiny_llm --loader week1 --model qwen2-7b \ --prompt "Give me a short introduction to large language model" ``` From 5dc71b8be0c6b8b6e88b1404b3fc6584bd5b1703 Mon Sep 17 00:00:00 2001 From: Connor Date: Sat, 1 Nov 2025 21:40:27 -0700 Subject: [PATCH 61/79] perform pdm sync before running (#76) Signed-off-by: Connor1996 --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 5e57cab..9e5cae1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ dependencies = [ ] [tool.pdm.scripts] +pre_run = { shell = "pdm sync --no-self > /dev/null" } build-ext.cmd = "python build.py" build-ext.working_dir = "src/extensions" build-ext-test.cmd = "python test.py" From ace6e45b756095e8e12406094d6ae375bbe247c9 Mon Sep 17 00:00:00 2001 From: Gao Date: Thu, 18 Dec 2025 14:06:13 +0800 Subject: [PATCH 62/79] Fix f-string syntax (#81) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Extract newline character to a variable to avoid backslash in f-string expression part. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-authored-by: Claude --- src/tiny_llm_ref/generate.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/tiny_llm_ref/generate.py b/src/tiny_llm_ref/generate.py index 2adc5ae..1b2238c 100644 --- a/src/tiny_llm_ref/generate.py +++ b/src/tiny_llm_ref/generate.py @@ -131,7 +131,8 @@ def _rewind_cache(kv_cache, revert_len): layer.rewind(revert_len) def _print_text(text, progress): - print(f"+{progress} {text.replace('\n', ' ')[-80:]}") + newline = '\n' + print(f"+{progress} {text.replace(newline, ' ')[-80:]}") # speculative decode while True: From 5b6fdc3eae7b59d74a65d77ca971f061b99aed83 Mon Sep 17 00:00:00 2001 From: Liu Jinyi Date: Thu, 18 Dec 2025 14:06:31 +0800 Subject: [PATCH 63/79] fix: draft-generate offset (#83) Signed-off-by: KKKZOZ --- src/tiny_llm_ref/generate.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/tiny_llm_ref/generate.py b/src/tiny_llm_ref/generate.py index 1b2238c..92931d2 100644 --- a/src/tiny_llm_ref/generate.py +++ b/src/tiny_llm_ref/generate.py @@ -117,11 +117,13 @@ def _decode_one(token, tokenizer): def draft_generate(model, last_token, offset, kv_cache, num_drafts): tokens = [] + current_offset = offset for _ in range(num_drafts): - token, _ = _step(model, last_token, offset, kv_cache) + token, _ = _step(model, last_token, current_offset, kv_cache) mx.eval(token) tokens.append(token.item()) last_token = token + current_offset += 1 return tokens num_drafts = 4 From 16f55c7fcb535f04913c5b8871919d2d2e0a70dc Mon Sep 17 00:00:00 2001 From: yangpeng Date: Thu, 18 Dec 2025 14:06:46 +0800 Subject: [PATCH 64/79] fix mx.logsumexp with the right dim (#80) --- src/tiny_llm/batch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tiny_llm/batch.py b/src/tiny_llm/batch.py index f8b5a08..329971c 100644 --- a/src/tiny_llm/batch.py +++ b/src/tiny_llm/batch.py @@ -9,7 +9,7 @@ def _step(model, y, offsets, kv_cache): logits = model(y, offsets, kv_cache) logits = logits[:, -1, :] - logprobs = logits - mx.logsumexp(logits, keepdims=True) + logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) sampler = lambda x: mx.argmax(x, axis=-1) y = sampler(logprobs) return y From 685caf50e7b5c7eb70d54af12452d8fba92a172f Mon Sep 17 00:00:00 2001 From: Lingching <92594709+Elubrazione@users.noreply.github.com> Date: Thu, 18 Dec 2025 14:07:08 +0800 Subject: [PATCH 65/79] feat: implement quantized_matmul with typed CPU implementation (#77) - Add complete quantized_matmul_impl_typed template function for CPU, which support float16, float32, and bfloat16 data types - Add float32 test cases for quantized_matmul - Adjust float32 tolerance in test utils for better precision --- src/extensions_ref/src/quantized_matmul.cpp | 97 +++++++++++++++++++-- src/extensions_ref/src/tiny_llm_ext.h | 2 +- tests/utils.py | 2 +- tests_refsol/test_week_2_day_2.py | 8 ++ 4 files changed, 100 insertions(+), 9 deletions(-) diff --git a/src/extensions_ref/src/quantized_matmul.cpp b/src/extensions_ref/src/quantized_matmul.cpp index 0420414..86886b9 100644 --- a/src/extensions_ref/src/quantized_matmul.cpp +++ b/src/extensions_ref/src/quantized_matmul.cpp @@ -1,7 +1,8 @@ #include -#include -#include +#include "mlx/array.h" +#include "mlx/device.h" +#include "mlx/dtype.h" #include "mlx/backend/common/utils.h" #include "mlx/backend/cpu/encoder.h" #include "mlx/utils.h" @@ -9,7 +10,6 @@ #ifdef _METAL_ #include "mlx/backend/metal/device.h" -#include "mlx/backend/metal/utils.h" #endif namespace tiny_llm_ext_ref { @@ -23,8 +23,8 @@ mx::array quantized_matmul(const mx::array &scales, // Input array scale const bool transpose_b, // Whether to transpose b mx::StreamOrDevice s /* = {} */ // Stream on which to schedule the operation ) { - if (scales.dtype() != mx::float16 && scales.dtype() != mx::bfloat16) { - throw std::runtime_error("quantized_matmul: scales must be float16 or bfloat16"); + if (scales.dtype() != mx::float16 && scales.dtype() != mx::bfloat16 && scales.dtype() != mx::float32) { + throw std::runtime_error("quantized_matmul: scales must be float16 or bfloat16 or float32"); } if (scales.dtype() != biases.dtype()) { throw std::runtime_error("quantized_matmul: scales and biases must be the same dtype"); @@ -143,6 +143,77 @@ void quantized_matmul_impl(const mx::array &scales, const mx::array &biases, con }); } +template +void quantized_matmul_impl_typed( + const mx::array &scales, const mx::array &biases, + const mx::array &a, const mx::array &b, + mx::array &out, int group_size, int bits, mx::Stream stream) { + + out.set_data(mx::allocator::malloc(out.nbytes())); + auto &encoder = mx::cpu::get_command_encoder(stream); + encoder.set_input_array(scales); + encoder.set_input_array(biases); + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_output_array(out); + + encoder.dispatch([out_ptr = out.data(), out_shape = out.shape(), out_strides = out.strides(), + group_size = group_size, bits = bits, + a = mx::array::unsafe_weak_copy(a), b = mx::array::unsafe_weak_copy(b), + scales = mx::array::unsafe_weak_copy(scales), biases = mx::array::unsafe_weak_copy(biases)]() { + + // each `group_size` continuous weighted elements are packed into a group and each weight is quantized into `bits` bits + // thus each `group_size` continuous weighted elements takes `group_size * bits / 32` uint32_t elements in b + // when decoding the group of weights, the scales and biases are repeated for `group_size` times (shared by all elements in the group) + int m = a.shape()[0], n = a.shape()[1], k = b.shape()[0]; + + // row => group => item => pack + const int group_per_row = n / group_size; // b[k, :] = [ group_0, group_1, ..., group_(group_per_row-1) ] + const int packs_per_item = 32 / bits; // each uint32_t element can store `packs_per_item` packed elements + const int items_per_group = group_size / packs_per_item; // each group contains `items_per_group` uint32_t elements + + const T *a_ptr = a.data(), + *scales_ptr = scales.data(), *biases_ptr = biases.data(); + const uint32_t *b_ptr = b.data(); + + uint32_t pack_mask = (1 << bits) - 1; + + for (int i = 0; i < m; i++) { + for (int j = 0; j < k; j++) { + float sum = 0; + for (int group_idx = 0; group_idx < group_per_row; group_idx++) { + int64_t scales_idx = mx::elem_to_loc(j * group_per_row + group_idx, scales.shape(), scales.strides()); + int64_t biases_idx = mx::elem_to_loc(j * group_per_row + group_idx, biases.shape(), biases.strides()); + T scale = scales_ptr[scales_idx], bias = biases_ptr[biases_idx]; + + int64_t a_idx = mx::elem_to_loc(i * n + group_idx * group_size, a.shape(), a.strides()); + int64_t b_idx = mx::elem_to_loc((j * n + group_idx * group_size) / packs_per_item, b.shape(), b.strides()); + + for (int item_idx = 0; item_idx < items_per_group; item_idx++) { + uint32_t b_val = b_ptr[b_idx]; // fetch one uint32_t element in current group (item), so we use type uint32_t to store it + uint8_t *b_bytes = reinterpret_cast(&b_val); // reinterpret the uint32_t element as a byte array (32 = one byte * 4) + + for (int pack_idx = 0; pack_idx < packs_per_item; pack_idx++) { + // extract the pack(4 bits) from the byte array + // pack_idx / 2 is the index of the byte array, and (pack_idx % 2) * bits is the shift amount + // when pack_idx is even, extract the low 4 bits, otherwise extract the high 4 bits + // (pack_7, pack_6, pack_5, pack_4, pack_3, pack_2, pack_1, pack_0) => (b_bytes[3], b_bytes[2], b_bytes[1], b_bytes[0]) + uint8_t item_val = (b_bytes[pack_idx / 2] >> ((pack_idx % 2) * bits)) & pack_mask; + float a_val = static_cast(a_ptr[a_idx]); + float b_val_real = static_cast(item_val) * static_cast(scale) + static_cast(bias); + sum += a_val * b_val_real; + a_idx += 1; + } + b_idx += 1; + } + } + int64_t out_idx = mx::elem_to_loc(i * k + j, out_shape, out_strides); + out_ptr[out_idx] = static_cast(sum); + } + } + }); +} + void QuantizedMatmul::eval_cpu(const std::vector &inputs, std::vector &outputs) { auto &scales = inputs[0]; auto &biases = inputs[1]; @@ -150,8 +221,20 @@ void QuantizedMatmul::eval_cpu(const std::vector &inputs, std::vector auto &b = inputs[3]; auto &out = outputs[0]; - // TODO: dispatch to f32, f16, bf16 - quantized_matmul_impl(scales, biases, a, b, out, group_size_, bits_, stream()); + // quantized_matmul_impl(scales, biases, a, b, out, group_size_, bits_, stream()); + switch (a.dtype()) { + case mx::float16: + quantized_matmul_impl_typed(scales, biases, a, b, out, group_size_, bits_, stream()); + break; + case mx::float32: + quantized_matmul_impl_typed(scales, biases, a, b, out, group_size_, bits_, stream()); + break; + case mx::bfloat16: + quantized_matmul_impl_typed(scales, biases, a, b, out, group_size_, bits_, stream()); + break; + default: + throw std::runtime_error("Unsupported dtype for quantized_matmul"); + } } void QuantizedMatmul::eval_gpu(const std::vector &inputs, std::vector &outputs) { diff --git a/src/extensions_ref/src/tiny_llm_ext.h b/src/extensions_ref/src/tiny_llm_ext.h index 0fe663e..599ba9f 100644 --- a/src/extensions_ref/src/tiny_llm_ext.h +++ b/src/extensions_ref/src/tiny_llm_ext.h @@ -1,6 +1,6 @@ #pragma once -#include "mlx/ops.h" +#include "mlx/utils.h" #include "mlx/primitives.h" namespace mx = mlx::core; diff --git a/tests/utils.py b/tests/utils.py index c34584f..aef7e25 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -20,7 +20,7 @@ def assert_allclose( b = np.array(b) if precision == mx.float32: rtol = rtol or 1.0e-5 - atol = atol or 1.0e-8 + atol = atol or 1.0e-6 elif precision == mx.float16: rtol = rtol or 3.0e-2 atol = atol or 1.0e-5 diff --git a/tests_refsol/test_week_2_day_2.py b/tests_refsol/test_week_2_day_2.py index fcb7f1c..56d918b 100644 --- a/tests_refsol/test_week_2_day_2.py +++ b/tests_refsol/test_week_2_day_2.py @@ -49,3 +49,11 @@ def test_task_2_quantized_matmul_simple_f16_gpu(): def test_task_2_quantized_matmul_complex_f16_gpu(): quantized_matmul_helper(mx.gpu, False, mx.float16) + + +def test_task_1_quantized_matmul_simple_f32_cpu(): + quantized_matmul_helper(mx.cpu, True, mx.float32) + + +def test_task_1_quantized_matmul_complex_f32_cpu(): + quantized_matmul_helper(mx.cpu, False, mx.float32) \ No newline at end of file From c9f05de49a6c7cf55a54f76ec58302af7ff0d848 Mon Sep 17 00:00:00 2001 From: Connor Date: Sun, 8 Feb 2026 12:37:27 -0800 Subject: [PATCH 66/79] book: remove deprecated mdbook multilingual key (#86) --- book/book.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/book/book.toml b/book/book.toml index b6f8140..8300bac 100644 --- a/book/book.toml +++ b/book/book.toml @@ -1,7 +1,6 @@ [book] authors = ["Alex Chi", "Connor Zhang"] language = "en" -multilingual = false src = "src" title = "Tiny LLM - LLM Serving in a Week" From e34dc7e2a55515309a69c41d10e16e6eaaf4b69f Mon Sep 17 00:00:00 2001 From: Connor Date: Sun, 8 Feb 2026 13:01:12 -0800 Subject: [PATCH 67/79] ci: update mdbook preprocessors for 0.5 pipeline (#87) --- .github/workflows/main.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 970febd..a022869 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -16,7 +16,8 @@ jobs: - name: setup rust toolchain run: rustup update && rustup toolchain install - uses: dtolnay/rust-toolchain@stable - - run: cargo install mdbook-katex + - run: cargo install mdbook-toc + - run: cargo install mdbook-katex --version 0.10.0-alpha - uses: taiki-e/install-action@mdbook - name: patch for gh-pages build run: mv book/theme/head.hbs._ book/theme/head.hbs From 0c952670e87a49e5cf93fa284e7bcdc2ef352334 Mon Sep 17 00:00:00 2001 From: Connor Date: Mon, 9 Feb 2026 23:16:30 -0800 Subject: [PATCH 68/79] add AGENTS.md (#85) Signed-off-by: Connor1996 --- AGENTS.md | 81 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100644 AGENTS.md diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..9b575e2 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,81 @@ +# AGENTS.md + +## Scope + +- This file applies to the entire repository. +- Use this as the default test-running policy for coding agents. + +## Objective + +- Run and verify tests in a way that matches the book workflow (`book/src/*.md`). +- Prefer `pdm` entrypoints defined in `pyproject.toml`. + +## Environment Requirements + +- macOS on Apple Silicon is expected by the project. +- Install dependencies first: + +```bash +pdm install -v +pdm run check-installation +``` + +- Optional baseline check from the setup chapter (reference solution, Week 1): + +```bash +pdm run test-refsol -- -- -k week_1 +``` + +## Agent Test Workflow + +1. Start with the smallest relevant scope (`--week` + `--day`). +2. Use pytest filters via `-- -k ...` to isolate failing tasks. +3. Run broader suites only after targeted tests pass. +4. If extension code changed, rebuild extensions before testing. + +## Canonical Commands + +Run all tests: + +```bash +pdm run test +``` + +Run a specific chapter/day: + +```bash +pdm run test --week --day +``` + +Run with pytest filters: + +```bash +pdm run test --week 1 --day 3 -- -k task_2 +pdm run test --week 2 --day 2 -- -k cpu +pdm run test --week 2 --day 2 -- -k gpu +``` + +Run reference-solution tests: + +```bash +pdm run test-refsol +pdm run test-refsol --week 2 --day 2 -- -k cpu +``` + +## Extension Rebuild Rule + +Rebuild before tests if these changed: + +- `src/extensions/src/*` + +Commands: + +```bash +pdm run build-ext +``` + +## Guardrails + +- Use `--` before pytest args (`-k`, `-q`, `--collect-only`, etc.). +- `pdm run test --week X --day Y` auto-copies `tests_refsol/test_week_X_day_Y.py` into `tests/`. +- Model-dependent tests (0.5B/1.5B/7B) skip when models are not downloaded locally. From b2393a2ded5677e95326c581f464ce0dc3f66b0a Mon Sep 17 00:00:00 2001 From: Connor Date: Sun, 15 Feb 2026 12:12:17 -0800 Subject: [PATCH 69/79] docs: add Week 2 Day 2-3 Quantized Matmul chapter CPU part (#88) * docs: add Week 2 Day 2-3 Quantized Matmul chapter - Add quantized matmul documentation (week2-02-quantized-matmul.md) Signed-off-by: Connor1996 --- .cspell.json | 8 +- README.md | 2 +- book/src/SUMMARY.md | 2 +- book/src/week2-02-quantized-matmul.md | 266 ++++++++++++++++++++++++++ tests_refsol/test_week_2_day_2.py | 20 +- 5 files changed, 284 insertions(+), 14 deletions(-) create mode 100644 book/src/week2-02-quantized-matmul.md diff --git a/.cspell.json b/.cspell.json index 8f4136b..b1f953f 100644 --- a/.cspell.json +++ b/.cspell.json @@ -23,9 +23,13 @@ "bfloat", "multihead", "vllm", - "silu" + "silu", + "GFLOPS", + "TFLOPS", + "dequantized", + "dequantization" ], "ignoreRegExpList": [ "`[^`]*`", ] -} \ No newline at end of file +} diff --git a/README.md b/README.md index 55636b6..3438a50 100644 --- a/README.md +++ b/README.md @@ -38,7 +38,7 @@ Week 1 is complete. Week 2 is in progress. | 1.6 | Generate Responses (aka Decoding) | ✅ | ✅ | ✅ | | 1.7 | Sampling | ✅ | ✅ | ✅ | | 2.1 | Key-Value Cache | ✅ | ✅ | ✅ | -| 2.2 | Quantized Matmul and Linear - CPU | ✅ | ✅ | 🚧 | +| 2.2 | Quantized Matmul and Linear - CPU | ✅ | ✅ | ✅ | | 2.3 | Quantized Matmul and Linear - GPU | ✅ | ✅ | 🚧 | | 2.4 | Flash Attention 2 - CPU | ✅ | ✅ | 🚧 | | 2.5 | Flash Attention 2 - GPU | ✅ | ✅ | 🚧 | diff --git a/book/src/SUMMARY.md b/book/src/SUMMARY.md index ca5b41a..f3c342a 100644 --- a/book/src/SUMMARY.md +++ b/book/src/SUMMARY.md @@ -15,7 +15,7 @@ - [Sampling and Preparing for Week 2](./week1-07-sampling-prepare.md) - [Week 2: Tiny vLLM](./week2-overview.md) - [Key-Value Cache](./week2-01-kv-cache.md) - - [Quantized Matmul (2 Days)]() + - [Quantized Matmul (2 Days)](./week2-02-quantized-matmul.md) - [Flash Attention (2 Days)]() - [Continuous Batching (2 Days)](./week2-06-prefill-and-batch.md) - [Week 3: Serving]() diff --git a/book/src/week2-02-quantized-matmul.md b/book/src/week2-02-quantized-matmul.md new file mode 100644 index 0000000..a590995 --- /dev/null +++ b/book/src/week2-02-quantized-matmul.md @@ -0,0 +1,266 @@ +# Week 2 Day 2-3: Quantized Matmul + +In this chapter, we will implement the quantized matrix multiplication. Quantization compresses model weights from 16-bit floating point to 4-bit integers, which is critical for efficient LLM serving on devices with limited memory bandwidth. + +## Readings + +- [Model Compression and Quantization](https://huggingface.co/blog/hf-bitsandbytes-integration) +- [MLX Extensions Development Guide](https://ml-explore.github.io/mlx/build/html/dev/extensions.html) +- [Quantized Matmul on CPU (Video)](https://www.youtube.com/watch?v=es6s6T1bTtI) +- [Quantized Matmul on GPU (Video)](https://www.youtube.com/watch?v=jYCxVirq4d0) + +## Why Quantization? + +As we learned in the KV Cache chapter, the decode phase of LLM inference is **memory-bandwidth bound**. Let's revisit the arithmetic intensity calculation for the Qwen2-0.5B model: + +```plain +Per-token computation in decode phase: +- Input: 1 token × 896 dimensions = 896 float16 values = 1.792 KB +- MLP weights: 896 × 4864 × 3 matrices × 2 bytes = ~25 MB per layer +- Attention weights: 896 × 896 × 4 matrices × 2 bytes = ~6 MB per layer +- Total weights per layer: ~31 MB +- Total for 24 layers: ~750 MB + +FLOPs (2 per multiply-accumulate): +- MLP per layer: 2 × 3 × 896 × 4864 ≈ 26M +- Attention per layer: 2 × 4 × 896 × 896 ≈ 6.4M +- 24 layers: ~780 million per token + +Memory access: ~750 MB +Arithmetic intensity: 780M FLOPs / 750 MB ≈ 1.0 FLOPs/Byte +``` + +With M3 Max's 400 GB/s memory bandwidth and ~10 TFLOPS compute: + +```plain +Memory-bound throughput: 400 GB/s × 1.0 FLOPs/Byte = 400 GFLOPS +Compute-bound throughput: 10 TFLOPS + +We're using only ~4% of available compute! +``` + +### The Solution: Quantization + +By compressing weights from 16 bits (float16/bfloat16) to 4 bits (int4), we: + +- **Reduce memory bandwidth by 4×**: 750 MB → ~190 MB per token +- **Improve arithmetic intensity by 4×**: 1.0 → ~4.0 FLOPs/Byte +- **Increase throughput by ~4×**: 400 GFLOPS → ~1.6 TFLOPS + +The tradeoff is minimal accuracy loss with proper quantization techniques. + +### Group-wise Quantization + +Instead of quantizing all weights uniformly, we divide them into **groups** and quantize each group independently. This preserves more information about the weight distribution. + +For a weight matrix $W$ of shape $(K, N)$, we divide each row into groups of size $G$ (typically 64 or 128): + +```plain +Original weight matrix W: K × N (float16/bfloat16) + +Group size G = 64 +Number of groups per row = N / G + +For each group of 64 consecutive values in a row: + 1. Find min and max values + 2. Compute scale and bias to map [min, max] → [0, 15] (4-bit range) + 3. Quantize each value using: quantized = round((value - bias) / scale) +``` + +### Affine Quantization + +We use **affine (asymmetric) quantization** which maps a floating-point range to the full integer range: + +$$ +\text{quantized} = \text{round}\left(\frac{\text{value} - \text{bias}}{\text{scale}}\right) +$$ + +$$ +\text{dequantized} = \text{quantized} \times \text{scale} + \text{bias} +$$ + +For 4-bit quantization, the quantized values are in the range $[0, 15]$. + +Given a group with minimum value $v_{min}$ and maximum value $v_{max}$: + +$$ +\text{scale} = \frac{v_{max} - v_{min}}{2^{\text{bits}} - 1} = \frac{v_{max} - v_{min}}{15} +$$ + +$$ +\text{bias} = v_{min} +$$ + +**Example:** + +```plain +Group values: [-0.5, -0.3, 0.1, 0.4, 0.8] +min = -0.5, max = 0.8 + +scale = (0.8 - (-0.5)) / 15 = 1.3 / 15 ≈ 0.0867 +bias = -0.5 + +Quantization: + -0.5 → round((-0.5 - (-0.5)) / 0.0867) = 0 + -0.3 → round((-0.3 - (-0.5)) / 0.0867) = 2 + 0.1 → round((0.1 - (-0.5)) / 0.0867) = 7 + 0.4 → round((0.4 - (-0.5)) / 0.0867) = 10 + 0.8 → round((0.8 - (-0.5)) / 0.0867) = 15 + +Quantized: [0, 2, 7, 10, 15] (4 bits each) +``` + +### Storage Format + +For efficient storage and computation, quantized weights are packed: + +```plain +Original: K × N float16 (2 bytes each) = 2KN bytes +Quantized: K × N int4 (0.5 bytes each) = 0.5KN bytes + +Packing: 8 × 4-bit values fit in one uint32 (32 bits) + +Weight matrix shape: K × N +Quantized storage shape: K × (N / 8) uint32 +Scales shape: K × (N / 64) float16 +Biases shape: K × (N / 64) float16 +``` + +Example packing for 8 consecutive 4-bit values `[a, b, c, d, e, f, g, h]`: + +```plain +uint32_value = (h << 28) | (g << 24) | (f << 20) | (e << 16) | + (d << 12) | (c << 8) | (b << 4) | a + +Unpacking: + a = (uint32_value >> 0) & 0xF + b = (uint32_value >> 4) & 0xF + c = (uint32_value >> 8) & 0xF + ... + h = (uint32_value >> 28) & 0xF +``` + +## Quantized Matrix Multiplication + +### Mathematical Formulation + +For standard matrix multiplication $C = AB^T$ where: + +- $A$: shape $(M, N)$, float16/bfloat16 (activations) +- $B$: shape $(K, N)$, **quantized** to int4 (weights) +- $C$: shape $(M, K)$, float16/bfloat16 (output) + +Each element $C[i, k]$ is computed as: + +$$ +C[i, k] = \sum_{j=0}^{N-1} A[i, j] \times B[k, j] +$$ + +With quantization, $B[k, j]$ is represented as: + +$$ +B[k, j] = B_{\text{quantized}}[k, j] \times \text{scale}[k, g] + \text{bias}[k, g] +$$ + +where $g = \lfloor j / G \rfloor$ is the group index. + +Substituting: + +$$ +C[i, k] = \sum_{g=0}^{N/G-1} \sum_{j'=0}^{G-1} A[i, g \times G + j'] \times (B_{\text{quantized}}[k, g \times G + j'] \times \text{scale}[k, g] + \text{bias}[k, g]) +$$ + +Rearranging: + +$$ +C[i, k] = \sum_{g=0}^{N/G-1} \left( \text{scale}[k, g] \sum_{j'=0}^{G-1} A[i, g \times G + j'] \times B_{\text{quantized}}[k, g \times G + j'] + \text{bias}[k, g] \sum_{j'=0}^{G-1} A[i, g \times G + j'] \right) +$$ + +This shows we can factor out the scale and bias per group, reducing the number of floating-point operations. + +### Computation Flow + +```plain +Input: + A: M × N (float16, activations) + B_quantized: K × (N/8) (uint32, packed weights) + scales: K × (N/64) (float16) + biases: K × (N/64) (float16) + +Output: + C: M × K (float16) + +For each output element C[i, k]: + sum = 0 + for each group g in 0..(N/64 - 1): + scale = scales[k, g] + bias = biases[k, g] + + # Process 64 values in the group (8 uint32 packs) + for each pack p in 0..7: + packed_value = B_quantized[k, g*8 + p] + + # Unpack 8 × 4-bit values + for bit_offset in [0, 4, 8, 12, 16, 20, 24, 28]: + quantized = (packed_value >> bit_offset) & 0xF + b_value = quantized * scale + bias + a_value = A[i, g*64 + p*8 + bit_offset/4] + sum += a_value * b_value + + C[i, k] = sum +``` + +## Task 1: Implement QuantizedWeights + +``` +src/tiny_llm/quantize.py +``` + +First, familiarize yourself with the `QuantizedWeights` class, which stores quantized weight information: + +| Field | Shape | Description | +|-------|-------|-------------| +| `weight` | $(K, N/8)$ uint32 | Packed quantized weights. Each uint32 stores 8 consecutive 4-bit values. The original weight matrix has shape $(K, N)$, and after packing, it becomes $(K, N/8)$. | +| `scales` | $(K, N/G)$ float16 | Per-group scale factors for dequantization. Each group of $G$ consecutive values shares one scale. Recall: $\text{scale} = (v_{max} - v_{min}) / 15$ | +| `biases` | $(K, N/G)$ float16 | Per-group bias (offset) for dequantization. Recall: $\text{bias} = v_{min}$ | +| `group_size` | int | Number of consecutive values that share the same scale/bias (typically 64) | +| `bits` | int | Quantization bit width (typically 4, meaning values are in range $[0, 15]$) | + +The `from_mlx_layer` static method extracts these fields from MLX's quantized linear layers when loading the model. + +Next, implement the `quantized_linear` function, which is a wrapper around `quantized_matmul` that mimics the standard `linear` function interface. And we'll implement `quantized_matmul` in the next task. + +## Task 2: Implement `quantized_matmul` (CPU version) + +In this task, we will implement the quantized matmul as an MLX C++ extension. The pattern is identical to the existing `axpby` example in the codebase — read through `axpby.h`, `axpby.cpp`, and the corresponding binding in `bindings.cpp` first as your reference. + +``` +src/extensions/src/tiny_llm_ext.h +src/extensions/bindings.cpp +src/extensions/src/quantized_matmul.cpp +src/extensions/CMakeLists.txt +``` + +You need to touch three files, all within the `tiny_llm_ext` namespace: + +- **`tiny_llm_ext.h`** — Declare the `quantized_matmul(...)` function signature and define a `QuantizedMatmul` primitive class (inheriting `mx::Primitive`). Store `group_size` and `bits` as private members. +- **`bindings.cpp`** — Add an `m.def(...)` call to expose the function to Python. +- **`quantized_matmul.cpp`** — Implement the `quantized_matmul(...)` function (validate inputs, compute output shape, return a lazy `mx::array`) and the `eval_cpu` method (allocate output, register arrays with the CPU encoder, dispatch the compute kernel). + +The `eval_cpu` implementation follows the same CPU encoder pattern as `axpby`: allocate output memory with `out.set_data(mx::allocator::malloc(out.nbytes()))`, register input/output arrays with the encoder, then dispatch a lambda that performs the actual computation. Inside the lambda, implement the nested loop from the Computation Flow section above — iterate over each output element `(i, k)`, accumulate in `float` (fp32) to avoid precision loss, and cast the result back to `float16` when writing to the output. + +Don't forget to add `src/quantized_matmul.cpp` to `target_sources` in `CMakeLists.txt`. + +You can test your implementation by running: + +```bash +pdm run build-ext +pdm run test --week 2 --day 2 -- -k task_2 +``` + +## Task 3: Implement Metal Kernel + +TBD... + +{{#include copyright.md}} + diff --git a/tests_refsol/test_week_2_day_2.py b/tests_refsol/test_week_2_day_2.py index 56d918b..043aa51 100644 --- a/tests_refsol/test_week_2_day_2.py +++ b/tests_refsol/test_week_2_day_2.py @@ -35,25 +35,25 @@ def quantized_matmul_helper( assert_allclose(user_out, ref_out, precision) -def test_task_1_quantized_matmul_simple_f16_cpu(): +def test_task_2_quantized_matmul_simple_f16_cpu(): quantized_matmul_helper(mx.cpu, True, mx.float16) -def test_task_1_quantized_matmul_complex_f16_cpu(): +def test_task_2_quantized_matmul_complex_f16_cpu(): quantized_matmul_helper(mx.cpu, False, mx.float16) -def test_task_2_quantized_matmul_simple_f16_gpu(): - quantized_matmul_helper(mx.gpu, True, mx.float16) +def test_task_2_quantized_matmul_simple_f32_cpu(): + quantized_matmul_helper(mx.cpu, True, mx.float32) -def test_task_2_quantized_matmul_complex_f16_gpu(): - quantized_matmul_helper(mx.gpu, False, mx.float16) +def test_task_2_quantized_matmul_complex_f32_cpu(): + quantized_matmul_helper(mx.cpu, False, mx.float32) -def test_task_1_quantized_matmul_simple_f32_cpu(): - quantized_matmul_helper(mx.cpu, True, mx.float32) +def test_task_3_quantized_matmul_simple_f16_gpu(): + quantized_matmul_helper(mx.gpu, True, mx.float16) -def test_task_1_quantized_matmul_complex_f32_cpu(): - quantized_matmul_helper(mx.cpu, False, mx.float32) \ No newline at end of file +def test_task_3_quantized_matmul_complex_f16_gpu(): + quantized_matmul_helper(mx.gpu, False, mx.float16) \ No newline at end of file From 0688e96f98743fcd463a6afbcba36db25ab598f6 Mon Sep 17 00:00:00 2001 From: Connor Date: Sun, 15 Feb 2026 15:17:31 -0800 Subject: [PATCH 70/79] docs: add Week 2 Day 2-3 Quantized Matmul chapter GPU part (#89) * docs: add week2 quantized matmul GPU part Signed-off-by: Connor1996 --- .cspell.json | 4 +- book/src/week2-02-quantized-matmul.md | 60 +++++++++++++++++++++++++-- 2 files changed, 59 insertions(+), 5 deletions(-) diff --git a/.cspell.json b/.cspell.json index b1f953f..1616579 100644 --- a/.cspell.json +++ b/.cspell.json @@ -27,7 +27,9 @@ "GFLOPS", "TFLOPS", "dequantized", - "dequantization" + "dequantization", + "dequantizes", + "dtype", ], "ignoreRegExpList": [ "`[^`]*`", diff --git a/book/src/week2-02-quantized-matmul.md b/book/src/week2-02-quantized-matmul.md index a590995..8914b35 100644 --- a/book/src/week2-02-quantized-matmul.md +++ b/book/src/week2-02-quantized-matmul.md @@ -2,7 +2,7 @@ In this chapter, we will implement the quantized matrix multiplication. Quantization compresses model weights from 16-bit floating point to 4-bit integers, which is critical for efficient LLM serving on devices with limited memory bandwidth. -## Readings +**📚 Readings** - [Model Compression and Quantization](https://huggingface.co/blog/hf-bitsandbytes-integration) - [MLX Extensions Development Guide](https://ml-explore.github.io/mlx/build/html/dev/extensions.html) @@ -258,9 +258,61 @@ pdm run build-ext pdm run test --week 2 --day 2 -- -k task_2 ``` -## Task 3: Implement Metal Kernel +## Task 3: Implement `quantized_matmul` (GPU version) -TBD... +``` +src/extensions/src/quantized_matmul.metal +src/extensions/src/quantized_matmul.cpp +``` -{{#include copyright.md}} +In this task, you will write the Metal kernel for quantized matmul **and** wire up the `eval_gpu` method to dispatch it. Keep the math exactly the same as Task 2 (CPU); only the execution model changes. + +### Metal Kernel + +You need to implement one kernel entry in `quantized_matmul.metal`: + +- Use a **one-thread-per-output-element** mapping: each thread computes `out[i, k]`. +- The kernel should be templated on the data type (to support both `half` and `bfloat16_t`). +- Apply the same group-wise dequantization loop as the CPU version: + - Iterate over groups (`group_size = 64`) + - Unpack int4 values from packed `uint32` + - Dequantize with `q * scale + bias` + - Accumulate in `float` and cast to the output dtype at the end +- Add boundary checks (`i < M`, `k < K`) before writing output. + +### GPU Dispatch +Complete the `eval_gpu` method in `quantized_matmul.cpp` to dispatch your Metal kernel. Follow the same pattern as `axpby`'s GPU dispatch: + +1. Get the Metal device and command encoder from the stream. +2. Select the correct kernel name based on the activation dtype (`float16` → `half`, `bfloat16` → `bfloat16_t`). +3. Set input/output buffers and dimension constants (`M`, `N`, `K`) on the encoder — make sure the buffer order matches your kernel signature. +4. Calculate a 2D thread group configuration: use `kernel->maxTotalThreadsPerThreadgroup()` to determine the total threads, then split between the M and K dimensions (e.g., 32 threads for M, the rest for K). +5. Dispatch with `dispatchThreadgroups`. + +You can test your implementation by running: + +```bash +pdm run build-ext +pdm run test --week 2 --day 2 -- -k task_3 +``` + +## Task 4: Model Integration + +``` +src/tiny_llm/qwen2_week2.py +``` + +Integrate your quantized matmul into the Week 2 Qwen2 model so that inference runs on quantized weights end-to-end. + +Change the weight type from `mx.array` to `QuantizedWeights` for all linear layers in attention (`wq/wk/wv/wo`) and MLP (`w_gate/w_up/w_down`). Replace every `linear(x, w)` call with `quantized_linear(x, w)`. In the model loading code, use `QuantizedWeights.from_mlx_layer(...)` to extract quantized weight information from each MLX linear layer, instead of calling `mx.dequantize` to get a full float16 matrix. Make sure the Week 1 loader still dequantizes (since Week 1 layers expect plain `mx.array`), while the Week 2 loader does **not** dequantize. + +Note that MLX loads quantized models with `scales` and `biases` stored in **bfloat16** by default, while the activation tensors are typically **float16**. Since we have not implemented bfloat16 support in our kernel, you will need to convert the scales and biases to float16 with `mx.astype` before calling the kernel. If you see `nan` or garbage output, a dtype mismatch is the most likely cause. + +You can test your implementation by running: + +```bash +pdm run main --solution tiny_llm --loader week2 --model qwen2-0.5b +``` + +{{#include copyright.md}} From 1cd513bf6dca66e61b7c6f07bce466ccd346674a Mon Sep 17 00:00:00 2001 From: Connor Date: Sun, 15 Feb 2026 16:12:58 -0800 Subject: [PATCH 71/79] doc: add tokenizer definition reference (#90) Signed-off-by: Connor1996 --- book/src/week1-06-generate-response.md | 1 + 1 file changed, 1 insertion(+) diff --git a/book/src/week1-06-generate-response.md b/book/src/week1-06-generate-response.md index 2c27459..5907c7b 100644 --- a/book/src/week1-06-generate-response.md +++ b/book/src/week1-06-generate-response.md @@ -35,6 +35,7 @@ pick the token with the highest probability. - 📚 [The Log-Sum-Exp Trick](https://gregorygundersen.com/blog/2020/02/09/log-sum-exp/) - 📚 [Decoding Strategies in Large Language Models](https://mlabonne.github.io/blog/posts/2023-06-07-Decoding_strategies.html) +- 📚 [Tokenizer definition](https://huggingface.co/docs/transformers/main/en/main_classes/tokenizer) With the `_step` function implemented, you can now implement the full `simple_generate` function. The function will first prefill the model with the prompt. As the prompt is a string, you need to first convert it to a list of tokens From e9b90bd55cfdc60c04edd5abc6d81aa7ba20026d Mon Sep 17 00:00:00 2001 From: Connor Date: Sun, 15 Feb 2026 16:17:29 -0800 Subject: [PATCH 72/79] docs: mark week 2.3 tiny_llm status as complete (#91) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 3438a50..3c3e153 100644 --- a/README.md +++ b/README.md @@ -39,7 +39,7 @@ Week 1 is complete. Week 2 is in progress. | 1.7 | Sampling | ✅ | ✅ | ✅ | | 2.1 | Key-Value Cache | ✅ | ✅ | ✅ | | 2.2 | Quantized Matmul and Linear - CPU | ✅ | ✅ | ✅ | -| 2.3 | Quantized Matmul and Linear - GPU | ✅ | ✅ | 🚧 | +| 2.3 | Quantized Matmul and Linear - GPU | ✅ | ✅ | ✅ | | 2.4 | Flash Attention 2 - CPU | ✅ | ✅ | 🚧 | | 2.5 | Flash Attention 2 - GPU | ✅ | ✅ | 🚧 | | 2.6 | Continuous Batching | ✅ | ✅ | ✅ | From f4dc967978196b281d45efc541efbe917d7756b7 Mon Sep 17 00:00:00 2001 From: Connor Date: Mon, 16 Feb 2026 17:25:10 -0800 Subject: [PATCH 73/79] Add bench-main command and week2 benchmark instructions (#93) --- bench.py | 280 ++++++++++++++++++++++++++ book/src/week2-01-kv-cache.md | 7 + book/src/week2-02-quantized-matmul.md | 7 + pyproject.toml | 3 +- 4 files changed, 296 insertions(+), 1 deletion(-) create mode 100644 bench.py diff --git a/bench.py b/bench.py new file mode 100644 index 0000000..aac1cf6 --- /dev/null +++ b/bench.py @@ -0,0 +1,280 @@ +import argparse +from dataclasses import dataclass +from random import Random +from time import perf_counter + +import mlx.core as mx +from mlx_lm import load +from tqdm.auto import tqdm + + +@dataclass +class BenchRequest: + prompt_token_ids: list[int] + max_new_tokens: int + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Benchmark tiny-llm token throughput with synthetic token IDs." + ) + parser.add_argument("--model", type=str, default="qwen2-0.5b") + parser.add_argument("--solution", type=str, default="tiny_llm") + parser.add_argument("--loader", type=str, default="week2", choices=["week1", "week2"]) + parser.add_argument("--device", type=str, default="gpu", choices=["cpu", "gpu"]) + parser.add_argument("--num-seqs", type=int, default=16) + parser.add_argument("--min-input-len", type=int, default=64) + parser.add_argument("--max-input-len", type=int, default=256) + parser.add_argument("--min-output-len", type=int, default=64) + parser.add_argument("--max-output-len", type=int, default=256) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--warmup", type=int, default=1) + return parser.parse_args() + + +def validate_args(args: argparse.Namespace) -> None: + if args.num_seqs <= 0: + raise ValueError("--num-seqs must be > 0") + if args.min_input_len <= 0 or args.max_input_len <= 0: + raise ValueError("input lengths must be > 0") + if args.min_output_len <= 0 or args.max_output_len <= 0: + raise ValueError("output lengths must be > 0") + if args.min_input_len > args.max_input_len: + raise ValueError("--min-input-len cannot be greater than --max-input-len") + if args.min_output_len > args.max_output_len: + raise ValueError("--min-output-len cannot be greater than --max-output-len") + if args.warmup < 0: + raise ValueError("--warmup must be >= 0") + + +def load_solution_modules(solution: str): + if solution == "tiny_llm": + from tiny_llm import models + from tiny_llm.kv_cache import TinyKvFullCache + + return "tiny_llm", models, TinyKvFullCache + if solution in {"tiny_llm_ref", "ref"}: + from tiny_llm_ref import models + from tiny_llm_ref.kv_cache import TinyKvFullCache + + return "tiny_llm_ref", models, TinyKvFullCache + raise ValueError(f"Solution {solution} not supported for bench") + + +def random_token_id(rng: Random, low: int, high: int, eos_token_id: int) -> int: + if low == high: + return low + token = rng.randint(low, high) + if token != eos_token_id: + return token + if token == low: + return low + 1 + return token - 1 + + +def build_requests( + *, + rng: Random, + num_seqs: int, + vocab_size: int, + eos_token_id: int, + min_input_len: int, + max_input_len: int, + min_output_len: int, + max_output_len: int, +) -> list[BenchRequest]: + token_low = 256 if vocab_size > 512 else 0 + token_high = vocab_size - 1 + if token_low > token_high: + token_low = 0 + requests = [] + for _ in range(num_seqs): + prompt_len = rng.randint(min_input_len, max_input_len) + max_new_tokens = rng.randint(min_output_len, max_output_len) + prompt_token_ids = [ + random_token_id(rng, token_low, token_high, eos_token_id) + for _ in range(prompt_len) + ] + requests.append(BenchRequest(prompt_token_ids, max_new_tokens)) + return requests + + +def sample_next_week1(model, y: mx.array) -> mx.array: + output_logits = model(y[None, :]) + logits = output_logits[:, -1, :] + return mx.argmax(logits, axis=-1) + + +def sample_next_week2(model, y: mx.array, offset: int, kv_cache: list) -> mx.array: + output_logits = model(y[None, :], offset, kv_cache) + logits = output_logits[:, -1, :] + return mx.argmax(logits, axis=-1) + + +def run_one_request_week1( + model, + request: BenchRequest, +) -> tuple[int, float, float]: + context = mx.array(request.prompt_token_ids, dtype=mx.int32) + t0 = perf_counter() + token = sample_next_week1(model, context) + mx.eval(token) + prefill_time = perf_counter() - t0 + + generated_tokens = 1 + decode_time = 0.0 + + for _ in range(request.max_new_tokens - 1): + t1 = perf_counter() + context = mx.concat([context, token]) + token = sample_next_week1(model, context) + mx.eval(token) + decode_time += perf_counter() - t1 + generated_tokens += 1 + return generated_tokens, prefill_time, decode_time + + +def run_one_request_week2( + model, + request: BenchRequest, + kv_cache_cls, +) -> tuple[int, float, float]: + kv_cache = [kv_cache_cls() for _ in range(model.num_hidden_layers)] + context = mx.array(request.prompt_token_ids, dtype=mx.int32) + offset = 0 + + t0 = perf_counter() + token = sample_next_week2(model, context, offset, kv_cache) + mx.eval(token) + prefill_time = perf_counter() - t0 + offset += context.size + + generated_tokens = 1 + decode_time = 0.0 + + for _ in range(request.max_new_tokens - 1): + t1 = perf_counter() + token = sample_next_week2(model, token, offset, kv_cache) + mx.eval(token) + decode_time += perf_counter() - t1 + offset += 1 + generated_tokens += 1 + return generated_tokens, prefill_time, decode_time + + +def safe_div(num: float, den: float) -> float: + return num / den if den > 0 else 0.0 + + +def main() -> None: + args = parse_args() + validate_args(args) + + rng = Random(args.seed) + solution_name, models, kv_cache_cls = load_solution_modules(args.solution) + model_name = models.shortcut_name_to_full_name(args.model) + print( + f"Solution={solution_name} Loader={args.loader} Device={args.device} " + f"Model={model_name}" + ) + mlx_model, tokenizer = load(model_name) + + with mx.stream(mx.gpu if args.device == "gpu" else mx.cpu): + if args.loader == "week1": + model = models.dispatch_model(model_name, mlx_model, week=1) + + def run_one_request(request: BenchRequest) -> tuple[int, float, float]: + return run_one_request_week1( + model, + request, + ) + else: + model = models.dispatch_model( + model_name, + mlx_model, + week=2, + ) + + def run_one_request(request: BenchRequest) -> tuple[int, float, float]: + return run_one_request_week2( + model, + request, + kv_cache_cls, + ) + + requests = build_requests( + rng=rng, + num_seqs=args.num_seqs, + vocab_size=mlx_model.args.vocab_size, + eos_token_id=tokenizer.eos_token_id, + min_input_len=args.min_input_len, + max_input_len=args.max_input_len, + min_output_len=args.min_output_len, + max_output_len=args.max_output_len, + ) + + if args.warmup > 0: + print(f"Warmup runs: {args.warmup}") + warmup_iter = range(args.warmup) + warmup_iter = tqdm( + warmup_iter, + total=args.warmup, + desc="Warmup", + dynamic_ncols=True, + leave=False, + ) + for i in warmup_iter: + run_one_request(requests[i % len(requests)]) + + total_prompt_tokens = 0 + total_generated_tokens = 0 + total_decode_tokens = 0 + total_prefill_time = 0.0 + total_decode_time = 0.0 + + progress = tqdm(total=len(requests), desc="Bench", dynamic_ncols=True) + + t0 = perf_counter() + for request in requests: + generated_tokens, prefill_time, decode_time = run_one_request(request) + total_prompt_tokens += len(request.prompt_token_ids) + total_generated_tokens += generated_tokens + total_decode_tokens += max(0, generated_tokens - 1) + total_prefill_time += prefill_time + total_decode_time += decode_time + elapsed = perf_counter() - t0 + progress.update(1) + progress.set_postfix( + { + "out_tok/s": f"{safe_div(total_generated_tokens, elapsed):.1f}", + "decode_tok/s": f"{safe_div(total_decode_tokens, total_decode_time):.1f}", + } + ) + total_time = perf_counter() - t0 + progress.close() + + total_model_tokens = total_prompt_tokens + total_generated_tokens + print( + f"Requests: {args.num_seqs}, Prompt tokens: {total_prompt_tokens}, " + f"Generated tokens: {total_generated_tokens}" + ) + print( + f"Time: {total_time:.2f}s, Output throughput: " + f"{safe_div(total_generated_tokens, total_time):.2f} tok/s" + ) + print( + f"Total throughput (prompt+output): " + f"{safe_div(total_model_tokens, total_time):.2f} tok/s" + ) + print( + f"Prefill throughput: " + f"{safe_div(total_prompt_tokens, total_prefill_time):.2f} tok/s" + ) + print( + f"Decode throughput: " + f"{safe_div(total_decode_tokens, total_decode_time):.2f} tok/s" + ) + + +if __name__ == "__main__": + main() diff --git a/book/src/week2-01-kv-cache.md b/book/src/week2-01-kv-cache.md index d3a4b6a..ed342d3 100644 --- a/book/src/week2-01-kv-cache.md +++ b/book/src/week2-01-kv-cache.md @@ -192,4 +192,11 @@ pdm run main --solution tiny_llm --loader week2 --model qwen2-0.5b pdm run main --solution tiny_llm --loader week2 --model qwen2-7b ``` +You can also benchmark throughput and compare your implementation with the reference solution: + +```bash +pdm bench --solution tiny_llm --loader week2 --model qwen2-0.5b +pdm bench --solution tiny_llm_ref --loader week2 --model qwen2-0.5b +``` + {{#include copyright.md}} diff --git a/book/src/week2-02-quantized-matmul.md b/book/src/week2-02-quantized-matmul.md index 8914b35..52f40cd 100644 --- a/book/src/week2-02-quantized-matmul.md +++ b/book/src/week2-02-quantized-matmul.md @@ -315,4 +315,11 @@ You can test your implementation by running: pdm run main --solution tiny_llm --loader week2 --model qwen2-0.5b ``` +You can also benchmark throughput and compare your implementation with the reference solution: + +```bash +pdm bench --solution tiny_llm --loader week2 --model qwen2-0.5b +pdm bench --solution tiny_llm_ref --loader week2 --model qwen2-0.5b +``` + {{#include copyright.md}} diff --git a/pyproject.toml b/pyproject.toml index 9e5cae1..f02f317 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,11 +37,12 @@ clean-ext-ref.working_dir = "src/extensions_ref" main.cmd = "python main.py" main-week1.cmd = "python main.py --loader week1" main-week2.cmd = "python main.py --loader week2" +bench.cmd = "python bench.py" batch-main.cmd = "python batch-main.py" test.cmd = "python scripts/dev-tools.py test" check-installation.cmd = "python scripts/check-installation.py" test-refsol.cmd = "python scripts/dev-tools.py test-refsol" -bench.cmd = "pytest benches" +bench-test.cmd = "pytest benches" format = "ruff format" format-cpp-ref.shell = "find src/extensions_ref -type file \\( -name '*.h' -or -name '*.cpp' \\) | xargs -n1 clang-format -i" format-cpp.shell = "find src/extensions -type file \\( -name '*.h' -or -name '*.cpp' \\) | xargs -n1 clang-format -i" From ed8ac9e5f60b18683bbfe11a461be5520b5fa8e4 Mon Sep 17 00:00:00 2001 From: jhsong233 Date: Tue, 17 Feb 2026 13:13:22 +0800 Subject: [PATCH 74/79] fix(ref): correct attention weight shape asserts (#92) --- src/tiny_llm_ref/attention.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/tiny_llm_ref/attention.py b/src/tiny_llm_ref/attention.py index bc2d30c..a53a68d 100644 --- a/src/tiny_llm_ref/attention.py +++ b/src/tiny_llm_ref/attention.py @@ -121,10 +121,10 @@ def __init__( assert hidden_size % num_heads == 0 self.head_dim = hidden_size // num_heads self.scale = mx.rsqrt(self.head_dim) - assert wq.shape == (hidden_size, num_heads * self.head_dim) - assert wk.shape == (hidden_size, num_heads * self.head_dim) - assert wv.shape == (hidden_size, num_heads * self.head_dim) - assert wo.shape == (num_heads * self.head_dim, hidden_size) + assert wq.shape == (num_heads * self.head_dim, hidden_size) + assert wk.shape == (num_heads * self.head_dim, hidden_size) + assert wv.shape == (num_heads * self.head_dim, hidden_size) + assert wo.shape == (hidden_size, num_heads * self.head_dim) self.wq = wq self.wk = wk self.wv = wv From 9b40133b0ebaa2b1e1b0f3d7865226c59fd1f651 Mon Sep 17 00:00:00 2001 From: Eric Fu Date: Fri, 20 Feb 2026 10:47:58 +0800 Subject: [PATCH 75/79] bugfix: way to get_kernel from library --- src/extensions/src/axpby.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/extensions/src/axpby.cpp b/src/extensions/src/axpby.cpp index ee67a23..9cbf4bd 100644 --- a/src/extensions/src/axpby.cpp +++ b/src/extensions/src/axpby.cpp @@ -148,7 +148,8 @@ void Axpby::eval_gpu(const std::vector &inputs, std::vector Date: Sun, 22 Feb 2026 17:16:02 +0900 Subject: [PATCH 76/79] book: replace huggingface-cli with hf Signed-off-by: you06 --- book/src/setup.md | 10 +++++----- book/src/week1-05-qwen2-model.md | 24 ++++++++++++------------ book/src/week1-06-generate-response.md | 7 +++---- 3 files changed, 20 insertions(+), 21 deletions(-) diff --git a/book/src/setup.md b/book/src/setup.md index 9028de9..dd4401b 100644 --- a/book/src/setup.md +++ b/book/src/setup.md @@ -52,16 +52,16 @@ pdm run test We will use the Qwen2-7B-Instruct model for this course. It takes ~20GB of memory in week 1 to load the model parameters. If you do not have enough memory, you can consider using the smaller 0.5B model. -Follow the guide of [this page](https://huggingface.co/docs/huggingface_hub/main/en/guides/cli) to install the huggingface -cli. +Follow the guide of [this page](https://huggingface.co/docs/huggingface_hub/main/en/guides/cli) to install the Hugging Face +CLI (`hf`). The model parameters are hosted on Hugging Face. Once you authenticated your cli with the credentials, you can download them with: ```bash -huggingface-cli login -huggingface-cli download Qwen/Qwen2-0.5B-Instruct-MLX -huggingface-cli download Qwen/Qwen2-7B-Instruct-MLX +hf login +hf download Qwen/Qwen2-0.5B-Instruct-MLX +hf download Qwen/Qwen2-7B-Instruct-MLX ``` Then, you can run: diff --git a/book/src/week1-05-qwen2-model.md b/book/src/week1-05-qwen2-model.md index 77cb19e..5767a86 100644 --- a/book/src/week1-05-qwen2-model.md +++ b/book/src/week1-05-qwen2-model.md @@ -5,9 +5,9 @@ In day 5, we will implement the Qwen2 model. Before we start, please make sure you have downloaded the models: ```bash -huggingface-cli download Qwen/Qwen2-0.5B-Instruct-MLX -huggingface-cli download Qwen/Qwen2-1.5B-Instruct-MLX -huggingface-cli download Qwen/Qwen2-7B-Instruct-MLX +hf download Qwen/Qwen2-0.5B-Instruct-MLX +hf download Qwen/Qwen2-1.5B-Instruct-MLX +hf download Qwen/Qwen2-7B-Instruct-MLX ``` Otherwise, some of the tests will be skipped. @@ -47,9 +47,9 @@ You should pass all tests for this task by running: ```bash # Download the models if you haven't done so -huggingface-cli download Qwen/Qwen2-0.5B-Instruct-MLX -huggingface-cli download Qwen/Qwen2-1.5B-Instruct-MLX -huggingface-cli download Qwen/Qwen2-7B-Instruct-MLX +hf download Qwen/Qwen2-0.5B-Instruct-MLX +hf download Qwen/Qwen2-1.5B-Instruct-MLX +hf download Qwen/Qwen2-7B-Instruct-MLX # Run the tests pdm run test --week 1 --day 5 -- -k task_1 ``` @@ -89,9 +89,9 @@ You should pass all tests for this task by running: ```bash # Download the models if you haven't done so; we need to tokenizers -huggingface-cli download Qwen/Qwen2-0.5B-Instruct-MLX -huggingface-cli download Qwen/Qwen2-1.5B-Instruct-MLX -huggingface-cli download Qwen/Qwen2-7B-Instruct-MLX +hf download Qwen/Qwen2-0.5B-Instruct-MLX +hf download Qwen/Qwen2-1.5B-Instruct-MLX +hf download Qwen/Qwen2-7B-Instruct-MLX # Run the tests pdm run test --week 1 --day 5 -- -k task_2 ``` @@ -154,9 +154,9 @@ You should pass all tests for this task by running: ```bash # Download the models if you haven't done so -huggingface-cli download Qwen/Qwen2-0.5B-Instruct-MLX -huggingface-cli download Qwen/Qwen2-1.5B-Instruct-MLX -huggingface-cli download Qwen/Qwen2-7B-Instruct-MLX +hf download Qwen/Qwen2-0.5B-Instruct-MLX +hf download Qwen/Qwen2-1.5B-Instruct-MLX +hf download Qwen/Qwen2-7B-Instruct-MLX # Run the tests pdm run test --week 1 --day 5 -- -k task_3 ``` diff --git a/book/src/week1-06-generate-response.md b/book/src/week1-06-generate-response.md index 2c27459..7634dad 100644 --- a/book/src/week1-06-generate-response.md +++ b/book/src/week1-06-generate-response.md @@ -59,9 +59,9 @@ You can test your implementation by running the following command: ```bash # Download the models if you haven't done so -huggingface-cli download Qwen/Qwen2-0.5B-Instruct-MLX -huggingface-cli download Qwen/Qwen2-1.5B-Instruct-MLX -huggingface-cli download Qwen/Qwen2-7B-Instruct-MLX +hf download Qwen/Qwen2-0.5B-Instruct-MLX +hf download Qwen/Qwen2-1.5B-Instruct-MLX +hf download Qwen/Qwen2-7B-Instruct-MLX # Run the tests pdm run main --solution tiny_llm --loader week1 --model qwen2-0.5b \ --prompt "Give me a short introduction to large language model" @@ -75,4 +75,3 @@ It should gives you a reasonable response of "what is a large language model". R `--solution ref` to use the reference solution. {{#include copyright.md}} - From 2ace66cf959b11084ff9acb3912ddf5c36e2e1ce Mon Sep 17 00:00:00 2001 From: Connor Date: Sun, 22 Feb 2026 13:56:47 -0800 Subject: [PATCH 77/79] tests: parametrize flash attention mask coverage (#96) --- tests_refsol/test_week_2_day_4.py | 41 ++++++++++++++++++++----------- 1 file changed, 26 insertions(+), 15 deletions(-) diff --git a/tests_refsol/test_week_2_day_4.py b/tests_refsol/test_week_2_day_4.py index 8e8aba4..3e8b3d4 100644 --- a/tests_refsol/test_week_2_day_4.py +++ b/tests_refsol/test_week_2_day_4.py @@ -4,56 +4,67 @@ from .utils import * -def attention_helper(stream: mx.Stream, H_q, H, L, E, S, BATCH): +def attention_helper(stream: mx.Stream, H_q, H, L, E, S, BATCH, with_mask: bool): precision = mx.float32 with mx.stream(stream): q_shape = (BATCH, H_q, L, E) kv_shape = (BATCH, H, S, E) + mask_shape = (BATCH, H_q, L, S) scale = 0.9 for _ in range(100): query = mx.random.uniform(shape=q_shape, dtype=precision) key = mx.random.uniform(shape=kv_shape, dtype=precision) value = mx.random.uniform(shape=kv_shape, dtype=precision) + mask = mx.random.uniform(shape=mask_shape, dtype=precision) if with_mask else None reference_output = mx.fast.scaled_dot_product_attention( q=query, k=key, v=value, scale=scale, + mask=mask, ) user_output = flash_attention( query, key, value, scale=scale, + mask=mask, ) mx.eval(user_output) # so that any error will be caught here assert_allclose(user_output, reference_output, precision=mx.float16) -def test_flash_attention_cpu_small(): - attention_helper(mx.cpu, 6, 3, 2, 5, 3, 1) +@pytest.mark.parametrize("with_mask", [False, True], ids=["no_mask", "mask"]) +def test_task_2_flash_attention_cpu_small(with_mask: bool): + attention_helper(mx.cpu, 6, 3, 2, 5, 3, 1, with_mask) -def test_flash_attention_cpu(): - attention_helper(mx.cpu, 18, 6, 7, 5, 3, 10) +@pytest.mark.parametrize("with_mask", [False, True], ids=["no_mask", "mask"]) +def test_task_2_flash_attention_cpu(with_mask: bool): + attention_helper(mx.cpu, 18, 6, 7, 5, 3, 10, with_mask) -def test_flash_attention_cpu_large(): - attention_helper(mx.cpu, 28, 4, 16, 128, 16, 3) +@pytest.mark.parametrize("with_mask", [False, True], ids=["no_mask", "mask"]) +def test_task_2_flash_attention_cpu_large(with_mask: bool): + attention_helper(mx.cpu, 28, 4, 16, 128, 16, 3, with_mask) -def test_flash_attention_gpu_extra_small(): - attention_helper(mx.gpu, 1, 1, 5, 7, 4, 1) +@pytest.mark.parametrize("with_mask", [False, True], ids=["no_mask", "mask"]) +def test_task_3_flash_attention_gpu_extra_small(with_mask: bool): + attention_helper(mx.gpu, 1, 1, 5, 7, 4, 1, with_mask) -def test_flash_attention_gpu_small(): - attention_helper(mx.gpu, 6, 3, 2, 5, 3, 1) +@pytest.mark.parametrize("with_mask", [False, True], ids=["no_mask", "mask"]) +def test_task_3_flash_attention_gpu_small(with_mask: bool): + attention_helper(mx.gpu, 6, 3, 2, 5, 3, 1, with_mask) -def test_flash_attention_gpu(): - attention_helper(mx.gpu, 18, 6, 7, 5, 3, 10) +@pytest.mark.parametrize("with_mask", [False, True], ids=["no_mask", "mask"]) +def test_task_3_flash_attention_gpu(with_mask: bool): + attention_helper(mx.gpu, 18, 6, 7, 5, 3, 10, with_mask) -def test_flash_attention_gpu_large(): - attention_helper(mx.gpu, 28, 4, 16, 128, 16, 3) +@pytest.mark.parametrize("with_mask", [False, True], ids=["no_mask", "mask"]) +def test_task_3_flash_attention_gpu_large(with_mask: bool): + attention_helper(mx.gpu, 28, 4, 16, 128, 16, 3, with_mask) From bb1e902c78919778be13ab5f39a376f3ee1e62a2 Mon Sep 17 00:00:00 2001 From: Connor Date: Sun, 22 Feb 2026 16:40:31 -0800 Subject: [PATCH 78/79] docs: add week2 flash-attention CPU part (#97) * docs: add week2 flash-attention chapter links and draft Signed-off-by: Connor1996 --- .cspell.json | 1 + README.md | 2 +- book/src/SUMMARY.md | 2 +- book/src/glossary.md | 1 + book/src/week2-04-flash-attention.md | 112 +++++++++++++++++++++++++++ 5 files changed, 116 insertions(+), 2 deletions(-) create mode 100644 book/src/week2-04-flash-attention.md diff --git a/.cspell.json b/.cspell.json index 1616579..9fc2d00 100644 --- a/.cspell.json +++ b/.cspell.json @@ -30,6 +30,7 @@ "dequantization", "dequantizes", "dtype", + "threadgroups", ], "ignoreRegExpList": [ "`[^`]*`", diff --git a/README.md b/README.md index 3c3e153..7e4c753 100644 --- a/README.md +++ b/README.md @@ -40,7 +40,7 @@ Week 1 is complete. Week 2 is in progress. | 2.1 | Key-Value Cache | ✅ | ✅ | ✅ | | 2.2 | Quantized Matmul and Linear - CPU | ✅ | ✅ | ✅ | | 2.3 | Quantized Matmul and Linear - GPU | ✅ | ✅ | ✅ | -| 2.4 | Flash Attention 2 - CPU | ✅ | ✅ | 🚧 | +| 2.4 | Flash Attention 2 - CPU | ✅ | ✅ | ✅ | | 2.5 | Flash Attention 2 - GPU | ✅ | ✅ | 🚧 | | 2.6 | Continuous Batching | ✅ | ✅ | ✅ | | 2.7 | Chunked Prefill | ✅ | ✅ | ✅ | diff --git a/book/src/SUMMARY.md b/book/src/SUMMARY.md index f3c342a..f7fb814 100644 --- a/book/src/SUMMARY.md +++ b/book/src/SUMMARY.md @@ -16,7 +16,7 @@ - [Week 2: Tiny vLLM](./week2-overview.md) - [Key-Value Cache](./week2-01-kv-cache.md) - [Quantized Matmul (2 Days)](./week2-02-quantized-matmul.md) - - [Flash Attention (2 Days)]() + - [Flash Attention (2 Days)](./week2-04-flash-attention.md) - [Continuous Batching (2 Days)](./week2-06-prefill-and-batch.md) - [Week 3: Serving]() diff --git a/book/src/glossary.md b/book/src/glossary.md index b9a4a0d..134016f 100644 --- a/book/src/glossary.md +++ b/book/src/glossary.md @@ -14,5 +14,6 @@ - [Qwen2 Transformer Block](./week1-05-qwen2-model.md) - [Week 1 Qwen2 Model](./week1-05-qwen2-model.md) - [dequantize_linear](./week1-05-qwen2-model.md) +- [Flash Attention 2](./week2-04-flash-attention.md) {{#include copyright.md}} diff --git a/book/src/week2-04-flash-attention.md b/book/src/week2-04-flash-attention.md new file mode 100644 index 0000000..3578e82 --- /dev/null +++ b/book/src/week2-04-flash-attention.md @@ -0,0 +1,112 @@ +# Week 2 Day 4-5: Flash Attention 2 + +In this chapter, we will implement Flash Attention 2 for the Week 2 Qwen2 serving pipeline. The goal is to replace the regular attention path with a tiled implementation to reduce memory bandwidth and increase throughput, especially for long contexts. + +**📚 Readings** + +- [FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness](https://arxiv.org/abs/2205.14135) +- [FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning](https://arxiv.org/abs/2307.08691) +- [MLX Extension Development Guide](https://ml-explore.github.io/mlx/build/html/dev/extensions.html) +- [MLX steel attention kernel (reference)](https://github.com/ml-explore/mlx/blob/main/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h) + +## Why Flash Attention? + +The key idea from the FlashAttention papers is that attention is often **IO-bound**, not FLOP-bound. + +In the standard implementation, we compute: + +1. `S = QK^T` +2. `P = softmax(S + mask)` +3. `O = PV` + +This path materializes large `L x S` tensors (`S` and often `P`) in global memory. For long contexts, repeatedly writing and reading these tensors dominates runtime. + +For example, if `L = S = 4096`: + +```plain +One L x S matrix: 4096 x 4096 = 16,777,216 elements +float32 storage: ~64 MB per matrix per head +Scores + probabilities: ~128 MB temporary memory per head +``` + +So even before counting Q/K/V and output tensors, memory traffic is already huge. + +### IO-Aware Exact Attention + +FlashAttention avoids this bottleneck by tiling Q/K/V into on-chip memory (cache / shared memory), and combining each tile with **online softmax** updates. Instead of storing the full attention matrix, it keeps only per-row running statistics (`m`, `l`) and partial output (`o`). + +This gives three practical benefits: + +- **Exactness**: same result as standard softmax attention (not an approximation). +- **Lower memory**: activation memory scales linearly with sequence length instead of quadratically. +- **Higher throughput**: fewer high-bandwidth-memory accesses, which is usually the real bottleneck. + +## Online Softmax Recap + +For one query row, split keys/values into tiles `j = 1..T`: + +$$ +m^{(j)} = \max\left(m^{(j-1)}, \max(s^{(j)})\right) +$$ + +$$ +l^{(j)} = e^{m^{(j-1)} - m^{(j)}} l^{(j-1)} + \sum e^{s^{(j)} - m^{(j)}} +$$ + +$$ +o^{(j)} = e^{m^{(j-1)} - m^{(j)}} o^{(j-1)} + \sum e^{s^{(j)} - m^{(j)}} v^{(j)} +$$ + +At the end: + +$$ +o = \frac{o^{(T)}}{l^{(T)}} +$$ + +This is the core numerical trick used by both the CPU and GPU kernels in this chapter, and the rest of the implementation is mostly about mapping this update rule to CPU loops and Metal threadgroups. + +## Task 1: Implement `flash_attention` Wrapper + +``` +src/tiny_llm/attention.py +``` + +Implement `flash_attention(query, key, value, scale=None, mask=None)` so it matches the extension API in `tiny_llm_ext`. + +Follow the same shape convention as Week 1 and Week 2 attention: + +```plain +query: B..., H_q, L, E +key: B..., H, S, E +value: B..., H, S, E +mask: B..., H_q, L, S +out: B..., H_q, L, E +``` + +The wrapper should compute `factor` using `mx.rsqrt` when `scale` is `None`, flatten batch and head dimensions before calling into C++, and reshape the output back to the original layout. Make sure `query`, `key`, and `value` are contiguous before calling the extension. For `mask`, always broadcast to `B..., H_q, L, S`, reshape to `(N, L, S)`, and cast to `float32` so that CPU and GPU kernels receive exactly the same dtype. + +## Task 2: Implement `flash_attention` (CPU version) + +``` +src/extensions/src/tiny_llm_ext.h +src/extensions/bindings.cpp +src/extensions/src/flash_attention.cpp +src/extensions/CMakeLists.txt +``` + +In this task, add the new MLX primitive and its CPU implementation. The structure is the same as the quantized matmul chapter: declare the primitive in `tiny_llm_ext.h`, expose it in `bindings.cpp`, and register `flash_attention.cpp` in `CMakeLists.txt`. + +Before creating the lazy output array, validate all shape and dtype constraints in C++: inputs should be 3D float32 tensors, `num_heads` must be divisible by `num_kv_heads`, and head mapping between Q and KV batches must be consistent. + +Then implement `FlashAttention::eval_cpu(...)` with tiled online softmax. Use `Br = 32` and `Bc = 32`, iterate over `(n, i, j)` tiles, map query heads to KV heads with `q_kv_heads_ratio = num_heads / num_kv_heads`, and accumulate in float32. Mask values should be applied in each tile before updating `m_i` and `l_i`. + +You can test your implementation by running: + +```bash +pdm run build-ext +pdm run test --week 2 --day 4 -- -k task_2 +``` + +## Task 3: Implement `flash_attention` (GPU version) + +{{#include copyright.md}} From ddfcaedfef0b597b0e5401bbce5a5b3a25b7e082 Mon Sep 17 00:00:00 2001 From: Connor Date: Sun, 22 Feb 2026 17:02:16 -0800 Subject: [PATCH 79/79] docs: add week2 entries to glossary (#98) Signed-off-by: Connor1996 --- book/src/glossary.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/book/src/glossary.md b/book/src/glossary.md index 134016f..64de7a7 100644 --- a/book/src/glossary.md +++ b/book/src/glossary.md @@ -14,6 +14,8 @@ - [Qwen2 Transformer Block](./week1-05-qwen2-model.md) - [Week 1 Qwen2 Model](./week1-05-qwen2-model.md) - [dequantize_linear](./week1-05-qwen2-model.md) -- [Flash Attention 2](./week2-04-flash-attention.md) +- [KV Cache](./week2-01-kv-cache.md) +- [Quantized Matmul](./week2-02-quantized-matmul.md) +- [Flash Attention](./week2-04-flash-attention.md) {{#include copyright.md}}