Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 104 additions & 14 deletions nemo_skills/inference/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class GenerationTaskConfig:

input_file: str # Path to the input file with data
output_file: str # Where to save the generations
prompt_config: str | None = None # How to format the data into prompts
prompt_config: Any = None # How to format the data into prompts (str path, dict, or None)

# Deprecated, please use endpoint_type in the InferenceConfig instead
use_completions_api: bool = False
Expand All @@ -104,6 +104,7 @@ class GenerationTaskConfig:
prompt_format: str = "ns"
prompt_suffix: str = "" # suffix to add to the prompt, e.g. " /no_think"
system_message: str | None = None # can override the default system message in the config
user_message: str | None = None # can override the user message in the prompt config template
code_tags: str | None = None # required when using code execution
examples_type: str | None = None # to be able to customize few-shot examples

Expand Down Expand Up @@ -251,10 +252,9 @@ def _post_init_validate_params(self):
if self.prompt_format not in ["ns", "openai"]:
raise ValueError(f"prompt_format must be either 'ns' or 'openai', got '{self.prompt_format}'")

if self.prompt_format == "openai":
assert self.prompt_config is None, "prompt_config is not supported for prompt_format == 'openai'"
else:
assert self.prompt_config is not None, "prompt_config is required when prompt_format == 'ns'"
if self.prompt_format == "ns":
if self.prompt_config is None:
raise ValueError("prompt_config is required when prompt_format == 'ns'")
for param, default_value in self._get_disallowed_params():
if getattr(self, param) != default_value:
raise ValueError(f"{param} must be {default_value}")
Expand Down Expand Up @@ -311,6 +311,7 @@ def __init__(self, cfg: GenerationTaskConfig):
cfg: GenerationTaskConfig object with the configuration parameters or subclass.
"""
self.cfg = cfg

if isinstance(self.cfg.inference.extra_body, DictConfig):
self.cfg.inference.extra_body = OmegaConf.to_container(self.cfg.inference.extra_body, resolve=True)
else:
Expand Down Expand Up @@ -405,7 +406,8 @@ def __init__(self, cfg: GenerationTaskConfig):
self.output_lock = None

def setup_prompt(self):
if self.cfg.prompt_format == "openai":
if self.cfg.prompt_config is None:
# openai format without prompt_config -- messages come from data
return None

prompt = get_prompt(
Expand All @@ -414,6 +416,7 @@ def setup_prompt(self):
code_tags=self.cfg.code_tags,
examples_type=self.cfg.examples_type,
system_message=self.cfg.system_message,
user_message=self.cfg.user_message,
)

LOG.info("Prompt used: %s", prompt)
Expand Down Expand Up @@ -579,20 +582,107 @@ def skip_completed_samples(self, data):

return remaining_data

def _merge_audio_from_data(self, template_filled_messages, data_point):
"""Copy audio metadata from original data messages into template-generated messages.

Normalizes to a single "audios" list on each message. Openai-format data always
has messages; top-level audio is not used in practice.
"""
if "messages" not in data_point or not isinstance(data_point["messages"], (list, ListConfig)):
return
original_messages = data_point["messages"]
used_original_messages = [False] * len(original_messages)
for msg in template_filled_messages:
if "audios" in msg:
continue
for idx, orig_msg in enumerate(original_messages):
if used_original_messages[idx]:
continue
if orig_msg["role"] != msg["role"]:
continue
if "name" in msg and orig_msg.get("name") != msg["name"]:
continue

audios = orig_msg.get("audios") or ([orig_msg["audio"]] if "audio" in orig_msg else None)
if audios:
msg["audios"] = audios
used_original_messages[idx] = True
break

@staticmethod
def _set_message_text_content(message: dict, text: str) -> None:
"""Set text content for string or multimodal message content while preserving non-text items."""
content = message["content"]
if isinstance(content, str):
message["content"] = text
return
if not isinstance(content, list):
raise TypeError(f"Unexpected content type: {type(content)}")

for item in content:
if item.get("type") == "text":
item["text"] = text
return
content.insert(0, {"type": "text", "text": text})

@staticmethod
def _append_message_text_suffix(message: dict, suffix: str) -> None:
"""Append suffix to text content for string or multimodal message content."""
content = message["content"]
if isinstance(content, str):
message["content"] = content + suffix
return
if not isinstance(content, list):
raise TypeError(f"Unexpected content type: {type(content)}")

for item in content:
if item.get("type") == "text":
item["text"] += suffix
return
content.append({"type": "text", "text": suffix})

# TODO: data will not include any samples skipped after restart
def fill_prompt(self, data_point, data, prompt_format=None):
"""Passing in full data in case it's needed to fill the prompt in subclasses."""
prompt_format = prompt_format or self.cfg.prompt_format
if prompt_format == "openai":
if self.prompt is None:
# Pure openai path -- messages come from the data
data_point = deepcopy(data_point)
if self.cfg.user_message:
user_msgs = [m for m in data_point["messages"] if m["role"] == "user"]
if len(user_msgs) != 1:
raise ValueError(
f"user_message override expects exactly 1 user message, found {len(user_msgs)}"
)
GenerationTask._set_message_text_content(user_msgs[0], self.cfg.user_message)
if self.cfg.prompt_suffix:
GenerationTask._append_message_text_suffix(data_point["messages"][-1], self.cfg.prompt_suffix)
if self.cfg.system_message:
if data_point["messages"][0]["role"] != "system":
data_point["messages"].insert(0, {"role": "system", "content": self.cfg.system_message})
else:
data_point["messages"][0]["content"] = self.cfg.system_message
return data_point["messages"]

