Skip to content

How to save and load ipex optimized model? #686

Open
@benja-matic

Description

@benja-matic

Describe the issue

Hi IPEX team,

I have an application where I want to serve multiple models concurrently, and I want to share weights across concurrent instances. I normally do this with torch.load(path, mmap=True). However, calling ipex.llm.optimize will interfere with weight sharing because ipex manipulates the weights in memory (does a deep copy from what I understand). I would like to instead save the ipex optimized model and load it (something like torch.load(ipex_model, mmap=True))). However, I can't figure out how to do this, and was hoping you could provide an example.

How to reproduce:

My miniconda env.yml file is listed below. pip install -r requirements.txt may not work here but you can create this env easily conda create -n ipex_issue python=3.10 && conda activate ipex_issue followed by the install instructions here and pip install transformers==4.38.1. I am using python 3.10 on an aws c7i.2xlarge instance.

certifi==2024.7.4
charset-normalizer==3.3.2
filelock==3.13.1
fsspec==2024.2.0
huggingface-hub==0.24.5
idna==3.7
intel_extension_for_pytorch==2.3.100
Jinja2==3.1.3
MarkupSafe==2.1.5
mpmath==1.3.0
networkx==3.2.1
numpy==1.26.3
oneccl-bind-pt==2.3.0+cpu
packaging==24.1
pillow==10.2.0
psutil==6.0.0
PyYAML==6.0.2
regex==2024.7.24
requests==2.32.3
safetensors==0.4.4
sympy==1.12
tokenizers==0.15.2
torch==2.3.0+cpu
torchaudio==2.3.0+cpu
torchvision==0.18.0+cpu
tqdm==4.66.5
transformers==4.38.1
typing_extensions==4.9.0
urllib3==2.2.2

Here are the things I have tried:

import os
import torch
from transformers import AutoConfig, AutoModelForCausalLM
import intel_extension_for_pytorch as ipex

config = AutoConfig.from_pretrained("HuggingFaceH4/zephyr-7b-beta", token=os.getenv("HF_ACCESS_TOKEN"))

# Just use a tiny model with random weights so it uses less mem, faster to test

config_micro = dict(
    hidden_size=768,
    intermediate_size=int(768*3.5),
    num_hidden_layers=4,
    num_attention_heads=8,
    num_key_value_heads=8,
)

for k, v in config_micro.items():
    setattr(config, k, v)

# NOTE: this is where we want to load the model with mmap=True to enable weight sharing
# EG model = AutoModelForCausalLM.from_pretrained(path, mmap=True)
model = AutoModelForCausalLM.from_config(config)
model.eval()
model.to(torch.bfloat16)


# NOTE: optimize deepcopies the model, breaks weights sharing / memory mapping
model = ipex.llm.optimize(
    model, dtype=torch.bfloat16, inplace=True, deployment_mode=True
)

###
# Several attempts to save / load
###

# 0) model.save_pretrained("save_pretrained")
model.save_pretrained("save_pretrained")
# error: RecursionError: maximum recursion depth exceeded while calling a Python object

# 1) model.save
model.save("model_save")
# error: AttributeError: 'MistralForCausalLM' object has no attribute 'save'

# 2) torch.save
torch.save(model, "torch_save_model.pt")
# error: RuntimeError: Tried to serialize object __torch__.transformers.models.mistral.modeling_mistral.___torch_mangle_65.MistralForCausalLM which does not have a __getstate__ method defined!

# 3) model.trace_graph.save()
model.trace_graph.save("model_trace_graph")
m3 = torch.jit.load("model_trace_graph")
inputs = torch.randint(low=500, high=1_000, size=(1, 16), dtype=torch.int64)
m3(inputs)
# error: RuntimeError: forward() is missing value for argument 'attention_mask'.

# 4) save jit traced
with torch.no_grad():
    traced_model = torch.jit.trace(model, inputs)
    # error: RecursionError: maximum recursion depth exceeded in comparison

As a side note, I understand you normally use subprocess to deploy multiple concurrent models, but this is not an option for my case because the logic that decides how and when to fork processes is separated from the part of the code that loads the model.

At some point I think I was able to get option 0) above to work, but the loaded model would be a vanilla transformer without ipex optimizations, and I also can't seem to reproduce that behavior at least in this env.

Any help would be much appreciated.

Activity

self-assigned this
on Aug 12, 2024
ZhaoqiongZ

ZhaoqiongZ commented on Aug 12, 2024

@ZhaoqiongZ
Contributor

Hi @benja-matic , thank you for taking the time to try IPEX and for reporting the issue.
the correct way to save should be the third method you used model.trace_graph.save().
but for the load part, need make a little change, refer to https://github.com/intel/intel-extension-for-pytorch/blob/main/examples/cpu/inference/python/llm/single_instance/run_quantization.py#L1128-L1147
your code could be something like

