-
Notifications
You must be signed in to change notification settings - Fork 155
Update promt_config to working with openai format + inline setup #1210
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Jorjeous
wants to merge
17
commits into
main
Choose a base branch
from
prompt_override_clean
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+449
−16
Open
Changes from all commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
9f49f58
Adress comments
Jorjeous 2ac4af6
linter
Jorjeous 6eade62
adress comments
Jorjeous 16d7e29
fix comment
Jorjeous 7341970
greptile comments
Jorjeous 1133d3e
rm new promt, add all modifications to promt_config.
Jorjeous 026d1d0
optimized to use audio field
Jorjeous 6bad619
linter and format
Jorjeous 4405f06
Merge branch 'main' into prompt_override_clean
Jorjeous 984e91e
Update nemo_skills/inference/generate.py
Jorjeous 80ea2a2
Merge branch 'main' into prompt_override_clean
Jorjeous ceaa6bf
Apply suggestion from @Kipok
Jorjeous 601d92d
address some comments
Jorjeous 962b06a
fixing cicd promblems
Jorjeous 1168f38
fix openai prompt text updates in mocked task tests
Jorjeous f5b1f0d
Merge origin/main into prompt_override_clean and resolve prompt merge…
Jorjeous 14b9676
linter fix
Jorjeous File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -91,7 +91,7 @@ class GenerationTaskConfig: | |
|
|
||
| input_file: str # Path to the input file with data | ||
| output_file: str # Where to save the generations | ||
| prompt_config: str | None = None # How to format the data into prompts | ||
| prompt_config: Any = None # How to format the data into prompts (str path, dict, or None) | ||
|
|
||
| # Deprecated, please use endpoint_type in the InferenceConfig instead | ||
| use_completions_api: bool = False | ||
|
|
@@ -104,6 +104,7 @@ class GenerationTaskConfig: | |
| prompt_format: str = "ns" | ||
| prompt_suffix: str = "" # suffix to add to the prompt, e.g. " /no_think" | ||
| system_message: str | None = None # can override the default system message in the config | ||
| user_message: str | None = None # can override the user message in the prompt config template | ||
| code_tags: str | None = None # required when using code execution | ||
| examples_type: str | None = None # to be able to customize few-shot examples | ||
|
|
||
|
|
@@ -251,10 +252,9 @@ def _post_init_validate_params(self): | |
| if self.prompt_format not in ["ns", "openai"]: | ||
| raise ValueError(f"prompt_format must be either 'ns' or 'openai', got '{self.prompt_format}'") | ||
|
|
||
| if self.prompt_format == "openai": | ||
| assert self.prompt_config is None, "prompt_config is not supported for prompt_format == 'openai'" | ||
| else: | ||
| assert self.prompt_config is not None, "prompt_config is required when prompt_format == 'ns'" | ||
| if self.prompt_format == "ns": | ||
| if self.prompt_config is None: | ||
| raise ValueError("prompt_config is required when prompt_format == 'ns'") | ||
| for param, default_value in self._get_disallowed_params(): | ||
| if getattr(self, param) != default_value: | ||
| raise ValueError(f"{param} must be {default_value}") | ||
|
|
@@ -311,6 +311,7 @@ def __init__(self, cfg: GenerationTaskConfig): | |
| cfg: GenerationTaskConfig object with the configuration parameters or subclass. | ||
| """ | ||
| self.cfg = cfg | ||
|
|
||
| if isinstance(self.cfg.inference.extra_body, DictConfig): | ||
| self.cfg.inference.extra_body = OmegaConf.to_container(self.cfg.inference.extra_body, resolve=True) | ||
| else: | ||
|
|
@@ -405,7 +406,8 @@ def __init__(self, cfg: GenerationTaskConfig): | |
| self.output_lock = None | ||
|
|
||
| def setup_prompt(self): | ||
| if self.cfg.prompt_format == "openai": | ||
| if self.cfg.prompt_config is None: | ||
| # openai format without prompt_config -- messages come from data | ||
| return None | ||
|
|
||
| prompt = get_prompt( | ||
|
|
@@ -414,6 +416,7 @@ def setup_prompt(self): | |
| code_tags=self.cfg.code_tags, | ||
| examples_type=self.cfg.examples_type, | ||
| system_message=self.cfg.system_message, | ||
| user_message=self.cfg.user_message, | ||
| ) | ||
|
|
||
| LOG.info("Prompt used: %s", prompt) | ||
|
|
@@ -579,20 +582,107 @@ def skip_completed_samples(self, data): | |
|
|
||
| return remaining_data | ||
|
|
||
| def _merge_audio_from_data(self, template_filled_messages, data_point): | ||
| """Copy audio metadata from original data messages into template-generated messages. | ||
|
|
||
| Normalizes to a single "audios" list on each message. Openai-format data always | ||
| has messages; top-level audio is not used in practice. | ||
| """ | ||
| if "messages" not in data_point or not isinstance(data_point["messages"], (list, ListConfig)): | ||
| return | ||
| original_messages = data_point["messages"] | ||
| used_original_messages = [False] * len(original_messages) | ||
| for msg in template_filled_messages: | ||
| if "audios" in msg: | ||
| continue | ||
| for idx, orig_msg in enumerate(original_messages): | ||
| if used_original_messages[idx]: | ||
| continue | ||
| if orig_msg["role"] != msg["role"]: | ||
| continue | ||
| if "name" in msg and orig_msg.get("name") != msg["name"]: | ||
| continue | ||
|
|
||
| audios = orig_msg.get("audios") or ([orig_msg["audio"]] if "audio" in orig_msg else None) | ||
| if audios: | ||
| msg["audios"] = audios | ||
| used_original_messages[idx] = True | ||
| break | ||
|
|
||
| @staticmethod | ||
| def _set_message_text_content(message: dict, text: str) -> None: | ||
| """Set text content for string or multimodal message content while preserving non-text items.""" | ||
| content = message["content"] | ||
| if isinstance(content, str): | ||
| message["content"] = text | ||
| return | ||
| if not isinstance(content, list): | ||
| raise TypeError(f"Unexpected content type: {type(content)}") | ||
|
|
||
| for item in content: | ||
| if item.get("type") == "text": | ||
| item["text"] = text | ||
| return | ||
| content.insert(0, {"type": "text", "text": text}) | ||
|
|
||
| @staticmethod | ||
| def _append_message_text_suffix(message: dict, suffix: str) -> None: | ||
| """Append suffix to text content for string or multimodal message content.""" | ||
| content = message["content"] | ||
| if isinstance(content, str): | ||
| message["content"] = content + suffix | ||
| return | ||
| if not isinstance(content, list): | ||
| raise TypeError(f"Unexpected content type: {type(content)}") | ||
|
|
||
| for item in content: | ||
| if item.get("type") == "text": | ||
| item["text"] += suffix | ||
| return | ||
| content.append({"type": "text", "text": suffix}) | ||
|
|
||
| # TODO: data will not include any samples skipped after restart | ||
| def fill_prompt(self, data_point, data, prompt_format=None): | ||
| """Passing in full data in case it's needed to fill the prompt in subclasses.""" | ||
| prompt_format = prompt_format or self.cfg.prompt_format | ||
| if prompt_format == "openai": | ||
| if self.prompt is None: | ||
| # Pure openai path -- messages come from the data | ||
| data_point = deepcopy(data_point) | ||
| if self.cfg.user_message: | ||
| user_msgs = [m for m in data_point["messages"] if m["role"] == "user"] | ||
| if len(user_msgs) != 1: | ||
| raise ValueError( | ||
| f"user_message override expects exactly 1 user message, found {len(user_msgs)}" | ||
| ) | ||
| GenerationTask._set_message_text_content(user_msgs[0], self.cfg.user_message) | ||
| if self.cfg.prompt_suffix: | ||
| GenerationTask._append_message_text_suffix(data_point["messages"][-1], self.cfg.prompt_suffix) | ||
| if self.cfg.system_message: | ||
| if data_point["messages"][0]["role"] != "system": | ||
| data_point["messages"].insert(0, {"role": "system", "content": self.cfg.system_message}) | ||
| else: | ||
| data_point["messages"][0]["content"] = self.cfg.system_message | ||
| return data_point["messages"] | ||
|
|
||
| # OpenAI path with prompt_config template -- build prompt from template, merge audio from data. | ||
| data_point = deepcopy(data_point) | ||
| filled_prompt = self.prompt.fill( | ||
| data_point, | ||
| start_assistant_response_key=self.cfg.start_assistant_response_key, | ||
| chat_template_kwargs=self.cfg.chat_template_kwargs, | ||
| format_as_string=(self.cfg.inference.endpoint_type == EndpointType.text), | ||
| ) | ||
| if isinstance(filled_prompt, list): | ||
| self._merge_audio_from_data(filled_prompt, data_point) | ||
| if self.cfg.prompt_suffix: | ||
| data_point["messages"][-1]["content"] += self.cfg.prompt_suffix | ||
| if self.cfg.system_message: | ||
| if data_point["messages"][0]["role"] != "system": | ||
| data_point["messages"].insert(0, {"role": "system", "content": self.cfg.system_message}) | ||
| if isinstance(filled_prompt, list): | ||
| GenerationTask._append_message_text_suffix(filled_prompt[-1], self.cfg.prompt_suffix) | ||
| else: | ||
| data_point["messages"][0]["content"] = self.cfg.system_message | ||
| return data_point["messages"] | ||
| filled_prompt += self.cfg.prompt_suffix | ||
| return filled_prompt | ||
|
|
||
| # NS path -- always uses prompt template | ||
| total_code_executions_in_prompt = self.cfg.total_code_executions_in_prompt | ||
| if total_code_executions_in_prompt is not None: | ||
| if isinstance(total_code_executions_in_prompt, (list, tuple)): | ||
|
|
@@ -608,8 +698,8 @@ def fill_prompt(self, data_point, data, prompt_format=None): | |
| ) | ||
| if self.cfg.prompt_suffix: | ||
| if isinstance(filled_prompt, list): | ||
| filled_prompt[-1]["content"] += self.cfg.prompt_suffix | ||
| else: | ||
| GenerationTask._append_message_text_suffix(filled_prompt[-1], self.cfg.prompt_suffix) | ||
| elif isinstance(filled_prompt, str): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what's the other possible case? Why do we need elif?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no other case, but that's "no silent approach" |
||
| filled_prompt += self.cfg.prompt_suffix | ||
| return filled_prompt | ||
|
|
||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,9 +1,9 @@ | ||
| jiwer>=3.1.0,<4.0.0 # Word/Character Error Rate computation | ||
| sacrebleu # BLEU score computation | ||
| soundfile # Audio file I/O for dataset preparation | ||
| whisper-normalizer # Lightweight text normalization (EnglishTextNormalizer) | ||
| # torchcodec requires FFmpeg shared libraries (not installable via pip). | ||
| # Install via system package manager before running pip install: | ||
| # Linux: sudo apt install ffmpeg | ||
| # macOS: brew install ffmpeg | ||
| torchcodec | ||
| whisper-normalizer # Lightweight text normalization (EnglishTextNormalizer) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.