|
1 | 1 | #ifndef __MODEL_H__ |
2 | 2 | #define __MODEL_H__ |
3 | 3 |
|
4 | | -#include <functional> |
5 | | -#include <map> |
6 | | -#include <memory> |
7 | | -#include <set> |
8 | 4 | #include <string> |
| 5 | +#include <utility> |
9 | 6 | #include <vector> |
10 | 7 |
|
11 | 8 | #include "core/ordered_map.hpp" |
@@ -238,73 +235,4 @@ enum PMVersion { |
238 | 235 | typedef OrderedMap<std::string, TensorStorage> String2TensorStorage; |
239 | 236 | using TensorTypeRules = std::vector<std::pair<std::string, ggml_type>>; |
240 | 237 |
|
241 | | -TensorTypeRules parse_tensor_type_rules(const std::string& tensor_type_rules); |
242 | | - |
243 | | -class MmapWrapper; |
244 | | - |
245 | | -struct ModelFileData { |
246 | | - std::string path; |
247 | | - std::vector<TensorStorage> tensors; |
248 | | - std::shared_ptr<MmapWrapper> mmapped; |
249 | | - std::shared_ptr<struct ggml_backend_buffer> mmbuffer; |
250 | | - bool is_zip; |
251 | | -}; |
252 | | - |
253 | | -struct MmapTensorStore { |
254 | | - std::shared_ptr<MmapWrapper> mmapped; |
255 | | - std::shared_ptr<struct ggml_backend_buffer> mmbuffer; |
256 | | -}; |
257 | | - |
258 | | -class ModelLoader { |
259 | | -protected: |
260 | | - SDVersion version_ = VERSION_COUNT; |
261 | | - std::vector<std::string> file_paths_; |
262 | | - std::vector<ModelFileData> file_data; |
263 | | - bool model_files_processed = false; |
264 | | - String2TensorStorage tensor_storage_map; |
265 | | - |
266 | | - void add_tensor_storage(const TensorStorage& tensor_storage); |
267 | | - |
268 | | - bool init_from_gguf_file(const std::string& file_path, const std::string& prefix = ""); |
269 | | - bool init_from_safetensors_file(const std::string& file_path, const std::string& prefix = ""); |
270 | | - bool init_from_torch_zip_file(const std::string& file_path, const std::string& prefix = ""); |
271 | | - bool init_from_torch_legacy_file(const std::string& file_path, const std::string& prefix = ""); |
272 | | - bool init_from_diffusers_file(const std::string& file_path, const std::string& prefix = ""); |
273 | | - |
274 | | -public: |
275 | | - bool init_from_file(const std::string& file_path, const std::string& prefix = ""); |
276 | | - void convert_tensors_name(); |
277 | | - bool init_from_file_and_convert_name(const std::string& file_path, |
278 | | - const std::string& prefix = "", |
279 | | - SDVersion version = VERSION_COUNT); |
280 | | - SDVersion get_sd_version(); |
281 | | - std::map<ggml_type, uint32_t> get_wtype_stat(); |
282 | | - std::map<ggml_type, uint32_t> get_conditioner_wtype_stat(); |
283 | | - std::map<ggml_type, uint32_t> get_diffusion_model_wtype_stat(); |
284 | | - std::map<ggml_type, uint32_t> get_vae_wtype_stat(); |
285 | | - String2TensorStorage& get_tensor_storage_map() { return tensor_storage_map; } |
286 | | - void set_wtype_override(ggml_type wtype, std::string tensor_type_rules = ""); |
287 | | - void process_model_files(bool enable_mmap = false, bool writable_mmap = true); |
288 | | - std::vector<MmapTensorStore> mmap_tensors(std::map<std::string, ggml_tensor*>& tensors, |
289 | | - std::set<std::string> ignore_tensors = {}, |
290 | | - bool writable = true); |
291 | | - bool load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_threads = 0, bool use_mmap = false); |
292 | | - bool load_tensors(std::map<std::string, ggml_tensor*>& tensors, |
293 | | - std::set<std::string> ignore_tensors = {}, |
294 | | - int n_threads = 0, |
295 | | - bool use_mmap = false); |
296 | | - |
297 | | - std::vector<std::string> get_tensor_names() const { |
298 | | - std::vector<std::string> names; |
299 | | - for (const auto& [name, tensor_storage] : tensor_storage_map) { |
300 | | - names.push_back(name); |
301 | | - } |
302 | | - return names; |
303 | | - } |
304 | | - |
305 | | - bool tensor_should_be_converted(const TensorStorage& tensor_storage, ggml_type type); |
306 | | - int64_t get_params_mem_size(ggml_backend_t backend, ggml_type type = GGML_TYPE_COUNT); |
307 | | - ~ModelLoader() = default; |
308 | | -}; |
309 | | - |
310 | 238 | #endif // __MODEL_H__ |
0 commit comments