Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 34 additions & 14 deletions opennyai/summarizer/others/tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,19 @@
import unicodedata
from io import open

from pytorch_transformers import cached_path
# --- START OF MODIFICATION ---
# Replaced 'from pytorch_transformers import cached_path' with the modern equivalent from huggingface_hub
from huggingface_hub import cached_download, HfFolder
from wasabi import msg

from opennyai.utils.download import CACHE_DIR
# from opennyai.utils.download import CACHE_DIR # This import is no longer needed/relevant for Hugging Face caching.
# Hugging Face Hub manages its own cache directories.

# EXTRACTIVE_SUMMARIZER_CACHE_PATH = os.path.join(CACHE_DIR, 'ExtractiveSummarizer'.lower())
# The cache path should be managed by huggingface_hub directly or specified by environment variables.
# We will pass cache_dir explicitly to cached_download.
# --- END OF MODIFICATION ---

EXTRACTIVE_SUMMARIZER_CACHE_PATH = os.path.join(CACHE_DIR, 'ExtractiveSummarizer'.lower())

PRETRAINED_VOCAB_ARCHIVE_MAP = {
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
Expand Down Expand Up @@ -131,30 +138,43 @@ def convert_ids_to_tokens(self, ids):
return tokens

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=EXTRACTIVE_SUMMARIZER_CACHE_PATH, *inputs,
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, # Changed default cache_dir to None
**kwargs):
"""
Instantiate a PreTrainedBertModel from a pre-trained model file.
Download and cache the pre-trained model file if needed.
"""
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
else:
vocab_url = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
# Use huggingface_hub's default cache logic if cache_dir is not provided
local_path = cached_download(url=vocab_url, cache_dir=cache_dir)
vocab_file = local_path
elif os.path.isfile(pretrained_model_name_or_path):
# If it's already a local file path, use it directly
vocab_file = pretrained_model_name_or_path
if os.path.isdir(vocab_file):
vocab_file = os.path.join(vocab_file, VOCAB_NAME)
# redirect to the cache, if necessary
try:
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
except EnvironmentError:
else:
# Assume it's a model name that needs to be resolved by the hub
# This might require a model_id and filename. For this old BERT tokenizer,
# it's best to stick to the explicit URL or local file path approach.
msg.fail(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name_or_path,
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
vocab_file))
pretrained_model_name_or_path))
return None

# If it's a directory, assume vocab.txt is inside
if os.path.isdir(vocab_file):
vocab_file = os.path.join(vocab_file, VOCAB_NAME)

# Ensure the resolved_vocab_file exists
if not os.path.isfile(vocab_file):
msg.fail(
"Resolved vocabulary file '{}' could not be found.".format(vocab_file))
return None

# if resolved_vocab_file == vocab_file:
# msg.info("loading vocabulary file {}".format(vocab_file))
# else:
Expand All @@ -166,7 +186,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=EXTRACTIVE_SUM
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
# Instantiate tokenizer.
tokenizer = cls(resolved_vocab_file, *inputs, **kwargs)
tokenizer = cls(vocab_file, *inputs, **kwargs)
return tokenizer


Expand Down