# OpenAI path with prompt_config template -- build prompt from template, merge audio from data.
data_point = deepcopy(data_point)
filled_prompt = self.prompt.fill(
data_point,
start_assistant_response_key=self.cfg.start_assistant_response_key,
chat_template_kwargs=self.cfg.chat_template_kwargs,
format_as_string=(self.cfg.inference.endpoint_type == EndpointType.text),
)
if isinstance(filled_prompt, list):
self._merge_audio_from_data(filled_prompt, data_point)
if self.cfg.prompt_suffix:
data_point["messages"][-1]["content"] += self.cfg.prompt_suffix
if self.cfg.system_message:
if data_point["messages"][0]["role"] != "system":
data_point["messages"].insert(0, {"role": "system", "content": self.cfg.system_message})
if isinstance(filled_prompt, list):
GenerationTask._append_message_text_suffix(filled_prompt[-1], self.cfg.prompt_suffix)
else:
data_point["messages"][0]["content"] = self.cfg.system_message
return data_point["messages"]
filled_prompt += self.cfg.prompt_suffix
return filled_prompt

# NS path -- always uses prompt template
total_code_executions_in_prompt = self.cfg.total_code_executions_in_prompt
if total_code_executions_in_prompt is not None:
if isinstance(total_code_executions_in_prompt, (list, tuple)):
Expand All @@ -608,8 +698,8 @@ def fill_prompt(self, data_point, data, prompt_format=None):
)
if self.cfg.prompt_suffix:
if isinstance(filled_prompt, list):
filled_prompt[-1]["content"] += self.cfg.prompt_suffix
else:
GenerationTask._append_message_text_suffix(filled_prompt[-1], self.cfg.prompt_suffix)
elif isinstance(filled_prompt, str):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the other possible case? Why do we need elif?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no other case, but that's "no silent approach"

filled_prompt += self.cfg.prompt_suffix
return filled_prompt

Expand Down
18 changes: 17 additions & 1 deletion nemo_skills/prompt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ class PromptConfig:
image_field: str | None = None
# Whether to put image before or after the text in multimodal content
image_position: str = "before" # "before" or "after"
# Audio support: field name from input_dict containing audio metadata (dict or list of dicts)
# When set, audio metadata is attached to the user message as "audios" list for model processing
audio_field: str | None = None


class Prompt:
Expand Down Expand Up @@ -292,7 +295,16 @@ def fill(
else:
user_content = user_text

messages.append({"role": "user", "content": user_content})
user_message_dict = {"role": "user", "content": user_content}

# For audio: attach audio metadata to user message as audios list (model layer handles base64 conversion)
if self.config.audio_field and self.config.audio_field in input_dict:
audio_data = input_dict[self.config.audio_field]
if isinstance(audio_data, dict):
audio_data = [audio_data]
user_message_dict["audios"] = audio_data

messages.append(user_message_dict)

if not format_as_string:
if start_assistant_response_key:
Expand Down Expand Up @@ -443,6 +455,7 @@ def get_prompt(
code_tags: str | dict | None = None,
examples_type: str | None = None,
system_message: str | None = None,
user_message: str | None = None,
config_dir: str | None = None,
code_tags_dir: str | None = None,
) -> Prompt:
Expand All @@ -457,6 +470,9 @@ def get_prompt(
if system_message is not None:
config["system"] = system_message

if user_message is not None:
config["user"] = user_message

code_tags_obj = None
if code_tags is not None:
if isinstance(code_tags, str):
Expand Down
2 changes: 1 addition & 1 deletion requirements/audio.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
jiwer>=3.1.0,<4.0.0 # Word/Character Error Rate computation
sacrebleu # BLEU score computation
soundfile # Audio file I/O for dataset preparation
whisper-normalizer # Lightweight text normalization (EnglishTextNormalizer)
# torchcodec requires FFmpeg shared libraries (not installable via pip).
# Install via system package manager before running pip install:
# Linux: sudo apt install ffmpeg
# macOS: brew install ffmpeg
torchcodec
whisper-normalizer # Lightweight text normalization (EnglishTextNormalizer)
Loading