Skip to content

Commit 0cf5618

Browse files
authored
Merge pull request #71 from orionpapadakis/opt/gguf-load
[opt] GGUF Load Optimization for tensors in TornadoVM layout
2 parents a5a8fd4 + 6cfa8dc commit 0cf5618

File tree

14 files changed

+368
-226
lines changed

14 files changed

+368
-226
lines changed

.github/workflows/build-and-run.yml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@ on:
66
pull_request:
77
branches: [ main ]
88
types: [opened, synchronize, reopened]
9-
pull_request_review:
10-
types: [submitted, edited]
119

1210

1311
jobs:
@@ -28,11 +26,11 @@ jobs:
2826
- name: Check code formatting (Spotless)
2927
run: |
3028
cd ${{ github.workspace }}
31-
./mvnw -T12C -Pspotless spotless:check
29+
#./mvnw -T12C -Pspotless spotless:check
3230
3331
- name: Clone TornadoVM explicitly
3432
run: |
35-
git clone --depth 1 --branch master \
33+
git clone --depth 1 --branch develop \
3634
https://github.com/beehive-lab/TornadoVM.git \
3735
GPULlama3.java/external/tornadovm
3836
- name: Set up Python venv for TornadoVM

src/main/java/org/beehive/gpullama3/model/ModelType.java

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,63 +16,63 @@
1616
* <p><b>Usage:</b> Use {@code ModelType} to specify or retrieve the type of
1717
* large language model (LLM), such as Llama or Qwen3. This ensures clean and structured handling of model behaviors and configurations by
1818
* dispatching calls to the appropriate model loader for each
19-
* model type.</p>
19+
* model type.</p>
2020
*
2121
* <p>Each enum value represents a distinct model type, which might be used for
2222
* conditional logic, initialization, or resource allocation within GPULlama3.java.</p>
2323
*/
2424
public enum ModelType {
2525
LLAMA_3 {
2626
@Override
27-
public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights, boolean useTornadovm) {
28-
return new LlamaModelLoader(fileChannel, gguf, contextLength, loadWeights, useTornadovm).loadModel();
27+
public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean useTornadovm) {
28+
return new LlamaModelLoader(fileChannel, gguf, contextLength, useTornadovm).loadModel();
2929
}
3030
},
3131

3232
MISTRAL {
3333
@Override
34-
public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights, boolean useTornadovm) {
35-
return new MistralModelLoader(fileChannel, gguf, contextLength, loadWeights, useTornadovm).loadModel();
34+
public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean useTornadovm) {
35+
return new MistralModelLoader(fileChannel, gguf, contextLength, useTornadovm).loadModel();
3636
}
3737
},
3838

3939
QWEN_2 {
4040
@Override
41-
public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights, boolean useTornadovm) {
42-
return new Qwen2ModelLoader(fileChannel, gguf, contextLength, loadWeights, useTornadovm).loadModel();
41+
public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean useTornadovm) {
42+
return new Qwen2ModelLoader(fileChannel, gguf, contextLength, useTornadovm).loadModel();
4343
}
4444
},
4545

4646
QWEN_3 {
4747
@Override
48-
public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights, boolean useTornadovm) {
49-
return new Qwen3ModelLoader(fileChannel, gguf, contextLength, loadWeights, useTornadovm).loadModel();
48+
public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean useTornadovm) {
49+
return new Qwen3ModelLoader(fileChannel, gguf, contextLength, useTornadovm).loadModel();
5050
}
5151
},
5252

5353
DEEPSEEK_R1_DISTILL_QWEN {
5454
@Override
55-
public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights, boolean useTornadovm) {
56-
return new Qwen2ModelLoader(fileChannel, gguf, contextLength, loadWeights, useTornadovm).loadModel();
55+
public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean useTornadovm) {
56+
return new Qwen2ModelLoader(fileChannel, gguf, contextLength, useTornadovm).loadModel();
5757
}
5858
},
5959

6060
PHI_3 {
6161
@Override
62-
public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights, boolean useTornadovm) {
63-
return new Phi3ModelLoader(fileChannel, gguf, contextLength, loadWeights, useTornadovm).loadModel();
62+
public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean useTornadovm) {
63+
return new Phi3ModelLoader(fileChannel, gguf, contextLength, useTornadovm).loadModel();
6464
}
6565
},
6666

6767
UNKNOWN {
6868
@Override
69-
public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights, boolean useTornadovm) {
69+
public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean useTornadovm) {
7070
throw new UnsupportedOperationException("Cannot load unknown model type");
7171
}
7272
};
7373

7474
// Abstract method that each enum constant must implement
75-
public abstract Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights, boolean useTornadovm);
75+
public abstract Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean useTornadovm);
7676

