diff --git a/flair/embeddings/transformer.py b/flair/embeddings/transformer.py index 0622a4e16..fd46321e9 100644 --- a/flair/embeddings/transformer.py +++ b/flair/embeddings/transformer.py @@ -997,6 +997,11 @@ def __init__( force_max_length: bool = False, needs_manual_ocr: Optional[bool] = None, use_context_separator: bool = True, + transformers_tokenizer_kwargs: Dict[str, Any] = {}, + transformers_config_kwargs: Dict[str, Any] = {}, + transformers_model_kwargs: Dict[str, Any] = {}, + peft_config=None, + peft_gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = {}, **kwargs, ) -> None: """Instantiate transformers embeddings. @@ -1023,6 +1028,11 @@ def __init__( force_max_length: If True, the tokenizer will always pad the sequences to maximum length. needs_manual_ocr: If True, bounding boxes will be calculated manually. This is used for models like `layoutlm `_ where the tokenizer doesn't compute the bounding boxes itself. use_context_separator: If True, the embedding will hold an additional token to allow the model to distingulish between context and prediction. + transformers_tokenizer_kwargs: Further values forwarded to the initialization of the transformers tokenizer + transformers_config_kwargs: Further values forwarded to the initialization of the transformers config + transformers_model_kwargs: Further values forwarded to the initialization of the transformers model + peft_config: If set, the model will be trained using adapters and optionally QLoRA. Should be of type "PeftConfig" or subtype + peft_gradient_checkpointing_kwargs: Further values used when preparing the model for kbit training. Only used if peft_config is set. **kwargs: Further values forwarded to the transformers config """ self.instance_parameters = self.get_instance_parameters(locals=locals()) @@ -1042,7 +1052,9 @@ def __init__( if tokenizer_data is None: # load tokenizer and transformer model - self.tokenizer = AutoTokenizer.from_pretrained(model, add_prefix_space=True, **kwargs) + self.tokenizer = AutoTokenizer.from_pretrained( + model, add_prefix_space=True, **transformers_tokenizer_kwargs, **kwargs + ) try: self.feature_extractor = AutoFeatureExtractor.from_pretrained(model, apply_ocr=False) except OSError: @@ -1060,22 +1072,67 @@ def is_supported_t5_model(config: PretrainedConfig) -> bool: return getattr(config, "model_type", "") in t5_supported_model_types if saved_config is None: - config = AutoConfig.from_pretrained(model, output_hidden_states=True, **kwargs) + config = AutoConfig.from_pretrained( + model, output_hidden_states=True, **transformers_config_kwargs, **kwargs + ) if is_supported_t5_model(config): from transformers import T5EncoderModel - transformer_model = T5EncoderModel.from_pretrained(model, config=config) + transformer_model = T5EncoderModel.from_pretrained( + model, config=config, **transformers_model_kwargs, **kwargs + ) else: - transformer_model = AutoModel.from_pretrained(model, config=config) + transformer_model = AutoModel.from_pretrained( + model, config=config, **transformers_model_kwargs, **kwargs + ) else: if is_supported_t5_model(saved_config): from transformers import T5EncoderModel - transformer_model = T5EncoderModel(saved_config, **kwargs) + transformer_model = T5EncoderModel(saved_config, **transformers_model_kwargs, **kwargs) else: - transformer_model = AutoModel.from_config(saved_config, **kwargs) - transformer_model = transformer_model.to(flair.device) + transformer_model = AutoModel.from_config(saved_config, **transformers_model_kwargs, **kwargs) + try: + transformer_model = transformer_model.to(flair.device) + except ValueError as e: + # if model is quantized by BitsAndBytes this will fail + if "Please use the model as it is" not in str(e): + raise e + + if peft_config is not None: + # add adapters for finetuning + try: + from peft import ( + TaskType, + get_peft_model, + prepare_model_for_kbit_training, + ) + except ImportError: + log.error("You cannot use the PEFT finetuning without peft being installed") + raise + # peft_config: PeftConfig + if peft_config.task_type is None: + peft_config.task_type = TaskType.FEATURE_EXTRACTION + if peft_config.task_type != TaskType.FEATURE_EXTRACTION: + log.warn("The task type for PEFT should be set to FEATURE_EXTRACTION, as it is the only supported type") + if ( + "load_in_4bit" in {**kwargs, **transformers_model_kwargs} + or "load_in_8bit" in {**kwargs, **transformers_model_kwargs} + or "quantization_config" in {**kwargs, **transformers_model_kwargs} + ): + transformer_model = prepare_model_for_kbit_training( + transformer_model, + **(peft_gradient_checkpointing_kwargs or {}), + ) + transformer_model = get_peft_model(model=transformer_model, peft_config=peft_config) + + trainable_params, all_param = transformer_model.get_nb_trainable_parameters() + log.info( + f"trainable params: {trainable_params:,d} || " + f"all params: {all_param:,d} || " + f"trainable%: {100 * trainable_params / all_param:.4f}" + ) self.truncate = True self.force_max_length = force_max_length