Skip to content

Commit 82173d1

Browse files
committed
Updated stable diffusion version.
2 parents 8e5e645 + d46ed5e commit 82173d1

34 files changed

+3065
-757
lines changed

.github/workflows/build.yml

+9-9
Original file line numberDiff line numberDiff line change
@@ -155,17 +155,17 @@ jobs:
155155
matrix:
156156
include:
157157
- build: "noavx"
158-
defines: "-DGGML_AVX=OFF -DGGML_AVX2=OFF -DGGML_FMA=OFF -DSD_BUILD_SHARED_LIBS=ON"
158+
defines: "-DGGML_NATIVE=OFF -DGGML_AVX=OFF -DGGML_AVX2=OFF -DGGML_FMA=OFF -DSD_BUILD_SHARED_LIBS=ON"
159159
- build: "avx2"
160-
defines: "-DGGML_AVX2=ON -DSD_BUILD_SHARED_LIBS=ON"
160+
defines: "-DGGML_NATIVE=OFF -DGGML_AVX2=ON -DSD_BUILD_SHARED_LIBS=ON"
161161
- build: "avx"
162-
defines: "-DGGML_AVX2=OFF -DSD_BUILD_SHARED_LIBS=ON"
162+
defines: "-DGGML_NATIVE=OFF -DGGML_AVX=ON -DGGML_AVX2=OFF -DSD_BUILD_SHARED_LIBS=ON"
163163
- build: "avx512"
164-
defines: "-DGGML_AVX512=ON -DSD_BUILD_SHARED_LIBS=ON"
164+
defines: "-DGGML_NATIVE=OFF -DGGML_AVX512=ON -DGGML_AVX=ON -DGGML_AVX2=ON -DSD_BUILD_SHARED_LIBS=ON"
165165
- build: "cuda12"
166-
defines: "-DSD_CUBLAS=ON -DSD_BUILD_SHARED_LIBS=ON"
167-
- build: "rocm5.5"
168-
defines: '-G Ninja -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DSD_HIPBLAS=ON -DCMAKE_BUILD_TYPE=Release -DAMDGPU_TARGETS="gfx1100;gfx1102;gfx1030" -DSD_BUILD_SHARED_LIBS=ON'
166+
defines: "-DSD_CUDA=ON -DSD_BUILD_SHARED_LIBS=ON -DCMAKE_CUDA_ARCHITECTURES=90;89;80;75"
167+
# - build: "rocm5.5"
168+
# defines: '-G Ninja -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DSD_HIPBLAS=ON -DCMAKE_BUILD_TYPE=Release -DAMDGPU_TARGETS="gfx1100;gfx1102;gfx1030" -DSD_BUILD_SHARED_LIBS=ON'
169169
- build: 'vulkan'
170170
defines: "-DSD_VULKAN=ON -DSD_BUILD_SHARED_LIBS=ON"
171171
steps:
@@ -178,9 +178,9 @@ jobs:
178178
- name: Install cuda-toolkit
179179
id: cuda-toolkit
180180
if: ${{ matrix.build == 'cuda12' }}
181-
uses: Jimver/[email protected].11
181+
uses: Jimver/[email protected].19
182182
with:
183-
cuda: "12.2.0"
183+
cuda: "12.6.2"
184184
method: "network"
185185
sub-packages: '["nvcc", "cudart", "cublas", "cublas_dev", "thrust", "visual_studio_integration"]'
186186

CMakeLists.txt