import os
import torch
from transformers import AutoConfig, AutoModelForCausalLM
import intel_extension_for_pytorch as ipex

config = AutoConfig.from_pretrained("HuggingFaceH4/zephyr-7b-beta", token=os.getenv("HF_ACCESS_TOKEN"))

# Just use a tiny model with random weights so it uses less mem, faster to test

config_micro = dict(
    hidden_size=768,
    intermediate_size=int(768*3.5),
    num_hidden_layers=4,
    num_attention_heads=8,
    num_key_value_heads=8,
)

for k, v in config_micro.items():
    setattr(config, k, v)

# NOTE: this is where we want to load the model with mmap=True to enable weight sharing
# EG model = AutoModelForCausalLM.from_pretrained(path, mmap=True)
model = AutoModelForCausalLM.from_config(config)
model.eval()
model.to(torch.bfloat16)


# NOTE: optimize deepcopies the model, breaks weights sharing / memory mapping
model = ipex.llm.optimize(
    model, dtype=torch.bfloat16, inplace=True, deployment_mode=True
)

# 3) model.trace_graph.save()
model.trace_graph.save("model_trace_graph")
m3 = torch.jit.load("model_trace_graph")
m3 = torch.jit.freeze(m3.eval())
ipex._set_optimized_model_for_generation(model, optimized_model=m3)
inputs = torch.randint(low=500, high=1_000, size=(1, 16), dtype=torch.int64)
model(inputs)
benja-matic

benja-matic commented on Aug 16, 2024

@benja-matic
Author

Hi @ZhaoqiongZ , thanks for the fast reply. I tested the code you provided, and it runs without any errors. However, it doesn't solve the underlying issue of weight sharing across concurrent instances of a model. Weight sharing was the motivator for figuring out how to save and load IPEX models. I've been thrilled with the latency numbers I'm getting when using individual IPEX optimized models, so I'd love to be able to use IPEX for concurrent models. My use case involves two things: 1) a python script similar to the one above to load the model and process requests, and 2) a load balancer that runs parallel versions of the python script. The way I do weight sharing for non-ipex optimized models is torch.load("model.pt", mmap=True) in the python script (1). I'm curious what your thoughts are on how I could enable ipex in this use case.

ZhaoqiongZ

ZhaoqiongZ commented on Aug 21, 2024

@ZhaoqiongZ
Contributor

Hi @benja-matic , hope the following method meet your requirement, you can try to load the model in main process, and run parallel version in the multi-threading in the main process. Refer to the following code.

+    if args.benchmark:
+        if args.use_share_weight:
+            threads = []
+            num_instances = args.total_cores // args.cores_per_instance
+            for i in range(0, num_instances):
+               t = threading.Thread(target=benchmark_evaluate, args=(args, model, eval_dataloader))
+               threads.append(t)
+               t.start()
+            for t in threads:
+                t.join()
girishponkiya

girishponkiya commented on Aug 21, 2024

@girishponkiya

I'm bit confused there; may be I didn't understand it clearly.

Just to load a quantized model, I need to run the following code:

m3 = torch.jit.load("model_trace_graph")
m3 = torch.jit.freeze(m3.eval())
ipex._set_optimized_model_for_generation(model, optimized_model=m3)

But, what is model here? If I use additional code (like model = ipex.llm.optimize(...)) to create model variable, then I'm doing optimization again, isn't it? If so, then the sole purpose of saving the optimized model to avoid quantization again is lost. Please help me understand what I'm missing here..

ZhaoqiongZ

ZhaoqiongZ commented on Aug 22, 2024

@ZhaoqiongZ
Contributor

Hi @jianan-gu , please help on this issue.

azhuvath

azhuvath commented on Aug 27, 2024

@azhuvath

m3 = torch.jit.load("model_trace_graph")
m3 = torch.jit.freeze(m3.eval())
ipex._set_optimized_model_for_generation(model, optimized_model=m3)

I am also confused. I perform the quantization as below. It is not clear how to load these models and why are we loading the original model and not just the quantized model.

OMP_NUM_THREADS=56 numactl -m 0 -C 0-55 python run.py --benchmark -m mistralai/Mistral-7B-Instruct-v0.2 --ipex-weight-only-quantization --weight-dtype INT8 --quant-with-amp --output-dir "saved_results/INT8"

OMP_NUM_THREADS=56 numactl -m 0 -C 0-55 python run.py --benchmark -m mistralai/Mistral-7B-Instruct-v0.2 --ipex-weight-only-quantization --weight-dtype INT4 --gptq --quant-with-amp --output-dir "saved_results/INT4"

There are two models getting created after the above steps.
./saved_results/INT8/best_model.pt
./saved_results/INT4/best_model.pt

