diff --git a/opennyai/summarizer/others/tokenization.py b/opennyai/summarizer/others/tokenization.py index 116f1a9..0fd62e2 100644 --- a/opennyai/summarizer/others/tokenization.py +++ b/opennyai/summarizer/others/tokenization.py @@ -21,12 +21,19 @@ import unicodedata from io import open -from pytorch_transformers import cached_path +# --- START OF MODIFICATION --- +# Replaced 'from pytorch_transformers import cached_path' with the modern equivalent from huggingface_hub +from huggingface_hub import cached_download, HfFolder from wasabi import msg -from opennyai.utils.download import CACHE_DIR +# from opennyai.utils.download import CACHE_DIR # This import is no longer needed/relevant for Hugging Face caching. +# Hugging Face Hub manages its own cache directories. + +# EXTRACTIVE_SUMMARIZER_CACHE_PATH = os.path.join(CACHE_DIR, 'ExtractiveSummarizer'.lower()) +# The cache path should be managed by huggingface_hub directly or specified by environment variables. +# We will pass cache_dir explicitly to cached_download. +# --- END OF MODIFICATION --- -EXTRACTIVE_SUMMARIZER_CACHE_PATH = os.path.join(CACHE_DIR, 'ExtractiveSummarizer'.lower()) PRETRAINED_VOCAB_ARCHIVE_MAP = { 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", @@ -131,30 +138,43 @@ def convert_ids_to_tokens(self, ids): return tokens @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=EXTRACTIVE_SUMMARIZER_CACHE_PATH, *inputs, + def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, # Changed default cache_dir to None **kwargs): """ Instantiate a PreTrainedBertModel from a pre-trained model file. Download and cache the pre-trained model file if needed. """ if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: - vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] - else: + vocab_url = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] + # Use huggingface_hub's default cache logic if cache_dir is not provided + local_path = cached_download(url=vocab_url, cache_dir=cache_dir) + vocab_file = local_path + elif os.path.isfile(pretrained_model_name_or_path): + # If it's already a local file path, use it directly vocab_file = pretrained_model_name_or_path - if os.path.isdir(vocab_file): - vocab_file = os.path.join(vocab_file, VOCAB_NAME) - # redirect to the cache, if necessary - try: - resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) - except EnvironmentError: + else: + # Assume it's a model name that needs to be resolved by the hub + # This might require a model_id and filename. For this old BERT tokenizer, + # it's best to stick to the explicit URL or local file path approach. msg.fail( "Model name '{}' was not found in model name list ({}). " "We assumed '{}' was a path or url but couldn't find any file " "associated to this path or url.".format( pretrained_model_name_or_path, ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), - vocab_file)) + pretrained_model_name_or_path)) return None + + # If it's a directory, assume vocab.txt is inside + if os.path.isdir(vocab_file): + vocab_file = os.path.join(vocab_file, VOCAB_NAME) + + # Ensure the resolved_vocab_file exists + if not os.path.isfile(vocab_file): + msg.fail( + "Resolved vocabulary file '{}' could not be found.".format(vocab_file)) + return None + # if resolved_vocab_file == vocab_file: # msg.info("loading vocabulary file {}".format(vocab_file)) # else: @@ -166,7 +186,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=EXTRACTIVE_SUM max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path] kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) # Instantiate tokenizer. - tokenizer = cls(resolved_vocab_file, *inputs, **kwargs) + tokenizer = cls(vocab_file, *inputs, **kwargs) return tokenizer