diff --git a/src/train.py b/src/train.py index 6040368..4c6f81c 100644 --- a/src/train.py +++ b/src/train.py @@ -33,13 +33,9 @@ DEFAULT_UNK_TOKEN = "" PROMPT_DICT = { "prompt_input": ( - "Below is an instruction that describes a task, paired with an input that provides further context. " - "Write a response that appropriately completes the request.\n\n" "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:" ), "prompt_no_input": ( - "Below is an instruction that describes a task. " - "Write a response that appropriately completes the request.\n\n" "### Instruction:\n{instruction}\n\n### Response:" ), } @@ -58,9 +54,9 @@ class DataArguments: @dataclass class TrainingArguments(transformers.TrainingArguments): cache_dir: Optional[str] = field(default=None) - optim: str = field(default="adamw_torch") + optim: str = field(default="adamw_bnb_8bit") model_max_length: int = field( - default=512, + default=2048 metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."}, ) diff --git a/src/utils.py b/src/utils.py index 29d2499..f49e77e 100644 --- a/src/utils.py +++ b/src/utils.py @@ -13,123 +13,6 @@ from openai import openai_object import copy -StrOrOpenAIObject = Union[str, openai_object.OpenAIObject] - -openai_org = os.getenv("OPENAI_ORG") -if openai_org is not None: - openai.organization = openai_org - logging.warning(f"Switching to organization: {openai_org} for OAI API key.") - - -@dataclasses.dataclass -class OpenAIDecodingArguments(object): - max_tokens: int = 1800 - temperature: float = 0.2 - top_p: float = 1.0 - n: int = 1 - stream: bool = False - stop: Optional[Sequence[str]] = None - presence_penalty: float = 0.0 - frequency_penalty: float = 0.0 - suffix: Optional[str] = None - logprobs: Optional[int] = None - echo: bool = False - - -def openai_completion( - prompts: Union[str, Sequence[str], Sequence[dict[str, str]], dict[str, str]], - decoding_args: OpenAIDecodingArguments, - model_name="text-davinci-003", - sleep_time=2, - batch_size=1, - max_instances=sys.maxsize, - max_batches=sys.maxsize, - return_text=False, - **decoding_kwargs, -) -> Union[Union[StrOrOpenAIObject], Sequence[StrOrOpenAIObject], Sequence[Sequence[StrOrOpenAIObject]],]: - """Decode with OpenAI API. - - Args: - prompts: A string or a list of strings to complete. If it is a chat model the strings should be formatted - as explained here: https://github.com/openai/openai-python/blob/main/chatml.md. If it is a chat model - it can also be a dictionary (or list thereof) as explained here: - https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb - decoding_args: Decoding arguments. - model_name: Model name. Can be either in the format of "org/model" or just "model". - sleep_time: Time to sleep once the rate-limit is hit. - batch_size: Number of prompts to send in a single request. Only for non chat model. - max_instances: Maximum number of prompts to decode. - max_batches: Maximum number of batches to decode. This argument will be deprecated in the future. - return_text: If True, return text instead of full completion object (which contains things like logprob). - decoding_kwargs: Additional decoding arguments. Pass in `best_of` and `logit_bias` if you need them. - - Returns: - A completion or a list of completions. - Depending on return_text, return_openai_object, and decoding_args.n, the completion type can be one of - - a string (if return_text is True) - - an openai_object.OpenAIObject object (if return_text is False) - - a list of objects of the above types (if decoding_args.n > 1) - """ - is_single_prompt = isinstance(prompts, (str, dict)) - if is_single_prompt: - prompts = [prompts] - - if max_batches < sys.maxsize: - logging.warning( - "`max_batches` will be deprecated in the future, please use `max_instances` instead." - "Setting `max_instances` to `max_batches * batch_size` for now." - ) - max_instances = max_batches * batch_size - - prompts = prompts[:max_instances] - num_prompts = len(prompts) - prompt_batches = [ - prompts[batch_id * batch_size : (batch_id + 1) * batch_size] - for batch_id in range(int(math.ceil(num_prompts / batch_size))) - ] - - completions = [] - for batch_id, prompt_batch in tqdm.tqdm( - enumerate(prompt_batches), - desc="prompt_batches", - total=len(prompt_batches), - ): - batch_decoding_args = copy.deepcopy(decoding_args) # cloning the decoding_args - - while True: - try: - shared_kwargs = dict( - model=model_name, - **batch_decoding_args.__dict__, - **decoding_kwargs, - ) - completion_batch = openai.Completion.create(prompt=prompt_batch, **shared_kwargs) - choices = completion_batch.choices - - for choice in choices: - choice["total_tokens"] = completion_batch.usage.total_tokens - completions.extend(choices) - break - except openai.error.OpenAIError as e: - logging.warning(f"OpenAIError: {e}.") - if "Please reduce your prompt" in str(e): - batch_decoding_args.max_tokens = int(batch_decoding_args.max_tokens * 0.8) - logging.warning(f"Reducing target length to {batch_decoding_args.max_tokens}, Retrying...") - else: - logging.warning("Hit request rate limit; retrying...") - time.sleep(sleep_time) # Annoying rate limit on requests. - - if return_text: - completions = [completion.text for completion in completions] - if decoding_args.n > 1: - # make completions a nested list, where each entry is a consecutive decoding_args.n of original entries. - completions = [completions[i : i + decoding_args.n] for i in range(0, len(completions), decoding_args.n)] - if is_single_prompt: - # Return non-tuple if only 1 input and 1 generation. - (completions,) = completions - return completions - - def _make_w_io_base(f, mode: str): if not isinstance(f, io.IOBase): f_dirname = os.path.dirname(f)