diff --git a/mistral_example.py b/mistral_example.py index ae46e75..b7bc06c 100644 --- a/mistral_example.py +++ b/mistral_example.py @@ -9,6 +9,9 @@ from transformers.models.mistral.modeling_mistral import MistralAttention from transformers import AutoTokenizer, AutoModelForCausalLM import transformers +import torch + +device = torch.device("mps" if torch.backends.mps.is_available() else "auto") original_mistral_forward = MistralAttention.forward self_extend_forward = partial(MistralSE.self_extend_forward, group_size_1=4, group_size_2=1024) @@ -17,18 +20,18 @@ model_path = 'mistralai/Mistral-7B-Instruct-v0.1' config = transformers.AutoConfig.from_pretrained(model_path) config.sliding_window = 200000000 # disable mistral's default SWA mechanism (4096), mistral's true window is 8192. -model = AutoModelForCausalLM.from_pretrained(model_path, config=config, device_map="auto") +model = AutoModelForCausalLM.from_pretrained(model_path, config=config, device_map=device) tokenizer = AutoTokenizer.from_pretrained(model_path) -model.eval() +model.eval().to(device) # In the example task file, the passkey is placed within the last 4096 tokens, this means, if you use SWA, mistral will successfully find the passkey. for line in open("passkey_examples_10k.jsonl", "r"): example = json.loads(line) prompt_postfix = "What is the pass key? The pass key is " prompt = example["input"] + prompt_postfix - input_ids = tokenizer(prompt, return_tensors="pt").input_ids + input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) print( "-----------------------------------" ) print( f"#Tokens of Prompt:", input_ids.shape[1], end=" " ) print( "Passkey target:", example["target"] ) @@ -47,6 +50,9 @@ answer = answer.replace("\n", "\\n") print( answer ) print( "-----------------------------------\n" ) + + if torch.backends.mps.is_available(): + torch.mps.empty_cache()