Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 24 additions & 22 deletions src/transformers/processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,26 +654,18 @@ def to_dict(self) -> dict[str, Any]:
Returns:
`dict[str, Any]`: Dictionary of all the attributes that make up this processor instance.
"""
output = copy.deepcopy(self.__dict__)
# shallow copy to avoid deepcopy errors
output = self.__dict__.copy()

# Get the kwargs in `__init__`.
sig = inspect.signature(self.__init__)
# Only save the attributes that are presented in the kwargs of `__init__`.
# or in the attributes
attrs_to_save = list(sig.parameters) + self.__class__.attributes
# extra attributes to be kept
attrs_to_save += ["auto_map"]

if "tokenizer" in output:
del output["tokenizer"]
if "qformer_tokenizer" in output:
del output["qformer_tokenizer"]
if "protein_tokenizer" in output:
del output["protein_tokenizer"]
if "char_tokenizer" in output:
del output["char_tokenizer"]
if "chat_template" in output:
del output["chat_template"]
# Save only the attributes that are either passed as kwargs to `__init__`,
# defined in the class's `attributes` list, or included in "auto_map".
attrs_to_save = list(sig.parameters) + self.__class__.attributes + ["auto_map"]

# Special attributes to handle: tokenizers and chat_template
for key in ["tokenizer", "qformer_tokenizer", "protein_tokenizer", "char_tokenizer", "chat_template"]:
output.pop(key, None)

def save_public_processor_class(dictionary):
# make sure private name "_processor_class" is correctly
Expand Down Expand Up @@ -748,7 +740,7 @@ def __repr__(self):
attributes_repr = "\n".join(attributes_repr)
return f"{self.__class__.__name__}:\n{attributes_repr}\n\n{self.to_json_string()}"

def save_pretrained(self, save_directory, push_to_hub: bool = False, **kwargs):
def save_pretrained(self, save_directory, save_jinja_files=False, push_to_hub: bool = False, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's avoid adding an arg here for API consistency, and I don't think it's needed too

"""
Saves the attributes of this processor (feature extractor, tokenizer...) in the specified directory so that it
can be reloaded using the [`~ProcessorMixin.from_pretrained`] method.
Expand Down Expand Up @@ -792,9 +784,12 @@ def save_pretrained(self, save_directory, push_to_hub: bool = False, **kwargs):
if hasattr(attribute, "_set_processor_class"):
attribute._set_processor_class(self.__class__.__name__)

# Save the tokenizer in its own vocab file. The other attributes are saved as part of `processor_config.json`
if attribute_name == "tokenizer":
attribute.save_pretrained(save_directory)
# if attribute is tokenizer, then save it in its own file for avoid overwriting
if hasattr(attribute, "save_pretrained"):
# use the attribute_name as prefix to create a unique file
attribute_save_dir = os.path.join(save_directory, attribute_name)
os.makedirs(attribute_save_dir, exist_ok=True)
attribute.save_pretrained(attribute_save_dir, save_jinja_files=save_jinja_files)
Comment on lines +790 to +792
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed evolla already uses the subdir approach. however I'd reserve that for multi-tokenizers models only, checking with len(tokenizer_attributes) > 1 and keeping the normal save path in other cases.

elif attribute._auto_class is not None:
custom_object_save(attribute, save_directory, config=attribute)

Expand Down Expand Up @@ -1425,7 +1420,14 @@ def _get_arguments_from_pretrained(cls, pretrained_model_name_or_path, **kwargs)
else:
attribute_class = cls.get_possibly_dynamic_module(class_name)

args.append(attribute_class.from_pretrained(pretrained_model_name_or_path, **kwargs))
# updated loading path for handling multiple tokenizers
attribute_path = os.path.join(pretrained_model_name_or_path, attribute_name)
if os.path.isdir(attribute_path):
# load from its attribute's-specific folder
args.append(attribute_class.from_pretrained(attribute_path, **kwargs))
else:
# now fallback to original path
args.append(attribute_class.from_pretrained(pretrained_model_name_or_path, **kwargs))

return args

Expand Down
37 changes: 37 additions & 0 deletions tests/test_processor_utils.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does not need to be in an entirely new file, can be in test_processing_common and actually be part of the processor testing suite

Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import tempfile

from transformers import AutoTokenizer, PreTrainedTokenizer, ProcessorMixin
from transformers.testing_utils import TestCasePlus


class ProcessorSavePretrainedMultipleAttributes(TestCasePlus):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test case test_processor_from_and_save_pretrained should be modified to handle the multi-tokenizers case

def test_processor_loads_separate_attributes(self):
class OtherProcessor(ProcessorMixin):
name = "other-processor"

attributes = [
"tokenizer1",
"tokenizer2",
]
tokenizer1_class = "AutoTokenizer"
tokenizer2_class = "AutoTokenizer"

def __init__(self,
tokenizer1: PreTrainedTokenizer,
tokenizer2: PreTrainedTokenizer
):
super().__init__(tokenizer1=tokenizer1,
tokenizer2=tokenizer2)

tokenizer1 = AutoTokenizer.from_pretrained("google/gemma-3-270m")
tokenizer2 = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-1.7B")

processor = OtherProcessor(tokenizer1=tokenizer1,
tokenizer2=tokenizer2)
assert processor.tokenizer1.__class__ != processor.tokenizer2.__class__

with tempfile.TemporaryDirectory() as temp_dir:
processor.save_pretrained(save_directory=temp_dir, push_to_hub=False)
new_processor = OtherProcessor.from_pretrained(temp_dir)

assert new_processor.tokenizer1.__class__ != new_processor.tokenizer2.__class__