diff --git a/src/main/java/org/beehive/gpullama3/model/loader/LlamaModelLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/LlamaModelLoader.java index aa3a389..3fe6265 100644 --- a/src/main/java/org/beehive/gpullama3/model/loader/LlamaModelLoader.java +++ b/src/main/java/org/beehive/gpullama3/model/loader/LlamaModelLoader.java @@ -108,7 +108,7 @@ protected Weights createTornadoVMWeights(Map tensorEntr } // Validate supported types - if (ggmlType != GGMLType.F16 && ggmlType != GGMLType.Q8_0) { + if (ggmlType != GGMLType.F16 && ggmlType != GGMLType.Q8_0 && ggmlType != GGMLType.Q4_0) { throw new UnsupportedOperationException("Type: " + ggmlType + " currently not supported for TornadoVM weights."); } diff --git a/src/main/java/org/beehive/gpullama3/model/loader/MistralModelLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/MistralModelLoader.java index 0b9ba3d..cd1a632 100644 --- a/src/main/java/org/beehive/gpullama3/model/loader/MistralModelLoader.java +++ b/src/main/java/org/beehive/gpullama3/model/loader/MistralModelLoader.java @@ -102,7 +102,7 @@ protected Weights createTornadoVMWeights(Map tensorEntr } // Validate supported types - if (ggmlType != GGMLType.F16 && ggmlType != GGMLType.Q8_0) { + if (ggmlType != GGMLType.F16 && ggmlType != GGMLType.Q8_0 && ggmlType != GGMLType.Q4_0) { throw new UnsupportedOperationException("Type: " + ggmlType + " currently not supported for TornadoVM weights."); } diff --git a/src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java index ce8e6ca..20f1ae9 100644 --- a/src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java +++ b/src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java @@ -9,6 +9,7 @@ import org.beehive.gpullama3.tensor.standard.*; import org.beehive.gpullama3.tensor.tornado.FP16TornadoTensor; import org.beehive.gpullama3.tensor.tornado.FP32TornadoTensor; +import org.beehive.gpullama3.tensor.tornado.Q4_0TornadoTensor; import org.beehive.gpullama3.tensor.tornado.Q8_0TornadoTensor; import org.beehive.gpullama3.tensor.tornado.TornadoTensor; import uk.ac.manchester.tornado.api.types.HalfFloat; @@ -130,7 +131,7 @@ public static TornadoTensor loadTornadoTensor(GGMLTensorEntry entry) { case F32 -> new FP32TornadoTensor(size, entry.memorySegment()); case F16 -> new FP16TornadoTensor(size, entry.memorySegment()); case Q8_0 -> Q8_0TornadoTensor.create(entry); - case Q4_0 -> throw new UnsupportedOperationException("Q4 format not supported yet"); + case Q4_0 -> Q4_0TornadoTensor.create(entry); default -> throw new UnsupportedOperationException("Quantization format " + ggmlType); }; } @@ -203,6 +204,14 @@ public static Q8_0TornadoTensor[] loadArrayAsQ8_0TornadoTensor(int size, IntFunc return array; } + public static Q4_0TornadoTensor[] loadArrayAsQ4_0TornadoTensor(int size, IntFunction getTensorEntry) { + Q4_0TornadoTensor[] array = new Q4_0TornadoTensor[size]; + for (int i = 0; i < size; i++) { + array[i] = Q4_0TornadoTensor.create(getTensorEntry.apply(i)); + } + return array; + } + public static FloatArray floatBufferToFloatArray(GGMLTensorEntry tensorEntry) { if (tensorEntry.ggmlType() == GGMLType.F32) { FloatBuffer buffer = tensorEntry.memorySegment().asByteBuffer().order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer(); diff --git a/src/main/java/org/beehive/gpullama3/model/loader/Phi3ModelLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/Phi3ModelLoader.java index 745367c..5e485a7 100644 --- a/src/main/java/org/beehive/gpullama3/model/loader/Phi3ModelLoader.java +++ b/src/main/java/org/beehive/gpullama3/model/loader/Phi3ModelLoader.java @@ -122,7 +122,7 @@ protected Weights createTornadoVMWeights(Map tensorEntr } // Validate supported types - if (ggmlType != GGMLType.F16 && ggmlType != GGMLType.Q8_0) { + if (ggmlType != GGMLType.F16 && ggmlType != GGMLType.Q8_0 && ggmlType != GGMLType.Q4_0) { throw new UnsupportedOperationException("Type: " + ggmlType + " currently not supported for TornadoVM weights."); } diff --git a/src/main/java/org/beehive/gpullama3/model/loader/Qwen2ModelLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/Qwen2ModelLoader.java index a3abe14..fae0b90 100644 --- a/src/main/java/org/beehive/gpullama3/model/loader/Qwen2ModelLoader.java +++ b/src/main/java/org/beehive/gpullama3/model/loader/Qwen2ModelLoader.java @@ -122,7 +122,7 @@ protected Weights createTornadoVMWeights(Map tensorEntr } // Validate supported types - if (ggmlType != GGMLType.F16 && ggmlType != GGMLType.Q8_0) { + if (ggmlType != GGMLType.F16 && ggmlType != GGMLType.Q8_0 && ggmlType != GGMLType.Q4_0) { throw new UnsupportedOperationException("Type: " + ggmlType + " currently not supported for TornadoVM weights."); } diff --git a/src/main/java/org/beehive/gpullama3/tensor/tornado/Q4_0TornadoTensor.java b/src/main/java/org/beehive/gpullama3/tensor/tornado/Q4_0TornadoTensor.java new file mode 100644 index 0000000..a96811e --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tensor/tornado/Q4_0TornadoTensor.java @@ -0,0 +1,131 @@ +package org.beehive.gpullama3.tensor.tornado; + +import org.beehive.gpullama3.tensor.GGMLTensorEntry; +import org.beehive.gpullama3.tensor.GGMLType; +import org.beehive.gpullama3.tensor.standard.FloatTensor; +import uk.ac.manchester.tornado.api.types.HalfFloat; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; +import uk.ac.manchester.tornado.api.types.arrays.ByteArray; + +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; +import java.nio.ByteOrder; + +public class Q4_0TornadoTensor extends TornadoTensor { + + private final HalfFloatArray scales; // One per 32-element block + private final ByteArray quants; // Packed 4-bit quantized values (2 per byte) + private MemorySegment segment; + + public Q4_0TornadoTensor(int size, HalfFloatArray scales, ByteArray quants, MemorySegment segment) { + super(size); + this.scales = scales; + this.quants = quants; + this.segment = segment; + } + + /** + * Returns the scale factors for GPU kernels. + * + * @return HalfFloatArray containing fp16 scale factors + */ + public HalfFloatArray getScales() { + return scales; + } + + /** + * Returns the quantized values for GPU kernels. + * + * @return ByteArray containing packed 4-bit quantized values + */ + public ByteArray getQuants() { + return quants; + } + + @Override + public int size() { + return size; + } + + @Override + public GGMLType type() { + return GGMLType.Q4_0; + } + + public MemorySegment asMemorySegment() { + return segment; + } + + /** + * Dequantizes and returns a single float value. + * + * @param index Element index + * @return Dequantized float value + */ + public float getFloat(int index) { + assert 0 <= index && index < size; + int blockIdx = index / GGMLType.Q4_0.getBlockSize(); + int withinBlockIdx = index % GGMLType.Q4_0.getBlockSize(); + + float scale = scales.get(blockIdx).getFloat32(); + + // Each byte contains 2 4-bit values + int byteIdx = withinBlockIdx / 2; + byte packedByte = quants.get(blockIdx * 16 + byteIdx); + + // Extract the 4-bit value (lower or upper nibble) + byte quant; + if (withinBlockIdx % 2 == 0) { + // Lower 4 bits + quant = (byte) (packedByte & 0x0F); + } else { + // Upper 4 bits + quant = (byte) ((packedByte >>> 4) & 0x0F); + } + + // Offset by -8 (same as Q8_0) + quant -= 8; + + return quant * scale; + } + + public static Q4_0TornadoTensor create(GGMLTensorEntry entry) { + if (entry.ggmlType() != GGMLType.Q4_0) { + throw new IllegalArgumentException("Expected Q4_0 tensor, got: " + entry.ggmlType() + " for tensor: " + entry.name()); + } + + int[] shape = entry.shape(); + int size = FloatTensor.numberOfElements(shape); + int numBlocks = size / GGMLType.Q4_0.getBlockSize(); + + if (size % GGMLType.Q4_0.getBlockSize() != 0) { + throw new IllegalArgumentException("Q4_0 tensor size must be multiple of " + GGMLType.Q4_0.getBlockSize() + ", got: " + size + " for tensor: " + entry.name()); + } + + MemorySegment q4Segment = entry.memorySegment(); + + // allocate the arrays for quantized data (packed 4-bit) and scales (fp16) + HalfFloatArray scales = new HalfFloatArray(numBlocks); + ByteArray quants = new ByteArray(numBlocks * 16); // 32 4-bit values = 16 bytes per block + + // unpack Q4_0 blocks: [2 bytes fp16 scale][16 bytes packed 4-bit quants] + ValueLayout.OfShort shortLayout = ValueLayout.JAVA_SHORT_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN); + ValueLayout.OfByte byteLayout = ValueLayout.JAVA_BYTE; + + for (int block = 0; block < numBlocks; block++) { + long blockOffset = block * GGMLType.Q4_0.getTypeSize(); // 18 bytes per block + + // read fp16 scale (first 2 bytes of block) + short scaleRaw = q4Segment.get(shortLayout, blockOffset); + scales.set(block, new HalfFloat(scaleRaw)); + + // read 16 bytes of packed 4-bit quantized values (remaining bytes of block) + for (int i = 0; i < 16; i++) { + byte quantValue = q4Segment.get(byteLayout, blockOffset + 2 + i); + quants.set(block * 16 + i, quantValue); + } + } + + return new Q4_0TornadoTensor(size, scales, quants, q4Segment); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/base/QuantizationPlannerFactory.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/base/QuantizationPlannerFactory.java index 1684a5b..2f57eb6 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/base/QuantizationPlannerFactory.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/base/QuantizationPlannerFactory.java @@ -16,6 +16,10 @@ import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.Phi3Q8_0LayerPlanner; import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.Qwen2Q8_0LayerPlanner; import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.Qwen3Q8_0LayerPlanner; +import org.beehive.gpullama3.tornadovm.layerplanner.model.q4_0.LlamaQ4_0LayerPlanner; +import org.beehive.gpullama3.tornadovm.layerplanner.model.q4_0.Phi3Q4_0LayerPlanner; +import org.beehive.gpullama3.tornadovm.layerplanner.model.q4_0.Qwen2Q4_0LayerPlanner; +import org.beehive.gpullama3.tornadovm.layerplanner.model.q4_0.Qwen3Q4_0LayerPlanner; /** * Factory class responsible for creating appropriate layer planners based on model type and quantization. @@ -77,8 +81,16 @@ private static GenericLayerPlanner createFP32Planner(State state, Model model) { throw new UnsupportedOperationException("FP32 planners not yet implemented"); } + // ============ Q4_0 Planners ============ private static GenericLayerPlanner createQ4_0Planner(State state, Model model) { - throw new UnsupportedOperationException("Q4 planners not yet implemented"); + return switch (model.getModelType()) { + case LLAMA_3, MISTRAL -> new LlamaQ4_0LayerPlanner((LlamaState) state, model); + case QWEN_2 -> new Qwen2Q4_0LayerPlanner((Qwen2State) state, model); + case QWEN_3 -> new Qwen3Q4_0LayerPlanner((Qwen3State) state, model); + case PHI_3 -> new Phi3Q4_0LayerPlanner((Phi3State) state, model); + case DEEPSEEK_R1_DISTILL_QWEN -> new Qwen2Q4_0LayerPlanner((Qwen2State) state, model); + default -> throw new UnsupportedOperationException("Q4_0 not supported for model: " + model.getModelType()); + }; } } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q4_0/LlamaQ4_0LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q4_0/LlamaQ4_0LayerPlanner.java new file mode 100644 index 0000000..777e5cd --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q4_0/LlamaQ4_0LayerPlanner.java @@ -0,0 +1,27 @@ +package org.beehive.gpullama3.tornadovm.layerplanner.model.q4_0; + +import org.beehive.gpullama3.inference.state.LlamaState; +import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.llama.LlamaConfiguration; +import org.beehive.gpullama3.tornadovm.layerplanner.quantization.Q4_0LayerPlanner; +import org.beehive.gpullama3.tornadovm.layers.Activation; +import org.beehive.gpullama3.tornadovm.layers.type.q4_0.LlamaQ4_0FFNLayers; +import org.beehive.gpullama3.tornadovm.layers.type.q4_0.LogitsQ4_0Layer; + +public class LlamaQ4_0LayerPlanner extends Q4_0LayerPlanner { + + public LlamaQ4_0LayerPlanner(LlamaState state, Model model) { + super(state, model); + validateQuantizationType(); + setupTornadoForwardPlan(); + } + + @Override + protected void initializeLayerComponents() { + this.activationLayer = new Activation("activationUpdate", this.state, this.weights, this.config); + this.ffnLayers = new LlamaQ4_0FFNLayers("llamaFFN", this.state, this.weights, this.config, this.schedulerType); + this.logitsLayer = new LogitsQ4_0Layer("llamaLogits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID(), this.schedulerType); + } + +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q4_0/Phi3Q4_0LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q4_0/Phi3Q4_0LayerPlanner.java new file mode 100644 index 0000000..cf91f1b --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q4_0/Phi3Q4_0LayerPlanner.java @@ -0,0 +1,38 @@ +package org.beehive.gpullama3.tornadovm.layerplanner.model.q4_0; + +import org.beehive.gpullama3.inference.state.Phi3State; +import org.beehive.gpullama3.inference.weights.tornado.Phi3TornadoWeights; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.phi3.Phi3Configuration; +import org.beehive.gpullama3.tornadovm.layerplanner.quantization.Q4_0LayerPlanner; +import org.beehive.gpullama3.tornadovm.layers.Activation; +import org.beehive.gpullama3.tornadovm.layers.type.q4_0.Phi3Q4_0FFNLayers; +import org.beehive.gpullama3.tornadovm.layers.type.q4_0.LogitsQ4_0Layer; + +/** + * Phi3Q4_0LayerPlanner: Phi3 model with Q4_0-quantized weights. + * + * Follows the same pattern as Qwen3Q4_0LayerPlanner but with: + * - Phi3-specific FFN layers (combined QKV + gate/up FFN) + * - Phi3TornadoWeights (4-bit integer quantization) + * - Phi3Configuration + * - 4x memory compression vs FP16, 2x vs Q8_0 + * + * Inherits from Q4_0LayerPlanner + */ +public class Phi3Q4_0LayerPlanner extends Q4_0LayerPlanner { + + public Phi3Q4_0LayerPlanner(Phi3State state, Model model) { + super(state, model); + validateQuantizationType(); + setupTornadoForwardPlan(); + } + + @Override + protected void initializeLayerComponents() { + this.activationLayer = new Activation("activationUpdate", this.state, this.weights, this.config); + this.ffnLayers = new Phi3Q4_0FFNLayers("phi3FFN", this.state, this.weights, this.config, this.schedulerType); + this.logitsLayer = new LogitsQ4_0Layer("phi3Logits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID(), this.schedulerType); + } + +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q4_0/Qwen2Q4_0LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q4_0/Qwen2Q4_0LayerPlanner.java new file mode 100644 index 0000000..546b6c3 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q4_0/Qwen2Q4_0LayerPlanner.java @@ -0,0 +1,27 @@ +package org.beehive.gpullama3.tornadovm.layerplanner.model.q4_0; + +import org.beehive.gpullama3.inference.state.Qwen2State; +import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.qwen2.Qwen2Configuration; +import org.beehive.gpullama3.tornadovm.layerplanner.quantization.Q4_0LayerPlanner; +import org.beehive.gpullama3.tornadovm.layers.Activation; +import org.beehive.gpullama3.tornadovm.layers.type.q4_0.Qwen2Q4_0FFNLayers; +import org.beehive.gpullama3.tornadovm.layers.type.q4_0.LogitsQ4_0Layer; + +public class Qwen2Q4_0LayerPlanner extends Q4_0LayerPlanner { + + public Qwen2Q4_0LayerPlanner(Qwen2State state, Model model) { + super(state, model); + validateQuantizationType(); + setupTornadoForwardPlan(); + } + + @Override + protected void initializeLayerComponents() { + this.activationLayer = new Activation("activationUpdate", this.state, this.weights, this.config); + this.ffnLayers = new Qwen2Q4_0FFNLayers("qwen2FFN", this.state, this.weights, this.config, this.schedulerType); + this.logitsLayer = new LogitsQ4_0Layer("qwen2Logits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID(), this.schedulerType); + } + +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q4_0/Qwen3Q4_0LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q4_0/Qwen3Q4_0LayerPlanner.java new file mode 100644 index 0000000..5e7dcad --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q4_0/Qwen3Q4_0LayerPlanner.java @@ -0,0 +1,37 @@ +package org.beehive.gpullama3.tornadovm.layerplanner.model.q4_0; + +import org.beehive.gpullama3.inference.state.Qwen3State; +import org.beehive.gpullama3.inference.weights.tornado.Qwen3TornadoWeights; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.qwen3.Qwen3Configuration; +import org.beehive.gpullama3.tornadovm.layerplanner.quantization.Q4_0LayerPlanner; +import org.beehive.gpullama3.tornadovm.layers.Activation; +import org.beehive.gpullama3.tornadovm.layers.type.q4_0.Qwen3Q4_0FFNLayers; +import org.beehive.gpullama3.tornadovm.layers.type.q4_0.LogitsQ4_0Layer; + +/** + * Qwen3Q4_0LayerPlanner: Qwen3 model with Q4_0-quantized weights. + * + * Follows the same pattern as LlamaQ4_0LayerPlanner but with: + * - Qwen3-specific FFN layers (supports GQA) + * - Qwen3TornadoWeights (4-bit integer quantization) + * - Qwen3Configuration + * - 4x memory compression vs FP16, 2x vs Q8_0 + * + * Inherits from Q4_0LayerPlanner + */ +public class Qwen3Q4_0LayerPlanner extends Q4_0LayerPlanner { + + public Qwen3Q4_0LayerPlanner(Qwen3State state, Model model) { + super(state, model); + validateQuantizationType(); + setupTornadoForwardPlan(); + } + + @Override + protected void initializeLayerComponents() { + this.activationLayer = new Activation("activationUpdate", this.state, this.weights, this.config); + this.ffnLayers = new Qwen3Q4_0FFNLayers("qwen3FFN", this.state, this.weights, this.config, this.schedulerType); + this.logitsLayer = new LogitsQ4_0Layer("qwen3Logits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID(),this.schedulerType); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/quantization/Q4_0LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/quantization/Q4_0LayerPlanner.java new file mode 100644 index 0000000..4aed6a7 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/quantization/Q4_0LayerPlanner.java @@ -0,0 +1,96 @@ +package org.beehive.gpullama3.tornadovm.layerplanner.quantization; + +import org.beehive.gpullama3.tensor.GGMLType; +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights; +import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.tornadovm.layerplanner.base.QuantizedLayerPlanner; +import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; +import org.beehive.gpullama3.tornadovm.layers.Activation; +import org.beehive.gpullama3.tornadovm.layers.type.q4_0.LogitsQ4_0Layer; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; + +import java.util.ArrayList; +import java.util.List; + +/** + * Base for all Q4_0-quantized layer planners. + * + * Subclasses: LlamaQ4_0LayerPlanner, Qwen2Q4_0LayerPlanner, etc. + * + * Q4_0 Specific: + * - Uses 4-bit integer quantization with uniform scaling per 32-element block + * - Weights: weights.xxxByteArray arrays (packed 4-bit values) + * - Compute: dequantize on-the-fly during matmul + * - Memory: 4x compression vs FP16, 2x vs Q8_0 + */ +public abstract class Q4_0LayerPlanner extends QuantizedLayerPlanner { + + protected Activation activationLayer; + protected AbstractFFNLayers ffnLayers; + protected LogitsQ4_0Layer logitsLayer; + + // Cache for task graphs and scheduler (set once, reused) + protected List cachedTaskGraphs; + protected GridScheduler cachedScheduler; + + protected Q4_0LayerPlanner(S state, Model model) { + super(state, model); + initializeLayerComponents(); + } + + @Override + protected void validateQuantizationType() { + if (this.weights.getWeightType() != GGMLType.Q4_0) { + throw new IllegalArgumentException("Q4_0LayerPlanner requires GGMLType.Q4_0, got: " + this.weights.getWeightType()); + } + } + + @Override + protected void initializeLayerComponents() { + // Override in subclasses (LlamaQ4_0LayerPlanner, etc.) + } + + protected final void setupTornadoForwardPlan() { + List allTaskGraphs = new ArrayList<>(); + GridScheduler masterScheduler = new GridScheduler(); + + // 1. Activation layer (common to all models) + allTaskGraphs.add(activationLayer.getImmutableTaskGraph()); + activationLayer.updateGridScheduler(masterScheduler); + + // 2. FFN layers (N transformer layers - model-specific) + allTaskGraphs.addAll(ffnLayers.getFfnLayerTaskGraphs()); + ffnLayers.updateGridScheduler(masterScheduler); + + // 3. Logits layer (common to all models) + allTaskGraphs.add(logitsLayer.getTaskGraph().snapshot()); + logitsLayer.updateGridScheduler(masterScheduler); + + // Cache for future retrievals + this.cachedTaskGraphs = allTaskGraphs; + this.cachedScheduler = masterScheduler; + } + + /** + * Returns cached task graphs (used by hardware strategy pattern). + * + * Removed from all model-specific planners - centralized here. + */ + public final List getImmutableTaskGraphs() { + return this.cachedTaskGraphs; + } + + /** + * Returns cached scheduler (used by hardware strategy pattern). + * + * Removed from all model-specific planners - centralized here. + */ + @Override + public final GridScheduler getGridScheduler() { + return this.cachedScheduler; + } + +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q4_0/LlamaQ4_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q4_0/LlamaQ4_0FFNLayers.java new file mode 100644 index 0000000..51f8b46 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q4_0/LlamaQ4_0FFNLayers.java @@ -0,0 +1,174 @@ +package org.beehive.gpullama3.tornadovm.layers.type.q4_0; + +import org.beehive.gpullama3.inference.state.LlamaState; +import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; +import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; +import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.WorkerGrid; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; + +import java.util.List; +import java.util.stream.IntStream; + +public class LlamaQ4_0FFNLayers extends AbstractFFNLayers { + + GridScheduler scheduler; + List ffnLayerTaskGraphs; + + public LlamaQ4_0FFNLayers(String taskGraphName, LlamaState state, LlamaTornadoWeights weights, Configuration config, SchedulerType schedulerType) { + super(taskGraphName, state, weights, config, schedulerType); + ffnLayerTaskGraphs = setupFFNLayered(); + } + + @Override + public GridScheduler getGridScheduler() { + return scheduler; + } + + @Override + public TaskGraph getTaskGraph() { + return null; + } + + @Override + public ImmutableTaskGraph getImmutableTaskGraph() { + return null; + } + + List setupFFNLayered() { + state.temp.init(0.0f); + state.tempFFN.init(0.0f); + var numLayers = config.numberOfLayers(); + + return IntStream.range(0, numLayers).mapToObj(i -> { + var ffnLayer = setupSingleFFNLayer((LlamaTornadoWeights) weights, config, i); + if (i == numLayers - 1) { + setupLastID(ffnLayer.getTaskGraphName()); + } + return ffnLayer.snapshot(); + }).toList(); + } + + TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config, int layerIndex) { + var layerTaskGraphName = "layer_" + layerIndex; + TaskGraph unifiedLayer = new TaskGraph(layerTaskGraphName); + unifiedLayer.consumeFromDevice(state.wrapX); + unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + //Copy-in weights per layer for batched-layered layout + weights.rms_att_weightLayered[layerIndex].asFloatArray(), weights.wqLayered[layerIndex].getQuants(), weights.wqLayered[layerIndex].getScales(), weights.wkLayered[layerIndex].getQuants(), + weights.wkLayered[layerIndex].getScales(), weights.wvLayered[layerIndex].getQuants(), weights.wvLayered[layerIndex].getScales(), weights.woLayered[layerIndex].getQuants(), + weights.woLayered[layerIndex].getScales(), weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), weights.w1Layered[layerIndex].getQuants(), weights.w1Layered[layerIndex].getScales(), + weights.w2Layered[layerIndex].getQuants(), weights.w2Layered[layerIndex].getScales(), weights.w3Layered[layerIndex].getQuants(), weights.w3Layered[layerIndex].getScales()); + unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex); + unifiedLayer.task("reductionsOneBlock", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.temp, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize); + if (shouldUseFinalNormalization()) { + unifiedLayer.task("reductionFinalNormalization", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.temp, + config.dim(), config.rmsNormEps()); + } + unifiedLayer.task("mapContext", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, state.wrapX, weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp) + .task("qmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, state.wrapXb, state.wrapQ, weights.wqLayered[layerIndex].getQuants(), + weights.wqLayered[layerIndex].getScales(), config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) + .task("kmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, state.wrapXb, state.wrapK, weights.wkLayered[layerIndex].getQuants(), + weights.wkLayered[layerIndex].getScales(), config.dim(), config.kvDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) + .task("vmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, state.wrapXb, state.wrapV, weights.wvLayered[layerIndex].getQuants(), + weights.wvLayered[layerIndex].getScales(), config.dim(), config.kvDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) + .task("rope", TransformerComputeKernelsLayered::ropeRotation, context, state.positionHolder, state.wrapQ, state.wrapK, config.kvDim(), config.headSize()) + .task("copyToCaches", TransformerComputeKernelsLayered::copyToCache, state.wrapKeyCache, state.wrapK, state.wrapValueCache, state.wrapV, state.positionHolder, config.kvDim(), + layerIndex, config.contextLength()); + configureAttention(unifiedLayer, layerIndex); + unifiedLayer.task("matmul1", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, state.wrapXb, state.wrapX, weights.woLayered[layerIndex].getQuants(), + weights.woLayered[layerIndex].getScales(), config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) + .task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.tempFFN, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize); + if (shouldUseFinalNormalization()) { + unifiedLayer.task("reductionFinalNormalizationFFN", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.tempFFN, + config.dim(), config.rmsNormEps()); + } + unifiedLayer.task("mapContextFFN", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, state.wrapX, weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), state.tempFFN) + .task("fused_ffn_w1_w3", TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivation, context, state.wrapXb, state.wrapHb, weights.w1Layered[layerIndex].getQuants(), + weights.w1Layered[layerIndex].getScales(), weights.w3Layered[layerIndex].getQuants(), weights.w3Layered[layerIndex].getScales(), config.dim(), config.hiddenDim(), + LOCAL_WORK_GROUP_SIZE_ALLOC) + .task("projectionTwo", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, state.wrapHb, state.wrapX, weights.w2Layered[layerIndex].getQuants(), + weights.w2Layered[layerIndex].getScales(), config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC).persistOnDevice(state.wrapX); + return unifiedLayer; + } + + protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int layerIndex) { + // First layer: Transfer initial data to device (one-time transfer) + if (layerIndex == 0) { + // Transfer all attention-related data: query, key, value matrices and their caches + unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.positionHolder, state.temp, state.tempFFN); // + unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, // + context, state.wrapXb, state.wrapXb2, // + state.wrapQ, state.wrapK, state.wrapV, // + state.wrapKeyCache, state.wrapValueCache, // + state.wrapAtt, state.wrapHb); // + } else { + // Subsequent layers: Consume data already on device from previous layer + unifiedLayer.consumeFromDevice(context, state.wrapXb, state.wrapXb2, // + state.wrapQ, state.wrapK, state.wrapV, // + state.wrapKeyCache, state.wrapValueCache, // + state.wrapAtt, state.wrapHb, // + state.positionHolder // + ); + } + return unifiedLayer; + } + + @Override + public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) { + WorkerGrid ropeWorker = WorkerGridFactory.genericWorker(config.dim() / 2, 128); + WorkerGrid rmsNormWorker = WorkerGridFactory.createRmsNormWorker(config.dim(), 256); + + int configDimRowMajorGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid configDimRowMajorGlobalWorker = WorkerGridFactory.genericWorker(configDimRowMajorGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); + + int configKvDimRowMajorGlobal = config.kvDim() * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid configKvDimRowMajorGlobalWorker = WorkerGridFactory.genericWorker(configKvDimRowMajorGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); + + int configHiddenDimRowMajor = config.hiddenDim() * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid configHiddenDimRowMajorWorker = WorkerGridFactory.genericWorker(configHiddenDimRowMajor, LOCAL_WORK_GROUP_SIZE_ALLOC); + WorkerGrid parallelAttentionWorker = WorkerGridFactory.createAttentionWorker(config.numberOfHeads(), config.headSize()); + WorkerGrid copyToCachesWorker = WorkerGridFactory.genericWorker(config.dim(), 128); + + for (int i = 0; i < config.numberOfLayers(); i++) { + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qmatmul", configDimRowMajorGlobalWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".kmatmul", configKvDimRowMajorGlobalWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".vmatmul", configKvDimRowMajorGlobalWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope", ropeWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".matmul1", configDimRowMajorGlobalWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".projectionTwo", configDimRowMajorGlobalWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".fused_ffn_w1_w3", configHiddenDimRowMajorWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlock", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContext", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlockFFN", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContextFFN", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".parallel-attention", parallelAttentionWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".copyToCaches", copyToCachesWorker); + } + return tornadoForwardScheduler; + } + + public List getFfnLayerTaskGraphs() { + return ffnLayerTaskGraphs; + } + + private TaskGraph configureAttention(TaskGraph unifiedLayer, int layerIndex) { + if (schedulerType == SchedulerType.NVIDIA) { + return unifiedLayer.task("parallel-attention", TransformerComputeKernelsLayered::processHeadsFlashAttention, + context, state.wrapQ, state.wrapKeyCache, state.wrapValueCache, state.wrapXb, + config.numberOfHeads(), config.headSize(), config.kvDim(), config.kvMul(), + state.positionHolder, layerIndex, config.contextLength()); + } else { + return unifiedLayer.task("parallel-attention", TransformerComputeKernelsLayered::processHeadsParallel, + state.wrapQ, state.wrapKeyCache, state.wrapValueCache, state.wrapXb, + config.numberOfHeads(), config.headSize(), config.kvDim(), config.kvMul(), config.contextLength(), + state.positionHolder, state.wrapAtt, layerIndex, config.contextLength()); + } + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q4_0/LogitsQ4_0Layer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q4_0/LogitsQ4_0Layer.java new file mode 100644 index 0000000..db825f7 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q4_0/LogitsQ4_0Layer.java @@ -0,0 +1,90 @@ +package org.beehive.gpullama3.tornadovm.layers.type.q4_0; + +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.inference.weights.Weights; +import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights; +import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights; +import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; +import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; +import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.layers.AbstractLayer; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.WorkerGrid; +import uk.ac.manchester.tornado.api.WorkerGrid1D; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; + +import java.util.SequencedCollection; + +public class LogitsQ4_0Layer extends AbstractLayer { + + private String lastTaskGraphID; + private TaskGraph logitsTaskGraph; + private ImmutableTaskGraph immutableLogitsGraph; + private GridScheduler scheduler; + private SchedulerType schedulerType; + + public LogitsQ4_0Layer(String taskGraphName, State state, Weights weights, Configuration config, String lastTaskGraphID, SchedulerType schedulerType) { + super(taskGraphName, state, weights, config); + this.lastTaskGraphID = lastTaskGraphID; + state.tempLogits.init(0.0f); + var tornadoWeights = requireWeightsType(weights, TornadoWeights.class, "LogitsQ4_0Layer", "TornadoTensor"); + this.logitsTaskGraph = setupLogitsTaskGraph(tornadoWeights, config); + this.schedulerType = schedulerType; + } + + @Override + public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) { + WorkerGrid logitsRMS; + if (weights instanceof Qwen2TornadoWeights) { + logitsRMS = WorkerGridFactory.createRmsNormWorker(config.dim(), 32); + } else { + logitsRMS = WorkerGridFactory.createRmsNormWorker(config.dim(), 256); + } + + var vocabSizeRowMajor = config.vocabularySize() * LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS; + WorkerGrid vocabWorker = new WorkerGrid1D(vocabSizeRowMajor); + vocabWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, 1, 1); + + tornadoForwardScheduler.addWorkerGrid("logits.projection", vocabWorker); + tornadoForwardScheduler.addWorkerGrid("logits.reductionsOneBlockLogits", logitsRMS); + tornadoForwardScheduler.addWorkerGrid("logits.mapContextLogits", logitsRMS); + return tornadoForwardScheduler; + } + + private TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration config) { + TaskGraph logits = new TaskGraph("logits"); + logits.consumeFromDevice(lastTaskGraphID, state.wrapX).transferToDevice(DataTransferMode.EVERY_EXECUTION, state.tempLogits) + .transferToDevice(DataTransferMode.FIRST_EXECUTION, context, state.wrapLogits, weights.wclsByteArray.getQuants(), weights.wclsByteArray.getScales(), + weights.rms_final_weight_as_floatArray) + .task("reductionsOneBlockLogits", TransformerComputeKernels::reductionOneBlockWithLayer, context, state.tempLogits, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize); + if (schedulerType == SchedulerType.NON_NVIDIA) { + logits.task("reductionFinalNormalizationLogits", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.tempLogits, config.dim(), config.rmsNormEps()); + } + logits.task("mapContextLogits", TransformerComputeKernels::reductionOneBlock2WithLogits, context, state.wrapX, weights.rms_final_weight_as_floatArray.asFloatArray(), state.tempLogits) + .task("projection", TransformerComputeKernelsLayered::matrixVectorGeneric, // + context, state.wrapX, state.wrapLogits, weights.wclsByteArray.getQuants(), weights.wclsByteArray.getScales(), // + config.dim(), config.vocabularySize(), LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS) // + .transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits); + return logits; + } + + @Override + public GridScheduler getGridScheduler() { + return scheduler; + } + + @Override + public TaskGraph getTaskGraph() { + return logitsTaskGraph; + } + + @Override + public ImmutableTaskGraph getImmutableTaskGraph() { + return immutableLogitsGraph; + } + +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q4_0/Phi3Q4_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q4_0/Phi3Q4_0FFNLayers.java new file mode 100644 index 0000000..c21b67d --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q4_0/Phi3Q4_0FFNLayers.java @@ -0,0 +1,317 @@ +package org.beehive.gpullama3.tornadovm.layers.type.q4_0; + +import org.beehive.gpullama3.inference.state.Phi3State; +import org.beehive.gpullama3.inference.weights.tornado.Phi3TornadoWeights; +import org.beehive.gpullama3.model.phi3.Phi3Configuration; +import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; +import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.WorkerGrid; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; + +import java.util.ArrayList; +import java.util.List; + +/** + * Phi3Q4_0FFNLayers: Q4_0-quantized FFN layers for Phi3 with Group Query Attention (GQA) support. + * + * Key Differences from Phi3FP16FFNLayers: + * - Uses Q4_0-quantized weights (getQuants() and getScales()) + * - Same attention and RoPE kernels as FP16 version + * - 4-bit integer computations with dequantization + * - 4x memory compression vs FP16, 2x vs Q8_0 + * - Same combined QKV and gate/up FFN structure + * + * Works directly with Phi3State to access and mutate Phi3-specific state fields. + */ +public class Phi3Q4_0FFNLayers extends AbstractFFNLayers { + + TaskGraph ffnLayerTaskGraph; + GridScheduler scheduler; + List ffnLayerTaskGraphs; + + // Typed references to Phi3-specific state and config + private final Phi3State phi3State; + private final Phi3Configuration phi3Config; + + // Phi3-specific dimension for combined QKV buffer + private final int opSize; + + public Phi3Q4_0FFNLayers(String taskGraphName, Phi3State state, Phi3TornadoWeights weights, Phi3Configuration config, SchedulerType schedulerType) { + super(taskGraphName, state, weights, config, schedulerType); + this.phi3State = state; + this.phi3Config = config; + this.opSize = config.dim() + 2 * (config.numberOfKeyValueHeads() * config.headSize()); + ffnLayerTaskGraphs = setupFFNLayered(); + } + + @Override + public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) { + WorkerGrid rmsNormWorker = WorkerGridFactory.createRmsNormWorker(config.dim(), 256); + WorkerGrid ropeWorker = WorkerGridFactory.genericWorker(config.dim() / 2, 128); + + int configDimRowMajorGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid configDimRowMajorGlobalWorker = WorkerGridFactory.genericWorker(configDimRowMajorGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); + + final int opSize = config.dim() + 2 * (config.numberOfKeyValueHeads() * config.headSize()); + + int qkvmatmulDimRowMajorGlobal = opSize * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid qkvDimRowMajorGlobalWorker = WorkerGridFactory.genericWorker(qkvmatmulDimRowMajorGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); + + int wgetUPDimRowMajor = 2 * config.hiddenDim() * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid wgetHiddenDimRowMajorWorker = WorkerGridFactory.genericWorker(wgetUPDimRowMajor, LOCAL_WORK_GROUP_SIZE_ALLOC); + + WorkerGrid parallelAttentionWorker = WorkerGridFactory.createAttentionWorker(config.numberOfHeads(), config.headSize()); + WorkerGrid copyToCachesWorker = WorkerGridFactory.genericWorker(config.dim(), 128); + WorkerGrid splitGateUpSiLUWorker = WorkerGridFactory.genericWorker(config.hiddenDim(), 128); + WorkerGrid splitQKVWorker = WorkerGridFactory.genericWorker(opSize, 128); + for (int i = 0; i < config.numberOfLayers(); i++) { + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qkvmatmul", qkvDimRowMajorGlobalWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".splitQKV", splitQKVWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope", ropeWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".matmul1", configDimRowMajorGlobalWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".wDown", configDimRowMajorGlobalWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".wGateUp", wgetHiddenDimRowMajorWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlock", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContext", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlockFFN", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContextFFN", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".parallel-attention", parallelAttentionWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".copyToCaches", copyToCachesWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".gateUpSiLU", splitGateUpSiLUWorker); + } + return tornadoForwardScheduler; + } + + @Override + public GridScheduler getGridScheduler() { + return scheduler; + } + + @Override + public TaskGraph getTaskGraph() { + return ffnLayerTaskGraph; + } + + @Override + public ImmutableTaskGraph getImmutableTaskGraph() { + return null; + } + + public List getFfnLayerTaskGraphs() { + return ffnLayerTaskGraphs; + } + + /** + * Setup all FFN layers for all transformer layers + */ + List setupFFNLayered() { + List ffnGraphs = new ArrayList<>(); + + // Initialize buffers using Phi3State directly + phi3State.temp.init(0.0f); + phi3State.tempFFN.init(0.0f); + + for (int layerIndex = 0; layerIndex < phi3Config.numberOfLayers(); layerIndex++) { + TaskGraph ffnLayer = setupSinglePhi3Q4_0FFNLayer((Phi3TornadoWeights) weights, layerIndex); + if (layerIndex == phi3Config.numberOfLayers() - 1) { + setupLastID(ffnLayer.getTaskGraphName()); + } + ffnGraphs.add(ffnLayer.snapshot()); + } + + return ffnGraphs; + } + + /** + * Setup a single transformer layer for Phi3 with Q4_0 quantization, combined QKV and gate/up FFN + */ + TaskGraph setupSinglePhi3Q4_0FFNLayer(Phi3TornadoWeights weights, int layerIndex) { + + TaskGraph unifiedLayer = new TaskGraph("layer_" + layerIndex); + unifiedLayer.consumeFromDevice(phi3State.wrapX); + unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + // Copy-in quantized weights per layer + weights.rms_att_weightLayered[layerIndex].asFloatArray(), + weights.wqkvLayered[layerIndex].getQuants(), + weights.wqkvLayered[layerIndex].getScales(), + weights.woLayered[layerIndex].getQuants(), + weights.woLayered[layerIndex].getScales(), + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), + weights.wUpLayered[layerIndex].getQuants(), + weights.wUpLayered[layerIndex].getScales(), + weights.wDownLayered[layerIndex].getQuants(), + weights.wDownLayered[layerIndex].getScales() + ); + unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex); + + // RMSNorm for attention input + unifiedLayer.task("reductionsOneBlock", + TransformerComputeKernelsLayered::reductionOneBlockWithLayer, + context, + phi3State.temp, + phi3State.wrapX, + phi3Config.dim(), + phi3Config.rmsNormEps(), + phi3State.localSize) + .task("mapContext", + TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, + context, + phi3State.wrapXb, + phi3State.wrapX, + weights.rms_att_weightLayered[layerIndex].asFloatArray(), + phi3State.temp); + + // Combined QKV projection (quantized) + unifiedLayer.task("qkvmatmul", + TransformerComputeKernelsLayered::matrixVectorGeneric, + context, + phi3State.wrapXb, + phi3State.wrapQkv, + weights.wqkvLayered[layerIndex].getQuants(), + weights.wqkvLayered[layerIndex].getScales(), + phi3Config.dim(), + opSize, + LOCAL_WORK_GROUP_SIZE_ALLOC) + .task("splitQKV", + TransformerComputeKernelsLayered::splitQKV, + phi3State.wrapQkv, + phi3State.wrapQ, + phi3State.wrapK, + phi3State.wrapV, + phi3Config.dim(), + phi3Config.headSize() * phi3Config.numberOfKeyValueHeads()); + + // RoPE rotation (Phi3-specific kernel) + unifiedLayer.task("rope", + TransformerComputeKernelsLayered::ropeRotationPhi3, + context, + phi3State.positionHolder, + phi3State.wrapQ, + phi3State.wrapK, + phi3Config.kvDim(), + phi3Config.headSize()); + + // Copy to caches + unifiedLayer.task("copyToCaches", + TransformerComputeKernelsLayered::copyToCache, + phi3State.wrapKeyCache, + phi3State.wrapK, + phi3State.wrapValueCache, + phi3State.wrapV, + phi3State.positionHolder, + phi3Config.kvDim(), + layerIndex, + phi3Config.contextLength()); + + // Parallel attention + unifiedLayer.task("parallel-attention", + TransformerComputeKernelsLayered::processHeadsFlashAttention, + context, + phi3State.wrapQ, + phi3State.wrapKeyCache, + phi3State.wrapValueCache, + phi3State.wrapXb, + phi3Config.numberOfHeads(), + phi3Config.headSize(), + phi3Config.kvDim(), + phi3Config.kvMul(), + phi3State.positionHolder, + layerIndex, + phi3Config.contextLength()); + + // Output projection (quantized) + unifiedLayer.task("matmul1", + TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, + context, + phi3State.wrapXb, + phi3State.wrapX, + weights.woLayered[layerIndex].getQuants(), + weights.woLayered[layerIndex].getScales(), + phi3Config.dim(), + phi3Config.dim(), + LOCAL_WORK_GROUP_SIZE_ALLOC); + + // FFN section: RMSNorm + unifiedLayer.task("reductionsOneBlockFFN", + TransformerComputeKernelsLayered::reductionOneBlockWithLayer, + context, + phi3State.tempFFN, + phi3State.wrapX, + phi3Config.dim(), + phi3Config.rmsNormEps(), + phi3State.localSize) + .task("mapContextFFN", + TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, + context, + phi3State.wrapXb, + phi3State.wrapX, + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), + phi3State.tempFFN); + + // FFN: combined Up and Gate projection (outputs 2 * hiddenDim, quantized) + unifiedLayer.task("wGateUp", + TransformerComputeKernelsLayered::matrixVectorGeneric, + context, + phi3State.wrapXb, + phi3State.wrapHb, + weights.wUpLayered[layerIndex].getQuants(), + weights.wUpLayered[layerIndex].getScales(), + phi3Config.dim(), + 2 * phi3Config.hiddenDim(), + LOCAL_WORK_GROUP_SIZE_ALLOC) + .task("gateUpSiLU", + TransformerComputeKernelsLayered::splitGateUpAndSiLU, + phi3State.wrapHb, + phi3State.wrapHbG, + phi3State.wrapHbU, + phi3Config.hiddenDim()); + + // FFN: Down projection with residual (quantized) + unifiedLayer.task("wDown", + TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, + context, + phi3State.wrapHbU, + phi3State.wrapX, + weights.wDownLayered[layerIndex].getQuants(), + weights.wDownLayered[layerIndex].getScales(), + phi3Config.hiddenDim(), + phi3Config.dim(), + LOCAL_WORK_GROUP_SIZE_ALLOC) + .persistOnDevice( + phi3State.wrapX + ); + return unifiedLayer; + } + + /** + * Configure data transfers for first and subsequent layers + */ + protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int layerIndex) { + // First layer: Transfer initial data to device (one-time transfer) + if (layerIndex == 0) { + // Transfer all attention-related data: query, key, value matrices and their caches + unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.positionHolder, state.temp, state.tempFFN); // + unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, // + context, state.wrapXb, state.wrapXb2, // + state.wrapQ, state.wrapK, state.wrapV, // + state.wrapKeyCache, state.wrapValueCache, // + state.wrapAtt, state.wrapHb, // + phi3State.wrapHbG, phi3State.wrapHbU, phi3State.wrapQkv); // + } else { + // Subsequent layers: Consume data already on device from previous layer + unifiedLayer.consumeFromDevice(context, state.wrapXb, state.wrapXb2, // + state.wrapQ, state.wrapK, state.wrapV, // + state.wrapKeyCache, state.wrapValueCache, // + state.wrapAtt, state.wrapHb, // + state.positionHolder, // / + phi3State.wrapHbG, phi3State.wrapHbU, phi3State.wrapQkv); + } + return unifiedLayer; + } + +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q4_0/Qwen2Q4_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q4_0/Qwen2Q4_0FFNLayers.java new file mode 100644 index 0000000..e83d4a1 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q4_0/Qwen2Q4_0FFNLayers.java @@ -0,0 +1,252 @@ +package org.beehive.gpullama3.tornadovm.layers.type.q4_0; + +import org.beehive.gpullama3.inference.state.Qwen2State; +import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights; +import org.beehive.gpullama3.model.qwen2.Qwen2Configuration; +import org.beehive.gpullama3.tornadovm.kernels.Qwen2Kernels; +import org.beehive.gpullama3.tornadovm.kernels.Qwen3Kernels; +import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; +import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.WorkerGrid; +import uk.ac.manchester.tornado.api.WorkerGrid1D; +import uk.ac.manchester.tornado.api.WorkerGrid2D; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; + +import java.util.ArrayList; +import java.util.List; + +/** + * Qwen2Q4_0FFNLayers: Q4_0-quantized FFN layers for Qwen2 with Group Query Attention (GQA) support. + * + * Key Differences from Qwen2FP16FFNLayers: + * - Uses Q4_0-quantized weights (getQuants() and getScales()) + * - Same attention and RoPE kernels as FP16 version + * - 4-bit integer computations with dequantization + * - 4x memory compression vs FP16, 2x vs Q8_0 + * - Includes bias terms for Q, K, V projections + * + * Works directly with Qwen2State to access and mutate Qwen2-specific state fields. + */ +public class Qwen2Q4_0FFNLayers extends AbstractFFNLayers { + + TaskGraph ffnLayerTaskGraph; + GridScheduler scheduler; + List ffnLayerTaskGraphs; + + // Typed references to Qwen2-specific state and config + private final Qwen2State qwen2State; + private final Qwen2Configuration qwen2Config; + + public Qwen2Q4_0FFNLayers(String taskGraphName, Qwen2State state, Qwen2TornadoWeights weights, Qwen2Configuration config, SchedulerType schedulerType) { + super(taskGraphName, state, weights, config, schedulerType); + this.qwen2State = state; + this.qwen2Config = config; + ffnLayerTaskGraphs = setupFFNLayered(); + } + + @Override + public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) { + int h = config.numberOfHeads(); + int ic = config.headSize() / 2; + WorkerGrid ropeWorker = new WorkerGrid2D(h, ic); + ropeWorker.setGlobalWork(h, ic, 1); + ropeWorker.setLocalWork(1, 1, 1); + + + int configDimRowMajorGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid configDimRowMajorGlobalWorker = new WorkerGrid1D(configDimRowMajorGlobal); + configDimRowMajorGlobalWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC, 1, 1); + + int configKvDimRowMajorGlobal = config.kvDim() * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid configKvDimRowMajorGlobalWorker = new WorkerGrid1D(configKvDimRowMajorGlobal); + configKvDimRowMajorGlobalWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC, 1, 1); + + WorkerGrid qBiasWorker = new WorkerGrid1D(config.dim()); + qBiasWorker.setGlobalWork(config.dim(), 1, 1); + qBiasWorker.setLocalWork(config.dim() / 8, 1, 1); + WorkerGrid kvBiasWorker = new WorkerGrid1D(config.kvDim()); + kvBiasWorker.setGlobalWork(config.kvDim(), 1, 1); + kvBiasWorker.setLocalWork(32, 1, 1); + + int configHiddenDimRowMajor = config.hiddenDim() * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid configHiddenDimRowMajorWorker = new WorkerGrid1D(configHiddenDimRowMajor); + configHiddenDimRowMajorWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC, 1, 1); + + WorkerGrid rmsNormWorker = WorkerGridFactory.createRmsNormWorker(config.dim(), 32); + + int optimalLocalSize = Math.min(config.headSize(), 64); // Start with 64 threads per head + if (config.headSize() % optimalLocalSize != 0) { + // Find largest divisor of headSize <= 64 + for (int size = 64; size >= 1; size--) { + if (config.headSize() % size == 0) { + optimalLocalSize = size; + break; + } + } + } + + WorkerGrid parallelAttentionWorker = new WorkerGrid1D(config.numberOfHeads()); + parallelAttentionWorker.setGlobalWork(config.numberOfHeads() * optimalLocalSize, 1, 1); + parallelAttentionWorker.setLocalWork(optimalLocalSize, 1, 1); + + WorkerGrid copyToCachesWorker = new WorkerGrid1D(config.kvDim()); + copyToCachesWorker.setGlobalWork(config.kvDim(), 1, 1); + copyToCachesWorker.setLocalWork(32, 1, 1); // Set local work size to 32 (for copying to caches) + + // Map workers to tasks + for (int i = 0; i < config.numberOfLayers(); i++) { + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qmatmul", configDimRowMajorGlobalWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".kmatmul", configKvDimRowMajorGlobalWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".vmatmul", configKvDimRowMajorGlobalWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qbias", qBiasWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".kbias", kvBiasWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".vbias", kvBiasWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope", ropeWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".matmul1", configDimRowMajorGlobalWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".projectionTwo", configDimRowMajorGlobalWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".fused_ffn_w1_w3", configHiddenDimRowMajorWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlock", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContext", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlockFFN", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContextFFN", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".parallel-attention", parallelAttentionWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".copyToCaches", copyToCachesWorker); + } + return tornadoForwardScheduler; + } + + @Override + public GridScheduler getGridScheduler() { + return scheduler; + } + + @Override + public TaskGraph getTaskGraph() { + return ffnLayerTaskGraph; + } + + @Override + public ImmutableTaskGraph getImmutableTaskGraph() { + return null; + } + + public List getFfnLayerTaskGraphs() { + return ffnLayerTaskGraphs; + } + + /** + * Setup all FFN layers for all transformer layers + */ + List setupFFNLayered() { + List ffnGraphs = new ArrayList<>(); + qwen2State.temp.init(0.0f); + qwen2State.tempFFN.init(0.0f); + + for (int layerIndex = 0; layerIndex < qwen2Config.numberOfLayers(); layerIndex++) { + TaskGraph ffnLayer = setupSingleQwen2Q4_0FFNLayer((Qwen2TornadoWeights) weights, layerIndex); + if (layerIndex == qwen2Config.numberOfLayers() - 1) { + setupLastID(ffnLayer.getTaskGraphName()); + } + ffnGraphs.add(ffnLayer.snapshot()); + } + return ffnGraphs; + } + + /** + * Setup a single transformer layer for Qwen2 with Q4_0 quantization and GQA + */ + TaskGraph setupSingleQwen2Q4_0FFNLayer(Qwen2TornadoWeights weights, int layerIndex) { + TaskGraph unifiedLayer = new TaskGraph("layer_" + layerIndex); + unifiedLayer.consumeFromDevice(state.wrapX); + unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + //Copy-in weights per layer for batched-layered layout + weights.rms_att_weightLayered[layerIndex].asFloatArray(), + weights.wqLayered[layerIndex].getScales(), + weights.wqLayered[layerIndex].getQuants(), + weights.wkLayered[layerIndex].getScales(), + weights.wkLayered[layerIndex].getQuants(), + weights.wvLayered[layerIndex].getScales(), + weights.wvLayered[layerIndex].getQuants(), + weights.woLayered[layerIndex].getScales(), + weights.woLayered[layerIndex].getQuants(), + weights.q_biasLayered[layerIndex].asFloatArray(), + weights.k_biasLayered[layerIndex].asFloatArray(), + weights.v_biasLayered[layerIndex].asFloatArray(), + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), + weights.w1Layered[layerIndex].getScales(), + weights.w1Layered[layerIndex].getQuants(), + weights.w2Layered[layerIndex].getScales(), + weights.w2Layered[layerIndex].getQuants(), + weights.w3Layered[layerIndex].getScales(), + weights.w3Layered[layerIndex].getQuants() + ); + unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex); + + unifiedLayer.task("reductionsOneBlock" , TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.temp, + state.wrapX, config.dim(), config.rmsNormEps(), state.localSize) + .task("mapContext", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, + state.wrapX, weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp) + .task("qmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, + state.wrapXb, state.wrapQ, weights.wqLayered[layerIndex].getQuants(), weights.wqLayered[layerIndex].getScales(), config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) + .task("kmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, + state.wrapXb, state.wrapK, weights.wkLayered[layerIndex].getQuants(), weights.wkLayered[layerIndex].getScales(), config.dim(), config.kvDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) + .task("vmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, + state.wrapXb, state.wrapV, weights.wvLayered[layerIndex].getQuants(), weights.wvLayered[layerIndex].getScales(), config.dim(), config.kvDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) + .task("qbias", TransformerComputeKernelsLayered::addInPlace, state.wrapQ, weights.q_biasLayered[layerIndex].asFloatArray(), config.dim()) + .task("kbias", TransformerComputeKernelsLayered::addInPlace, state.wrapK, weights.k_biasLayered[layerIndex].asFloatArray(), config.kvDim()) + .task("vbias", TransformerComputeKernelsLayered::addInPlace, state.wrapV, weights.v_biasLayered[layerIndex].asFloatArray(), config.kvDim()) + .task("rope", Qwen3Kernels::ropeRotation,context, state.positionHolder, state.wrapQ, state.wrapK, config.numberOfKeyValueHeads(), + config.headSize()) + .task("copyToCaches", TransformerComputeKernelsLayered::copyToCache, + state.wrapKeyCache, state.wrapK, state.wrapValueCache, state.wrapV, state.positionHolder, config.kvDim(), layerIndex, config.contextLength()) + .task("parallel-attention", Qwen2Kernels::processHeadsFlashAttention, context, + state.wrapQ, state.wrapKeyCache, state.wrapValueCache, state.wrapXb, + config.numberOfHeads(), config.headSize(), config.kvDim(), config.kvMul(), + state.positionHolder, layerIndex, config.contextLength()) + .task("matmul1", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, + state.wrapXb, state.wrapX, weights.woLayered[layerIndex].getQuants(), weights.woLayered[layerIndex].getScales(), config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) + .task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.tempFFN, + state.wrapX, config.dim(), config.rmsNormEps(), state.localSize) + .task("mapContextFFN", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, + state.wrapX, weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), state.tempFFN) + .task("fused_ffn_w1_w3", TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivation, context, + state.wrapXb, state.wrapHb, weights.w1Layered[layerIndex].getQuants(), weights.w1Layered[layerIndex].getScales(), weights.w3Layered[layerIndex].getQuants(), weights.w3Layered[layerIndex].getScales(), config.dim(), config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) + .task("projectionTwo", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, + state.wrapHb, state.wrapX, weights.w2Layered[layerIndex].getQuants(), weights.w2Layered[layerIndex].getScales(), config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) + .persistOnDevice( + state.wrapX + ); + return unifiedLayer; + + } + + /** + * Configure data transfers for first and subsequent layers + */ + protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int layerIndex) { + if (layerIndex == 0) { + // First layer: Transfer temporary buffers and QKV state every execution + unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, + qwen2State.positionHolder, qwen2State.temp, qwen2State.tempFFN); + // First execution: allocate workspace buffers + unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + context, qwen2State.wrapXb, qwen2State.wrapXb2, + qwen2State.wrapQ, qwen2State.wrapK, qwen2State.wrapV, + qwen2State.wrapKeyCache, qwen2State.wrapValueCache, + qwen2State.wrapAtt, qwen2State.wrapHb); + } else { + // Subsequent layers: Consume data from previous layer + unifiedLayer.consumeFromDevice(context, qwen2State.wrapXb, qwen2State.wrapXb2, + qwen2State.wrapQ, qwen2State.wrapK, qwen2State.wrapV, + qwen2State.wrapKeyCache, qwen2State.wrapValueCache, + qwen2State.wrapAtt, qwen2State.wrapHb, qwen2State.positionHolder); + } + return unifiedLayer; + } + +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q4_0/Qwen3Q4_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q4_0/Qwen3Q4_0FFNLayers.java new file mode 100644 index 0000000..a3af3bf --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q4_0/Qwen3Q4_0FFNLayers.java @@ -0,0 +1,319 @@ +package org.beehive.gpullama3.tornadovm.layers.type.q4_0; + +import org.beehive.gpullama3.inference.state.Qwen3State; +import org.beehive.gpullama3.inference.weights.tornado.Qwen3TornadoWeights; +import org.beehive.gpullama3.model.qwen3.Qwen3Configuration; +import org.beehive.gpullama3.tornadovm.kernels.Qwen3Kernels; +import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; +import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.WorkerGrid; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; + +import java.util.ArrayList; +import java.util.List; + +/** + * Qwen3Q4_0FFNLayers: Q4_0-quantized FFN layers for Qwen3 with Group Query Attention (GQA) support. + * + * Key Differences from Qwen3FP16FFNLayers: + * - Uses Q4_0-quantized weights (getQuants() and getScales()) + * - Same Qwen3Kernels for RMSNorm and RoPE + * - 4-bit integer computations with dequantization + * - 4x memory compression vs FP16, 2x vs Q8_0 + * + * Works directly with Qwen3State to access and mutate Qwen3-specific state fields + * like tempQcur and tempKcur. + */ +public class Qwen3Q4_0FFNLayers extends AbstractFFNLayers { + + String lastTaskGraphID; + TaskGraph ffnLayerTaskGraph; + GridScheduler scheduler; + List ffnLayerTaskGraphs; + + // Typed references to Qwen3-specific state and config + private final Qwen3State qwen3State; + private final Qwen3Configuration qwen3Config; + + // Qwen3-specific GQA parameters + private final int nHeadKv; + private final int nEmbdHeadK; + private final int nEmbdHeadV; + private final int nEmbdVGqa; + private final int nEmbdHead; + private final int nEmbdGqa; + private final int gqa; + + public Qwen3Q4_0FFNLayers(String taskGraphName, Qwen3State state, Qwen3TornadoWeights weights, Qwen3Configuration config, SchedulerType schedulerType) { + super(taskGraphName, state, weights, config, schedulerType); + this.qwen3State = state; + this.qwen3Config = config; + this.nHeadKv = config.numberOfKeyValueHeads(); + this.nEmbdHeadK = config.numberOfHeadsKey(); + this.nEmbdHeadV = config.numberOfHeadsValue(); + this.nEmbdVGqa = nEmbdHeadV * nHeadKv; + this.nEmbdHead = nEmbdHeadV; + this.nEmbdGqa = nEmbdVGqa; + this.gqa = config.numberOfHeads() / config.numberOfKeyValueHeads(); + ffnLayerTaskGraphs = setupFFNLayered(); + } + + @Override + public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) { + WorkerGrid rmsNormWorker = WorkerGridFactory.createRmsNormWorker(config.dim(), state.localSize); + + int matmulQGlobal = nEmbdHeadK * config.numberOfHeads() * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid matmulQRowMajorWorker = WorkerGridFactory.genericWorker(matmulQGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); + + int matmulKVGlobal = nEmbdGqa * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid matmulKVRowMajorWorker = WorkerGridFactory.genericWorker(matmulKVGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); + + WorkerGrid qCurWorker = WorkerGridFactory.genericWorker(config.numberOfHeads() * nEmbdHead, nEmbdHead); + WorkerGrid kCurWorker = WorkerGridFactory.genericWorker(config.numberOfKeyValueHeads() * nEmbdHead, nEmbdHead); + + int h = config.numberOfHeads(); + int ic = nEmbdHead / 2; + WorkerGrid ropeWorker = WorkerGridFactory.createRoPEWorker(h, nEmbdHead); + WorkerGrid copyToCachesWorker = WorkerGridFactory.genericWorker(nEmbdGqa, 128); + WorkerGrid parallelAttentionWorker = WorkerGridFactory.createAttentionWorker(config.numberOfHeads(), nEmbdHead); + + int matmul1Global = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid matmul1Worker = WorkerGridFactory.genericWorker(matmul1Global, LOCAL_WORK_GROUP_SIZE_ALLOC); + + int fusedFFNW1W3Global = config.hiddenDim() * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid fusedFFNW1W3Worker = WorkerGridFactory.genericWorker(fusedFFNW1W3Global, LOCAL_WORK_GROUP_SIZE_ALLOC); + + int projectionTwoGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid projectionTwoWorker = WorkerGridFactory.genericWorker(projectionTwoGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); + + for (int i = 0; i < config.numberOfLayers(); i++) { + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlock", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContext", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qmatmul", matmulQRowMajorWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".kmatmul", matmulKVRowMajorWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".vmatmul", matmulKVRowMajorWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rmsnormReduction_Qcur", qCurWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rmsnormMapIndexInPlace_Qcur", qCurWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rmsnormReduction_Kcur", kCurWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rmsnormMapIndexInPlace_Kcur", kCurWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ropeRotation", ropeWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".copyToCaches", copyToCachesWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".parallel-attention", parallelAttentionWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".matmul1", matmul1Worker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlockFFN", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContextFFN", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".fused_ffn_w1_w3", fusedFFNW1W3Worker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".projectionTwo", projectionTwoWorker); + } + return tornadoForwardScheduler; + } + + @Override + public GridScheduler getGridScheduler() { + return scheduler; + } + + @Override + public TaskGraph getTaskGraph() { + return ffnLayerTaskGraph; + } + + @Override + public ImmutableTaskGraph getImmutableTaskGraph() { + return null; + } + + public List getFfnLayerTaskGraphs() { + return ffnLayerTaskGraphs; + } + + /** + * Setup all FFN layers for all transformer layers + */ + List setupFFNLayered() { + List ffnGraphs = new ArrayList<>(); + qwen3State.temp.init(0.0f); + qwen3State.tempFFN.init(0.0f); + qwen3State.tempQcur.init(0.0f); + qwen3State.tempKcur.init(0.0f); + + for (int layerIndex = 0; layerIndex < qwen3Config.numberOfLayers(); layerIndex++) { + TaskGraph ffnLayer = setupSingleQwen3FFNLayer((Qwen3TornadoWeights) weights, layerIndex); + if (layerIndex == qwen3Config.numberOfLayers() - 1) { + setupLastID(ffnLayer.getTaskGraphName()); + } + ffnGraphs.add(ffnLayer.snapshot()); + } + return ffnGraphs; + } + + /** + * Setup a single transformer layer for Qwen3 with GQA (Q4_0 quantized) + */ + TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex) { + + var unifiedLayerName = "layer_" + layerIndex; + TaskGraph unifiedLayer = new TaskGraph(unifiedLayerName); + unifiedLayer.consumeFromDevice(qwen3State.wrapX); + // Transfer Q4_0 weights for this layer (quants and scales) + unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + weights.rms_att_weightLayered[layerIndex].asFloatArray(), // + weights.wqLayered[layerIndex].getQuants(), // + weights.wqLayered[layerIndex].getScales(), // + weights.wkLayered[layerIndex].getQuants(), // + weights.wkLayered[layerIndex].getScales(), // + weights.wvLayered[layerIndex].getQuants(), // + weights.wvLayered[layerIndex].getScales(),// + weights.woLayered[layerIndex].getQuants(),// + weights.woLayered[layerIndex].getScales(),// + weights.rms_att_KNormLayered[layerIndex].asFloatArray(), // + weights.rms_att_QNormLayered[layerIndex].asFloatArray(),// + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), // + weights.w1Layered[layerIndex].getQuants(), // + weights.w1Layered[layerIndex].getScales(), // + weights.w2Layered[layerIndex].getQuants(), // + weights.w2Layered[layerIndex].getScales(), // + weights.w3Layered[layerIndex].getQuants(), // + weights.w3Layered[layerIndex].getScales()); // + + // Configure layer data transfers (EVERY_EXECUTION and device persistence) + unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex); + + + // RMS norm for attention input + unifiedLayer.task("reductionsOneBlock", + TransformerComputeKernelsLayered::reductionOneBlockWithLayer, + context, qwen3State.temp, qwen3State.wrapX, config.dim(), config.rmsNormEps(), qwen3State.localSize) + .task("mapContext", + TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, + context, qwen3State.wrapXb, qwen3State.wrapX, weights.rms_att_weightLayered[layerIndex].asFloatArray(), qwen3State.temp); + + // QKV projections with Qwen3 GQA dimensions + // Q4_0 weights pass both quants and scales + int qDim0 = nEmbdHeadK * config.numberOfHeads(); // Query dimension + int kvDim0 = nEmbdGqa; // KV dimension (smaller due to GQA) + int qkvDim1 = config.dim(); // Input dimension + + unifiedLayer.task("qmatmul", + TransformerComputeKernelsLayered::matrixVectorGeneric, + context, qwen3State.wrapXb, qwen3State.wrapQ, + weights.wqLayered[layerIndex].getQuants(), weights.wqLayered[layerIndex].getScales(), + qkvDim1, qDim0, LOCAL_WORK_GROUP_SIZE_ALLOC) + .task("kmatmul", + TransformerComputeKernelsLayered::matrixVectorGeneric, + context, qwen3State.wrapXb, qwen3State.wrapK, + weights.wkLayered[layerIndex].getQuants(), weights.wkLayered[layerIndex].getScales(), + qkvDim1, kvDim0, LOCAL_WORK_GROUP_SIZE_ALLOC) + .task("vmatmul", + TransformerComputeKernelsLayered::matrixVectorGeneric, + context, qwen3State.wrapXb, qwen3State.wrapV, + weights.wvLayered[layerIndex].getQuants(), weights.wvLayered[layerIndex].getScales(), + qkvDim1, kvDim0, LOCAL_WORK_GROUP_SIZE_ALLOC); + + // Qcur: RMS norm with parallel offset for Query + Qwen3State qwen3State = (Qwen3State) state; + unifiedLayer.task("rmsnormReduction_Qcur", + Qwen3Kernels::rmsnormWithParallelOffset, + context, qwen3State.tempQcur, qwen3State.wrapQ, qwen3State.localSize, nEmbdHead, config.rmsNormEps()) + .task("rmsnormMapIndexInPlace_Qcur", + Qwen3Kernels::rmsnormMapIndexInPlaceWithParallelOffset, + context, qwen3State.wrapQ, weights.rms_att_QNormLayered[layerIndex].asFloatArray(), nEmbdHead, qwen3State.tempQcur); + + // Kcur: RMS norm with parallel offset for Key + unifiedLayer.task("rmsnormReduction_Kcur", + Qwen3Kernels::rmsnormWithParallelOffset, + context, qwen3State.tempKcur, qwen3State.wrapK, qwen3State.localSize, nEmbdHead, config.rmsNormEps()) + .task("rmsnormMapIndexInPlace_Kcur", + Qwen3Kernels::rmsnormMapIndexInPlaceWithParallelOffset, + context, qwen3State.wrapK, weights.rms_att_KNormLayered[layerIndex].asFloatArray(), nEmbdHead, qwen3State.tempKcur); + + // RoPE rotation (Qwen3 variant) + unifiedLayer.task("ropeRotation", + Qwen3Kernels::ropeRotation, + context, qwen3State.positionHolder, qwen3State.wrapQ, qwen3State.wrapK, + config.numberOfKeyValueHeads(), nEmbdHead); + + // Copy to KV cache + unifiedLayer.task("copyToCaches", + TransformerComputeKernelsLayered::copyToCache, + qwen3State.wrapKeyCache, qwen3State.wrapK, qwen3State.wrapValueCache, qwen3State.wrapV, + qwen3State.positionHolder, nEmbdGqa, layerIndex, config.contextLength()); + + // Parallel attention (with GQA support) + unifiedLayer.task("parallel-attention", + TransformerComputeKernelsLayered::processHeadsFlashAttentionOpt, + context, qwen3State.wrapQ, qwen3State.wrapKeyCache, qwen3State.wrapValueCache, qwen3State.wrapXb, + config.numberOfHeads(), nEmbdHead, nEmbdGqa, gqa, qwen3State.positionHolder, layerIndex, config.contextLength()); + + // Output projection (Q4_0 weights) + unifiedLayer.task("matmul1", + TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, + context, qwen3State.wrapXb, qwen3State.wrapX, + weights.woLayered[layerIndex].getQuants(), weights.woLayered[layerIndex].getScales(), + qDim0, config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC); + + // ========== FEED-FORWARD BLOCK ========== + + // RMS norm for FFN input + unifiedLayer.task("reductionsOneBlockFFN", + TransformerComputeKernelsLayered::reductionOneBlockWithLayer, + context, qwen3State.tempFFN, qwen3State.wrapX, config.dim(), config.rmsNormEps(), qwen3State.localSize) + .task("mapContextFFN", + TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, + context, qwen3State.wrapXb, qwen3State.wrapX, weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), qwen3State.tempFFN); + + // Fused FFN: w1(x) ⊗ w3(x) with SiLU activation (Q4_0 weights) + unifiedLayer.task("fused_ffn_w1_w3", + TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivation, + context, qwen3State.wrapXb, qwen3State.wrapHb, + weights.w1Layered[layerIndex].getQuants(), weights.w1Layered[layerIndex].getScales(), + weights.w3Layered[layerIndex].getQuants(), weights.w3Layered[layerIndex].getScales(), + config.dim(), config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) + .task("projectionTwo", + TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, + context, qwen3State.wrapHb, qwen3State.wrapX, + weights.w2Layered[layerIndex].getQuants(), weights.w2Layered[layerIndex].getScales(), + config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) + .persistOnDevice(state.wrapX); + + return unifiedLayer; + } + + /** + * Configure data transfers for first and subsequent layers + */ + protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int layerIndex) { + if (layerIndex == 0) { + // First layer: Transfer temporary buffers and QKV state every execution + unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, + qwen3State.positionHolder, qwen3State.temp, qwen3State.tempFFN); + + Qwen3State qwen3State = (Qwen3State) state; + unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, + qwen3State.tempQcur, qwen3State.tempKcur); + + // First execution: allocate workspace buffers + unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, // + context, qwen3State.wrapXb, qwen3State.wrapXb2, // + qwen3State.wrapQ, qwen3State.wrapK, qwen3State.wrapV, // + qwen3State.wrapKeyCache, qwen3State.wrapValueCache, // + qwen3State.wrapAtt, qwen3State.wrapHb); // + } else { + // Subsequent layers: Consume data from previous layer + unifiedLayer.consumeFromDevice(context, qwen3State.wrapXb, qwen3State.wrapXb2, // + qwen3State.wrapQ, qwen3State.wrapK, qwen3State.wrapV, // + qwen3State.wrapKeyCache, qwen3State.wrapValueCache, // + qwen3State.wrapAtt, qwen3State.wrapHb, qwen3State.positionHolder); // + + Qwen3State qwen3State = (Qwen3State) state; + unifiedLayer.consumeFromDevice(qwen3State.tempQcur, qwen3State.tempKcur); // + } + return unifiedLayer; + } + +}