diff --git a/config.py b/config.py index ec34664..3b6f75b 100644 --- a/config.py +++ b/config.py @@ -1,146 +1,327 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -MNLI_LABEL = ['entailment', 'neutral', 'contradiction', - 'entailment\'', 'neutral\'', 'contradiction\''] -EQ_LABEL = ['equivalent', 'not_equivalent', 'equivalent\'', 'not_equivalent\''] -ENTAIL_LABEL = ['entailment', 'not_entailment', 'entailment\'', - 'not_entailment\'', '0', '1', '0\'', '1\'', 0, 1] +TIME_LIMIT = (3 * 60 + 50) * 60 # 3 hours and 50 minutes + +MNLI_LABEL = [ + "entailment", + "neutral", + "contradiction", + "entailment'", + "neutral'", + "contradiction'", +] +EQ_LABEL = ["equivalent", "not_equivalent", "equivalent'", "not_equivalent'"] +ENTAIL_LABEL = [ + "entailment", + "not_entailment", + "entailment'", + "not_entailment'", + "0", + "1", + "0'", + "1'", + 0, + 1, +] LABEL_SET = { # 'positive\'', 'negative\'' is used for label constraint due to a bug of TextAttack repo. - 'sst2': ['positive', 'negative', 'positive\'', 'negative\'', '0', '1', '0\'', '1\'', 0, 1], - 'mnli': MNLI_LABEL, - 'mnli_mismatched': MNLI_LABEL, - 'mnli_matched': MNLI_LABEL, - 'qqp': EQ_LABEL, - 'qnli': ENTAIL_LABEL, - 'rte': ENTAIL_LABEL, - 'cola': ['unacceptable', 'acceptable', 'unacceptable\'', 'acceptable\''], - 'mrpc': EQ_LABEL, - 'wnli': ENTAIL_LABEL, - 'mmlu': ['A', 'B', 'C', 'D', 'A\'', 'B\'', 'C\'', 'D\'', 'a', 'b', 'c', 'd', 'a\'', 'b\'', 'c\'', 'd\''], + "sst2": [ + "positive", + "negative", + "positive'", + "negative'", + "0", + "1", + "0'", + "1'", + 0, + 1, + ], + "mnli": MNLI_LABEL, + "mnli_mismatched": MNLI_LABEL, + "mnli_matched": MNLI_LABEL, + "qqp": EQ_LABEL, + "qnli": ENTAIL_LABEL, + "rte": ENTAIL_LABEL, + "cola": ["unacceptable", "acceptable", "unacceptable'", "acceptable'"], + "mrpc": EQ_LABEL, + "wnli": ENTAIL_LABEL, + "mmlu": [ + "A", + "B", + "C", + "D", + "A'", + "B'", + "C'", + "D'", + "a", + "b", + "c", + "d", + "a'", + "b'", + "c'", + "d'", + ], # do not change the word 'nothing' in prompts. - 'squad_v2': ['unanswerable', 'unanswerable\''], - 'iwslt': ['translate', 'translate\''], - 'un_multi': ['translate', 'translate\''], - 'math': ['math', 'math\''], - 'bool_logic': ['True', 'False', 'True\'', 'False\'', "bool", "boolean", "bool\'", "boolean\'"], - 'valid_parentheses': ['Valid', 'Invalid', 'Valid\'', 'Invalid\'', 'matched', 'matched\'', 'valid', 'invalid', 'valid\'', 'invalid\''], + "squad_v2": ["unanswerable", "unanswerable'"], + "iwslt": ["translate", "translate'"], + "un_multi": ["translate", "translate'"], + "math": ["math", "math'"], + "bool_logic": [ + "True", + "False", + "True'", + "False'", + "bool", + "boolean", + "bool'", + "boolean'", + ], + "valid_parentheses": [ + "Valid", + "Invalid", + "Valid'", + "Invalid'", + "matched", + "matched'", + "valid", + "invalid", + "valid'", + "invalid'", + ], } GENERATE_LEN = { - 'sst2': {'google/flan-t5-large': 20, 'vicuna-13b': 5, 'google/flan-ul2': 20, "chatgpt": 2, 'llama2-13b': 2, 'llama2-13b-chat': 2, 'llama2-7b-chat': 2, 'llama2-7b-chat': 2}, - 'mnli': {'google/flan-t5-large': 20, 'vicuna-13b': 5, 'google/flan-ul2': 20, "chatgpt": 3, 'llama2-13b': 3, 'llama2-13b-chat': 3, 'llama2-7b-chat': 3, 'llama2-7b-chat': 3}, - 'qqp': {'google/flan-t5-large': 20, 'vicuna-13b': 5, 'google/flan-ul2': 20, "chatgpt": 3, 'llama2-13b': 4, 'llama2-13b-chat': 4, 'llama2-7b-chat': 4, 'llama2-7b-chat': 4}, - 'qnli': {'google/flan-t5-large': 20, 'vicuna-13b': 5, 'google/flan-ul2': 20, "chatgpt": 4, 'llama2-13b': 2, 'llama2-13b-chat': 2, 'llama2-7b-chat': 2, 'llama2-7b-chat': 2}, - 'rte': {'google/flan-t5-large': 20, 'vicuna-13b': 5, 'google/flan-ul2': 20, "chatgpt": 4, 'llama2-13b': 3, 'llama2-13b-chat': 3, 'llama2-7b-chat': 3, 'llama2-7b-chat': 3}, - 'cola': {'google/flan-t5-large': 20, 'vicuna-13b': 5, 'google/flan-ul2': 20, "chatgpt": 3, 'llama2-13b': 3, 'llama2-13b-chat': 3, 'llama2-7b-chat': 3, 'llama2-7b-chat': 3}, - 'mrpc': {'google/flan-t5-large': 20, 'vicuna-13b': 5, 'google/flan-ul2': 20, "chatgpt": 3, 'llama2-13b': 2, 'llama2-13b-chat': 2, 'llama2-7b-chat': 2, 'llama2-7b-chat': 2}, - 'wnli': {'google/flan-t5-large': 20, 'vicuna-13b': 5, 'google/flan-ul2': 20, "chatgpt": 4, 'llama2-13b': 3, 'llama2-13b-chat': 3, 'llama2-7b-chat': 3, 'llama2-7b-chat': 3}, - 'mmlu': {'google/flan-t5-large': 2, 'vicuna-13b': 2, 'google/flan-ul2': 2, "chatgpt": 2, 'llama2-13b': 3, 'llama2-13b-chat': 3, 'llama2-7b-chat': 3, 'llama2-7b-chat': 3}, - 'squad_v2': {'google/flan-t5-large': 20, 'google/flan-ul2': 20, "chatgpt": 20}, - 'iwslt': {'google/flan-t5-large': 70, 'google/flan-ul2': 70, 'chatgpt': 70}, - 'un_multi': {'google/flan-t5-large': 140, 'google/flan-ul2': 140, 'chatgpt': 140}, - 'math': {'google/flan-t5-large': 20, 'google/flan-ul2': 20, 'chatgpt': 20}, - 'bool_logic': {'google/flan-t5-large': 4, }, + "sst2": { + "google/flan-t5-large": 20, + "vicuna-13b": 5, + "google/flan-ul2": 20, + "chatgpt": 2, + "llama2-13b": 2, + "llama2-13b-chat": 2, + "llama2-7b-chat": 2, + "llama2-7b-chat": 2, + }, + "mnli": { + "google/flan-t5-large": 20, + "vicuna-13b": 5, + "google/flan-ul2": 20, + "chatgpt": 3, + "llama2-13b": 3, + "llama2-13b-chat": 3, + "llama2-7b-chat": 3, + "llama2-7b-chat": 3, + }, + "qqp": { + "google/flan-t5-large": 20, + "vicuna-13b": 5, + "google/flan-ul2": 20, + "chatgpt": 3, + "llama2-13b": 4, + "llama2-13b-chat": 4, + "llama2-7b-chat": 4, + "llama2-7b-chat": 4, + }, + "qnli": { + "google/flan-t5-large": 20, + "vicuna-13b": 5, + "google/flan-ul2": 20, + "chatgpt": 4, + "llama2-13b": 2, + "llama2-13b-chat": 2, + "llama2-7b-chat": 2, + "llama2-7b-chat": 2, + }, + "rte": { + "google/flan-t5-large": 20, + "vicuna-13b": 5, + "google/flan-ul2": 20, + "chatgpt": 4, + "llama2-13b": 3, + "llama2-13b-chat": 3, + "llama2-7b-chat": 3, + "llama2-7b-chat": 3, + }, + "cola": { + "google/flan-t5-large": 20, + "vicuna-13b": 5, + "google/flan-ul2": 20, + "chatgpt": 3, + "llama2-13b": 3, + "llama2-13b-chat": 3, + "llama2-7b-chat": 3, + "llama2-7b-chat": 3, + }, + "mrpc": { + "google/flan-t5-large": 20, + "vicuna-13b": 5, + "google/flan-ul2": 20, + "chatgpt": 3, + "llama2-13b": 2, + "llama2-13b-chat": 2, + "llama2-7b-chat": 2, + "llama2-7b-chat": 2, + }, + "wnli": { + "google/flan-t5-large": 20, + "vicuna-13b": 5, + "google/flan-ul2": 20, + "chatgpt": 4, + "llama2-13b": 3, + "llama2-13b-chat": 3, + "llama2-7b-chat": 3, + "llama2-7b-chat": 3, + }, + "mmlu": { + "google/flan-t5-large": 2, + "vicuna-13b": 2, + "google/flan-ul2": 2, + "chatgpt": 2, + "llama2-13b": 3, + "llama2-13b-chat": 3, + "llama2-7b-chat": 3, + "llama2-7b-chat": 3, + }, + "squad_v2": {"google/flan-t5-large": 20, "google/flan-ul2": 20, "chatgpt": 20}, + "iwslt": {"google/flan-t5-large": 70, "google/flan-ul2": 70, "chatgpt": 70}, + "un_multi": {"google/flan-t5-large": 140, "google/flan-ul2": 140, "chatgpt": 140}, + "math": {"google/flan-t5-large": 20, "google/flan-ul2": 20, "chatgpt": 20}, + "bool_logic": { + "google/flan-t5-large": 4, + }, } MODEL_SET = [ - 'google/flan-t5-large', - 'EleutherAI/gpt-neox-20b', - 'tiiuae/falcon-40b-instruct', + "google/flan-t5-large", + "EleutherAI/gpt-neox-20b", + "tiiuae/falcon-40b-instruct", # 'facebook/opt-66b', - 'llama-13b', - 'llama2-13b', - 'llama2-13b-chat', - 'llama2-7b', - 'llama2-7b-chat', - 'vicuna-13b', - 'vicuna-13b-v1.3', - 'google/flan-ul2', - 'cerebras/Cerebras-GPT-13B', - 'databricks/dolly-v1-6b', - 'chatgpt', - 'gpt4', - 'nemo', + "llama-13b", + "llama2-13b", + "llama2-13b-chat", + "llama2-7b", + "llama2-7b-chat", + "vicuna-13b", + "vicuna-13b-v1.3", + "google/flan-ul2", + "cerebras/Cerebras-GPT-13B", + "databricks/dolly-v1-6b", + "chatgpt", + "gpt4", + "nemo", ] NEMO_TRT_MODELS = ["GPT-8B-SFT", "GPT-8B-RLHF"] -NEMO_PROMPT = "System\n\nUser\n{prompt}\nAssistant\n" -NEMO_STEERLM_PROMPT = """\x00System A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. \x11User {prompt} \x11Assistant \x12quality:4,toxicity:0,humor:0,creativity:0,helpfulness:4,correctness:4,coherence:4,complexity:4,verbosity:0""" # Verbosity was changed to 0 from default 4 +NEMO_PROMPT = ( + "System\n\nUser\n{prompt}\nAssistant\n" +) +NEMO_STEERLM_PROMPT = """\x00System A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. \x11User {prompt} \x11Assistant \x12quality:4,toxicity:0,humor:0,creativity:0,helpfulness:4,correctness:4,coherence:4,complexity:4,verbosity:0""" # Verbosity? @Evelina Bakhturina LABEL_TO_ID = { - 'mmlu': {'A': 'A', 'B': 'B', 'C': 'C', 'D': 'D'}, - 'sst2': {'negative': 0, 'positive': 1, '0': 0, '1': 1, 0: 0, 1: 1}, - 'mnli': {'entailment': 0, 'neutral': 1, 'contradiction': 2, '0': 0, '1': 1, '2': 2, 0: 0, 1: 1, 2: 2}, - 'mnli_mismatched': {'entailment': 0, 'neutral': 1, 'contradiction': 2, '0': 0, '1': 1, '2': 2, 0: 0, 1: 1, 2: 2}, - 'mnli_matched': {'entailment': 0, 'neutral': 1, 'contradiction': 2, '0': 0, '1': 1, '2': 2, 0: 0, 1: 1, 2: 2}, - 'qqp': {'equivalent': 1, 'not_equivalent': 0, '0': 0, '1': 1, 0: 0, 1: 1}, - 'qnli': {'entailment': 0, 'not_entailment': 1, '0': 0, '1': 1, 0: 0, 1: 1}, - 'rte': {'entailment': 0, 'not_entailment': 1, '0': 0, '1': 1, 0: 0, 1: 1}, - 'cola': {'unacceptable': 0, 'acceptable': 1, '0': 0, '1': 1, 0: 0, 1: 1}, - 'mrpc': {'equivalent': 1, 'not_equivalent': 0, '0': 0, '1': 1, 0: 0, 1: 1}, - 'wnli': {'entailment': 1, 'not_entailment': 0, '0': 0, '1': 1, 0: 0, 1: 1}, + "mmlu": {"A": "A", "B": "B", "C": "C", "D": "D"}, + "sst2": {"negative": 0, "positive": 1, "0": 0, "1": 1, 0: 0, 1: 1}, + "mnli": { + "entailment": 0, + "neutral": 1, + "contradiction": 2, + "0": 0, + "1": 1, + "2": 2, + 0: 0, + 1: 1, + 2: 2, + }, + "mnli_mismatched": { + "entailment": 0, + "neutral": 1, + "contradiction": 2, + "0": 0, + "1": 1, + "2": 2, + 0: 0, + 1: 1, + 2: 2, + }, + "mnli_matched": { + "entailment": 0, + "neutral": 1, + "contradiction": 2, + "0": 0, + "1": 1, + "2": 2, + 0: 0, + 1: 1, + 2: 2, + }, + "qqp": {"equivalent": 1, "not_equivalent": 0, "0": 0, "1": 1, 0: 0, 1: 1}, + "qnli": {"entailment": 0, "not_entailment": 1, "0": 0, "1": 1, 0: 0, 1: 1}, + "rte": {"entailment": 0, "not_entailment": 1, "0": 0, "1": 1, 0: 0, 1: 1}, + "cola": {"unacceptable": 0, "acceptable": 1, "0": 0, "1": 1, 0: 0, 1: 1}, + "mrpc": {"equivalent": 1, "not_equivalent": 0, "0": 0, "1": 1, 0: 0, 1: 1}, + "wnli": {"entailment": 1, "not_entailment": 0, "0": 0, "1": 1, 0: 0, 1: 1}, } ID_TO_LABEL = { - 'mmlu': {'A': 'A', 'B': 'B', 'C': 'C', 'D': 'D'}, - 'sst2': {0: 'negative', 1: 'positive'}, - 'mnli': {0: 'entailment', 1: 'neutral', 2: 'contradiction'}, - 'mnli_matched': {0: 'entailment', 1: 'neutral', 2: 'contradiction'}, - 'mnli_mismatched': {0: 'entailment', 1: 'neutral', 2: 'contradiction'}, - 'qqp': {1: 'equivalent', 0: 'not_equivalent'}, - 'qnli': {0: 'entailment', 1: 'not_entailment'}, - 'rte': {0: 'entailment', 1: 'not_entailment'}, - 'cola': {0: 'unacceptable', 1: 'acceptable'}, - 'mrpc': {1: 'equivalent', 0: 'not_equivalent'}, - 'wnli': {1: 'entailment', 0: 'not_entailment'}, + "mmlu": {"A": "A", "B": "B", "C": "C", "D": "D"}, + "sst2": {0: "negative", 1: "positive"}, + "mnli": {0: "entailment", 1: "neutral", 2: "contradiction"}, + "mnli_matched": {0: "entailment", 1: "neutral", 2: "contradiction"}, + "mnli_mismatched": {0: "entailment", 1: "neutral", 2: "contradiction"}, + "qqp": {1: "equivalent", 0: "not_equivalent"}, + "qnli": {0: "entailment", 1: "not_entailment"}, + "rte": {0: "entailment", 1: "not_entailment"}, + "cola": {0: "unacceptable", 1: "acceptable"}, + "mrpc": {1: "equivalent", 0: "not_equivalent"}, + "wnli": {1: "entailment", 0: "not_entailment"}, } SUPPORTED_LANGUAGES = { - 'google/flan-t5-large': ['en', 'de', 'fr'], - 'google/flan-ul2': ['en', 'de', 'fr'], - 'vicuna-13b': ['en', 'de', 'fr'], - 'llama2-13b-chat': ['en', 'de', 'fr'], - 'chatgpt': ['en', 'de', 'fr'], - 'nemo': ['en'] + "google/flan-t5-large": ["en", "de", "fr"], + "google/flan-ul2": ["en", "de", "fr"], + "vicuna-13b": ["en", "de", "fr"], + "llama2-13b-chat": ["en", "de", "fr"], + "chatgpt": ["en", "de", "fr"], + "nemo": ["en"], } LANGUAGES = { - 'ar': 'Arabic', - 'de': 'German', - 'en': 'English', - 'es': 'Spanish', - 'fr': 'French', - 'ru': 'Russian', - 'zh': 'Chinese', - 'it': 'Italian', - 'nl': 'Dutch', - 'ro': 'Romanian', - 'ja': 'Japanese', - 'ko': 'Korean', + "ar": "Arabic", + "de": "German", + "en": "English", + "es": "Spanish", + "fr": "French", + "ru": "Russian", + "zh": "Chinese", + "it": "Italian", + "nl": "Dutch", + "ro": "Romanian", + "ja": "Japanese", + "ko": "Korean", } MATH_QUESTION_TYPES = { - 'algebra_linear_1d': ' linear algebra ', - 'algebra_linear_2d': ' linear algebra ', - 'algebra_sequence_next_term': ' given a sequence predict the next term ', - 'arithmetic_addition_sub_multiple': ' arithmetic addition and subtraction ', - 'arithmetic_mul_div_multiple': ' arithmetic multiplication and division ', - 'arithmetic_mixed': ' arithmetic addition, subtraction, multiplication and division ', - 'arithmetic_nearest_integer_root': ' arithmetic nearest integer root ', - 'comparison_closest': ' compare which one of given numbers is closest to target number ', - 'comparison_kth_biggest': ' compare which one of given numbers is kth biggest or smallest ', - 'comparison_pair': ' comparison which one of given numbers is bigger or smaller ', - 'measurement_conversion': ' measurement conversion ', - 'numbers_base_conversion': ' numbers base conversion ', - 'numbers_div_remainder': ' numbers division and remainder ', - 'numbers_gcd': ' numbers greatest common divisor ', - 'numbers_is_factor': ' if one number is a factor of antoher number ', - 'number_is_prime': ' if a number is prime ', - 'numbers_lcm': ' least common multiple ', - 'numbers_place_value': ' place value ', - 'numbers_round_number': ' round number ', - 'polynomials_evaluate': ' polynomials evaluate ', + "algebra_linear_1d": " linear algebra ", + "algebra_linear_2d": " linear algebra ", + "algebra_sequence_next_term": " given a sequence predict the next term ", + "arithmetic_addition_sub_multiple": " arithmetic addition and subtraction ", + "arithmetic_mul_div_multiple": " arithmetic multiplication and division ", + "arithmetic_mixed": " arithmetic addition, subtraction, multiplication and division ", + "arithmetic_nearest_integer_root": " arithmetic nearest integer root ", + "comparison_closest": " compare which one of given numbers is closest to target number ", + "comparison_kth_biggest": " compare which one of given numbers is kth biggest or smallest ", + "comparison_pair": " comparison which one of given numbers is bigger or smaller ", + "measurement_conversion": " measurement conversion ", + "numbers_base_conversion": " numbers base conversion ", + "numbers_div_remainder": " numbers division and remainder ", + "numbers_gcd": " numbers greatest common divisor ", + "numbers_is_factor": " if one number is a factor of antoher number ", + "number_is_prime": " if a number is prime ", + "numbers_lcm": " least common multiple ", + "numbers_place_value": " place value ", + "numbers_round_number": " round number ", + "polynomials_evaluate": " polynomials evaluate ", } diff --git a/inference.py b/inference.py index 04319d5..3c97bd7 100644 --- a/inference.py +++ b/inference.py @@ -4,19 +4,28 @@ try: import openai except ImportError: - print("OpenAI API is not installed, please install it by running: pip install openai") + print( + "OpenAI API is not installed, please install it by running: pip install openai" + ) -import sys -import os import math +import os +import sys + +from nemo_utils.run_time_tracker import RunTimeTracker +from nemo_utils.state_manager import StateManager +from utils import generate_predict_step, handle_timeout_error + dir_path = os.path.dirname(os.path.realpath(__file__)) sys.path.append(os.path.join(dir_path, "nemo_utils")) try: # to use Nemo generation directly import os + + from megatron_gpt_eval import nemo_generate, nemo_init_model from omegaconf import OmegaConf - from megatron_gpt_eval import nemo_init_model, nemo_generate + NEMO_AVAILABLE = True except: NEMO_AVAILABLE = False @@ -29,11 +38,13 @@ except: TRT_AVAILABLE = False -from config import LABEL_SET, LABEL_TO_ID, NEMO_PROMPT, NEMO_STEERLM_PROMPT -from tqdm import tqdm -from typing import List from collections import defaultdict +from typing import List + from joblib import Parallel, delayed +from tqdm import tqdm + +from config import LABEL_SET, LABEL_TO_ID, NEMO_PROMPT, NEMO_STEERLM_PROMPT """ This clss implements the inference of the model (including create the model). @@ -41,7 +52,6 @@ class Inference(object): - def __init__(self, args): self.error_analysis = False self.args = args @@ -54,28 +64,34 @@ def create_model(self): ChatGPT is a special case, we use the openai api to create the model. """ - if self.model not in ['chatgpt', 'gpt4']: - import torch + if self.model not in ["chatgpt", "gpt4"]: import os + import torch + """ Here you can add you own model. """ - if self.model == 'google/flan-t5-large': - from transformers import T5Tokenizer, T5ForConditionalGeneration + if self.model == "google/flan-t5-large": + from transformers import T5ForConditionalGeneration, T5Tokenizer self.tokenizer = T5Tokenizer.from_pretrained( - self.model, device_map="cuda") - self.pipe = T5ForConditionalGeneration.from_pretrained(self.model, device_map="cuda") + self.model, device_map="cuda" + ) + self.pipe = T5ForConditionalGeneration.from_pretrained( + self.model, device_map="cuda" + ) - elif self.model == 'EleutherAI/gpt-neox-20b': + elif self.model == "EleutherAI/gpt-neox-20b": from transformers import GPTNeoXForCausalLM, GPTNeoXTokenizerFast self.tokenizer = GPTNeoXTokenizerFast.from_pretrained( - self.model, device_map="auto") + self.model, device_map="auto" + ) self.pipe = GPTNeoXForCausalLM.from_pretrained( - self.model, device_map="auto", torch_dtype=torch.float16) + self.model, device_map="auto", torch_dtype=torch.float16 + ) # elif self.model.lower() == 'facebook/opt-66b': # from transformers import AutoModelForCausalLM, AutoTokenizer @@ -84,58 +100,76 @@ def create_model(self): # self.tokenizer = AutoTokenizer.from_pretrained(model, device_map="auto", use_fast=False) # self.pipe = AutoModelForCausalLM.from_pretrained(model, device_map="auto", torch_dtype=torch.float16) - elif self.model.lower() in ["llama-13b", "llama2-13b", 'llama2-13b-chat', 'llama2-7b', 'llama2-7b-chat']: - + elif self.model.lower() in [ + "llama-13b", + "llama2-13b", + "llama2-13b-chat", + "llama2-7b", + "llama2-7b-chat", + ]: from transformers import LlamaForCausalLM, LlamaTokenizer model_dir = os.path.join(self.args.model_dir, self.model) self.tokenizer = LlamaTokenizer.from_pretrained( - model_dir, device_map="auto") + model_dir, device_map="auto" + ) self.pipe = LlamaForCausalLM.from_pretrained( - model_dir, device_map="auto", torch_dtype=torch.float16) + model_dir, device_map="auto", torch_dtype=torch.float16 + ) elif self.model.lower() in ["vicuna-13b", "vicuna-13b-v1.3"]: - from transformers import AutoModelForCausalLM, AutoTokenizer model_dir = os.path.join(self.args.model_dir, self.model) self.tokenizer = AutoTokenizer.from_pretrained( - model_dir, device_map="auto", use_fast=False) + model_dir, device_map="auto", use_fast=False + ) self.pipe = AutoModelForCausalLM.from_pretrained( - model_dir, device_map="auto", torch_dtype=torch.float16) + model_dir, device_map="auto", torch_dtype=torch.float16 + ) elif self.model == "google/flan-ul2": - - from transformers import T5ForConditionalGeneration, AutoTokenizer + from transformers import AutoTokenizer, T5ForConditionalGeneration self.tokenizer = AutoTokenizer.from_pretrained(self.model) self.pipe = T5ForConditionalGeneration.from_pretrained( - self.model, torch_dtype=torch.bfloat16, device_map="auto") + self.model, torch_dtype=torch.bfloat16, device_map="auto" + ) elif self.model == "tiiuae/falcon-40b-instruct": - from transformers import AutoTokenizer, AutoModelForCausalLM + from transformers import AutoModelForCausalLM, AutoTokenizer self.tokenizer = AutoTokenizer.from_pretrained(self.model) self.pipe = AutoModelForCausalLM.from_pretrained( - self.model, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto",) + self.model, + torch_dtype=torch.bfloat16, + trust_remote_code=True, + device_map="auto", + ) elif self.model == "cerebras/Cerebras-GPT-13B": - from transformers import AutoTokenizer, AutoModelForCausalLM + from transformers import AutoModelForCausalLM, AutoTokenizer self.tokenizer = AutoTokenizer.from_pretrained( - self.model, device_map="auto") + self.model, device_map="auto" + ) self.pipe = AutoModelForCausalLM.from_pretrained( - self.model, device_map="auto", torch_dtype=torch.float16) + self.model, device_map="auto", torch_dtype=torch.float16 + ) elif self.model == "databricks/dolly-v1-6b": from transformers import AutoModelForCausalLM, AutoTokenizer self.tokenizer = AutoTokenizer.from_pretrained( - "databricks/dolly-v1-6b", device_map="auto", padding_side="left") + "databricks/dolly-v1-6b", device_map="auto", padding_side="left" + ) self.pipe = AutoModelForCausalLM.from_pretrained( - "databricks/dolly-v1-6b", device_map="auto", torch_dtype=torch.float16) + "databricks/dolly-v1-6b", + device_map="auto", + torch_dtype=torch.float16, + ) elif self.model == "nemo": if self.args.nemo_use_server: @@ -147,20 +181,29 @@ def create_model(self): if not NEMO_AVAILABLE: raise ImportError("NeMo is not installed") - + dir_path = os.path.dirname(os.path.realpath(__file__)) - cfg = os.path.abspath(f"{dir_path}/nemo_utils/nemo_cfgs/megatron_gpt_inference.yaml") + cfg = os.path.abspath( + f"{dir_path}/nemo_utils/nemo_cfgs/megatron_gpt_inference.yaml" + ) cfg = OmegaConf.load(cfg) cfg.inference.tokens_to_generate = self.args.generate_len cfg.inference.batch_size = self.args.batch_size - + # update NeMo config if non None nemo-* args provided for arg_name in vars(self.args): - if arg_name.startswith("nemo_") and getattr(self.args, arg_name) is not None: + if ( + arg_name.startswith("nemo_") + and getattr(self.args, arg_name) is not None + ): if arg_name == "nemo_model_path": - if self.args.nemo_model_path is None or not os.path.exists(self.args.nemo_model_path): - raise ValueError(f"NeMo model path {self.args.nemo_model_path} does not exist") + if self.args.nemo_model_path is None or not os.path.exists( + self.args.nemo_model_path + ): + raise ValueError( + f"NeMo model path {self.args.nemo_model_path} does not exist" + ) else: cfg.gpt_model_file = self.args.nemo_model_path @@ -168,51 +211,80 @@ def create_model(self): cfg.trainer.devices = self.args.nemo_devices else: cfg.inference[arg_name] = getattr(self.args, arg_name) - + self.nemo_cfg = cfg self.pipe, self.nemo_trainer = nemo_init_model(cfg) else: raise NotImplementedError("The model is not implemented!") def process_input(self, prompt, raw_data): - if self.args.dataset in ["cola", "sst2", "mrpc", "qqp", "mnli", "qnli", "rte", "wnli"]: + if self.args.dataset in [ + "cola", + "sst2", + "mrpc", + "qqp", + "mnli", + "qnli", + "rte", + "wnli", + ]: return self._process_cls_input(prompt, raw_data) elif self.args.dataset == "mmlu": return self._process_qa_input(prompt, raw_data) elif self.args.dataset == "squad_v2": return self._process_squad_v2_input(prompt, raw_data) - elif self.args.dataset in ['iwslt', 'un_multi']: + elif self.args.dataset in ["iwslt", "un_multi"]: return self._process_trans_input(prompt, raw_data) - elif self.args.dataset == 'math': + elif self.args.dataset == "math": return self._process_math_input(prompt, raw_data) - elif self.args.dataset == 'bool_logic': + elif self.args.dataset == "bool_logic": return self._process_bool_logic_input(prompt, raw_data) - elif self.args.dataset == 'valid_parentheses': + elif self.args.dataset == "valid_parentheses": return self._process_valid_parentheses_input(prompt, raw_data) else: raise NotImplementedError("The dataset is not implemented!") def process_pred(self, pred): - if self.args.dataset in ["cola", "sst2", "mrpc", "qqp", "mnli", "qnli", "rte", "wnli"]: + if self.args.dataset in [ + "cola", + "sst2", + "mrpc", + "qqp", + "mnli", + "qnli", + "rte", + "wnli", + ]: return self._process_cls_pred(pred) elif self.args.dataset == "mmlu": return self._process_qa_pred(pred) elif self.args.dataset == "squad_v2": return self._process_squad_v2_pred(pred) - elif self.args.dataset in ['iwslt', 'un_multi']: + elif self.args.dataset in ["iwslt", "un_multi"]: return self._process_trans_pred(pred) - elif self.args.dataset == 'math': + elif self.args.dataset == "math": return self._process_math_pred(pred) - elif self.args.dataset == 'bool_logic': + elif self.args.dataset == "bool_logic": return self._process_bool_logic_pred(pred) - elif self.args.dataset == 'valid_parentheses': + elif self.args.dataset == "valid_parentheses": return self._process_valid_parentheses_pred(pred) else: raise NotImplementedError("The dataset is not implemented!") def eval(self, preds, gts): - - if self.args.dataset in ["cola", "sst2", "mrpc", "qqp", "mnli", "qnli", "rte", "wnli", "mmlu", "bool_logic", "valid_parentheses"]: + if self.args.dataset in [ + "cola", + "sst2", + "mrpc", + "qqp", + "mnli", + "qnli", + "rte", + "wnli", + "mmlu", + "bool_logic", + "valid_parentheses", + ]: if self.args.dataset == "mmlu": preds = [pred.lower() for pred in preds] gts = [gt.lower() for gt in gts] @@ -224,14 +296,13 @@ def eval(self, preds, gts): return sum(a == b for a, b in zip(preds, gts)) / len(preds) elif self.args.dataset == "squad_v2": - from metrics.squad_v2.squad_v2 import SquadV2 + metric = SquadV2() model_output = [] for id, pred in zip(gts, preds): - if pred == "unanswerable": no_ans_prob = 1 pred = "" @@ -239,25 +310,28 @@ def eval(self, preds, gts): no_ans_prob = 0 model_output.append( - {"id": id, "prediction_text": pred, "no_answer_probability": no_ans_prob}) + { + "id": id, + "prediction_text": pred, + "no_answer_probability": no_ans_prob, + } + ) references = self.args.data.get_reference() - score = metric.compute( - predictions=model_output, references=references) + score = metric.compute(predictions=model_output, references=references) return score["f1"] / 100 - elif self.args.dataset in ['iwslt', 'un_multi']: - + elif self.args.dataset in ["iwslt", "un_multi"]: from metrics.bleu.bleu import Bleu + metric = Bleu() results = metric.compute(predictions=preds, references=gts) # it need to /100 to get the proper bleu score (in alignment with other dataset, e.g., glue) - return results['bleu'] / 100 - - elif self.args.dataset == 'math': + return results["bleu"] / 100 + elif self.args.dataset == "math": processed_preds = [] processed_gts = [] for pred, gt in zip(preds, gts): @@ -270,14 +344,16 @@ def eval(self, preds, gts): processed_preds.append(pred.lower()) processed_gts.append(gt.lower()) - acc = sum(a == b for a, b in zip(processed_preds, - processed_gts)) / len(processed_gts) + acc = sum(a == b for a, b in zip(processed_preds, processed_gts)) / len( + processed_gts + ) return acc else: raise NotImplementedError( - "Eval this dataset {self.args.dataset} is not implemented!") + "Eval this dataset {self.args.dataset} is not implemented!" + ) def predict(self, prompt=None, max_samples=1000): assert self.args.data is not None, "Please load data first!" @@ -292,11 +368,15 @@ def predict_batch(self, prompt: List[str], max_samples=1000): assert self.args.data is not None, "Please load data first!" if self.model in ["chatgpt", "gpt4"]: - raise NotImplementedError("Batch inference is not implemented for openai api, use predict() instead.") + raise NotImplementedError( + "Batch inference is not implemented for openai api, use predict() instead." + ) else: - results = self.predict_by_local_inference_batch(self.model, prompt, max_samples) + results = self.predict_by_local_inference_batch( + self.model, prompt, max_samples + ) return results - + def predict_by_openai_api(self, model, prompt): data_len = len(self.args.data) if data_len > 1000: @@ -308,9 +388,7 @@ def predict_by_openai_api(self, model, prompt): gts = [] for idx in tqdm(range(data_len)): - - raw_data = self.args.data.get_content_by_idx( - idx, self.args.dataset) + raw_data = self.args.data.get_content_by_idx(idx, self.args.dataset) input_text, gt = self.process_input(prompt, raw_data) raw_pred = self.call_openai_api(model, input_text) @@ -329,7 +407,6 @@ def predict_by_openai_api(self, model, prompt): score = self.eval(preds, gts) return score - def predict_by_local_inference(self, model, prompt, max_samples=1000): data_len = len(self.args.data) @@ -350,7 +427,7 @@ def predict_by_local_inference(self, model, prompt, max_samples=1000): if isinstance(raw_pred, list) and len(list) == 1: raw_pred = raw_pred[0] pred = self.process_pred(raw_pred) - + preds.append(pred) gts.append(gt) if check_correctness > 0 and self.args.verbose: @@ -362,14 +439,35 @@ def predict_by_local_inference(self, model, prompt, max_samples=1000): score = self.eval(preds, gts) return score - def predict_by_local_inference_batch(self, model: str, prompts: List[str], max_samples: int=1000): + @handle_timeout_error(state_generator_func=generate_predict_step) + def predict_step(self, batch_id, gts, preds, model, all_data=None): + """ + Predicts the output for a single batch of prompts, and updates the state_manager with the results. + """ + if not RunTimeTracker().has_sufficient_time(): + raise TimeoutError("Time limit exceeded.") + start_idx = batch_id * self.args.batch_size + batch = all_data[start_idx : start_idx + self.args.batch_size] + input_texts, batch_gts = zip(*batch) + gts.extend(batch_gts) + raw_batch_preds = self.pred_by_generation(input_text=input_texts, model=model) + preds.extend([self.process_pred(raw_pred) for raw_pred in raw_batch_preds]) + + def predict_by_local_inference_batch( + self, model: str, prompts: List[str], max_samples: int = 1000 + ): + state_manager = StateManager() + data_len = len(self.args.data) if data_len > max_samples: data_len = max_samples scores = [] # TODO: why are we re-doing this multiple times? - raw_data = [self.args.data.get_content_by_idx(idx, self.args.dataset) for idx in range(data_len)] + raw_data = [ + self.args.data.get_content_by_idx(idx, self.args.dataset) + for idx in range(data_len) + ] if isinstance(prompts, str): prompts = [prompts] @@ -378,52 +476,69 @@ def predict_by_local_inference_batch(self, model: str, prompts: List[str], max_s # so that we can use batching across all prompts all_data = [] for prompt in prompts: - all_data.extend([self.process_input(prompt, raw_data[idx]) for idx in range(len(raw_data))]) + all_data.extend( + [ + self.process_input(prompt, raw_data[idx]) + for idx in range(len(raw_data)) + ] + ) total_num_samples = len(all_data) raw_dataset_size = len(raw_data) assert total_num_samples == len(prompts) * raw_dataset_size - gts = [] - preds = [] + gts = state_manager.state.get("gts", []) + preds = state_manager.state.get("preds", []) num_iter = math.ceil(1.0 * total_num_samples / self.args.batch_size) - for batch_id in tqdm(range(num_iter)): - start_idx = batch_id * self.args.batch_size - batch = all_data[start_idx : start_idx + self.args.batch_size] - input_texts, batch_gts = zip(*batch) - gts.extend(batch_gts) - - raw_batch_preds = self.pred_by_generation(input_text=input_texts, model=model) - preds.extend([self.process_pred(raw_pred) for raw_pred in raw_batch_preds]) - + start_idx = state_manager.state.get("batch_id", 0) if state_manager.state else 0 + for batch_id in tqdm( + range(num_iter)[start_idx:], desc=f"Predicting {self.args.dataset}" + ): + self.predict_step( + batch_id=batch_id, gts=gts, preds=preds, model=model, all_data=all_data + ) assert len(preds) == total_num_samples # split preds and gts into lists of lists, where each sublist is the preds/gts for a single prompt - preds = [preds[i:i+raw_dataset_size] for i in range(0, len(preds), raw_dataset_size)] - gts = [gts[i:i+raw_dataset_size] for i in range(0, len(gts), raw_dataset_size)] + preds = [ + preds[i : i + raw_dataset_size] + for i in range(0, len(preds), raw_dataset_size) + ] + gts = [ + gts[i : i + raw_dataset_size] for i in range(0, len(gts), raw_dataset_size) + ] # calculate scores for each prompt - scores = [self.eval(prompt_preds, prompt_gts) for prompt_preds, prompt_gts in zip(preds, gts)] + scores = [ + self.eval(prompt_preds, prompt_gts) + for prompt_preds, prompt_gts in zip(preds, gts) + ] + state_manager.update_state( + {"batch_id": 0, "gts": [], "preds": [], "all_data": []} + ) + state_manager.save_state() return scores - + def call_openai_api(self, model, prompt): import openai + from config import OPENAI_API + openai.api_key = OPENAI_API - if model in ['chatgpt']: + if model in ["chatgpt"]: response = openai.Completion.create( model="gpt-3.5-turbo-instruct", prompt=prompt, max_tokens=20, - temperature=0 + temperature=0, ) - result = response['choices'][0]['text'] + result = response["choices"][0]["text"] else: response = openai.ChatCompletion.create( - model='gpt-4-0613', - messages=[ - {"role": "user", "content": prompt}, - ] - ) - result = response['choices'][0]['message']['content'] + model="gpt-4-0613", + messages=[ + {"role": "user", "content": prompt}, + ], + ) + result = response["choices"][0]["message"]["content"] return result def pred_by_generation(self, input_text: List[str], model: str) -> List[str]: @@ -434,49 +549,71 @@ def pred_by_generation(self, input_text: List[str], model: str) -> List[str]: input_text (str or List[str]): the input text model (str): the model name """ - out = 'error!' + out = "error!" if model == "nemo": if self.args.nemo_use_server and TRT_AVAILABLE: - preds = query_llm(url=self.args.nemo_url, model_name=self.args.nemo_model_path,prompts=input_text, - max_output_token=self.args.generate_len, - top_k=self.args.nemo_top_k, - top_p=self.args.nemo_top_p, - temperature=self.args.nemo_temperature, - init_timeout=self.args.nemo_init_timeout) + preds = query_llm( + url=self.args.nemo_url, + model_name=self.args.nemo_model_path, + prompts=input_text, + max_output_token=self.args.generate_len, + top_k=self.args.nemo_top_k, + top_p=self.args.nemo_top_p, + temperature=self.args.nemo_temperature, + init_timeout=self.args.nemo_init_timeout, + ) return [p[0] for p in preds] else: - out = nemo_generate(model=self.pipe, prompts=input_text,trainer=self.nemo_trainer,cfg=self.nemo_cfg,batch_size=self.args.batch_size) + out = nemo_generate( + model=self.pipe, + prompts=input_text, + trainer=self.nemo_trainer, + cfg=self.nemo_cfg, + batch_size=self.args.batch_size, + ) preds = [] for pred in out: preds.extend(pred["sentences"]) - + assert len(preds) == len(input_text) # remove context from output for i in range(len(input_text)): - preds[i] = preds[i][len(input_text[i]):] - preds[i] = preds[i].replace("System", "").replace("system", "").replace("", "").strip() + preds[i] = preds[i][len(input_text[i]) :] + preds[i] = ( + preds[i] + .replace("System", "") + .replace("system", "") + .replace("", "") + .strip() + ) preds[i] = preds[i].split("\n")[0].strip() return preds # pad to the longest sequence in the batch and truncate all the sequences to the max model's length if not self.tokenizer.pad_token: self.tokenizer.pad_token = self.tokenizer.eos_token - input_ids = self.tokenizer(input_text, padding="longest", truncation=True, return_tensors="pt").input_ids.to("cuda") - - if 't5' in model or 'ul2' in model: - outputs = self.pipe.generate(input_ids, max_length=self.args.generate_len, early_stopping=True) + input_ids = self.tokenizer( + input_text, padding="longest", truncation=True, return_tensors="pt" + ).input_ids.to("cuda") + + if "t5" in model or "ul2" in model: + outputs = self.pipe.generate( + input_ids, max_length=self.args.generate_len, early_stopping=True + ) out = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) - elif model == 'EleutherAI/gpt-neox-20b': - outputs = self.pipe.generate(input_ids, - # do_sample=True, - temperature=0.00001, - # max_length=50, - max_new_tokens=self.args.generate_len, - early_stopping=True, - pad_token_id=self.tokenizer.eos_token_id) + elif model == "EleutherAI/gpt-neox-20b": + outputs = self.pipe.generate( + input_ids, + # do_sample=True, + temperature=0.00001, + # max_length=50, + max_new_tokens=self.args.generate_len, + early_stopping=True, + pad_token_id=self.tokenizer.eos_token_id, + ) out = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) @@ -484,96 +621,111 @@ def pred_by_generation(self, input_text: List[str], model: str) -> List[str]: outputs = self.pipe.generate(input_ids) out = self.tokenizer.decode(outputs[0], skip_special_tokens=True) - elif model in ["llama-13b", "llama2-13b", 'llama2-13b-chat', "vicuna-13b", "vicuna-13b-v1.3", "llama2-7b", "llama2-7b-chat"]: - outputs = self.pipe.generate(input_ids, - # temperature=1.0, - max_new_tokens=self.args.generate_len, - early_stopping=True) + elif model in [ + "llama-13b", + "llama2-13b", + "llama2-13b-chat", + "vicuna-13b", + "vicuna-13b-v1.3", + "llama2-7b", + "llama2-7b-chat", + ]: + outputs = self.pipe.generate( + input_ids, + # temperature=1.0, + max_new_tokens=self.args.generate_len, + early_stopping=True, + ) out = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) - elif model in ['databricks/dolly-v1-6b', 'cerebras/Cerebras-GPT-13B']: - outputs = self.pipe.generate(input_ids, - temperature=0, - max_new_tokens=self.args.generate_len, - pad_token_id=self.tokenizer.eos_token_id, - early_stopping=True) + elif model in ["databricks/dolly-v1-6b", "cerebras/Cerebras-GPT-13B"]: + outputs = self.pipe.generate( + input_ids, + temperature=0, + max_new_tokens=self.args.generate_len, + pad_token_id=self.tokenizer.eos_token_id, + early_stopping=True, + ) out = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) elif model == "tiiuae/falcon-40b-instruct": - outputs = self.pipe.generate(input_ids, - temperature=0, - max_new_tokens=self.args.generate_len, - pad_token_id=self.tokenizer.eos_token_id, - early_stopping=True) + outputs = self.pipe.generate( + input_ids, + temperature=0, + max_new_tokens=self.args.generate_len, + pad_token_id=self.tokenizer.eos_token_id, + early_stopping=True, + ) out = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) return out def _process_valid_parentheses_input(self, prompt, raw_data): - question, label = raw_data['question'], raw_data['answer'] - input_text = prompt + '\n' + question, label = raw_data["question"], raw_data["answer"] + input_text = prompt + "\n" if self.args.shot > 0: - input_text += "\n" + \ - self.args.data.get_few_shot_examples(raw_data['task']) + input_text += "\n" + self.args.data.get_few_shot_examples(raw_data["task"]) - input_text += ("Question: " + question) + input_text += "Question: " + question if self.args.model == "nemo": input_text = NEMO_PROMPT.replace("{prompt}", input_text) if not self.args.steerlm else NEMO_STEERLM_PROMPT.replace("{prompt}", input_text) else: - input_text += '\nAnswer: ' + input_text += "\nAnswer: " return input_text, label def _process_bool_logic_input(self, prompt, raw_data): - question, label = raw_data['question'], raw_data['answer'] - input_text = prompt + '\n' + question, label = raw_data["question"], raw_data["answer"] + input_text = prompt + "\n" if self.args.shot > 0: - input_text += "\n" + \ - self.args.data.get_few_shot_examples(raw_data['task']) + input_text += "\n" + self.args.data.get_few_shot_examples(raw_data["task"]) - input_text += ("Question: " + question) + input_text += "Question: " + question if self.args.model == "nemo": input_text = NEMO_PROMPT.replace("{prompt}", input_text) if not self.args.steerlm else NEMO_STEERLM_PROMPT.replace("{prompt}", input_text) else: - input_text += '\nAnswer: ' + input_text += "\nAnswer: " return input_text, label def _process_math_input(self, prompt, raw_data): from config import MATH_QUESTION_TYPES - question_type, question, label = MATH_QUESTION_TYPES[raw_data['task'] - ], raw_data['question'], raw_data['answer'] - input_text = prompt.format(question_type) + '\n' + + question_type, question, label = ( + MATH_QUESTION_TYPES[raw_data["task"]], + raw_data["question"], + raw_data["answer"], + ) + input_text = prompt.format(question_type) + "\n" if self.args.shot > 0: - input_text += "\n" + \ - self.args.data.get_few_shot_examples(raw_data['task']) + input_text += "\n" + self.args.data.get_few_shot_examples(raw_data["task"]) - input_text += ("Question: " + question) + input_text += "Question: " + question if self.args.model == "nemo": input_text = NEMO_PROMPT.replace("{prompt}", input_text) if not self.args.steerlm else NEMO_STEERLM_PROMPT.replace("{prompt}", input_text) else: - input_text += '\nAnswer: ' + input_text += "\nAnswer: " return input_text, label def _process_trans_input(self, prompt, raw_data): from config import LANGUAGES - source, target, task = raw_data['source'], raw_data['target'], raw_data['task'] - src_lang, des_lang = task.split('-') - input_text = prompt.format( - LANGUAGES[src_lang], LANGUAGES[des_lang]) + '\n' + + source, target, task = raw_data["source"], raw_data["target"], raw_data["task"] + src_lang, des_lang = task.split("-") + input_text = prompt.format(LANGUAGES[src_lang], LANGUAGES[des_lang]) + "\n" if self.args.shot > 0: - input_text += "\n"+self.args.data.get_few_shot_examples(task) + input_text += "\n" + self.args.data.get_few_shot_examples(task) input_text += content if self.args.model == "nemo": input_text = NEMO_PROMPT.replace("{prompt}", input_text) if not self.args.steerlm else NEMO_STEERLM_PROMPT.replace("{prompt}", input_text) else: - input_text += '\nAnswer: ' + input_text += "\nAnswer: " return input_text, target def _process_squad_v2_input(self, prompt, raw_data): @@ -581,8 +733,7 @@ def _process_squad_v2_input(self, prompt, raw_data): input_text = prompt if self.args.shot > 0: - input_text += "\n" + \ - self.args.data.get_few_shot_examples(self.args.dataset) + input_text += "\n" + self.args.data.get_few_shot_examples(self.args.dataset) input_text += content if self.args.model == "nemo": @@ -598,8 +749,9 @@ def _process_qa_input(self, prompt, raw_data): input_text = prompt.format(task) + "\n" if self.args.shot > 0: - input_text += "\n" + \ - self.args.data.get_few_shot_examples(task.replace(" ", "_")) + input_text += "\n" + self.args.data.get_few_shot_examples( + task.replace(" ", "_") + ) input_text += content if self.args.model == "nemo": @@ -616,24 +768,28 @@ def _process_cls_input(self, prompt, raw_data): if self.args.shot > 0: few_shot_examples = self.args.data.get_few_shot_examples(self.args.dataset) - input_text += "\n"+few_shot_examples + input_text += "\n" + few_shot_examples if self.args.dataset == "sst2" or self.args.dataset == "cola": input_text += "Sentence: " input_text += content # TODO fix few shot examples for NeMo prompt if self.args.model == "nemo": - input_text = NEMO_PROMPT.replace("{prompt}", input_text) if not self.args.steerlm else NEMO_STEERLM_PROMPT.replace("{prompt}", input_text) + input_text = ( + NEMO_PROMPT.replace("{prompt}", input_text) + if not self.args.steerlm + else NEMO_STEERLM_PROMPT.replace("{prompt}", input_text) + ) else: - input_text += ' Answer: ' + input_text += " Answer: " return input_text, label def _process_bool_logic_pred(self, raw_pred): pred = raw_pred.lower() pred = pred.replace("", "") pred = pred.replace("", "") - pred = pred.replace("", "") # for nemo - pred = pred.strip(",._\"\'-+=!?()&^%$#@:\\|\{\}[]<>/`\n\t\r\v\f ") + pred = pred.replace("", "") # for nemo + pred = pred.strip(",._\"'-+=!?()&^%$#@:\\|\{\}[]<>/`\n\t\r\v\f ") return pred @@ -641,7 +797,7 @@ def _process_valid_parentheses_pred(self, raw_pred): pred = raw_pred.lower() pred = pred.replace("", "") pred = pred.replace("", "") - pred = pred.strip(",._\"\'-+=!?()&^%$#@:\\|\{\}[]<>/`\n\t\r\v\f ") + pred = pred.strip(",._\"'-+=!?()&^%$#@:\\|\{\}[]<>/`\n\t\r\v\f ") return pred @@ -649,7 +805,7 @@ def _process_math_pred(self, raw_pred): pred = raw_pred.lower() pred = pred.replace("", "") pred = pred.replace("", "") - pred = pred.strip(",._\"\'-+=!?()&^%$#@:\\|\{\}[]<>/`\n\t\r\v\f ") + pred = pred.strip(",._\"'-+=!?()&^%$#@:\\|\{\}[]<>/`\n\t\r\v\f ") return pred @@ -657,7 +813,7 @@ def _process_trans_pred(self, raw_pred): pred = raw_pred.lower() pred = pred.replace("", "") pred = pred.replace("", "") - pred = pred.strip(",._\"\'-+=!?()&^%$#@:\\|\{\}[]<>/`\n\t\r\v\f ") + pred = pred.strip(",._\"'-+=!?()&^%$#@:\\|\{\}[]<>/`\n\t\r\v\f ") return pred @@ -665,7 +821,7 @@ def _process_squad_v2_pred(self, raw_pred): pred = raw_pred.lower() pred = pred.replace("", "") pred = pred.replace("", "") - pred = pred.strip(",._\"\'-+=!?()&^%$#@:\\|\{\}[]<>/`\n\t\r\v\f ") + pred = pred.strip(",._\"'-+=!?()&^%$#@:\\|\{\}[]<>/`\n\t\r\v\f ") return pred @@ -674,14 +830,16 @@ def _process_cls_pred(self, raw_pred): pred = pred.replace("", "") pred = pred.replace("", "") - pred = pred.strip(",._\"\'-+=!?()&^%$#@:\\|\{\}[]<>/`\n\t\r\v\f ") + pred = pred.strip(",._\"'-+=!?()&^%$#@:\\|\{\}[]<>/`\n\t\r\v\f ") pred = pred.split(" ")[-1] - pred = pred.strip(",._\"\'-+=!?()&^%$#@:\\|\{\}[]<>/`\n\t\r\v\f ") + pred = pred.strip(",._\"'-+=!?()&^%$#@:\\|\{\}[]<>/`\n\t\r\v\f ") if pred in LABEL_SET[self.args.dataset]: pred = LABEL_TO_ID[self.args.dataset][pred] else: - self.args.logger.debug(f"The raw_pred label ({raw_pred}) -> processed_pred {pred} is not in label set: {LABEL_TO_ID[self.args.dataset]}") + self.args.logger.debug( + f"The raw_pred label ({raw_pred}) -> processed_pred {pred} is not in label set: {LABEL_TO_ID[self.args.dataset]}" + ) pred = -1 return pred @@ -691,15 +849,15 @@ def _process_qa_pred(self, raw_pred): pred = pred.replace("", "") pred = pred.replace("", "") - pred = pred.strip(",._\"\'-+=!?()&^%$#@:\\|\{\}[]<>/`\n\t\r\v\f ") + pred = pred.strip(",._\"'-+=!?()&^%$#@:\\|\{\}[]<>/`\n\t\r\v\f ") pred = pred.split(" ")[-1] - pred = pred.strip(",._\"\'-+=!?()&^%$#@:\\|\{\}[]<>/`\n\t\r\v\f ") + pred = pred.strip(",._\"'-+=!?()&^%$#@:\\|\{\}[]<>/`\n\t\r\v\f ") if pred not in LABEL_SET[self.args.dataset]: + self.args.logger.warn("The original label : '{}'.".format(raw_pred)) self.args.logger.warn( - "The original label : '{}'.".format(raw_pred)) - self.args.logger.warn( - "The predicted label: '{}' is not in label set.".format(pred)) - pred = 'no_answer' + "The predicted label: '{}' is not in label set.".format(pred) + ) + pred = "no_answer" return pred diff --git a/main.py b/main.py index c438a57..ca37494 100644 --- a/main.py +++ b/main.py @@ -3,174 +3,295 @@ import argparse import os -import logging from config import * +from config import MODEL_SET, NEMO_TRT_MODELS, TIME_LIMIT from dataload import create_dataset from inference import Inference +from nemo_utils.run_time_tracker import RunTimeTracker +from nemo_utils.state_manager import StateManager from prompt_attack.attack import create_attack from prompt_attack.goal_function import create_goal_function -from config import MODEL_SET, NEMO_TRT_MODELS from prompt_attack.utils import CLASS_REGISTRY +from utils import generate_semantic_attack_state, handle_timeout_error, serialize_args -def create_logger(log_path): - - logging.getLogger().handlers = [] - - logger = logging.getLogger(__name__) - logger.setLevel(logging.INFO) - - formatter = logging.Formatter( - '%(asctime)s - %(levelname)s - %(message)s') - - file_handler = logging.FileHandler(log_path) - file_handler.setFormatter(formatter) - file_handler.setLevel(logging.INFO) - logger.addHandler(file_handler) - - return logger - def _add_nemo_args(parser): """Add NeMo arguments to update inference config, see promptbench/nemo_utils/nemo_cfgs/megatron_gpt_inference.yaml""" - group = parser.add_argument_group(title='NeMo arguments') - group.add_argument('--nemo_model_path', type=str, default=None, help='path to .nemo model file') - group.add_argument('--nemo_greedy', action='store_true', help='Whether or not to use sampling ; use greedy decoding otherwise') - group.add_argument('--nemo_top_k', type=int, default=0, help='The number of highest probability vocabulary tokens to keep for top-k-filtering.') - group.add_argument('--nemo_top_p', type=float, default=0.9, help='If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.') - group.add_argument('--nemo_temperature', type=float, default=1.0, help='sampling temperature') - group.add_argument('--nemo_add_BOS', action='store_true', help='add the bos token at the begining of the prompt') - group.add_argument('--nemo_all_probs', action='store_true', help='whether return the log prob for all the tokens in vocab') - group.add_argument('--nemo_repetition_penalty', type=float, default=1.2, help='The parameter for repetition penalty. 1.0 means no penalty.') - group.add_argument('--nemo_devices', type=int, default=1, help='Number of GPUs to use for inference') - group.add_argument('--nemo_url', type=str, default="localhost:8000", help='url for server inference') - group.add_argument('--nemo_init_timeout', type=float, default=600.0, help='timeout for server inference') - group.add_argument('--nemo_use_server', action='store_true', help='enable server inference') - group.add_argument('--nemo_use_prompt', action='store_true', help='use NeMo prompt for aligned models') - group.add_argument('--steerlm', action='store_true', help='use steerlm prompt for aligned models') + group = parser.add_argument_group(title="NeMo arguments") + group.add_argument( + "--nemo_model_path", type=str, default=None, help="path to .nemo model file" + ) + group.add_argument( + "--nemo_greedy", + action="store_true", + help="Whether or not to use sampling ; use greedy decoding otherwise", + ) + group.add_argument( + "--nemo_top_k", + type=int, + default=0, + help="The number of highest probability vocabulary tokens to keep for top-k-filtering.", + ) + group.add_argument( + "--nemo_top_p", + type=float, + default=0.9, + help="If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.", + ) + group.add_argument( + "--nemo_temperature", type=float, default=1.0, help="sampling temperature" + ) + group.add_argument( + "--nemo_add_BOS", + action="store_true", + help="add the bos token at the begining of the prompt", + ) + group.add_argument( + "--nemo_all_probs", + action="store_true", + help="whether return the log prob for all the tokens in vocab", + ) + group.add_argument( + "--nemo_repetition_penalty", + type=float, + default=1.2, + help="The parameter for repetition penalty. 1.0 means no penalty.", + ) + group.add_argument( + "--nemo_devices", + type=int, + default=1, + help="Number of GPUs to use for inference", + ) + group.add_argument( + "--nemo_url", + type=str, + default="localhost:8000", + help="url for server inference", + ) + group.add_argument( + "--nemo_init_timeout", + type=float, + default=600.0, + help="timeout for server inference", + ) + group.add_argument( + "--nemo_use_server", action="store_true", help="enable server inference" + ) + group.add_argument( + "--nemo_use_prompt", + action="store_true", + help="use NeMo prompt for aligned models", + ) + group.add_argument( + "--steerlm", action="store_true", help="use steerlm prompt for aligned models" + ) return parser + def get_args(): parser = argparse.ArgumentParser() - parser.add_argument('--model', type=str, default='google/flan-t5-large', choices=MODEL_SET, - help="model name. For LLAMA also specify `--model_dir`, and NeMo models: `--nemo_model_path` and `--nemo_infer_cfg`.") - parser.add_argument('--dataset', type=str, default='bool_logic', choices=["sst2", "cola", "qqp", - "mnli", "mnli_matched", "mnli_mismatched", - "qnli", "wnli", "rte", "mrpc", - "mmlu", "squad_v2", "un_multi", "iwslt", "math", - "bool_logic", "valid_parentheses", - ]) - - parser.add_argument('--query_budget', type=float, default=float("inf")) - parser.add_argument('--attack', type=str, default='deepwordbug', choices=[ - 'textfooler', - 'textbugger', - 'bertattack', - 'deepwordbug', - 'checklist', - 'stresstest', - 'semantic', - 'no', - 'noattack', - 'clean', - 'nemo', - 'flexible_attack' - ]) + parser.add_argument( + "--model", + type=str, + default="google/flan-t5-large", + choices=MODEL_SET, + help="model name. For LLAMA also specify `--model_dir`, and NeMo models: `--nemo_model_path` and `--nemo_infer_cfg`.", + ) + parser.add_argument( + "--dataset", + type=str, + default="bool_logic", + choices=[ + "sst2", + "cola", + "qqp", + "mnli", + "mnli_matched", + "mnli_mismatched", + "qnli", + "wnli", + "rte", + "mrpc", + "mmlu", + "squad_v2", + "un_multi", + "iwslt", + "math", + "bool_logic", + "valid_parentheses", + ], + ) + + parser.add_argument("--query_budget", type=float, default=float("inf")) + parser.add_argument( + "--attack", + type=str, + default="deepwordbug", + choices=[ + "textfooler", + "textbugger", + "bertattack", + "deepwordbug", + "checklist", + "stresstest", + "semantic", + "no", + "noattack", + "clean", + "nemo", + "flexible_attack", + ], + ) parser.add_argument("--verbose", action="store_true") - parser.add_argument('--output_dir', type=str, default='./') - parser.add_argument('--model_dir', type=str, default=None, help="path to the model directory for LLAMA and NeMo models") - parser.add_argument('--shot', type=int, default=0) - parser.add_argument('--generate_len', type=int, default=4) - parser.add_argument('--prompt_selection', action='store_true') - parser.add_argument('--max_samples', type=int, default=1000, help="max number of samples to use from the dataset") - parser.add_argument('--batch_size', type=int, default=32, help='batch size for inference') - parser.add_argument('--transforms', nargs='*', type=str, help=f'List of transformations for the flexible attack, list of available transformations: {CLASS_REGISTRY["transformations"]}', default=[]) - parser.add_argument('--constraints', nargs='*', type=str, help=f'List of constraints for the flexible attack, list of available constraints: {CLASS_REGISTRY["constraints"]}', default=[]) - parser.add_argument('--search_method', type=str, help=f'Search method for the flexible attack, list of available search methods: {CLASS_REGISTRY["search_methods"]}', default='') - + parser.add_argument("--output_dir", type=str, default="./") + parser.add_argument( + "--model_dir", + type=str, + default=None, + help="path to the model directory for LLAMA and NeMo models", + ) + parser.add_argument("--shot", type=int, default=0) + parser.add_argument("--generate_len", type=int, default=4) + parser.add_argument("--prompt_selection", action="store_true") + parser.add_argument( + "--max_samples", + type=int, + default=1000, + help="max number of samples to use from the dataset", + ) + parser.add_argument( + "--batch_size", type=int, default=32, help="batch size for inference" + ) + parser.add_argument( + "--transforms", + nargs="*", + type=str, + help=f'List of transformations for the flexible attack, list of available transformations: {CLASS_REGISTRY["transformations"]}', + default=[], + ) + parser.add_argument( + "--constraints", + nargs="*", + type=str, + help=f'List of constraints for the flexible attack, list of available constraints: {CLASS_REGISTRY["constraints"]}', + default=[], + ) + parser.add_argument( + "--search_method", + type=str, + help=f'Search method for the flexible attack, list of available search methods: {CLASS_REGISTRY["search_methods"]}', + default="", + ) parser = _add_nemo_args(parser) args = parser.parse_args() return args -def prompt_selection(logger, inference_model, prompts, max_samples=1000): - """Select the top 3 prompts to attack based on the accuracy - """ - - import time - - # start_time = time.time() - # prompt_dict = {} - # for prompt in prompts: - # acc = inference_model.predict(prompt, max_samples=max_samples) - # prompt_dict[prompt] = acc - # logger.info("{:.2f}, {}\n".format(acc*100, prompt)) - # print("{:.2f}, {}\n".format(acc*100, prompt)) - # print("Default Time: ", time.time() - start_time) -# - # start_time = time.time() +def prompt_selection(logger, inference_model, prompts, max_samples): + """Select the top 3 prompts to attack based on the accuracy""" + if "predict_batch" in dir(inference_model): acc = inference_model.predict_batch(prompts, max_samples=max_samples) prompt_dict = {prompt: acc[idx] for idx, prompt in enumerate(prompts)} else: - logger.warning("The model does not support batch inference! Running sequentially...") + logger.warning( + "The model does not support batch inference! Running sequentially..." + ) prompt_dict = {} for prompt in prompts: acc = inference_model.predict(prompt) prompt_dict[prompt] = acc - logger.info("{:.2f}, {}\n".format(acc*100, prompt)) - sorted_prompts = sorted(prompt_dict.items(), - key=lambda x: x[1], reverse=True) + logger.info("{:.2f}, {}\n".format(acc * 100, prompt)) + sorted_prompts = sorted(prompt_dict.items(), key=lambda x: x[1], reverse=True) return sorted_prompts +@handle_timeout_error(generate_semantic_attack_state) +def semantic_attack(language, inference_model, results_dir, dataset): + state_manager = StateManager() + from prompts.semantic_atk_prompts import SEMANTIC_ADV_PROMPT_SET + + prompts_dict = SEMANTIC_ADV_PROMPT_SET[dataset] + prompts = prompts_dict[language] + if "predict_batch" in dir(inference_model): + acc = inference_model.predict_batch(prompts) + for idx in range(len(prompts)): + state_manager.logger.info( + "Language: {}, acc: {:.2f}%, prompt: {}\n".format( + language, acc[idx] * 100, prompts[idx] + ) + ) + with open(results_dir + args.save_file_name + ".txt", "a+") as f: + f.write( + "Language: {}, acc: {:.2f}%, prompt: {}\n".format( + language, acc * 100, prompt + ) + ) # TODO: Fix this + else: + state_manager.logger.warninig( + "The model does not support batch inference! Running sequentially..." + ) + for prompt in prompts: + acc = inference_model.predict(prompt) + state_manager.logger.info( + "Language: {}, acc: {:.2f}%, prompt: {}\n".format( + language, acc * 100, prompt + ) + ) + with open(results_dir + args.save_file_name + ".txt", "a+") as f: + f.write( + "Language: {}, acc: {:.2f}%, prompt: {}\n".format( + language, acc * 100, prompt + ) + ) + + def attack(args, inference_model, RESULTS_DIR): + state_manager = StateManager() + if args.attack == "semantic": from prompts.semantic_atk_prompts import SEMANTIC_ADV_PROMPT_SET prompts_dict = SEMANTIC_ADV_PROMPT_SET[args.dataset] - + latest_language = state_manager.get("language", None) for language in prompts_dict.keys(): - prompts = prompts_dict[language] - if "predict_batch" in dir(inference_model): - acc = inference_model.predict_batch(prompts) - for idx in range(len(prompts)): - args.logger.info("Language: {}, acc: {:.2f}%, prompt: {}\n".format(language, acc[idx]*100, prompts[idx])) - - with open(RESULTS_DIR+args.save_file_name+".txt", "a+") as f: - f.write("Language: {}, acc: {:.2f}%, prompt: {}\n".format(language, acc*100, prompt)) - else: - args.logger.warninig("The model does not support batch inference! Running sequentially...") - for prompt in prompts: - acc = inference_model.predict(prompt) - args.logger.info("Language: {}, acc: {:.2f}%, prompt: {}\n".format( - language, acc*100, prompt)) - - with open(RESULTS_DIR+args.save_file_name+".txt", "a+") as f: - f.write("Language: {}, acc: {:.2f}%, prompt: {}\n".format( - language, acc*100, prompt)) - - - elif args.attack in ['no', 'noattack', 'clean']: + if latest_language is not None and language < latest_language: + continue + semantic_attack( + language=language, + inference_model=inference_model, + results_dir=RESULTS_DIR, + dataset=args.dataset, + ) + + elif args.attack in [ + "no", + "noattack", + "clean", + ]: # Not adding saving state for this because it's done in one pass from config import PROMPT_SET_Promptbench_advglue as prompt_raw - prompt = prompt_raw['clean'][args.dataset][0] + + prompt = prompt_raw["clean"][args.dataset][0] acc = inference_model.predict(prompt) - args.logger.info(f"Prompt: {prompt}, acc: {acc}%\n") - with open(RESULTS_DIR+args.save_file_name+".txt", "a+") as f: - f.write("Prompt: {}, acc: {:.2f}%\n".format(prompt, acc*100)) + state_manager.logger.info(f"Prompt: {prompt}, acc: {acc}%\n") + with open(RESULTS_DIR + args.save_file_name + ".txt", "a+") as f: + f.write("Prompt: {}, acc: {:.2f}%\n".format(prompt, acc * 100)) else: if args.shot == 0: - from prompts.zero_shot.task_oriented import TASK_ORIENTED_PROMPT_SET from prompts.zero_shot.role_oriented import ROLE_ORIENTED_PROMPT_SET + from prompts.zero_shot.task_oriented import TASK_ORIENTED_PROMPT_SET elif args.shot == 3: - from prompts.three_shot.task_oriented import TASK_ORIENTED_PROMPT_SET from prompts.three_shot.role_oriented import ROLE_ORIENTED_PROMPT_SET + from prompts.three_shot.task_oriented import TASK_ORIENTED_PROMPT_SET else: raise NotImplementedError( - "Currently we only implemented zero-shot and three-shot!") + "Currently we only implemented zero-shot and three-shot!" + ) goal_function = create_goal_function(args, inference_model) attack = create_attack(args, goal_function) @@ -178,42 +299,82 @@ def attack(args, inference_model, RESULTS_DIR): # each dataset has different predifiend prompts, the number of prompts can vary run_list = [ TASK_ORIENTED_PROMPT_SET[args.dataset], - ROLE_ORIENTED_PROMPT_SET[args.dataset], + # ROLE_ORIENTED_PROMPT_SET[args.dataset], ] - for prompts in run_list: + for prompts in run_list: # TODO: add saving if needed # select attack prompts that give the highest accuracy - sorted_prompts = prompt_selection( - args.logger, inference_model, prompts, args.max_samples) + if state_manager.state.get("sorted_prompts", None) is None: + sorted_prompts = prompt_selection( + state_manager.logger, inference_model, prompts, args.max_samples + ) + state_manager.add_to_state("sorted_prompts", sorted_prompts) + state_manager.logger.info(f"Sorted prompts: {sorted_prompts}") + state_manager.save_state() + else: + sorted_prompts = state_manager.state["sorted_prompts"] + if args.prompt_selection: for prompt, acc in sorted_prompts: - args.logger.info( - "Prompt: {}, acc: {:.2f}%\n".format(prompt, acc*100)) - with open(RESULTS_DIR+args.save_file_name+".txt", "a+") as f: - f.write("Prompt: {}, acc: {:.2f}%\n".format(prompt, acc*100)) + state_manager.logger.info( + "Prompt: {}, acc: {:.2f}%\n".format(prompt, acc * 100) + ) + with open(RESULTS_DIR + args.save_file_name + ".txt", "a+") as f: + f.write("Prompt: {}, acc: {:.2f}%\n".format(prompt, acc * 100)) continue - for init_prompt, init_acc in sorted_prompts[:3]: + if state_manager.state.get("prompt_idx", None) is None: + prompt_idx = 0 + state_manager.add_to_state("prompt_idx", prompt_idx) + state_manager.save_state() + else: + prompt_idx = state_manager.state["prompt_idx"] + + for i, (init_prompt, init_acc) in enumerate( + sorted_prompts[:3], start=prompt_idx + ): if init_acc > 0: - args.logger.info("Init prompt: {}".format(init_prompt)) - init_acc, attacked_prompt, attacked_acc, dropped_acc = attack.attack(init_prompt) - args.logger.info("Original prompt: {}".format(init_prompt)) - args.logger.info("Attacked prompt: {}".format(attacked_prompt.encode('utf-8'))) - args.logger.info("Original acc: {:.2f}%, attacked acc: {:.2f}%, dropped acc: {:.2f}%".format(init_acc*100, attacked_acc*100, dropped_acc*100)) - - with open(RESULTS_DIR+args.save_file_name+".txt", "a+") as f: + state_manager.logger.info("Init prompt: {}".format(init_prompt)) + ( + init_acc, + attacked_prompt, + attacked_acc, + dropped_acc, + ) = attack.attack(init_prompt) + state_manager.logger.info("Original prompt: {}".format(init_prompt)) + state_manager.logger.info( + "Attacked prompt: {}".format(attacked_prompt.encode("utf-8")) + ) + state_manager.logger.info( + "Original acc: {:.2f}%, attacked acc: {:.2f}%, dropped acc: {:.2f}%".format( + init_acc * 100, attacked_acc * 100, dropped_acc * 100 + ) + ) + + with open(RESULTS_DIR + args.save_file_name + ".txt", "a+") as f: f.write("Original prompt: {}\n".format(init_prompt)) - f.write("Attacked prompt: {}\n".format( - attacked_prompt.encode('utf-8'))) - f.write("Original acc: {:.2f}%, attacked acc: {:.2f}%, dropped acc: {:.2f}%\n\n".format( - init_acc*100, attacked_acc*100, dropped_acc*100)) + f.write( + "Attacked prompt: {}\n".format( + attacked_prompt.encode("utf-8") + ) + ) + f.write( + "Original acc: {:.2f}%, attacked acc: {:.2f}%, dropped acc: {:.2f}%\n\n".format( + init_acc * 100, attacked_acc * 100, dropped_acc * 100 + ) + ) else: - with open(RESULTS_DIR+args.save_file_name+".txt", "a+") as f: + with open(RESULTS_DIR + args.save_file_name + ".txt", "a+") as f: f.write("Init acc is 0, skip this prompt\n") f.write("Original prompt: {}\n".format(init_prompt)) - f.write("Original acc: {:.2f}% \n\n".format( - init_acc*100, init_prompt)) + f.write( + "Original acc: {:.2f}% \n\n".format( + init_acc * 100, init_prompt + ) + ) + state_manager.add_to_state("prompt_idx", i) + state_manager.save_state() def main(args): @@ -221,6 +382,7 @@ def main(args): if args.dataset == "iwslt" or args.dataset == "un_multi": from config import SUPPORTED_LANGUAGES + supported_languages = SUPPORTED_LANGUAGES[args.model] save_dir += "/" @@ -231,11 +393,13 @@ def main(args): for DIR in [LOGS_DIR, RESULTS_DIR]: os.makedirs(DIR, exist_ok=True) - log_model_name = args.model.replace('/', '_') + log_model_name = args.model.replace("/", "_") if args.model == "nemo": if args.nemo_use_server: if args.nemo_model_path not in NEMO_TRT_MODELS: - raise ValueError("Please specify a valid NeMo model for server inference!") + raise ValueError( + "Please specify a valid NeMo model for server inference!" + ) elif args.nemo_model_path is None or not os.path.exists(args.nemo_model_path): raise ValueError(f"{args.nemo_model_path} not found. Please specify a valid .nemo path") @@ -251,9 +415,17 @@ def main(args): file_name += "/" + "_".join(args.transforms) - args.save_file_name = file_name + state_manager: StateManager = StateManager( + LOGS_DIR + f"/benchmark_state_{serialize_args(args)}.pkl", + LOGS_DIR + file_name + ".log", + ) + restored_state = state_manager.restore_state() + if restored_state: + state_manager.update_state(restored_state) + time_tracker = RunTimeTracker(TIME_LIMIT) + if args.dataset in ["iwslt", "un_multi"]: data = create_dataset(args.dataset, supported_languages) else: @@ -261,16 +433,17 @@ def main(args): inference_model = Inference(args) args.data = data + args.logger = state_manager.logger - logger = create_logger(LOGS_DIR+file_name+".log") - logger.info(args) - with open(RESULTS_DIR+args.save_file_name+".txt", "a+") as f: - f.write(str(args)+"\n") - args.logger = logger - - attack(args, inference_model, RESULTS_DIR) + state_manager.logger.info(args) + with open(RESULTS_DIR + args.save_file_name + ".txt", "a+") as f: + f.write(str(args) + "\n") + try: + attack(args, inference_model, RESULTS_DIR) + except TimeoutError: + state_manager.logger.info("Exiting due to time limit.") -if __name__ == '__main__': +if __name__ == "__main__": args = get_args() main(args) diff --git a/nemo_utils/megatron_gpt_eval.py b/nemo_utils/megatron_gpt_eval.py index 69341f0..2ed9fb2 100644 --- a/nemo_utils/megatron_gpt_eval.py +++ b/nemo_utils/megatron_gpt_eval.py @@ -20,18 +20,23 @@ from typing import List import torch -from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import \ - MegatronGPTModel -from nemo.collections.nlp.modules.common.megatron.megatron_init import \ - fake_initialize_model_parallel -from nemo.collections.nlp.modules.common.text_generation_server import \ - MegatronServer +from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import ( + MegatronGPTModel, +) +from nemo.collections.nlp.modules.common.megatron.megatron_init import ( + fake_initialize_model_parallel, +) +from nemo.collections.nlp.modules.common.text_generation_server import MegatronServer from nemo.collections.nlp.modules.common.text_generation_utils import generate from nemo.collections.nlp.modules.common.transformer.text_generation import ( - LengthParam, SamplingParam) -from nemo.collections.nlp.parts.nlp_overrides import (CustomProgressBar, - NLPDDPStrategy, - NLPSaveRestoreConnector) + LengthParam, + SamplingParam, +) +from nemo.collections.nlp.parts.nlp_overrides import ( + CustomProgressBar, + NLPDDPStrategy, + NLPSaveRestoreConnector, +) from nemo.core.config import hydra_runner from nemo.utils.app_state import AppState from nemo.utils.model_utils import inject_model_parallel_rank @@ -45,10 +50,9 @@ HAVE_MEGATRON_CORE = True except (ImportError, ModuleNotFoundError): - HAVE_MEGATRON_CORE = False -__all__ = ['init_model', 'pred_by_generation'] +__all__ = ["init_model", "pred_by_generation"] """ This is the script to run GPT text generation. @@ -63,7 +67,9 @@ def __init__(self, sentences): super().__init__() self.sentences = sentences - def __len__(self,): + def __len__( + self, + ): return len(self.sentences) def __getitem__(self, idx): @@ -84,13 +90,13 @@ def nemo_init_model(cfg: OmegaConf): trainer = Trainer( strategy=NLPDDPStrategy(timeout=datetime.timedelta(seconds=18000)), **cfg.trainer, - callbacks=[CustomProgressBar()], + # callbacks=[CustomProgressBar()], ) if ( cfg.tensor_model_parallel_size < 0 or cfg.pipeline_model_parallel_size < 0 - or cfg.get('pipeline_model_parallel_split_rank', -1) < 0 + or cfg.get("pipeline_model_parallel_split_rank", -1) < 0 ): save_restore_connector = NLPSaveRestoreConnector() if os.path.isdir(cfg.gpt_model_file): @@ -103,9 +109,15 @@ def nemo_init_model(cfg: OmegaConf): ) with open_dict(cfg): - cfg.tensor_model_parallel_size = model_config.get('tensor_model_parallel_size', 1) - cfg.pipeline_model_parallel_size = model_config.get('pipeline_model_parallel_size', 1) - cfg.pipeline_model_parallel_split_rank = model_config.get('pipeline_model_parallel_split_rank', 0) + cfg.tensor_model_parallel_size = model_config.get( + "tensor_model_parallel_size", 1 + ) + cfg.pipeline_model_parallel_size = model_config.get( + "pipeline_model_parallel_size", 1 + ) + cfg.pipeline_model_parallel_split_rank = model_config.get( + "pipeline_model_parallel_split_rank", 0 + ) assert ( cfg.trainer.devices * cfg.trainer.num_nodes @@ -128,20 +140,24 @@ def nemo_init_model(cfg: OmegaConf): pretrained_cfg.activations_checkpoint_granularity = None pretrained_cfg.activations_checkpoint_method = None pretrained_cfg.precision = trainer.precision - if pretrained_cfg.get('mcore_gpt', False): + if pretrained_cfg.get("mcore_gpt", False): # with dist checkpointing we can use the model parallel config specified by the user pretrained_cfg.tensor_model_parallel_size = cfg.tensor_model_parallel_size - pretrained_cfg.pipeline_model_parallel_size = cfg.pipeline_model_parallel_size + pretrained_cfg.pipeline_model_parallel_size = ( + cfg.pipeline_model_parallel_size + ) if trainer.precision == "16": pretrained_cfg.megatron_amp_O2 = False - elif trainer.precision in ['bf16', 'bf16-mixed'] and cfg.get('megatron_amp_O2', False): + elif trainer.precision in ["bf16", "bf16-mixed"] and cfg.get( + "megatron_amp_O2", False + ): pretrained_cfg.megatron_amp_O2 = True model = MegatronGPTModel.restore_from( restore_path=cfg.gpt_model_file, trainer=trainer, override_config_path=pretrained_cfg, save_restore_connector=save_restore_connector, - map_location=f'cuda:{trainer.local_rank}', # map_location is needed for converted models + map_location=f"cuda:{trainer.local_rank}", # map_location is needed for converted models ) model.freeze() @@ -153,16 +169,15 @@ def nemo_init_model(cfg: OmegaConf): pass return model, trainer -def nemo_generate(model, prompts: List[str], batch_size: int, trainer, cfg: OmegaConf) -> List[str]: + +def nemo_generate( + model, prompts: List[str], batch_size: int, trainer, cfg: OmegaConf +) -> List[str]: cfg_infer = OmegaConf.to_container(cfg.inference) - + cfg_infer["batch_size"] = batch_size ds = RequestDataSet(prompts) request_dl = DataLoader(dataset=ds, batch_size=batch_size) model.set_inference_config(cfg_infer) response = trainer.predict(model, request_dl) return response - - - - diff --git a/nemo_utils/nemo_cfgs/megatron_gpt_inference.yaml b/nemo_utils/nemo_cfgs/megatron_gpt_inference.yaml index 4648575..4c3572f 100644 --- a/nemo_utils/nemo_cfgs/megatron_gpt_inference.yaml +++ b/nemo_utils/nemo_cfgs/megatron_gpt_inference.yaml @@ -18,6 +18,7 @@ trainer: logger: False # logger provided by exp_manager precision: bf16 # 16, 32, or bf16 use_distributed_sampler: False + enable_progress_bar: False tensor_model_parallel_size: 1 pipeline_model_parallel_size: 1 diff --git a/nemo_utils/run_time_tracker.py b/nemo_utils/run_time_tracker.py new file mode 100644 index 0000000..4d1b973 --- /dev/null +++ b/nemo_utils/run_time_tracker.py @@ -0,0 +1,33 @@ +import time + +from nemo_utils.state_manager import SingletonMeta + + +class RunTimeTracker(metaclass=SingletonMeta): + def __init__(self, time_limit_sec=None): + if not hasattr(self, "initialized"): + if time_limit_sec is None: + raise ValueError( + "time_limit_sec must be provided for the first initialization." + ) + self.start_time = time.time() + self.time_limit = time_limit_sec + self.initialized = True + + def elapsed_time(self): + return time.time() - self.start_time + + def __getstate__(self): + state = self.__dict__.copy() + return state + + def __setstate__(self, state): + self.__dict__.update(state) + + def has_sufficient_time(self, buffer_time=30): + """ + Check if there is sufficient time left before the time limit. + :param buffer_time: Time in seconds to be reserved for saving state (default 10 minutes) + :return: True if there is enough time left, False otherwise + """ + return self.elapsed_time() < self.time_limit - buffer_time diff --git a/nemo_utils/state_manager.py b/nemo_utils/state_manager.py new file mode 100644 index 0000000..2d5d5a9 --- /dev/null +++ b/nemo_utils/state_manager.py @@ -0,0 +1,75 @@ +import logging +import pickle + + +def create_logger(log_path): + logging.getLogger().handlers = [] + + logger = logging.getLogger(__name__) + logger.setLevel(logging.INFO) + + formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") + + file_handler = logging.FileHandler(log_path) + file_handler.setFormatter(formatter) + file_handler.setLevel(logging.INFO) + logger.addHandler(file_handler) + + return logger + + +class SingletonMeta(type): + _instances = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super(SingletonMeta, cls).__call__(*args, **kwargs) + return cls._instances[cls] + + +class StateManager(metaclass=SingletonMeta): + def __init__(self, file_path=None, log_dir=None): + if not hasattr(self, "initialized"): + if file_path is None or log_dir is None: + raise ValueError( + "file_path and log_dir must be provided for the first initialization." + ) + self.file_path = file_path + self.state = {} + self.logger = create_logger(log_dir) + self.initialized = True + if self.initialized: + self.logger.info( + f"StateManager retrieved from file_path={self.file_path} and log_dir={log_dir}" + ) + else: + self.logger.info( + f"StateManager initialized with file_path={self.file_path} and log_dir={log_dir}" + ) + + def save_state(self, state=None): + if state is not None: + self.state.update(state) + with open(self.file_path, "wb") as f: + pickle.dump(self.state, f) + self.logger.info(f"State saved to {self.file_path}") + + def restore_state(self): + try: + with open(self.file_path, "rb") as f: + self.state = pickle.load(f) + return self.state + except FileNotFoundError: + return None + + def add_to_state(self, key, value): + """ + Add a key-value pair to the state. + """ + self.state[key] = value + + def update_state(self, state): + """ + Add a dictionary of key-value pairs to the state. + """ + self.state.update(state) diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..0fc63aa --- /dev/null +++ b/utils.py @@ -0,0 +1,44 @@ +import hashlib + +from nemo_utils.state_manager import StateManager + + +def serialize_args(args): + """Serialize the argparse Namespace to a hash string.""" + args_str = "_".join([f"{key}_{value}" for key, value in vars(args).items()]) + # Create a hash of the arguments string for a shorter, fixed-length file name + return hashlib.md5(args_str.encode()).hexdigest() + + +def handle_timeout_error(state_generator_func): + """ + A decorator that handles TimeoutError exceptions by saving the state and logging a message. + state_generator_func (function): A function that generates the state. + """ + + def decorator(func): + def wrapper(*args, **kwargs): + state_manager = StateManager() + try: + return func(*args, **kwargs) + except TimeoutError: + if state_manager: + print("Saving state...") + state = state_generator_func(*args, **kwargs) + state_manager.save_state(state) + state_manager.logger.info("Exiting due to time limit.") + raise + + return wrapper + + return decorator + + +def generate_predict_step(obj, batch_id, gts, preds, model, all_data=None): + return {"batch_id": batch_id, "gts": gts, "preds": preds, "all_data": all_data} + + +def generate_semantic_attack_state( + obj, language, inference_model, results_dir, dataset +): + return {"language": language}