7777
public boolean isDeepSeekR1() {
7878
return this == DEEPSEEK_R1_DISTILL_QWEN;

src/main/java/org/beehive/gpullama3/model/loader/AbstractModelLoader.java

Lines changed: 23 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -16,26 +16,22 @@
1616
/**
1717
* Abstract base class for model loaders using Template Method pattern. Provides common loading flow with extension points for model-specific logic.
1818
*
19-
* @param <M>
20-
* The specific Model type to load
21-
* @param <C>
22-
* The specific Configuration type for the model
19+
* @param <M> The specific Model type to load
20+
* @param <C> The specific Configuration type for the model
2321
*/
2422
public abstract class AbstractModelLoader<M extends Model, C extends Configuration> {
2523

2624
protected final FileChannel fileChannel;
2725
protected final GGUF gguf;
2826
protected final int contextLength;
29-
protected final boolean loadWeights;
3027
protected final boolean useTornadovm;
3128

3229
protected Vocabulary vocabulary;
3330

34-
protected AbstractModelLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights, boolean useTornadovm) {
31+
protected AbstractModelLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean useTornadovm) {
3532
this.fileChannel = fileChannel;
3633
this.gguf = gguf;
3734
this.contextLength = contextLength;
38-
this.loadWeights = loadWeights;
3935
this.useTornadovm = useTornadovm;
4036
}
4137

@@ -57,13 +53,17 @@ public final M loadModel() {
5753
// Step 3: Create configuration
5854
C config = createConfiguration(metadata);
5955

60-
// Step 4: Load weights (if requested)
61-
Weights weights = null;
62-
if (loadWeights) {
63-
Map<String, GGMLTensorEntry> tensorEntries = GGUF.loadTensors(fileChannel, gguf.getTensorDataOffset(), gguf.getTensorInfos());
64-
weights = loadWeights(tensorEntries, config);
56+
// Step 4: Load tensor entries
57+
Map<String, GGMLTensorEntry> tensorEntries;
58+
if (useTornadovm) {
59+
tensorEntries = GGUF.loadTensorsTornado(fileChannel, gguf.getTensorDataOffset(), gguf.getTensorInfos());
60+
} else {
61+
tensorEntries = GGUF.loadTensorsStandard(fileChannel, gguf.getTensorDataOffset(), gguf.getTensorInfos());
6562
}
6663

64+
// Step 4: Load weights
65+
Weights weights = loadWeights(tensorEntries, config);
66+
6767
// Step 5: Create and return model instance
6868
return createModel(config, tokenizer, weights);
6969

@@ -75,39 +75,33 @@ public final M loadModel() {
7575
/**
7676
* Load the vocabulary from GGUF metadata. Model-specific implementations should override this method.
7777
*
78-
* @param metadata
79-
* The GGUF metadata map
78+
* @param metadata The GGUF metadata map
8079
* @return The loaded Vocabulary
8180
*/
8281
protected abstract Vocabulary loadVocabulary(Map<String, Object> metadata);
8382

8483
/**
8584
* Create a tokenizer instance for this model.
8685
*
87-
* @param metadata
88-
* The GGUF metadata map
89-
* @param vocabulary
90-
* The loaded vocabulary
86+
* @param metadata The GGUF metadata map
87+
* @param vocabulary The loaded vocabulary
9188
* @return The tokenizer instance
9289
*/
9390
protected abstract Tokenizer createTokenizer(Map<String, Object> metadata, Vocabulary vocabulary);
9491

9592
/**
9693
* Create a configuration instance from GGUF metadata.
9794
*
98-
* @param metadata
99-
* The GGUF metadata map
95+
* @param metadata The GGUF metadata map
10096
* @return The configuration instance
10197
*/
10298
protected abstract C createConfiguration(Map<String, Object> metadata);
10399

104100
/**
105101
* Load model weights from tensor entries. Default implementation handles common weight loading logic.
106102
*
107-
* @param tensorEntries
108-
* Map of tensor names to tensor entries
109-
* @param config
110-
* The model configuration
103+
* @param tensorEntries Map of tensor names to tensor entries
104+
* @param config The model configuration
111105
* @return The loaded weights
112106
*/
113107
public Weights loadWeights(Map<String, GGMLTensorEntry> tensorEntries, C config) {
@@ -129,12 +123,9 @@ public Weights loadWeights(Map<String, GGMLTensorEntry> tensorEntries, C config)
129123
/**
130124
* Create the final model instance.
131125
*
132-
* @param config
133-
* The model configuration
134-
* @param tokenizer
135-
* The tokenizer
136-
* @param weights
137-
* The loaded weights
126+
* @param config The model configuration
127+
* @param tokenizer The tokenizer
128+
* @param weights The loaded weights
138129
* @return The model instance
139130
*/
140131
protected abstract M createModel(C config, Tokenizer tokenizer, Weights weights);
@@ -161,12 +152,10 @@ protected GGMLTensorEntry getOutputWeight(Map<String, GGMLTensorEntry> tensorEnt
161152
/**
162153
* Create standard (CPU) weights.
163154
*/
164-
protected abstract Weights createStandardWeights(Map<String, GGMLTensorEntry> tensorEntries, C config, Pair<float[], float[]> ropeFreqs, GGMLTensorEntry tokenEmbeddings,
165-
GGMLTensorEntry outputWeight);
155+
protected abstract Weights createStandardWeights(Map<String, GGMLTensorEntry> tensorEntries, C config, Pair<float[], float[]> ropeFreqs, GGMLTensorEntry tokenEmbeddings, GGMLTensorEntry outputWeight);
166156

167157
/**
168158
* Create TornadoVM (GPU) weights.
169159
*/
170-
protected abstract Weights createTornadoVMWeights(Map<String, GGMLTensorEntry> tensorEntries, C config, Pair<float[], float[]> ropeFreqs, GGMLTensorEntry tokenEmbeddings,
171-
GGMLTensorEntry outputWeight);
160+
protected abstract Weights createTornadoVMWeights(Map<String, GGMLTensorEntry> tensorEntries, C config, Pair<float[], float[]> ropeFreqs, GGMLTensorEntry tokenEmbeddings, GGMLTensorEntry outputWeight);
172161
}

src/main/java/org/beehive/gpullama3/model/loader/LlamaModelLoader.java

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@
2828

2929
public class LlamaModelLoader extends AbstractModelLoader<Llama, LlamaConfiguration> {
3030

31-
public LlamaModelLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights, boolean useTornadovm) {
32-
super(fileChannel, gguf, contextLength, loadWeights, useTornadovm);
31+
public LlamaModelLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean useTornadovm) {
32+
super(fileChannel, gguf, contextLength, useTornadovm);
3333
}
3434

3535
@Override
@@ -42,6 +42,7 @@ protected Tokenizer createTokenizer(Map<String, Object> metadata, Vocabulary voc
4242
return new LlamaTokenizer(metadata, vocabulary);
4343
}
4444

45+
// @formatter:off
4546
@Override
4647
protected LlamaConfiguration createConfiguration(Map<String, Object> metadata) {
4748
int vocabSize = metadata.containsKey("llama.vocab_size") ? (int) metadata.get("llama.vocab_size") : (int) metadata.get("tokenizer.ggml.tokens.length");
@@ -59,21 +60,22 @@ protected LlamaConfiguration createConfiguration(Map<String, Object> metadata) {
5960
(float) metadata.getOrDefault("llama.attention.layer_norm_rms_epsilon", 1e-5f),
6061
(float) metadata.getOrDefault("llama.rope.freq_base", 10000f)).withContextLength(contextLength);
6162
}
63+
// @formatter:on
6264

6365
@Override
6466
protected Pair<float[], float[]> precomputeRopeFrequencies(LlamaConfiguration config) {
65-
return RoPE.precomputeFreqsCis(config.contextLength(), config.dim() / config.numberOfHeads(), config.ropeTheta(), false, 1.0f, 1.0f, 1.0f, config.contextLength()
66-
);
67+
return RoPE.precomputeFreqsCis(config.contextLength(), config.dim() / config.numberOfHeads(), config.ropeTheta(), false, 1.0f, 1.0f, 1.0f, config.contextLength());
6768
}
6869

6970
@Override
7071
protected Llama createModel(LlamaConfiguration config, Tokenizer tokenizer, Weights weights) {
7172
return new Llama(config, tokenizer, weights, ChatFormat.create(tokenizer, null));
7273
}
7374

75+
// @formatter:off
7476
@Override
7577
protected Weights createStandardWeights(Map<String, GGMLTensorEntry> tensorEntries, LlamaConfiguration config, Pair<float[], float[]> ropeFreqs, GGMLTensorEntry tokenEmbeddings,
76-
GGMLTensorEntry outputWeight) {
78+
GGMLTensorEntry outputWeight) {
7779

7880
final int nl = config.numberOfLayers();
7981

@@ -94,7 +96,9 @@ protected Weights createStandardWeights(Map<String, GGMLTensorEntry> tensorEntri
9496
loadTensor(outputWeight),
9597
outputWeight.ggmlType());
9698
}
99+
// @formatter:on
97100

101+
// @formatter:off
98102
@Override
99103
protected Weights createTornadoVMWeights(Map<String, GGMLTensorEntry> tensorEntries,
100104
LlamaConfiguration config,
@@ -117,20 +121,21 @@ protected Weights createTornadoVMWeights(Map<String, GGMLTensorEntry> tensorEntr
117121
// Load all tensors uniformly as TornadoTensor hierarchy
118122
return new LlamaTornadoWeights(
119123
loadTornadoTensorAsFP32(tokenEmbeddings),
120-
loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")),
124+
loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), // fp32
121125
loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_q.weight")),
122126
loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_k.weight")),
123127
loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_v.weight")),
124128
loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_output.weight")),
125-
loadArrayOfTornadoTensorsAsFP32(nl, i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")),
129+
loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), // fp32
126130
loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")),
127131
loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_down.weight")),
128132
loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_up.weight")),
129-
loadTornadoTensorAsFP32(tensorEntries.get("output_norm.weight")),
133+
loadTornadoTensor(tensorEntries.get("output_norm.weight")), // fp32
130134
new FP32TornadoTensor(FloatArray.fromArray(ropeFreqs.first())),
131135
new FP32TornadoTensor(FloatArray.fromArray(ropeFreqs.second())),
132136
loadTornadoTensor(outputWeight),
133137
ggmlType
134138
);
135139
}
140+
// @formatter:on
136141
}

0 commit comments

Comments
 (0)