+15-11
Original file line numberDiff line numberDiff line change
@@ -24,20 +24,20 @@ endif()
2424
# general
2525
#option(SD_BUILD_TESTS "sd: build tests" ${SD_STANDALONE})
2626
option(SD_BUILD_EXAMPLES "sd: build examples" ${SD_STANDALONE})
27-
option(SD_CUBLAS "sd: cuda backend" OFF)
27+
option(SD_CUDA "sd: cuda backend" OFF)
2828
option(SD_HIPBLAS "sd: rocm backend" OFF)
2929
option(SD_METAL "sd: metal backend" OFF)
3030
option(SD_VULKAN "sd: vulkan backend" OFF)
3131
option(SD_SYCL "sd: sycl backend" OFF)
32-
option(SD_FLASH_ATTN "sd: use flash attention for x4 less memory usage" OFF)
32+
option(SD_MUSA "sd: musa backend" OFF)
3333
option(SD_FAST_SOFTMAX "sd: x1.5 faster softmax, indeterministic (sometimes, same seed don't generate same image), cuda only" OFF)
3434
option(SD_BUILD_SHARED_LIBS "sd: build shared libs" OFF)
3535
#option(SD_BUILD_SERVER "sd: build server example" ON)
3636

37-
if(SD_CUBLAS)
38-
message("-- Use CUBLAS as backend stable-diffusion")
37+
if(SD_CUDA)
38+
message("-- Use CUDA as backend stable-diffusion")
3939
set(GGML_CUDA ON)
40-
add_definitions(-DSD_USE_CUBLAS)
40+
add_definitions(-DSD_USE_CUDA)
4141
endif()
4242

4343
if(SD_METAL)
@@ -54,21 +54,25 @@ endif ()
5454

5555
if (SD_HIPBLAS)
5656
message("-- Use HIPBLAS as backend stable-diffusion")
57-
set(GGML_HIPBLAS ON)
58-
add_definitions(-DSD_USE_CUBLAS)
57+
set(GGML_HIP ON)
58+
add_definitions(-DSD_USE_CUDA)
5959
if(SD_FAST_SOFTMAX)
6060
set(GGML_CUDA_FAST_SOFTMAX ON)
6161
endif()
6262
endif ()
6363

64-
if(SD_FLASH_ATTN)
65-
message("-- Use Flash Attention for memory optimization")
66-
add_definitions(-DSD_USE_FLASH_ATTENTION)
64+
if(SD_MUSA)
65+
message("-- Use MUSA as backend stable-diffusion")
66+
set(GGML_MUSA ON)
67+
add_definitions(-DSD_USE_CUDA)
68+
if(SD_FAST_SOFTMAX)
69+
set(GGML_CUDA_FAST_SOFTMAX ON)
70+
endif()
6771
endif()
6872

6973
set(SD_LIB stable-diffusion)
7074