girishponkiya

girishponkiya commented on Sep 2, 2024

@girishponkiya

I'm bit confused there; may be I didn't understand it clearly.

Just to load a quantized model, I need to run the following code:

m3 = torch.jit.load("model_trace_graph")
m3 = torch.jit.freeze(m3.eval())
ipex._set_optimized_model_for_generation(model, optimized_model=m3)

But, what is model here? If I use additional code (like model = ipex.llm.optimize(...)) to create model variable, then I'm doing optimization again, isn't it? If so, then the sole purpose of saving the optimized model to avoid quantization again is lost. Please help me understand what I'm missing here..

Hi @jianan-gu, please help on this issue.

benja-matic

benja-matic commented on Sep 3, 2024

@benja-matic
Author

Hi @benja-matic , hope the following method meet your requirement, you can try to load the model in main process, and run parallel version in the multi-threading in the main process. Refer to the following code.

+    if args.benchmark:
+        if args.use_share_weight:
+            threads = []
+            num_instances = args.total_cores // args.cores_per_instance
+            for i in range(0, num_instances):
+               t = threading.Thread(target=benchmark_evaluate, args=(args, model, eval_dataloader))
+               threads.append(t)
+               t.start()
+            for t in threads:
+                t.join()

Hi @ZhaoqiongZ , in my use case, I don't have control of forking processes. Is there any way to share weights across multiple models if another process is responsible for forking?

My application is based on the NVIDIA triton inference server. I typically have multiple models running concurrently. This means that the Triton Server calls the initialize method on each concurrent model to load it into memory, and as requests from the user come in, the Triton Server schedules those requests across models. For a non-ipex-optimized model, I put torch.load("model.pt", mmap=True) in the initialize method of the python backed. I'm looking for something comparable I can do with IPEX. I'm also looking into rather or not this is possible if I use a different backend, like vLLM and LibTorch.

Does that make sense? Please let me know if I can further clarify. Thanks again for the support.

jianan-gu

jianan-gu commented on Sep 5, 2024

@jianan-gu
Contributor

I'm bit confused there; may be I didn't understand it clearly.

Just to load a quantized model, I need to run the following code:

m3 = torch.jit.load("model_trace_graph")
m3 = torch.jit.freeze(m3.eval())
ipex._set_optimized_model_for_generation(model, optimized_model=m3)

But, what is model here? If I use additional code (like model = ipex.llm.optimize(...)) to create model variable, then I'm doing optimization again, isn't it? If so, then the sole purpose of saving the optimized model to avoid quantization again is lost. Please help me understand what I'm missing here..

Hi @girishponkiya

When deployment or benchmark only (quantization is done), we still need an object of model to apply our optimizations for generations (like for beam search), while model optimization itself is done (no need to repeat) as long as you save the best_model.pt or model_trace_graph.

More details:
We can init HF model on meta device (no real weights loaded, we just need the object of generation functions form HF model class), see here link , then we use model = ipex.llm.optimize(...) to apply optimizations for generations, and it will not quant or optimize model again see link and link

jianan-gu

jianan-gu commented on Sep 5, 2024

@jianan-gu
Contributor

m3 = torch.jit.load("model_trace_graph")
m3 = torch.jit.freeze(m3.eval())
ipex._set_optimized_model_for_generation(model, optimized_model=m3)

I am also confused. I perform the quantization as below. It is not clear how to load these models and why are we loading the original model and not just the quantized model.

OMP_NUM_THREADS=56 numactl -m 0 -C 0-55 python run.py --benchmark -m mistralai/Mistral-7B-Instruct-v0.2 --ipex-weight-only-quantization --weight-dtype INT8 --quant-with-amp --output-dir "saved_results/INT8"

OMP_NUM_THREADS=56 numactl -m 0 -C 0-55 python run.py --benchmark -m mistralai/Mistral-7B-Instruct-v0.2 --ipex-weight-only-quantization --weight-dtype INT4 --gptq --quant-with-amp --output-dir "saved_results/INT4"

There are two models getting created after the above steps. ./saved_results/INT8/best_model.pt ./saved_results/INT4/best_model.pt

Hi @azhuvath
If you have quantized model "best_model.pt", see notes here , just add args of "--quantized-model-path <output_dir + "best_model.pt">" in your command, you can skip the quantization stage and just to run benchmark with your local optimized models. (which is described in above details of what has been done #686 (comment))

2 remaining items

Loading
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

Labels

CPUCPU specific issuesDocumentationImprovements or additions to documentationLLMNotAnIssue

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

    Development

    No branches or pull requests

      Participants

      @girishponkiya@azhuvath@benja-matic@jianan-gu@ZhaoqiongZ

      Issue actions

        How to save and load ipex optimized model? · Issue #686 · intel/intel-extension-for-pytorch