Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ protected Weights createTornadoVMWeights(Map<String, GGMLTensorEntry> tensorEntr
}

// Validate supported types
if (ggmlType != GGMLType.F16 && ggmlType != GGMLType.Q8_0) {
if (ggmlType != GGMLType.F16 && ggmlType != GGMLType.Q8_0 && ggmlType != GGMLType.Q4_0) {
throw new UnsupportedOperationException("Type: " + ggmlType + " currently not supported for TornadoVM weights.");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ protected Weights createTornadoVMWeights(Map<String, GGMLTensorEntry> tensorEntr
}

// Validate supported types
if (ggmlType != GGMLType.F16 && ggmlType != GGMLType.Q8_0) {
if (ggmlType != GGMLType.F16 && ggmlType != GGMLType.Q8_0 && ggmlType != GGMLType.Q4_0) {
throw new UnsupportedOperationException("Type: " + ggmlType + " currently not supported for TornadoVM weights.");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import org.beehive.gpullama3.tensor.standard.*;
import org.beehive.gpullama3.tensor.tornado.FP16TornadoTensor;
import org.beehive.gpullama3.tensor.tornado.FP32TornadoTensor;
import org.beehive.gpullama3.tensor.tornado.Q4_0TornadoTensor;
import org.beehive.gpullama3.tensor.tornado.Q8_0TornadoTensor;
import org.beehive.gpullama3.tensor.tornado.TornadoTensor;
import uk.ac.manchester.tornado.api.types.HalfFloat;
Expand Down Expand Up @@ -130,7 +131,7 @@ public static TornadoTensor loadTornadoTensor(GGMLTensorEntry entry) {
case F32 -> new FP32TornadoTensor(size, entry.memorySegment());
case F16 -> new FP16TornadoTensor(size, entry.memorySegment());
case Q8_0 -> Q8_0TornadoTensor.create(entry);
case Q4_0 -> throw new UnsupportedOperationException("Q4 format not supported yet");
case Q4_0 -> Q4_0TornadoTensor.create(entry);
default -> throw new UnsupportedOperationException("Quantization format " + ggmlType);
};
}
Expand Down Expand Up @@ -203,6 +204,14 @@ public static Q8_0TornadoTensor[] loadArrayAsQ8_0TornadoTensor(int size, IntFunc
return array;
}

public static Q4_0TornadoTensor[] loadArrayAsQ4_0TornadoTensor(int size, IntFunction<GGMLTensorEntry> getTensorEntry) {
Q4_0TornadoTensor[] array = new Q4_0TornadoTensor[size];
for (int i = 0; i < size; i++) {
array[i] = Q4_0TornadoTensor.create(getTensorEntry.apply(i));
}
return array;
}

public static FloatArray floatBufferToFloatArray(GGMLTensorEntry tensorEntry) {
if (tensorEntry.ggmlType() == GGMLType.F32) {
FloatBuffer buffer = tensorEntry.memorySegment().asByteBuffer().order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ protected Weights createTornadoVMWeights(Map<String, GGMLTensorEntry> tensorEntr
}

// Validate supported types
if (ggmlType != GGMLType.F16 && ggmlType != GGMLType.Q8_0) {
if (ggmlType != GGMLType.F16 && ggmlType != GGMLType.Q8_0 && ggmlType != GGMLType.Q4_0) {
throw new UnsupportedOperationException("Type: " + ggmlType + " currently not supported for TornadoVM weights.");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ protected Weights createTornadoVMWeights(Map<String, GGMLTensorEntry> tensorEntr
}

// Validate supported types
if (ggmlType != GGMLType.F16 && ggmlType != GGMLType.Q8_0) {
if (ggmlType != GGMLType.F16 && ggmlType != GGMLType.Q8_0 && ggmlType != GGMLType.Q4_0) {
throw new UnsupportedOperationException("Type: " + ggmlType + " currently not supported for TornadoVM weights.");
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
package org.beehive.gpullama3.tensor.tornado;

import org.beehive.gpullama3.tensor.GGMLTensorEntry;
import org.beehive.gpullama3.tensor.GGMLType;
import org.beehive.gpullama3.tensor.standard.FloatTensor;
import uk.ac.manchester.tornado.api.types.HalfFloat;
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
import uk.ac.manchester.tornado.api.types.arrays.ByteArray;

import java.lang.foreign.MemorySegment;
import java.lang.foreign.ValueLayout;
import java.nio.ByteOrder;

public class Q4_0TornadoTensor extends TornadoTensor {

private final HalfFloatArray scales; // One per 32-element block
private final ByteArray quants; // Packed 4-bit quantized values (2 per byte)
private MemorySegment segment;

public Q4_0TornadoTensor(int size, HalfFloatArray scales, ByteArray quants, MemorySegment segment) {
super(size);
this.scales = scales;
this.quants = quants;
this.segment = segment;
}

/**
* Returns the scale factors for GPU kernels.
*
* @return HalfFloatArray containing fp16 scale factors
*/
Copy link

Copilot AI Nov 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method overrides TornadoTensor.getScales; it is advisable to add an Override annotation.

Suggested change
*/
*/
@Override

Copilot uses AI. Check for mistakes.
public HalfFloatArray getScales() {
return scales;
}

/**
* Returns the quantized values for GPU kernels.
*
* @return ByteArray containing packed 4-bit quantized values
*/
Copy link

Copilot AI Nov 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method overrides TornadoTensor.getQuants; it is advisable to add an Override annotation.

Suggested change
*/
*/
@Override

Copilot uses AI. Check for mistakes.
public ByteArray getQuants() {
return quants;
}

@Override
public int size() {
return size;
}

@Override
public GGMLType type() {
return GGMLType.Q4_0;
}

public MemorySegment asMemorySegment() {
return segment;
}

/**
* Dequantizes and returns a single float value.
*
* @param index Element index
* @return Dequantized float value
*/
public float getFloat(int index) {
assert 0 <= index && index < size;
int blockIdx = index / GGMLType.Q4_0.getBlockSize();
int withinBlockIdx = index % GGMLType.Q4_0.getBlockSize();

float scale = scales.get(blockIdx).getFloat32();

// Each byte contains 2 4-bit values
int byteIdx = withinBlockIdx / 2;
byte packedByte = quants.get(blockIdx * 16 + byteIdx);

// Extract the 4-bit value (lower or upper nibble)
byte quant;
if (withinBlockIdx % 2 == 0) {
// Lower 4 bits
quant = (byte) (packedByte & 0x0F);
} else {
// Upper 4 bits
quant = (byte) ((packedByte >>> 4) & 0x0F);
}

// Offset by -8 (same as Q8_0)
quant -= 8;

return quant * scale;
}

public static Q4_0TornadoTensor create(GGMLTensorEntry entry) {
if (entry.ggmlType() != GGMLType.Q4_0) {
throw new IllegalArgumentException("Expected Q4_0 tensor, got: " + entry.ggmlType() + " for tensor: " + entry.name());
}

int[] shape = entry.shape();
int size = FloatTensor.numberOfElements(shape);
int numBlocks = size / GGMLType.Q4_0.getBlockSize();

if (size % GGMLType.Q4_0.getBlockSize() != 0) {
throw new IllegalArgumentException("Q4_0 tensor size must be multiple of " + GGMLType.Q4_0.getBlockSize() + ", got: " + size + " for tensor: " + entry.name());
}

MemorySegment q4Segment = entry.memorySegment();

// allocate the arrays for quantized data (packed 4-bit) and scales (fp16)
HalfFloatArray scales = new HalfFloatArray(numBlocks);
ByteArray quants = new ByteArray(numBlocks * 16); // 32 4-bit values = 16 bytes per block

// unpack Q4_0 blocks: [2 bytes fp16 scale][16 bytes packed 4-bit quants]
ValueLayout.OfShort shortLayout = ValueLayout.JAVA_SHORT_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN);
ValueLayout.OfByte byteLayout = ValueLayout.JAVA_BYTE;

for (int block = 0; block < numBlocks; block++) {
long blockOffset = block * GGMLType.Q4_0.getTypeSize(); // 18 bytes per block
Copy link

Copilot AI Nov 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential overflow in int multiplication before it is converted to long by use in an assignment context.

Suggested change
long blockOffset = block * GGMLType.Q4_0.getTypeSize(); // 18 bytes per block
long blockOffset = ((long) block) * GGMLType.Q4_0.getTypeSize(); // 18 bytes per block

Copilot uses AI. Check for mistakes.

// read fp16 scale (first 2 bytes of block)
short scaleRaw = q4Segment.get(shortLayout, blockOffset);
scales.set(block, new HalfFloat(scaleRaw));

// read 16 bytes of packed 4-bit quantized values (remaining bytes of block)
for (int i = 0; i < 16; i++) {
byte quantValue = q4Segment.get(byteLayout, blockOffset + 2 + i);
quants.set(block * 16 + i, quantValue);
}
}

return new Q4_0TornadoTensor(size, scales, quants, q4Segment);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.Phi3Q8_0LayerPlanner;
import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.Qwen2Q8_0LayerPlanner;
import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.Qwen3Q8_0LayerPlanner;
import org.beehive.gpullama3.tornadovm.layerplanner.model.q4_0.LlamaQ4_0LayerPlanner;
import org.beehive.gpullama3.tornadovm.layerplanner.model.q4_0.Phi3Q4_0LayerPlanner;
import org.beehive.gpullama3.tornadovm.layerplanner.model.q4_0.Qwen2Q4_0LayerPlanner;
import org.beehive.gpullama3.tornadovm.layerplanner.model.q4_0.Qwen3Q4_0LayerPlanner;

/**
* Factory class responsible for creating appropriate layer planners based on model type and quantization.
Expand Down Expand Up @@ -77,8 +81,16 @@ private static GenericLayerPlanner createFP32Planner(State state, Model model) {
throw new UnsupportedOperationException("FP32 planners not yet implemented");
}

// ============ Q4_0 Planners ============
private static GenericLayerPlanner createQ4_0Planner(State state, Model model) {
throw new UnsupportedOperationException("Q4 planners not yet implemented");
return switch (model.getModelType()) {
case LLAMA_3, MISTRAL -> new LlamaQ4_0LayerPlanner((LlamaState) state, model);
case QWEN_2 -> new Qwen2Q4_0LayerPlanner((Qwen2State) state, model);
case QWEN_3 -> new Qwen3Q4_0LayerPlanner((Qwen3State) state, model);
case PHI_3 -> new Phi3Q4_0LayerPlanner((Phi3State) state, model);
case DEEPSEEK_R1_DISTILL_QWEN -> new Qwen2Q4_0LayerPlanner((Qwen2State) state, model);
default -> throw new UnsupportedOperationException("Q4_0 not supported for model: " + model.getModelType());
};
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package org.beehive.gpullama3.tornadovm.layerplanner.model.q4_0;

import org.beehive.gpullama3.inference.state.LlamaState;
import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights;
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.model.llama.LlamaConfiguration;
import org.beehive.gpullama3.tornadovm.layerplanner.quantization.Q4_0LayerPlanner;
import org.beehive.gpullama3.tornadovm.layers.Activation;
import org.beehive.gpullama3.tornadovm.layers.type.q4_0.LlamaQ4_0FFNLayers;
import org.beehive.gpullama3.tornadovm.layers.type.q4_0.LogitsQ4_0Layer;

public class LlamaQ4_0LayerPlanner extends Q4_0LayerPlanner<LlamaState, LlamaConfiguration, LlamaTornadoWeights> {

public LlamaQ4_0LayerPlanner(LlamaState state, Model model) {
super(state, model);
validateQuantizationType();
setupTornadoForwardPlan();
}

@Override
protected void initializeLayerComponents() {
this.activationLayer = new Activation("activationUpdate", this.state, this.weights, this.config);
this.ffnLayers = new LlamaQ4_0FFNLayers("llamaFFN", this.state, this.weights, this.config, this.schedulerType);
this.logitsLayer = new LogitsQ4_0Layer("llamaLogits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID(), this.schedulerType);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package org.beehive.gpullama3.tornadovm.layerplanner.model.q4_0;

import org.beehive.gpullama3.inference.state.Phi3State;
import org.beehive.gpullama3.inference.weights.tornado.Phi3TornadoWeights;
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.model.phi3.Phi3Configuration;
import org.beehive.gpullama3.tornadovm.layerplanner.quantization.Q4_0LayerPlanner;
import org.beehive.gpullama3.tornadovm.layers.Activation;
import org.beehive.gpullama3.tornadovm.layers.type.q4_0.Phi3Q4_0FFNLayers;
import org.beehive.gpullama3.tornadovm.layers.type.q4_0.LogitsQ4_0Layer;

/**
* Phi3Q4_0LayerPlanner: Phi3 model with Q4_0-quantized weights.
*
* Follows the same pattern as Qwen3Q4_0LayerPlanner but with:
* - Phi3-specific FFN layers (combined QKV + gate/up FFN)
* - Phi3TornadoWeights (4-bit integer quantization)
* - Phi3Configuration
* - 4x memory compression vs FP16, 2x vs Q8_0
*
* Inherits from Q4_0LayerPlanner<Phi3State, Phi3Configuration, Phi3TornadoWeights>
*/
public class Phi3Q4_0LayerPlanner extends Q4_0LayerPlanner<Phi3State, Phi3Configuration, Phi3TornadoWeights> {

public Phi3Q4_0LayerPlanner(Phi3State state, Model model) {
super(state, model);
validateQuantizationType();
setupTornadoForwardPlan();
}

@Override
protected void initializeLayerComponents() {
this.activationLayer = new Activation("activationUpdate", this.state, this.weights, this.config);
this.ffnLayers = new Phi3Q4_0FFNLayers("phi3FFN", this.state, this.weights, this.config, this.schedulerType);
this.logitsLayer = new LogitsQ4_0Layer("phi3Logits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID(), this.schedulerType);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package org.beehive.gpullama3.tornadovm.layerplanner.model.q4_0;

import org.beehive.gpullama3.inference.state.Qwen2State;
import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights;
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.model.qwen2.Qwen2Configuration;
import org.beehive.gpullama3.tornadovm.layerplanner.quantization.Q4_0LayerPlanner;
import org.beehive.gpullama3.tornadovm.layers.Activation;
import org.beehive.gpullama3.tornadovm.layers.type.q4_0.Qwen2Q4_0FFNLayers;
import org.beehive.gpullama3.tornadovm.layers.type.q4_0.LogitsQ4_0Layer;

public class Qwen2Q4_0LayerPlanner extends Q4_0LayerPlanner<Qwen2State, Qwen2Configuration, Qwen2TornadoWeights> {

public Qwen2Q4_0LayerPlanner(Qwen2State state, Model model) {
super(state, model);
validateQuantizationType();
setupTornadoForwardPlan();
}

@Override
protected void initializeLayerComponents() {
this.activationLayer = new Activation("activationUpdate", this.state, this.weights, this.config);
this.ffnLayers = new Qwen2Q4_0FFNLayers("qwen2FFN", this.state, this.weights, this.config, this.schedulerType);
this.logitsLayer = new LogitsQ4_0Layer("qwen2Logits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID(), this.schedulerType);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package org.beehive.gpullama3.tornadovm.layerplanner.model.q4_0;

import org.beehive.gpullama3.inference.state.Qwen3State;
import org.beehive.gpullama3.inference.weights.tornado.Qwen3TornadoWeights;
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.model.qwen3.Qwen3Configuration;
import org.beehive.gpullama3.tornadovm.layerplanner.quantization.Q4_0LayerPlanner;
import org.beehive.gpullama3.tornadovm.layers.Activation;
import org.beehive.gpullama3.tornadovm.layers.type.q4_0.Qwen3Q4_0FFNLayers;
import org.beehive.gpullama3.tornadovm.layers.type.q4_0.LogitsQ4_0Layer;

/**
* Qwen3Q4_0LayerPlanner: Qwen3 model with Q4_0-quantized weights.
*
* Follows the same pattern as LlamaQ4_0LayerPlanner but with:
* - Qwen3-specific FFN layers (supports GQA)
* - Qwen3TornadoWeights (4-bit integer quantization)
* - Qwen3Configuration
* - 4x memory compression vs FP16, 2x vs Q8_0
*
* Inherits from Q4_0LayerPlanner<Qwen3State, Qwen3Configuration, Qwen3TornadoWeights>
*/
public class Qwen3Q4_0LayerPlanner extends Q4_0LayerPlanner<Qwen3State, Qwen3Configuration, Qwen3TornadoWeights> {

public Qwen3Q4_0LayerPlanner(Qwen3State state, Model model) {
super(state, model);
validateQuantizationType();
setupTornadoForwardPlan();
}

@Override
protected void initializeLayerComponents() {
this.activationLayer = new Activation("activationUpdate", this.state, this.weights, this.config);
this.ffnLayers = new Qwen3Q4_0FFNLayers("qwen3FFN", this.state, this.weights, this.config, this.schedulerType);
this.logitsLayer = new LogitsQ4_0Layer("qwen3Logits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID(),this.schedulerType);
}
}
Loading
Loading