71-
file(GLOB SD_LIB_SOURCES
75+
file(GLOB SD_LIB_SOURCES
7276
"*.h"
7377
"*.cpp"
7478
"*.hpp"

Dockerfile.musa

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
ARG MUSA_VERSION=rc3.1.0
2+
3+
FROM mthreads/musa:${MUSA_VERSION}-devel-ubuntu22.04 as build
4+
5+
RUN apt-get update && apt-get install -y cmake
6+
7+
WORKDIR /sd.cpp
8+
9+
COPY . .
10+
11+
RUN mkdir build && cd build && \
12+
cmake .. -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DSD_MUSA=ON -DCMAKE_BUILD_TYPE=Release && \
13+
cmake --build . --config Release
14+
15+
FROM mthreads/musa:${MUSA_VERSION}-runtime-ubuntu22.04 as runtime
16+
17+
COPY --from=build /sd.cpp/build/bin/sd /sd
18+
19+
ENTRYPOINT [ "/sd" ]

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -63,4 +63,4 @@ To pull the latest changes from the upstream repository (`leejet/stable-diffusio
6363

6464
```bash
6565
git push origin master
66-
```
66+
```

assets/sd3.5_large.png

1.81 MB
Loading

clip.hpp

+52-25
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,13 @@ class CLIPTokenizer {
343343
}
344344
}
345345

346+
std::string clean_up_tokenization(std::string& text) {
347+
std::regex pattern(R"( ,)");
348+
// Replace " ," with ","
349+
std::string result = std::regex_replace(text, pattern, ",");
350+
return result;
351+
}
352+
346353
std::string decode(const std::vector<int>& tokens) {
347354
std::string text = "";
348355
for (int t : tokens) {
@@ -351,8 +358,12 @@ class CLIPTokenizer {
351358
std::u32string ts = decoder[t];
352359
// printf("%d, %s \n", t, utf32_to_utf8(ts).c_str());
353360
std::string s = utf32_to_utf8(ts);
354-
if (s.length() >= 4 && ends_with(s, "</w>")) {
355-
text += " " + s.replace(s.length() - 4, s.length() - 1, "");
361+
if (s.length() >= 4) {
362+
if (ends_with(s, "</w>")) {
363+
text += s.replace(s.length() - 4, s.length() - 1, "") + " ";
364+
} else {
365+
text += s;
366+
}
356367
} else {
357368
text += " " + s;
358369
}
@@ -364,6 +375,7 @@ class CLIPTokenizer {
364375

365376
// std::string s((char *)bytes.data());
366377
// std::string s = "";
378+
text = clean_up_tokenization(text);
367379
return trim(text);
368380
}
369381

@@ -533,9 +545,12 @@ class CLIPEmbeddings : public GGMLBlock {
533545
int64_t vocab_size;
534546
int64_t num_positions;
535547

536-
void init_params(struct ggml_context* ctx, ggml_type wtype) {
537-
params["token_embedding.weight"] = ggml_new_tensor_2d(ctx, wtype, embed_dim, vocab_size);
538-
params["position_embedding.weight"] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, embed_dim, num_positions);
548+
void init_params(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, const std::string prefix = "") {
549+
enum ggml_type token_wtype = (tensor_types.find(prefix + "token_embedding.weight") != tensor_types.end()) ? tensor_types[prefix + "token_embedding.weight"] : GGML_TYPE_F32;
550+
enum ggml_type position_wtype = GGML_TYPE_F32; //(tensor_types.find(prefix + "position_embedding.weight") != tensor_types.end()) ? tensor_types[prefix + "position_embedding.weight"] : GGML_TYPE_F32;
551+
552+
params["token_embedding.weight"] = ggml_new_tensor_2d(ctx, token_wtype, embed_dim, vocab_size);
553+
params["position_embedding.weight"] = ggml_new_tensor_2d(ctx, position_wtype, embed_dim, num_positions);
539554
}
540555

541556
public:
@@ -579,11 +594,14 @@ class CLIPVisionEmbeddings : public GGMLBlock {
579594
int64_t image_size;
580595
int64_t num_patches;
581596
int64_t num_positions;
597+
void init_params(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, const std::string prefix = "") {
598+
enum ggml_type patch_wtype = GGML_TYPE_F16; // tensor_types.find(prefix + "patch_embedding.weight") != tensor_types.end() ? tensor_types[prefix + "patch_embedding.weight"] : GGML_TYPE_F16;
599+
enum ggml_type class_wtype = GGML_TYPE_F32; // tensor_types.find(prefix + "class_embedding") != tensor_types.end() ? tensor_types[prefix + "class_embedding"] : GGML_TYPE_F32;
600+
enum ggml_type position_wtype = GGML_TYPE_F32; // tensor_types.find(prefix + "position_embedding.weight") != tensor_types.end() ? tensor_types[prefix + "position_embedding.weight"] : GGML_TYPE_F32;
582601

583-
void init_params(struct ggml_context* ctx, ggml_type wtype) {
584-
params["patch_embedding.weight"] = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, patch_size, patch_size, num_channels, embed_dim);
585-
params["class_embedding"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, embed_dim);
586-
params["position_embedding.weight"] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, embed_dim, num_positions);
602+
params["patch_embedding.weight"] = ggml_new_tensor_4d(ctx, patch_wtype, patch_size, patch_size, num_channels, embed_dim);
603+
params["class_embedding"] = ggml_new_tensor_1d(ctx, class_wtype, embed_dim);
604+
params["position_embedding.weight"] = ggml_new_tensor_2d(ctx, position_wtype, embed_dim, num_positions);
587605
}
588606

589607
public:
@@ -639,9 +657,10 @@ enum CLIPVersion {
639657

640658
class CLIPTextModel : public GGMLBlock {
641659
protected:
642-
void init_params(struct ggml_context* ctx, ggml_type wtype) {
660+
void init_params(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, const std::string prefix = "") {
643661
if (version == OPEN_CLIP_VIT_BIGG_14) {
644-
params["text_projection"] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, projection_dim, hidden_size);
662+
enum ggml_type wtype = GGML_TYPE_F32; // tensor_types.find(prefix + "text_projection") != tensor_types.end() ? tensor_types[prefix + "text_projection"] : GGML_TYPE_F32;
663+
params["text_projection"] = ggml_new_tensor_2d(ctx, wtype, projection_dim, hidden_size);
645664
}
646665
}
647666

@@ -711,8 +730,12 @@ class CLIPTextModel : public GGMLBlock {
711730
if (return_pooled) {
712731
auto text_projection = params["text_projection"];
713732
ggml_tensor* pooled = ggml_view_1d(ctx, x, hidden_size, x->nb[1] * max_token_idx);
714-
pooled = ggml_mul_mat(ctx, ggml_cont(ctx, ggml_transpose(ctx, text_projection)), pooled);
715-
return pooled;
733+
if (text_projection != NULL) {
734+
pooled = ggml_nn_linear(ctx, pooled, text_projection, NULL);
735+
} else {
736+
LOG_DEBUG("Missing text_projection matrix, assuming identity...");
737+
}
738+
return pooled; // [hidden_size, 1, 1]
716739
}
717740

718741
return x; // [N, n_token, hidden_size]
@@ -761,14 +784,17 @@ class CLIPVisionModel : public GGMLBlock {
761784
auto x = embeddings->forward(ctx, pixel_values); // [N, num_positions, embed_dim]
762785
x = pre_layernorm->forward(ctx, x);
763786
x = encoder->forward(ctx, x, -1, false);
764-
x = post_layernorm->forward(ctx, x); // [N, n_token, hidden_size]
787+
// print_ggml_tensor(x, true, "ClipVisionModel x: ");
788+
auto last_hidden_state = x;
789+
x = post_layernorm->forward(ctx, x); // [N, n_token, hidden_size]
765790

766791
GGML_ASSERT(x->ne[3] == 1);
767792
if (return_pooled) {
768793
ggml_tensor* pooled = ggml_cont(ctx, ggml_view_2d(ctx, x, x->ne[0], x->ne[2], x->nb[2], 0));
769794
return pooled; // [N, hidden_size]
770795
} else {
771-
return x; // [N, n_token, hidden_size]
796+
// return x; // [N, n_token, hidden_size]
797+
return last_hidden_state; // [N, n_token, hidden_size]
772798
}
773799
}
774800
};
@@ -779,9 +805,9 @@ class CLIPProjection : public UnaryBlock {
779805
int64_t out_features;
780806
bool transpose_weight;
781807

782-
void init_params(struct ggml_context* ctx, ggml_type wtype) {
808+
void init_params(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, const std::string prefix = "") {
809+
enum ggml_type wtype = tensor_types.find(prefix + "weight") != tensor_types.end() ? tensor_types[prefix + "weight"] : GGML_TYPE_F32;
783810
if (transpose_weight) {
784-
LOG_ERROR("transpose_weight");
785811
params["weight"] = ggml_new_tensor_2d(ctx, wtype, out_features, in_features);
786812
} else {
787813
params["weight"] = ggml_new_tensor_2d(ctx, wtype, in_features, out_features);
@@ -842,12 +868,13 @@ struct CLIPTextModelRunner : public GGMLRunner {
842868
CLIPTextModel model;
843869

844870
CLIPTextModelRunner(ggml_backend_t backend,
845-
ggml_type wtype,
871+
std::map<std::string, enum ggml_type>& tensor_types,
872+
const std::string prefix,
846873
CLIPVersion version = OPENAI_CLIP_VIT_L_14,
847874
int clip_skip_value = 1,
848875
bool with_final_ln = true)
849-
: GGMLRunner(backend, wtype), model(version, clip_skip_value, with_final_ln) {
850-
model.init(params_ctx, wtype);
876+
: GGMLRunner(backend), model(version, clip_skip_value, with_final_ln) {
877+
model.init(params_ctx, tensor_types, prefix);
851878
}
852879

853880
std::string get_desc() {
@@ -889,13 +916,13 @@ struct CLIPTextModelRunner : public GGMLRunner {
889916
struct ggml_tensor* embeddings = NULL;
890917

891918
if (num_custom_embeddings > 0 && custom_embeddings_data != NULL) {
892-
auto custom_embeddings = ggml_new_tensor_2d(compute_ctx,
893-
wtype,
894-
model.hidden_size,
895-
num_custom_embeddings);
919+
auto token_embed_weight = model.get_token_embed_weight();
920+
auto custom_embeddings = ggml_new_tensor_2d(compute_ctx,
921+
token_embed_weight->type,
922+
model.hidden_size,
923+
num_custom_embeddings);
896924
set_backend_tensor_data(custom_embeddings, custom_embeddings_data);
897925

898-
auto token_embed_weight = model.get_token_embed_weight();
899926
// concatenate custom embeddings
900927
embeddings = ggml_concat(compute_ctx, token_embed_weight, custom_embeddings, 1);
901928
}

common.hpp

+23-14
Original file line numberDiff line numberDiff line change
@@ -182,9 +182,11 @@ class GEGLU : public GGMLBlock {
182182
int64_t dim_in;
183183
int64_t dim_out;
184184

185-
void init_params(struct ggml_context* ctx, ggml_type wtype) {
186-
params["proj.weight"] = ggml_new_tensor_2d(ctx, wtype, dim_in, dim_out * 2);
187-
params["proj.bias"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, dim_out * 2);
185+
void init_params(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, std::string prefix = "") {
186+
enum ggml_type wtype = (tensor_types.find(prefix + "proj.weight") != tensor_types.end()) ? tensor_types[prefix + "proj.weight"] : GGML_TYPE_F32;
187+
enum ggml_type bias_wtype = GGML_TYPE_F32; //(tensor_types.find(prefix + "proj.bias") != tensor_types.end()) ? tensor_types[prefix + "proj.bias"] : GGML_TYPE_F32;
188+
params["proj.weight"] = ggml_new_tensor_2d(ctx, wtype, dim_in, dim_out * 2);
189+
params["proj.bias"] = ggml_new_tensor_1d(ctx, bias_wtype, dim_out * 2);
188190
}
189191

190192
public:
@@ -245,16 +247,19 @@ class CrossAttention : public GGMLBlock {
245247
int64_t context_dim;
246248
int64_t n_head;
247249
int64_t d_head;
250+
bool flash_attn;
248251

249252
public:
250253
CrossAttention(int64_t query_dim,
251254
int64_t context_dim,
252255
int64_t n_head,
253-
int64_t d_head)
256+
int64_t d_head,
257+
bool flash_attn = false)
254258
: n_head(n_head),
255259
d_head(d_head),
256260
query_dim(query_dim),
257-
context_dim(context_dim) {
261+
context_dim(context_dim),
262+
flash_attn(flash_attn) {
258263
int64_t inner_dim = d_head * n_head;
259264

260265
blocks["to_q"] = std::shared_ptr<GGMLBlock>(new Linear(query_dim, inner_dim, false));
@@ -283,7 +288,7 @@ class CrossAttention : public GGMLBlock {
283288
auto k = to_k->forward(ctx, context); // [N, n_context, inner_dim]
284289
auto v = to_v->forward(ctx, context); // [N, n_context, inner_dim]
285290

286-
x = ggml_nn_attention_ext(ctx, q, k, v, n_head, NULL, false); // [N, n_token, inner_dim]
291+
x = ggml_nn_attention_ext(ctx, q, k, v, n_head, NULL, false, false, flash_attn); // [N, n_token, inner_dim]
287292

288293
x = to_out_0->forward(ctx, x); // [N, n_token, query_dim]
289294
return x;
@@ -301,15 +306,16 @@ class BasicTransformerBlock : public GGMLBlock {
301306
int64_t n_head,
302307
int64_t d_head,
303308
int64_t context_dim,
304-
bool ff_in = false)
309+
bool ff_in = false,
310+
bool flash_attn = false)
305311
: n_head(n_head), d_head(d_head), ff_in(ff_in) {
306312
// disable_self_attn is always False
307313
// disable_temporal_crossattention is always False
308314
// switch_temporal_ca_to_sa is always False
309315
// inner_dim is always None or equal to dim
310316
// gated_ff is always True
311-
blocks["attn1"] = std::shared_ptr<GGMLBlock>(new CrossAttention(dim, dim, n_head, d_head));
312-
blocks["attn2"] = std::shared_ptr<GGMLBlock>(new CrossAttention(dim, context_dim, n_head, d_head));
317+
blocks["attn1"] = std::shared_ptr<GGMLBlock>(new CrossAttention(dim, dim, n_head, d_head, flash_attn));
318+
blocks["attn2"] = std::shared_ptr<GGMLBlock>(new CrossAttention(dim, context_dim, n_head, d_head, flash_attn));
313319
blocks["ff"] = std::shared_ptr<GGMLBlock>(new FeedForward(dim, dim));
314320
blocks["norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim));
315321
blocks["norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim));
@@ -374,7 +380,8 @@ class SpatialTransformer : public GGMLBlock {
374380
int64_t n_head,
375381
int64_t d_head,
376382
int64_t depth,
377-
int64_t context_dim)
383+
int64_t context_dim,
384+
bool flash_attn = false)
378385
: in_channels(in_channels),
379386
n_head(n_head),
380387
d_head(d_head),
@@ -388,7 +395,7 @@ class SpatialTransformer : public GGMLBlock {
388395

389396
for (int i = 0; i < depth; i++) {
390397
std::string name = "transformer_blocks." + std::to_string(i);
391-
blocks[name] = std::shared_ptr<GGMLBlock>(new BasicTransformerBlock(inner_dim, n_head, d_head, context_dim));
398+
blocks[name] = std::shared_ptr<GGMLBlock>(new BasicTransformerBlock(inner_dim, n_head, d_head, context_dim, false, flash_attn));
392399
}
393400

394401
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Conv2d(inner_dim, in_channels, {1, 1}));
@@ -433,8 +440,10 @@ class SpatialTransformer : public GGMLBlock {
433440

434441
class AlphaBlender : public GGMLBlock {
435442
protected:
436-
void init_params(struct ggml_context* ctx, ggml_type wtype) {
437-
params["mix_factor"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
443+
void init_params(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, std::string prefix = "") {
444+
// Get the type of the "mix_factor" tensor from the input tensors map with the specified prefix
445+
enum ggml_type wtype = GGML_TYPE_F32; //(tensor_types.ypes.find(prefix + "mix_factor") != tensor_types.end()) ? tensor_types[prefix + "mix_factor"] : GGML_TYPE_F32;
446+
params["mix_factor"] = ggml_new_tensor_1d(ctx, wtype, 1);
438447
}
439448

440449
float get_alpha() {
@@ -511,4 +520,4 @@ class VideoResBlock : public ResBlock {
511520
}
512521
};
513522

514-
#endif // __COMMON_HPP__
523+
#endif // __COMMON_HPP__

0 commit comments

Comments
 (0)