Skip to content

Commit

Permalink
device_map defaults to auto (#607)
Browse files Browse the repository at this point in the history
  • Loading branch information
casper-hansen authored Sep 12, 2024
1 parent ae77736 commit 1523880
Show file tree
Hide file tree
Showing 7 changed files with 14 additions and 12 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version":

# Load model
model = AutoAWQForCausalLM.from_pretrained(
model_path, **{"low_cpu_mem_usage": True, "use_cache": False}
model_path, low_cpu_mem_usage=True, use_cache=False
)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

Expand Down
2 changes: 1 addition & 1 deletion awq/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def from_pretrained(
model_path,
trust_remote_code=True,
safetensors=True,
device_map=None,
device_map="auto",
download_kwargs=None,
**model_init_kwargs,
) -> BaseAWQForCausalLM:
Expand Down
2 changes: 1 addition & 1 deletion awq/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def from_pretrained(
Doc(
"A device map that will be passed onto the model loading method from transformers."
),
] = None,
] = "auto",
download_kwargs: Annotated[
Dict,
Doc("Used for configure download model"),
Expand Down
15 changes: 8 additions & 7 deletions docs/examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version":

# Load model
model = AutoAWQForCausalLM.from_pretrained(
model_path, **{"low_cpu_mem_usage": True, "use_cache": False}
model_path, low_cpu_mem_usage=True, use_cache=False
)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

Expand Down Expand Up @@ -50,7 +50,9 @@ quant_path = 'vicuna-7b-v1.5-awq'
quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" }

# Load model
model = AutoAWQForCausalLM.from_pretrained(model_path)
model = AutoAWQForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, use_cache=False, device_map="cuda",
)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

# Define data loading methods
Expand Down Expand Up @@ -106,7 +108,7 @@ quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version":

# Load model
model = AutoAWQForCausalLM.from_pretrained(
model_path, **{"low_cpu_mem_usage": True, "use_cache": False}
model_path, low_cpu_mem_usage=True, use_cache=False, device_map="cuda",
)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

Expand Down Expand Up @@ -149,7 +151,7 @@ quant_config = { "zero_point": True, "q_group_size": 64, "w_bit": 4, "version":

# Load model
model = AutoAWQForCausalLM.from_pretrained(
model_path, **{"low_cpu_mem_usage": True, "use_cache": False}
model_path, low_cpu_mem_usage=True, use_cache=False, device_map="cuda",
)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

Expand Down Expand Up @@ -195,7 +197,7 @@ quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version":

# Load model
model = AutoAWQForCausalLM.from_pretrained(
model_path, device_map="cuda", **{"low_cpu_mem_usage": True}
model_path, low_cpu_mem_usage=True, device_map="cuda",
)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

Expand Down Expand Up @@ -234,9 +236,8 @@ llama_cpp_path = '/workspace/llama.cpp'
quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 6, "version": "GEMM" }

# Load model
# NOTE: pass safetensors=True to load safetensors
model = AutoAWQForCausalLM.from_pretrained(
model_path, **{"low_cpu_mem_usage": True, "use_cache": False}
model_path, low_cpu_mem_usage=True, use_cache=False, device_map="cuda",
)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

Expand Down
2 changes: 1 addition & 1 deletion examples/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def main():
parser.add_argument("--no-low_cpu_mem_usage", action="store_false", dest="low_cpu_mem_usage", help="Don't use low CPU memory")
parser.add_argument("--use_cache", action="store_true", help="Use cache")
parser.add_argument("--no-use_cache", action="store_false", dest="use_cache", help="Don't use cache")
parser.add_argument("--device_map", type=str, default=None, help="Device map for loading the pretrained model")
parser.add_argument("--device_map", type=str, default="auto", help="Device map for loading the pretrained model")

parser.set_defaults(zero_point=True, low_cpu_mem_usage=True, use_cache=None)

Expand Down
2 changes: 1 addition & 1 deletion examples/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

# Load model
model = AutoAWQForCausalLM.from_pretrained(
model_path, **{"low_cpu_mem_usage": True, "use_cache": False}
model_path, low_cpu_mem_usage=True, use_cache=False
)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

Expand Down
1 change: 1 addition & 0 deletions scripts/runpod_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
version = "GEMM",
low_cpu_mem_usage = True,
use_cache = False,
device_map = "auto",
)
cli_args = " ".join([f"--{k}" if isinstance(v, bool) else f"--{k} {v}" for k,v in cli_args.items()])

Expand Down

0 comments on commit 1523880

Please sign in to comment.