Skip to content

Commit b3d56d0

Browse files
authored
refactor: split model loader from model definitions (#1619)
1 parent 2a07540 commit b3d56d0

25 files changed

Lines changed: 106 additions & 84 deletions

src/conditioning/conditioner.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "model/te/clip.hpp"
1010
#include "model/te/llm.hpp"
1111
#include "model/te/t5.hpp"
12+
#include "model_loader.h"
1213

1314
struct SDCondition {
1415
sd::Tensor<float> c_crossattn;

src/convert.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
#include <regex>
44
#include <vector>
55

6-
#include "model.h"
76
#include "model_io/gguf_io.h"
87
#include "model_io/safetensors_io.h"
8+
#include "model_loader.h"
99
#include "util.h"
1010

1111
#include "ggml_extend_backend.h"

src/extensions/generation_extension.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
#include "conditioning/conditioner.hpp"
1111
#include "core/ggml_extend_backend.h"
12-
#include "model.h"
12+
#include "model_loader.h"
1313
#include "stable-diffusion.h"
1414

1515
struct GenerationExtensionInitContext {

src/model.h

Lines changed: 1 addition & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
11
#ifndef __MODEL_H__
22
#define __MODEL_H__
33

4-
#include <functional>
5-
#include <map>
6-
#include <memory>
7-
#include <set>
84
#include <string>
5+
#include <utility>
96
#include <vector>
107

118
#include "core/ordered_map.hpp"
@@ -238,73 +235,4 @@ enum PMVersion {
238235
typedef OrderedMap<std::string, TensorStorage> String2TensorStorage;
239236
using TensorTypeRules = std::vector<std::pair<std::string, ggml_type>>;
240237

241-
TensorTypeRules parse_tensor_type_rules(const std::string& tensor_type_rules);
242-
243-
class MmapWrapper;
244-
245-
struct ModelFileData {
246-
std::string path;
247-
std::vector<TensorStorage> tensors;
248-
std::shared_ptr<MmapWrapper> mmapped;
249-
std::shared_ptr<struct ggml_backend_buffer> mmbuffer;
250-
bool is_zip;
251-
};
252-
253-
struct MmapTensorStore {
254-
std::shared_ptr<MmapWrapper> mmapped;
255-
std::shared_ptr<struct ggml_backend_buffer> mmbuffer;
256-
};
257-
258-
class ModelLoader {
259-
protected:
260-
SDVersion version_ = VERSION_COUNT;
261-
std::vector<std::string> file_paths_;
262-
std::vector<ModelFileData> file_data;
263-
bool model_files_processed = false;
264-
String2TensorStorage tensor_storage_map;
265-
266-
void add_tensor_storage(const TensorStorage& tensor_storage);
267-
268-
bool init_from_gguf_file(const std::string& file_path, const std::string& prefix = "");
269-
bool init_from_safetensors_file(const std::string& file_path, const std::string& prefix = "");
270-
bool init_from_torch_zip_file(const std::string& file_path, const std::string& prefix = "");
271-
bool init_from_torch_legacy_file(const std::string& file_path, const std::string& prefix = "");
272-
bool init_from_diffusers_file(const std::string& file_path, const std::string& prefix = "");
273-
274-
public:
275-
bool init_from_file(const std::string& file_path, const std::string& prefix = "");
276-
void convert_tensors_name();
277-
bool init_from_file_and_convert_name(const std::string& file_path,
278-
const std::string& prefix = "",
279-
SDVersion version = VERSION_COUNT);
280-
SDVersion get_sd_version();
281-
std::map<ggml_type, uint32_t> get_wtype_stat();
282-
std::map<ggml_type, uint32_t> get_conditioner_wtype_stat();
283-
std::map<ggml_type, uint32_t> get_diffusion_model_wtype_stat();
284-
std::map<ggml_type, uint32_t> get_vae_wtype_stat();
285-
String2TensorStorage& get_tensor_storage_map() { return tensor_storage_map; }
286-
void set_wtype_override(ggml_type wtype, std::string tensor_type_rules = "");
287-
void process_model_files(bool enable_mmap = false, bool writable_mmap = true);
288-
std::vector<MmapTensorStore> mmap_tensors(std::map<std::string, ggml_tensor*>& tensors,
289-
std::set<std::string> ignore_tensors = {},
290-
bool writable = true);
291-
bool load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_threads = 0, bool use_mmap = false);
292-
bool load_tensors(std::map<std::string, ggml_tensor*>& tensors,
293-
std::set<std::string> ignore_tensors = {},
294-
int n_threads = 0,
295-
bool use_mmap = false);
296-
297-
std::vector<std::string> get_tensor_names() const {
298-
std::vector<std::string> names;
299-
for (const auto& [name, tensor_storage] : tensor_storage_map) {
300-
names.push_back(name);
301-
}
302-
return names;
303-
}
304-
305-
bool tensor_should_be_converted(const TensorStorage& tensor_storage, ggml_type type);
306-
int64_t get_params_mem_size(ggml_backend_t backend, ggml_type type = GGML_TYPE_COUNT);
307-
~ModelLoader() = default;
308-
};
309-
310238
#endif // __MODEL_H__

src/model/adapter/lora.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include <mutex>
55
#include "core/ggml_extend.hpp"
6+
#include "model_loader.h"
67

78
#define LORA_GRAPH_BASE_SIZE 10240
89

src/model/adapter/pmid.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "model/adapter/lora.hpp"
77
#include "model/common/block.hpp"
88
#include "model/te/clip.hpp"
9+
#include "model_loader.h"
910

1011
struct FuseBlock : public GGMLBlock {
1112
// network hparams

src/model/diffusion/control.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#ifndef __SD_MODEL_DIFFUSION_CONTROL_HPP__
22
#define __SD_MODEL_DIFFUSION_CONTROL_HPP__
33

4-
#include "model.h"
54
#include "model/common/block.hpp"
5+
#include "model_loader.h"
66

77
#define CONTROL_NET_GRAPH_SIZE 1536
88

src/model/diffusion/flux.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
#include <memory>
55
#include <vector>
66

7-
#include "model.h"
87
#include "model/common/rope.hpp"
98
#include "model/diffusion/dit.hpp"
109
#include "model/diffusion/model.hpp"
10+
#include "model_loader.h"
1111

1212
#define FLUX_GRAPH_SIZE 10240
1313

src/model/diffusion/ltxv.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "model/common/rope.hpp"
1414
#include "model/diffusion/flux.hpp"
1515
#include "model/diffusion/model.hpp"
16+
#include "model_loader.h"
1617

1718
namespace LTXV {
1819

src/model/diffusion/mmdit.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
#include <vector>
88

99
#include "core/ggml_extend.hpp"
10-
#include "model.h"
1110
#include "model/common/block.hpp"
1211
#include "model/diffusion/model.hpp"
12+
#include "model_loader.h"
1313

1414
#define MMDIT_GRAPH_SIZE 10240
1515

0 commit comments

Comments
 (0)