diff --git a/megatron/core/tokenizers/conversation/__init__.py b/megatron/core/tokenizers/conversation/__init__.py new file mode 100644 index 00000000000..08a04f095a1 --- /dev/null +++ b/megatron/core/tokenizers/conversation/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +from megatron.core.tokenizers.conversation.conversation_tokenizer import tokenize_conversation +from megatron.core.tokenizers.conversation.prompt_config import PROMPT_FORMAT_REGISTRY, PromptConfig diff --git a/megatron/core/tokenizers/conversation/conversation_tokenizer.py b/megatron/core/tokenizers/conversation/conversation_tokenizer.py new file mode 100644 index 00000000000..947e74ed832 --- /dev/null +++ b/megatron/core/tokenizers/conversation/conversation_tokenizer.py @@ -0,0 +1,143 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +from typing import Callable, Dict, List, Optional, Tuple, Union + +import numpy as np + +from megatron.core.tokenizers.conversation.prompt_config import PromptConfig + +IGNORE_INDEX = -100 + + +def tokenize_conversation( + tokenizer, + conversation: List[Dict], + prompt_config: PromptConfig, + return_target: bool, + add_generation_prompt: bool, + apply_image_tag_fn: Optional[Callable] = None, +) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: + """Tokenize a conversation with optional target masking for training. + + This is the shared implementation used by both SFT and multimodal tokenizers. + + Args: + tokenizer: A tokenizer instance with ``apply_chat_template`` support. + conversation: Sequence of system/user/assistant messages in the format: + [{"role": "system", "content": "..."}, {"role": "user", "content": "..."}, ...] + prompt_config: Configuration controlling tokenization and masking behavior. + return_target: If True, return target tokens with system/user tokens masked. + add_generation_prompt: If True, add assistant prefix to the end. + apply_image_tag_fn: Optional callback to apply image tags to conversation + (used by multimodal tokenizer). + + Returns: + tokens (np.ndarray): Token IDs for the conversation. + target (np.ndarray): Target token IDs with masked positions set to IGNORE_INDEX. + Only returned if return_target is True. + """ + # 1. Skip system message if the tokenizer doesn't have a system role. + if not prompt_config.has_system_role and conversation[0]["role"] == "system": + conversation = conversation[1:] + + # 2. Force system message if configured. + if prompt_config.force_system_message: + assert ( + prompt_config.system_default is not None + ), "Trying to force system message with empty system default" + if conversation[0]["role"] == "system": + conversation[0] = prompt_config.system_default + else: + conversation = [prompt_config.system_default] + conversation + + # 3. Capitalize roles if needed (e.g. nemotron5-aligned format). + if prompt_config.capitalize_roles: + for turn in conversation: + role = turn['role'] + turn['role'] = role[:1].upper() + role[1:] + + # 4. Apply image tags if callback provided (multimodal). + if apply_image_tag_fn is not None: + conversation = apply_image_tag_fn(conversation) + + # 5. Tokenize with chat template. + result = tokenizer.apply_chat_template( + conversation, + tokenize=True, + add_generation_prompt=add_generation_prompt, + chat_template=prompt_config.custom_chat_template, + ) + # Normalize to 1D numpy array regardless of backend. + if isinstance(result, np.ndarray): + tokens = result.flatten() + else: + tokens = np.array(result, dtype=np.int64) + + if not return_target: + return tokens + + target = tokens.copy() + + # 6. Skip masking if configured (e.g. SFT "default" format). + if prompt_config.skip_masking: + return tokens, target + + # 7. Mask system and user tokens in the target. + masked_roles = {"system", "user"} + if prompt_config.allow_tool_role: + masked_roles.add("tool") + + idx = 0 + for turn_idx, turn in enumerate(conversation): + role = turn["role"].lower() + + # Validate empty turns. + if prompt_config.allow_tool_role: + # SFT behavior: only check assistant turns for empty content. + if role == "assistant" and len(turn["content"]) == 0: + raise ValueError(f"empty assistant turn in conversation: {conversation}.") + if role == "assistant": + assert conversation[turn_idx - 1]["role"].lower() in ("user", "tool") + else: + # Multimodal behavior: check all turns for empty content. + if len(turn["content"]) == 0: + raise ValueError(f"empty turn in conversation: {conversation}. Skipping.") + + # Validate no image token in assistant content. + if prompt_config.validate_no_image_in_assistant and role == "assistant": + try: + from megatron.core.models.multimodal.llava_model import IMAGE_TOKEN + + if IMAGE_TOKEN in turn["content"]: + raise RuntimeError(f"{IMAGE_TOKEN} not allowed in assistant content!") + except ImportError: + pass + + turn_result = tokenizer.apply_chat_template( + [turn], tokenize=True, chat_template=prompt_config.custom_chat_template + ) + turn_tokens = list(turn_result) if not isinstance(turn_result, list) else turn_result + + # There should be only one BOS at the very beginning. + # After the first turn, skip BOS token. + if prompt_config.has_bos and turn_idx > 0: + turn_tokens = turn_tokens[1:] + turn_len = len(turn_tokens) + + if role in masked_roles: + target[idx : idx + turn_len] = IGNORE_INDEX + elif role == "assistant": + if prompt_config.assistant_prefix_len > 0: + target[idx : idx + prompt_config.assistant_prefix_len] = IGNORE_INDEX + else: + raise ValueError(f"Wrong role value: {role}") + + assert np.allclose( + tokens[idx : idx + turn_len], turn_tokens + ), f"expected turn tokens to match tokens in conversation {conversation}" + + idx += turn_len + + assert idx == len(tokens), f"mismatch in target masking the conversation {conversation}" + + return tokens, target diff --git a/megatron/core/tokenizers/conversation/prompt_config.py b/megatron/core/tokenizers/conversation/prompt_config.py new file mode 100644 index 00000000000..c70bebbdfc3 --- /dev/null +++ b/megatron/core/tokenizers/conversation/prompt_config.py @@ -0,0 +1,282 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +from dataclasses import dataclass +from typing import Dict, Optional + +# --------------------------------------------------------------------------- +# Library-agnostic helpers for resolving token attributes +# --------------------------------------------------------------------------- + + +def _token_to_id(tokenizer, token: str) -> int: + """Resolve a token string to its ID, working with any tokenizer library. + + All text library tokenizers have ``token_to_id`` (concrete default on the + abstract base class), so that is the primary branch. The + ``convert_tokens_to_ids`` fallback covers the multimodal path, which passes + a raw HuggingFace ``AutoTokenizer`` directly. + """ + if hasattr(tokenizer, 'token_to_id'): + return tokenizer.token_to_id(token) + if hasattr(tokenizer, 'convert_tokens_to_ids'): + return tokenizer.convert_tokens_to_ids(token) + raise TypeError(f"Cannot resolve token to ID for {type(tokenizer)}") + + +def _get_pad_token_id(tokenizer) -> Optional[int]: + return getattr(tokenizer, 'pad_token_id', None) or getattr(tokenizer, 'pad_id', None) + + +def _get_bos_token_id(tokenizer) -> Optional[int]: + return getattr(tokenizer, 'bos_token_id', None) or getattr(tokenizer, 'bos_id', None) + + +def _get_eos_token_id(tokenizer) -> Optional[int]: + return getattr(tokenizer, 'eos_token_id', None) or getattr(tokenizer, 'eos_id', None) + + +def _get_unk_token_id(tokenizer) -> Optional[int]: + return getattr(tokenizer, 'unk_token_id', None) or getattr(tokenizer, 'unk_id', None) + + +def _get_chat_template(tokenizer) -> Optional[str]: + return getattr(tokenizer, 'chat_template', None) + + +@dataclass +class PromptConfig: + """Config options for different prompt formats. + + Controls how conversations are tokenized and how target masking is applied + for supervised fine-tuning (SFT) and multimodal training. + """ + + # How many tokens are used for the assistant prefix, e.g. "<|im_start|>assistant\n". + # Used for masking the assistant prefix. + assistant_prefix_len: int + # Padding token ID. + pad_token_id: int + # For overriding the default chat format template. + custom_chat_template: Optional[str] + # If the tokenizer inserts BOS token by default. + has_bos: bool + # If the tokenizer supports a separate role for system messages. + has_system_role: bool + # Whether to force a specific system message. + force_system_message: bool = False + system_default: Optional[dict] = None + # Whether to validate that IMAGE_TOKEN is not in assistant content. + validate_no_image_in_assistant: bool = False + # Whether to capitalize role names (e.g. for nemotron5-aligned format). + capitalize_roles: bool = False + # Whether to skip target masking entirely (e.g. for SFT "default" format). + skip_masking: bool = False + # Whether to include "tool" role in the set of masked (non-training) roles. + allow_tool_role: bool = False + + +# --------------------------------------------------------------------------- +# Chat template strings +# --------------------------------------------------------------------------- + +# SFT templates +# fmt: off +nemotron_h_aligned_custom_template = """{% for message in messages %}{% if message['role'] == 'system' %}{{ 'System\n' + message['content'].strip() + '\n' }}{% elif message['role'] == 'user' %}{{ 'User\n' + message['content'].strip() + '\n' + 'Assistant\n' }}{% elif message['role'] == 'assistant' %}{{ message['content'].strip() + '\n' }}{% endif %}{% endfor %}""" # pylint: disable=line-too-long +nemotron_nano_v2_custom_template = """{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'system' %}{{ 'System\n' + content.replace('/think', '').replace('/no_think', '').strip() + '\n' }}{% elif message['role'] == 'user' %}{{ 'User\n' + content.replace('/think', '').replace('/no_think', '').strip() + '\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant\n' + content.strip() + '\n\n' }}{% endif %}{% endfor %}""" # pylint: disable=line-too-long +identity_template = """{% for message in messages %}{{ message['content'] }}{% endfor %}""" +# fmt: on + +# Multimodal templates +# The default mistral template raises exceptions so we use a custom one. +mistral_custom_template = """ +{{- bos_token }} +{%- for message in messages %} + {%- if message['role'] == 'user' %} + {{- '[INST] ' + message['content'] + '[/INST]' }} + {%- elif message['role'] == 'assistant' %} + {{- ' ' + message['content'] + eos_token}} + {%- endif %} +{%- endfor %} +{% if add_generation_prompt %}{{ ' ' }}{% endif %} +""" + +nvlm_yi_34b_template = "{{- bos_token }}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" # pylint: disable=line-too-long + +qwen2p0_custom_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" # pylint: disable=line-too-long + +# Note: this is the same template as +# https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct/blob/main/tokenizer_config.json#L2053 +# but we removed the forced system message. +llama3p1_chat_template = """{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- set date_string = \"26 Jul 2024\" %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = none %}\n{%- endif %}\n\n{%- if system_message is not none %}{#- System message + builtin tools #}\n{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n{%- if builtin_tools is defined or tools is not none %}\n {{- \"Environment: ipython\\n\" }}\n{%- endif %}\n{%- if builtin_tools is defined %}\n {{- \"Tools: \" + builtin_tools | reject('equalto', 'code_interpreter') | join(\", \") + \"\\n\\n\"}}\n{%- endif %}{%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- \"<|eot_id|>\" }}\n\n{%-endif %}{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot_id|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}\n {%- elif 'tool_calls' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {%- if builtin_tools is defined and tool_call.name in builtin_tools %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- \"<|python_tag|>\" + tool_call.name + \".call(\" }}\n {%- for arg_name, arg_val in tool_call.arguments | items %}\n {{- arg_name + '=\"' + arg_val + '\"' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \")\" }}\n {%- else %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- \"}\" }}\n {%- endif %}\n {%- if builtin_tools is defined %}\n {#- This means we're in ipython mode #}\n {{- \"<|eom_id|>\" }}\n {%- else %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n""" # pylint: disable=line-too-long + +nemotron_custom_template = "{{- bos_token }}{% for message in messages %}{{'' + message['role'] + '\n' + message['content'] + '' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ 'assistant\n' }}{% endif %}" # pylint: disable=line-too-long + +nemotron_aligned_custom_template = "{{- bos_token}}{% for message in messages %}{{message['role'] + '\n' + message['content'] + '\n' + '[PREFIX]'}}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant\n' }}{% endif %}" # pylint: disable=line-too-long + + +# --------------------------------------------------------------------------- +# Prompt format registry +# --------------------------------------------------------------------------- + + +def _build_sft_nemotron_nano_v2(tokenizer): + return PromptConfig( + assistant_prefix_len=3, + pad_token_id=_token_to_id(tokenizer, ""), + custom_chat_template=nemotron_nano_v2_custom_template, + has_bos=False, + has_system_role=True, + allow_tool_role=True, + ) + + +def _build_sft_nemotron_h_aligned(tokenizer): + return PromptConfig( + assistant_prefix_len=0, + pad_token_id=_token_to_id(tokenizer, ""), + custom_chat_template=nemotron_h_aligned_custom_template, + has_bos=False, + has_system_role=True, + allow_tool_role=True, + ) + + +def _build_sft_identity(tokenizer): + return PromptConfig( + assistant_prefix_len=0, + pad_token_id=_token_to_id(tokenizer, ""), + custom_chat_template=identity_template, + has_bos=False, + has_system_role=True, + allow_tool_role=True, + ) + + +def _build_sft_default(tokenizer): + pad = _get_pad_token_id(tokenizer) + return PromptConfig( + assistant_prefix_len=0, + pad_token_id=(pad if pad is not None else _get_eos_token_id(tokenizer)), + custom_chat_template=_get_chat_template(tokenizer), + has_bos=_get_bos_token_id(tokenizer) is not None, + has_system_role=True, + skip_masking=True, + allow_tool_role=True, + ) + + +def _build_multimodal_mistral(tokenizer): + return PromptConfig( + assistant_prefix_len=0, + pad_token_id=_get_unk_token_id(tokenizer), + custom_chat_template=mistral_custom_template, + has_bos=True, + has_system_role=False, + validate_no_image_in_assistant=True, + ) + + +def _build_multimodal_llama3(tokenizer): + return PromptConfig( + assistant_prefix_len=4, + pad_token_id=_token_to_id(tokenizer, "<|end_of_text|>"), + custom_chat_template=None, + has_bos=True, + has_system_role=True, + validate_no_image_in_assistant=True, + ) + + +def _build_multimodal_llama3p1(tokenizer): + return PromptConfig( + assistant_prefix_len=4, + pad_token_id=_token_to_id(tokenizer, "<|finetune_right_pad_id|>"), + custom_chat_template=llama3p1_chat_template, + has_bos=True, + has_system_role=True, + validate_no_image_in_assistant=True, + ) + + +def _build_multimodal_nvlm_yi_34b(tokenizer): + return PromptConfig( + assistant_prefix_len=4, + pad_token_id=_get_pad_token_id(tokenizer), + custom_chat_template=nvlm_yi_34b_template, + has_bos=True, + has_system_role=True, + validate_no_image_in_assistant=True, + ) + + +def _build_multimodal_chatml(tokenizer): + return PromptConfig( + assistant_prefix_len=3, + pad_token_id=_get_pad_token_id(tokenizer), + custom_chat_template=None, + has_bos=False, + has_system_role=True, + validate_no_image_in_assistant=True, + ) + + +def _build_multimodal_nemotron5(tokenizer): + return PromptConfig( + assistant_prefix_len=3, + pad_token_id=_token_to_id(tokenizer, ""), + custom_chat_template=nemotron_custom_template, + has_bos=True, + has_system_role=True, + validate_no_image_in_assistant=True, + ) + + +def _build_multimodal_nemotron5_aligned(tokenizer): + return PromptConfig( + assistant_prefix_len=2, + pad_token_id=_token_to_id(tokenizer, ""), + custom_chat_template=nemotron_aligned_custom_template, + has_bos=True, + has_system_role=True, + capitalize_roles=True, + validate_no_image_in_assistant=True, + ) + + +def _build_multimodal_qwen2(tokenizer, force_system_message=False): + return PromptConfig( + assistant_prefix_len=3, + pad_token_id=_get_pad_token_id(tokenizer), + custom_chat_template=qwen2p0_custom_template, + has_bos=False, + has_system_role=True, + force_system_message=force_system_message, + system_default={ + "role": "system", + "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant.", + }, + validate_no_image_in_assistant=True, + ) + + +# Registry mapping prompt format names to factory functions. +# Factory functions take a tokenizer and return a PromptConfig. +# Some entries map to the same factory (e.g. llama3p1 and llama3p2). +PROMPT_FORMAT_REGISTRY: Dict[str, callable] = { + # SFT formats + "nemotron-nano-v2": _build_sft_nemotron_nano_v2, + "nemotron-h-aligned": _build_sft_nemotron_h_aligned, + "identity": _build_sft_identity, + "default": _build_sft_default, + # Multimodal formats + "mistral": _build_multimodal_mistral, + "llama3": _build_multimodal_llama3, + "llama3p1": _build_multimodal_llama3p1, + "llama3p2": _build_multimodal_llama3p1, # Same config as llama3p1 + "nvlm-yi-34b": _build_multimodal_nvlm_yi_34b, + "chatml": _build_multimodal_chatml, + "nemotron5": _build_multimodal_nemotron5, + "nemotron5-aligned": _build_multimodal_nemotron5_aligned, + "qwen2p0": _build_multimodal_qwen2, + "qwen2p5": _build_multimodal_qwen2, +} diff --git a/megatron/core/tokenizers/megatron_tokenizer.py b/megatron/core/tokenizers/megatron_tokenizer.py index 31694d91af0..f12b0b3bfe7 100644 --- a/megatron/core/tokenizers/megatron_tokenizer.py +++ b/megatron/core/tokenizers/megatron_tokenizer.py @@ -4,31 +4,11 @@ import json import logging import os -from collections import OrderedDict from typing import Optional, Union from megatron.core.tokenizers.base_tokenizer import MegatronTokenizerBase -TOKENIZER_MAPPING_NAMES = OrderedDict( - [ - ("default-text", "DefaultTokenizerText"), - ("gpt", "GPTTokenizer"), - ("mamba", "MambaTokenizer"), - ("bert", "BertTokenizer"), - ("t5", "T5Tokenizer"), - ("default-vision", "DefaultTokenizerVision"), - ] -) - -TEXT_LIBRARIES = [ - "sentencepiece", - "huggingface", - "megatron", - "tiktoken", - "byte-level", - "null-text", - "sft", -] +TEXT_LIBRARIES = ["sentencepiece", "huggingface", "megatron", "tiktoken", "byte-level", "null-text"] VISION_LIBRARIES = ["multimodal", "null-multimodal"] logger = logging.getLogger(__name__) @@ -117,8 +97,7 @@ def write_metadata( tokenizer_path (str): path to tokenizer model. tokenizer_library (str): tokenizer model library. model_type (str): type of the model to be used with tokenizer. - list of available model types: [gpt, bert, t5, mamba, default]. - `DefaultTokenizerText` will be used if model_type is not specified. + Kept for backward compatibility but no longer used for class resolution. tokenizer_class (MegatronTokenizerBase): pre-defined tokenizer class. chat_template (str): tokenizer chat template in jinja format. overwrite (bool): overwrites existing metadata file if set to True. @@ -129,7 +108,6 @@ def write_metadata( MegatronTokenizer.write_metadata( tokenizer_path='/path/to/tokenzier/model', tokenizer_library='sentencepiece', - model_type='llama', ) """ @@ -187,29 +165,30 @@ def _get_metadata_path(tokenizer_path: str) -> str: def _get_tokenizer_model_class(library: str, metadata: dict) -> MegatronTokenizerBase: """ - Returns a class which corresponds to choosen tokenizer model type. + Returns a class which corresponds to the tokenizer type (text or vision). + + The model_type field in metadata is ignored since model-specific wrapper classes + have been removed. All text tokenizers use MegatronTokenizerText and all vision + tokenizers use MegatronTokenizerVision directly. Args: library (str): tokenizer library. metadata (dict): tokenizer metadata. Returns: - MegatronTokenizerBase: class for choosen tokenizer model type. + MegatronTokenizerBase: class for the tokenizer. """ # Return tokenizer class if it was specified in metadata. - if metadata.get('tokenizer_class', None): - return getattr(metadata['tokenizer_class_path'], metadata['tokenizer_class_name']) - - # Define tokenizer type - tokenizer_type = 'text' if library in TEXT_LIBRARIES else 'vision' - - module_name = f"megatron.core.tokenizers.{tokenizer_type}.models" - models = importlib.import_module(module_name) + if metadata.get('class_name', None) and metadata.get('class_path', None): + module = importlib.import_module(metadata['class_path']) + return getattr(module, metadata['class_name']) - model_type = metadata.get("model_type", None) - if model_type is None: - model_type = f"default-{tokenizer_type}" + # Resolve based on library type (text vs vision). + if library in TEXT_LIBRARIES: + from megatron.core.tokenizers.text.text_tokenizer import MegatronTokenizerText - tokenizer_cls = getattr(models, TOKENIZER_MAPPING_NAMES[model_type]) + return MegatronTokenizerText + else: + from megatron.core.tokenizers.vision.vision_tokenizer import MegatronTokenizerVision - return tokenizer_cls + return MegatronTokenizerVision diff --git a/megatron/core/tokenizers/text/libraries/__init__.py b/megatron/core/tokenizers/text/libraries/__init__.py index 92a97b31f9a..e9141b81b8d 100644 --- a/megatron/core/tokenizers/text/libraries/__init__.py +++ b/megatron/core/tokenizers/text/libraries/__init__.py @@ -5,5 +5,4 @@ from megatron.core.tokenizers.text.libraries.megatron_hf_tokenizer import MegatronHFTokenizer from megatron.core.tokenizers.text.libraries.null_tokenizer import NullTokenizer from megatron.core.tokenizers.text.libraries.sentencepiece_tokenizer import SentencePieceTokenizer -from megatron.core.tokenizers.text.libraries.sft_tokenizer import SFTTokenizer from megatron.core.tokenizers.text.libraries.tiktoken_tokenizer import TikTokenTokenizer diff --git a/megatron/core/tokenizers/text/libraries/abstract_tokenizer.py b/megatron/core/tokenizers/text/libraries/abstract_tokenizer.py index 360db03e5f2..98fb6aa95e8 100644 --- a/megatron/core/tokenizers/text/libraries/abstract_tokenizer.py +++ b/megatron/core/tokenizers/text/libraries/abstract_tokenizer.py @@ -1,7 +1,14 @@ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. from abc import ABC, abstractmethod -from typing import List +from typing import Dict, List, Optional, Union + +try: + from transformers.utils.chat_template_utils import _compile_jinja_template + + HAVE_TRANSFORMERS = True +except ImportError: + HAVE_TRANSFORMERS = False class MegatronTokenizerTextAbstract(ABC): @@ -9,6 +16,68 @@ class MegatronTokenizerTextAbstract(ABC): Abstract class for Megatron text tokenizers. """ + def apply_chat_template( + self, + conversation: List[Dict[str, str]], + chat_template: str = None, + tokenize: Optional[bool] = True, + truncation: Optional[bool] = False, + max_length: Optional[int] = None, + add_generation_prompt: Optional[bool] = False, + **kwargs, + ) -> Union[str, List[int]]: + """ + Applies tokenizer's chat template to the conversation using Jinja2. + + Args: + conversation (List[Dict[str, str]]): a list of dicts with "role" and "content" keys, + representing the chat history so far. + chat_template (str): Jinja2 chat template string. If not provided, falls back to + ``self.chat_template``. + tokenize (bool): whether to tokenize the output. If ``False``, + the output will be a string. + truncation (bool): whether to truncate sequences at the maximum length. + Has no effect if tokenize is ``False``. + max_length (int): maximum length to use for truncation. + Has no effect if tokenize is ``False``. + add_generation_prompt (bool): If set, a prompt with the token(s) that indicate + the start of an assistant message will be appended to the formatted output. + """ + if not chat_template: + chat_template = getattr(self, 'chat_template', None) + assert chat_template, ( + "Chat template is not defined. " + "Please, specify tokenizer chat template in the metadata file." + ) + if truncation: + assert max_length, "max_length must be specified if truncation is used." + + if HAVE_TRANSFORMERS: + compiled_template = _compile_jinja_template(chat_template) + chat_text = compiled_template.render( + messages=conversation, add_generation_prompt=add_generation_prompt + ) + + if tokenize: + chat_ids = self.text_to_ids(chat_text) + if truncation: + chat_ids = chat_ids[:max_length] + return chat_ids + + return chat_text + else: + raise ModuleNotFoundError("Please, install transformers library.") + + def token_to_id(self, token: str) -> int: + """Converts a single token to its ID. + + Concrete default so that every text library tokenizer exposes a canonical + single-token-to-ID method. SentencePiece, TikToken, and HuggingFace + override this with optimized versions; the default delegates to + ``tokens_to_ids``. + """ + return self.tokens_to_ids([token])[0] + @abstractmethod def text_to_tokens(self, text: str) -> List[str]: """ @@ -91,57 +160,3 @@ def ids_to_text(self, ids: List[int]) -> str: def add_special_tokens(self): """Adds special tokens to the tokenizer.""" pass - - @property - def cls_id(self) -> int: - """Property alias to match MegatronTokenizer; returns cls_id if available.""" - if hasattr(self, 'cls_id'): - return self.cls_id - raise AttributeError(f"{type(self).__name__} has no attribute 'cls' or 'cls_id'") - - @property - def sep_id(self) -> int: - """Property alias to match MegatronTokenizer; returns sep_id if available.""" - if hasattr(self, 'sep_id'): - return self.sep_id - raise AttributeError(f"{type(self).__name__} has no attribute 'sep' or 'sep_id'") - - @property - def pad_id(self) -> int: - """Property alias to match MegatronTokenizer; returns pad_id if available.""" - if hasattr(self, 'pad_id'): - return self.pad_id - raise AttributeError(f"{type(self).__name__} has no attribute 'pad' or 'pad_id'") - - @property - def eod(self) -> int: - """Property alias to match MegatronTokenizer; returns eod_id if available.""" - if hasattr(self, 'eod_id'): - return self.eod_id - if hasattr(self, 'eos_id'): - # Default to end-of-sentence id if end-of-document is not defined. - return self.eos_id - raise AttributeError( - f"{type(self).__name__} has no attribute 'eod', 'eod_id', 'eos', or 'eos_id'" - ) - - @property - def bos_id(self) -> int: - """Property alias to match MegatronTokenizer; returns bos_id if available.""" - if hasattr(self, 'bos_id'): - return self.bos_id - raise AttributeError(f"{type(self).__name__} has no attribute 'bos' or 'bos_id'") - - @property - def eos_id(self) -> int: - """Property alias to match MegatronTokenizer; returns eos_id if available.""" - if hasattr(self, 'eos_id'): - return self.eos_id - raise AttributeError(f"{type(self).__name__} has no attribute 'eos' or 'eos_id'") - - @property - def mask_id(self) -> int: - """Property alias to match MegatronTokenizer; returns mask_id if available.""" - if hasattr(self, 'mask_id'): - return self.mask_id - raise AttributeError(f"{type(self).__name__} has no attribute 'mask' or 'mask_id'") diff --git a/megatron/core/tokenizers/text/libraries/chat_template.py b/megatron/core/tokenizers/text/libraries/chat_template.py deleted file mode 100644 index aa43cb12e4f..00000000000 --- a/megatron/core/tokenizers/text/libraries/chat_template.py +++ /dev/null @@ -1,71 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. - -from typing import Dict, List, Optional, Union - -try: - from transformers.utils.chat_template_utils import _compile_jinja_template - - HAVE_TRANSFORMERS = True -except ImportError: - HAVE_TRANSFORMERS = False - - -class MegatronTokenizerChatTemplate: - """Chat template class for Megatron text tokenizers.""" - - def apply_chat_template( - self, - conversation: List[Dict[str, str]], - chat_template: str, - tokenize: Optional[bool] = True, - truncation: Optional[bool] = False, - max_length: Optional[int] = None, - add_generation_prompt: Optional[bool] = False, - ) -> Union[str, List[int]]: - """ - Applies tokenizer's chat template to the conversation. - - Args: - conversation (List[Dict[str, str]]): a list of dicts with "role" and "content" keys, - representing the chat history so far. Conversation example: - [ - {"role": "system", "content": "You are a witty and helpful assistant."}, - {"role": "user", "content": "Hey, what's a fun fact about octopuses?"}, - {"role": "assistant", "content": "Octopuses blood is blue!"}, - {"role": "user", "content": "Whoa, why is their blood blue?"}, - ] - tokenize (bool): whether to tokenize the output. If `False`, - the output will be a string. - truncation (bool): whether to truncate sequences at the maximum length. - Has no effect if tokenize is `False`. - max_length (int): maximum length to use fro truncation. - Has no effect if tokenize is `False`. - add_generation_prompt (bool): If this is set, a prompt with the token(s) that indicate - the start of an assistant message will be appended to the formatted output. - This is useful when you want to generate a response from the model. - Note that this argument will be passed to the chat template, - and so it must be supported in the template for this argument to have any effect. - """ - - assert chat_template, ( - "Chat template is not defined. " - "Please, specify tokenizer chat template in the metadata file." - ) - if truncation: - assert max_length, "max_length must be specified if truncation is used." - - if HAVE_TRANSFORMERS: - compiled_template = _compile_jinja_template(chat_template) - chat_text = compiled_template.render( - messages=conversation, add_generation_prompt=add_generation_prompt - ) - - if tokenize: - chat_ids = self.text_to_ids(chat_text) - if truncation: - chat_ids = chat_ids[:max_length] - return chat_ids - - return chat_text - else: - raise ModuleNotFoundError("Please, install transformers library.") diff --git a/megatron/core/tokenizers/text/libraries/megatron_hf_tokenizer.py b/megatron/core/tokenizers/text/libraries/megatron_hf_tokenizer.py index 849d3f0de0e..eec55b73ce8 100644 --- a/megatron/core/tokenizers/text/libraries/megatron_hf_tokenizer.py +++ b/megatron/core/tokenizers/text/libraries/megatron_hf_tokenizer.py @@ -77,7 +77,20 @@ class MegatronHFTokenizer(HuggingFaceTokenizer): - """ """ + """Auto-download resolver for Megatron's legacy predefined tokenizer names. + + Maps predefined names (``GPT2BPETokenizer``, ``BertWordPieceCase``, + ``BertWordPieceLowerCase``, ``megatron-gpt-345m``, ``biomegatron-bert-*``, + etc.) to the corresponding HuggingFace tokenizer name and vocab/merges file + URLs. Files are automatically downloaded from NGC/S3 and cached under + ``~/.cache/torch/megatron/``. + + Once the name is resolved and files are fetched, all tokenization is + delegated to :class:`HuggingFaceTokenizer`. + + This class is selected via ``--tokenizer-library megatron``. See + :data:`MEGATRON_CONFIG_MAP` for the full list of supported model names. + """ def __init__( self, diff --git a/megatron/core/tokenizers/text/libraries/sentencepiece_tokenizer.py b/megatron/core/tokenizers/text/libraries/sentencepiece_tokenizer.py index feaf1c4e9a1..2370542fbf9 100644 --- a/megatron/core/tokenizers/text/libraries/sentencepiece_tokenizer.py +++ b/megatron/core/tokenizers/text/libraries/sentencepiece_tokenizer.py @@ -16,10 +16,9 @@ import torch from .abstract_tokenizer import MegatronTokenizerTextAbstract -from .chat_template import MegatronTokenizerChatTemplate -class SentencePieceTokenizer(MegatronTokenizerTextAbstract, MegatronTokenizerChatTemplate): +class SentencePieceTokenizer(MegatronTokenizerTextAbstract): """Sentencepiecetokenizer https://github.com/google/sentencepiece.""" def __init__( diff --git a/megatron/core/tokenizers/text/libraries/sft_tokenizer.py b/megatron/core/tokenizers/text/libraries/sft_tokenizer.py deleted file mode 100644 index 8a418f2dd7f..00000000000 --- a/megatron/core/tokenizers/text/libraries/sft_tokenizer.py +++ /dev/null @@ -1,254 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. - -from dataclasses import dataclass -from typing import Dict, List, Union - -import numpy as np - -try: - import transformers - - HAVE_TRANSFORMERS = True -except ModuleNotFoundError: - HAVE_TRANSFORMERS = False - - -# fmt: off -nemotron_h_aligned_custom_template = """{% for message in messages %}{% if message['role'] == 'system' %}{{ 'System\n' + message['content'].strip() + '\n' }}{% elif message['role'] == 'user' %}{{ 'User\n' + message['content'].strip() + '\n' + 'Assistant\n' }}{% elif message['role'] == 'assistant' %}{{ message['content'].strip() + '\n' }}{% endif %}{% endfor %}""" # pylint: disable=line-too-long -nemotron_nano_v2_custom_template = """{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'system' %}{{ 'System\n' + content.replace('/think', '').replace('/no_think', '').strip() + '\n' }}{% elif message['role'] == 'user' %}{{ 'User\n' + content.replace('/think', '').replace('/no_think', '').strip() + '\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant\n' + content.strip() + '\n\n' }}{% endif %}{% endfor %}""" # pylint: disable=line-too-long -identity_template = """{% for message in messages %}{{ message['content'] }}{% endfor %}""" -# fmt: on - - -IGNORE_INDEX = -100 - - -@dataclass -class PromptConfig: - """Config options for different prompt formats.""" - - # How many tokens are used for the assistant prefix, e.g. "<|im_start|>assistant\n". - # Used for masking the assistant prefix. - assistant_prefix_len: int - # Padding token ID. - pad_token_id: int - # For overriding the default chat format template. - custom_chat_template: str - # If the tokenizer inserts BOS token by default. - has_bos: bool - # If the tokenizer supports a separate role for system messages. - has_system_role: bool - # Wether to force a specific system message. - force_system_message: bool = False - system_default: dict = None - - -class SFTTokenizer: - """SFT Tokenizer.""" - - def __init__(self, tokenizer_path: str, prompt_format: str): - """ - Note: Currently, only HuggingFaceTokenizer is supported as the underlying text tokenizer. - - Args: - tokenizer_path (str): Underlying tokenizer path. - prompt_format (str): Prompt format for the tokenizer. - """ - if HAVE_TRANSFORMERS: - # Currently, only HuggingFace tokenizers are supported. - tokenizer = transformers.AutoTokenizer.from_pretrained( - pretrained_model_name_or_path=tokenizer_path - ) - else: - raise ImportError( - "SFTTokenizer currently requires transformers library to be installed" - ) - - self._vocab_size = len(tokenizer) - self._tokenizer = tokenizer - - if prompt_format == "nemotron-nano-v2": - self._prompt_config = PromptConfig( - assistant_prefix_len=3, - pad_token_id=tokenizer.convert_tokens_to_ids(""), - custom_chat_template=nemotron_nano_v2_custom_template, - has_bos=False, - has_system_role=True, - ) - elif prompt_format == "nemotron-h-aligned": - self._prompt_config = PromptConfig( - assistant_prefix_len=0, - pad_token_id=tokenizer.convert_tokens_to_ids(""), - custom_chat_template=nemotron_h_aligned_custom_template, - has_bos=False, - has_system_role=True, - ) - elif prompt_format == "identity": - self._prompt_config = PromptConfig( - assistant_prefix_len=0, - pad_token_id=tokenizer.convert_tokens_to_ids(""), - custom_chat_template=identity_template, - has_bos=False, - has_system_role=True, - ) - elif prompt_format == "default": - self._prompt_config = PromptConfig( - assistant_prefix_len=0, - pad_token_id=( - tokenizer.pad_token_id - if tokenizer.pad_token_id is not None - else tokenizer.eos_token_id - ), - custom_chat_template=tokenizer.chat_template, - has_bos=tokenizer.bos_token_id is not None, - has_system_role=True, - ) - else: - raise NotImplementedError("unknown SFT prompt format", prompt_format) - - self._prompt_format = prompt_format - - def tokenize_conversation( - self, conversation: List[Dict], return_target: bool, add_generation_prompt: bool - ): - """Convert a conversation to tokens. - - Args: - conversation (List[Dict]): Sequence of system/user/assistant messages. - Must be in the following format: - [ - {"role": "system", "content": "something"}, - {"role": "user", "content": "something1"}, - {"role": "assistant", "content": "something2"}, - ] - return_target (bool): Return target tokens with system and assistant masked. - add_generation_prompt (bool): Add assistant prefix to the end. - """ - # Skip system message if the tokenizer doesn't have a system role. - if not self._prompt_config.has_system_role and conversation[0]["role"] == "system": - conversation = conversation[1:] - - tokens = self._tokenizer.apply_chat_template( - conversation, - tokenize=True, - add_generation_prompt=add_generation_prompt, - return_assistant_token_mask=False, - return_tensors="np", - chat_template=self._prompt_config.custom_chat_template, - )[0] - - if not return_target: - return tokens - - target = tokens.copy() - - # When using the default prompt format, we do not replace any tokens with IGNORE_INDEX. - # Instead, all token losses will be used for simplicity. - if self._prompt_format == "default": - return tokens, target - - # Mask system and user tokens in the target. - idx = 0 - for turn_idx, turn in enumerate(conversation): - - if turn["role"].lower() == "assistant" and len(turn["content"]) == 0: - raise ValueError(f"empty assistant turn in conversation: {conversation}.") - if turn["role"].lower() == "assistant": - assert conversation[turn_idx - 1]["role"].lower() in ("user", "tool") - - turn_tokens = self._tokenizer.apply_chat_template( - [turn], tokenize=True, chat_template=self._prompt_config.custom_chat_template - ) - - # There should be only one BOS at the very beginning. - # After the first turn, skip BOS token. - if self._prompt_config.has_bos and turn_idx > 0: - turn_tokens = turn_tokens[1:] - turn_len = len(turn_tokens) - - role = turn["role"].lower() - if role in ("system", "user", "tool"): - target[idx : idx + turn_len] = IGNORE_INDEX - elif role == "assistant": - if self._prompt_config.assistant_prefix_len > 0: - target[idx : idx + self._prompt_config.assistant_prefix_len] = IGNORE_INDEX - else: - raise ValueError("Wrong role value.") - - assert np.allclose( - tokens[idx : idx + turn_len], turn_tokens - ), f"expected turn tokens to match tokens in conversation {conversation}" - - idx += turn_len - - assert idx == len(tokens), f"mismatch in target masking the conversation {conversation}" - - return tokens, target - - def text_to_ids(self, text: Union[str, List[Dict]]): - """Tokenize conversation or string input.""" - if isinstance(text, list): - # This code path is used by the inference code currently. - return self.tokenize_conversation( - text, return_target=False, add_generation_prompt=True - ).tolist() - - return self._tokenizer.encode(text) - - def tokens_to_ids(self, tokens: List[str]): - """Convert tokens to IDs.""" - return self._tokenizer.convert_tokens_to_ids(tokens) - - def ids_to_text(self, tokens: List[int]): - """Detokenize tokens.""" - return self._tokenizer.decode(tokens) - - def ids_to_tokens(self): - """Converts ids to tokens.""" - raise NotImplementedError("This method is not supported for SFTTokenizer.") - - def text_to_tokens(self): - """Converts text to tokens.""" - raise NotImplementedError("This method is not supported for SFTTokenizer.") - - def tokens_to_text(self): - """Converts tokens to text.""" - raise NotImplementedError("This method is not supported for SFTTokenizer.") - - def get_special_tokens(self): - """Get special tokens.""" - return self._tokenizer.get_added_vocab() - - def add_special_tokens(self): - """Add special tokens.""" - raise NotImplementedError("This method is not supported for SFTTokenizer.") - - @property - def pad_id(self): - """Pad token ID.""" - return self._prompt_config.pad_token_id - - @property - def bos_id(self): - """Beginning of sequence token ID.""" - return self._tokenizer.bos_token_id - - @property - def eod(self): - """End of sentence token ID.""" - return self._tokenizer.eos_token_id - - @property - def vocab(self): - """Vocab.""" - return NotImplementedError("not used") - - @property - def inv_vocab(self): - """Inverse vocab.""" - return NotImplementedError("not used") - - @property - def vocab_size(self): - """Vocabulary size.""" - return self._vocab_size diff --git a/megatron/core/tokenizers/text/libraries/tiktoken_tokenizer.py b/megatron/core/tokenizers/text/libraries/tiktoken_tokenizer.py index 39228ad4afd..939265bea1c 100644 --- a/megatron/core/tokenizers/text/libraries/tiktoken_tokenizer.py +++ b/megatron/core/tokenizers/text/libraries/tiktoken_tokenizer.py @@ -13,7 +13,6 @@ pass from .abstract_tokenizer import MegatronTokenizerTextAbstract -from .chat_template import MegatronTokenizerChatTemplate logger = logging.getLogger(__name__) @@ -73,7 +72,7 @@ def reload_mergeable_ranks( return ranks -class TikTokenTokenizer(MegatronTokenizerTextAbstract, MegatronTokenizerChatTemplate): +class TikTokenTokenizer(MegatronTokenizerTextAbstract): """TikTokenTokenizer https://github.com/openai/tiktoken.""" def __init__( diff --git a/megatron/core/tokenizers/text/models/__init__.py b/megatron/core/tokenizers/text/models/__init__.py index d1788adb417..b85f22cc4d6 100644 --- a/megatron/core/tokenizers/text/models/__init__.py +++ b/megatron/core/tokenizers/text/models/__init__.py @@ -1,7 +1,12 @@ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -from megatron.core.tokenizers.text.models.bert_tokenizer import BertTokenizer -from megatron.core.tokenizers.text.models.default_tokenizer import DefaultTokenizerText -from megatron.core.tokenizers.text.models.gpt_tokenizer import GPTTokenizer -from megatron.core.tokenizers.text.models.mamba_tokenizer import MambaTokenizer -from megatron.core.tokenizers.text.models.t5_tokenizer import T5Tokenizer +# The individual model wrapper classes (GPTTokenizer, BertTokenizer, etc.) have been removed +# as they were empty subclasses that added no functionality. These aliases are kept for +# backward compatibility with any code that imports them by name. +from megatron.core.tokenizers.text.text_tokenizer import MegatronTokenizerText as BertTokenizer +from megatron.core.tokenizers.text.text_tokenizer import ( + MegatronTokenizerText as DefaultTokenizerText, +) +from megatron.core.tokenizers.text.text_tokenizer import MegatronTokenizerText as GPTTokenizer +from megatron.core.tokenizers.text.text_tokenizer import MegatronTokenizerText as MambaTokenizer +from megatron.core.tokenizers.text.text_tokenizer import MegatronTokenizerText as T5Tokenizer diff --git a/megatron/core/tokenizers/text/models/bert_tokenizer.py b/megatron/core/tokenizers/text/models/bert_tokenizer.py deleted file mode 100644 index b577596ed2d..00000000000 --- a/megatron/core/tokenizers/text/models/bert_tokenizer.py +++ /dev/null @@ -1,12 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. - -from megatron.core.tokenizers.text.text_tokenizer import MegatronTokenizerText - - -class BertTokenizer(MegatronTokenizerText): - """Base class for Megatron Bert tokenizer.""" - - def __init__(self, path: str = None, config: dict = None, **kwargs) -> None: - config['class_name'] = self.__class__.__name__ - config['class_path'] = self.__class__.__module__ - super().__init__(path, config, **kwargs) diff --git a/megatron/core/tokenizers/text/models/default_tokenizer.py b/megatron/core/tokenizers/text/models/default_tokenizer.py deleted file mode 100644 index 95ad765e17c..00000000000 --- a/megatron/core/tokenizers/text/models/default_tokenizer.py +++ /dev/null @@ -1,12 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. - -from megatron.core.tokenizers.text.text_tokenizer import MegatronTokenizerText - - -class DefaultTokenizerText(MegatronTokenizerText): - """Base class for Megatron default tokenizer.""" - - def __init__(self, path: str = None, config: dict = None, **kwargs) -> None: - config['class_name'] = self.__class__.__name__ - config['class_path'] = self.__class__.__module__ - super().__init__(path, config, **kwargs) diff --git a/megatron/core/tokenizers/text/models/gpt_tokenizer.py b/megatron/core/tokenizers/text/models/gpt_tokenizer.py deleted file mode 100644 index 4f11a2bec01..00000000000 --- a/megatron/core/tokenizers/text/models/gpt_tokenizer.py +++ /dev/null @@ -1,12 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. - -from megatron.core.tokenizers.text.text_tokenizer import MegatronTokenizerText - - -class GPTTokenizer(MegatronTokenizerText): - """Base class for Megatron GPT tokenizer.""" - - def __init__(self, path: str = None, config: dict = None, **kwargs) -> None: - config['class_name'] = self.__class__.__name__ - config['class_path'] = self.__class__.__module__ - super().__init__(path, config, **kwargs) diff --git a/megatron/core/tokenizers/text/models/mamba_tokenizer.py b/megatron/core/tokenizers/text/models/mamba_tokenizer.py deleted file mode 100644 index 9f7d76d17b5..00000000000 --- a/megatron/core/tokenizers/text/models/mamba_tokenizer.py +++ /dev/null @@ -1,12 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. - -from megatron.core.tokenizers.text.text_tokenizer import MegatronTokenizerText - - -class MambaTokenizer(MegatronTokenizerText): - """Base class for Megatron Mamba tokenizer.""" - - def __init__(self, path: str = None, config: dict = None, **kwargs) -> None: - config['class_name'] = self.__class__.__name__ - config['class_path'] = self.__class__.__module__ - super().__init__(path, config, **kwargs) diff --git a/megatron/core/tokenizers/text/models/t5_tokenizer.py b/megatron/core/tokenizers/text/models/t5_tokenizer.py deleted file mode 100644 index 013f41d22c4..00000000000 --- a/megatron/core/tokenizers/text/models/t5_tokenizer.py +++ /dev/null @@ -1,12 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. - -from megatron.core.tokenizers.text.text_tokenizer import MegatronTokenizerText - - -class T5Tokenizer(MegatronTokenizerText): - """Base class for Megatron T5 tokenizer.""" - - def __init__(self, path: str = None, config: dict = None, **kwargs) -> None: - config['class_name'] = self.__class__.__name__ - config['class_path'] = self.__class__.__module__ - super().__init__(path, config, **kwargs) diff --git a/megatron/core/tokenizers/text/text_tokenizer.py b/megatron/core/tokenizers/text/text_tokenizer.py index 0145ae353ca..b619998d207 100644 --- a/megatron/core/tokenizers/text/text_tokenizer.py +++ b/megatron/core/tokenizers/text/text_tokenizer.py @@ -14,7 +14,6 @@ ("tiktoken", "TikTokenTokenizer"), ("byte-level", "ByteLevelTokenizer"), ("null-text", "NullTokenizer"), - ("sft", "SFTTokenizer"), ] ) @@ -34,6 +33,8 @@ def __init__(self, path: str, config: dict, **kwargs) -> None: chat_template (str): tokenizer chat template. """ + config.setdefault('class_name', self.__class__.__name__) + config.setdefault('class_path', self.__class__.__module__) super().__init__(path, config, **kwargs) self._tokenizer = self._restore_model(**kwargs) self.additional_args = kwargs @@ -50,6 +51,18 @@ def __init__(self, path: str, config: dict, **kwargs) -> None: else: self.chat_template = kwargs_template + # SFT prompt config: when prompt_format is provided, SFT conversation + # tokenization becomes available on any text tokenizer library. + from megatron.core.tokenizers.conversation import PROMPT_FORMAT_REGISTRY, PromptConfig + + prompt_format = kwargs.get('prompt_format', None) + if prompt_format is not None: + if prompt_format not in PROMPT_FORMAT_REGISTRY: + raise NotImplementedError(f"unknown prompt format: {prompt_format}") + self._prompt_config = PROMPT_FORMAT_REGISTRY[prompt_format](self._tokenizer) + else: + self._prompt_config = None + def _restore_model(self, **kwargs) -> MegatronTokenizerTextAbstract: """Returns tokenizer library object.""" @@ -57,10 +70,14 @@ def _restore_model(self, **kwargs) -> MegatronTokenizerTextAbstract: library_class = getattr(tokenizers, TOKENIZER_MAPPING_LIBRARIES[self.library]) + # Filter out mode-level kwargs consumed by MegatronTokenizerText, not the library. + _MODE_KWARGS = ('prompt_format',) + library_kwargs = {k: v for k, v in kwargs.items() if k not in _MODE_KWARGS} + if self.library in ['byte-level', 'null-text']: - return library_class(**kwargs) + return library_class(**library_kwargs) else: - return library_class(self.path, **kwargs) + return library_class(self.path, **library_kwargs) def tokenize(self, text: str) -> List[int]: """ @@ -116,7 +133,9 @@ def apply_chat_template( def tokenize_conversation( self, conversation: List[Dict], return_target: bool, add_generation_prompt: bool ): - """Convert a conversation to tokens. Needed for SFTTokenizer. + """Convert a conversation to tokens. + + Requires ``--tokenizer-prompt-format`` to be configured (SFT mode). Args: conversation (List[Dict]): Sequence of system/user/assistant messages. @@ -130,14 +149,20 @@ def tokenize_conversation( add_generation_prompt (bool): Add assistant prefix to the end. """ - if self.library == 'sft': - return self._tokenizer.tokenize_conversation( - conversation=conversation, - return_target=return_target, - add_generation_prompt=add_generation_prompt, + if self._prompt_config is None: + raise RuntimeError( + "tokenize_conversation requires --tokenizer-prompt-format to be configured. " + f"Current library: {self.library}" ) - else: - raise NotImplementedError("This method is supported only for SFTTokenizer.") + from megatron.core.tokenizers.conversation import tokenize_conversation + + return tokenize_conversation( + tokenizer=self._tokenizer, + conversation=conversation, + prompt_config=self._prompt_config, + return_target=return_target, + add_generation_prompt=add_generation_prompt, + ) def save_pretrained(self, path: str) -> None: """ @@ -211,17 +236,23 @@ def unique_identifiers(self) -> OrderedDict: @property def pad(self) -> int: """Returns id of padding token.""" + if self._prompt_config is not None: + return self._prompt_config.pad_token_id return self._tokenizer.pad_id @property def pad_id(self) -> int: """Returns id of padding token. Need for NeMo.""" + if self._prompt_config is not None: + return self._prompt_config.pad_token_id return self._tokenizer.pad_id @property def eod(self) -> int: """Returns id of end of document token.""" - return self._tokenizer.eod + if hasattr(self._tokenizer, 'eod'): + return self._tokenizer.eod + return self._tokenizer.eos_id @property def bos(self) -> int: diff --git a/megatron/core/tokenizers/tokenizers.MD b/megatron/core/tokenizers/tokenizers.MD new file mode 100644 index 00000000000..4ed756c5710 --- /dev/null +++ b/megatron/core/tokenizers/tokenizers.MD @@ -0,0 +1,478 @@ +# Megatron Tokenizer API + +## Architecture Overview + +The tokenizer subsystem is organized into a layered architecture: a **factory** entry point, **text/vision wrappers** that expose a unified API, and **library implementations** that wrap specific tokenization backends. + +``` +megatron/core/tokenizers/ + megatron_tokenizer.py # Factory: MegatronTokenizer.from_pretrained() + base_tokenizer.py # ABC: MegatronTokenizerBase + text/ + text_tokenizer.py # Wrapper: MegatronTokenizerText + libraries/ + abstract_tokenizer.py # ABC for all text libraries + sentencepiece_tokenizer.py # Google SentencePiece + huggingface_tokenizer.py # HuggingFace AutoTokenizer + megatron_hf_tokenizer.py # Megatron-native HF (BPE, WordPiece) + tiktoken_tokenizer.py # OpenAI TikToken + bytelevel_tokenizer.py # Raw byte-level + null_tokenizer.py # No-op (testing / integer passthrough) + models/ # (reserved for future model-specific configs) + parsers/ # Output parsers (DeepSeek-R1, Qwen3 tool, etc.) + vision/ + vision_tokenizer.py # Wrapper: MegatronTokenizerVision + libraries/ + multimodal_tokenizer.py # HF-backed multimodal + null_multimodal_tokenizer.py # No-op multimodal + models/ + conversation/ + conversation_tokenizer.py # tokenize_conversation() function + prompt_config.py # PromptConfig dataclass + PROMPT_FORMAT_REGISTRY + utils/ + build_tokenizer.py # build_tokenizer(args) from CLI arguments +``` + +### Class Hierarchy + +``` +MegatronTokenizerBase (ABC) + MegatronTokenizerText # wraps a library tokenizer + library tokenizer (MegatronTokenizerTextAbstract) + SentencePieceTokenizer + HuggingFaceTokenizer + MegatronHFTokenizer + TikTokenTokenizer + ByteLevelTokenizer + NullTokenizer + MegatronTokenizerVision # wraps a vision library tokenizer + MegatronMultimodalTokenizer + MegatronNullMultimodalTokenizer +``` + +--- + +## Tokenizer Libraries + +Each library wraps a different tokenization backend. The library is selected via `--tokenizer-library` or the `library` field in metadata. + +### Text Libraries + +| Library name | Class | Backend | Use case | +|---|---|---|---| +| `sentencepiece` | `SentencePieceTokenizer` | Google SentencePiece protobuf | LLaMA, Mistral, and other SP-based models | +| `huggingface` | `HuggingFaceTokenizer` | `transformers.AutoTokenizer` | Any model on HuggingFace Hub or local HF tokenizer | +| `megatron` | `MegatronHFTokenizer` | Megatron-native BPE / WordPiece via HF | GPT-2 BPE (`GPT2BPETokenizer`), BERT WordPiece (`BertWordPieceCase`, `BertWordPieceLowerCase`) | +| `tiktoken` | `TikTokenTokenizer` | OpenAI TikToken | Models using TikToken vocabulary (e.g. custom vocab.json) | +| `byte-level` | `ByteLevelTokenizer` | Raw bytes | Byte-level models; no tokenizer model file needed | +| `null` | `NullTokenizer` | Space-split integers | Testing and debugging; tokenizes `"11 325 97"` to `[11, 325, 97]` | + +### Vision Libraries + +| Library name | Class | Backend | Use case | +|---|---|---|---| +| `multimodal` | `MegatronMultimodalTokenizer` | HF AutoTokenizer + conversation module | Multimodal models (LLaVA, NVLM, etc.) with image token support | +| `null-multimodal` | `MegatronNullMultimodalTokenizer` | Space-split integers | Testing multimodal pipelines without a real tokenizer | + +--- + +## Command-Line API + +### Primary Flags + +``` +--tokenizer-library {huggingface,sentencepiece,tiktoken,megatron,byte-level,null} +--tokenizer-mode {text,sft,multimodal} # default: text +--tokenizer-model PATH # path to tokenizer file or directory +--tokenizer-prompt-format FORMAT # required for sft and multimodal modes +``` + +### Library-Specific Flags + +**SentencePiece:** +``` +--tokenizer-sentencepiece-legacy # restore legacy SentencePiece behavior +``` + +**HuggingFace / Megatron:** +``` +--tokenizer-hf-no-use-fast # disable fast tokenizer +--tokenizer-hf-no-include-special-tokens # exclude special tokens from encoding +``` + +**TikToken:** +``` +--tiktoken-pattern PATTERN # regex pattern for TikToken +--tiktoken-num-special-tokens N # default: 1000 +``` + +**Common:** +``` +--tokenizer-special-tokens TOK1 TOK2 ... # additional special token strings +--tokenizer-metadata PATH # path to metadata JSON +--vocab-file PATH # vocabulary file (megatron library) +--merge-file PATH # BPE merges file (megatron library) +--vocab-size N # vocabulary size (tiktoken, null, byte-level) +``` + +### Usage Examples + +```bash +# SentencePiece (e.g. LLaMA) +--tokenizer-library sentencepiece \ + --tokenizer-model /path/to/tokenizer.model + +# HuggingFace (e.g. any HF model) +--tokenizer-library huggingface \ + --tokenizer-model /path/to/hf-tokenizer-dir + +# TikToken +--tokenizer-library tiktoken \ + --tokenizer-model /path/to/tiktoken.vocab.json \ + --tiktoken-pattern v2 + +# Megatron-native BPE (GPT-2) +--tokenizer-library megatron \ + --tokenizer-model GPT2BPETokenizer \ + --vocab-file /path/to/gpt2-vocab.json \ + --merge-file /path/to/gpt2-merges.txt + +# Megatron-native WordPiece (BERT / T5) +--tokenizer-library megatron \ + --tokenizer-model BertWordPieceCase \ + --vocab-file /path/to/bert-large-cased-vocab.txt + +# SFT mode (supervised fine-tuning with any text library) +--tokenizer-library huggingface \ + --tokenizer-mode sft \ + --tokenizer-model /path/to/hf-tokenizer \ + --tokenizer-prompt-format nemotron-h-aligned + +# Multimodal +--tokenizer-library huggingface \ + --tokenizer-mode multimodal \ + --tokenizer-model /path/to/tokenizer \ + --tokenizer-prompt-format qwen2p0 + +# Null (testing) +--tokenizer-library null \ + --vocab-size 131072 +``` + +--- + +## Programmatic API + +### Loading a Tokenizer + +Use `MegatronTokenizer.from_pretrained()` (not the constructor): + +```python +from megatron.core.tokenizers import MegatronTokenizer + +# With metadata file (tokenizer_metadata.json auto-detected in tokenizer directory) +tokenizer = MegatronTokenizer.from_pretrained("/path/to/tokenizer") + +# With inline metadata dict +tokenizer = MegatronTokenizer.from_pretrained( + tokenizer_path="/path/to/tokenizer.model", + metadata_path={"library": "sentencepiece"}, +) + +# With SFT prompt format +tokenizer = MegatronTokenizer.from_pretrained( + tokenizer_path="/path/to/hf-tokenizer", + metadata_path={"library": "huggingface"}, + prompt_format="nemotron-h-aligned", +) + +# With chat template +tokenizer = MegatronTokenizer.from_pretrained( + tokenizer_path="/path/to/tokenizer.model", + metadata_path={"library": "sentencepiece"}, + chat_template="{% for message in messages %}...", +) +``` + +### Saving Metadata + +```python +MegatronTokenizer.write_metadata( + tokenizer_path="/path/to/tokenizer", + tokenizer_library="huggingface", + chat_template="{% for message in messages %}...", # optional + overwrite=True, +) +``` + +This creates a `tokenizer_metadata.json` file in the tokenizer directory: + +```json +{ + "library": "huggingface", + "class_name": null, + "class_path": null, + "model_type": "default-text", + "chat_template": "{% for message in messages %}..." +} +``` + +### Building from CLI Arguments + +Inside Megatron training scripts, use `build_tokenizer(args)`: + +```python +from megatron.core.tokenizers.utils.build_tokenizer import build_tokenizer + +tokenizer = build_tokenizer(args) +``` + +This reads `--tokenizer-library`, `--tokenizer-mode`, `--tokenizer-model`, and all library-specific flags from `args`, then calls `MegatronTokenizer.from_pretrained()` internally. + +--- + +## Core Methods + +All tokenizers (text and vision) expose the same high-level interface: + +### Encoding & Decoding + +```python +# Text to token IDs +ids = tokenizer.tokenize("Hello, world!") + +# Token IDs to text +text = tokenizer.detokenize([15496, 11, 995, 0]) +``` + +### Chat Templates + +Apply a Jinja2 chat template to a conversation: + +```python +conversation = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello!"}, +] + +# Returns token IDs (default: tokenize=True) +ids = tokenizer.apply_chat_template(conversation, chat_template=template) + +# Returns formatted string +text = tokenizer.apply_chat_template( + conversation, chat_template=template, tokenize=False +) +``` + +The chat template resolution order is: +1. Explicit `chat_template` argument +2. `chat_template` from tokenizer metadata +3. `chat_template` from the underlying tokenizer (e.g. HF tokenizer config) + +### Special Token Properties + +| Property | Aliases | Description | +|---|---|---| +| `vocab_size` | | Total vocabulary size | +| `pad` | `pad_id` | Padding token ID (uses `PromptConfig.pad_token_id` in SFT mode) | +| `eod` | | End-of-document token ID (falls back to `eos_id`) | +| `eos_id` | `eos` | End-of-sentence token ID | +| `bos` | `bos_id` | Beginning-of-sentence token ID | +| `unk` | `unk_id` | Unknown token ID | +| `mask` | `mask_id` | Mask token ID (for MLM) | +| `cls` | `cls_id` | Classification token ID (for BERT-like models) | +| `sep` | `sep_id` | Separator token ID | + +--- + +## SFT & Conversation Tokenization + +SFT (Supervised Fine-Tuning) conversation tokenization is a **capability** on any text tokenizer, activated by providing a `prompt_format`. It is not a separate tokenizer library. + +### How It Works + +When `--tokenizer-mode sft` and `--tokenizer-prompt-format ` are set, the text tokenizer gains a `tokenize_conversation()` method that: + +1. Formats a conversation using a chat template (Jinja2) +2. Tokenizes the formatted text +3. Optionally produces **target tokens** with non-assistant content masked (`IGNORE_INDEX = -100`) for loss computation + +```python +conversation = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "What is 2+2?"}, + {"role": "assistant", "content": "4"}, +] + +# Get token IDs only +tokens = tokenizer.tokenize_conversation( + conversation, return_target=False, add_generation_prompt=False +) + +# Get token IDs and target mask for training +tokens, target = tokenizer.tokenize_conversation( + conversation, return_target=True, add_generation_prompt=False +) +# target has IGNORE_INDEX (-100) for system/user tokens +``` + +### Target Masking + +When `return_target=True`, the function produces a target array where: +- **System** and **user** tokens are set to `IGNORE_INDEX` (-100) +- **Tool** tokens are masked if the prompt config enables `allow_tool_role` +- **Assistant prefix tokens** (e.g. `<|assistant|>\n`) are masked based on `assistant_prefix_len` +- **Assistant content tokens** are kept as-is (training signal) + +### PromptConfig + +Each prompt format defines a `PromptConfig` dataclass that controls tokenization behavior: + +```python +@dataclass +class PromptConfig: + assistant_prefix_len: int # tokens in assistant header to mask + pad_token_id: int # padding token ID for this format + custom_chat_template: str | None # Jinja2 template override + has_bos: bool # whether tokenizer adds BOS + has_system_role: bool # whether format supports system role + force_system_message: bool # force a default system message + system_default: dict | None # default system message dict + validate_no_image_in_assistant: bool # reject in assistant + capitalize_roles: bool # capitalize role names (e.g. "User") + skip_masking: bool # skip all target masking + allow_tool_role: bool # include "tool" in masked roles +``` + +### Available Prompt Formats + +**SFT formats** (used with `--tokenizer-mode sft`): + +| Format | Description | +|---|---| +| `nemotron-h-aligned` | Nemotron-H aligned format with ``, `` tokens | +| `nemotron-nano-v2` | Nemotron Nano v2 format | +| `identity` | Identity template (concatenates content without role markers) | +| `default` | Uses the tokenizer's native chat template; skips all masking | + +**Multimodal formats** (used with `--tokenizer-mode multimodal`): + +| Format | Description | +|---|---| +| `mistral` | Mistral format; no system role | +| `llama3` | LLaMA 3 format; has BOS and system role | +| `llama3p1` | LLaMA 3.1 with full chat template | +| `llama3p2` | Same as llama3p1 | +| `nvlm-yi-34b` | Yi 34B NVLM format | +| `chatml` | ChatML (OpenAI-style) format | +| `nemotron5` | Nemotron 5 with special tokens | +| `nemotron5-aligned` | Nemotron 5 aligned with capitalized roles | +| `qwen2p0` | Qwen 2.0 format | +| `qwen2p5` | Qwen 2.5 format | + +### Library-Agnostic Design + +The conversation tokenization pipeline is library-agnostic. It only requires the underlying tokenizer to implement: +- `apply_chat_template(conversation, tokenize=True, chat_template=..., add_generation_prompt=...)` +- `text_to_ids(text)` (used internally by `apply_chat_template`) + +This means SFT works with **any** text tokenizer library: + +```bash +# SFT with SentencePiece +--tokenizer-library sentencepiece --tokenizer-mode sft \ + --tokenizer-model /path/to/model.sp \ + --tokenizer-prompt-format nemotron-h-aligned + +# SFT with TikToken +--tokenizer-library tiktoken --tokenizer-mode sft \ + --tokenizer-model /path/to/vocab.json \ + --tokenizer-prompt-format default + +# SFT with HuggingFace (default for SFT mode) +--tokenizer-library huggingface --tokenizer-mode sft \ + --tokenizer-model /path \ + --tokenizer-prompt-format nemotron-h-aligned +``` + +--- + +## Multimodal Tokenization + +Vision/multimodal tokenizers (`--tokenizer-mode multimodal`) handle image tokens within conversations. + +### Image Tags + +The multimodal tokenizer wraps image tokens (e.g. ``) with configurable tags: + +| `image_tag_type` | Open tag | Close tag | Result | +|---|---|---|---| +| `nvlm` | `` | `` | `` | +| `internvl` | `` | `` | `` | +| `""` (empty) | — | — | `` (unchanged) | + +### Conversation Tokenization + +Multimodal tokenizers use the same `tokenize_conversation()` function as SFT, with an additional `apply_image_tag_fn` callback: + +```python +tokens = tokenizer.tokenize_conversation( + conversation, return_target=False, add_generation_prompt=True +) + +tokens, target = tokenizer.tokenize_conversation( + conversation, return_target=True, add_generation_prompt=False +) +``` + +### Required Arguments + +```bash +--tokenizer-library huggingface \ + --tokenizer-mode multimodal \ + --tokenizer-model /path/to/tokenizer \ + --tokenizer-prompt-format qwen2p0 \ + --special-tokens "" \ + --image-tag-type nvlm +``` + +--- + +## Deprecation: `--tokenizer-type` + +The legacy `--tokenizer-type` flag is deprecated but still supported via automatic mapping in argument validation. It is converted to `--tokenizer-library` + `--tokenizer-mode` internally. + +| Old `--tokenizer-type` | New `--tokenizer-library` | New `--tokenizer-mode` | +|---|---|---| +| `GPTSentencePieceTokenizer` | `sentencepiece` | `text` | +| `SentencePieceTokenizer` | `sentencepiece` | `text` | +| `Llama2Tokenizer` | `sentencepiece` | `text` | +| `HuggingFaceTokenizer` | `huggingface` | `text` | +| `TikTokenizer` | `tiktoken` | `text` | +| `GPT2BPETokenizer` | `megatron` | `text` | +| `BertWordPieceCase` | `megatron` | `text` | +| `BertWordPieceLowerCase` | `megatron` | `text` | +| `NullTokenizer` | `null` | `text` | +| `SFTTokenizer` | `huggingface` | `sft` | +| `MultimodalTokenizer` | `huggingface` | `multimodal` | +| `NullMultimodalTokenizer` | `null` | `multimodal` | + +When `--tokenizer-type` is used, a deprecation warning is emitted and the value is mapped to the new flags automatically. + +--- + +## Custom Tokenizer Classes + +You can register a custom tokenizer class in metadata: + +```python +MegatronTokenizer.write_metadata( + tokenizer_path="/path/to/tokenizer", + tokenizer_library="huggingface", + tokenizer_class=MyCustomTokenizerClass, +) +``` + +This stores `class_name` and `class_path` in the metadata. On load, the factory imports and instantiates the custom class instead of the default `MegatronTokenizerText` / `MegatronTokenizerVision`. + +Custom classes must inherit from `MegatronTokenizerBase`. diff --git a/megatron/core/tokenizers/utils/build_tokenizer.py b/megatron/core/tokenizers/utils/build_tokenizer.py index bf02451ae6c..7fd81763e26 100644 --- a/megatron/core/tokenizers/utils/build_tokenizer.py +++ b/megatron/core/tokenizers/utils/build_tokenizer.py @@ -5,88 +5,31 @@ from megatron.core.tokenizers import MegatronTokenizer -MEGATRON_TOKENIZERS = ['BertWordPieceLowerCase', 'BertWordPieceCase', 'GPT2BPETokenizer'] - -SP_TOKENIZERS = ['SentencePieceTokenizer', 'GPTSentencePieceTokenizer', 'Llama2Tokenizer'] - logger = logging.getLogger(__name__) +# Libraries that map to MegatronTokenizerVision (everything else is text). +_MULTIMODAL_LIBRARIES = {'multimodal', 'null-multimodal'} -def build_tokenizer(args, **kwargs): - """Initialize tokenizer.""" - kwargs = {} - tokenizer_library = None - tokenizer_path = None - if args.tokenizer_type in MEGATRON_TOKENIZERS: - tokenizer_library = 'megatron' - tokenizer_path = args.tokenizer_type - kwargs['additional_special_tokens'] = ( - args.tokenizer_special_tokens if args.tokenizer_special_tokens else [] - ) - if tokenizer_path == 'BertWordPieceCase': - special_tokens = {} - special_tokens['additional_special_tokens'] = [f'' for i in range(100)] - kwargs = special_tokens - kwargs['vocab_file'] = args.vocab_file - kwargs['merges_file'] = args.merge_file - kwargs['use_fast'] = not args.tokenizer_hf_no_use_fast - kwargs['trust_remote_code'] = args.trust_remote_code - kwargs['include_special_tokens'] = not args.tokenizer_hf_no_include_special_tokens - elif args.tokenizer_type in SP_TOKENIZERS: - tokenizer_library = 'sentencepiece' - tokenizer_path = args.tokenizer_model - kwargs['legacy'] = args.tokenizer_sentencepiece_legacy - kwargs['special_tokens'] = args.tokenizer_special_tokens - elif args.tokenizer_type == 'TikTokenizer': - tokenizer_library = 'tiktoken' - tokenizer_path = args.tokenizer_model - if args.tiktoken_pattern: - kwargs['pattern'] = args.tiktoken_pattern - if args.vocab_size: - kwargs['vocab_size'] = args.vocab_size - kwargs['num_special_tokens'] = args.tiktoken_num_special_tokens - kwargs['special_tokens'] = args.tokenizer_special_tokens - elif args.tokenizer_type == 'HuggingFaceTokenizer': - tokenizer_library = 'huggingface' - tokenizer_path = args.tokenizer_model - kwargs['vocab_file'] = args.vocab_file - kwargs['merges_file'] = args.merge_file - kwargs['additional_special_tokens'] = ( - args.tokenizer_special_tokens if args.tokenizer_special_tokens else [] - ) - kwargs['use_fast'] = not args.tokenizer_hf_no_use_fast - kwargs['trust_remote_code'] = args.trust_remote_code - kwargs['include_special_tokens'] = not args.tokenizer_hf_no_include_special_tokens - elif args.tokenizer_type == 'MultimodalTokenizer': - tokenizer_library = 'multimodal' - kwargs['prompt_format'] = args.tokenizer_prompt_format - kwargs['special_tokens'] = args.special_tokens - kwargs['image_tag_type'] = args.image_tag_type - kwargs['force_system_message'] = args.force_system_message - elif args.tokenizer_type == 'SFTTokenizer': - tokenizer_library = 'sft' - tokenizer_path = args.tokenizer_model - kwargs['prompt_format'] = args.sft_tokenizer_prompt_format - elif args.tokenizer_type in ['NullTokenizer', 'NullMultimodalTokenizer']: - tokenizer_library = ( - 'null-text' if args.tokenizer_type == 'NullTokenizer' else 'null-multimodal' - ) - metadata = {'library': tokenizer_library} - if args.vocab_size: - kwargs['vocab_size'] = args.vocab_size - tokenizer = MegatronTokenizer.from_pretrained(metadata_path=metadata, **kwargs) - # Add vocab size (if not already set from a checkpoint). - _set_padded_vocab_size(args, tokenizer) +def build_tokenizer(args, **kwargs): + """Initialize tokenizer from command-line arguments. - return tokenizer + Uses --tokenizer-library (and optionally --tokenizer-prompt-format) or + falls back to --tokenizer-type (deprecated, mapped in argument validation). + """ + build_kwargs = _build_library_kwargs(args) + build_kwargs.update(_build_prompt_kwargs(args)) + build_kwargs.update(kwargs) # Allow caller overrides - if args.tokenizer_metadata: + if getattr(args, 'tokenizer_metadata', None): metadata = args.tokenizer_metadata else: - metadata = {'library': tokenizer_library} + metadata = {'library': _resolve_library(args)} + + tokenizer_path = _resolve_tokenizer_path(args) + tokenizer = MegatronTokenizer.from_pretrained( - tokenizer_path=tokenizer_path, metadata_path=metadata, **kwargs + tokenizer_path=tokenizer_path, metadata_path=metadata, **build_kwargs ) # Add vocab size (if not already set from a checkpoint). @@ -95,6 +38,91 @@ def build_tokenizer(args, **kwargs): return tokenizer +def _resolve_library(args): + """Map CLI --tokenizer-library value to the internal library string.""" + lib = getattr(args, 'tokenizer_library', None) + + if lib == 'null': + return 'null-text' + return lib + + +def _resolve_tokenizer_path(args): + """Resolve tokenizer path from args.""" + lib = getattr(args, 'tokenizer_library', None) + + if lib == 'megatron': + # For 'megatron' library, the tokenizer_type or tokenizer_model is the + # predefined model name (e.g., "GPT2BPETokenizer", "BertWordPieceCase"). + return getattr(args, 'tokenizer_type', None) or args.tokenizer_model + return args.tokenizer_model + + +def _build_library_kwargs(args): + """Build kwargs specific to the tokenizer library.""" + build_kwargs = {} + lib = getattr(args, 'tokenizer_library', None) + + if lib in ('huggingface', 'megatron'): + # Special case for BertWordPieceCase: add extra_id tokens. + if ( + getattr(args, 'tokenizer_type', None) == 'BertWordPieceCase' + or getattr(args, 'tokenizer_model', None) == 'BertWordPieceCase' + ): + build_kwargs['additional_special_tokens'] = [f'' for i in range(100)] + else: + build_kwargs['additional_special_tokens'] = ( + args.tokenizer_special_tokens if args.tokenizer_special_tokens else [] + ) + build_kwargs['vocab_file'] = args.vocab_file + build_kwargs['merges_file'] = args.merge_file + build_kwargs['use_fast'] = not getattr(args, 'tokenizer_hf_no_use_fast', False) + build_kwargs['trust_remote_code'] = getattr(args, 'trust_remote_code', False) + build_kwargs['include_special_tokens'] = not getattr( + args, 'tokenizer_hf_no_include_special_tokens', False + ) + elif lib == 'sentencepiece': + build_kwargs['legacy'] = getattr(args, 'tokenizer_sentencepiece_legacy', False) + build_kwargs['special_tokens'] = args.tokenizer_special_tokens + elif lib == 'tiktoken': + if getattr(args, 'tiktoken_pattern', None): + build_kwargs['pattern'] = args.tiktoken_pattern + if getattr(args, 'vocab_size', None): + build_kwargs['vocab_size'] = args.vocab_size + build_kwargs['num_special_tokens'] = getattr(args, 'tiktoken_num_special_tokens', 1000) + build_kwargs['special_tokens'] = args.tokenizer_special_tokens + elif lib in ('null', 'null-multimodal'): + if getattr(args, 'vocab_size', None): + build_kwargs['vocab_size'] = args.vocab_size + + return build_kwargs + + +def _build_prompt_kwargs(args): + """Build prompt_format and multimodal kwargs based on the tokenizer library. + + Text libraries always receive a prompt_format (defaulting to "default"), + which enables tokenize_conversation() on every text tokenizer. + Multimodal libraries receive prompt_format plus vision-specific kwargs. + """ + build_kwargs = {} + lib = getattr(args, 'tokenizer_library', None) + + prompt_format = getattr(args, 'tokenizer_prompt_format', None) + + if lib in _MULTIMODAL_LIBRARIES: + build_kwargs['prompt_format'] = prompt_format + build_kwargs['special_tokens'] = getattr(args, 'special_tokens', []) + build_kwargs['image_tag_type'] = getattr(args, 'image_tag_type', '') + build_kwargs['force_system_message'] = getattr(args, 'force_system_message', False) + else: + # All text tokenizers get prompt_format; "default" enables + # tokenize_conversation() with skip_masking=True (no-op for pretraining). + build_kwargs['prompt_format'] = prompt_format or 'default' + + return build_kwargs + + def vocab_size_with_padding(orig_vocab_size, args, logging_enabled=True): """Pad vocab size so it is divisible by model parallel size and still having GPU friendly size.""" diff --git a/megatron/core/tokenizers/vision/libraries/multimodal_tokenizer.py b/megatron/core/tokenizers/vision/libraries/multimodal_tokenizer.py index 80712351095..f9b60c2ed2d 100644 --- a/megatron/core/tokenizers/vision/libraries/multimodal_tokenizer.py +++ b/megatron/core/tokenizers/vision/libraries/multimodal_tokenizer.py @@ -2,10 +2,8 @@ from typing import Dict, List, Union -import numpy as np - -from megatron.core.models.multimodal.llava_model import IGNORE_INDEX, IMAGE_TOKEN -from megatron.core.tokenizers.text.libraries.sft_tokenizer import PromptConfig +from megatron.core.models.multimodal.llava_model import IMAGE_TOKEN +from megatron.core.tokenizers.conversation import PROMPT_FORMAT_REGISTRY, tokenize_conversation try: import transformers @@ -22,36 +20,6 @@ } -# The default mistral template raises exceptions so we use a custom one. -mistral_custom_template = """ -{{- bos_token }} -{%- for message in messages %} - {%- if message['role'] == 'user' %} - {{- '[INST] ' + message['content'] + '[/INST]' }} - {%- elif message['role'] == 'assistant' %} - {{- ' ' + message['content'] + eos_token}} - {%- endif %} -{%- endfor %} -{% if add_generation_prompt %}{{ ' ' }}{% endif %} -""" - - -nvlm_yi_34b_template = "{{- bos_token }}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" # pylint: disable=line-too-long - - -qwen2p0_custom_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" # pylint: disable=line-too-long - - -# Note: this is the same template as -# https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct/blob/main/tokenizer_config.json#L2053 -# but we removed the forced system message. -llama3p1_chat_template = """{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- set date_string = \"26 Jul 2024\" %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = none %}\n{%- endif %}\n\n{%- if system_message is not none %}{#- System message + builtin tools #}\n{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n{%- if builtin_tools is defined or tools is not none %}\n {{- \"Environment: ipython\\n\" }}\n{%- endif %}\n{%- if builtin_tools is defined %}\n {{- \"Tools: \" + builtin_tools | reject('equalto', 'code_interpreter') | join(\", \") + \"\\n\\n\"}}\n{%- endif %}{%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- \"<|eot_id|>\" }}\n\n{%-endif %}{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot_id|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}\n {%- elif 'tool_calls' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {%- if builtin_tools is defined and tool_call.name in builtin_tools %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- \"<|python_tag|>\" + tool_call.name + \".call(\" }}\n {%- for arg_name, arg_val in tool_call.arguments | items %}\n {{- arg_name + '=\"' + arg_val + '\"' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \")\" }}\n {%- else %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- \"}\" }}\n {%- endif %}\n {%- if builtin_tools is defined %}\n {#- This means we're in ipython mode #}\n {{- \"<|eom_id|>\" }}\n {%- else %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n""" # pylint: disable=line-too-long - -nemotron_custom_template = "{{- bos_token }}{% for message in messages %}{{'' + message['role'] + '\n' + message['content'] + '' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ 'assistant\n' }}{% endif %}" # pylint: disable=line-too-long - -nemotron_aligned_custom_template = "{{- bos_token}}{% for message in messages %}{{message['role'] + '\n' + message['content'] + '\n' + '[PREFIX]'}}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant\n' }}{% endif %}" # pylint: disable=line-too-long - - class MegatronMultimodalTokenizer: """Multimodal Tokenizer.""" @@ -94,95 +62,17 @@ def __init__( self.tokenizer = tokenizer - if prompt_format == "mistral": - # Mistral format doesn't have prefix for the assistant message. - self._prompt_config = PromptConfig( - assistant_prefix_len=0, - pad_token_id=tokenizer.unk_token_id, - custom_chat_template=mistral_custom_template, - has_bos=True, - has_system_role=False, - ) - elif prompt_format == "llama3": - # "<|start_header_id|>assistant<|end_header|>\n\n" is the prefix for assistant messages. - self._prompt_config = PromptConfig( - assistant_prefix_len=4, - pad_token_id=tokenizer.convert_tokens_to_ids("<|end_of_text|>"), - custom_chat_template=None, - has_bos=True, - has_system_role=True, - ) - elif prompt_format in ("llama3p1", "llama3p2"): - # "<|start_header_id|>assistant<|end_header|>\n\n" is the prefix for assistant messages. - # That occupies 4 tokens and can be masked in the target. - self._prompt_config = PromptConfig( - assistant_prefix_len=4, - pad_token_id=tokenizer.convert_tokens_to_ids("<|finetune_right_pad_id|>"), - custom_chat_template=llama3p1_chat_template, - has_bos=True, - has_system_role=True, - ) - elif prompt_format == "nvlm-yi-34b": - self._prompt_config = PromptConfig( - assistant_prefix_len=4, - pad_token_id=tokenizer.pad_token_id, - custom_chat_template=nvlm_yi_34b_template, - has_bos=True, - has_system_role=True, - ) - elif prompt_format == "chatml": - # "<|im_start|>assistant\n" is the prefix for assistant messages - self._prompt_config = PromptConfig( - assistant_prefix_len=3, - pad_token_id=tokenizer.pad_token_id, - custom_chat_template=None, - has_bos=False, - has_system_role=True, - ) - elif prompt_format == "nemotron5": - # "<|im_start|>assistant\n" is the prefix. - self._prompt_config = PromptConfig( - assistant_prefix_len=3, - pad_token_id=tokenizer.convert_tokens_to_ids(""), - custom_chat_template=nemotron_custom_template, - has_bos=True, - has_system_role=True, - ) - elif prompt_format == "nemotron5-aligned": - # "Assistant\n" is the prefix. - self._prompt_config = PromptConfig( - assistant_prefix_len=2, - pad_token_id=tokenizer.convert_tokens_to_ids(""), - custom_chat_template=nemotron_aligned_custom_template, - has_bos=True, - has_system_role=True, - ) - elif prompt_format in ("qwen2p0", "qwen2p5"): - # "<|im_start|>assistant\n" is the prefix for assistant messages - self._prompt_config = PromptConfig( - assistant_prefix_len=3, - pad_token_id=tokenizer.pad_token_id, - custom_chat_template=qwen2p0_custom_template, - has_bos=False, - has_system_role=True, - force_system_message=force_system_message, - system_default={ - "role": "system", - "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant.", # pylint: disable=line-too-long - }, - ) - elif prompt_format == "llama3p1": - # "<|start_header_id|>assistant<|end_header|>\n\n" is the prefix for assistant messages. - # That occupies 4 tokens and can be masked in the target. - self._prompt_config = PromptConfig( - assistant_prefix_len=4, - pad_token_id=tokenizer.convert_tokens_to_ids("<|finetune_right_pad_id|>"), - custom_chat_template=llama3p1_chat_template, - has_bos=True, - has_system_role=True, + if prompt_format not in PROMPT_FORMAT_REGISTRY: + raise NotImplementedError("unknown multimodal tokenizer type", prompt_format) + + # Build the prompt config from the registry. + # Qwen formats need the force_system_message flag passed through. + if prompt_format in ("qwen2p0", "qwen2p5"): + self._prompt_config = PROMPT_FORMAT_REGISTRY[prompt_format]( + tokenizer, force_system_message=force_system_message ) else: - raise NotImplementedError("unknown multimodal tokenizer type", prompt_format) + self._prompt_config = PROMPT_FORMAT_REGISTRY[prompt_format](tokenizer) self._prompt_format = prompt_format self._image_tag = IMAGE_TAGS[image_tag_type] @@ -230,77 +120,14 @@ def tokenize_conversation( return_target (bool): Return target tokens with system and assistant masked. add_generation_prompt (bool): Add assistant prefix to the end. """ - # Skip system message if the tokenizer doesn't have a system role. - if not self._prompt_config.has_system_role and conversation[0]["role"] == "system": - conversation = conversation[1:] - - if self._prompt_config.force_system_message: - assert ( - self._prompt_config.system_default is not None - ), "Trying to force system message with empty system default" - if conversation[0]["role"] == "system": - conversation[0] = self._prompt_config.system_default - else: - conversation = [self._prompt_config.system_default] + conversation - - if self._prompt_format == "nemotron5-aligned": - for turn in conversation: - tmp = turn['role'] - turn['role'] = tmp[:1].upper() + tmp[1:] - - # Apply possible image tag. - conversation = self._apply_image_tag(conversation) - - tokens = self.tokenizer.apply_chat_template( - conversation, - tokenize=True, + return tokenize_conversation( + tokenizer=self.tokenizer, + conversation=conversation, + prompt_config=self._prompt_config, + return_target=return_target, add_generation_prompt=add_generation_prompt, - return_assistant_token_mask=False, - return_tensors="np", - chat_template=self._prompt_config.custom_chat_template, - )[0] - - if not return_target: - return tokens - - target = tokens.copy() - - # Mask system and user tokens in the target. - idx = 0 - for turn_idx, turn in enumerate(conversation): - if len(turn["content"]) == 0: - raise ValueError(f"empty turn in conversation: {conversation}. Skipping.") - - turn_tokens = self.tokenizer.apply_chat_template( - [turn], tokenize=True, chat_template=self._prompt_config.custom_chat_template - ) - - # There should be only one BOS at the very beginning. - # After the first turn, skip BOS token. - if self._prompt_config.has_bos and turn_idx > 0: - turn_tokens = turn_tokens[1:] - - turn_len = len(turn_tokens) - - role = turn["role"].lower() - if role in ("system", "user"): - target[idx : idx + turn_len] = IGNORE_INDEX - elif role == "assistant": - if IMAGE_TOKEN in turn["content"]: - raise RuntimeError(f"{IMAGE_TOKEN} not allowed in assistant content!") - - if self._prompt_config.assistant_prefix_len > 0: - target[idx : idx + self._prompt_config.assistant_prefix_len] = IGNORE_INDEX - - assert np.allclose( - tokens[idx : idx + turn_len], turn_tokens - ), f"expected turn tokens to match tokens in conversation {conversation}" - - idx += turn_len - - assert idx == len(tokens), f"mismatch in target masking the conversation {conversation}" - - return tokens, target + apply_image_tag_fn=self._apply_image_tag, + ) def convert_tokens_to_ids(self, tokens: List[str]): """Convert tokens to IDs.""" diff --git a/megatron/core/tokenizers/vision/models/__init__.py b/megatron/core/tokenizers/vision/models/__init__.py index b9387786024..cfbb702b3e4 100644 --- a/megatron/core/tokenizers/vision/models/__init__.py +++ b/megatron/core/tokenizers/vision/models/__init__.py @@ -1,3 +1,7 @@ # Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. -from megatron.core.tokenizers.vision.models.default_tokenizer import DefaultTokenizerVision +# The DefaultTokenizerVision wrapper class has been removed as it was an empty subclass. +# This alias is kept for backward compatibility. +from megatron.core.tokenizers.vision.vision_tokenizer import ( + MegatronTokenizerVision as DefaultTokenizerVision, +) diff --git a/megatron/core/tokenizers/vision/models/default_tokenizer.py b/megatron/core/tokenizers/vision/models/default_tokenizer.py deleted file mode 100644 index 5366820727d..00000000000 --- a/megatron/core/tokenizers/vision/models/default_tokenizer.py +++ /dev/null @@ -1,12 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. - -from megatron.core.tokenizers.vision.vision_tokenizer import MegatronTokenizerVision - - -class DefaultTokenizerVision(MegatronTokenizerVision): - """Base class for Megatron default vision tokenizer.""" - - def __init__(self, path: str = None, config: dict = None, **kwargs) -> None: - config['class_name'] = self.__class__.__name__ - config['class_path'] = self.__class__.__module__ - super().__init__(path, config, **kwargs) diff --git a/megatron/core/tokenizers/vision/vision_tokenizer.py b/megatron/core/tokenizers/vision/vision_tokenizer.py index 5e1769116a6..40372513767 100644 --- a/megatron/core/tokenizers/vision/vision_tokenizer.py +++ b/megatron/core/tokenizers/vision/vision_tokenizer.py @@ -27,6 +27,8 @@ def __init__(self, path: str, config: dict, **kwargs) -> None: model_type (str): type of the model to be used with tokenizer. """ + config.setdefault('class_name', self.__class__.__name__) + config.setdefault('class_path', self.__class__.__module__) super().__init__(path, config, **kwargs) self._tokenizer = self._restore_model(**kwargs) self.path = path diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 0420873124c..ecb1eeaa00b 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1552,7 +1552,7 @@ def validate_args(args, defaults={}): "Use --tokenizer-special-tokens instead." ) args.tokenizer_special_tokens = args.tiktoken_special_tokens - + if args.tokenizer_hf_use_fast: warn_rank_0( "--tokenizer-hf-use-fast argument is deprecated and will be removed soon. " @@ -1567,6 +1567,54 @@ def validate_args(args, defaults={}): "Use --tokenizer-hf-no-include-special-tokens if you want to disable `include_special_tokens`." ) + # Map deprecated --tokenizer-type to new --tokenizer-library. + _TOKENIZER_TYPE_TO_LIBRARY = { + 'BertWordPieceLowerCase': 'megatron', + 'BertWordPieceCase': 'megatron', + 'GPT2BPETokenizer': 'megatron', + 'SentencePieceTokenizer': 'sentencepiece', + 'GPTSentencePieceTokenizer': 'sentencepiece', + 'Llama2Tokenizer': 'sentencepiece', + 'HuggingFaceTokenizer': 'huggingface', + 'TikTokenizer': 'tiktoken', + 'MultimodalTokenizer': 'multimodal', + 'SFTTokenizer': 'huggingface', + 'NullTokenizer': 'null', + 'NullMultimodalTokenizer': 'null-multimodal', + } + if args.tokenizer_type is not None and args.tokenizer_library is None: + warn_rank_0( + "--tokenizer-type is deprecated and will be removed in a future release. " + "Use --tokenizer-library instead." + ) + args.tokenizer_library = _TOKENIZER_TYPE_TO_LIBRARY[args.tokenizer_type] + + # Map deprecated --sft-tokenizer-prompt-format to --tokenizer-prompt-format. + # Only applies when using the deprecated --tokenizer-type SFTTokenizer. + if ( + getattr(args, 'tokenizer_type', None) == 'SFTTokenizer' + and getattr(args, 'sft_tokenizer_prompt_format', None) + and not args.tokenizer_prompt_format + ): + args.tokenizer_prompt_format = args.sft_tokenizer_prompt_format + + # Map deprecated --tokenizer-mode to new flags. + _mode = getattr(args, 'tokenizer_mode', 'text') + if _mode != 'text': + warn_rank_0( + "--tokenizer-mode is deprecated and will be removed in a future release. " + "SFT: use --tokenizer-prompt-format instead. " + "Multimodal: use --tokenizer-library multimodal instead." + ) + if _mode == 'multimodal' and args.tokenizer_library not in ('multimodal', 'null-multimodal'): + args.tokenizer_library = ( + 'null-multimodal' if args.tokenizer_library == 'null' else 'multimodal' + ) + if _mode == 'sft' and not args.tokenizer_prompt_format: + args.tokenizer_prompt_format = getattr( + args, 'sft_tokenizer_prompt_format', 'nemotron-h-aligned' + ) + # Print arguments. _print_args("arguments", args) @@ -2657,6 +2705,23 @@ def _add_validation_args(parser): def _add_tokenizer_args(parser): group = parser.add_argument_group(title='tokenizer') + + # --- New preferred flags --- + group.add_argument('--tokenizer-library', type=str, default=None, + choices=['huggingface', 'sentencepiece', 'tiktoken', + 'megatron', 'byte-level', 'null', + 'multimodal', 'null-multimodal'], + help='Tokenizer backend library. Preferred over --tokenizer-type.') + group.add_argument('--tokenizer-prompt-format', type=str, default=None, + help='Prompt format for conversation tokenization (SFT / multimodal).') + + # --- Deprecated mode flag (mapped to --tokenizer-library + --tokenizer-prompt-format) --- + group.add_argument('--tokenizer-mode', type=str, default='text', + choices=['text', 'sft', 'multimodal'], + help='Deprecated. Use --tokenizer-library and ' + '--tokenizer-prompt-format instead.') + + # --- Vocabulary --- group.add_argument('--vocab-size', type=int, default=None, help='Size of vocab before EOD or padding.') group.add_argument('--padded-vocab-size', type=int, default=None, @@ -2670,6 +2735,8 @@ def _add_tokenizer_args(parser): group.add_argument('--vocab-extra-ids', type=int, default=0, help='Number of additional vocabulary tokens. ' 'They are used for span masking in the T5 model') + + # --- Legacy --tokenizer-type (deprecated, use --tokenizer-library) --- group.add_argument('--tokenizer-type', type=str, default=None, choices=['BertWordPieceLowerCase', @@ -2684,31 +2751,40 @@ def _add_tokenizer_args(parser): 'NullTokenizer', 'NullMultimodalTokenizer', 'SFTTokenizer'], - help='What type of tokenizer to use.') + help='Deprecated. Use --tokenizer-library instead.') + + # --- Common --- group.add_argument('--tokenizer-model', type=str, default=None, - help='Sentencepiece tokenizer model.') + help='Path to the tokenizer model file.') group.add_argument('--tokenizer-metadata', type=str, default=None, help='Path to tokenizer metadata in json format.') group.add_argument('--tokenizer-special-tokens', type=str, nargs='+', default=None, help='List of special tokens. For TikTokenizer needs to have ' '["", "", "", "", "", "", ""]') + + # --- TikToken-specific --- group.add_argument('--tiktoken-pattern', type=str, default=None, help='Which tiktoken pattern to use. Options: [v1, v2]') group.add_argument('--tiktoken-num-special-tokens', type=int, default=1000, help='Number of special tokens in tiktoken tokenizer') group.add_argument('--tiktoken-special-tokens', type=str, nargs='+', default=None, - help='List of tiktoken special tokens, needs to have ' - '["", "", "", "", "", "", ""]') + help='Deprecated. Use --tokenizer-special-tokens instead.') + + # --- SentencePiece-specific --- group.add_argument('--tokenizer-sentencepiece-legacy', action='store_true', default=False, help='SentencePiece tokenizer wrapper legacy behavior. Allows special tokens usage.') + + # --- HuggingFace-specific --- group.add_argument('--tokenizer-hf-use-fast', action='store_true', default=True, - help='Whether to use fast HuggingFace tokenizer.') + help='Deprecated. use_fast is True by default. ' + 'Use --tokenizer-hf-no-use-fast to disable.') group.add_argument('--tokenizer-hf-include-special-tokens', action='store_true', default=True, - help='Converting text to ids will include special for HuggingFace tokenizer.') + help='Deprecated. include_special_tokens is True by default. ' + 'Use --tokenizer-hf-no-include-special-tokens to disable.') group.add_argument('--tokenizer-hf-no-use-fast', action='store_true', default=False, - help='Whether to use fast HuggingFace tokenizer.') + help='Disable fast HuggingFace tokenizer.') group.add_argument('--tokenizer-hf-no-include-special-tokens', action='store_true', default=False, - help='Converting text to ids will not include special for HuggingFace tokenizer.') + help='Do not include special tokens in text-to-ids for HuggingFace tokenizer.') group.add_argument("--trust-remote-code", action="store_true", default=False, help='Whether or not to allow PreTrainedTokenizer to execute remote code') return parser diff --git a/tests/functional_tests/test_cases/common/ckpt_converter/__main__.py b/tests/functional_tests/test_cases/common/ckpt_converter/__main__.py index 543ddd36a6d..b9fa90a1dca 100644 --- a/tests/functional_tests/test_cases/common/ckpt_converter/__main__.py +++ b/tests/functional_tests/test_cases/common/ckpt_converter/__main__.py @@ -576,8 +576,8 @@ def get_model_argv(self): "16", "--micro-batch-size", "1", # single sample generated. - "--tokenizer-type", - "NullTokenizer", + "--tokenizer-library", + "null", "--vocab-size", "127", # ... NullTokenizer adds +1 EOD token. "--make-vocab-size-divisible-by", @@ -623,8 +623,8 @@ def get_model_argv(self): "1024", "--micro-batch-size", "1", # single sample generated. - "--tokenizer-type", - "NullMultimodalTokenizer", + "--tokenizer-library", + "null-multimodal", "--vocab-size", "127", # ... NullTokenizer adds +1 EOD token. "--make-vocab-size-divisible-by", diff --git a/tests/functional_tests/test_cases/gpt/gpt3_15b_8t_release/model_config.yaml b/tests/functional_tests/test_cases/gpt/gpt3_15b_8t_release/model_config.yaml index 7086790f957..a3a848acfcb 100644 --- a/tests/functional_tests/test_cases/gpt/gpt3_15b_8t_release/model_config.yaml +++ b/tests/functional_tests/test_cases/gpt/gpt3_15b_8t_release/model_config.yaml @@ -29,7 +29,7 @@ MODEL_ARGS: --transformer-impl: transformer_engine # Data args --data-cache-path: ${DATA_CACHE_PATH} - --tokenizer-type: GPTSentencePieceTokenizer + --tokenizer-library: sentencepiece --tokenizer-model: ${DATA_PATH}/utils/nemotron_2_256k.model --data-path: $DATA_BLEND --split: 99,1,0 diff --git a/tests/functional_tests/test_cases/gpt/gpt3_15b_8t_release_gb200/model_config.yaml b/tests/functional_tests/test_cases/gpt/gpt3_15b_8t_release_gb200/model_config.yaml index 0f88e5f2a00..dc0c4b2e5a2 100644 --- a/tests/functional_tests/test_cases/gpt/gpt3_15b_8t_release_gb200/model_config.yaml +++ b/tests/functional_tests/test_cases/gpt/gpt3_15b_8t_release_gb200/model_config.yaml @@ -33,7 +33,7 @@ MODEL_ARGS: --transformer-impl: transformer_engine # Data args --data-cache-path: ${DATA_CACHE_PATH} - --tokenizer-type: GPTSentencePieceTokenizer + --tokenizer-library: sentencepiece --tokenizer-model: ${DATA_PATH}/utils/nemotron_2_256k.model --data-path: $DATA_BLEND --split: 99,1,0 diff --git a/tests/functional_tests/test_cases/gpt/gpt3_15b_8t_release_sm/model_config.yaml b/tests/functional_tests/test_cases/gpt/gpt3_15b_8t_release_sm/model_config.yaml index 34ef8668e5d..0b657d7d0ca 100644 --- a/tests/functional_tests/test_cases/gpt/gpt3_15b_8t_release_sm/model_config.yaml +++ b/tests/functional_tests/test_cases/gpt/gpt3_15b_8t_release_sm/model_config.yaml @@ -29,7 +29,7 @@ MODEL_ARGS: --transformer-impl: transformer_engine # Data args --data-cache-path: ${DATA_CACHE_PATH} - --tokenizer-type: GPTSentencePieceTokenizer + --tokenizer-library: sentencepiece --tokenizer-model: ${DATA_PATH}/utils/nemotron_2_256k.model --data-path: $DATA_BLEND --split: 99,1,0 diff --git a/tests/functional_tests/test_cases/gpt/gpt3_15b_8t_release_sm_gb200/model_config.yaml b/tests/functional_tests/test_cases/gpt/gpt3_15b_8t_release_sm_gb200/model_config.yaml index c3f57c4aa4f..39642198b67 100644 --- a/tests/functional_tests/test_cases/gpt/gpt3_15b_8t_release_sm_gb200/model_config.yaml +++ b/tests/functional_tests/test_cases/gpt/gpt3_15b_8t_release_sm_gb200/model_config.yaml @@ -33,7 +33,7 @@ MODEL_ARGS: --transformer-impl: transformer_engine # Data args --data-cache-path: ${DATA_CACHE_PATH} - --tokenizer-type: GPTSentencePieceTokenizer + --tokenizer-library: sentencepiece --tokenizer-model: ${DATA_PATH}/utils/nemotron_2_256k.model --data-path: $DATA_BLEND --split: 99,1,0 diff --git a/tests/functional_tests/test_cases/gpt/gpt3_7b_tp1_pp4_memory_speed/model_config.yaml b/tests/functional_tests/test_cases/gpt/gpt3_7b_tp1_pp4_memory_speed/model_config.yaml index d18d84fbf61..bab6b7d133f 100644 --- a/tests/functional_tests/test_cases/gpt/gpt3_7b_tp1_pp4_memory_speed/model_config.yaml +++ b/tests/functional_tests/test_cases/gpt/gpt3_7b_tp1_pp4_memory_speed/model_config.yaml @@ -27,7 +27,7 @@ MODEL_ARGS: --lr-decay-iters: 320000 --save: ${CHECKPOINT_SAVE_PATH} --load: ${CHECKPOINT_LOAD_PATH} - --tokenizer-type: NullTokenizer + --tokenizer-library: null --vocab-size: 131072 --mock-data: true --split: 949,50,1 diff --git a/tests/functional_tests/test_cases/gpt/gpt3_7b_tp4_pp1_memory_speed/model_config.yaml b/tests/functional_tests/test_cases/gpt/gpt3_7b_tp4_pp1_memory_speed/model_config.yaml index c7dfcc00675..8ab60ec5ffc 100644 --- a/tests/functional_tests/test_cases/gpt/gpt3_7b_tp4_pp1_memory_speed/model_config.yaml +++ b/tests/functional_tests/test_cases/gpt/gpt3_7b_tp4_pp1_memory_speed/model_config.yaml @@ -27,7 +27,7 @@ MODEL_ARGS: --lr-decay-iters: 320000 --save: ${CHECKPOINT_SAVE_PATH} --load: ${CHECKPOINT_LOAD_PATH} - --tokenizer-type: NullTokenizer + --tokenizer-library: null --vocab-size: 131072 --mock-data: true --split: 949,50,1 diff --git a/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp1_pp1_583m_cuda_graphs_fp8_logitsmatch/model_config.yaml b/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp1_pp1_583m_cuda_graphs_fp8_logitsmatch/model_config.yaml index 743c4f50da3..a7c0b0ebc56 100644 --- a/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp1_pp1_583m_cuda_graphs_fp8_logitsmatch/model_config.yaml +++ b/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp1_pp1_583m_cuda_graphs_fp8_logitsmatch/model_config.yaml @@ -8,7 +8,7 @@ MODE: inference MODEL_ARGS: --tiktoken-pattern: v2 --use-mcore-models: true - --tokenizer-type: TikTokenizer + --tokenizer-library: tiktoken --tokenizer-model: ${CHECKPOINT_LOAD_PATH}/model/mcore_mistral/nemo_minitron-0.5b/v1/multiMixV8.gpt4o_nc_sd.500000.128k.vocab.json --auto-detect-ckpt-format: true --max-tokens-to-oom: 3600000 diff --git a/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp1_pp1_583m_cuda_graphs_logitsmatch_decode_graphs_only/model_config.yaml b/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp1_pp1_583m_cuda_graphs_logitsmatch_decode_graphs_only/model_config.yaml index b5dc7cd5bd2..cae7209e64f 100644 --- a/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp1_pp1_583m_cuda_graphs_logitsmatch_decode_graphs_only/model_config.yaml +++ b/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp1_pp1_583m_cuda_graphs_logitsmatch_decode_graphs_only/model_config.yaml @@ -8,7 +8,7 @@ MODE: inference MODEL_ARGS: --tiktoken-pattern: v2 --use-mcore-models: true - --tokenizer-type: TikTokenizer + --tokenizer-library: tiktoken --tokenizer-model: ${CHECKPOINT_LOAD_PATH}/model/mcore_mistral/nemo_minitron-0.5b/v1/multiMixV8.gpt4o_nc_sd.500000.128k.vocab.json --auto-detect-ckpt-format: true --max-tokens-to-oom: 3600000 diff --git a/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp1_pp1_583m_cuda_graphs_validation/cuda_graphs.sh b/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp1_pp1_583m_cuda_graphs_validation/cuda_graphs.sh index 641019c9750..69323180f55 100644 --- a/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp1_pp1_583m_cuda_graphs_validation/cuda_graphs.sh +++ b/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp1_pp1_583m_cuda_graphs_validation/cuda_graphs.sh @@ -37,7 +37,7 @@ export CUBLAS_WORKSPACE_CONFIG=:4096:8 ARGS=" \ --tiktoken-pattern v2 \ --use-mcore-models \ - --tokenizer-type TikTokenizer \ + --tokenizer-library tiktoken \ --tokenizer-model ${TOKENIZER_MODEL} \ --auto-detect-ckpt-format \ --max-tokens-to-oom 3600000 \ diff --git a/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp1_pp1_583m_logitsmatch/model_config.yaml b/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp1_pp1_583m_logitsmatch/model_config.yaml index aae99fd1c4c..d26a0136a52 100644 --- a/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp1_pp1_583m_logitsmatch/model_config.yaml +++ b/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp1_pp1_583m_logitsmatch/model_config.yaml @@ -8,7 +8,7 @@ MODE: inference MODEL_ARGS: --tiktoken-pattern: v2 --use-mcore-models: true - --tokenizer-type: TikTokenizer + --tokenizer-library: tiktoken --tokenizer-model: ${CHECKPOINT_LOAD_PATH}/model/mcore_mistral/nemo_minitron-0.5b/v1/multiMixV8.gpt4o_nc_sd.500000.128k.vocab.json --auto-detect-ckpt-format: true --max-tokens-to-oom: 3600000 diff --git a/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp1_pp1_dp8_583m_logitsmatch_zmq/model_config.yaml b/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp1_pp1_dp8_583m_logitsmatch_zmq/model_config.yaml index d84dd24487f..a4096ad9c8e 100644 --- a/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp1_pp1_dp8_583m_logitsmatch_zmq/model_config.yaml +++ b/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp1_pp1_dp8_583m_logitsmatch_zmq/model_config.yaml @@ -8,7 +8,7 @@ MODE: inference MODEL_ARGS: --tiktoken-pattern: v2 --use-mcore-models: true - --tokenizer-type: TikTokenizer + --tokenizer-library: tiktoken --tokenizer-model: ${CHECKPOINT_LOAD_PATH}/model/mcore_mistral/nemo_minitron-0.5b/v1/multiMixV8.gpt4o_nc_sd.500000.128k.vocab.json --auto-detect-ckpt-format: true --max-tokens-to-oom: 3600000 diff --git a/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp1_pp1_dp8_583m_throughputtest_zmq/model_config.yaml b/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp1_pp1_dp8_583m_throughputtest_zmq/model_config.yaml index aa4fde5e512..510f27291b1 100644 --- a/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp1_pp1_dp8_583m_throughputtest_zmq/model_config.yaml +++ b/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp1_pp1_dp8_583m_throughputtest_zmq/model_config.yaml @@ -8,7 +8,7 @@ MODE: inference MODEL_ARGS: --tiktoken-pattern: v2 --use-mcore-models: true - --tokenizer-type: TikTokenizer + --tokenizer-library: tiktoken --tokenizer-model: ${CHECKPOINT_LOAD_PATH}/model/mcore_mistral/nemo_minitron-0.5b/v1/multiMixV8.gpt4o_nc_sd.500000.128k.vocab.json --auto-detect-ckpt-format: true --max-tokens-to-oom: 3600000 diff --git a/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp1_pp8_dp1_583m_logitsmatch_zmq/model_config.yaml b/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp1_pp8_dp1_583m_logitsmatch_zmq/model_config.yaml index 345fc250694..6a68fb06da3 100644 --- a/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp1_pp8_dp1_583m_logitsmatch_zmq/model_config.yaml +++ b/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp1_pp8_dp1_583m_logitsmatch_zmq/model_config.yaml @@ -8,7 +8,7 @@ MODE: inference MODEL_ARGS: --tiktoken-pattern: v2 --use-mcore-models: true - --tokenizer-type: TikTokenizer + --tokenizer-library: tiktoken --tokenizer-model: ${CHECKPOINT_LOAD_PATH}/model/mcore_mistral/nemo_minitron-0.5b/v1/multiMixV8.gpt4o_nc_sd.500000.128k.vocab.json --auto-detect-ckpt-format: true --max-tokens-to-oom: 3600000 diff --git a/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp2_pp2_dp2_583m_logitsmatch_zmq/model_config.yaml b/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp2_pp2_dp2_583m_logitsmatch_zmq/model_config.yaml index 3b55b09e82e..b3ec65119d6 100644 --- a/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp2_pp2_dp2_583m_logitsmatch_zmq/model_config.yaml +++ b/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp2_pp2_dp2_583m_logitsmatch_zmq/model_config.yaml @@ -8,7 +8,7 @@ MODE: inference MODEL_ARGS: --tiktoken-pattern: v2 --use-mcore-models: true - --tokenizer-type: TikTokenizer + --tokenizer-library: tiktoken --tokenizer-model: ${CHECKPOINT_LOAD_PATH}/model/mcore_mistral/nemo_minitron-0.5b/v1/multiMixV8.gpt4o_nc_sd.500000.128k.vocab.json --auto-detect-ckpt-format: true --max-tokens-to-oom: 3600000 diff --git a/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp8_pp1_583m_logitsmatch/model_config.yaml b/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp8_pp1_583m_logitsmatch/model_config.yaml index 4458edf5772..b96ca529c8c 100644 --- a/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp8_pp1_583m_logitsmatch/model_config.yaml +++ b/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp8_pp1_583m_logitsmatch/model_config.yaml @@ -8,7 +8,7 @@ MODE: inference MODEL_ARGS: --tiktoken-pattern: v2 --use-mcore-models: true - --tokenizer-type: TikTokenizer + --tokenizer-library: tiktoken --tokenizer-model: ${CHECKPOINT_LOAD_PATH}/model/mcore_mistral/nemo_minitron-0.5b/v1/multiMixV8.gpt4o_nc_sd.500000.128k.vocab.json --auto-detect-ckpt-format: true --max-tokens-to-oom: 3600000 diff --git a/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp8_pp1_dp1_583m_logitsmatch_zmq/model_config.yaml b/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp8_pp1_dp1_583m_logitsmatch_zmq/model_config.yaml index 88a3e40a193..2ffdca5ca41 100644 --- a/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp8_pp1_dp1_583m_logitsmatch_zmq/model_config.yaml +++ b/tests/functional_tests/test_cases/gpt/gpt_dynamic_inference_tp8_pp1_dp1_583m_logitsmatch_zmq/model_config.yaml @@ -8,7 +8,7 @@ MODE: inference MODEL_ARGS: --tiktoken-pattern: v2 --use-mcore-models: true - --tokenizer-type: TikTokenizer + --tokenizer-library: tiktoken --tokenizer-model: ${CHECKPOINT_LOAD_PATH}/model/mcore_mistral/nemo_minitron-0.5b/v1/multiMixV8.gpt4o_nc_sd.500000.128k.vocab.json --auto-detect-ckpt-format: true --max-tokens-to-oom: 3600000 diff --git a/tests/functional_tests/test_cases/gpt/gpt_grpo_basic_function/model_config.yaml b/tests/functional_tests/test_cases/gpt/gpt_grpo_basic_function/model_config.yaml index 0143a39f017..67a7c2d3d1c 100644 --- a/tests/functional_tests/test_cases/gpt/gpt_grpo_basic_function/model_config.yaml +++ b/tests/functional_tests/test_cases/gpt/gpt_grpo_basic_function/model_config.yaml @@ -41,7 +41,7 @@ MODEL_ARGS: --hidden-dropout: 0.0 --no-masked-softmax-fusion: true --attention-softmax-in-fp32: true - --tokenizer-type: HuggingFaceTokenizer + --tokenizer-library: huggingface --tokenizer-model: ${CHECKPOINT_LOAD_PATH}/model/qwen3-8b-dist/tokenizer --vocab-size: 151936 --make-vocab-size-divisible-by: 128 diff --git a/tests/functional_tests/test_cases/gpt/gpt_grpo_tp1tp2_pp1_dp8_583m_throughputtest/model_config.yaml b/tests/functional_tests/test_cases/gpt/gpt_grpo_tp1tp2_pp1_dp8_583m_throughputtest/model_config.yaml index 4f9be214289..d091b954729 100644 --- a/tests/functional_tests/test_cases/gpt/gpt_grpo_tp1tp2_pp1_dp8_583m_throughputtest/model_config.yaml +++ b/tests/functional_tests/test_cases/gpt/gpt_grpo_tp1tp2_pp1_dp8_583m_throughputtest/model_config.yaml @@ -8,7 +8,7 @@ MODE: rl MODEL_ARGS: --tiktoken-pattern: v2 --use-mcore-models: true - --tokenizer-type: TikTokenizer + --tokenizer-library: tiktoken --tokenizer-model: ${CHECKPOINT_LOAD_PATH}/model/mcore_mistral/nemo_minitron-0.5b/v1/multiMixV8.gpt4o_nc_sd.500000.128k.vocab.json --load: ${CHECKPOINT_LOAD_PATH}/model/mcore_mistral/nemo_minitron-0.5b/v1/ --auto-detect-ckpt-format: true diff --git a/tests/functional_tests/test_cases/gpt/gpt_grpo_tp1tp2_pp1_dp8_583m_throughputtest_github/model_config.yaml b/tests/functional_tests/test_cases/gpt/gpt_grpo_tp1tp2_pp1_dp8_583m_throughputtest_github/model_config.yaml index c8fa19d0500..07ca071628c 100644 --- a/tests/functional_tests/test_cases/gpt/gpt_grpo_tp1tp2_pp1_dp8_583m_throughputtest_github/model_config.yaml +++ b/tests/functional_tests/test_cases/gpt/gpt_grpo_tp1tp2_pp1_dp8_583m_throughputtest_github/model_config.yaml @@ -8,7 +8,7 @@ MODE: rl MODEL_ARGS: --tiktoken-pattern: v2 --use-mcore-models: true - --tokenizer-type: TikTokenizer + --tokenizer-library: tiktoken --tokenizer-model: ${CHECKPOINT_LOAD_PATH}/model/mcore_mistral/nemo_minitron-0.5b/v1/multiMixV8.gpt4o_nc_sd.500000.128k.vocab.json --load: ${CHECKPOINT_LOAD_PATH}/model/mcore_mistral/nemo_minitron-0.5b/v1/ --auto-detect-ckpt-format: true diff --git a/tests/functional_tests/test_cases/gpt/gpt_grpo_tp4_pp1_dp2_8b_cudagraphs_throughput/model_config.yaml b/tests/functional_tests/test_cases/gpt/gpt_grpo_tp4_pp1_dp2_8b_cudagraphs_throughput/model_config.yaml index f3c0c4ecc5b..22d1898fe86 100644 --- a/tests/functional_tests/test_cases/gpt/gpt_grpo_tp4_pp1_dp2_8b_cudagraphs_throughput/model_config.yaml +++ b/tests/functional_tests/test_cases/gpt/gpt_grpo_tp4_pp1_dp2_8b_cudagraphs_throughput/model_config.yaml @@ -38,7 +38,7 @@ MODEL_ARGS: --hidden-dropout: 0.0 --no-masked-softmax-fusion: true --attention-softmax-in-fp32: true - --tokenizer-type: HuggingFaceTokenizer + --tokenizer-library: huggingface --tokenizer-model: ${CHECKPOINT_LOAD_PATH}/model/qwen3-8b-dist/tokenizer --vocab-size: 151936 --make-vocab-size-divisible-by: 128 diff --git a/tests/functional_tests/test_cases/gpt/gpt_grpo_tp4_pp1_dp2_8b_throughput/model_config.yaml b/tests/functional_tests/test_cases/gpt/gpt_grpo_tp4_pp1_dp2_8b_throughput/model_config.yaml index 80664dcdc59..99470bb88f1 100644 --- a/tests/functional_tests/test_cases/gpt/gpt_grpo_tp4_pp1_dp2_8b_throughput/model_config.yaml +++ b/tests/functional_tests/test_cases/gpt/gpt_grpo_tp4_pp1_dp2_8b_throughput/model_config.yaml @@ -38,7 +38,7 @@ MODEL_ARGS: --hidden-dropout: 0.0 --no-masked-softmax-fusion: true --attention-softmax-in-fp32: true - --tokenizer-type: HuggingFaceTokenizer + --tokenizer-library: huggingface --tokenizer-model: ${CHECKPOINT_LOAD_PATH}/model/qwen3-8b-dist/tokenizer --vocab-size: 151936 --make-vocab-size-divisible-by: 128 diff --git a/tests/functional_tests/test_cases/gpt/gpt_grpo_tp4_pp1_dp2_8b_throughput_github/model_config.yaml b/tests/functional_tests/test_cases/gpt/gpt_grpo_tp4_pp1_dp2_8b_throughput_github/model_config.yaml index cc25f3ab90e..aa6c9b81377 100644 --- a/tests/functional_tests/test_cases/gpt/gpt_grpo_tp4_pp1_dp2_8b_throughput_github/model_config.yaml +++ b/tests/functional_tests/test_cases/gpt/gpt_grpo_tp4_pp1_dp2_8b_throughput_github/model_config.yaml @@ -38,7 +38,7 @@ MODEL_ARGS: --hidden-dropout: 0.0 --no-masked-softmax-fusion: true --attention-softmax-in-fp32: true - --tokenizer-type: HuggingFaceTokenizer + --tokenizer-library: huggingface --tokenizer-model: ${CHECKPOINT_LOAD_PATH}/model/qwen3-8b-dist/tokenizer --vocab-size: 151936 --make-vocab-size-divisible-by: 128 diff --git a/tests/functional_tests/test_cases/gpt/gpt_static_inference_tp1_pp1_16b_multiprompt_tokensmatch/model_config.yaml b/tests/functional_tests/test_cases/gpt/gpt_static_inference_tp1_pp1_16b_multiprompt_tokensmatch/model_config.yaml index 40b45024cb1..73c50e8f45c 100644 --- a/tests/functional_tests/test_cases/gpt/gpt_static_inference_tp1_pp1_16b_multiprompt_tokensmatch/model_config.yaml +++ b/tests/functional_tests/test_cases/gpt/gpt_static_inference_tp1_pp1_16b_multiprompt_tokensmatch/model_config.yaml @@ -14,7 +14,7 @@ MODEL_ARGS: # See the mount paths defined in the top level tests/test_utils/recipes/gpt-static-inference.yaml --load: ${CHECKPOINT_LOAD_PATH}/model/deepseek_16b_pyt/dcp/mcore-v1_bf16/checkpoints --tokenizer-model: ${CHECKPOINT_LOAD_PATH}/model/deepseek_16b_pyt/dcp/mcore-v1_bf16/multiMixV8.gpt4o_nc_sd.500000.128k.vocab.json - --tokenizer-type: TikTokenizer + --tokenizer-library: tiktoken --tiktoken-pattern: v2 --distributed-backend: nccl --log-interval: 1 diff --git a/tests/functional_tests/test_cases/gpt/gpt_static_inference_tp1_pp1_583m_cudagraphs/model_config.yaml b/tests/functional_tests/test_cases/gpt/gpt_static_inference_tp1_pp1_583m_cudagraphs/model_config.yaml index 9a47281703a..8d27f44f50e 100644 --- a/tests/functional_tests/test_cases/gpt/gpt_static_inference_tp1_pp1_583m_cudagraphs/model_config.yaml +++ b/tests/functional_tests/test_cases/gpt/gpt_static_inference_tp1_pp1_583m_cudagraphs/model_config.yaml @@ -8,7 +8,7 @@ MODE: inference MODEL_ARGS: --tiktoken-pattern: v2 --use-mcore-models: true - --tokenizer-type: TikTokenizer + --tokenizer-library: tiktoken --tokenizer-model: ${CHECKPOINT_LOAD_PATH}/model/mcore_mistral/nemo_minitron-0.5b/v1/multiMixV8.gpt4o_nc_sd.500000.128k.vocab.json --auto-detect-ckpt-format: true --max-tokens-to-oom: 3600000 diff --git a/tests/functional_tests/test_cases/gpt/gpt_static_inference_tp1_pp1_583m_fp8_cudagraphs/model_config.yaml b/tests/functional_tests/test_cases/gpt/gpt_static_inference_tp1_pp1_583m_fp8_cudagraphs/model_config.yaml index 99bcc433ad1..a9424666f20 100644 --- a/tests/functional_tests/test_cases/gpt/gpt_static_inference_tp1_pp1_583m_fp8_cudagraphs/model_config.yaml +++ b/tests/functional_tests/test_cases/gpt/gpt_static_inference_tp1_pp1_583m_fp8_cudagraphs/model_config.yaml @@ -8,7 +8,7 @@ MODE: inference MODEL_ARGS: --tiktoken-pattern: v2 --use-mcore-models: true - --tokenizer-type: TikTokenizer + --tokenizer-library: tiktoken --tokenizer-model: ${CHECKPOINT_LOAD_PATH}/model/mcore_mistral/nemo_minitron-0.5b/v1/multiMixV8.gpt4o_nc_sd.500000.128k.vocab.json --auto-detect-ckpt-format: true --max-tokens-to-oom: 3600000 diff --git a/tests/functional_tests/test_cases/gpt/gpt_static_inference_tp1_pp1_583m_logitsmatch/model_config.yaml b/tests/functional_tests/test_cases/gpt/gpt_static_inference_tp1_pp1_583m_logitsmatch/model_config.yaml index 1c78b466b1e..5819c519e6d 100644 --- a/tests/functional_tests/test_cases/gpt/gpt_static_inference_tp1_pp1_583m_logitsmatch/model_config.yaml +++ b/tests/functional_tests/test_cases/gpt/gpt_static_inference_tp1_pp1_583m_logitsmatch/model_config.yaml @@ -8,7 +8,7 @@ MODE: inference MODEL_ARGS: --tiktoken-pattern: v2 --use-mcore-models: true - --tokenizer-type: TikTokenizer + --tokenizer-library: tiktoken --tokenizer-model: ${CHECKPOINT_LOAD_PATH}/model/mcore_mistral/nemo_minitron-0.5b/v1/multiMixV8.gpt4o_nc_sd.500000.128k.vocab.json --auto-detect-ckpt-format: true --max-tokens-to-oom: 3600000 diff --git a/tests/functional_tests/test_cases/hybrid/hybrid_dynamic_inference_tp1_pp1_dp8_583m/model_config.yaml b/tests/functional_tests/test_cases/hybrid/hybrid_dynamic_inference_tp1_pp1_dp8_583m/model_config.yaml index 778aa094f4d..783c6175d8c 100644 --- a/tests/functional_tests/test_cases/hybrid/hybrid_dynamic_inference_tp1_pp1_dp8_583m/model_config.yaml +++ b/tests/functional_tests/test_cases/hybrid/hybrid_dynamic_inference_tp1_pp1_dp8_583m/model_config.yaml @@ -13,7 +13,7 @@ MODEL_ARGS: --timing-log-level: 0 --load: ${CHECKPOINT_LOAD_PATH}/model/mamba_hybrid_2b/dcp/mcore-v1_bf16/checkpoint --tokenizer-model: ${CHECKPOINT_LOAD_PATH}/model/mamba_hybrid_2b/dcp/mcore-v1_bf16/multiMixV8.gpt4o_nc_sd.500000.128k.vocab.json - --tokenizer-type: TikTokenizer + --tokenizer-library: tiktoken --tiktoken-pattern: v2 --distributed-backend: nccl --log-interval: 1 diff --git a/tests/functional_tests/test_cases/hybrid/hybrid_dynamic_inference_tp1_pp1_dp8_583m_chunked_prefill/model_config.yaml b/tests/functional_tests/test_cases/hybrid/hybrid_dynamic_inference_tp1_pp1_dp8_583m_chunked_prefill/model_config.yaml index 823199f69e9..381c875df05 100644 --- a/tests/functional_tests/test_cases/hybrid/hybrid_dynamic_inference_tp1_pp1_dp8_583m_chunked_prefill/model_config.yaml +++ b/tests/functional_tests/test_cases/hybrid/hybrid_dynamic_inference_tp1_pp1_dp8_583m_chunked_prefill/model_config.yaml @@ -13,7 +13,7 @@ MODEL_ARGS: --timing-log-level: 0 --load: ${CHECKPOINT_LOAD_PATH}/model/mamba_hybrid_2b/dcp/mcore-v1_bf16/checkpoint --tokenizer-model: ${CHECKPOINT_LOAD_PATH}/model/mamba_hybrid_2b/dcp/mcore-v1_bf16/multiMixV8.gpt4o_nc_sd.500000.128k.vocab.json - --tokenizer-type: TikTokenizer + --tokenizer-library: tiktoken --tiktoken-pattern: v2 --distributed-backend: nccl --log-interval: 1 diff --git a/tests/functional_tests/test_cases/hybrid/hybrid_static_inference_tp1_pp1_2B_cudagraphs/model_config.yaml b/tests/functional_tests/test_cases/hybrid/hybrid_static_inference_tp1_pp1_2B_cudagraphs/model_config.yaml index 26708b32a60..73739892ab4 100644 --- a/tests/functional_tests/test_cases/hybrid/hybrid_static_inference_tp1_pp1_2B_cudagraphs/model_config.yaml +++ b/tests/functional_tests/test_cases/hybrid/hybrid_static_inference_tp1_pp1_2B_cudagraphs/model_config.yaml @@ -13,7 +13,7 @@ MODEL_ARGS: --timing-log-level: 0 --load: ${CHECKPOINT_LOAD_PATH}/model/mamba_hybrid_2b/dcp/mcore-v1_bf16/checkpoint --tokenizer-model: ${CHECKPOINT_LOAD_PATH}/model/mamba_hybrid_2b/dcp/mcore-v1_bf16/multiMixV8.gpt4o_nc_sd.500000.128k.vocab.json - --tokenizer-type: TikTokenizer + --tokenizer-library: tiktoken --tiktoken-pattern: v2 --distributed-backend: nccl --log-interval: 1 diff --git a/tests/functional_tests/test_cases/hybrid/hybrid_static_inference_tp1_pp1_2B_logitsmatch/model_config.yaml b/tests/functional_tests/test_cases/hybrid/hybrid_static_inference_tp1_pp1_2B_logitsmatch/model_config.yaml index 3964bcb8ecb..067289ecbba 100644 --- a/tests/functional_tests/test_cases/hybrid/hybrid_static_inference_tp1_pp1_2B_logitsmatch/model_config.yaml +++ b/tests/functional_tests/test_cases/hybrid/hybrid_static_inference_tp1_pp1_2B_logitsmatch/model_config.yaml @@ -13,7 +13,7 @@ MODEL_ARGS: --timing-log-level: 0 --load: ${CHECKPOINT_LOAD_PATH}/model/mamba_hybrid_2b/dcp/mcore-v1_bf16/checkpoint --tokenizer-model: ${CHECKPOINT_LOAD_PATH}/model/mamba_hybrid_2b/dcp/mcore-v1_bf16/multiMixV8.gpt4o_nc_sd.500000.128k.vocab.json - --tokenizer-type: TikTokenizer + --tokenizer-library: tiktoken --tiktoken-pattern: v2 --distributed-backend: nccl --log-interval: 1 diff --git a/tests/functional_tests/test_cases/mimo/mimo_vlm_pretrain_convergence_tp1_pp1_cp1_dp8/model_config.yaml b/tests/functional_tests/test_cases/mimo/mimo_vlm_pretrain_convergence_tp1_pp1_cp1_dp8/model_config.yaml index e95856e7308..cd8cd32962b 100644 --- a/tests/functional_tests/test_cases/mimo/mimo_vlm_pretrain_convergence_tp1_pp1_cp1_dp8/model_config.yaml +++ b/tests/functional_tests/test_cases/mimo/mimo_vlm_pretrain_convergence_tp1_pp1_cp1_dp8/model_config.yaml @@ -23,7 +23,7 @@ MODEL_ARGS: --lr-decay-iters: 2200 --save: ${CHECKPOINT_SAVE_PATH} --load: ${CHECKPOINT_LOAD_PATH} - --tokenizer-type: HuggingFaceTokenizer + --tokenizer-library: huggingface --tokenizer-model: llava-hf/llava-1.5-7b-hf --distributed-backend: nccl --lr: 0.001 diff --git a/tests/functional_tests/test_cases/mimo/mimo_vlm_pretrain_convergence_tp1_pp1_cp1_dp8_seq_packing/model_config.yaml b/tests/functional_tests/test_cases/mimo/mimo_vlm_pretrain_convergence_tp1_pp1_cp1_dp8_seq_packing/model_config.yaml index 2e86278fa67..17d7124c8ef 100644 --- a/tests/functional_tests/test_cases/mimo/mimo_vlm_pretrain_convergence_tp1_pp1_cp1_dp8_seq_packing/model_config.yaml +++ b/tests/functional_tests/test_cases/mimo/mimo_vlm_pretrain_convergence_tp1_pp1_cp1_dp8_seq_packing/model_config.yaml @@ -26,7 +26,7 @@ MODEL_ARGS: --lr-decay-iters: 2200 --save: ${CHECKPOINT_SAVE_PATH} --load: ${CHECKPOINT_LOAD_PATH} - --tokenizer-type: HuggingFaceTokenizer + --tokenizer-library: huggingface --tokenizer-model: llava-hf/llava-1.5-7b-hf --distributed-backend: nccl --lr: 0.001 diff --git a/tests/functional_tests/test_cases/mimo/mimo_vlm_pretrain_convergence_tp1_pp1_cp2_dp8/model_config.yaml b/tests/functional_tests/test_cases/mimo/mimo_vlm_pretrain_convergence_tp1_pp1_cp2_dp8/model_config.yaml index 37c55e4cd93..6267bd51bb6 100644 --- a/tests/functional_tests/test_cases/mimo/mimo_vlm_pretrain_convergence_tp1_pp1_cp2_dp8/model_config.yaml +++ b/tests/functional_tests/test_cases/mimo/mimo_vlm_pretrain_convergence_tp1_pp1_cp2_dp8/model_config.yaml @@ -25,7 +25,7 @@ MODEL_ARGS: --lr-decay-iters: 2200 --save: ${CHECKPOINT_SAVE_PATH} --load: ${CHECKPOINT_LOAD_PATH} - --tokenizer-type: HuggingFaceTokenizer + --tokenizer-library: huggingface --tokenizer-model: llava-hf/llava-1.5-7b-hf --distributed-backend: nccl --lr: 0.001 diff --git a/tests/functional_tests/test_cases/mixtral/deepseekv3_proxy_flex_tp1pp4emp16etp1cp1_release/model_config.yaml b/tests/functional_tests/test_cases/mixtral/deepseekv3_proxy_flex_tp1pp4emp16etp1cp1_release/model_config.yaml index e504bcb1320..fe3deba2cb1 100644 --- a/tests/functional_tests/test_cases/mixtral/deepseekv3_proxy_flex_tp1pp4emp16etp1cp1_release/model_config.yaml +++ b/tests/functional_tests/test_cases/mixtral/deepseekv3_proxy_flex_tp1pp4emp16etp1cp1_release/model_config.yaml @@ -48,7 +48,7 @@ MODEL_ARGS: # Data args --seq-length: 4096 --data-cache-path: ${DATA_CACHE_PATH} - --tokenizer-type: GPTSentencePieceTokenizer + --tokenizer-library: sentencepiece --tokenizer-model: ${DATA_PATH}/utils/nemotron_2_256k.model --data-path: $DATA_BLEND --split: 99,1,0 diff --git a/tests/functional_tests/test_cases/mixtral/deepseekv3_proxy_flex_tp1pp4emp16etp1cp1_release_sm/model_config.yaml b/tests/functional_tests/test_cases/mixtral/deepseekv3_proxy_flex_tp1pp4emp16etp1cp1_release_sm/model_config.yaml index 49cca71a596..9bc0d5f35ac 100644 --- a/tests/functional_tests/test_cases/mixtral/deepseekv3_proxy_flex_tp1pp4emp16etp1cp1_release_sm/model_config.yaml +++ b/tests/functional_tests/test_cases/mixtral/deepseekv3_proxy_flex_tp1pp4emp16etp1cp1_release_sm/model_config.yaml @@ -48,7 +48,7 @@ MODEL_ARGS: # Data args --seq-length: 4096 --data-cache-path: ${DATA_CACHE_PATH} - --tokenizer-type: GPTSentencePieceTokenizer + --tokenizer-library: sentencepiece --tokenizer-model: ${DATA_PATH}/utils/nemotron_2_256k.model --data-path: $DATA_BLEND --split: 99,1,0 diff --git a/tests/functional_tests/test_cases/mixtral/deepseekv3_proxy_flex_tp2pp2emp16etp1cp1_gb_200_release/model_config.yaml b/tests/functional_tests/test_cases/mixtral/deepseekv3_proxy_flex_tp2pp2emp16etp1cp1_gb_200_release/model_config.yaml index 4452ad22987..9d05ba65639 100644 --- a/tests/functional_tests/test_cases/mixtral/deepseekv3_proxy_flex_tp2pp2emp16etp1cp1_gb_200_release/model_config.yaml +++ b/tests/functional_tests/test_cases/mixtral/deepseekv3_proxy_flex_tp2pp2emp16etp1cp1_gb_200_release/model_config.yaml @@ -50,7 +50,7 @@ MODEL_ARGS: # Data args --seq-length: 4096 --data-cache-path: ${DATA_CACHE_PATH} - --tokenizer-type: GPTSentencePieceTokenizer + --tokenizer-library: sentencepiece --tokenizer-model: ${DATA_PATH}/utils/nemotron_2_256k.model --data-path: $DATA_BLEND --split: 99,1,0 diff --git a/tests/functional_tests/test_cases/mixtral/deepseekv3_proxy_flex_tp2pp2emp16etp1cp1_gb_200_release_sm/model_config.yaml b/tests/functional_tests/test_cases/mixtral/deepseekv3_proxy_flex_tp2pp2emp16etp1cp1_gb_200_release_sm/model_config.yaml index 36cd34bf752..9742e617aee 100644 --- a/tests/functional_tests/test_cases/mixtral/deepseekv3_proxy_flex_tp2pp2emp16etp1cp1_gb_200_release_sm/model_config.yaml +++ b/tests/functional_tests/test_cases/mixtral/deepseekv3_proxy_flex_tp2pp2emp16etp1cp1_gb_200_release_sm/model_config.yaml @@ -50,7 +50,7 @@ MODEL_ARGS: # Data args --seq-length: 4096 --data-cache-path: ${DATA_CACHE_PATH} - --tokenizer-type: GPTSentencePieceTokenizer + --tokenizer-library: sentencepiece --tokenizer-model: ${DATA_PATH}/utils/nemotron_2_256k.model --data-path: $DATA_BLEND --split: 99,1,0 diff --git a/tests/functional_tests/test_cases/mixtral/mixtral_8x22b_tp2pp8ep8vpp1_release/model_config.yaml b/tests/functional_tests/test_cases/mixtral/mixtral_8x22b_tp2pp8ep8vpp1_release/model_config.yaml index efe39998065..a8d72229f60 100644 --- a/tests/functional_tests/test_cases/mixtral/mixtral_8x22b_tp2pp8ep8vpp1_release/model_config.yaml +++ b/tests/functional_tests/test_cases/mixtral/mixtral_8x22b_tp2pp8ep8vpp1_release/model_config.yaml @@ -27,7 +27,7 @@ MODEL_ARGS: --transformer-impl: transformer_engine # Data args --data-cache-path: ${DATA_CACHE_PATH} - --tokenizer-type: Llama2Tokenizer + --tokenizer-library: sentencepiece --tokenizer-model: ${DATA_PATH}/tokenizer.model --data-path: ${DATA_BLEND} --split: 99,1,0 diff --git a/tests/functional_tests/test_cases/mixtral/mixtral_8x7b_alltoall_tp2pp4ep4_release/model_config.yaml b/tests/functional_tests/test_cases/mixtral/mixtral_8x7b_alltoall_tp2pp4ep4_release/model_config.yaml index f4476c712f2..743e8cc8b9a 100644 --- a/tests/functional_tests/test_cases/mixtral/mixtral_8x7b_alltoall_tp2pp4ep4_release/model_config.yaml +++ b/tests/functional_tests/test_cases/mixtral/mixtral_8x7b_alltoall_tp2pp4ep4_release/model_config.yaml @@ -30,7 +30,7 @@ MODEL_ARGS: --transformer-impl: transformer_engine # Data args --data-cache-path: ${DATA_CACHE_PATH} - --tokenizer-type: GPTSentencePieceTokenizer + --tokenizer-library: sentencepiece --tokenizer-model: ${DATA_PATH}/utils/nemotron_2_256k.model --data-path: $DATA_BLEND --split: 99,1,0 diff --git a/tests/functional_tests/test_cases/mixtral/mixtral_8x7b_alltoall_tp2pp4ep4_release_sm/model_config.yaml b/tests/functional_tests/test_cases/mixtral/mixtral_8x7b_alltoall_tp2pp4ep4_release_sm/model_config.yaml index cfeb7709839..15d42c3e4bf 100644 --- a/tests/functional_tests/test_cases/mixtral/mixtral_8x7b_alltoall_tp2pp4ep4_release_sm/model_config.yaml +++ b/tests/functional_tests/test_cases/mixtral/mixtral_8x7b_alltoall_tp2pp4ep4_release_sm/model_config.yaml @@ -30,7 +30,7 @@ MODEL_ARGS: --transformer-impl: transformer_engine # Data args --data-cache-path: ${DATA_CACHE_PATH} - --tokenizer-type: GPTSentencePieceTokenizer + --tokenizer-library: sentencepiece --tokenizer-model: ${DATA_PATH}/utils/nemotron_2_256k.model --data-path: $DATA_BLEND --split: 99,1,0 diff --git a/tests/functional_tests/test_cases/mixtral/mixtral_8x7b_tp1pp4ep8vpp8_release/model_config.yaml b/tests/functional_tests/test_cases/mixtral/mixtral_8x7b_tp1pp4ep8vpp8_release/model_config.yaml index 29dcefadf0e..fb8b844fc7a 100644 --- a/tests/functional_tests/test_cases/mixtral/mixtral_8x7b_tp1pp4ep8vpp8_release/model_config.yaml +++ b/tests/functional_tests/test_cases/mixtral/mixtral_8x7b_tp1pp4ep8vpp8_release/model_config.yaml @@ -29,7 +29,7 @@ MODEL_ARGS: --transformer-impl: transformer_engine # Data args --data-cache-path: ${DATA_CACHE_PATH} - --tokenizer-type: Llama2Tokenizer + --tokenizer-library: sentencepiece --tokenizer-model: ${DATA_PATH}/tokenizer.model --data-path: ${DATA_BLEND} --split: 99,1,0 diff --git a/tests/functional_tests/test_cases/moe/gpt_dynamic_inference_cuda_graphs_pad_tp4_pp1_ep4_16B_logitsmatch/model_config.yaml b/tests/functional_tests/test_cases/moe/gpt_dynamic_inference_cuda_graphs_pad_tp4_pp1_ep4_16B_logitsmatch/model_config.yaml index afc75144dc8..7ad7712b13b 100644 --- a/tests/functional_tests/test_cases/moe/gpt_dynamic_inference_cuda_graphs_pad_tp4_pp1_ep4_16B_logitsmatch/model_config.yaml +++ b/tests/functional_tests/test_cases/moe/gpt_dynamic_inference_cuda_graphs_pad_tp4_pp1_ep4_16B_logitsmatch/model_config.yaml @@ -13,7 +13,7 @@ MODEL_ARGS: --timing-log-level: 0 --load: ${CHECKPOINT_LOAD_PATH}/model/deepseek_16b_pyt/dcp/mcore-v1_bf16/checkpoints --tokenizer-model: ${CHECKPOINT_LOAD_PATH}/model/deepseek_16b_pyt/dcp/mcore-v1_bf16/multiMixV8.gpt4o_nc_sd.500000.128k.vocab.json - --tokenizer-type: TikTokenizer + --tokenizer-library: tiktoken --tiktoken-pattern: v2 --distributed-backend: nccl --log-interval: 1 diff --git a/tests/functional_tests/test_cases/moe/gpt_dynamic_inference_tp4_etp1_pp1_ep8_16B_logitsmatch_cudagraph_zmq/model_config.yaml b/tests/functional_tests/test_cases/moe/gpt_dynamic_inference_tp4_etp1_pp1_ep8_16B_logitsmatch_cudagraph_zmq/model_config.yaml index 80e2a37c250..988aead03f0 100644 --- a/tests/functional_tests/test_cases/moe/gpt_dynamic_inference_tp4_etp1_pp1_ep8_16B_logitsmatch_cudagraph_zmq/model_config.yaml +++ b/tests/functional_tests/test_cases/moe/gpt_dynamic_inference_tp4_etp1_pp1_ep8_16B_logitsmatch_cudagraph_zmq/model_config.yaml @@ -13,7 +13,7 @@ MODEL_ARGS: --timing-log-level: 0 --load: ${CHECKPOINT_LOAD_PATH}/model/deepseek_16b_pyt/dcp/mcore-v1_bf16/checkpoints --tokenizer-model: ${CHECKPOINT_LOAD_PATH}/model/deepseek_16b_pyt/dcp/mcore-v1_bf16/multiMixV8.gpt4o_nc_sd.500000.128k.vocab.json - --tokenizer-type: TikTokenizer + --tokenizer-library: tiktoken --tiktoken-pattern: v2 --distributed-backend: nccl --log-interval: 1 diff --git a/tests/functional_tests/test_cases/moe/gpt_dynamic_inference_tp4_etp1_pp1_ep8_16B_logitsmatch_zmq/model_config.yaml b/tests/functional_tests/test_cases/moe/gpt_dynamic_inference_tp4_etp1_pp1_ep8_16B_logitsmatch_zmq/model_config.yaml index 479cb7a4751..01f044f6def 100644 --- a/tests/functional_tests/test_cases/moe/gpt_dynamic_inference_tp4_etp1_pp1_ep8_16B_logitsmatch_zmq/model_config.yaml +++ b/tests/functional_tests/test_cases/moe/gpt_dynamic_inference_tp4_etp1_pp1_ep8_16B_logitsmatch_zmq/model_config.yaml @@ -13,7 +13,7 @@ MODEL_ARGS: --timing-log-level: 0 --load: ${CHECKPOINT_LOAD_PATH}/model/deepseek_16b_pyt/dcp/mcore-v1_bf16/checkpoints --tokenizer-model: ${CHECKPOINT_LOAD_PATH}/model/deepseek_16b_pyt/dcp/mcore-v1_bf16/multiMixV8.gpt4o_nc_sd.500000.128k.vocab.json - --tokenizer-type: TikTokenizer + --tokenizer-library: tiktoken --tiktoken-pattern: v2 --distributed-backend: nccl --log-interval: 1 diff --git a/tests/functional_tests/test_cases/moe/gpt_dynamic_inference_tp4_etp1_pp1_ep8_16B_logitsmatch_zmq_suspend_resume/model_config.yaml b/tests/functional_tests/test_cases/moe/gpt_dynamic_inference_tp4_etp1_pp1_ep8_16B_logitsmatch_zmq_suspend_resume/model_config.yaml index 1f302455440..2937d5cd913 100644 --- a/tests/functional_tests/test_cases/moe/gpt_dynamic_inference_tp4_etp1_pp1_ep8_16B_logitsmatch_zmq_suspend_resume/model_config.yaml +++ b/tests/functional_tests/test_cases/moe/gpt_dynamic_inference_tp4_etp1_pp1_ep8_16B_logitsmatch_zmq_suspend_resume/model_config.yaml @@ -13,7 +13,7 @@ MODEL_ARGS: --timing-log-level: 0 --load: ${CHECKPOINT_LOAD_PATH}/model/deepseek_16b_pyt/dcp/mcore-v1_bf16/checkpoints --tokenizer-model: ${CHECKPOINT_LOAD_PATH}/model/deepseek_16b_pyt/dcp/mcore-v1_bf16/multiMixV8.gpt4o_nc_sd.500000.128k.vocab.json - --tokenizer-type: TikTokenizer + --tokenizer-library: tiktoken --tiktoken-pattern: v2 --distributed-backend: nccl --log-interval: 1 diff --git a/tests/functional_tests/test_cases/moe/gpt_dynamic_inference_tp4_pp1_ep4_16B_logitsmatch/model_config.yaml b/tests/functional_tests/test_cases/moe/gpt_dynamic_inference_tp4_pp1_ep4_16B_logitsmatch/model_config.yaml index 5ed1f1205f6..fb64a341de1 100644 --- a/tests/functional_tests/test_cases/moe/gpt_dynamic_inference_tp4_pp1_ep4_16B_logitsmatch/model_config.yaml +++ b/tests/functional_tests/test_cases/moe/gpt_dynamic_inference_tp4_pp1_ep4_16B_logitsmatch/model_config.yaml @@ -13,7 +13,7 @@ MODEL_ARGS: --timing-log-level: 0 --load: ${CHECKPOINT_LOAD_PATH}/model/deepseek_16b_pyt/dcp/mcore-v1_bf16/checkpoints --tokenizer-model: ${CHECKPOINT_LOAD_PATH}/model/deepseek_16b_pyt/dcp/mcore-v1_bf16/multiMixV8.gpt4o_nc_sd.500000.128k.vocab.json - --tokenizer-type: TikTokenizer + --tokenizer-library: tiktoken --tiktoken-pattern: v2 --distributed-backend: nccl --log-interval: 1 diff --git a/tests/functional_tests/test_cases/moe/gpt_grpo_tp8tp4_pp1_ep8ep2_dp8_throughputtest/model_config.yaml b/tests/functional_tests/test_cases/moe/gpt_grpo_tp8tp4_pp1_ep8ep2_dp8_throughputtest/model_config.yaml index 139c5a82e57..26a6014a574 100644 --- a/tests/functional_tests/test_cases/moe/gpt_grpo_tp8tp4_pp1_ep8ep2_dp8_throughputtest/model_config.yaml +++ b/tests/functional_tests/test_cases/moe/gpt_grpo_tp8tp4_pp1_ep8ep2_dp8_throughputtest/model_config.yaml @@ -22,7 +22,7 @@ MODEL_ARGS: # Model loading --load: ${CHECKPOINT_LOAD_PATH}/model/deepseek_16b_pyt/dcp/mcore-v1_bf16/checkpoints --tokenizer-model: ${CHECKPOINT_LOAD_PATH}/model/deepseek_16b_pyt/dcp/mcore-v1_bf16/multiMixV8.gpt4o_nc_sd.500000.128k.vocab.json - --tokenizer-type: TikTokenizer + --tokenizer-library: tiktoken --tiktoken-pattern: v2 --use-checkpoint-args: true --no-use-tokenizer-model-from-checkpoint-args: true diff --git a/tests/functional_tests/test_cases/moe/gpt_static_inference_cuda_graphs_pad_tp4_pp1_ep4_16B_logitsmatch/model_config.yaml b/tests/functional_tests/test_cases/moe/gpt_static_inference_cuda_graphs_pad_tp4_pp1_ep4_16B_logitsmatch/model_config.yaml index 1c22a729f6e..6744186cdf6 100644 --- a/tests/functional_tests/test_cases/moe/gpt_static_inference_cuda_graphs_pad_tp4_pp1_ep4_16B_logitsmatch/model_config.yaml +++ b/tests/functional_tests/test_cases/moe/gpt_static_inference_cuda_graphs_pad_tp4_pp1_ep4_16B_logitsmatch/model_config.yaml @@ -13,7 +13,7 @@ MODEL_ARGS: --timing-log-level: 0 --load: ${CHECKPOINT_LOAD_PATH}/model/deepseek_16b_pyt/dcp/mcore-v1_bf16/checkpoints --tokenizer-model: ${CHECKPOINT_LOAD_PATH}/model/deepseek_16b_pyt/dcp/mcore-v1_bf16/multiMixV8.gpt4o_nc_sd.500000.128k.vocab.json - --tokenizer-type: TikTokenizer + --tokenizer-library: tiktoken --tiktoken-pattern: v2 --distributed-backend: nccl --log-interval: 1 diff --git a/tests/functional_tests/test_cases/moe/gpt_static_inference_tp1_pp1_ep1_16B_logitsmatch/model_config.yaml b/tests/functional_tests/test_cases/moe/gpt_static_inference_tp1_pp1_ep1_16B_logitsmatch/model_config.yaml index 03895d97ee9..b2c3bdff960 100644 --- a/tests/functional_tests/test_cases/moe/gpt_static_inference_tp1_pp1_ep1_16B_logitsmatch/model_config.yaml +++ b/tests/functional_tests/test_cases/moe/gpt_static_inference_tp1_pp1_ep1_16B_logitsmatch/model_config.yaml @@ -13,7 +13,7 @@ MODEL_ARGS: --timing-log-level: 0 --load: ${CHECKPOINT_LOAD_PATH}/model/deepseek_16b_pyt/dcp/mcore-v1_bf16/checkpoints --tokenizer-model: ${CHECKPOINT_LOAD_PATH}/model/deepseek_16b_pyt/dcp/mcore-v1_bf16/multiMixV8.gpt4o_nc_sd.500000.128k.vocab.json - --tokenizer-type: TikTokenizer + --tokenizer-library: tiktoken --tiktoken-pattern: v2 --distributed-backend: nccl --log-interval: 1 diff --git a/tests/functional_tests/test_cases/moe/gpt_static_inference_tp4_pp1_ep4_16B_logitsmatch/model_config.yaml b/tests/functional_tests/test_cases/moe/gpt_static_inference_tp4_pp1_ep4_16B_logitsmatch/model_config.yaml index 9259d63c9d1..f8da065c853 100644 --- a/tests/functional_tests/test_cases/moe/gpt_static_inference_tp4_pp1_ep4_16B_logitsmatch/model_config.yaml +++ b/tests/functional_tests/test_cases/moe/gpt_static_inference_tp4_pp1_ep4_16B_logitsmatch/model_config.yaml @@ -13,7 +13,7 @@ MODEL_ARGS: --timing-log-level: 0 --load: ${CHECKPOINT_LOAD_PATH}/model/deepseek_16b_pyt/dcp/mcore-v1_bf16/checkpoints --tokenizer-model: ${CHECKPOINT_LOAD_PATH}/model/deepseek_16b_pyt/dcp/mcore-v1_bf16/multiMixV8.gpt4o_nc_sd.500000.128k.vocab.json - --tokenizer-type: TikTokenizer + --tokenizer-library: tiktoken --tiktoken-pattern: v2 --distributed-backend: nccl --log-interval: 1 diff --git a/tests/functional_tests/test_cases/multimodal-llava/multimodal_llava_mcore_te_tp1_pp1/model_config.yaml b/tests/functional_tests/test_cases/multimodal-llava/multimodal_llava_mcore_te_tp1_pp1/model_config.yaml index 2898070f957..81c2e9ae593 100644 --- a/tests/functional_tests/test_cases/multimodal-llava/multimodal_llava_mcore_te_tp1_pp1/model_config.yaml +++ b/tests/functional_tests/test_cases/multimodal-llava/multimodal_llava_mcore_te_tp1_pp1/model_config.yaml @@ -24,7 +24,7 @@ MODEL_ARGS: --save: ${CHECKPOINT_SAVE_PATH} --load: ${CHECKPOINT_LOAD_PATH} --split: 949,50,1 - --tokenizer-type: NullTokenizer + --tokenizer-library: null --vocab-size: 8192 --distributed-backend: nccl --lr: 0.00015 diff --git a/tests/functional_tests/test_cases/multimodal-llava/multimodal_llava_mcore_te_tp4_sp_cp2/model_config.yaml b/tests/functional_tests/test_cases/multimodal-llava/multimodal_llava_mcore_te_tp4_sp_cp2/model_config.yaml index 23bdaac5010..28d8839d198 100644 --- a/tests/functional_tests/test_cases/multimodal-llava/multimodal_llava_mcore_te_tp4_sp_cp2/model_config.yaml +++ b/tests/functional_tests/test_cases/multimodal-llava/multimodal_llava_mcore_te_tp4_sp_cp2/model_config.yaml @@ -26,7 +26,7 @@ MODEL_ARGS: --save: ${CHECKPOINT_SAVE_PATH} --load: ${CHECKPOINT_LOAD_PATH} --split: 949,50,1 - --tokenizer-type: NullTokenizer + --tokenizer-library: null --vocab-size: 8192 --distributed-backend: nccl --lr: 0.00015 diff --git a/tests/functional_tests/test_cases/t5/t5_11b_mcore_tp4_pp1/model_config.yaml b/tests/functional_tests/test_cases/t5/t5_11b_mcore_tp4_pp1/model_config.yaml index bccce17cef1..7a1cb52defc 100644 --- a/tests/functional_tests/test_cases/t5/t5_11b_mcore_tp4_pp1/model_config.yaml +++ b/tests/functional_tests/test_cases/t5/t5_11b_mcore_tp4_pp1/model_config.yaml @@ -31,7 +31,8 @@ MODEL_ARGS: --transformer-impl: transformer_engine --data-path: ${DATA_PATH}/text/the_pile/t5_shard00/my-t5_00_text_document --vocab-file: ${DATA_PATH}/text/the_pile/t5_shard00/bert-large-cased-vocab.txt - --tokenizer-type: BertWordPieceCase + --tokenizer-library: megatron + --tokenizer-model: BertWordPieceCase --calculate-per-token-loss: true --split: 99982,9,9 --save: ${CHECKPOINT_SAVE_PATH} diff --git a/tests/functional_tests/test_cases/t5/t5_mcore_te_tp1_pp1_vp1_resume_torch/model_config.yaml b/tests/functional_tests/test_cases/t5/t5_mcore_te_tp1_pp1_vp1_resume_torch/model_config.yaml index aa0f67ff311..9103bbfae7d 100644 --- a/tests/functional_tests/test_cases/t5/t5_mcore_te_tp1_pp1_vp1_resume_torch/model_config.yaml +++ b/tests/functional_tests/test_cases/t5/t5_mcore_te_tp1_pp1_vp1_resume_torch/model_config.yaml @@ -31,7 +31,8 @@ MODEL_ARGS: --transformer-impl: transformer_engine --data-path: ${DATA_PATH}/text/the_pile/t5_shard00/my-t5_00_text_document --vocab-file: ${DATA_PATH}/text/the_pile/t5_shard00/bert-large-cased-vocab.txt - --tokenizer-type: BertWordPieceCase + --tokenizer-library: megatron + --tokenizer-model: BertWordPieceCase --calculate-per-token-loss: true --split: 99982,9,9 --save: ${CHECKPOINT_SAVE_PATH} diff --git a/tests/functional_tests/test_cases/t5/t5_mcore_te_tp2_pp1_vp1/model_config.yaml b/tests/functional_tests/test_cases/t5/t5_mcore_te_tp2_pp1_vp1/model_config.yaml index 59c1d0f280f..73a4f14d0ab 100644 --- a/tests/functional_tests/test_cases/t5/t5_mcore_te_tp2_pp1_vp1/model_config.yaml +++ b/tests/functional_tests/test_cases/t5/t5_mcore_te_tp2_pp1_vp1/model_config.yaml @@ -31,7 +31,8 @@ MODEL_ARGS: --transformer-impl: transformer_engine --data-path: ${DATA_PATH}/text/the_pile/t5_shard00/my-t5_00_text_document --vocab-file: ${DATA_PATH}/text/the_pile/t5_shard00/bert-large-cased-vocab.txt - --tokenizer-type: BertWordPieceCase + --tokenizer-library: megatron + --tokenizer-model: BertWordPieceCase --calculate-per-token-loss: true --split: 99982,9,9 --save: ${CHECKPOINT_SAVE_PATH} diff --git a/tests/functional_tests/test_cases/t5/t5_mcore_te_tp2_pp1_vp1_sequence_parallel/model_config.yaml b/tests/functional_tests/test_cases/t5/t5_mcore_te_tp2_pp1_vp1_sequence_parallel/model_config.yaml index 80a84a26e0c..f61f15f0582 100644 --- a/tests/functional_tests/test_cases/t5/t5_mcore_te_tp2_pp1_vp1_sequence_parallel/model_config.yaml +++ b/tests/functional_tests/test_cases/t5/t5_mcore_te_tp2_pp1_vp1_sequence_parallel/model_config.yaml @@ -31,7 +31,8 @@ MODEL_ARGS: --transformer-impl: transformer_engine --data-path: ${DATA_PATH}/text/the_pile/t5_shard00/my-t5_00_text_document --vocab-file: ${DATA_PATH}/text/the_pile/t5_shard00/bert-large-cased-vocab.txt - --tokenizer-type: BertWordPieceCase + --tokenizer-library: megatron + --tokenizer-model: BertWordPieceCase --calculate-per-token-loss: true --split: 99982,9,9 --save: ${CHECKPOINT_SAVE_PATH} diff --git a/tests/functional_tests/test_cases/t5/t5_mcore_te_tp4_pp1/model_config.yaml b/tests/functional_tests/test_cases/t5/t5_mcore_te_tp4_pp1/model_config.yaml index 047280dec39..84743778eb9 100644 --- a/tests/functional_tests/test_cases/t5/t5_mcore_te_tp4_pp1/model_config.yaml +++ b/tests/functional_tests/test_cases/t5/t5_mcore_te_tp4_pp1/model_config.yaml @@ -31,7 +31,8 @@ MODEL_ARGS: --transformer-impl: transformer_engine --data-path: ${DATA_PATH}/text/the_pile/t5_shard00/my-t5_00_text_document --vocab-file: ${DATA_PATH}/text/the_pile/t5_shard00/bert-large-cased-vocab.txt - --tokenizer-type: BertWordPieceCase + --tokenizer-library: megatron + --tokenizer-model: BertWordPieceCase --calculate-per-token-loss: true --split: 99982,9,9 --save: ${CHECKPOINT_SAVE_PATH} diff --git a/tests/functional_tests/test_cases/t5/t5_mcore_te_tp4_pp1_resume_torch_dist/model_config.yaml b/tests/functional_tests/test_cases/t5/t5_mcore_te_tp4_pp1_resume_torch_dist/model_config.yaml index 1611c02251b..c394ccb5713 100644 --- a/tests/functional_tests/test_cases/t5/t5_mcore_te_tp4_pp1_resume_torch_dist/model_config.yaml +++ b/tests/functional_tests/test_cases/t5/t5_mcore_te_tp4_pp1_resume_torch_dist/model_config.yaml @@ -31,7 +31,8 @@ MODEL_ARGS: --transformer-impl: transformer_engine --data-path: ${DATA_PATH}/text/the_pile/t5_shard00/my-t5_00_text_document --vocab-file: ${DATA_PATH}/text/the_pile/t5_shard00/bert-large-cased-vocab.txt - --tokenizer-type: BertWordPieceCase + --tokenizer-library: megatron + --tokenizer-model: BertWordPieceCase --calculate-per-token-loss: true --split: 99982,9,9 --save: ${CHECKPOINT_SAVE_PATH} diff --git a/tests/functional_tests/test_cases/t5/t5_mcore_tp1_pp1_vp1/model_config.yaml b/tests/functional_tests/test_cases/t5/t5_mcore_tp1_pp1_vp1/model_config.yaml index 12ccecb5883..cf5baac582d 100644 --- a/tests/functional_tests/test_cases/t5/t5_mcore_tp1_pp1_vp1/model_config.yaml +++ b/tests/functional_tests/test_cases/t5/t5_mcore_tp1_pp1_vp1/model_config.yaml @@ -31,7 +31,8 @@ MODEL_ARGS: --transformer-impl: local --data-path: ${DATA_PATH}/text/the_pile/t5_shard00/my-t5_00_text_document --vocab-file: ${DATA_PATH}/text/the_pile/t5_shard00/bert-large-cased-vocab.txt - --tokenizer-type: BertWordPieceCase + --tokenizer-library: megatron + --tokenizer-model: BertWordPieceCase --calculate-per-token-loss: true --split: 99982,9,9 --save: ${CHECKPOINT_SAVE_PATH} diff --git a/tests/functional_tests/test_cases/t5/t5_mcore_tp1_pp1_vp1_resume_torch/model_config.yaml b/tests/functional_tests/test_cases/t5/t5_mcore_tp1_pp1_vp1_resume_torch/model_config.yaml index 8559fd587d1..6d539e627fd 100644 --- a/tests/functional_tests/test_cases/t5/t5_mcore_tp1_pp1_vp1_resume_torch/model_config.yaml +++ b/tests/functional_tests/test_cases/t5/t5_mcore_tp1_pp1_vp1_resume_torch/model_config.yaml @@ -31,7 +31,8 @@ MODEL_ARGS: --transformer-impl: local --data-path: ${DATA_PATH}/text/the_pile/t5_shard00/my-t5_00_text_document --vocab-file: ${DATA_PATH}/text/the_pile/t5_shard00/bert-large-cased-vocab.txt - --tokenizer-type: BertWordPieceCase + --tokenizer-library: megatron + --tokenizer-model: BertWordPieceCase --calculate-per-token-loss: true --split: 99982,9,9 --save: ${CHECKPOINT_SAVE_PATH} diff --git a/tests/functional_tests/test_cases/t5/t5_mcore_tp2_pp1_vp1/model_config.yaml b/tests/functional_tests/test_cases/t5/t5_mcore_tp2_pp1_vp1/model_config.yaml index 9c6a835571c..853d2c484c0 100644 --- a/tests/functional_tests/test_cases/t5/t5_mcore_tp2_pp1_vp1/model_config.yaml +++ b/tests/functional_tests/test_cases/t5/t5_mcore_tp2_pp1_vp1/model_config.yaml @@ -31,7 +31,8 @@ MODEL_ARGS: --transformer-impl: local --data-path: ${DATA_PATH}/text/the_pile/t5_shard00/my-t5_00_text_document --vocab-file: ${DATA_PATH}/text/the_pile/t5_shard00/bert-large-cased-vocab.txt - --tokenizer-type: BertWordPieceCase + --tokenizer-library: megatron + --tokenizer-model: BertWordPieceCase --calculate-per-token-loss: true --split: 99982,9,9 --save: ${CHECKPOINT_SAVE_PATH} diff --git a/tests/functional_tests/test_cases/t5/t5_mcore_tp4_pp1/model_config.yaml b/tests/functional_tests/test_cases/t5/t5_mcore_tp4_pp1/model_config.yaml index dd3896ad88a..89c6bf9d0ee 100644 --- a/tests/functional_tests/test_cases/t5/t5_mcore_tp4_pp1/model_config.yaml +++ b/tests/functional_tests/test_cases/t5/t5_mcore_tp4_pp1/model_config.yaml @@ -31,7 +31,8 @@ MODEL_ARGS: --transformer-impl: local --data-path: ${DATA_PATH}/text/the_pile/t5_shard00/my-t5_00_text_document --vocab-file: ${DATA_PATH}/text/the_pile/t5_shard00/bert-large-cased-vocab.txt - --tokenizer-type: BertWordPieceCase + --tokenizer-library: megatron + --tokenizer-model: BertWordPieceCase --calculate-per-token-loss: true --split: 99982,9,9 --save: ${CHECKPOINT_SAVE_PATH} diff --git a/tests/functional_tests/test_cases/t5/t5_mcore_tp4_pp1_resume_torch_dist/model_config.yaml b/tests/functional_tests/test_cases/t5/t5_mcore_tp4_pp1_resume_torch_dist/model_config.yaml index 4c955dd5441..88b61ced435 100644 --- a/tests/functional_tests/test_cases/t5/t5_mcore_tp4_pp1_resume_torch_dist/model_config.yaml +++ b/tests/functional_tests/test_cases/t5/t5_mcore_tp4_pp1_resume_torch_dist/model_config.yaml @@ -31,7 +31,8 @@ MODEL_ARGS: --transformer-impl: local --data-path: ${DATA_PATH}/text/the_pile/t5_shard00/my-t5_00_text_document --vocab-file: ${DATA_PATH}/text/the_pile/t5_shard00/bert-large-cased-vocab.txt - --tokenizer-type: BertWordPieceCase + --tokenizer-library: megatron + --tokenizer-model: BertWordPieceCase --calculate-per-token-loss: true --split: 99982,9,9 --save: ${CHECKPOINT_SAVE_PATH} diff --git a/tests/functional_tests/test_cases/t5/t5_release/model_config.yaml b/tests/functional_tests/test_cases/t5/t5_release/model_config.yaml index a7abdc1bdd4..b01d081b4dc 100644 --- a/tests/functional_tests/test_cases/t5/t5_release/model_config.yaml +++ b/tests/functional_tests/test_cases/t5/t5_release/model_config.yaml @@ -38,7 +38,8 @@ MODEL_ARGS: # Data args --data-path: ${DATA_BLEND} --vocab-file: ${DATA_PATH}/text/the_pile/t5_shard00/bert-large-cased-vocab.txt - --tokenizer-type: BertWordPieceCase + --tokenizer-library: megatron + --tokenizer-model: BertWordPieceCase --split: 99982,9,9 --data-cache-path: ${DATA_CACHE_PATH} --vocab-extra-ids: 100 diff --git a/tests/functional_tests/test_cases/t5/t5_release_sm/model_config.yaml b/tests/functional_tests/test_cases/t5/t5_release_sm/model_config.yaml index 7f748273cd3..28a58fe95fa 100644 --- a/tests/functional_tests/test_cases/t5/t5_release_sm/model_config.yaml +++ b/tests/functional_tests/test_cases/t5/t5_release_sm/model_config.yaml @@ -38,7 +38,8 @@ MODEL_ARGS: # Data args --data-path: ${DATA_BLEND} --vocab-file: ${DATA_PATH}/text/the_pile/t5_shard00/bert-large-cased-vocab.txt - --tokenizer-type: BertWordPieceCase + --tokenizer-library: megatron + --tokenizer-model: BertWordPieceCase --split: 99982,9,9 --data-cache-path: ${DATA_CACHE_PATH} --vocab-extra-ids: 100 diff --git a/tests/unit_tests/data/test_bin_reader.py b/tests/unit_tests/data/test_bin_reader.py index e479676ac4b..4f23a338b5f 100644 --- a/tests/unit_tests/data/test_bin_reader.py +++ b/tests/unit_tests/data/test_bin_reader.py @@ -1,3 +1,5 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + import os import random import sys @@ -165,7 +167,9 @@ def test_bin_reader(): path_to_raws, path_to_data, extra_args=[ - "--tokenizer-type", + "--tokenizer-library", + "megatron", + "--tokenizer-model", "GPT2BPETokenizer", "--vocab-file", gpt2_vocab(temp_dir), diff --git a/tests/unit_tests/data/test_builder.py b/tests/unit_tests/data/test_builder.py index 59f73911c82..4041eb96608 100644 --- a/tests/unit_tests/data/test_builder.py +++ b/tests/unit_tests/data/test_builder.py @@ -364,7 +364,8 @@ def test_fast_builder( tokenizer = build_tokenizer( Namespace( vocab_size=vocab_size, - tokenizer_type="NullTokenizer", + tokenizer_library="null", + tokenizer_model=None, rank=0, make_vocab_size_divisible_by=128, tensor_model_parallel_size=1, diff --git a/tests/unit_tests/data/test_preprocess_data.py b/tests/unit_tests/data/test_preprocess_data.py index 8df8c63e052..4e54319754d 100644 --- a/tests/unit_tests/data/test_preprocess_data.py +++ b/tests/unit_tests/data/test_preprocess_data.py @@ -186,7 +186,9 @@ def test_preprocess_data_gpt(): # gpt specific args gpt_args = [ - "--tokenizer-type", + "--tokenizer-library", + "megatron", + "--tokenizer-model", "GPT2BPETokenizer", "--vocab-file", "/opt/data/tokenizers/megatron/gpt2-vocab.json", @@ -250,7 +252,9 @@ def test_preprocess_data_bert(): # bert specific args bert_args = [ - "--tokenizer-type", + "--tokenizer-library", + "megatron", + "--tokenizer-model", "BertWordPieceLowerCase", "--vocab-file", "/opt/data/tokenizers/megatron/gpt2-vocab.json", diff --git a/tests/unit_tests/data/test_preprocess_mmdata.py b/tests/unit_tests/data/test_preprocess_mmdata.py index d6ad4eddc74..d7eb3f404f1 100644 --- a/tests/unit_tests/data/test_preprocess_mmdata.py +++ b/tests/unit_tests/data/test_preprocess_mmdata.py @@ -199,7 +199,9 @@ def test_preprocess_mmdata(): gpt_args = [ "--pad-length", "1024", - "--tokenizer-type", + "--tokenizer-library", + "megatron", + "--tokenizer-model", "GPT2BPETokenizer", "--vocab-file", gpt2_vocab(temp_dir), diff --git a/tests/unit_tests/test_argument_utils.py b/tests/unit_tests/test_argument_utils.py index e5744c3b074..4cce7de2607 100644 --- a/tests/unit_tests/test_argument_utils.py +++ b/tests/unit_tests/test_argument_utils.py @@ -499,7 +499,7 @@ def test_default_override(self): def test_choices_override(self): """Test that argparse_meta can override choices.""" - parser = ArgumentParser() + parser = ArgumentParser(exit_on_error=False) factory = ArgumentGroupFactory(ConfigWithArgparseMeta) factory.build_group(parser, title="Test Group") @@ -509,7 +509,7 @@ def test_choices_override(self): assert args.custom_choices == "option2" # Invalid choice should fail - with pytest.raises(SystemExit): + with pytest.raises(ArgumentError): parser.parse_args(['--custom-choices', 'invalid_option']) def test_dest_override(self): diff --git a/tests/unit_tests/tokenizers/test_tokenizer.py b/tests/unit_tests/tokenizers/test_tokenizer.py index f38674c6329..251d5e256cb 100755 --- a/tests/unit_tests/tokenizers/test_tokenizer.py +++ b/tests/unit_tests/tokenizers/test_tokenizer.py @@ -367,11 +367,11 @@ def test_null_multimodal_tokenizer(): def test_sft_tokenizer(): - """Test SFTTokenizer.""" + """Test SFT tokenization via MegatronTokenizerText with prompt_format.""" prompt_format = "nemotron-nano-v2" tokenizer = MegatronTokenizer.from_pretrained( tokenizer_path="/opt/data/tokenizers/multimodal", - metadata_path={"library": "sft"}, + metadata_path={"library": "huggingface"}, prompt_format=prompt_format, ) diff --git a/tools/preprocess_data.py b/tools/preprocess_data.py index f472dd50dbf..44b8a33b859 100644 --- a/tools/preprocess_data.py +++ b/tools/preprocess_data.py @@ -247,7 +247,7 @@ def get_args(): args = parser.parse_args() args.keep_empty = False - if args.tokenizer_type.lower().startswith('bert') and not args.split_sentences: + if getattr(args, 'tokenizer_type', None) and args.tokenizer_type.lower().startswith('bert') and not args.split_sentences: print("Are you sure you don't want to split sentences?") # some default/dummy values for the tokenizer