From f9ff09876e98b07b75a085f1975fde18c60ba9a8 Mon Sep 17 00:00:00 2001 From: pass_lin <935499957@qq.com> Date: Sat, 3 May 2025 01:28:34 +0800 Subject: [PATCH 01/13] add esm --- keras_hub/src/models/esm/__init__.py | 0 keras_hub/src/models/esm/esm_attention.py | 81 +++++++ keras_hub/src/models/esm/esm_backbone.py | 211 ++++++++++++++++++ keras_hub/src/models/esm/esm_backbone_test.py | 33 +++ keras_hub/src/models/esm/esm_classifier.py | 121 ++++++++++ .../models/esm/esm_classifier_preprocessor.py | 137 ++++++++++++ .../esm/esm_classifier_preprocessor_test.py | 46 ++++ .../src/models/esm/esm_classifier_test.py | 59 +++++ keras_hub/src/models/esm/esm_encoder.py | 137 ++++++++++++ keras_hub/src/models/esm/esm_masked_plm.py | 119 ++++++++++ .../models/esm/esm_masked_plm_preprocessor.py | 148 ++++++++++++ .../esm/esm_masked_plm_preprocessor_test.py | 60 +++++ .../src/models/esm/esm_masked_plm_test.py | 57 +++++ keras_hub/src/models/esm/esm_presets.py | 0 keras_hub/src/models/esm/esm_tokenizer.py | 62 +++++ .../src/models/esm/esm_tokenizer_test.py | 40 ++++ .../src/utils/transformers/convert_esm.py | 153 +++++++++++++ 17 files changed, 1464 insertions(+) create mode 100644 keras_hub/src/models/esm/__init__.py create mode 100644 keras_hub/src/models/esm/esm_attention.py create mode 100644 keras_hub/src/models/esm/esm_backbone.py create mode 100644 keras_hub/src/models/esm/esm_backbone_test.py create mode 100644 keras_hub/src/models/esm/esm_classifier.py create mode 100644 keras_hub/src/models/esm/esm_classifier_preprocessor.py create mode 100644 keras_hub/src/models/esm/esm_classifier_preprocessor_test.py create mode 100644 keras_hub/src/models/esm/esm_classifier_test.py create mode 100644 keras_hub/src/models/esm/esm_encoder.py create mode 100644 keras_hub/src/models/esm/esm_masked_plm.py create mode 100644 keras_hub/src/models/esm/esm_masked_plm_preprocessor.py create mode 100644 keras_hub/src/models/esm/esm_masked_plm_preprocessor_test.py create mode 100644 keras_hub/src/models/esm/esm_masked_plm_test.py create mode 100644 keras_hub/src/models/esm/esm_presets.py create mode 100644 keras_hub/src/models/esm/esm_tokenizer.py create mode 100644 keras_hub/src/models/esm/esm_tokenizer_test.py create mode 100644 keras_hub/src/utils/transformers/convert_esm.py diff --git a/keras_hub/src/models/esm/__init__.py b/keras_hub/src/models/esm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/keras_hub/src/models/esm/esm_attention.py b/keras_hub/src/models/esm/esm_attention.py new file mode 100644 index 0000000000..e6cef2ffd7 --- /dev/null +++ b/keras_hub/src/models/esm/esm_attention.py @@ -0,0 +1,81 @@ +from keras import ops +from keras import initializers +import keras +from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding +from keras_hub.src.models.roformer_v2.roformer_v2_attention import RoformerAttention +class ESMRotaryEmbedding(RotaryEmbedding): + def _compute_cos_sin_embedding(self,x,position=1): + dim = x.shape[-1] + inv_freq = self.scaling_factor / (self.max_wavelength ** (ops.arange(0, dim, 2, dtype=x.dtype) / dim)) + t = ops.arange(x.shape[position],dtype=x.dtype) + freqs = ops.outer(t, inv_freq) + emb = ops.concatenate((freqs, freqs), axis=-1) + + cos_emb = ops.cos(emb)[None, :,None, :] + sin_emb = ops.sin(emb)[None, :,None, :] + return cos_emb, sin_emb + def call(self, q, k,position=1): + cos_emb, sin_emb = self._compute_cos_sin_embedding(q,position) + + return ( + self.apply_rotary_pos_emb(q, cos_emb, sin_emb), + self.apply_rotary_pos_emb(k, cos_emb, sin_emb), + ) + def rotate_half(self,x): + x1, x2 = ops.split(x,2,-1) + return ops.concatenate((-x2, x1), axis=-1) + def apply_rotary_pos_emb(self,x, cos, sin): + cos = cos[:, : x.shape[1],:, :] + sin = sin[:, : x.shape[1],:, :] + + return (x * cos) + (self.rotate_half(x) * sin) + +class EsmSelfAttention(RoformerAttention): + """MultiHeadAttention by ESM2 + + Referred to the implementation of HuggingFace. + In fact, this part of the calculation is exactly the same as RoFormer. + Only the calculation of the rotary part is different. + """ + def __init__(self,use_rotary=True,**kwargs): + super().__init__(**kwargs) + self.use_rotary = use_rotary + def build(self, input_shape): + super().build(input_shape) + if self.use_rotary: + self.rotary_embedding_layer = ESMRotaryEmbedding( + max_wavelength = self.max_wavelength, dtype=self.dtype_policy + ) + self.rotary_embedding_layer.build([]) + def call(self, x, attention_mask=None): + qw = self.q_dense(x) + kw = self.k_dense(x) + vw = self.v_dense(x) + + b, s = ops.shape(qw)[:2] + qw = ops.reshape(qw, (b, s, self.heads, self.head_size)) + kw = ops.reshape(kw, (b, s, self.heads, self.head_size)) + vw = ops.reshape(vw, (b, s, self.heads, self.head_size)) + + if self.use_rotary: + qw, kw = self.rotary_embedding_layer(qw, kw) + if keras.__version__ < "3.6": + raise ("Please make sure your Keras version is >=3.6.") + flash_attention = keras.config.is_flash_attention_enabled() + attention_mask = ops.reshape(attention_mask, [b, 1, s, 1]) + if keras.config.backend() == "torch": + attention_mask = ops.repeat(attention_mask, s, -1) + attention_mask = ops.transpose(attention_mask, [0, 1, 3, 2]) + o = ops.dot_product_attention( + qw, kw, vw, mask=attention_mask, flash_attention=flash_attention + ) + return self.o_dense(ops.reshape(o, [b, s, -1])) + def get_config(self): + config = super().get_config() + config.update( + { + "use_rotary": self.use_rotary, + } + ) + return config + diff --git a/keras_hub/src/models/esm/esm_backbone.py b/keras_hub/src/models/esm/esm_backbone.py new file mode 100644 index 0000000000..26e0aeeb67 --- /dev/null +++ b/keras_hub/src/models/esm/esm_backbone.py @@ -0,0 +1,211 @@ +import keras +from keras import activations + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.backbone import Backbone +from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding +from keras_hub.src.models.esm.esm_encoder import ESMEncoder + +def esm2_kernel_initializer(stddev=0.02): + return keras.initializers.TruncatedNormal(stddev=stddev) +@keras_hub_export(["keras_hub.models.ESM2Backbone","keras_hub.models.ESMBackbone"]) +class ESMBackbone(Backbone): + """A ESM2 and ESM encoder network. + + This class implements a bi-directional Transformer-based encoder as + described in ["Roformer"](https://github.com/facebookresearch/esm). + + The default constructor gives a fully customizable, randomly initialized + ESM2 encoder with any number of layers, heads, and embed dim.To + load preset architectures and weights, use the `from_preset()` constructor. + + Disclaimer: Pre-trained models are provided on an "as is" basis, without + warranties or conditions of any kind. + + Args: + vocabulary_size: int. The size of the token vocabulary. + num_layers: int. The number of transformer layers. + num_heads: int. The number of attention heads for each transformer. + The hidden size must be divisible by the number of attention heads. + hidden_dim: int. The size of the transformer encoding and pooler layers. + intermediate_dim: int. The output dimension of the first Dense layer in + a two-layer feedforward network for each transformer. + dropout: float. Dropout probability for the Transformer encoder. + layer_norm_eps:bool.Should we use ln after embedding? + Since it's pre-norm, the default is false. + max_sequence_length: int. The maximum sequence length that this encoder + can consume. If None, `max_sequence_length` uses the value from + sequence length. This determines the variable shape for positional + embeddings. + position_embedding_type:esm1 use abs position embeding,esm2 use rope. + so this parameter is only except for absolute and rotary. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use + for model computations and weights. Note that some computations, + such as softmax and layer normalization, will always be done at + float32 precision regardless of dtype. + + Examples: + ```python + input_data = { + "token_ids": np.ones(shape=(1, 12), dtype="int32"), + } + + # Pretrained ESM2 encoder. + model = keras_hub.models.ESM2Backbone.from_preset('hf://facebook/esm2_t6_8M_UR50D') + model(input_data) + + # Randomly initialized ESM2 encoder with a custom config. + model = keras_hub.models.ESM2Backbone( + vocabulary_size=30552, + num_layers=4, + num_heads=4, + hidden_dim=256, + intermediate_dim=512, + head_size = 64, + ) + model(input_data) + ``` + """ + + def __init__( + self, + vocabulary_size, + num_layers, + num_heads, + hidden_dim, + intermediate_dim, + head_size, + use_bias=False, + activation="gelu", + dropout=0.1, + dtype=None, + max_sequence_length = 1024, + max_wavelength=10000, + layer_norm_eps = 1e-12, + emb_layer_norm_before = False, + position_embedding_type = "rotary", + pad_token_id = 0, + **kwargs, + ): + support_positon_type = ["rotary","absolute"] + if position_embedding_type.lower() not in support_positon_type: + raise(f"This model only support below position embedding type: {support_positon_type}") + + # === Layers === + self.token_embedding = keras.layers.Embedding( + input_dim=vocabulary_size, + output_dim=hidden_dim, + embeddings_initializer=esm2_kernel_initializer(), + dtype=dtype, + name="token_embedding", + ) + if position_embedding_type == "absolute": + self.position_embedding = PositionEmbedding( + initializer=esm2_kernel_initializer(), + sequence_length=max_sequence_length, + dtype=dtype, + name="position_embedding", + ) + self.embeddings_add = keras.layers.Add( + dtype=dtype, + name="embeddings_add", + ) + + + self.output_layer_norm = keras.layers.LayerNormalization( + epsilon=layer_norm_eps, + dtype=dtype, + name="output_layer_norm", + ) + if emb_layer_norm_before: + self.emb_layer_norm = keras.layers.LayerNormalization( + epsilon=layer_norm_eps, + dtype=dtype, + name="emb_layer_norm", + ) + self.transformer_layers = [] + for i in range(num_layers): + layer = ESMEncoder( + heads=num_heads, + head_size=head_size, + intermediate_size=intermediate_dim, + use_bias=use_bias, + max_wavelength=max_wavelength, + dropout=dropout, + activation=activation, + kernel_initializer=esm2_kernel_initializer(), + layer_norm_eps = layer_norm_eps, + dtype=dtype, + use_rotary=position_embedding_type=="rotary", + name=f"transformer_layer_{i}", + ) + self.transformer_layers.append(layer) + + # === Functional Model === + token_id_input = keras.Input( + shape=(None,), dtype="int32", name="token_ids" + ) + + attention_mask = keras.ops.not_equal(token_id_input, pad_token_id) + + + token_vector = self.token_embedding(token_id_input) + if position_embedding_type == "absolute": + position_vector = self.position_embedding(token_vector) + x = self.embeddings_add([token_vector, position_vector]) + else: + x = token_vector + if emb_layer_norm_before: + x = self.emb_layer_norm(x) + for transformer_layer in self.transformer_layers: + x = transformer_layer(x, attention_mask=attention_mask) + output = self.output_layer_norm(x) + super().__init__( + inputs={ + "token_ids": token_id_input, + }, + outputs=output, + dtype=dtype, + **kwargs, + ) + + # === Config === + self.vocabulary_size = vocabulary_size + self.num_layers = num_layers + self.num_heads = num_heads + self.hidden_dim = hidden_dim + self.intermediate_dim = intermediate_dim + self.dropout = dropout + self.max_wavelength = max_wavelength + self.head_size = head_size + self.dropout = dropout + self.activation = activations.get(activation) + self.use_bias = use_bias + self.start_token_index = 0 + self.layer_norm_eps = layer_norm_eps + self.max_sequence_length = max_sequence_length + self.emb_layer_norm_before = emb_layer_norm_before + self.position_embedding_type = position_embedding_type + self.pad_token_id = pad_token_id + def get_config(self): + config = super().get_config() + config.update( + { + "vocabulary_size": self.vocabulary_size, + "num_layers": self.num_layers, + "num_heads": self.num_heads, + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "dropout": self.dropout, + "max_wavelength": self.max_wavelength, + "head_size": self.head_size, + "use_bias": self.use_bias, + "activation": activations.serialize(self.activation), + "layer_norm_eps":self.layer_norm_eps, + "emb_layer_norm_before":self.emb_layer_norm_before, + "position_embedding_type":self.position_embedding_type, + "max_sequence_length":self.max_sequence_length, + "pad_token_id":self.pad_token_id, + } + ) + return config diff --git a/keras_hub/src/models/esm/esm_backbone_test.py b/keras_hub/src/models/esm/esm_backbone_test.py new file mode 100644 index 0000000000..aef8454c1a --- /dev/null +++ b/keras_hub/src/models/esm/esm_backbone_test.py @@ -0,0 +1,33 @@ +import keras +from keras import ops + +from keras_hub.src.models.esm.esm_backbone import ( + ESMBackbone, +) +from keras_hub.src.tests.test_case import TestCase + + +class ESMBackboneTest(TestCase): + def setUp(self): + self.init_kwargs = { + "vocabulary_size": 10, + "num_layers": 2, + "num_heads": 1, + "hidden_dim": 2, + "intermediate_dim": 4, + "head_size": 2, + } + self.input_data = { + "token_ids": ops.ones((2, 5), dtype="int32"), + "segment_ids": ops.zeros((2, 5), dtype="int32"), + } + + def test_backbone_basics(self): + if keras.__version__ < "3.6": + self.skipTest("Failing on keras lower version") + self.run_backbone_test( + cls=ESMBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=(2, 5, 2), + ) diff --git a/keras_hub/src/models/esm/esm_classifier.py b/keras_hub/src/models/esm/esm_classifier.py new file mode 100644 index 0000000000..1c6ca925ff --- /dev/null +++ b/keras_hub/src/models/esm/esm_classifier.py @@ -0,0 +1,121 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.roberta.roberta_text_classifier import ( + RobertaTextClassifier, # noqa: E501 +) +from keras_hub.src.models.esm.esm_backbone import ( + ESMBackbone, +) +from keras_hub.src.models.esm.esm_classifier_preprocessor import ( + ESMProteinClassifierPreprocessor +) + + +@keras_hub_export("keras_hub.models.ESMProteinClassifier") +class ESMProteinClassifier(RobertaTextClassifier): + """An end-to-end ESM model for classification tasks. + + This model attaches a classification head to + `keras_hub.model.ESMBackbone`, mapping from the backbone outputs + to logits suitable for a classification task. For usage of this model with + pre-trained weights, use the `from_preset()` constructor. + + This model can optionally be configured with a `preprocessor` layer, in + which case it will automatically apply preprocessing to raw inputs during + `fit()`, `predict()`, and `evaluate()`. This is done by default when + creating the model with `from_preset()`. + + Disclaimer: Pre-trained models are provided on an "as is" basis, without + warranties or conditions of any kind. + + Args: + backbone: A `keras_hub.models.ESMBackbone` instance. + num_classes: int. Number of classes to predict. + preprocessor: A `keras_hub.models.ESMProteinClassifierPreprocessor` + or `None`. If `None`, this model will not apply preprocessing, and + inputs should be preprocessed before calling the model. + activation: Optional `str` or callable. The + activation function to use on the model outputs. Set + `activation="softmax"` to return output probabilities. + Defaults to `None`. + dropout: float. The dropout probability value, applied after the dense + layer. + + Examples: + + Raw string data. + ```python + features = ["The quick brown fox jumped.", "I forgot my homework."] + labels = [0, 3] + + # Pretrained classifier. + classifier = keras_hub.models.ESMProteinClassifier.from_preset( + "roformer_v2_base_zh", + num_classes=4, + ) + classifier.fit(x=features, y=labels, batch_size=2) + classifier.predict(x=features, batch_size=2) + + # Re-compile (e.g., with a new learning rate). + classifier.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + optimizer=keras.optimizers.Adam(5e-5), + jit_compile=True, + ) + # Access backbone programmatically (e.g., to change `trainable`). + classifier.backbone.trainable = False + # Fit again. + classifier.fit(x=features, y=labels, batch_size=2) + ``` + + Preprocessed integer data. + ```python + features = { + "token_ids": np.ones(shape=(2, 12), dtype="int32"), + "segment_ids": np.array([[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0]] * 2), + "padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]] * 2), + } + labels = [0, 3] + + # Pretrained classifier without preprocessing. + classifier = keras_hub.models.ESMProteinClassifier.from_preset( + "roformer_v2_base_zh", + num_classes=4, + preprocessor=None, + ) + classifier.fit(x=features, y=labels, batch_size=2) + ``` + + Custom backbone and vocabulary. + ```python + features = ["The quick brown fox jumped.", "I forgot my homework."] + labels = [0, 3] + + vocab = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"] + vocab += ["The", "quick", "brown", "fox", "jumped", "."] + tokenizer = keras_hub.models.ESMTokenizer( + vocabulary=vocab, + ) + preprocessor = keras_hub.models.ESMProteinClassifierPreprocessor( + tokenizer=tokenizer, + sequence_length=128, + ) + backbone = keras_hub.models.ESMBackbone( + vocabulary_size=30552, + num_layers=4, + num_heads=4, + hidden_dim=256, + intermediate_dim=512, + max_wavelength=128, + head_size=64, + ) + classifier = keras_hub.models.ESMProteinClassifier( + backbone=backbone, + preprocessor=preprocessor, + num_classes=4, + ) + classifier.fit(x=features, y=labels, batch_size=2) + ``` + """ + + backbone_cls = ESMBackbone + preprocessor_cls = ESMProteinClassifierPreprocessor diff --git a/keras_hub/src/models/esm/esm_classifier_preprocessor.py b/keras_hub/src/models/esm/esm_classifier_preprocessor.py new file mode 100644 index 0000000000..42259d012f --- /dev/null +++ b/keras_hub/src/models/esm/esm_classifier_preprocessor.py @@ -0,0 +1,137 @@ +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.bert.bert_text_classifier_preprocessor import ( + BertTextClassifierPreprocessor, +) +from keras_hub.src.models.esm.esm_backbone import ( + ESMBackbone, +) +from keras_hub.src.models.esm.esm_tokenizer import ( + ESMTokenizer, +) +from keras_hub.src.utils.tensor_utils import preprocessing_function + +from keras_hub.src.layers.preprocessing.start_end_packer import ( + StartEndPacker, +) +@keras_hub_export("keras_hub.models.ESMProteinClassifierPreprocessor") +class ESMProteinClassifierPreprocessor(BertTextClassifierPreprocessor): + """A ESM preprocessing layer which tokenizes and packs inputs. + + This preprocessing layer will do three things: + + 1. Tokenize any number of input segments using the `tokenizer`. + 2. Pack the inputs together using a `keras_hub.layers.MultiSegmentPacker`. + with the appropriate `"[CLS]"`, `"[SEP]"` and `"[PAD]"` tokens. + 3. Construct a dictionary with keys `"token_ids"`, `"segment_ids"`, + `"padding_mask"`, that can be passed directly to a ESM model. + + This layer can be used directly with `tf.data.Dataset.map` to preprocess + string data in the `(x, y, sample_weight)` format used by + `keras.Model.fit`. + + Args: + tokenizer: A `keras_hub.models.ESMTokenizer` instance. + sequence_length: The length of the packed inputs. + truncate: string. The algorithm to truncate a list of batched segments + to fit within `sequence_length`. The value can be either + `round_robin` or `waterfall`: + - `"round_robin"`: Available space is assigned one token at a + time in a round-robin fashion to the inputs that still need + some, until the limit is reached. + - `"waterfall"`: The allocation of the budget is done using a + "waterfall" algorithm that allocates quota in a + left-to-right manner and fills up the buckets until we run + out of budget. It supports an arbitrary number of segments. + + Call arguments: + x: A tensor of single string sequences, or a tuple of multiple + tensor sequences to be packed together. Inputs may be batched or + unbatched. For single sequences, raw python inputs will be converted + to tensors. For multiple sequences, pass tensors directly. + y: Any label data. Will be passed through unaltered. + sample_weight: Any label weight data. Will be passed through unaltered. + + Examples: + + Directly calling the layer on data. + ```python + preprocessor = keras_hub.models.ProteinClassifierPreprocessor.from_preset( + "roformer_v2_base_zh" + ) + + # Tokenize and pack a single sentence. + preprocessor("The quick brown fox jumped.") + + # Tokenize a batch of single sentences. + preprocessor(["The quick brown fox jumped.", "Call me Ishmael."]) + + # Preprocess a batch of sentence pairs. + # When handling multiple sequences, always convert to tensors first! + first = tf.constant(["The quick brown fox jumped.", "Call me Ishmael."]) + second = tf.constant(["The fox tripped.", "Oh look, a whale."]) + preprocessor((first, second)) + + # Custom vocabulary. + vocab = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"] + vocab += ["The", "quick", "brown", "fox", "jumped", "."] + tokenizer = keras_hub.models.ESMTokenizer(vocabulary=vocab) + preprocessor = + keras_hub.models.ESMProteinClassifierPreprocessor(tokenizer) + preprocessor("The quick brown fox jumped.") + ``` + + Mapping with `tf.data.Dataset`. + ```python + preprocessor = keras_hub.models.ProteinClassifierPreprocessor.from_preset( + "roformer_v2_base_zh" + ) + + first = tf.constant(["The quick brown fox jumped.", "Call me Ishmael."]) + second = tf.constant(["The fox tripped.", "Oh look, a whale."]) + label = tf.constant([1, 1]) + + # Map labeled single sentences. + ds = tf.data.Dataset.from_tensor_slices((first, label)) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + + # Map unlabeled single sentences. + ds = tf.data.Dataset.from_tensor_slices(first) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + + # Map labeled sentence pairs. + ds = tf.data.Dataset.from_tensor_slices(((first, second), label)) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + + # Map unlabeled sentence pairs. + ds = tf.data.Dataset.from_tensor_slices((first, second)) + # Watch out for tf.data's default unpacking of tuples here! + # Best to invoke the `preprocessor` directly in this case. + ds = ds.map( + lambda first, second: preprocessor(x=(first, second)), + num_parallel_calls=tf.data.AUTOTUNE, + ) + ``` + """ + + backbone_cls = ESMBackbone + tokenizer_cls = ESMTokenizer + def build(self, input_shape): + super().build(input_shape) + # Defer masker creation to `build()` so that we can be sure tokenizer + # assets have loaded when restoring a saved model. + self.packer = StartEndPacker( + start_value=self.tokenizer.start_token_id, + end_value=self.tokenizer.end_token_id, + pad_value=self.tokenizer.pad_token_id, + sequence_length=self.sequence_length, + ) + @preprocessing_function + def call(self, x, y=None, sample_weight=None): + x = self.tokenizer(x) + token_ids = self.packer(x) + x = { + "token_ids": token_ids, + } + return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) diff --git a/keras_hub/src/models/esm/esm_classifier_preprocessor_test.py b/keras_hub/src/models/esm/esm_classifier_preprocessor_test.py new file mode 100644 index 0000000000..93ecce4f4d --- /dev/null +++ b/keras_hub/src/models/esm/esm_classifier_preprocessor_test.py @@ -0,0 +1,46 @@ + +from keras_hub.src.models.esm.esm_tokenizer import ( + ESMTokenizer, +) +from keras_hub.src.tests.test_case import TestCase +from keras_hub.src.models.esm.esm_classifier_preprocessor import ( + ESMProteinClassifierPreprocessor +) + + + +class ESMProteinClassifierPreprocessorTest(TestCase): + def setUp(self): + self.vocab = [ "[UNK]","[PAD]", "[CLS]", "[SEP]", "[MASK]"] + self.vocab += ["THE", "QUICK", "BROWN", "FOX"] + self.vocab += ["the", "quick", "brown", "fox"] + self.tokenizer = ESMTokenizer(vocabulary=self.vocab) + self.init_kwargs = { + "tokenizer": self.tokenizer, + "sequence_length": 8, + } + self.input_data = ( + ["THE QUICK BROWN FOX."], + [1], # Pass through labels. + [1.0], # Pass through sample_weights. + ) + + def test_preprocessor_basics(self): + self.run_preprocessor_test( + cls=ESMProteinClassifierPreprocessor, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output=( + { + "token_ids": [[2, 5, 6, 7, 8, 0, 3, 1]], + }, + [1], # Pass through labels. + [1.0], # Pass through sample_weights. + ), + ) + + def test_errors_for_2d_list_input(self): + preprocessor = ESMProteinClassifierPreprocessor(**self.init_kwargs) + ambiguous_input = [["one", "two"], ["three", "four"]] + with self.assertRaises(ValueError): + preprocessor(ambiguous_input) diff --git a/keras_hub/src/models/esm/esm_classifier_test.py b/keras_hub/src/models/esm/esm_classifier_test.py new file mode 100644 index 0000000000..92474ab0ab --- /dev/null +++ b/keras_hub/src/models/esm/esm_classifier_test.py @@ -0,0 +1,59 @@ +import keras + +from keras_hub.src.models.roformer_v2 import ( + roformer_v2_text_classifier_preprocessor as r, +) +from keras_hub.src.models.esm.esm_backbone import ( + ESMBackbone, +) +from keras_hub.src.models.esm.esm_tokenizer import ( + ESMTokenizer, +) +from keras_hub.src.models.esm.esm_classifier import ( + ESMProteinClassifier, +) +from keras_hub.src.models.esm.esm_classifier_preprocessor import ( + ESMProteinClassifierPreprocessor +) +from keras_hub.src.tests.test_case import TestCase + + + + +class RoformerVTextClassifierTest(TestCase): + def setUp(self): + # Setup model. + self.vocab = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"] + self.vocab += ["the", "quick", "brown", "fox", "."] + self.preprocessor = ESMProteinClassifierPreprocessor( + ESMTokenizer(vocabulary=self.vocab), + sequence_length=5, + ) + self.backbone = ESMBackbone( + vocabulary_size=self.preprocessor.tokenizer.vocabulary_size(), + num_layers=2, + num_heads=2, + hidden_dim=4, + intermediate_dim=8, + head_size=2, + ) + self.init_kwargs = { + "preprocessor": self.preprocessor, + "backbone": self.backbone, + "num_classes": 2, + } + self.train_data = ( + ["the quick brown fox.", "the slow brown fox."], # Features. + [1, 0], # Labels. + ) + self.input_data = self.preprocessor(*self.train_data)[0] + + def test_classifier_basics(self): + if keras.__version__ < "3.6": + self.skipTest("Failing on keras lower version") + self.run_task_test( + cls=ESMProteinClassifier, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape=(2, 2), + ) diff --git a/keras_hub/src/models/esm/esm_encoder.py b/keras_hub/src/models/esm/esm_encoder.py new file mode 100644 index 0000000000..ab9fbf5f53 --- /dev/null +++ b/keras_hub/src/models/esm/esm_encoder.py @@ -0,0 +1,137 @@ +import keras +from keras import activations +from keras import initializers + +from keras_hub.src.models.esm.esm_attention import EsmSelfAttention + + + +class ESMEncoder(keras.layers.Layer): + """MultiHeadAttention by ESM + + Referred to the implementation of HuggingFace. + reference: + https://github.com/huggingface/transformers/ + blob/main/src/transformers/models/esm/modeling_esm.py + """ + + def __init__( + self, + heads, + head_size, + intermediate_size=None, + max_wavelength=10000, + dropout=0, + activation="gelu", + use_bias=False, + kernel_initializer="glorot_uniform", + layer_norm_eps = 1e-12, + use_rotary = True, + **kwargs, + ): + super().__init__(**kwargs) + self.heads = heads + self.head_size = head_size + self.intermediate_size = intermediate_size + self.use_bias = use_bias + self.kernel_initializer = initializers.get(kernel_initializer) + self.max_wavelength = max_wavelength + self.dropout = dropout + self.activation = activations.get(activation) + self.layer_norm_eps = layer_norm_eps + self.use_rotary = use_rotary + + def build(self, input_shape): + super().build(input_shape) + self.attention_layer = EsmSelfAttention( + heads=self.heads, + head_size=self.head_size, + use_bias=self.use_bias, + max_wavelength=self.max_wavelength, + kernel_initializer=self.kernel_initializer, + dtype=self.dtype_policy, + use_rotary = self.use_rotary, + name="attention_layer", + ) + self.attention_layer.build(input_shape) + + self.dropout_layer = keras.layers.Dropout( + rate=self.dropout, + dtype=self.dtype_policy, + name="self_attention_dropout", + ) + self.dropout_layer.build([]) + + # Feedforward layers. + self.feedforward_intermediate_dense = keras.layers.Dense( + self.intermediate_size, + kernel_initializer=self.kernel_initializer, + use_bias=self.use_bias, + dtype=self.dtype_policy, + activation=self.activation, + name="feedforward_intermediate_dense", + ) + self.feedforward_intermediate_dense.build(input_shape) + + self.feedforward_output_dense = keras.layers.Dense( + input_shape[-1], + kernel_initializer=self.kernel_initializer, + use_bias=self.use_bias, + dtype=self.dtype_policy, + name="feedforward_output_dense", + ) + + self.feedforward_output_dense.build( + [None, None, self.intermediate_size] + ) + import torch + self.attention_norm = keras.layers.LayerNormalization( + epsilon=self.layer_norm_eps, + name="attention_norm", + dtype=self.dtype_policy, + ) + self.attention_norm.build(input_shape) + + self.feedforward_norm = keras.layers.LayerNormalization( + epsilon=self.layer_norm_eps, + name="ffn_norm", + dtype=self.dtype_policy, + ) + self.feedforward_norm.build(input_shape) + + def call(self, x, attention_mask=None): + + attention_output = self.attention_layer( + self.attention_norm(self.dropout_layer(x)), + attention_mask=attention_mask, + ) + residual = x + attention_output + + x = self.feedforward_norm(self.dropout_layer(residual)) + intermediate_output = self.feedforward_intermediate_dense(x) + feedroward_output = self.feedforward_output_dense(intermediate_output) + return residual + self.dropout_layer(feedroward_output) + + + def compute_output_shape(self, input_shape): + return input_shape + + def get_config(self): + config = super().get_config() + config.update( + { + "heads": self.heads, + "head_size": self.head_size, + "intermediate_size": self.intermediate_size, + "max_wavelength": self.max_wavelength, + "use_bias": self.use_bias, + "activation": activations.serialize(self.activation), + "dropout": self.dropout, + "layer_norm_eps":self.layer_norm_eps, + "use_rotary":self.use_rotary, + "kernel_initializer": initializers.serialize( + self.kernel_initializer + ), + } + ) + return config diff --git a/keras_hub/src/models/esm/esm_masked_plm.py b/keras_hub/src/models/esm/esm_masked_plm.py new file mode 100644 index 0000000000..42fa26f297 --- /dev/null +++ b/keras_hub/src/models/esm/esm_masked_plm.py @@ -0,0 +1,119 @@ +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.modeling.masked_lm_head import MaskedLMHead +from keras_hub.src.models.masked_lm import MaskedLM +from keras_hub.src.models.esm.esm_backbone import ( + ESMBackbone,esm2_kernel_initializer +) +from keras_hub.src.models.esm.esm_masked_plm_preprocessor import ( + ESMMaskedPLMPreprocessor, +) + + +@keras_hub_export(["keras_hub.models.ESM2MaskedPLM","keras_hub.models.ESMMaskedPLM"]) +class ESMMaskedPLM(MaskedLM): + """An end-to-end ESM2 model for the masked protein language modeling task. + + This model will train ESM2 on a masked protein language modeling task. + The model will predict labels for a number of masked tokens in the + input data. For usage of this model with pre-trained weights, see the + `from_preset()` method. + + This model can optionally be configured with a `preprocessor` layer, in + which case inputs can be raw string features during `fit()`, `predict()`, + and `evaluate()`. Inputs will be tokenized and dynamically masked during + training and evaluation. This is done by default when creating the model + with `from_preset()`. + + Disclaimer: Pre-trained models are provided on an "as is" basis, without + warranties or conditions of any kind. The underlying model is provided by a + third party and subject to a separate license, available + [here](https://github.com/facebookresearch/esm). + + Args: + backbone: A `keras_hub.models.ESM2Backbone` instance. + preprocessor: A `keras_hub.models.ESM2MaskedPLMPreprocessor` or + `None`. If `None`, this model will not apply preprocessing, and + inputs should be preprocessed before calling the model. + + Examples: + + Raw string data. + ```python + features = ["The quick brown fox jumped.", "I forgot my homework."] + + # Pretrained protein language model. + masked_lm = keras_hub.models.ESM2MaskedPLM.from_preset( + "ESM2_base_en", + ) + masked_lm.fit(x=features, batch_size=2) + + # Re-compile (e.g., with a new learning rate). + masked_lm.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + optimizer=keras.optimizers.Adam(5e-5), + jit_compile=True, + ) + # Access backbone programmatically (e.g., to change `trainable`). + masked_lm.backbone.trainable = False + # Fit again. + masked_lm.fit(x=features, batch_size=2) + ``` + + Preprocessed integer data. + ```python + # Create a preprocessed dataset where 0 is the mask token. + features = { + "token_ids": np.array([[1, 2, 0, 4, 0, 6, 7, 8]] * 2), + "mask_positions": np.array([[2, 4]] * 2) + } + # Labels are the original masked values. + labels = [[3, 5]] * 2 + + masked_lm = keras_hub.models.ESM2MaskedPLM.from_preset( + 'hf://facebook/esm2_t6_8M_UR50D', + preprocessor=None, + ) + + masked_lm.fit(x=features, y=labels, batch_size=2) + ``` + """ + + backbone_cls = ESMBackbone + preprocessor_cls = ESMMaskedPLMPreprocessor + def __init__( + self, + backbone, + preprocessor=None, + **kwargs, + ): + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + self.masked_lm_head = MaskedLMHead( + vocabulary_size=backbone.vocabulary_size, + intermediate_activation=backbone.activation, + kernel_initializer=esm2_kernel_initializer(), + dtype=backbone.dtype_policy, + layer_norm_epsilon = backbone.layer_norm_eps, + name="mlm_head", + ) + + # === Functional Model === + inputs = { + **backbone.input, + "mask_positions": keras.Input( + shape=(None,), dtype="int32", name="mask_positions" + ), + } + backbone_outputs = backbone(backbone.input) + + outputs = self.masked_lm_head( + backbone_outputs, inputs["mask_positions"] + ) + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) diff --git a/keras_hub/src/models/esm/esm_masked_plm_preprocessor.py b/keras_hub/src/models/esm/esm_masked_plm_preprocessor.py new file mode 100644 index 0000000000..553643f62f --- /dev/null +++ b/keras_hub/src/models/esm/esm_masked_plm_preprocessor.py @@ -0,0 +1,148 @@ +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.masked_lm_preprocessor import MaskedLMPreprocessor +from keras_hub.src.models.esm.esm_backbone import ( + ESMBackbone, +) +from keras_hub.src.models.esm.esm_tokenizer import ( + ESMTokenizer, +) +from keras_hub.src.utils.tensor_utils import preprocessing_function + +from keras_hub.src.layers.preprocessing.masked_lm_mask_generator import ( + MaskedLMMaskGenerator, +) +from keras_hub.src.layers.preprocessing.start_end_packer import ( + StartEndPacker, +) +@keras_hub_export("keras_hub.models.ESMMaskedPLMPreprocessor") +class ESMMaskedPLMPreprocessor(MaskedLMPreprocessor): + """ESM preprocessing for the masked language modeling task. + + This preprocessing layer will prepare inputs for a masked language modeling + task. It is primarily intended for use with the + `keras_hub.models.ESMMaskedPLM` task model. + Preprocessing will occur in multiple steps. + + 1. Tokenize any number of input segments using the `tokenizer`. + 2. Pack the inputs together with the appropriate `"[CLS]"`, `"[SEP]"` and + `"[PAD]"` tokens. + 3. Randomly select non-special tokens to mask, controlled by + `mask_selection_rate`. + 4. Construct a `(x, y, sample_weight)` tuple suitable for training with a + `keras_hub.models.ESMMaskedPLM` task model. + + Args: + tokenizer: A `keras_hub.models.ESMTokenizer` instance. + sequence_length: int. The length of the packed inputs. + truncate: string. The algorithm to truncate a list of batched segments + to fit within `sequence_length`. The value can be either + `round_robin` or `waterfall`: + - `"round_robin"`: Available space is assigned one token at a + time in a round-robin fashion to the inputs that still need + some, until the limit is reached. + - `"waterfall"`: The allocation of the budget is done using a + "waterfall" algorithm that allocates quota in a + left-to-right manner and fills up the buckets until we run + out of budget. It supports an arbitrary number of segments. + mask_selection_rate: float. The probability an input token will be + dynamically masked. + mask_selection_length: int. The maximum number of masked tokens + in a given sample. + mask_token_rate: float. The probability the a selected token will be + replaced with the mask token. + random_token_rate: float. The probability the a selected token will be + replaced with a random token from the vocabulary. A selected token + will be left as is with probability + `1 - mask_token_rate - random_token_rate`. + + Call arguments: + x: A tensor of single string sequences, or a tuple of multiple + tensor sequences to be packed together. Inputs may be batched or + unbatched. For single sequences, raw python inputs will be converted + to tensors. For multiple sequences, pass tensors directly. + y: Label data. Should always be `None` as the layer generates labels. + sample_weight: Label weights. Should always be `None` as the layer + generates label weights. + + Examples: + + Directly calling the layer on data. + ```python + preprocessor = keras_hub.models.ESMMaskedPLMPreprocessor.from_preset( + "roformer_v2_base_zh" + ) + + # Tokenize and mask a single sentence. + preprocessor("The quick brown fox jumped.") + + # Tokenize and mask a batch of single sentences. + preprocessor(["The quick brown fox jumped.", "Call me Ishmael."]) + + # Tokenize and mask sentence pairs. + # In this case, always convert input to tensors before calling the layer. + first = tf.constant(["The quick brown fox jumped.", "Call me Ishmael."]) + second = tf.constant(["The fox tripped.", "Oh look, a whale."]) + preprocessor((first, second)) + ``` + + Mapping with `tf.data.Dataset`. + ```python + preprocessor = keras_hub.models.ESMMaskedPLMPreprocessor.from_preset( + "roformer_v2_base_zh" + ) + + first = tf.constant(["The quick brown fox jumped.", "Call me Ishmael."]) + second = tf.constant(["The fox tripped.", "Oh look, a whale."]) + + # Map single sentences. + ds = tf.data.Dataset.from_tensor_slices(first) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + + # Map sentence pairs. + ds = tf.data.Dataset.from_tensor_slices((first, second)) + # Watch out for tf.data's default unpacking of tuples here! + # Best to invoke the `preprocessor` directly in this case. + ds = ds.map( + lambda first, second: preprocessor(x=(first, second)), + num_parallel_calls=tf.data.AUTOTUNE, + ) + ``` + """ + + backbone_cls = ESMBackbone + tokenizer_cls = ESMTokenizer + def build(self, input_shape): + super().build(input_shape) + # Defer masker creation to `build()` so that we can be sure tokenizer + # assets have loaded when restoring a saved model. + self.packer = StartEndPacker( + start_value=self.tokenizer.start_token_id, + end_value=self.tokenizer.end_token_id, + pad_value=self.tokenizer.pad_token_id, + sequence_length=self.sequence_length, + ) + self.masker = MaskedLMMaskGenerator( + mask_selection_rate=self.mask_selection_rate, + mask_selection_length=self.mask_selection_length, + mask_token_rate=self.mask_token_rate, + random_token_rate=self.random_token_rate, + vocabulary_size=self.tokenizer.vocabulary_size(), + mask_token_id=self.tokenizer.mask_token_id, + unselectable_token_ids=self.tokenizer.special_token_ids, + ) + + + @preprocessing_function + def call(self, x, y=None, sample_weight=None): + x = self.tokenizer(x) + token_ids = self.packer(x) + masker_outputs = self.masker(token_ids) + x = { + "token_ids": masker_outputs["token_ids"], + "mask_positions": masker_outputs["mask_positions"], + } + y = masker_outputs["mask_ids"] + sample_weight = masker_outputs["mask_weights"] + return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) diff --git a/keras_hub/src/models/esm/esm_masked_plm_preprocessor_test.py b/keras_hub/src/models/esm/esm_masked_plm_preprocessor_test.py new file mode 100644 index 0000000000..316927f247 --- /dev/null +++ b/keras_hub/src/models/esm/esm_masked_plm_preprocessor_test.py @@ -0,0 +1,60 @@ +from keras_hub.src.models.esm.esm_masked_plm_preprocessor import ( + ESMMaskedPLMPreprocessor, +) +from keras_hub.src.models.esm.esm_tokenizer import ( + ESMTokenizer, +) +from keras_hub.src.tests.test_case import TestCase + + +class ESMMaskedPLMPreprocessort(TestCase): + def setUp(self): + self.vocab = [ "[UNK]", "[PAD]","[CLS]", "[SEP]", "[MASK]"] + self.vocab += ["THE", "QUICK", "BROWN", "FOX"] + self.vocab += ["the", "quick", "brown", "fox"] + self.tokenizer = ESMTokenizer(vocabulary=self.vocab) + self.init_kwargs = { + "tokenizer": self.tokenizer, + # Simplify our testing by masking every available token. + "mask_selection_rate": 1.0, + "mask_token_rate": 1.0, + "random_token_rate": 0.0, + "mask_selection_length": 4, + "sequence_length": 12, + } + self.input_data = ["the quick brown fox"] + + def test_preprocessor_basics(self): + self.run_preprocessor_test( + cls=ESMMaskedPLMPreprocessor, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output=( + { + "token_ids": [[2, 4, 4, 4, 4, 3, 1, 1, 1, 1, 1, 1]], + "mask_positions": [[1, 2, 3, 4]], + }, + [[9, 10, 11, 12]], + [[1.0, 1.0, 1.0, 1.0]], + ), + ) + + def test_no_masking_zero_rate(self): + no_mask_preprocessor = ESMMaskedPLMPreprocessor( + self.tokenizer, + mask_selection_rate=0.0, + mask_selection_length=4, + sequence_length=12, + ) + input_data = ["the quick brown fox"] + self.assertAllClose( + no_mask_preprocessor(input_data), + ( + { + "token_ids": [[2, 9, 10, 11, 12, 3, 1, 1, 1, 1, 1, 1]], + "mask_positions": [[0, 0, 0, 0]], + }, + [[0, 0, 0, 0]], + [[0.0, 0.0, 0.0, 0.0]], + ), + ) diff --git a/keras_hub/src/models/esm/esm_masked_plm_test.py b/keras_hub/src/models/esm/esm_masked_plm_test.py new file mode 100644 index 0000000000..c601d15638 --- /dev/null +++ b/keras_hub/src/models/esm/esm_masked_plm_test.py @@ -0,0 +1,57 @@ +import keras + +from keras_hub.src.models.esm.esm_backbone import ( + ESMBackbone, +) +from keras_hub.src.models.esm.esm_masked_plm import ( + ESMMaskedPLM +) +from keras_hub.src.models.esm.esm_masked_plm_preprocessor import ( + ESMMaskedPLMPreprocessor, +) +from keras_hub.src.models.esm.esm_tokenizer import ( + ESMTokenizer, +) +from keras_hub.src.tests.test_case import TestCase + + +class ESMMaskedLMTest(TestCase): + def setUp(self): + # Setup model. + self.vocab = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"] + self.vocab += ["the", "quick", "brown", "fox", "."] + self.preprocessor = ESMMaskedPLMPreprocessor( + ESMTokenizer(vocabulary=self.vocab), + # Simplify our testing by masking every available token. + mask_selection_rate=1.0, + mask_token_rate=1.0, + random_token_rate=0.0, + mask_selection_length=5, + sequence_length=5, + ) + self.backbone = ESMBackbone( + vocabulary_size=self.preprocessor.tokenizer.vocabulary_size(), + num_layers=2, + num_heads=2, + hidden_dim=4, + intermediate_dim=8, + head_size=2, + ) + self.init_kwargs = { + "preprocessor": self.preprocessor, + "backbone": self.backbone, + } + self.train_data = ( + ["the quick brown fox.", "the slow brown fox."], # Features. + ) + self.input_data = self.preprocessor(*self.train_data)[0] + + def test_masked_lm_basics(self): + if keras.__version__ < "3.6": + self.skipTest("Failing on keras lower version") + self.run_task_test( + cls=ESMMaskedPLM, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape=(2, 5, 10), + ) diff --git a/keras_hub/src/models/esm/esm_presets.py b/keras_hub/src/models/esm/esm_presets.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/keras_hub/src/models/esm/esm_tokenizer.py b/keras_hub/src/models/esm/esm_tokenizer.py new file mode 100644 index 0000000000..7898200dad --- /dev/null +++ b/keras_hub/src/models/esm/esm_tokenizer.py @@ -0,0 +1,62 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.bert.bert_tokenizer import BertTokenizer +from keras_hub.src.models.esm.esm_backbone import ( + ESMBackbone, +) + + +@keras_hub_export( + [ + "keras_hub.tokenizers.ESMTokenizer", + "keras_hub.models.ESMTokenizer", + ] +) +class ESMTokenizer(BertTokenizer): + """A ESM tokenizer using WordPiece subword segmentation. + + This tokenizer class will tokenize raw strings into integer sequences and + is based on `keras_hub.tokenizers.WordPieceTokenizer`. Unlike the + underlying tokenizer, it will check for special tokens needed by ESM + models and provides a `from_preset()` method to automatically download + a matching vocabulary for a ESM preset. + + If input is a batch of strings (rank > 0), the layer will output a + `tf.RaggedTensor` where the last dimension of the output is ragged. + + If input is a scalar string (rank == 0), the layer will output a dense + `tf.Tensor` with static shape `[None]`. + + Args: + vocabulary: A list of strings or a string filename path. If + passing a list, each element of the list should be a single word + piece token string. If passing a filename, the file should be a + plain text file containing a single word piece token per line. + lowercase: If `True`, the input text will be first lowered before + tokenization. + special_tokens_in_strings: bool. A bool to indicate if the tokenizer + should expect special tokens in input strings that should be + tokenized and mapped correctly to their ids. Defaults to False. + + Examples: + ```python + # Unbatched input. + tokenizer = keras_hub.models.ESMTokenizer.from_preset( + "roformer_v2_base_zh", + ) + tokenizer("The quick brown fox jumped.") + + # Batched input. + tokenizer(["The quick brown fox jumped.", "The fox slept."]) + + # Detokenization. + tokenizer.detokenize(tokenizer("The quick brown fox jumped.")) + + # Custom vocabulary. + vocab = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"] + vocab += ["The", "quick", "brown", "fox", "jumped", "."] + tokenizer = keras_hub.models.ESMTokenizer(vocabulary=vocab) + tokenizer("The quick brown fox jumped.") + ``` + """ + + backbone_cls = ESMBackbone diff --git a/keras_hub/src/models/esm/esm_tokenizer_test.py b/keras_hub/src/models/esm/esm_tokenizer_test.py new file mode 100644 index 0000000000..89ec34032b --- /dev/null +++ b/keras_hub/src/models/esm/esm_tokenizer_test.py @@ -0,0 +1,40 @@ +from keras_hub.src.models.esm.esm_tokenizer import ( + ESMTokenizer, +) +from keras_hub.src.tests.test_case import TestCase + + +class ESMTokenizerTest(TestCase): + def setUp(self): + self.vocab = [ "[UNK]", "[PAD]","[CLS]", "[SEP]", "[MASK]"] + self.vocab += ["THE", "QUICK", "BROWN", "FOX"] + self.vocab += ["the", "quick", "brown", "fox"] + self.init_kwargs = {"vocabulary": self.vocab} + self.input_data = ["THE QUICK BROWN FOX", "THE FOX"] + + def test_tokenizer_basics(self): + self.run_preprocessing_layer_test( + cls=ESMTokenizer, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output=[[5, 6, 7, 8], [5, 8]], + ) + + def test_lowercase(self): + tokenizer = ESMTokenizer(vocabulary=self.vocab, lowercase=True) + output = tokenizer(self.input_data) + self.assertAllEqual(output, [[9, 10, 11, 12], [9, 12]]) + + def test_tokenizer_special_tokens(self): + input_data = ["[CLS] THE [MASK] FOX [SEP] [PAD]"] + tokenizer = ESMTokenizer( + **self.init_kwargs, special_tokens_in_strings=True + ) + output_data = tokenizer(input_data) + expected_output = [[2, 5, 4, 8, 3, 1]] + + self.assertAllEqual(output_data, expected_output) + + def test_errors_missing_special_tokens(self): + with self.assertRaises(ValueError): + ESMTokenizer(vocabulary=["a", "b", "c"]) diff --git a/keras_hub/src/utils/transformers/convert_esm.py b/keras_hub/src/utils/transformers/convert_esm.py new file mode 100644 index 0000000000..8e5ce17584 --- /dev/null +++ b/keras_hub/src/utils/transformers/convert_esm.py @@ -0,0 +1,153 @@ +import numpy as np + +from keras_hub.src.models.vit.vit_backbone import ViTBackbone + +backbone_cls = ViTBackbone + + +def convert_backbone_config(transformers_config): + + return { + "vocabulary_size": transformers_config["vocab_size"], + "num_layers": transformers_config["num_hidden_layers"], + "num_heads": transformers_config["num_attention_heads"], + "hidden_dim": transformers_config["hidden_size"], + "intermediate_dim": transformers_config["intermediate_size"], + "dropout": transformers_config["hidden_dropout_prob"], + "position_embedding_type": transformers_config["position_embedding_type"], + "pad_token_id": transformers_config["pad_token_id"], + "max_sequence_length": transformers_config.get("max_position_embeddings", None), # 默认值为None + "layer_norm_eps": transformers_config.get("layer_norm_eps", 1e-12), # 默认值为1e-12 + "emb_layer_norm_before": transformers_config.get("emb_layer_norm_before", False), # 默认值为False + "head_size": transformers_config.get("head_size", 64), # 默认值为64 + "activation": transformers_config.get("activation", "gelu"), # 默认值为"gelu" + "max_wavelength": transformers_config.get("max_wavelength", 10000), # 默认值为10000 + } + + +def convert_weights(backbone, loader, transformers_config): + # Embedding layer + loader.port_weight( + keras_variable=backbone.get_layer("token_embedding").embeddings, + hf_weight_key="bert.embeddings.word_embeddings.weight", + ) + if transformers_config["position_embedding_type"]=="absolute": + pass + loader.port_weight( + keras_variable=backbone.get_layer( + "position_embedding" + ).position_embeddings, + hf_weight_key="bert.embeddings.position_embeddings.weight", + ) + loader.port_weight( + keras_variable=backbone.get_layer("segment_embedding").embeddings, + hf_weight_key="bert.embeddings.token_type_embeddings.weight", + ) + loader.port_weight( + keras_variable=backbone.get_layer("embeddings_layer_norm").beta, + hf_weight_key="bert.embeddings.LayerNorm.beta", + ) + loader.port_weight( + keras_variable=backbone.get_layer("embeddings_layer_norm").gamma, + hf_weight_key="bert.embeddings.LayerNorm.gamma", + ) + + def transpose_and_reshape(x, shape): + return np.reshape(np.transpose(x), shape) + + # Attention blocks + for i in range(backbone.num_layers): + block = backbone.get_layer(f"transformer_layer_{i}") + attn = block._self_attention_layer + hf_prefix = "bert.encoder.layer." + # Attention layers + loader.port_weight( + keras_variable=attn.query_dense.kernel, + hf_weight_key=f"{hf_prefix}{i}.attention.self.query.weight", + hook_fn=transpose_and_reshape, + ) + loader.port_weight( + keras_variable=attn.query_dense.bias, + hf_weight_key=f"{hf_prefix}{i}.attention.self.query.bias", + hook_fn=lambda hf_tensor, shape: np.reshape(hf_tensor, shape), + ) + loader.port_weight( + keras_variable=attn.key_dense.kernel, + hf_weight_key=f"{hf_prefix}{i}.attention.self.key.weight", + hook_fn=transpose_and_reshape, + ) + loader.port_weight( + keras_variable=attn.key_dense.bias, + hf_weight_key=f"{hf_prefix}{i}.attention.self.key.bias", + hook_fn=lambda hf_tensor, shape: np.reshape(hf_tensor, shape), + ) + loader.port_weight( + keras_variable=attn.value_dense.kernel, + hf_weight_key=f"{hf_prefix}{i}.attention.self.value.weight", + hook_fn=transpose_and_reshape, + ) + loader.port_weight( + keras_variable=attn.value_dense.bias, + hf_weight_key=f"{hf_prefix}{i}.attention.self.value.bias", + hook_fn=lambda hf_tensor, shape: np.reshape(hf_tensor, shape), + ) + loader.port_weight( + keras_variable=attn.output_dense.kernel, + hf_weight_key=f"{hf_prefix}{i}.attention.output.dense.weight", + hook_fn=transpose_and_reshape, + ) + loader.port_weight( + keras_variable=attn.output_dense.bias, + hf_weight_key=f"{hf_prefix}{i}.attention.output.dense.bias", + hook_fn=lambda hf_tensor, shape: np.reshape(hf_tensor, shape), + ) + # Attention layer norm. + loader.port_weight( + keras_variable=block._self_attention_layer_norm.beta, + hf_weight_key=f"{hf_prefix}{i}.attention.output.LayerNorm.beta", + ) + loader.port_weight( + keras_variable=block._self_attention_layer_norm.gamma, + hf_weight_key=f"{hf_prefix}{i}.attention.output.LayerNorm.gamma", + ) + # MLP layers + loader.port_weight( + keras_variable=block._feedforward_intermediate_dense.kernel, + hf_weight_key=f"{hf_prefix}{i}.intermediate.dense.weight", + hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)), + ) + loader.port_weight( + keras_variable=block._feedforward_intermediate_dense.bias, + hf_weight_key=f"{hf_prefix}{i}.intermediate.dense.bias", + ) + loader.port_weight( + keras_variable=block._feedforward_output_dense.kernel, + hf_weight_key=f"{hf_prefix}{i}.output.dense.weight", + hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)), + ) + loader.port_weight( + keras_variable=block._feedforward_output_dense.bias, + hf_weight_key=f"{hf_prefix}{i}.output.dense.bias", + ) + # Output layer norm. + loader.port_weight( + keras_variable=block._feedforward_layer_norm.beta, + hf_weight_key=f"{hf_prefix}{i}.output.LayerNorm.beta", + ) + loader.port_weight( + keras_variable=block._feedforward_layer_norm.gamma, + hf_weight_key=f"{hf_prefix}{i}.output.LayerNorm.gamma", + ) + + +def convert_head(task, loader, transformers_config): + prefix = "classifier." + loader.port_weight( + task.output_dense.kernel, + hf_weight_key=prefix + "weight", + hook_fn=lambda x, _: x.T, + ) + loader.port_weight( + task.output_dense.bias, + hf_weight_key=prefix + "bias", + ) From cc4123b94b58d904e69e5e436603a8100f1d2615 Mon Sep 17 00:00:00 2001 From: pass_lin <935499957@qq.com> Date: Sat, 3 May 2025 15:38:21 +0800 Subject: [PATCH 02/13] add esm2 --- keras_hub/api/__init__.py | 16 +- keras_hub/api/layers/__init__.py | 128 ++-- keras_hub/api/metrics/__init__.py | 10 +- keras_hub/api/models/__init__.py | 614 +++++++----------- keras_hub/api/samplers/__init__.py | 22 +- keras_hub/api/tokenizers/__init__.py | 119 ++-- keras_hub/api/utils/__init__.py | 18 +- keras_hub/src/models/esm/esm_attention.py | 53 +- keras_hub/src/models/esm/esm_backbone.py | 58 +- keras_hub/src/models/esm/esm_backbone_test.py | 4 +- keras_hub/src/models/esm/esm_classifier.py | 10 +- .../models/esm/esm_classifier_preprocessor.py | 15 +- .../esm/esm_classifier_preprocessor_test.py | 12 +- .../src/models/esm/esm_classifier_test.py | 19 +- keras_hub/src/models/esm/esm_encoder.py | 21 +- keras_hub/src/models/esm/esm_masked_plm.py | 16 +- .../models/esm/esm_masked_plm_preprocessor.py | 21 +- .../esm/esm_masked_plm_preprocessor_test.py | 6 +- .../src/models/esm/esm_masked_plm_test.py | 12 +- keras_hub/src/models/esm/esm_tokenizer.py | 32 +- .../src/models/esm/esm_tokenizer_test.py | 8 +- .../src/utils/transformers/convert_esm.py | 136 ++-- .../src/utils/transformers/preset_loader.py | 3 + 23 files changed, 559 insertions(+), 794 deletions(-) diff --git a/keras_hub/api/__init__.py b/keras_hub/api/__init__.py index 2aa98bf3f9..3796e4c7f4 100644 --- a/keras_hub/api/__init__.py +++ b/keras_hub/api/__init__.py @@ -4,12 +4,12 @@ since your modifications would be overwritten. """ -from keras_hub import layers as layers -from keras_hub import metrics as metrics -from keras_hub import models as models -from keras_hub import samplers as samplers -from keras_hub import tokenizers as tokenizers -from keras_hub import utils as utils -from keras_hub.src.utils.preset_utils import upload_preset as upload_preset +from keras_hub import layers +from keras_hub import metrics +from keras_hub import models +from keras_hub import samplers +from keras_hub import tokenizers +from keras_hub import utils +from keras_hub.src.utils.preset_utils import upload_preset from keras_hub.src.version import __version__ as __version__ -from keras_hub.src.version import version as version +from keras_hub.src.version import version diff --git a/keras_hub/api/layers/__init__.py b/keras_hub/api/layers/__init__.py index 61eb0621b6..d42af86a3c 100644 --- a/keras_hub/api/layers/__init__.py +++ b/keras_hub/api/layers/__init__.py @@ -4,128 +4,86 @@ since your modifications would be overwritten. """ -from keras_hub.src.layers.modeling.alibi_bias import AlibiBias as AlibiBias -from keras_hub.src.layers.modeling.anchor_generator import ( - AnchorGenerator as AnchorGenerator, -) -from keras_hub.src.layers.modeling.box_matcher import BoxMatcher as BoxMatcher +from keras_hub.src.layers.modeling.alibi_bias import AlibiBias +from keras_hub.src.layers.modeling.anchor_generator import AnchorGenerator +from keras_hub.src.layers.modeling.box_matcher import BoxMatcher from keras_hub.src.layers.modeling.cached_multi_head_attention import ( - CachedMultiHeadAttention as CachedMultiHeadAttention, -) -from keras_hub.src.layers.modeling.f_net_encoder import ( - FNetEncoder as FNetEncoder, -) -from keras_hub.src.layers.modeling.masked_lm_head import ( - MaskedLMHead as MaskedLMHead, -) -from keras_hub.src.layers.modeling.non_max_supression import ( - NonMaxSuppression as NonMaxSuppression, -) -from keras_hub.src.layers.modeling.position_embedding import ( - PositionEmbedding as PositionEmbedding, + CachedMultiHeadAttention, ) +from keras_hub.src.layers.modeling.f_net_encoder import FNetEncoder +from keras_hub.src.layers.modeling.masked_lm_head import MaskedLMHead +from keras_hub.src.layers.modeling.non_max_supression import NonMaxSuppression +from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding from keras_hub.src.layers.modeling.reversible_embedding import ( - ReversibleEmbedding as ReversibleEmbedding, -) -from keras_hub.src.layers.modeling.rms_normalization import ( - RMSNormalization as RMSNormalization, -) -from keras_hub.src.layers.modeling.rotary_embedding import ( - RotaryEmbedding as RotaryEmbedding, + ReversibleEmbedding, ) +from keras_hub.src.layers.modeling.rms_normalization import RMSNormalization +from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding from keras_hub.src.layers.modeling.sine_position_encoding import ( - SinePositionEncoding as SinePositionEncoding, + SinePositionEncoding, ) from keras_hub.src.layers.modeling.token_and_position_embedding import ( - TokenAndPositionEmbedding as TokenAndPositionEmbedding, -) -from keras_hub.src.layers.modeling.transformer_decoder import ( - TransformerDecoder as TransformerDecoder, -) -from keras_hub.src.layers.modeling.transformer_encoder import ( - TransformerEncoder as TransformerEncoder, -) -from keras_hub.src.layers.preprocessing.audio_converter import ( - AudioConverter as AudioConverter, -) -from keras_hub.src.layers.preprocessing.image_converter import ( - ImageConverter as ImageConverter, + TokenAndPositionEmbedding, ) +from keras_hub.src.layers.modeling.transformer_decoder import TransformerDecoder +from keras_hub.src.layers.modeling.transformer_encoder import TransformerEncoder +from keras_hub.src.layers.preprocessing.audio_converter import AudioConverter +from keras_hub.src.layers.preprocessing.image_converter import ImageConverter from keras_hub.src.layers.preprocessing.masked_lm_mask_generator import ( - MaskedLMMaskGenerator as MaskedLMMaskGenerator, + MaskedLMMaskGenerator, ) from keras_hub.src.layers.preprocessing.multi_segment_packer import ( - MultiSegmentPacker as MultiSegmentPacker, -) -from keras_hub.src.layers.preprocessing.random_deletion import ( - RandomDeletion as RandomDeletion, -) -from keras_hub.src.layers.preprocessing.random_swap import ( - RandomSwap as RandomSwap, -) -from keras_hub.src.layers.preprocessing.start_end_packer import ( - StartEndPacker as StartEndPacker, + MultiSegmentPacker, ) +from keras_hub.src.layers.preprocessing.random_deletion import RandomDeletion +from keras_hub.src.layers.preprocessing.random_swap import RandomSwap +from keras_hub.src.layers.preprocessing.start_end_packer import StartEndPacker from keras_hub.src.models.basnet.basnet_image_converter import ( - BASNetImageConverter as BASNetImageConverter, -) -from keras_hub.src.models.clip.clip_image_converter import ( - CLIPImageConverter as CLIPImageConverter, + BASNetImageConverter, ) +from keras_hub.src.models.clip.clip_image_converter import CLIPImageConverter from keras_hub.src.models.cspnet.cspnet_image_converter import ( - CSPNetImageConverter as CSPNetImageConverter, + CSPNetImageConverter, ) from keras_hub.src.models.deeplab_v3.deeplab_v3_image_converter import ( - DeepLabV3ImageConverter as DeepLabV3ImageConverter, + DeepLabV3ImageConverter, ) from keras_hub.src.models.densenet.densenet_image_converter import ( - DenseNetImageConverter as DenseNetImageConverter, + DenseNetImageConverter, ) from keras_hub.src.models.efficientnet.efficientnet_image_converter import ( - EfficientNetImageConverter as EfficientNetImageConverter, + EfficientNetImageConverter, ) from keras_hub.src.models.gemma3.gemma3_image_converter import ( - Gemma3ImageConverter as Gemma3ImageConverter, -) -from keras_hub.src.models.mit.mit_image_converter import ( - MiTImageConverter as MiTImageConverter, + Gemma3ImageConverter, ) +from keras_hub.src.models.mit.mit_image_converter import MiTImageConverter from keras_hub.src.models.mobilenet.mobilenet_image_converter import ( - MobileNetImageConverter as MobileNetImageConverter, + MobileNetImageConverter, ) from keras_hub.src.models.pali_gemma.pali_gemma_image_converter import ( - PaliGemmaImageConverter as PaliGemmaImageConverter, + PaliGemmaImageConverter, ) from keras_hub.src.models.resnet.resnet_image_converter import ( - ResNetImageConverter as ResNetImageConverter, + ResNetImageConverter, ) from keras_hub.src.models.retinanet.retinanet_image_converter import ( - RetinaNetImageConverter as RetinaNetImageConverter, -) -from keras_hub.src.models.sam.sam_image_converter import ( - SAMImageConverter as SAMImageConverter, -) -from keras_hub.src.models.sam.sam_mask_decoder import ( - SAMMaskDecoder as SAMMaskDecoder, -) -from keras_hub.src.models.sam.sam_prompt_encoder import ( - SAMPromptEncoder as SAMPromptEncoder, + RetinaNetImageConverter, ) +from keras_hub.src.models.sam.sam_image_converter import SAMImageConverter +from keras_hub.src.models.sam.sam_mask_decoder import SAMMaskDecoder +from keras_hub.src.models.sam.sam_prompt_encoder import SAMPromptEncoder from keras_hub.src.models.segformer.segformer_image_converter import ( - SegFormerImageConverter as SegFormerImageConverter, + SegFormerImageConverter, ) from keras_hub.src.models.siglip.siglip_image_converter import ( - SigLIPImageConverter as SigLIPImageConverter, -) -from keras_hub.src.models.vgg.vgg_image_converter import ( - VGGImageConverter as VGGImageConverter, -) -from keras_hub.src.models.vit.vit_image_converter import ( - ViTImageConverter as ViTImageConverter, + SigLIPImageConverter, ) +from keras_hub.src.models.vgg.vgg_image_converter import VGGImageConverter +from keras_hub.src.models.vit.vit_image_converter import ViTImageConverter from keras_hub.src.models.whisper.whisper_audio_converter import ( - WhisperAudioConverter as WhisperAudioConverter, + WhisperAudioConverter, ) from keras_hub.src.models.xception.xception_image_converter import ( - XceptionImageConverter as XceptionImageConverter, + XceptionImageConverter, ) diff --git a/keras_hub/api/metrics/__init__.py b/keras_hub/api/metrics/__init__.py index 100c2c66fb..88a0a7df2b 100644 --- a/keras_hub/api/metrics/__init__.py +++ b/keras_hub/api/metrics/__init__.py @@ -4,8 +4,8 @@ since your modifications would be overwritten. """ -from keras_hub.src.metrics.bleu import Bleu as Bleu -from keras_hub.src.metrics.edit_distance import EditDistance as EditDistance -from keras_hub.src.metrics.perplexity import Perplexity as Perplexity -from keras_hub.src.metrics.rouge_l import RougeL as RougeL -from keras_hub.src.metrics.rouge_n import RougeN as RougeN +from keras_hub.src.metrics.bleu import Bleu +from keras_hub.src.metrics.edit_distance import EditDistance +from keras_hub.src.metrics.perplexity import Perplexity +from keras_hub.src.metrics.rouge_l import RougeL +from keras_hub.src.metrics.rouge_n import RougeN diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index 2a78362e9a..e3f0a3aa16 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -4,606 +4,452 @@ since your modifications would be overwritten. """ -from keras_hub.src.models.albert.albert_backbone import ( - AlbertBackbone as AlbertBackbone, -) -from keras_hub.src.models.albert.albert_masked_lm import ( - AlbertMaskedLM as AlbertMaskedLM, -) +from keras_hub.src.models.albert.albert_backbone import AlbertBackbone +from keras_hub.src.models.albert.albert_masked_lm import AlbertMaskedLM from keras_hub.src.models.albert.albert_masked_lm_preprocessor import ( - AlbertMaskedLMPreprocessor as AlbertMaskedLMPreprocessor, + AlbertMaskedLMPreprocessor, ) from keras_hub.src.models.albert.albert_text_classifier import ( - AlbertTextClassifier as AlbertClassifier, + AlbertTextClassifier, ) from keras_hub.src.models.albert.albert_text_classifier import ( - AlbertTextClassifier as AlbertTextClassifier, + AlbertTextClassifier as AlbertClassifier, ) from keras_hub.src.models.albert.albert_text_classifier_preprocessor import ( - AlbertTextClassifierPreprocessor as AlbertPreprocessor, + AlbertTextClassifierPreprocessor, ) from keras_hub.src.models.albert.albert_text_classifier_preprocessor import ( - AlbertTextClassifierPreprocessor as AlbertTextClassifierPreprocessor, -) -from keras_hub.src.models.albert.albert_tokenizer import ( - AlbertTokenizer as AlbertTokenizer, -) -from keras_hub.src.models.backbone import Backbone as Backbone -from keras_hub.src.models.bart.bart_backbone import BartBackbone as BartBackbone -from keras_hub.src.models.bart.bart_seq_2_seq_lm import ( - BartSeq2SeqLM as BartSeq2SeqLM, + AlbertTextClassifierPreprocessor as AlbertPreprocessor, ) +from keras_hub.src.models.albert.albert_tokenizer import AlbertTokenizer +from keras_hub.src.models.backbone import Backbone +from keras_hub.src.models.bart.bart_backbone import BartBackbone +from keras_hub.src.models.bart.bart_seq_2_seq_lm import BartSeq2SeqLM from keras_hub.src.models.bart.bart_seq_2_seq_lm_preprocessor import ( - BartSeq2SeqLMPreprocessor as BartSeq2SeqLMPreprocessor, -) -from keras_hub.src.models.bart.bart_tokenizer import ( - BartTokenizer as BartTokenizer, -) -from keras_hub.src.models.basnet.basnet import ( - BASNetImageSegmenter as BASNetImageSegmenter, -) -from keras_hub.src.models.basnet.basnet_backbone import ( - BASNetBackbone as BASNetBackbone, -) -from keras_hub.src.models.basnet.basnet_preprocessor import ( - BASNetPreprocessor as BASNetPreprocessor, -) -from keras_hub.src.models.bert.bert_backbone import BertBackbone as BertBackbone -from keras_hub.src.models.bert.bert_masked_lm import ( - BertMaskedLM as BertMaskedLM, -) + BartSeq2SeqLMPreprocessor, +) +from keras_hub.src.models.bart.bart_tokenizer import BartTokenizer +from keras_hub.src.models.basnet.basnet import BASNetImageSegmenter +from keras_hub.src.models.basnet.basnet_backbone import BASNetBackbone +from keras_hub.src.models.basnet.basnet_preprocessor import BASNetPreprocessor +from keras_hub.src.models.bert.bert_backbone import BertBackbone +from keras_hub.src.models.bert.bert_masked_lm import BertMaskedLM from keras_hub.src.models.bert.bert_masked_lm_preprocessor import ( - BertMaskedLMPreprocessor as BertMaskedLMPreprocessor, + BertMaskedLMPreprocessor, ) +from keras_hub.src.models.bert.bert_text_classifier import BertTextClassifier from keras_hub.src.models.bert.bert_text_classifier import ( BertTextClassifier as BertClassifier, ) -from keras_hub.src.models.bert.bert_text_classifier import ( - BertTextClassifier as BertTextClassifier, -) from keras_hub.src.models.bert.bert_text_classifier_preprocessor import ( - BertTextClassifierPreprocessor as BertPreprocessor, + BertTextClassifierPreprocessor, ) from keras_hub.src.models.bert.bert_text_classifier_preprocessor import ( - BertTextClassifierPreprocessor as BertTextClassifierPreprocessor, -) -from keras_hub.src.models.bert.bert_tokenizer import ( - BertTokenizer as BertTokenizer, -) -from keras_hub.src.models.bloom.bloom_backbone import ( - BloomBackbone as BloomBackbone, -) -from keras_hub.src.models.bloom.bloom_causal_lm import ( - BloomCausalLM as BloomCausalLM, + BertTextClassifierPreprocessor as BertPreprocessor, ) +from keras_hub.src.models.bert.bert_tokenizer import BertTokenizer +from keras_hub.src.models.bloom.bloom_backbone import BloomBackbone +from keras_hub.src.models.bloom.bloom_causal_lm import BloomCausalLM from keras_hub.src.models.bloom.bloom_causal_lm_preprocessor import ( - BloomCausalLMPreprocessor as BloomCausalLMPreprocessor, -) -from keras_hub.src.models.bloom.bloom_tokenizer import ( - BloomTokenizer as BloomTokenizer, -) -from keras_hub.src.models.causal_lm import CausalLM as CausalLM -from keras_hub.src.models.causal_lm_preprocessor import ( - CausalLMPreprocessor as CausalLMPreprocessor, -) -from keras_hub.src.models.clip.clip_backbone import CLIPBackbone as CLIPBackbone -from keras_hub.src.models.clip.clip_preprocessor import ( - CLIPPreprocessor as CLIPPreprocessor, -) -from keras_hub.src.models.clip.clip_text_encoder import ( - CLIPTextEncoder as CLIPTextEncoder, -) -from keras_hub.src.models.clip.clip_tokenizer import ( - CLIPTokenizer as CLIPTokenizer, -) -from keras_hub.src.models.clip.clip_vision_encoder import ( - CLIPVisionEncoder as CLIPVisionEncoder, -) -from keras_hub.src.models.cspnet.cspnet_backbone import ( - CSPNetBackbone as CSPNetBackbone, -) + BloomCausalLMPreprocessor, +) +from keras_hub.src.models.bloom.bloom_tokenizer import BloomTokenizer +from keras_hub.src.models.causal_lm import CausalLM +from keras_hub.src.models.causal_lm_preprocessor import CausalLMPreprocessor +from keras_hub.src.models.clip.clip_backbone import CLIPBackbone +from keras_hub.src.models.clip.clip_preprocessor import CLIPPreprocessor +from keras_hub.src.models.clip.clip_text_encoder import CLIPTextEncoder +from keras_hub.src.models.clip.clip_tokenizer import CLIPTokenizer +from keras_hub.src.models.clip.clip_vision_encoder import CLIPVisionEncoder +from keras_hub.src.models.cspnet.cspnet_backbone import CSPNetBackbone from keras_hub.src.models.cspnet.cspnet_image_classifier import ( - CSPNetImageClassifier as CSPNetImageClassifier, + CSPNetImageClassifier, ) from keras_hub.src.models.cspnet.cspnet_image_classifier_preprocessor import ( - CSPNetImageClassifierPreprocessor as CSPNetImageClassifierPreprocessor, + CSPNetImageClassifierPreprocessor, ) from keras_hub.src.models.deberta_v3.deberta_v3_backbone import ( - DebertaV3Backbone as DebertaV3Backbone, + DebertaV3Backbone, ) from keras_hub.src.models.deberta_v3.deberta_v3_masked_lm import ( - DebertaV3MaskedLM as DebertaV3MaskedLM, + DebertaV3MaskedLM, ) from keras_hub.src.models.deberta_v3.deberta_v3_masked_lm_preprocessor import ( - DebertaV3MaskedLMPreprocessor as DebertaV3MaskedLMPreprocessor, + DebertaV3MaskedLMPreprocessor, ) from keras_hub.src.models.deberta_v3.deberta_v3_text_classifier import ( - DebertaV3TextClassifier as DebertaV3Classifier, + DebertaV3TextClassifier, ) from keras_hub.src.models.deberta_v3.deberta_v3_text_classifier import ( - DebertaV3TextClassifier as DebertaV3TextClassifier, + DebertaV3TextClassifier as DebertaV3Classifier, ) from keras_hub.src.models.deberta_v3.deberta_v3_text_classifier_preprocessor import ( - DebertaV3TextClassifierPreprocessor as DebertaV3Preprocessor, + DebertaV3TextClassifierPreprocessor, ) from keras_hub.src.models.deberta_v3.deberta_v3_text_classifier_preprocessor import ( - DebertaV3TextClassifierPreprocessor as DebertaV3TextClassifierPreprocessor, + DebertaV3TextClassifierPreprocessor as DebertaV3Preprocessor, ) from keras_hub.src.models.deberta_v3.deberta_v3_tokenizer import ( - DebertaV3Tokenizer as DebertaV3Tokenizer, + DebertaV3Tokenizer, ) from keras_hub.src.models.deeplab_v3.deeplab_v3_backbone import ( - DeepLabV3Backbone as DeepLabV3Backbone, + DeepLabV3Backbone, ) from keras_hub.src.models.deeplab_v3.deeplab_v3_image_segmeter_preprocessor import ( - DeepLabV3ImageSegmenterPreprocessor as DeepLabV3ImageSegmenterPreprocessor, + DeepLabV3ImageSegmenterPreprocessor, ) from keras_hub.src.models.deeplab_v3.deeplab_v3_segmenter import ( - DeepLabV3ImageSegmenter as DeepLabV3ImageSegmenter, -) -from keras_hub.src.models.densenet.densenet_backbone import ( - DenseNetBackbone as DenseNetBackbone, + DeepLabV3ImageSegmenter, ) +from keras_hub.src.models.densenet.densenet_backbone import DenseNetBackbone from keras_hub.src.models.densenet.densenet_image_classifier import ( - DenseNetImageClassifier as DenseNetImageClassifier, + DenseNetImageClassifier, ) from keras_hub.src.models.densenet.densenet_image_classifier_preprocessor import ( - DenseNetImageClassifierPreprocessor as DenseNetImageClassifierPreprocessor, + DenseNetImageClassifierPreprocessor, ) from keras_hub.src.models.distil_bert.distil_bert_backbone import ( - DistilBertBackbone as DistilBertBackbone, + DistilBertBackbone, ) from keras_hub.src.models.distil_bert.distil_bert_masked_lm import ( - DistilBertMaskedLM as DistilBertMaskedLM, + DistilBertMaskedLM, ) from keras_hub.src.models.distil_bert.distil_bert_masked_lm_preprocessor import ( - DistilBertMaskedLMPreprocessor as DistilBertMaskedLMPreprocessor, + DistilBertMaskedLMPreprocessor, ) from keras_hub.src.models.distil_bert.distil_bert_text_classifier import ( - DistilBertTextClassifier as DistilBertClassifier, + DistilBertTextClassifier, ) from keras_hub.src.models.distil_bert.distil_bert_text_classifier import ( - DistilBertTextClassifier as DistilBertTextClassifier, + DistilBertTextClassifier as DistilBertClassifier, ) from keras_hub.src.models.distil_bert.distil_bert_text_classifier_preprocessor import ( - DistilBertTextClassifierPreprocessor as DistilBertPreprocessor, + DistilBertTextClassifierPreprocessor, ) from keras_hub.src.models.distil_bert.distil_bert_text_classifier_preprocessor import ( - DistilBertTextClassifierPreprocessor as DistilBertTextClassifierPreprocessor, + DistilBertTextClassifierPreprocessor as DistilBertPreprocessor, ) from keras_hub.src.models.distil_bert.distil_bert_tokenizer import ( - DistilBertTokenizer as DistilBertTokenizer, + DistilBertTokenizer, ) from keras_hub.src.models.efficientnet.efficientnet_backbone import ( - EfficientNetBackbone as EfficientNetBackbone, + EfficientNetBackbone, ) from keras_hub.src.models.efficientnet.efficientnet_image_classifier import ( - EfficientNetImageClassifier as EfficientNetImageClassifier, + EfficientNetImageClassifier, ) from keras_hub.src.models.efficientnet.efficientnet_image_classifier_preprocessor import ( - EfficientNetImageClassifierPreprocessor as EfficientNetImageClassifierPreprocessor, -) -from keras_hub.src.models.electra.electra_backbone import ( - ElectraBackbone as ElectraBackbone, -) -from keras_hub.src.models.electra.electra_tokenizer import ( - ElectraTokenizer as ElectraTokenizer, -) -from keras_hub.src.models.f_net.f_net_backbone import ( - FNetBackbone as FNetBackbone, -) -from keras_hub.src.models.f_net.f_net_masked_lm import ( - FNetMaskedLM as FNetMaskedLM, -) + EfficientNetImageClassifierPreprocessor, +) +from keras_hub.src.models.electra.electra_backbone import ElectraBackbone +from keras_hub.src.models.electra.electra_tokenizer import ElectraTokenizer +from keras_hub.src.models.esm.esm_backbone import ESMBackbone +from keras_hub.src.models.esm.esm_backbone import ESMBackbone as ESM2Backbone +from keras_hub.src.models.esm.esm_classifier import ESMProteinClassifier +from keras_hub.src.models.esm.esm_classifier_preprocessor import ( + ESMProteinClassifierPreprocessor, +) +from keras_hub.src.models.esm.esm_masked_plm import ESMMaskedPLM +from keras_hub.src.models.esm.esm_masked_plm import ( + ESMMaskedPLM as ESM2MaskedPLM, +) +from keras_hub.src.models.esm.esm_masked_plm_preprocessor import ( + ESMMaskedPLMPreprocessor, +) +from keras_hub.src.models.esm.esm_tokenizer import ESMTokenizer +from keras_hub.src.models.f_net.f_net_backbone import FNetBackbone +from keras_hub.src.models.f_net.f_net_masked_lm import FNetMaskedLM from keras_hub.src.models.f_net.f_net_masked_lm_preprocessor import ( - FNetMaskedLMPreprocessor as FNetMaskedLMPreprocessor, + FNetMaskedLMPreprocessor, ) +from keras_hub.src.models.f_net.f_net_text_classifier import FNetTextClassifier from keras_hub.src.models.f_net.f_net_text_classifier import ( FNetTextClassifier as FNetClassifier, ) -from keras_hub.src.models.f_net.f_net_text_classifier import ( - FNetTextClassifier as FNetTextClassifier, -) from keras_hub.src.models.f_net.f_net_text_classifier_preprocessor import ( - FNetTextClassifierPreprocessor as FNetPreprocessor, + FNetTextClassifierPreprocessor, ) from keras_hub.src.models.f_net.f_net_text_classifier_preprocessor import ( - FNetTextClassifierPreprocessor as FNetTextClassifierPreprocessor, -) -from keras_hub.src.models.f_net.f_net_tokenizer import ( - FNetTokenizer as FNetTokenizer, -) -from keras_hub.src.models.falcon.falcon_backbone import ( - FalconBackbone as FalconBackbone, -) -from keras_hub.src.models.falcon.falcon_causal_lm import ( - FalconCausalLM as FalconCausalLM, + FNetTextClassifierPreprocessor as FNetPreprocessor, ) +from keras_hub.src.models.f_net.f_net_tokenizer import FNetTokenizer +from keras_hub.src.models.falcon.falcon_backbone import FalconBackbone +from keras_hub.src.models.falcon.falcon_causal_lm import FalconCausalLM from keras_hub.src.models.falcon.falcon_causal_lm_preprocessor import ( - FalconCausalLMPreprocessor as FalconCausalLMPreprocessor, -) -from keras_hub.src.models.falcon.falcon_tokenizer import ( - FalconTokenizer as FalconTokenizer, -) -from keras_hub.src.models.feature_pyramid_backbone import ( - FeaturePyramidBackbone as FeaturePyramidBackbone, -) -from keras_hub.src.models.flux.flux_model import FluxBackbone as FluxBackbone -from keras_hub.src.models.flux.flux_text_to_image import ( - FluxTextToImage as FluxTextToImage, + FalconCausalLMPreprocessor, ) +from keras_hub.src.models.falcon.falcon_tokenizer import FalconTokenizer +from keras_hub.src.models.feature_pyramid_backbone import FeaturePyramidBackbone +from keras_hub.src.models.flux.flux_model import FluxBackbone +from keras_hub.src.models.flux.flux_text_to_image import FluxTextToImage from keras_hub.src.models.flux.flux_text_to_image_preprocessor import ( - FluxTextToImagePreprocessor as FluxTextToImagePreprocessor, -) -from keras_hub.src.models.gemma.gemma_backbone import ( - GemmaBackbone as GemmaBackbone, -) -from keras_hub.src.models.gemma.gemma_causal_lm import ( - GemmaCausalLM as GemmaCausalLM, + FluxTextToImagePreprocessor, ) +from keras_hub.src.models.gemma.gemma_backbone import GemmaBackbone +from keras_hub.src.models.gemma.gemma_causal_lm import GemmaCausalLM from keras_hub.src.models.gemma.gemma_causal_lm_preprocessor import ( - GemmaCausalLMPreprocessor as GemmaCausalLMPreprocessor, -) -from keras_hub.src.models.gemma.gemma_tokenizer import ( - GemmaTokenizer as GemmaTokenizer, -) -from keras_hub.src.models.gemma3.gemma3_backbone import ( - Gemma3Backbone as Gemma3Backbone, -) -from keras_hub.src.models.gemma3.gemma3_causal_lm import ( - Gemma3CausalLM as Gemma3CausalLM, + GemmaCausalLMPreprocessor, ) +from keras_hub.src.models.gemma.gemma_tokenizer import GemmaTokenizer +from keras_hub.src.models.gemma3.gemma3_backbone import Gemma3Backbone +from keras_hub.src.models.gemma3.gemma3_causal_lm import Gemma3CausalLM from keras_hub.src.models.gemma3.gemma3_causal_lm_preprocessor import ( - Gemma3CausalLMPreprocessor as Gemma3CausalLMPreprocessor, -) -from keras_hub.src.models.gemma3.gemma3_tokenizer import ( - Gemma3Tokenizer as Gemma3Tokenizer, + Gemma3CausalLMPreprocessor, ) +from keras_hub.src.models.gemma3.gemma3_tokenizer import Gemma3Tokenizer from keras_hub.src.models.gemma3.gemma3_vision_encoder import ( - Gemma3VisionEncoder as Gemma3VisionEncoder, -) -from keras_hub.src.models.gpt2.gpt2_backbone import GPT2Backbone as GPT2Backbone -from keras_hub.src.models.gpt2.gpt2_causal_lm import ( - GPT2CausalLM as GPT2CausalLM, + Gemma3VisionEncoder, ) +from keras_hub.src.models.gpt2.gpt2_backbone import GPT2Backbone +from keras_hub.src.models.gpt2.gpt2_causal_lm import GPT2CausalLM from keras_hub.src.models.gpt2.gpt2_causal_lm_preprocessor import ( - GPT2CausalLMPreprocessor as GPT2CausalLMPreprocessor, -) -from keras_hub.src.models.gpt2.gpt2_preprocessor import ( - GPT2Preprocessor as GPT2Preprocessor, -) -from keras_hub.src.models.gpt2.gpt2_tokenizer import ( - GPT2Tokenizer as GPT2Tokenizer, -) -from keras_hub.src.models.gpt_neo_x.gpt_neo_x_backbone import ( - GPTNeoXBackbone as GPTNeoXBackbone, -) -from keras_hub.src.models.gpt_neo_x.gpt_neo_x_causal_lm import ( - GPTNeoXCausalLM as GPTNeoXCausalLM, + GPT2CausalLMPreprocessor, ) +from keras_hub.src.models.gpt2.gpt2_preprocessor import GPT2Preprocessor +from keras_hub.src.models.gpt2.gpt2_tokenizer import GPT2Tokenizer +from keras_hub.src.models.gpt_neo_x.gpt_neo_x_backbone import GPTNeoXBackbone +from keras_hub.src.models.gpt_neo_x.gpt_neo_x_causal_lm import GPTNeoXCausalLM from keras_hub.src.models.gpt_neo_x.gpt_neo_x_causal_lm_preprocessor import ( - GPTNeoXCausalLMPreprocessor as GPTNeoXCausalLMPreprocessor, -) -from keras_hub.src.models.gpt_neo_x.gpt_neo_x_tokenizer import ( - GPTNeoXTokenizer as GPTNeoXTokenizer, -) -from keras_hub.src.models.image_classifier import ( - ImageClassifier as ImageClassifier, + GPTNeoXCausalLMPreprocessor, ) +from keras_hub.src.models.gpt_neo_x.gpt_neo_x_tokenizer import GPTNeoXTokenizer +from keras_hub.src.models.image_classifier import ImageClassifier from keras_hub.src.models.image_classifier_preprocessor import ( - ImageClassifierPreprocessor as ImageClassifierPreprocessor, -) -from keras_hub.src.models.image_segmenter import ( - ImageSegmenter as ImageSegmenter, + ImageClassifierPreprocessor, ) +from keras_hub.src.models.image_segmenter import ImageSegmenter from keras_hub.src.models.image_segmenter_preprocessor import ( - ImageSegmenterPreprocessor as ImageSegmenterPreprocessor, -) -from keras_hub.src.models.image_to_image import ImageToImage as ImageToImage -from keras_hub.src.models.inpaint import Inpaint as Inpaint -from keras_hub.src.models.llama.llama_backbone import ( - LlamaBackbone as LlamaBackbone, -) -from keras_hub.src.models.llama.llama_causal_lm import ( - LlamaCausalLM as LlamaCausalLM, + ImageSegmenterPreprocessor, ) +from keras_hub.src.models.image_to_image import ImageToImage +from keras_hub.src.models.inpaint import Inpaint +from keras_hub.src.models.llama.llama_backbone import LlamaBackbone +from keras_hub.src.models.llama.llama_causal_lm import LlamaCausalLM from keras_hub.src.models.llama.llama_causal_lm_preprocessor import ( - LlamaCausalLMPreprocessor as LlamaCausalLMPreprocessor, -) -from keras_hub.src.models.llama.llama_tokenizer import ( - LlamaTokenizer as LlamaTokenizer, -) -from keras_hub.src.models.llama3.llama3_backbone import ( - Llama3Backbone as Llama3Backbone, -) -from keras_hub.src.models.llama3.llama3_causal_lm import ( - Llama3CausalLM as Llama3CausalLM, + LlamaCausalLMPreprocessor, ) +from keras_hub.src.models.llama.llama_tokenizer import LlamaTokenizer +from keras_hub.src.models.llama3.llama3_backbone import Llama3Backbone +from keras_hub.src.models.llama3.llama3_causal_lm import Llama3CausalLM from keras_hub.src.models.llama3.llama3_causal_lm_preprocessor import ( - Llama3CausalLMPreprocessor as Llama3CausalLMPreprocessor, -) -from keras_hub.src.models.llama3.llama3_tokenizer import ( - Llama3Tokenizer as Llama3Tokenizer, -) -from keras_hub.src.models.masked_lm import MaskedLM as MaskedLM -from keras_hub.src.models.masked_lm_preprocessor import ( - MaskedLMPreprocessor as MaskedLMPreprocessor, -) -from keras_hub.src.models.mistral.mistral_backbone import ( - MistralBackbone as MistralBackbone, -) -from keras_hub.src.models.mistral.mistral_causal_lm import ( - MistralCausalLM as MistralCausalLM, + Llama3CausalLMPreprocessor, ) +from keras_hub.src.models.llama3.llama3_tokenizer import Llama3Tokenizer +from keras_hub.src.models.masked_lm import MaskedLM +from keras_hub.src.models.masked_lm_preprocessor import MaskedLMPreprocessor +from keras_hub.src.models.mistral.mistral_backbone import MistralBackbone +from keras_hub.src.models.mistral.mistral_causal_lm import MistralCausalLM from keras_hub.src.models.mistral.mistral_causal_lm_preprocessor import ( - MistralCausalLMPreprocessor as MistralCausalLMPreprocessor, -) -from keras_hub.src.models.mistral.mistral_tokenizer import ( - MistralTokenizer as MistralTokenizer, -) -from keras_hub.src.models.mit.mit_backbone import MiTBackbone as MiTBackbone -from keras_hub.src.models.mit.mit_image_classifier import ( - MiTImageClassifier as MiTImageClassifier, + MistralCausalLMPreprocessor, ) +from keras_hub.src.models.mistral.mistral_tokenizer import MistralTokenizer +from keras_hub.src.models.mit.mit_backbone import MiTBackbone +from keras_hub.src.models.mit.mit_image_classifier import MiTImageClassifier from keras_hub.src.models.mit.mit_image_classifier_preprocessor import ( - MiTImageClassifierPreprocessor as MiTImageClassifierPreprocessor, -) -from keras_hub.src.models.mobilenet.mobilenet_backbone import ( - MobileNetBackbone as MobileNetBackbone, + MiTImageClassifierPreprocessor, ) +from keras_hub.src.models.mobilenet.mobilenet_backbone import MobileNetBackbone from keras_hub.src.models.mobilenet.mobilenet_image_classifier import ( - MobileNetImageClassifier as MobileNetImageClassifier, + MobileNetImageClassifier, ) from keras_hub.src.models.mobilenet.mobilenet_image_classifier_preprocessor import ( - MobileNetImageClassifierPreprocessor as MobileNetImageClassifierPreprocessor, + MobileNetImageClassifierPreprocessor, ) +from keras_hub.src.models.object_detector import ObjectDetector from keras_hub.src.models.object_detector import ( ObjectDetector as ImageObjectDetector, ) -from keras_hub.src.models.object_detector import ( - ObjectDetector as ObjectDetector, -) from keras_hub.src.models.object_detector_preprocessor import ( - ObjectDetectorPreprocessor as ImageObjectDetectorPreprocessor, + ObjectDetectorPreprocessor, ) from keras_hub.src.models.object_detector_preprocessor import ( - ObjectDetectorPreprocessor as ObjectDetectorPreprocessor, + ObjectDetectorPreprocessor as ImageObjectDetectorPreprocessor, ) -from keras_hub.src.models.opt.opt_backbone import OPTBackbone as OPTBackbone -from keras_hub.src.models.opt.opt_causal_lm import OPTCausalLM as OPTCausalLM +from keras_hub.src.models.opt.opt_backbone import OPTBackbone +from keras_hub.src.models.opt.opt_causal_lm import OPTCausalLM from keras_hub.src.models.opt.opt_causal_lm_preprocessor import ( - OPTCausalLMPreprocessor as OPTCausalLMPreprocessor, + OPTCausalLMPreprocessor, ) -from keras_hub.src.models.opt.opt_tokenizer import OPTTokenizer as OPTTokenizer +from keras_hub.src.models.opt.opt_tokenizer import OPTTokenizer from keras_hub.src.models.pali_gemma.pali_gemma_backbone import ( - PaliGemmaBackbone as PaliGemmaBackbone, + PaliGemmaBackbone, ) from keras_hub.src.models.pali_gemma.pali_gemma_causal_lm import ( - PaliGemmaCausalLM as PaliGemmaCausalLM, + PaliGemmaCausalLM, ) from keras_hub.src.models.pali_gemma.pali_gemma_causal_lm_preprocessor import ( - PaliGemmaCausalLMPreprocessor as PaliGemmaCausalLMPreprocessor, + PaliGemmaCausalLMPreprocessor, ) from keras_hub.src.models.pali_gemma.pali_gemma_tokenizer import ( - PaliGemmaTokenizer as PaliGemmaTokenizer, -) -from keras_hub.src.models.phi3.phi3_backbone import Phi3Backbone as Phi3Backbone -from keras_hub.src.models.phi3.phi3_causal_lm import ( - Phi3CausalLM as Phi3CausalLM, + PaliGemmaTokenizer, ) +from keras_hub.src.models.phi3.phi3_backbone import Phi3Backbone +from keras_hub.src.models.phi3.phi3_causal_lm import Phi3CausalLM from keras_hub.src.models.phi3.phi3_causal_lm_preprocessor import ( - Phi3CausalLMPreprocessor as Phi3CausalLMPreprocessor, + Phi3CausalLMPreprocessor, ) -from keras_hub.src.models.phi3.phi3_tokenizer import ( - Phi3Tokenizer as Phi3Tokenizer, -) -from keras_hub.src.models.preprocessor import Preprocessor as Preprocessor +from keras_hub.src.models.phi3.phi3_tokenizer import Phi3Tokenizer +from keras_hub.src.models.preprocessor import Preprocessor +from keras_hub.src.models.qwen.qwen_backbone import QwenBackbone from keras_hub.src.models.qwen.qwen_backbone import ( QwenBackbone as Qwen2Backbone, ) -from keras_hub.src.models.qwen.qwen_backbone import QwenBackbone as QwenBackbone +from keras_hub.src.models.qwen.qwen_causal_lm import QwenCausalLM from keras_hub.src.models.qwen.qwen_causal_lm import ( QwenCausalLM as Qwen2CausalLM, ) -from keras_hub.src.models.qwen.qwen_causal_lm import ( - QwenCausalLM as QwenCausalLM, -) from keras_hub.src.models.qwen.qwen_causal_lm_preprocessor import ( - QwenCausalLMPreprocessor as Qwen2CausalLMPreprocessor, + QwenCausalLMPreprocessor, ) from keras_hub.src.models.qwen.qwen_causal_lm_preprocessor import ( - QwenCausalLMPreprocessor as QwenCausalLMPreprocessor, + QwenCausalLMPreprocessor as Qwen2CausalLMPreprocessor, ) +from keras_hub.src.models.qwen.qwen_tokenizer import QwenTokenizer from keras_hub.src.models.qwen.qwen_tokenizer import ( QwenTokenizer as Qwen2Tokenizer, ) -from keras_hub.src.models.qwen.qwen_tokenizer import ( - QwenTokenizer as QwenTokenizer, -) -from keras_hub.src.models.resnet.resnet_backbone import ( - ResNetBackbone as ResNetBackbone, -) +from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone from keras_hub.src.models.resnet.resnet_image_classifier import ( - ResNetImageClassifier as ResNetImageClassifier, + ResNetImageClassifier, ) from keras_hub.src.models.resnet.resnet_image_classifier_preprocessor import ( - ResNetImageClassifierPreprocessor as ResNetImageClassifierPreprocessor, -) -from keras_hub.src.models.retinanet.retinanet_backbone import ( - RetinaNetBackbone as RetinaNetBackbone, + ResNetImageClassifierPreprocessor, ) +from keras_hub.src.models.retinanet.retinanet_backbone import RetinaNetBackbone from keras_hub.src.models.retinanet.retinanet_object_detector import ( - RetinaNetObjectDetector as RetinaNetObjectDetector, + RetinaNetObjectDetector, ) from keras_hub.src.models.retinanet.retinanet_object_detector_preprocessor import ( - RetinaNetObjectDetectorPreprocessor as RetinaNetObjectDetectorPreprocessor, -) -from keras_hub.src.models.roberta.roberta_backbone import ( - RobertaBackbone as RobertaBackbone, -) -from keras_hub.src.models.roberta.roberta_masked_lm import ( - RobertaMaskedLM as RobertaMaskedLM, + RetinaNetObjectDetectorPreprocessor, ) +from keras_hub.src.models.roberta.roberta_backbone import RobertaBackbone +from keras_hub.src.models.roberta.roberta_masked_lm import RobertaMaskedLM from keras_hub.src.models.roberta.roberta_masked_lm_preprocessor import ( - RobertaMaskedLMPreprocessor as RobertaMaskedLMPreprocessor, + RobertaMaskedLMPreprocessor, ) from keras_hub.src.models.roberta.roberta_text_classifier import ( - RobertaTextClassifier as RobertaClassifier, + RobertaTextClassifier, ) from keras_hub.src.models.roberta.roberta_text_classifier import ( - RobertaTextClassifier as RobertaTextClassifier, + RobertaTextClassifier as RobertaClassifier, ) from keras_hub.src.models.roberta.roberta_text_classifier_preprocessor import ( - RobertaTextClassifierPreprocessor as RobertaPreprocessor, + RobertaTextClassifierPreprocessor, ) from keras_hub.src.models.roberta.roberta_text_classifier_preprocessor import ( - RobertaTextClassifierPreprocessor as RobertaTextClassifierPreprocessor, -) -from keras_hub.src.models.roberta.roberta_tokenizer import ( - RobertaTokenizer as RobertaTokenizer, + RobertaTextClassifierPreprocessor as RobertaPreprocessor, ) +from keras_hub.src.models.roberta.roberta_tokenizer import RobertaTokenizer from keras_hub.src.models.roformer_v2.roformer_v2_backbone import ( - RoformerV2Backbone as RoformerV2Backbone, + RoformerV2Backbone, ) from keras_hub.src.models.roformer_v2.roformer_v2_masked_lm import ( - RoformerV2MaskedLM as RoformerV2MaskedLM, + RoformerV2MaskedLM, ) from keras_hub.src.models.roformer_v2.roformer_v2_masked_lm_preprocessor import ( - RoformerV2MaskedLMPreprocessor as RoformerV2MaskedLMPreprocessor, + RoformerV2MaskedLMPreprocessor, ) from keras_hub.src.models.roformer_v2.roformer_v2_text_classifier import ( - RoformerV2TextClassifier as RoformerV2TextClassifier, + RoformerV2TextClassifier, ) from keras_hub.src.models.roformer_v2.roformer_v2_text_classifier_preprocessor import ( - RoformerV2TextClassifierPreprocessor as RoformerV2TextClassifierPreprocessor, + RoformerV2TextClassifierPreprocessor, ) from keras_hub.src.models.roformer_v2.roformer_v2_tokenizer import ( - RoformerV2Tokenizer as RoformerV2Tokenizer, -) -from keras_hub.src.models.sam.sam_backbone import SAMBackbone as SAMBackbone -from keras_hub.src.models.sam.sam_image_segmenter import ( - SAMImageSegmenter as SAMImageSegmenter, + RoformerV2Tokenizer, ) +from keras_hub.src.models.sam.sam_backbone import SAMBackbone +from keras_hub.src.models.sam.sam_image_segmenter import SAMImageSegmenter from keras_hub.src.models.sam.sam_image_segmenter_preprocessor import ( - SAMImageSegmenterPreprocessor as SAMImageSegmenterPreprocessor, -) -from keras_hub.src.models.segformer.segformer_backbone import ( - SegFormerBackbone as SegFormerBackbone, + SAMImageSegmenterPreprocessor, ) +from keras_hub.src.models.segformer.segformer_backbone import SegFormerBackbone from keras_hub.src.models.segformer.segformer_image_segmenter import ( - SegFormerImageSegmenter as SegFormerImageSegmenter, + SegFormerImageSegmenter, ) from keras_hub.src.models.segformer.segformer_image_segmenter_preprocessor import ( - SegFormerImageSegmenterPreprocessor as SegFormerImageSegmenterPreprocessor, -) -from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM as Seq2SeqLM -from keras_hub.src.models.seq_2_seq_lm_preprocessor import ( - Seq2SeqLMPreprocessor as Seq2SeqLMPreprocessor, -) -from keras_hub.src.models.siglip.siglip_backbone import ( - SigLIPBackbone as SigLIPBackbone, -) -from keras_hub.src.models.siglip.siglip_preprocessor import ( - SigLIPPreprocessor as SigLIPPreprocessor, -) -from keras_hub.src.models.siglip.siglip_text_encoder import ( - SigLIPTextEncoder as SigLIPTextEncoder, -) -from keras_hub.src.models.siglip.siglip_tokenizer import ( - SigLIPTokenizer as SigLIPTokenizer, -) + SegFormerImageSegmenterPreprocessor, +) +from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM +from keras_hub.src.models.seq_2_seq_lm_preprocessor import Seq2SeqLMPreprocessor +from keras_hub.src.models.siglip.siglip_backbone import SigLIPBackbone +from keras_hub.src.models.siglip.siglip_preprocessor import SigLIPPreprocessor +from keras_hub.src.models.siglip.siglip_text_encoder import SigLIPTextEncoder +from keras_hub.src.models.siglip.siglip_tokenizer import SigLIPTokenizer from keras_hub.src.models.siglip.siglip_vision_encoder import ( - SigLIPVisionEncoder as SigLIPVisionEncoder, + SigLIPVisionEncoder, ) from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import ( - StableDiffusion3Backbone as StableDiffusion3Backbone, + StableDiffusion3Backbone, ) from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_image_to_image import ( - StableDiffusion3ImageToImage as StableDiffusion3ImageToImage, + StableDiffusion3ImageToImage, ) from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_inpaint import ( - StableDiffusion3Inpaint as StableDiffusion3Inpaint, + StableDiffusion3Inpaint, ) from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image import ( - StableDiffusion3TextToImage as StableDiffusion3TextToImage, + StableDiffusion3TextToImage, ) from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image_preprocessor import ( - StableDiffusion3TextToImagePreprocessor as StableDiffusion3TextToImagePreprocessor, + StableDiffusion3TextToImagePreprocessor, ) -from keras_hub.src.models.t5.t5_backbone import T5Backbone as T5Backbone -from keras_hub.src.models.t5.t5_preprocessor import ( - T5Preprocessor as T5Preprocessor, -) -from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer as T5Tokenizer -from keras_hub.src.models.task import Task as Task +from keras_hub.src.models.t5.t5_backbone import T5Backbone +from keras_hub.src.models.t5.t5_preprocessor import T5Preprocessor +from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer +from keras_hub.src.models.task import Task +from keras_hub.src.models.text_classifier import TextClassifier from keras_hub.src.models.text_classifier import TextClassifier as Classifier -from keras_hub.src.models.text_classifier import ( - TextClassifier as TextClassifier, -) from keras_hub.src.models.text_classifier_preprocessor import ( - TextClassifierPreprocessor as TextClassifierPreprocessor, + TextClassifierPreprocessor, ) -from keras_hub.src.models.text_to_image import TextToImage as TextToImage +from keras_hub.src.models.text_to_image import TextToImage from keras_hub.src.models.text_to_image_preprocessor import ( - TextToImagePreprocessor as TextToImagePreprocessor, -) -from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone as VGGBackbone -from keras_hub.src.models.vgg.vgg_image_classifier import ( - VGGImageClassifier as VGGImageClassifier, + TextToImagePreprocessor, ) +from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone +from keras_hub.src.models.vgg.vgg_image_classifier import VGGImageClassifier from keras_hub.src.models.vgg.vgg_image_classifier_preprocessor import ( - VGGImageClassifierPreprocessor as VGGImageClassifierPreprocessor, -) -from keras_hub.src.models.vit.vit_backbone import ViTBackbone as ViTBackbone -from keras_hub.src.models.vit.vit_image_classifier import ( - ViTImageClassifier as ViTImageClassifier, + VGGImageClassifierPreprocessor, ) +from keras_hub.src.models.vit.vit_backbone import ViTBackbone +from keras_hub.src.models.vit.vit_image_classifier import ViTImageClassifier from keras_hub.src.models.vit.vit_image_classifier_preprocessor import ( - ViTImageClassifierPreprocessor as ViTImageClassifierPreprocessor, -) -from keras_hub.src.models.vit_det.vit_det_backbone import ( - ViTDetBackbone as ViTDetBackbone, -) -from keras_hub.src.models.whisper.whisper_backbone import ( - WhisperBackbone as WhisperBackbone, -) -from keras_hub.src.models.whisper.whisper_tokenizer import ( - WhisperTokenizer as WhisperTokenizer, -) -from keras_hub.src.models.xception.xception_backbone import ( - XceptionBackbone as XceptionBackbone, + ViTImageClassifierPreprocessor, ) +from keras_hub.src.models.vit_det.vit_det_backbone import ViTDetBackbone +from keras_hub.src.models.whisper.whisper_backbone import WhisperBackbone +from keras_hub.src.models.whisper.whisper_tokenizer import WhisperTokenizer +from keras_hub.src.models.xception.xception_backbone import XceptionBackbone from keras_hub.src.models.xception.xception_image_classifier import ( - XceptionImageClassifier as XceptionImageClassifier, + XceptionImageClassifier, ) from keras_hub.src.models.xception.xception_image_classifier_preprocessor import ( - XceptionImageClassifierPreprocessor as XceptionImageClassifierPreprocessor, + XceptionImageClassifierPreprocessor, ) from keras_hub.src.models.xlm_roberta.xlm_roberta_backbone import ( - XLMRobertaBackbone as XLMRobertaBackbone, + XLMRobertaBackbone, ) from keras_hub.src.models.xlm_roberta.xlm_roberta_masked_lm import ( - XLMRobertaMaskedLM as XLMRobertaMaskedLM, + XLMRobertaMaskedLM, ) from keras_hub.src.models.xlm_roberta.xlm_roberta_masked_lm_preprocessor import ( - XLMRobertaMaskedLMPreprocessor as XLMRobertaMaskedLMPreprocessor, + XLMRobertaMaskedLMPreprocessor, ) from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier import ( - XLMRobertaTextClassifier as XLMRobertaClassifier, + XLMRobertaTextClassifier, ) from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier import ( - XLMRobertaTextClassifier as XLMRobertaTextClassifier, + XLMRobertaTextClassifier as XLMRobertaClassifier, ) from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier_preprocessor import ( - XLMRobertaTextClassifierPreprocessor as XLMRobertaPreprocessor, + XLMRobertaTextClassifierPreprocessor, ) from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier_preprocessor import ( - XLMRobertaTextClassifierPreprocessor as XLMRobertaTextClassifierPreprocessor, + XLMRobertaTextClassifierPreprocessor as XLMRobertaPreprocessor, ) from keras_hub.src.models.xlm_roberta.xlm_roberta_tokenizer import ( - XLMRobertaTokenizer as XLMRobertaTokenizer, -) -from keras_hub.src.models.xlnet.xlnet_backbone import ( - XLNetBackbone as XLNetBackbone, + XLMRobertaTokenizer, ) -from keras_hub.src.tokenizers.tokenizer import Tokenizer as Tokenizer +from keras_hub.src.models.xlnet.xlnet_backbone import XLNetBackbone +from keras_hub.src.tokenizers.tokenizer import Tokenizer diff --git a/keras_hub/api/samplers/__init__.py b/keras_hub/api/samplers/__init__.py index 29bfef00fc..9feb76c669 100644 --- a/keras_hub/api/samplers/__init__.py +++ b/keras_hub/api/samplers/__init__.py @@ -4,15 +4,13 @@ since your modifications would be overwritten. """ -from keras_hub.src.samplers.beam_sampler import BeamSampler as BeamSampler -from keras_hub.src.samplers.contrastive_sampler import ( - ContrastiveSampler as ContrastiveSampler, -) -from keras_hub.src.samplers.greedy_sampler import GreedySampler as GreedySampler -from keras_hub.src.samplers.random_sampler import RandomSampler as RandomSampler -from keras_hub.src.samplers.sampler import Sampler as Sampler -from keras_hub.src.samplers.serialization import deserialize as deserialize -from keras_hub.src.samplers.serialization import get as get -from keras_hub.src.samplers.serialization import serialize as serialize -from keras_hub.src.samplers.top_k_sampler import TopKSampler as TopKSampler -from keras_hub.src.samplers.top_p_sampler import TopPSampler as TopPSampler +from keras_hub.src.samplers.beam_sampler import BeamSampler +from keras_hub.src.samplers.contrastive_sampler import ContrastiveSampler +from keras_hub.src.samplers.greedy_sampler import GreedySampler +from keras_hub.src.samplers.random_sampler import RandomSampler +from keras_hub.src.samplers.sampler import Sampler +from keras_hub.src.samplers.serialization import deserialize +from keras_hub.src.samplers.serialization import get +from keras_hub.src.samplers.serialization import serialize +from keras_hub.src.samplers.top_k_sampler import TopKSampler +from keras_hub.src.samplers.top_p_sampler import TopPSampler diff --git a/keras_hub/api/tokenizers/__init__.py b/keras_hub/api/tokenizers/__init__.py index 79b6efa192..3615e77581 100644 --- a/keras_hub/api/tokenizers/__init__.py +++ b/keras_hub/api/tokenizers/__init__.py @@ -4,105 +4,60 @@ since your modifications would be overwritten. """ -from keras_hub.src.models.albert.albert_tokenizer import ( - AlbertTokenizer as AlbertTokenizer, -) -from keras_hub.src.models.bart.bart_tokenizer import ( - BartTokenizer as BartTokenizer, -) -from keras_hub.src.models.bert.bert_tokenizer import ( - BertTokenizer as BertTokenizer, -) -from keras_hub.src.models.bloom.bloom_tokenizer import ( - BloomTokenizer as BloomTokenizer, -) -from keras_hub.src.models.clip.clip_tokenizer import ( - CLIPTokenizer as CLIPTokenizer, -) +from keras_hub.src.models.albert.albert_tokenizer import AlbertTokenizer +from keras_hub.src.models.bart.bart_tokenizer import BartTokenizer +from keras_hub.src.models.bert.bert_tokenizer import BertTokenizer +from keras_hub.src.models.bloom.bloom_tokenizer import BloomTokenizer +from keras_hub.src.models.clip.clip_tokenizer import CLIPTokenizer from keras_hub.src.models.deberta_v3.deberta_v3_tokenizer import ( - DebertaV3Tokenizer as DebertaV3Tokenizer, + DebertaV3Tokenizer, ) from keras_hub.src.models.distil_bert.distil_bert_tokenizer import ( - DistilBertTokenizer as DistilBertTokenizer, -) -from keras_hub.src.models.electra.electra_tokenizer import ( - ElectraTokenizer as ElectraTokenizer, -) -from keras_hub.src.models.f_net.f_net_tokenizer import ( - FNetTokenizer as FNetTokenizer, -) -from keras_hub.src.models.falcon.falcon_tokenizer import ( - FalconTokenizer as FalconTokenizer, -) -from keras_hub.src.models.gemma.gemma_tokenizer import ( - GemmaTokenizer as GemmaTokenizer, -) -from keras_hub.src.models.gemma3.gemma3_tokenizer import ( - Gemma3Tokenizer as Gemma3Tokenizer, -) -from keras_hub.src.models.gpt2.gpt2_tokenizer import ( - GPT2Tokenizer as GPT2Tokenizer, -) -from keras_hub.src.models.gpt_neo_x.gpt_neo_x_tokenizer import ( - GPTNeoXTokenizer as GPTNeoXTokenizer, -) -from keras_hub.src.models.llama.llama_tokenizer import ( - LlamaTokenizer as LlamaTokenizer, -) -from keras_hub.src.models.llama3.llama3_tokenizer import ( - Llama3Tokenizer as Llama3Tokenizer, -) -from keras_hub.src.models.mistral.mistral_tokenizer import ( - MistralTokenizer as MistralTokenizer, -) -from keras_hub.src.models.opt.opt_tokenizer import OPTTokenizer as OPTTokenizer + DistilBertTokenizer, +) +from keras_hub.src.models.electra.electra_tokenizer import ElectraTokenizer +from keras_hub.src.models.esm.esm_tokenizer import ESMTokenizer +from keras_hub.src.models.f_net.f_net_tokenizer import FNetTokenizer +from keras_hub.src.models.falcon.falcon_tokenizer import FalconTokenizer +from keras_hub.src.models.gemma.gemma_tokenizer import GemmaTokenizer +from keras_hub.src.models.gemma3.gemma3_tokenizer import Gemma3Tokenizer +from keras_hub.src.models.gpt2.gpt2_tokenizer import GPT2Tokenizer +from keras_hub.src.models.gpt_neo_x.gpt_neo_x_tokenizer import GPTNeoXTokenizer +from keras_hub.src.models.llama.llama_tokenizer import LlamaTokenizer +from keras_hub.src.models.llama3.llama3_tokenizer import Llama3Tokenizer +from keras_hub.src.models.mistral.mistral_tokenizer import MistralTokenizer +from keras_hub.src.models.opt.opt_tokenizer import OPTTokenizer from keras_hub.src.models.pali_gemma.pali_gemma_tokenizer import ( - PaliGemmaTokenizer as PaliGemmaTokenizer, -) -from keras_hub.src.models.phi3.phi3_tokenizer import ( - Phi3Tokenizer as Phi3Tokenizer, + PaliGemmaTokenizer, ) +from keras_hub.src.models.phi3.phi3_tokenizer import Phi3Tokenizer +from keras_hub.src.models.qwen.qwen_tokenizer import QwenTokenizer from keras_hub.src.models.qwen.qwen_tokenizer import ( QwenTokenizer as Qwen2Tokenizer, ) -from keras_hub.src.models.qwen.qwen_tokenizer import ( - QwenTokenizer as QwenTokenizer, -) -from keras_hub.src.models.roberta.roberta_tokenizer import ( - RobertaTokenizer as RobertaTokenizer, -) +from keras_hub.src.models.roberta.roberta_tokenizer import RobertaTokenizer from keras_hub.src.models.roformer_v2.roformer_v2_tokenizer import ( - RoformerV2Tokenizer as RoformerV2Tokenizer, -) -from keras_hub.src.models.siglip.siglip_tokenizer import ( - SigLIPTokenizer as SigLIPTokenizer, -) -from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer as T5Tokenizer -from keras_hub.src.models.whisper.whisper_tokenizer import ( - WhisperTokenizer as WhisperTokenizer, + RoformerV2Tokenizer, ) +from keras_hub.src.models.siglip.siglip_tokenizer import SigLIPTokenizer +from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer +from keras_hub.src.models.whisper.whisper_tokenizer import WhisperTokenizer from keras_hub.src.models.xlm_roberta.xlm_roberta_tokenizer import ( - XLMRobertaTokenizer as XLMRobertaTokenizer, -) -from keras_hub.src.tokenizers.byte_pair_tokenizer import ( - BytePairTokenizer as BytePairTokenizer, -) -from keras_hub.src.tokenizers.byte_tokenizer import ( - ByteTokenizer as ByteTokenizer, + XLMRobertaTokenizer, ) +from keras_hub.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer +from keras_hub.src.tokenizers.byte_tokenizer import ByteTokenizer from keras_hub.src.tokenizers.sentence_piece_tokenizer import ( - SentencePieceTokenizer as SentencePieceTokenizer, + SentencePieceTokenizer, ) from keras_hub.src.tokenizers.sentence_piece_tokenizer_trainer import ( - compute_sentence_piece_proto as compute_sentence_piece_proto, + compute_sentence_piece_proto, ) -from keras_hub.src.tokenizers.tokenizer import Tokenizer as Tokenizer +from keras_hub.src.tokenizers.tokenizer import Tokenizer from keras_hub.src.tokenizers.unicode_codepoint_tokenizer import ( - UnicodeCodepointTokenizer as UnicodeCodepointTokenizer, -) -from keras_hub.src.tokenizers.word_piece_tokenizer import ( - WordPieceTokenizer as WordPieceTokenizer, + UnicodeCodepointTokenizer, ) +from keras_hub.src.tokenizers.word_piece_tokenizer import WordPieceTokenizer from keras_hub.src.tokenizers.word_piece_tokenizer_trainer import ( - compute_word_piece_vocabulary as compute_word_piece_vocabulary, + compute_word_piece_vocabulary, ) diff --git a/keras_hub/api/utils/__init__.py b/keras_hub/api/utils/__init__.py index 0bd8cb642e..8ce47790b0 100644 --- a/keras_hub/api/utils/__init__.py +++ b/keras_hub/api/utils/__init__.py @@ -4,18 +4,10 @@ since your modifications would be overwritten. """ -from keras_hub.src.utils.coco.coco_utils import ( - coco_id_to_name as coco_id_to_name, -) -from keras_hub.src.utils.coco.coco_utils import ( - coco_name_to_id as coco_name_to_id, -) -from keras_hub.src.utils.imagenet.imagenet_utils import ( - decode_imagenet_predictions as decode_imagenet_predictions, -) -from keras_hub.src.utils.imagenet.imagenet_utils import ( - imagenet_id_to_name as imagenet_id_to_name, -) +from keras_hub.src.utils.coco.coco_utils import coco_id_to_name +from keras_hub.src.utils.coco.coco_utils import coco_name_to_id from keras_hub.src.utils.imagenet.imagenet_utils import ( - imagenet_name_to_id as imagenet_name_to_id, + decode_imagenet_predictions, ) +from keras_hub.src.utils.imagenet.imagenet_utils import imagenet_id_to_name +from keras_hub.src.utils.imagenet.imagenet_utils import imagenet_name_to_id diff --git a/keras_hub/src/models/esm/esm_attention.py b/keras_hub/src/models/esm/esm_attention.py index e6cef2ffd7..3c89f33d7b 100644 --- a/keras_hub/src/models/esm/esm_attention.py +++ b/keras_hub/src/models/esm/esm_attention.py @@ -1,52 +1,65 @@ -from keras import ops -from keras import initializers import keras +from keras import ops + from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding -from keras_hub.src.models.roformer_v2.roformer_v2_attention import RoformerAttention +from keras_hub.src.models.roformer_v2.roformer_v2_attention import ( + RoformerAttention, +) + + class ESMRotaryEmbedding(RotaryEmbedding): - def _compute_cos_sin_embedding(self,x,position=1): + def _compute_cos_sin_embedding(self, x, position=1): dim = x.shape[-1] - inv_freq = self.scaling_factor / (self.max_wavelength ** (ops.arange(0, dim, 2, dtype=x.dtype) / dim)) - t = ops.arange(x.shape[position],dtype=x.dtype) + inv_freq = self.scaling_factor / ( + self.max_wavelength ** (ops.arange(0, dim, 2, dtype=x.dtype) / dim) + ) + t = ops.arange(x.shape[position], dtype=x.dtype) freqs = ops.outer(t, inv_freq) emb = ops.concatenate((freqs, freqs), axis=-1) - cos_emb = ops.cos(emb)[None, :,None, :] - sin_emb = ops.sin(emb)[None, :,None, :] + cos_emb = ops.cos(emb)[None, :, None, :] + sin_emb = ops.sin(emb)[None, :, None, :] return cos_emb, sin_emb - def call(self, q, k,position=1): - cos_emb, sin_emb = self._compute_cos_sin_embedding(q,position) + + def call(self, q, k, position=1): + cos_emb, sin_emb = self._compute_cos_sin_embedding(q, position) return ( self.apply_rotary_pos_emb(q, cos_emb, sin_emb), self.apply_rotary_pos_emb(k, cos_emb, sin_emb), ) - def rotate_half(self,x): - x1, x2 = ops.split(x,2,-1) + + def rotate_half(self, x): + x1, x2 = ops.split(x, 2, -1) return ops.concatenate((-x2, x1), axis=-1) - def apply_rotary_pos_emb(self,x, cos, sin): - cos = cos[:, : x.shape[1],:, :] - sin = sin[:, : x.shape[1],:, :] + + def apply_rotary_pos_emb(self, x, cos, sin): + cos = cos[:, : x.shape[1], :, :] + sin = sin[:, : x.shape[1], :, :] return (x * cos) + (self.rotate_half(x) * sin) + class EsmSelfAttention(RoformerAttention): """MultiHeadAttention by ESM2 - + Referred to the implementation of HuggingFace. In fact, this part of the calculation is exactly the same as RoFormer. Only the calculation of the rotary part is different. """ - def __init__(self,use_rotary=True,**kwargs): + + def __init__(self, use_rotary=True, **kwargs): super().__init__(**kwargs) self.use_rotary = use_rotary + def build(self, input_shape): super().build(input_shape) if self.use_rotary: - self.rotary_embedding_layer = ESMRotaryEmbedding( - max_wavelength = self.max_wavelength, dtype=self.dtype_policy + self.rotary_embedding_layer = ESMRotaryEmbedding( + max_wavelength=self.max_wavelength, dtype=self.dtype_policy ) self.rotary_embedding_layer.build([]) + def call(self, x, attention_mask=None): qw = self.q_dense(x) kw = self.k_dense(x) @@ -70,6 +83,7 @@ def call(self, x, attention_mask=None): qw, kw, vw, mask=attention_mask, flash_attention=flash_attention ) return self.o_dense(ops.reshape(o, [b, s, -1])) + def get_config(self): config = super().get_config() config.update( @@ -78,4 +92,3 @@ def get_config(self): } ) return config - diff --git a/keras_hub/src/models/esm/esm_backbone.py b/keras_hub/src/models/esm/esm_backbone.py index 26e0aeeb67..5939ff0ee1 100644 --- a/keras_hub/src/models/esm/esm_backbone.py +++ b/keras_hub/src/models/esm/esm_backbone.py @@ -2,13 +2,18 @@ from keras import activations from keras_hub.src.api_export import keras_hub_export -from keras_hub.src.models.backbone import Backbone from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding +from keras_hub.src.models.backbone import Backbone from keras_hub.src.models.esm.esm_encoder import ESMEncoder + def esm2_kernel_initializer(stddev=0.02): return keras.initializers.TruncatedNormal(stddev=stddev) -@keras_hub_export(["keras_hub.models.ESM2Backbone","keras_hub.models.ESMBackbone"]) + + +@keras_hub_export( + ["keras_hub.models.ESM2Backbone", "keras_hub.models.ESMBackbone"] +) class ESMBackbone(Backbone): """A ESM2 and ESM encoder network. @@ -31,7 +36,7 @@ class ESMBackbone(Backbone): intermediate_dim: int. The output dimension of the first Dense layer in a two-layer feedforward network for each transformer. dropout: float. Dropout probability for the Transformer encoder. - layer_norm_eps:bool.Should we use ln after embedding? + layer_norm_eps:bool.Should we use ln after embedding? Since it's pre-norm, the default is false. max_sequence_length: int. The maximum sequence length that this encoder can consume. If None, `max_sequence_length` uses the value from @@ -74,23 +79,24 @@ def __init__( num_heads, hidden_dim, intermediate_dim, - head_size, - use_bias=False, + use_bias=True, activation="gelu", dropout=0.1, dtype=None, - max_sequence_length = 1024, + max_sequence_length=1024, max_wavelength=10000, - layer_norm_eps = 1e-12, - emb_layer_norm_before = False, - position_embedding_type = "rotary", - pad_token_id = 0, + layer_norm_eps=1e-12, + emb_layer_norm_before=False, + position_embedding_type="rotary", + pad_token_id=0, **kwargs, ): - support_positon_type = ["rotary","absolute"] + support_positon_type = ["rotary", "absolute"] if position_embedding_type.lower() not in support_positon_type: - raise(f"This model only support below position embedding type: {support_positon_type}") - + raise ( + f"This model only support below position embedding type: {support_positon_type}" # noqa: E501 + ) + head_size = hidden_dim // num_heads # === Layers === self.token_embedding = keras.layers.Embedding( input_dim=vocabulary_size, @@ -110,8 +116,7 @@ def __init__( dtype=dtype, name="embeddings_add", ) - - + self.output_layer_norm = keras.layers.LayerNormalization( epsilon=layer_norm_eps, dtype=dtype, @@ -134,9 +139,9 @@ def __init__( dropout=dropout, activation=activation, kernel_initializer=esm2_kernel_initializer(), - layer_norm_eps = layer_norm_eps, + layer_norm_eps=layer_norm_eps, dtype=dtype, - use_rotary=position_embedding_type=="rotary", + use_rotary=position_embedding_type == "rotary", name=f"transformer_layer_{i}", ) self.transformer_layers.append(layer) @@ -145,13 +150,14 @@ def __init__( token_id_input = keras.Input( shape=(None,), dtype="int32", name="token_ids" ) - + attention_mask = keras.ops.not_equal(token_id_input, pad_token_id) - token_vector = self.token_embedding(token_id_input) if position_embedding_type == "absolute": - position_vector = self.position_embedding(token_vector) + position_vector = self.position_embedding( + token_vector, start_index=pad_token_id + ) x = self.embeddings_add([token_vector, position_vector]) else: x = token_vector @@ -187,6 +193,7 @@ def __init__( self.emb_layer_norm_before = emb_layer_norm_before self.position_embedding_type = position_embedding_type self.pad_token_id = pad_token_id + def get_config(self): config = super().get_config() config.update( @@ -198,14 +205,13 @@ def get_config(self): "intermediate_dim": self.intermediate_dim, "dropout": self.dropout, "max_wavelength": self.max_wavelength, - "head_size": self.head_size, "use_bias": self.use_bias, "activation": activations.serialize(self.activation), - "layer_norm_eps":self.layer_norm_eps, - "emb_layer_norm_before":self.emb_layer_norm_before, - "position_embedding_type":self.position_embedding_type, - "max_sequence_length":self.max_sequence_length, - "pad_token_id":self.pad_token_id, + "layer_norm_eps": self.layer_norm_eps, + "emb_layer_norm_before": self.emb_layer_norm_before, + "position_embedding_type": self.position_embedding_type, + "max_sequence_length": self.max_sequence_length, + "pad_token_id": self.pad_token_id, } ) return config diff --git a/keras_hub/src/models/esm/esm_backbone_test.py b/keras_hub/src/models/esm/esm_backbone_test.py index aef8454c1a..91e5227ebe 100644 --- a/keras_hub/src/models/esm/esm_backbone_test.py +++ b/keras_hub/src/models/esm/esm_backbone_test.py @@ -1,9 +1,7 @@ import keras from keras import ops -from keras_hub.src.models.esm.esm_backbone import ( - ESMBackbone, -) +from keras_hub.src.models.esm.esm_backbone import ESMBackbone from keras_hub.src.tests.test_case import TestCase diff --git a/keras_hub/src/models/esm/esm_classifier.py b/keras_hub/src/models/esm/esm_classifier.py index 1c6ca925ff..f6225157c6 100644 --- a/keras_hub/src/models/esm/esm_classifier.py +++ b/keras_hub/src/models/esm/esm_classifier.py @@ -1,13 +1,11 @@ from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.esm.esm_backbone import ESMBackbone +from keras_hub.src.models.esm.esm_classifier_preprocessor import ( + ESMProteinClassifierPreprocessor, +) from keras_hub.src.models.roberta.roberta_text_classifier import ( RobertaTextClassifier, # noqa: E501 ) -from keras_hub.src.models.esm.esm_backbone import ( - ESMBackbone, -) -from keras_hub.src.models.esm.esm_classifier_preprocessor import ( - ESMProteinClassifierPreprocessor -) @keras_hub_export("keras_hub.models.ESMProteinClassifier") diff --git a/keras_hub/src/models/esm/esm_classifier_preprocessor.py b/keras_hub/src/models/esm/esm_classifier_preprocessor.py index 42259d012f..bb45495921 100644 --- a/keras_hub/src/models/esm/esm_classifier_preprocessor.py +++ b/keras_hub/src/models/esm/esm_classifier_preprocessor.py @@ -1,20 +1,15 @@ import keras from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.preprocessing.start_end_packer import StartEndPacker from keras_hub.src.models.bert.bert_text_classifier_preprocessor import ( BertTextClassifierPreprocessor, ) -from keras_hub.src.models.esm.esm_backbone import ( - ESMBackbone, -) -from keras_hub.src.models.esm.esm_tokenizer import ( - ESMTokenizer, -) +from keras_hub.src.models.esm.esm_backbone import ESMBackbone +from keras_hub.src.models.esm.esm_tokenizer import ESMTokenizer from keras_hub.src.utils.tensor_utils import preprocessing_function -from keras_hub.src.layers.preprocessing.start_end_packer import ( - StartEndPacker, -) + @keras_hub_export("keras_hub.models.ESMProteinClassifierPreprocessor") class ESMProteinClassifierPreprocessor(BertTextClassifierPreprocessor): """A ESM preprocessing layer which tokenizes and packs inputs. @@ -117,6 +112,7 @@ class ESMProteinClassifierPreprocessor(BertTextClassifierPreprocessor): backbone_cls = ESMBackbone tokenizer_cls = ESMTokenizer + def build(self, input_shape): super().build(input_shape) # Defer masker creation to `build()` so that we can be sure tokenizer @@ -127,6 +123,7 @@ def build(self, input_shape): pad_value=self.tokenizer.pad_token_id, sequence_length=self.sequence_length, ) + @preprocessing_function def call(self, x, y=None, sample_weight=None): x = self.tokenizer(x) diff --git a/keras_hub/src/models/esm/esm_classifier_preprocessor_test.py b/keras_hub/src/models/esm/esm_classifier_preprocessor_test.py index 93ecce4f4d..9868a009f9 100644 --- a/keras_hub/src/models/esm/esm_classifier_preprocessor_test.py +++ b/keras_hub/src/models/esm/esm_classifier_preprocessor_test.py @@ -1,17 +1,13 @@ - -from keras_hub.src.models.esm.esm_tokenizer import ( - ESMTokenizer, -) -from keras_hub.src.tests.test_case import TestCase from keras_hub.src.models.esm.esm_classifier_preprocessor import ( - ESMProteinClassifierPreprocessor + ESMProteinClassifierPreprocessor, ) - +from keras_hub.src.models.esm.esm_tokenizer import ESMTokenizer +from keras_hub.src.tests.test_case import TestCase class ESMProteinClassifierPreprocessorTest(TestCase): def setUp(self): - self.vocab = [ "[UNK]","[PAD]", "[CLS]", "[SEP]", "[MASK]"] + self.vocab = ["[UNK]", "[PAD]", "[CLS]", "[SEP]", "[MASK]"] self.vocab += ["THE", "QUICK", "BROWN", "FOX"] self.vocab += ["the", "quick", "brown", "fox"] self.tokenizer = ESMTokenizer(vocabulary=self.vocab) diff --git a/keras_hub/src/models/esm/esm_classifier_test.py b/keras_hub/src/models/esm/esm_classifier_test.py index 92474ab0ab..d410f23a0c 100644 --- a/keras_hub/src/models/esm/esm_classifier_test.py +++ b/keras_hub/src/models/esm/esm_classifier_test.py @@ -1,25 +1,14 @@ import keras -from keras_hub.src.models.roformer_v2 import ( - roformer_v2_text_classifier_preprocessor as r, -) -from keras_hub.src.models.esm.esm_backbone import ( - ESMBackbone, -) -from keras_hub.src.models.esm.esm_tokenizer import ( - ESMTokenizer, -) -from keras_hub.src.models.esm.esm_classifier import ( - ESMProteinClassifier, -) +from keras_hub.src.models.esm.esm_backbone import ESMBackbone +from keras_hub.src.models.esm.esm_classifier import ESMProteinClassifier from keras_hub.src.models.esm.esm_classifier_preprocessor import ( - ESMProteinClassifierPreprocessor + ESMProteinClassifierPreprocessor, ) +from keras_hub.src.models.esm.esm_tokenizer import ESMTokenizer from keras_hub.src.tests.test_case import TestCase - - class RoformerVTextClassifierTest(TestCase): def setUp(self): # Setup model. diff --git a/keras_hub/src/models/esm/esm_encoder.py b/keras_hub/src/models/esm/esm_encoder.py index ab9fbf5f53..f482913cf6 100644 --- a/keras_hub/src/models/esm/esm_encoder.py +++ b/keras_hub/src/models/esm/esm_encoder.py @@ -5,10 +5,9 @@ from keras_hub.src.models.esm.esm_attention import EsmSelfAttention - class ESMEncoder(keras.layers.Layer): """MultiHeadAttention by ESM - + Referred to the implementation of HuggingFace. reference: https://github.com/huggingface/transformers/ @@ -25,8 +24,8 @@ def __init__( activation="gelu", use_bias=False, kernel_initializer="glorot_uniform", - layer_norm_eps = 1e-12, - use_rotary = True, + layer_norm_eps=1e-12, + use_rotary=True, **kwargs, ): super().__init__(**kwargs) @@ -50,7 +49,7 @@ def build(self, input_shape): max_wavelength=self.max_wavelength, kernel_initializer=self.kernel_initializer, dtype=self.dtype_policy, - use_rotary = self.use_rotary, + use_rotary=self.use_rotary, name="attention_layer", ) self.attention_layer.build(input_shape) @@ -84,7 +83,7 @@ def build(self, input_shape): self.feedforward_output_dense.build( [None, None, self.intermediate_size] ) - import torch + self.attention_norm = keras.layers.LayerNormalization( epsilon=self.layer_norm_eps, name="attention_norm", @@ -100,18 +99,16 @@ def build(self, input_shape): self.feedforward_norm.build(input_shape) def call(self, x, attention_mask=None): - attention_output = self.attention_layer( self.attention_norm(self.dropout_layer(x)), attention_mask=attention_mask, ) residual = x + attention_output - + x = self.feedforward_norm(self.dropout_layer(residual)) intermediate_output = self.feedforward_intermediate_dense(x) feedroward_output = self.feedforward_output_dense(intermediate_output) - return residual + self.dropout_layer(feedroward_output) - + return residual + self.dropout_layer(feedroward_output) def compute_output_shape(self, input_shape): return input_shape @@ -127,8 +124,8 @@ def get_config(self): "use_bias": self.use_bias, "activation": activations.serialize(self.activation), "dropout": self.dropout, - "layer_norm_eps":self.layer_norm_eps, - "use_rotary":self.use_rotary, + "layer_norm_eps": self.layer_norm_eps, + "use_rotary": self.use_rotary, "kernel_initializer": initializers.serialize( self.kernel_initializer ), diff --git a/keras_hub/src/models/esm/esm_masked_plm.py b/keras_hub/src/models/esm/esm_masked_plm.py index 42fa26f297..d7bc609e54 100644 --- a/keras_hub/src/models/esm/esm_masked_plm.py +++ b/keras_hub/src/models/esm/esm_masked_plm.py @@ -2,16 +2,17 @@ from keras_hub.src.api_export import keras_hub_export from keras_hub.src.layers.modeling.masked_lm_head import MaskedLMHead -from keras_hub.src.models.masked_lm import MaskedLM -from keras_hub.src.models.esm.esm_backbone import ( - ESMBackbone,esm2_kernel_initializer -) +from keras_hub.src.models.esm.esm_backbone import ESMBackbone +from keras_hub.src.models.esm.esm_backbone import esm2_kernel_initializer from keras_hub.src.models.esm.esm_masked_plm_preprocessor import ( ESMMaskedPLMPreprocessor, ) +from keras_hub.src.models.masked_lm import MaskedLM -@keras_hub_export(["keras_hub.models.ESM2MaskedPLM","keras_hub.models.ESMMaskedPLM"]) +@keras_hub_export( + ["keras_hub.models.ESM2MaskedPLM", "keras_hub.models.ESMMaskedPLM"] +) class ESMMaskedPLM(MaskedLM): """An end-to-end ESM2 model for the masked protein language modeling task. @@ -82,6 +83,7 @@ class ESMMaskedPLM(MaskedLM): backbone_cls = ESMBackbone preprocessor_cls = ESMMaskedPLMPreprocessor + def __init__( self, backbone, @@ -96,7 +98,7 @@ def __init__( intermediate_activation=backbone.activation, kernel_initializer=esm2_kernel_initializer(), dtype=backbone.dtype_policy, - layer_norm_epsilon = backbone.layer_norm_eps, + layer_norm_epsilon=backbone.layer_norm_eps, name="mlm_head", ) @@ -108,7 +110,7 @@ def __init__( ), } backbone_outputs = backbone(backbone.input) - + outputs = self.masked_lm_head( backbone_outputs, inputs["mask_positions"] ) diff --git a/keras_hub/src/models/esm/esm_masked_plm_preprocessor.py b/keras_hub/src/models/esm/esm_masked_plm_preprocessor.py index 553643f62f..b24b30bba7 100644 --- a/keras_hub/src/models/esm/esm_masked_plm_preprocessor.py +++ b/keras_hub/src/models/esm/esm_masked_plm_preprocessor.py @@ -1,21 +1,16 @@ import keras from keras_hub.src.api_export import keras_hub_export -from keras_hub.src.models.masked_lm_preprocessor import MaskedLMPreprocessor -from keras_hub.src.models.esm.esm_backbone import ( - ESMBackbone, -) -from keras_hub.src.models.esm.esm_tokenizer import ( - ESMTokenizer, -) -from keras_hub.src.utils.tensor_utils import preprocessing_function - from keras_hub.src.layers.preprocessing.masked_lm_mask_generator import ( MaskedLMMaskGenerator, ) -from keras_hub.src.layers.preprocessing.start_end_packer import ( - StartEndPacker, -) +from keras_hub.src.layers.preprocessing.start_end_packer import StartEndPacker +from keras_hub.src.models.esm.esm_backbone import ESMBackbone +from keras_hub.src.models.esm.esm_tokenizer import ESMTokenizer +from keras_hub.src.models.masked_lm_preprocessor import MaskedLMPreprocessor +from keras_hub.src.utils.tensor_utils import preprocessing_function + + @keras_hub_export("keras_hub.models.ESMMaskedPLMPreprocessor") class ESMMaskedPLMPreprocessor(MaskedLMPreprocessor): """ESM preprocessing for the masked language modeling task. @@ -113,6 +108,7 @@ class ESMMaskedPLMPreprocessor(MaskedLMPreprocessor): backbone_cls = ESMBackbone tokenizer_cls = ESMTokenizer + def build(self, input_shape): super().build(input_shape) # Defer masker creation to `build()` so that we can be sure tokenizer @@ -133,7 +129,6 @@ def build(self, input_shape): unselectable_token_ids=self.tokenizer.special_token_ids, ) - @preprocessing_function def call(self, x, y=None, sample_weight=None): x = self.tokenizer(x) diff --git a/keras_hub/src/models/esm/esm_masked_plm_preprocessor_test.py b/keras_hub/src/models/esm/esm_masked_plm_preprocessor_test.py index 316927f247..05b9fea70c 100644 --- a/keras_hub/src/models/esm/esm_masked_plm_preprocessor_test.py +++ b/keras_hub/src/models/esm/esm_masked_plm_preprocessor_test.py @@ -1,15 +1,13 @@ from keras_hub.src.models.esm.esm_masked_plm_preprocessor import ( ESMMaskedPLMPreprocessor, ) -from keras_hub.src.models.esm.esm_tokenizer import ( - ESMTokenizer, -) +from keras_hub.src.models.esm.esm_tokenizer import ESMTokenizer from keras_hub.src.tests.test_case import TestCase class ESMMaskedPLMPreprocessort(TestCase): def setUp(self): - self.vocab = [ "[UNK]", "[PAD]","[CLS]", "[SEP]", "[MASK]"] + self.vocab = ["[UNK]", "[PAD]", "[CLS]", "[SEP]", "[MASK]"] self.vocab += ["THE", "QUICK", "BROWN", "FOX"] self.vocab += ["the", "quick", "brown", "fox"] self.tokenizer = ESMTokenizer(vocabulary=self.vocab) diff --git a/keras_hub/src/models/esm/esm_masked_plm_test.py b/keras_hub/src/models/esm/esm_masked_plm_test.py index c601d15638..fe6f1996f9 100644 --- a/keras_hub/src/models/esm/esm_masked_plm_test.py +++ b/keras_hub/src/models/esm/esm_masked_plm_test.py @@ -1,17 +1,11 @@ import keras -from keras_hub.src.models.esm.esm_backbone import ( - ESMBackbone, -) -from keras_hub.src.models.esm.esm_masked_plm import ( - ESMMaskedPLM -) +from keras_hub.src.models.esm.esm_backbone import ESMBackbone +from keras_hub.src.models.esm.esm_masked_plm import ESMMaskedPLM from keras_hub.src.models.esm.esm_masked_plm_preprocessor import ( ESMMaskedPLMPreprocessor, ) -from keras_hub.src.models.esm.esm_tokenizer import ( - ESMTokenizer, -) +from keras_hub.src.models.esm.esm_tokenizer import ESMTokenizer from keras_hub.src.tests.test_case import TestCase diff --git a/keras_hub/src/models/esm/esm_tokenizer.py b/keras_hub/src/models/esm/esm_tokenizer.py index 7898200dad..684f8ddf1c 100644 --- a/keras_hub/src/models/esm/esm_tokenizer.py +++ b/keras_hub/src/models/esm/esm_tokenizer.py @@ -1,8 +1,6 @@ from keras_hub.src.api_export import keras_hub_export -from keras_hub.src.models.bert.bert_tokenizer import BertTokenizer -from keras_hub.src.models.esm.esm_backbone import ( - ESMBackbone, -) +from keras_hub.src.models.esm.esm_backbone import ESMBackbone +from keras_hub.src.tokenizers.word_piece_tokenizer import WordPieceTokenizer @keras_hub_export( @@ -11,7 +9,7 @@ "keras_hub.models.ESMTokenizer", ] ) -class ESMTokenizer(BertTokenizer): +class ESMTokenizer(WordPieceTokenizer): """A ESM tokenizer using WordPiece subword segmentation. This tokenizer class will tokenize raw strings into integer sequences and @@ -52,7 +50,7 @@ class ESMTokenizer(BertTokenizer): tokenizer.detokenize(tokenizer("The quick brown fox jumped.")) # Custom vocabulary. - vocab = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"] + vocab = ["[UNK]", "", "", "", ""] vocab += ["The", "quick", "brown", "fox", "jumped", "."] tokenizer = keras_hub.models.ESMTokenizer(vocabulary=vocab) tokenizer("The quick brown fox jumped.") @@ -60,3 +58,25 @@ class ESMTokenizer(BertTokenizer): """ backbone_cls = ESMBackbone + + def __init__( + self, + vocabulary=None, + lowercase=False, + oov_token="", + **kwargs, + ): + self._add_special_token("", "cls_token") + self._add_special_token("", "sep_token") + self._add_special_token("", "pad_token") + self._add_special_token("", "mask_token") + # Also add `tokenizer.start_token` and `tokenizer.end_token` for + # compatibility with other tokenizers. + self._add_special_token("", "start_token") + self._add_special_token("", "end_token") + super().__init__( + vocabulary=vocabulary, + lowercase=lowercase, + oov_token=oov_token, + **kwargs, + ) diff --git a/keras_hub/src/models/esm/esm_tokenizer_test.py b/keras_hub/src/models/esm/esm_tokenizer_test.py index 89ec34032b..218c26d7b0 100644 --- a/keras_hub/src/models/esm/esm_tokenizer_test.py +++ b/keras_hub/src/models/esm/esm_tokenizer_test.py @@ -1,12 +1,10 @@ -from keras_hub.src.models.esm.esm_tokenizer import ( - ESMTokenizer, -) +from keras_hub.src.models.esm.esm_tokenizer import ESMTokenizer from keras_hub.src.tests.test_case import TestCase class ESMTokenizerTest(TestCase): def setUp(self): - self.vocab = [ "[UNK]", "[PAD]","[CLS]", "[SEP]", "[MASK]"] + self.vocab = ["", "", "", "", ""] self.vocab += ["THE", "QUICK", "BROWN", "FOX"] self.vocab += ["the", "quick", "brown", "fox"] self.init_kwargs = {"vocabulary": self.vocab} @@ -26,7 +24,7 @@ def test_lowercase(self): self.assertAllEqual(output, [[9, 10, 11, 12], [9, 12]]) def test_tokenizer_special_tokens(self): - input_data = ["[CLS] THE [MASK] FOX [SEP] [PAD]"] + input_data = [" THE FOX "] tokenizer = ESMTokenizer( **self.init_kwargs, special_tokens_in_strings=True ) diff --git a/keras_hub/src/utils/transformers/convert_esm.py b/keras_hub/src/utils/transformers/convert_esm.py index 8e5ce17584..c35de05b86 100644 --- a/keras_hub/src/utils/transformers/convert_esm.py +++ b/keras_hub/src/utils/transformers/convert_esm.py @@ -1,12 +1,12 @@ import numpy as np -from keras_hub.src.models.vit.vit_backbone import ViTBackbone +from keras_hub.src.models.esm.esm_backbone import ESMBackbone +from keras_hub.src.utils.preset_utils import get_file -backbone_cls = ViTBackbone +backbone_cls = ESMBackbone def convert_backbone_config(transformers_config): - return { "vocabulary_size": transformers_config["vocab_size"], "num_layers": transformers_config["num_hidden_layers"], @@ -14,140 +14,152 @@ def convert_backbone_config(transformers_config): "hidden_dim": transformers_config["hidden_size"], "intermediate_dim": transformers_config["intermediate_size"], "dropout": transformers_config["hidden_dropout_prob"], - "position_embedding_type": transformers_config["position_embedding_type"], + "position_embedding_type": transformers_config[ + "position_embedding_type" + ], "pad_token_id": transformers_config["pad_token_id"], - "max_sequence_length": transformers_config.get("max_position_embeddings", None), # 默认值为None - "layer_norm_eps": transformers_config.get("layer_norm_eps", 1e-12), # 默认值为1e-12 - "emb_layer_norm_before": transformers_config.get("emb_layer_norm_before", False), # 默认值为False - "head_size": transformers_config.get("head_size", 64), # 默认值为64 - "activation": transformers_config.get("activation", "gelu"), # 默认值为"gelu" - "max_wavelength": transformers_config.get("max_wavelength", 10000), # 默认值为10000 + "max_sequence_length": transformers_config.get( + "max_position_embeddings", None + ), # 默认值为None + "layer_norm_eps": transformers_config.get( + "layer_norm_eps", 1e-12 + ), # 默认值为1e-12 + "emb_layer_norm_before": transformers_config.get( + "emb_layer_norm_before", False + ), # 默认值为False + "activation": transformers_config.get( + "activation", "gelu" + ), # 默认值为"gelu" + "max_wavelength": transformers_config.get( + "max_wavelength", 10000 + ), # 默认值为10000 } +def transpose_and_reshape(x, shape): + return np.reshape(np.transpose(x), shape) + + def convert_weights(backbone, loader, transformers_config): # Embedding layer loader.port_weight( keras_variable=backbone.get_layer("token_embedding").embeddings, - hf_weight_key="bert.embeddings.word_embeddings.weight", - ) - if transformers_config["position_embedding_type"]=="absolute": - pass - loader.port_weight( - keras_variable=backbone.get_layer( - "position_embedding" - ).position_embeddings, - hf_weight_key="bert.embeddings.position_embeddings.weight", - ) - loader.port_weight( - keras_variable=backbone.get_layer("segment_embedding").embeddings, - hf_weight_key="bert.embeddings.token_type_embeddings.weight", + hf_weight_key="embeddings.word_embeddings.weight", ) + if transformers_config["position_embedding_type"] == "absolute": + loader.port_weight( + keras_variable=backbone.get_layer( + "position_embedding" + ).position_embeddings, + hf_weight_key="embeddings.position_embeddings.weight", + ) + if transformers_config.get("emb_layer_norm_before", False): + loader.port_weight( + keras_variable=backbone.get_layer("embeddings_layer_norm").gamma, + hf_weight_key="embeddings.layer_norm.weight", + ) + loader.port_weight( + keras_variable=backbone.get_layer("embeddings_layer_norm").beta, + hf_weight_key="embeddings.layer_norm.bias", + ) + loader.port_weight( - keras_variable=backbone.get_layer("embeddings_layer_norm").beta, - hf_weight_key="bert.embeddings.LayerNorm.beta", + keras_variable=backbone.output_layer_norm.gamma, + hf_weight_key="encoder.emb_layer_norm_after.weight", ) loader.port_weight( - keras_variable=backbone.get_layer("embeddings_layer_norm").gamma, - hf_weight_key="bert.embeddings.LayerNorm.gamma", + keras_variable=backbone.output_layer_norm.beta, + hf_weight_key="encoder.emb_layer_norm_after.bias", ) - def transpose_and_reshape(x, shape): - return np.reshape(np.transpose(x), shape) - # Attention blocks for i in range(backbone.num_layers): block = backbone.get_layer(f"transformer_layer_{i}") - attn = block._self_attention_layer - hf_prefix = "bert.encoder.layer." + attn = block.attention_layer + hf_prefix = "encoder.layer." # Attention layers loader.port_weight( - keras_variable=attn.query_dense.kernel, + keras_variable=attn.q_dense.kernel, hf_weight_key=f"{hf_prefix}{i}.attention.self.query.weight", hook_fn=transpose_and_reshape, ) loader.port_weight( - keras_variable=attn.query_dense.bias, + keras_variable=attn.q_dense.bias, hf_weight_key=f"{hf_prefix}{i}.attention.self.query.bias", hook_fn=lambda hf_tensor, shape: np.reshape(hf_tensor, shape), ) loader.port_weight( - keras_variable=attn.key_dense.kernel, + keras_variable=attn.k_dense.kernel, hf_weight_key=f"{hf_prefix}{i}.attention.self.key.weight", hook_fn=transpose_and_reshape, ) loader.port_weight( - keras_variable=attn.key_dense.bias, + keras_variable=attn.k_dense.bias, hf_weight_key=f"{hf_prefix}{i}.attention.self.key.bias", hook_fn=lambda hf_tensor, shape: np.reshape(hf_tensor, shape), ) loader.port_weight( - keras_variable=attn.value_dense.kernel, + keras_variable=attn.v_dense.kernel, hf_weight_key=f"{hf_prefix}{i}.attention.self.value.weight", hook_fn=transpose_and_reshape, ) loader.port_weight( - keras_variable=attn.value_dense.bias, + keras_variable=attn.v_dense.bias, hf_weight_key=f"{hf_prefix}{i}.attention.self.value.bias", hook_fn=lambda hf_tensor, shape: np.reshape(hf_tensor, shape), ) loader.port_weight( - keras_variable=attn.output_dense.kernel, + keras_variable=attn.o_dense.kernel, hf_weight_key=f"{hf_prefix}{i}.attention.output.dense.weight", hook_fn=transpose_and_reshape, ) loader.port_weight( - keras_variable=attn.output_dense.bias, + keras_variable=attn.o_dense.bias, hf_weight_key=f"{hf_prefix}{i}.attention.output.dense.bias", hook_fn=lambda hf_tensor, shape: np.reshape(hf_tensor, shape), ) # Attention layer norm. loader.port_weight( - keras_variable=block._self_attention_layer_norm.beta, - hf_weight_key=f"{hf_prefix}{i}.attention.output.LayerNorm.beta", + keras_variable=block.attention_norm.gamma, + hf_weight_key=f"{hf_prefix}{i}.attention.LayerNorm.weight", ) loader.port_weight( - keras_variable=block._self_attention_layer_norm.gamma, - hf_weight_key=f"{hf_prefix}{i}.attention.output.LayerNorm.gamma", + keras_variable=block.attention_norm.beta, + hf_weight_key=f"{hf_prefix}{i}.attention.LayerNorm.bias", ) # MLP layers loader.port_weight( - keras_variable=block._feedforward_intermediate_dense.kernel, + keras_variable=block.feedforward_intermediate_dense.kernel, hf_weight_key=f"{hf_prefix}{i}.intermediate.dense.weight", hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)), ) loader.port_weight( - keras_variable=block._feedforward_intermediate_dense.bias, + keras_variable=block.feedforward_intermediate_dense.bias, hf_weight_key=f"{hf_prefix}{i}.intermediate.dense.bias", ) loader.port_weight( - keras_variable=block._feedforward_output_dense.kernel, + keras_variable=block.feedforward_output_dense.kernel, hf_weight_key=f"{hf_prefix}{i}.output.dense.weight", hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)), ) loader.port_weight( - keras_variable=block._feedforward_output_dense.bias, + keras_variable=block.feedforward_output_dense.bias, hf_weight_key=f"{hf_prefix}{i}.output.dense.bias", ) # Output layer norm. loader.port_weight( - keras_variable=block._feedforward_layer_norm.beta, - hf_weight_key=f"{hf_prefix}{i}.output.LayerNorm.beta", + keras_variable=block.feedforward_norm.gamma, + hf_weight_key=f"{hf_prefix}{i}.LayerNorm.weight", ) loader.port_weight( - keras_variable=block._feedforward_layer_norm.gamma, - hf_weight_key=f"{hf_prefix}{i}.output.LayerNorm.gamma", + keras_variable=block.feedforward_norm.beta, + hf_weight_key=f"{hf_prefix}{i}.LayerNorm.bias", ) -def convert_head(task, loader, transformers_config): - prefix = "classifier." - loader.port_weight( - task.output_dense.kernel, - hf_weight_key=prefix + "weight", - hook_fn=lambda x, _: x.T, - ) - loader.port_weight( - task.output_dense.bias, - hf_weight_key=prefix + "bias", +def convert_tokenizer(cls, preset, **kwargs): + return cls( + get_file(preset, "vocab.txt"), + lowercase=True, + **kwargs, ) diff --git a/keras_hub/src/utils/transformers/preset_loader.py b/keras_hub/src/utils/transformers/preset_loader.py index 0d58747631..90c81be5d6 100644 --- a/keras_hub/src/utils/transformers/preset_loader.py +++ b/keras_hub/src/utils/transformers/preset_loader.py @@ -7,6 +7,7 @@ from keras_hub.src.utils.transformers import convert_bart from keras_hub.src.utils.transformers import convert_bert from keras_hub.src.utils.transformers import convert_distilbert +from keras_hub.src.utils.transformers import convert_esm from keras_hub.src.utils.transformers import convert_gemma from keras_hub.src.utils.transformers import convert_gpt2 from keras_hub.src.utils.transformers import convert_llama3 @@ -29,6 +30,8 @@ def __init__(self, preset, config): self.converter = convert_bert elif model_type == "distilbert": self.converter = convert_distilbert + elif model_type == "esm": + self.converter = convert_esm elif model_type == "gemma" or model_type == "gemma2": self.converter = convert_gemma elif model_type == "gpt2": From d3f598d0119d3c6083281f0dbc4e22a1f2fac042 Mon Sep 17 00:00:00 2001 From: pass_lin <935499957@qq.com> Date: Sat, 3 May 2025 17:35:57 +0800 Subject: [PATCH 03/13] fix --- keras_hub/src/models/esm/esm_classifier_preprocessor_test.py | 2 +- keras_hub/src/models/esm/esm_classifier_test.py | 2 +- keras_hub/src/models/esm/esm_masked_plm_preprocessor_test.py | 2 +- keras_hub/src/models/esm/esm_masked_plm_test.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/keras_hub/src/models/esm/esm_classifier_preprocessor_test.py b/keras_hub/src/models/esm/esm_classifier_preprocessor_test.py index 9868a009f9..85e3e0835f 100644 --- a/keras_hub/src/models/esm/esm_classifier_preprocessor_test.py +++ b/keras_hub/src/models/esm/esm_classifier_preprocessor_test.py @@ -7,7 +7,7 @@ class ESMProteinClassifierPreprocessorTest(TestCase): def setUp(self): - self.vocab = ["[UNK]", "[PAD]", "[CLS]", "[SEP]", "[MASK]"] + self.vocab = ["", "", "", "", ""] self.vocab += ["THE", "QUICK", "BROWN", "FOX"] self.vocab += ["the", "quick", "brown", "fox"] self.tokenizer = ESMTokenizer(vocabulary=self.vocab) diff --git a/keras_hub/src/models/esm/esm_classifier_test.py b/keras_hub/src/models/esm/esm_classifier_test.py index d410f23a0c..cf55cc8285 100644 --- a/keras_hub/src/models/esm/esm_classifier_test.py +++ b/keras_hub/src/models/esm/esm_classifier_test.py @@ -12,7 +12,7 @@ class RoformerVTextClassifierTest(TestCase): def setUp(self): # Setup model. - self.vocab = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"] + self.vocab = ["", "", "", "", ""] self.vocab += ["the", "quick", "brown", "fox", "."] self.preprocessor = ESMProteinClassifierPreprocessor( ESMTokenizer(vocabulary=self.vocab), diff --git a/keras_hub/src/models/esm/esm_masked_plm_preprocessor_test.py b/keras_hub/src/models/esm/esm_masked_plm_preprocessor_test.py index 05b9fea70c..6839dc3352 100644 --- a/keras_hub/src/models/esm/esm_masked_plm_preprocessor_test.py +++ b/keras_hub/src/models/esm/esm_masked_plm_preprocessor_test.py @@ -7,7 +7,7 @@ class ESMMaskedPLMPreprocessort(TestCase): def setUp(self): - self.vocab = ["[UNK]", "[PAD]", "[CLS]", "[SEP]", "[MASK]"] + self.vocab = ["", "", "", "", ""] self.vocab += ["THE", "QUICK", "BROWN", "FOX"] self.vocab += ["the", "quick", "brown", "fox"] self.tokenizer = ESMTokenizer(vocabulary=self.vocab) diff --git a/keras_hub/src/models/esm/esm_masked_plm_test.py b/keras_hub/src/models/esm/esm_masked_plm_test.py index fe6f1996f9..92a76cbb1b 100644 --- a/keras_hub/src/models/esm/esm_masked_plm_test.py +++ b/keras_hub/src/models/esm/esm_masked_plm_test.py @@ -12,7 +12,7 @@ class ESMMaskedLMTest(TestCase): def setUp(self): # Setup model. - self.vocab = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"] + self.vocab = ["", "", "", "", ""] self.vocab += ["the", "quick", "brown", "fox", "."] self.preprocessor = ESMMaskedPLMPreprocessor( ESMTokenizer(vocabulary=self.vocab), From 737a1473b7b694cfd2336e4449e61ba26f6f6d40 Mon Sep 17 00:00:00 2001 From: pass_lin <935499957@qq.com> Date: Sat, 3 May 2025 18:01:25 +0800 Subject: [PATCH 04/13] fix --- keras_hub/api/__init__.py | 3 +- keras_hub/api/layers/__init__.py | 81 ++--- keras_hub/api/metrics/__init__.py | 1 + keras_hub/api/models/__init__.py | 445 +++++++-------------------- keras_hub/api/samplers/__init__.py | 1 + keras_hub/api/tokenizers/__init__.py | 41 +-- keras_hub/api/utils/__init__.py | 5 +- 7 files changed, 150 insertions(+), 427 deletions(-) diff --git a/keras_hub/api/__init__.py b/keras_hub/api/__init__.py index 3796e4c7f4..72580b9101 100644 --- a/keras_hub/api/__init__.py +++ b/keras_hub/api/__init__.py @@ -4,6 +4,7 @@ since your modifications would be overwritten. """ + from keras_hub import layers from keras_hub import metrics from keras_hub import models @@ -11,5 +12,5 @@ from keras_hub import tokenizers from keras_hub import utils from keras_hub.src.utils.preset_utils import upload_preset -from keras_hub.src.version import __version__ as __version__ from keras_hub.src.version import version +from keras_hub.src.version import __version__ as __version__ diff --git a/keras_hub/api/layers/__init__.py b/keras_hub/api/layers/__init__.py index d42af86a3c..76be91cc4c 100644 --- a/keras_hub/api/layers/__init__.py +++ b/keras_hub/api/layers/__init__.py @@ -4,86 +4,47 @@ since your modifications would be overwritten. """ + from keras_hub.src.layers.modeling.alibi_bias import AlibiBias from keras_hub.src.layers.modeling.anchor_generator import AnchorGenerator from keras_hub.src.layers.modeling.box_matcher import BoxMatcher -from keras_hub.src.layers.modeling.cached_multi_head_attention import ( - CachedMultiHeadAttention, -) +from keras_hub.src.layers.modeling.cached_multi_head_attention import CachedMultiHeadAttention from keras_hub.src.layers.modeling.f_net_encoder import FNetEncoder from keras_hub.src.layers.modeling.masked_lm_head import MaskedLMHead from keras_hub.src.layers.modeling.non_max_supression import NonMaxSuppression from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding -from keras_hub.src.layers.modeling.reversible_embedding import ( - ReversibleEmbedding, -) +from keras_hub.src.layers.modeling.reversible_embedding import ReversibleEmbedding from keras_hub.src.layers.modeling.rms_normalization import RMSNormalization from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding -from keras_hub.src.layers.modeling.sine_position_encoding import ( - SinePositionEncoding, -) -from keras_hub.src.layers.modeling.token_and_position_embedding import ( - TokenAndPositionEmbedding, -) +from keras_hub.src.layers.modeling.sine_position_encoding import SinePositionEncoding +from keras_hub.src.layers.modeling.token_and_position_embedding import TokenAndPositionEmbedding from keras_hub.src.layers.modeling.transformer_decoder import TransformerDecoder from keras_hub.src.layers.modeling.transformer_encoder import TransformerEncoder from keras_hub.src.layers.preprocessing.audio_converter import AudioConverter from keras_hub.src.layers.preprocessing.image_converter import ImageConverter -from keras_hub.src.layers.preprocessing.masked_lm_mask_generator import ( - MaskedLMMaskGenerator, -) -from keras_hub.src.layers.preprocessing.multi_segment_packer import ( - MultiSegmentPacker, -) +from keras_hub.src.layers.preprocessing.masked_lm_mask_generator import MaskedLMMaskGenerator +from keras_hub.src.layers.preprocessing.multi_segment_packer import MultiSegmentPacker from keras_hub.src.layers.preprocessing.random_deletion import RandomDeletion from keras_hub.src.layers.preprocessing.random_swap import RandomSwap from keras_hub.src.layers.preprocessing.start_end_packer import StartEndPacker -from keras_hub.src.models.basnet.basnet_image_converter import ( - BASNetImageConverter, -) +from keras_hub.src.models.basnet.basnet_image_converter import BASNetImageConverter from keras_hub.src.models.clip.clip_image_converter import CLIPImageConverter -from keras_hub.src.models.cspnet.cspnet_image_converter import ( - CSPNetImageConverter, -) -from keras_hub.src.models.deeplab_v3.deeplab_v3_image_converter import ( - DeepLabV3ImageConverter, -) -from keras_hub.src.models.densenet.densenet_image_converter import ( - DenseNetImageConverter, -) -from keras_hub.src.models.efficientnet.efficientnet_image_converter import ( - EfficientNetImageConverter, -) -from keras_hub.src.models.gemma3.gemma3_image_converter import ( - Gemma3ImageConverter, -) +from keras_hub.src.models.cspnet.cspnet_image_converter import CSPNetImageConverter +from keras_hub.src.models.deeplab_v3.deeplab_v3_image_converter import DeepLabV3ImageConverter +from keras_hub.src.models.densenet.densenet_image_converter import DenseNetImageConverter +from keras_hub.src.models.efficientnet.efficientnet_image_converter import EfficientNetImageConverter +from keras_hub.src.models.gemma3.gemma3_image_converter import Gemma3ImageConverter from keras_hub.src.models.mit.mit_image_converter import MiTImageConverter -from keras_hub.src.models.mobilenet.mobilenet_image_converter import ( - MobileNetImageConverter, -) -from keras_hub.src.models.pali_gemma.pali_gemma_image_converter import ( - PaliGemmaImageConverter, -) -from keras_hub.src.models.resnet.resnet_image_converter import ( - ResNetImageConverter, -) -from keras_hub.src.models.retinanet.retinanet_image_converter import ( - RetinaNetImageConverter, -) +from keras_hub.src.models.mobilenet.mobilenet_image_converter import MobileNetImageConverter +from keras_hub.src.models.pali_gemma.pali_gemma_image_converter import PaliGemmaImageConverter +from keras_hub.src.models.resnet.resnet_image_converter import ResNetImageConverter +from keras_hub.src.models.retinanet.retinanet_image_converter import RetinaNetImageConverter from keras_hub.src.models.sam.sam_image_converter import SAMImageConverter from keras_hub.src.models.sam.sam_mask_decoder import SAMMaskDecoder from keras_hub.src.models.sam.sam_prompt_encoder import SAMPromptEncoder -from keras_hub.src.models.segformer.segformer_image_converter import ( - SegFormerImageConverter, -) -from keras_hub.src.models.siglip.siglip_image_converter import ( - SigLIPImageConverter, -) +from keras_hub.src.models.segformer.segformer_image_converter import SegFormerImageConverter +from keras_hub.src.models.siglip.siglip_image_converter import SigLIPImageConverter from keras_hub.src.models.vgg.vgg_image_converter import VGGImageConverter from keras_hub.src.models.vit.vit_image_converter import ViTImageConverter -from keras_hub.src.models.whisper.whisper_audio_converter import ( - WhisperAudioConverter, -) -from keras_hub.src.models.xception.xception_image_converter import ( - XceptionImageConverter, -) +from keras_hub.src.models.whisper.whisper_audio_converter import WhisperAudioConverter +from keras_hub.src.models.xception.xception_image_converter import XceptionImageConverter diff --git a/keras_hub/api/metrics/__init__.py b/keras_hub/api/metrics/__init__.py index 88a0a7df2b..636c3117d1 100644 --- a/keras_hub/api/metrics/__init__.py +++ b/keras_hub/api/metrics/__init__.py @@ -4,6 +4,7 @@ since your modifications would be overwritten. """ + from keras_hub.src.metrics.bleu import Bleu from keras_hub.src.metrics.edit_distance import EditDistance from keras_hub.src.metrics.perplexity import Perplexity diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index e3f0a3aa16..a884bea89c 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -4,55 +4,34 @@ since your modifications would be overwritten. """ + from keras_hub.src.models.albert.albert_backbone import AlbertBackbone from keras_hub.src.models.albert.albert_masked_lm import AlbertMaskedLM -from keras_hub.src.models.albert.albert_masked_lm_preprocessor import ( - AlbertMaskedLMPreprocessor, -) -from keras_hub.src.models.albert.albert_text_classifier import ( - AlbertTextClassifier, -) -from keras_hub.src.models.albert.albert_text_classifier import ( - AlbertTextClassifier as AlbertClassifier, -) -from keras_hub.src.models.albert.albert_text_classifier_preprocessor import ( - AlbertTextClassifierPreprocessor, -) -from keras_hub.src.models.albert.albert_text_classifier_preprocessor import ( - AlbertTextClassifierPreprocessor as AlbertPreprocessor, -) +from keras_hub.src.models.albert.albert_masked_lm_preprocessor import AlbertMaskedLMPreprocessor +from keras_hub.src.models.albert.albert_text_classifier import AlbertTextClassifier +from keras_hub.src.models.albert.albert_text_classifier import AlbertTextClassifier as AlbertClassifier +from keras_hub.src.models.albert.albert_text_classifier_preprocessor import AlbertTextClassifierPreprocessor +from keras_hub.src.models.albert.albert_text_classifier_preprocessor import AlbertTextClassifierPreprocessor as AlbertPreprocessor from keras_hub.src.models.albert.albert_tokenizer import AlbertTokenizer from keras_hub.src.models.backbone import Backbone from keras_hub.src.models.bart.bart_backbone import BartBackbone from keras_hub.src.models.bart.bart_seq_2_seq_lm import BartSeq2SeqLM -from keras_hub.src.models.bart.bart_seq_2_seq_lm_preprocessor import ( - BartSeq2SeqLMPreprocessor, -) +from keras_hub.src.models.bart.bart_seq_2_seq_lm_preprocessor import BartSeq2SeqLMPreprocessor from keras_hub.src.models.bart.bart_tokenizer import BartTokenizer from keras_hub.src.models.basnet.basnet import BASNetImageSegmenter from keras_hub.src.models.basnet.basnet_backbone import BASNetBackbone from keras_hub.src.models.basnet.basnet_preprocessor import BASNetPreprocessor from keras_hub.src.models.bert.bert_backbone import BertBackbone from keras_hub.src.models.bert.bert_masked_lm import BertMaskedLM -from keras_hub.src.models.bert.bert_masked_lm_preprocessor import ( - BertMaskedLMPreprocessor, -) +from keras_hub.src.models.bert.bert_masked_lm_preprocessor import BertMaskedLMPreprocessor from keras_hub.src.models.bert.bert_text_classifier import BertTextClassifier -from keras_hub.src.models.bert.bert_text_classifier import ( - BertTextClassifier as BertClassifier, -) -from keras_hub.src.models.bert.bert_text_classifier_preprocessor import ( - BertTextClassifierPreprocessor, -) -from keras_hub.src.models.bert.bert_text_classifier_preprocessor import ( - BertTextClassifierPreprocessor as BertPreprocessor, -) +from keras_hub.src.models.bert.bert_text_classifier import BertTextClassifier as BertClassifier +from keras_hub.src.models.bert.bert_text_classifier_preprocessor import BertTextClassifierPreprocessor +from keras_hub.src.models.bert.bert_text_classifier_preprocessor import BertTextClassifierPreprocessor as BertPreprocessor from keras_hub.src.models.bert.bert_tokenizer import BertTokenizer from keras_hub.src.models.bloom.bloom_backbone import BloomBackbone from keras_hub.src.models.bloom.bloom_causal_lm import BloomCausalLM -from keras_hub.src.models.bloom.bloom_causal_lm_preprocessor import ( - BloomCausalLMPreprocessor, -) +from keras_hub.src.models.bloom.bloom_causal_lm_preprocessor import BloomCausalLMPreprocessor from keras_hub.src.models.bloom.bloom_tokenizer import BloomTokenizer from keras_hub.src.models.causal_lm import CausalLM from keras_hub.src.models.causal_lm_preprocessor import CausalLMPreprocessor @@ -62,394 +41,194 @@ from keras_hub.src.models.clip.clip_tokenizer import CLIPTokenizer from keras_hub.src.models.clip.clip_vision_encoder import CLIPVisionEncoder from keras_hub.src.models.cspnet.cspnet_backbone import CSPNetBackbone -from keras_hub.src.models.cspnet.cspnet_image_classifier import ( - CSPNetImageClassifier, -) -from keras_hub.src.models.cspnet.cspnet_image_classifier_preprocessor import ( - CSPNetImageClassifierPreprocessor, -) -from keras_hub.src.models.deberta_v3.deberta_v3_backbone import ( - DebertaV3Backbone, -) -from keras_hub.src.models.deberta_v3.deberta_v3_masked_lm import ( - DebertaV3MaskedLM, -) -from keras_hub.src.models.deberta_v3.deberta_v3_masked_lm_preprocessor import ( - DebertaV3MaskedLMPreprocessor, -) -from keras_hub.src.models.deberta_v3.deberta_v3_text_classifier import ( - DebertaV3TextClassifier, -) -from keras_hub.src.models.deberta_v3.deberta_v3_text_classifier import ( - DebertaV3TextClassifier as DebertaV3Classifier, -) -from keras_hub.src.models.deberta_v3.deberta_v3_text_classifier_preprocessor import ( - DebertaV3TextClassifierPreprocessor, -) -from keras_hub.src.models.deberta_v3.deberta_v3_text_classifier_preprocessor import ( - DebertaV3TextClassifierPreprocessor as DebertaV3Preprocessor, -) -from keras_hub.src.models.deberta_v3.deberta_v3_tokenizer import ( - DebertaV3Tokenizer, -) -from keras_hub.src.models.deeplab_v3.deeplab_v3_backbone import ( - DeepLabV3Backbone, -) -from keras_hub.src.models.deeplab_v3.deeplab_v3_image_segmeter_preprocessor import ( - DeepLabV3ImageSegmenterPreprocessor, -) -from keras_hub.src.models.deeplab_v3.deeplab_v3_segmenter import ( - DeepLabV3ImageSegmenter, -) +from keras_hub.src.models.cspnet.cspnet_image_classifier import CSPNetImageClassifier +from keras_hub.src.models.cspnet.cspnet_image_classifier_preprocessor import CSPNetImageClassifierPreprocessor +from keras_hub.src.models.deberta_v3.deberta_v3_backbone import DebertaV3Backbone +from keras_hub.src.models.deberta_v3.deberta_v3_masked_lm import DebertaV3MaskedLM +from keras_hub.src.models.deberta_v3.deberta_v3_masked_lm_preprocessor import DebertaV3MaskedLMPreprocessor +from keras_hub.src.models.deberta_v3.deberta_v3_text_classifier import DebertaV3TextClassifier +from keras_hub.src.models.deberta_v3.deberta_v3_text_classifier import DebertaV3TextClassifier as DebertaV3Classifier +from keras_hub.src.models.deberta_v3.deberta_v3_text_classifier_preprocessor import DebertaV3TextClassifierPreprocessor +from keras_hub.src.models.deberta_v3.deberta_v3_text_classifier_preprocessor import DebertaV3TextClassifierPreprocessor as DebertaV3Preprocessor +from keras_hub.src.models.deberta_v3.deberta_v3_tokenizer import DebertaV3Tokenizer +from keras_hub.src.models.deeplab_v3.deeplab_v3_backbone import DeepLabV3Backbone +from keras_hub.src.models.deeplab_v3.deeplab_v3_image_segmeter_preprocessor import DeepLabV3ImageSegmenterPreprocessor +from keras_hub.src.models.deeplab_v3.deeplab_v3_segmenter import DeepLabV3ImageSegmenter from keras_hub.src.models.densenet.densenet_backbone import DenseNetBackbone -from keras_hub.src.models.densenet.densenet_image_classifier import ( - DenseNetImageClassifier, -) -from keras_hub.src.models.densenet.densenet_image_classifier_preprocessor import ( - DenseNetImageClassifierPreprocessor, -) -from keras_hub.src.models.distil_bert.distil_bert_backbone import ( - DistilBertBackbone, -) -from keras_hub.src.models.distil_bert.distil_bert_masked_lm import ( - DistilBertMaskedLM, -) -from keras_hub.src.models.distil_bert.distil_bert_masked_lm_preprocessor import ( - DistilBertMaskedLMPreprocessor, -) -from keras_hub.src.models.distil_bert.distil_bert_text_classifier import ( - DistilBertTextClassifier, -) -from keras_hub.src.models.distil_bert.distil_bert_text_classifier import ( - DistilBertTextClassifier as DistilBertClassifier, -) -from keras_hub.src.models.distil_bert.distil_bert_text_classifier_preprocessor import ( - DistilBertTextClassifierPreprocessor, -) -from keras_hub.src.models.distil_bert.distil_bert_text_classifier_preprocessor import ( - DistilBertTextClassifierPreprocessor as DistilBertPreprocessor, -) -from keras_hub.src.models.distil_bert.distil_bert_tokenizer import ( - DistilBertTokenizer, -) -from keras_hub.src.models.efficientnet.efficientnet_backbone import ( - EfficientNetBackbone, -) -from keras_hub.src.models.efficientnet.efficientnet_image_classifier import ( - EfficientNetImageClassifier, -) -from keras_hub.src.models.efficientnet.efficientnet_image_classifier_preprocessor import ( - EfficientNetImageClassifierPreprocessor, -) +from keras_hub.src.models.densenet.densenet_image_classifier import DenseNetImageClassifier +from keras_hub.src.models.densenet.densenet_image_classifier_preprocessor import DenseNetImageClassifierPreprocessor +from keras_hub.src.models.distil_bert.distil_bert_backbone import DistilBertBackbone +from keras_hub.src.models.distil_bert.distil_bert_masked_lm import DistilBertMaskedLM +from keras_hub.src.models.distil_bert.distil_bert_masked_lm_preprocessor import DistilBertMaskedLMPreprocessor +from keras_hub.src.models.distil_bert.distil_bert_text_classifier import DistilBertTextClassifier +from keras_hub.src.models.distil_bert.distil_bert_text_classifier import DistilBertTextClassifier as DistilBertClassifier +from keras_hub.src.models.distil_bert.distil_bert_text_classifier_preprocessor import DistilBertTextClassifierPreprocessor +from keras_hub.src.models.distil_bert.distil_bert_text_classifier_preprocessor import DistilBertTextClassifierPreprocessor as DistilBertPreprocessor +from keras_hub.src.models.distil_bert.distil_bert_tokenizer import DistilBertTokenizer +from keras_hub.src.models.efficientnet.efficientnet_backbone import EfficientNetBackbone +from keras_hub.src.models.efficientnet.efficientnet_image_classifier import EfficientNetImageClassifier +from keras_hub.src.models.efficientnet.efficientnet_image_classifier_preprocessor import EfficientNetImageClassifierPreprocessor from keras_hub.src.models.electra.electra_backbone import ElectraBackbone from keras_hub.src.models.electra.electra_tokenizer import ElectraTokenizer from keras_hub.src.models.esm.esm_backbone import ESMBackbone from keras_hub.src.models.esm.esm_backbone import ESMBackbone as ESM2Backbone from keras_hub.src.models.esm.esm_classifier import ESMProteinClassifier -from keras_hub.src.models.esm.esm_classifier_preprocessor import ( - ESMProteinClassifierPreprocessor, -) +from keras_hub.src.models.esm.esm_classifier_preprocessor import ESMProteinClassifierPreprocessor from keras_hub.src.models.esm.esm_masked_plm import ESMMaskedPLM -from keras_hub.src.models.esm.esm_masked_plm import ( - ESMMaskedPLM as ESM2MaskedPLM, -) -from keras_hub.src.models.esm.esm_masked_plm_preprocessor import ( - ESMMaskedPLMPreprocessor, -) +from keras_hub.src.models.esm.esm_masked_plm import ESMMaskedPLM as ESM2MaskedPLM +from keras_hub.src.models.esm.esm_masked_plm_preprocessor import ESMMaskedPLMPreprocessor from keras_hub.src.models.esm.esm_tokenizer import ESMTokenizer from keras_hub.src.models.f_net.f_net_backbone import FNetBackbone from keras_hub.src.models.f_net.f_net_masked_lm import FNetMaskedLM -from keras_hub.src.models.f_net.f_net_masked_lm_preprocessor import ( - FNetMaskedLMPreprocessor, -) +from keras_hub.src.models.f_net.f_net_masked_lm_preprocessor import FNetMaskedLMPreprocessor from keras_hub.src.models.f_net.f_net_text_classifier import FNetTextClassifier -from keras_hub.src.models.f_net.f_net_text_classifier import ( - FNetTextClassifier as FNetClassifier, -) -from keras_hub.src.models.f_net.f_net_text_classifier_preprocessor import ( - FNetTextClassifierPreprocessor, -) -from keras_hub.src.models.f_net.f_net_text_classifier_preprocessor import ( - FNetTextClassifierPreprocessor as FNetPreprocessor, -) +from keras_hub.src.models.f_net.f_net_text_classifier import FNetTextClassifier as FNetClassifier +from keras_hub.src.models.f_net.f_net_text_classifier_preprocessor import FNetTextClassifierPreprocessor +from keras_hub.src.models.f_net.f_net_text_classifier_preprocessor import FNetTextClassifierPreprocessor as FNetPreprocessor from keras_hub.src.models.f_net.f_net_tokenizer import FNetTokenizer from keras_hub.src.models.falcon.falcon_backbone import FalconBackbone from keras_hub.src.models.falcon.falcon_causal_lm import FalconCausalLM -from keras_hub.src.models.falcon.falcon_causal_lm_preprocessor import ( - FalconCausalLMPreprocessor, -) +from keras_hub.src.models.falcon.falcon_causal_lm_preprocessor import FalconCausalLMPreprocessor from keras_hub.src.models.falcon.falcon_tokenizer import FalconTokenizer from keras_hub.src.models.feature_pyramid_backbone import FeaturePyramidBackbone from keras_hub.src.models.flux.flux_model import FluxBackbone from keras_hub.src.models.flux.flux_text_to_image import FluxTextToImage -from keras_hub.src.models.flux.flux_text_to_image_preprocessor import ( - FluxTextToImagePreprocessor, -) +from keras_hub.src.models.flux.flux_text_to_image_preprocessor import FluxTextToImagePreprocessor from keras_hub.src.models.gemma.gemma_backbone import GemmaBackbone from keras_hub.src.models.gemma.gemma_causal_lm import GemmaCausalLM -from keras_hub.src.models.gemma.gemma_causal_lm_preprocessor import ( - GemmaCausalLMPreprocessor, -) +from keras_hub.src.models.gemma.gemma_causal_lm_preprocessor import GemmaCausalLMPreprocessor from keras_hub.src.models.gemma.gemma_tokenizer import GemmaTokenizer from keras_hub.src.models.gemma3.gemma3_backbone import Gemma3Backbone from keras_hub.src.models.gemma3.gemma3_causal_lm import Gemma3CausalLM -from keras_hub.src.models.gemma3.gemma3_causal_lm_preprocessor import ( - Gemma3CausalLMPreprocessor, -) +from keras_hub.src.models.gemma3.gemma3_causal_lm_preprocessor import Gemma3CausalLMPreprocessor from keras_hub.src.models.gemma3.gemma3_tokenizer import Gemma3Tokenizer -from keras_hub.src.models.gemma3.gemma3_vision_encoder import ( - Gemma3VisionEncoder, -) +from keras_hub.src.models.gemma3.gemma3_vision_encoder import Gemma3VisionEncoder from keras_hub.src.models.gpt2.gpt2_backbone import GPT2Backbone from keras_hub.src.models.gpt2.gpt2_causal_lm import GPT2CausalLM -from keras_hub.src.models.gpt2.gpt2_causal_lm_preprocessor import ( - GPT2CausalLMPreprocessor, -) +from keras_hub.src.models.gpt2.gpt2_causal_lm_preprocessor import GPT2CausalLMPreprocessor from keras_hub.src.models.gpt2.gpt2_preprocessor import GPT2Preprocessor from keras_hub.src.models.gpt2.gpt2_tokenizer import GPT2Tokenizer from keras_hub.src.models.gpt_neo_x.gpt_neo_x_backbone import GPTNeoXBackbone from keras_hub.src.models.gpt_neo_x.gpt_neo_x_causal_lm import GPTNeoXCausalLM -from keras_hub.src.models.gpt_neo_x.gpt_neo_x_causal_lm_preprocessor import ( - GPTNeoXCausalLMPreprocessor, -) +from keras_hub.src.models.gpt_neo_x.gpt_neo_x_causal_lm_preprocessor import GPTNeoXCausalLMPreprocessor from keras_hub.src.models.gpt_neo_x.gpt_neo_x_tokenizer import GPTNeoXTokenizer from keras_hub.src.models.image_classifier import ImageClassifier -from keras_hub.src.models.image_classifier_preprocessor import ( - ImageClassifierPreprocessor, -) +from keras_hub.src.models.image_classifier_preprocessor import ImageClassifierPreprocessor from keras_hub.src.models.image_segmenter import ImageSegmenter -from keras_hub.src.models.image_segmenter_preprocessor import ( - ImageSegmenterPreprocessor, -) +from keras_hub.src.models.image_segmenter_preprocessor import ImageSegmenterPreprocessor from keras_hub.src.models.image_to_image import ImageToImage from keras_hub.src.models.inpaint import Inpaint from keras_hub.src.models.llama.llama_backbone import LlamaBackbone from keras_hub.src.models.llama.llama_causal_lm import LlamaCausalLM -from keras_hub.src.models.llama.llama_causal_lm_preprocessor import ( - LlamaCausalLMPreprocessor, -) +from keras_hub.src.models.llama.llama_causal_lm_preprocessor import LlamaCausalLMPreprocessor from keras_hub.src.models.llama.llama_tokenizer import LlamaTokenizer from keras_hub.src.models.llama3.llama3_backbone import Llama3Backbone from keras_hub.src.models.llama3.llama3_causal_lm import Llama3CausalLM -from keras_hub.src.models.llama3.llama3_causal_lm_preprocessor import ( - Llama3CausalLMPreprocessor, -) +from keras_hub.src.models.llama3.llama3_causal_lm_preprocessor import Llama3CausalLMPreprocessor from keras_hub.src.models.llama3.llama3_tokenizer import Llama3Tokenizer from keras_hub.src.models.masked_lm import MaskedLM from keras_hub.src.models.masked_lm_preprocessor import MaskedLMPreprocessor from keras_hub.src.models.mistral.mistral_backbone import MistralBackbone from keras_hub.src.models.mistral.mistral_causal_lm import MistralCausalLM -from keras_hub.src.models.mistral.mistral_causal_lm_preprocessor import ( - MistralCausalLMPreprocessor, -) +from keras_hub.src.models.mistral.mistral_causal_lm_preprocessor import MistralCausalLMPreprocessor from keras_hub.src.models.mistral.mistral_tokenizer import MistralTokenizer from keras_hub.src.models.mit.mit_backbone import MiTBackbone from keras_hub.src.models.mit.mit_image_classifier import MiTImageClassifier -from keras_hub.src.models.mit.mit_image_classifier_preprocessor import ( - MiTImageClassifierPreprocessor, -) +from keras_hub.src.models.mit.mit_image_classifier_preprocessor import MiTImageClassifierPreprocessor from keras_hub.src.models.mobilenet.mobilenet_backbone import MobileNetBackbone -from keras_hub.src.models.mobilenet.mobilenet_image_classifier import ( - MobileNetImageClassifier, -) -from keras_hub.src.models.mobilenet.mobilenet_image_classifier_preprocessor import ( - MobileNetImageClassifierPreprocessor, -) +from keras_hub.src.models.mobilenet.mobilenet_image_classifier import MobileNetImageClassifier +from keras_hub.src.models.mobilenet.mobilenet_image_classifier_preprocessor import MobileNetImageClassifierPreprocessor from keras_hub.src.models.object_detector import ObjectDetector -from keras_hub.src.models.object_detector import ( - ObjectDetector as ImageObjectDetector, -) -from keras_hub.src.models.object_detector_preprocessor import ( - ObjectDetectorPreprocessor, -) -from keras_hub.src.models.object_detector_preprocessor import ( - ObjectDetectorPreprocessor as ImageObjectDetectorPreprocessor, -) +from keras_hub.src.models.object_detector import ObjectDetector as ImageObjectDetector +from keras_hub.src.models.object_detector_preprocessor import ObjectDetectorPreprocessor +from keras_hub.src.models.object_detector_preprocessor import ObjectDetectorPreprocessor as ImageObjectDetectorPreprocessor from keras_hub.src.models.opt.opt_backbone import OPTBackbone from keras_hub.src.models.opt.opt_causal_lm import OPTCausalLM -from keras_hub.src.models.opt.opt_causal_lm_preprocessor import ( - OPTCausalLMPreprocessor, -) +from keras_hub.src.models.opt.opt_causal_lm_preprocessor import OPTCausalLMPreprocessor from keras_hub.src.models.opt.opt_tokenizer import OPTTokenizer -from keras_hub.src.models.pali_gemma.pali_gemma_backbone import ( - PaliGemmaBackbone, -) -from keras_hub.src.models.pali_gemma.pali_gemma_causal_lm import ( - PaliGemmaCausalLM, -) -from keras_hub.src.models.pali_gemma.pali_gemma_causal_lm_preprocessor import ( - PaliGemmaCausalLMPreprocessor, -) -from keras_hub.src.models.pali_gemma.pali_gemma_tokenizer import ( - PaliGemmaTokenizer, -) +from keras_hub.src.models.pali_gemma.pali_gemma_backbone import PaliGemmaBackbone +from keras_hub.src.models.pali_gemma.pali_gemma_causal_lm import PaliGemmaCausalLM +from keras_hub.src.models.pali_gemma.pali_gemma_causal_lm_preprocessor import PaliGemmaCausalLMPreprocessor +from keras_hub.src.models.pali_gemma.pali_gemma_tokenizer import PaliGemmaTokenizer from keras_hub.src.models.phi3.phi3_backbone import Phi3Backbone from keras_hub.src.models.phi3.phi3_causal_lm import Phi3CausalLM -from keras_hub.src.models.phi3.phi3_causal_lm_preprocessor import ( - Phi3CausalLMPreprocessor, -) +from keras_hub.src.models.phi3.phi3_causal_lm_preprocessor import Phi3CausalLMPreprocessor from keras_hub.src.models.phi3.phi3_tokenizer import Phi3Tokenizer from keras_hub.src.models.preprocessor import Preprocessor from keras_hub.src.models.qwen.qwen_backbone import QwenBackbone -from keras_hub.src.models.qwen.qwen_backbone import ( - QwenBackbone as Qwen2Backbone, -) +from keras_hub.src.models.qwen.qwen_backbone import QwenBackbone as Qwen2Backbone from keras_hub.src.models.qwen.qwen_causal_lm import QwenCausalLM -from keras_hub.src.models.qwen.qwen_causal_lm import ( - QwenCausalLM as Qwen2CausalLM, -) -from keras_hub.src.models.qwen.qwen_causal_lm_preprocessor import ( - QwenCausalLMPreprocessor, -) -from keras_hub.src.models.qwen.qwen_causal_lm_preprocessor import ( - QwenCausalLMPreprocessor as Qwen2CausalLMPreprocessor, -) +from keras_hub.src.models.qwen.qwen_causal_lm import QwenCausalLM as Qwen2CausalLM +from keras_hub.src.models.qwen.qwen_causal_lm_preprocessor import QwenCausalLMPreprocessor +from keras_hub.src.models.qwen.qwen_causal_lm_preprocessor import QwenCausalLMPreprocessor as Qwen2CausalLMPreprocessor from keras_hub.src.models.qwen.qwen_tokenizer import QwenTokenizer -from keras_hub.src.models.qwen.qwen_tokenizer import ( - QwenTokenizer as Qwen2Tokenizer, -) +from keras_hub.src.models.qwen.qwen_tokenizer import QwenTokenizer as Qwen2Tokenizer from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone -from keras_hub.src.models.resnet.resnet_image_classifier import ( - ResNetImageClassifier, -) -from keras_hub.src.models.resnet.resnet_image_classifier_preprocessor import ( - ResNetImageClassifierPreprocessor, -) +from keras_hub.src.models.resnet.resnet_image_classifier import ResNetImageClassifier +from keras_hub.src.models.resnet.resnet_image_classifier_preprocessor import ResNetImageClassifierPreprocessor from keras_hub.src.models.retinanet.retinanet_backbone import RetinaNetBackbone -from keras_hub.src.models.retinanet.retinanet_object_detector import ( - RetinaNetObjectDetector, -) -from keras_hub.src.models.retinanet.retinanet_object_detector_preprocessor import ( - RetinaNetObjectDetectorPreprocessor, -) +from keras_hub.src.models.retinanet.retinanet_object_detector import RetinaNetObjectDetector +from keras_hub.src.models.retinanet.retinanet_object_detector_preprocessor import RetinaNetObjectDetectorPreprocessor from keras_hub.src.models.roberta.roberta_backbone import RobertaBackbone from keras_hub.src.models.roberta.roberta_masked_lm import RobertaMaskedLM -from keras_hub.src.models.roberta.roberta_masked_lm_preprocessor import ( - RobertaMaskedLMPreprocessor, -) -from keras_hub.src.models.roberta.roberta_text_classifier import ( - RobertaTextClassifier, -) -from keras_hub.src.models.roberta.roberta_text_classifier import ( - RobertaTextClassifier as RobertaClassifier, -) -from keras_hub.src.models.roberta.roberta_text_classifier_preprocessor import ( - RobertaTextClassifierPreprocessor, -) -from keras_hub.src.models.roberta.roberta_text_classifier_preprocessor import ( - RobertaTextClassifierPreprocessor as RobertaPreprocessor, -) +from keras_hub.src.models.roberta.roberta_masked_lm_preprocessor import RobertaMaskedLMPreprocessor +from keras_hub.src.models.roberta.roberta_text_classifier import RobertaTextClassifier +from keras_hub.src.models.roberta.roberta_text_classifier import RobertaTextClassifier as RobertaClassifier +from keras_hub.src.models.roberta.roberta_text_classifier_preprocessor import RobertaTextClassifierPreprocessor +from keras_hub.src.models.roberta.roberta_text_classifier_preprocessor import RobertaTextClassifierPreprocessor as RobertaPreprocessor from keras_hub.src.models.roberta.roberta_tokenizer import RobertaTokenizer -from keras_hub.src.models.roformer_v2.roformer_v2_backbone import ( - RoformerV2Backbone, -) -from keras_hub.src.models.roformer_v2.roformer_v2_masked_lm import ( - RoformerV2MaskedLM, -) -from keras_hub.src.models.roformer_v2.roformer_v2_masked_lm_preprocessor import ( - RoformerV2MaskedLMPreprocessor, -) -from keras_hub.src.models.roformer_v2.roformer_v2_text_classifier import ( - RoformerV2TextClassifier, -) -from keras_hub.src.models.roformer_v2.roformer_v2_text_classifier_preprocessor import ( - RoformerV2TextClassifierPreprocessor, -) -from keras_hub.src.models.roformer_v2.roformer_v2_tokenizer import ( - RoformerV2Tokenizer, -) +from keras_hub.src.models.roformer_v2.roformer_v2_backbone import RoformerV2Backbone +from keras_hub.src.models.roformer_v2.roformer_v2_masked_lm import RoformerV2MaskedLM +from keras_hub.src.models.roformer_v2.roformer_v2_masked_lm_preprocessor import RoformerV2MaskedLMPreprocessor +from keras_hub.src.models.roformer_v2.roformer_v2_text_classifier import RoformerV2TextClassifier +from keras_hub.src.models.roformer_v2.roformer_v2_text_classifier_preprocessor import RoformerV2TextClassifierPreprocessor +from keras_hub.src.models.roformer_v2.roformer_v2_tokenizer import RoformerV2Tokenizer from keras_hub.src.models.sam.sam_backbone import SAMBackbone from keras_hub.src.models.sam.sam_image_segmenter import SAMImageSegmenter -from keras_hub.src.models.sam.sam_image_segmenter_preprocessor import ( - SAMImageSegmenterPreprocessor, -) +from keras_hub.src.models.sam.sam_image_segmenter_preprocessor import SAMImageSegmenterPreprocessor from keras_hub.src.models.segformer.segformer_backbone import SegFormerBackbone -from keras_hub.src.models.segformer.segformer_image_segmenter import ( - SegFormerImageSegmenter, -) -from keras_hub.src.models.segformer.segformer_image_segmenter_preprocessor import ( - SegFormerImageSegmenterPreprocessor, -) +from keras_hub.src.models.segformer.segformer_image_segmenter import SegFormerImageSegmenter +from keras_hub.src.models.segformer.segformer_image_segmenter_preprocessor import SegFormerImageSegmenterPreprocessor from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM from keras_hub.src.models.seq_2_seq_lm_preprocessor import Seq2SeqLMPreprocessor from keras_hub.src.models.siglip.siglip_backbone import SigLIPBackbone from keras_hub.src.models.siglip.siglip_preprocessor import SigLIPPreprocessor from keras_hub.src.models.siglip.siglip_text_encoder import SigLIPTextEncoder from keras_hub.src.models.siglip.siglip_tokenizer import SigLIPTokenizer -from keras_hub.src.models.siglip.siglip_vision_encoder import ( - SigLIPVisionEncoder, -) -from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import ( - StableDiffusion3Backbone, -) -from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_image_to_image import ( - StableDiffusion3ImageToImage, -) -from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_inpaint import ( - StableDiffusion3Inpaint, -) -from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image import ( - StableDiffusion3TextToImage, -) -from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image_preprocessor import ( - StableDiffusion3TextToImagePreprocessor, -) +from keras_hub.src.models.siglip.siglip_vision_encoder import SigLIPVisionEncoder +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import StableDiffusion3Backbone +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_image_to_image import StableDiffusion3ImageToImage +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_inpaint import StableDiffusion3Inpaint +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image import StableDiffusion3TextToImage +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image_preprocessor import StableDiffusion3TextToImagePreprocessor from keras_hub.src.models.t5.t5_backbone import T5Backbone from keras_hub.src.models.t5.t5_preprocessor import T5Preprocessor from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer from keras_hub.src.models.task import Task from keras_hub.src.models.text_classifier import TextClassifier from keras_hub.src.models.text_classifier import TextClassifier as Classifier -from keras_hub.src.models.text_classifier_preprocessor import ( - TextClassifierPreprocessor, -) +from keras_hub.src.models.text_classifier_preprocessor import TextClassifierPreprocessor from keras_hub.src.models.text_to_image import TextToImage -from keras_hub.src.models.text_to_image_preprocessor import ( - TextToImagePreprocessor, -) +from keras_hub.src.models.text_to_image_preprocessor import TextToImagePreprocessor from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone from keras_hub.src.models.vgg.vgg_image_classifier import VGGImageClassifier -from keras_hub.src.models.vgg.vgg_image_classifier_preprocessor import ( - VGGImageClassifierPreprocessor, -) +from keras_hub.src.models.vgg.vgg_image_classifier_preprocessor import VGGImageClassifierPreprocessor from keras_hub.src.models.vit.vit_backbone import ViTBackbone from keras_hub.src.models.vit.vit_image_classifier import ViTImageClassifier -from keras_hub.src.models.vit.vit_image_classifier_preprocessor import ( - ViTImageClassifierPreprocessor, -) +from keras_hub.src.models.vit.vit_image_classifier_preprocessor import ViTImageClassifierPreprocessor from keras_hub.src.models.vit_det.vit_det_backbone import ViTDetBackbone from keras_hub.src.models.whisper.whisper_backbone import WhisperBackbone from keras_hub.src.models.whisper.whisper_tokenizer import WhisperTokenizer from keras_hub.src.models.xception.xception_backbone import XceptionBackbone -from keras_hub.src.models.xception.xception_image_classifier import ( - XceptionImageClassifier, -) -from keras_hub.src.models.xception.xception_image_classifier_preprocessor import ( - XceptionImageClassifierPreprocessor, -) -from keras_hub.src.models.xlm_roberta.xlm_roberta_backbone import ( - XLMRobertaBackbone, -) -from keras_hub.src.models.xlm_roberta.xlm_roberta_masked_lm import ( - XLMRobertaMaskedLM, -) -from keras_hub.src.models.xlm_roberta.xlm_roberta_masked_lm_preprocessor import ( - XLMRobertaMaskedLMPreprocessor, -) -from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier import ( - XLMRobertaTextClassifier, -) -from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier import ( - XLMRobertaTextClassifier as XLMRobertaClassifier, -) -from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier_preprocessor import ( - XLMRobertaTextClassifierPreprocessor, -) -from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier_preprocessor import ( - XLMRobertaTextClassifierPreprocessor as XLMRobertaPreprocessor, -) -from keras_hub.src.models.xlm_roberta.xlm_roberta_tokenizer import ( - XLMRobertaTokenizer, -) +from keras_hub.src.models.xception.xception_image_classifier import XceptionImageClassifier +from keras_hub.src.models.xception.xception_image_classifier_preprocessor import XceptionImageClassifierPreprocessor +from keras_hub.src.models.xlm_roberta.xlm_roberta_backbone import XLMRobertaBackbone +from keras_hub.src.models.xlm_roberta.xlm_roberta_masked_lm import XLMRobertaMaskedLM +from keras_hub.src.models.xlm_roberta.xlm_roberta_masked_lm_preprocessor import XLMRobertaMaskedLMPreprocessor +from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier import XLMRobertaTextClassifier +from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier import XLMRobertaTextClassifier as XLMRobertaClassifier +from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier_preprocessor import XLMRobertaTextClassifierPreprocessor +from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier_preprocessor import XLMRobertaTextClassifierPreprocessor as XLMRobertaPreprocessor +from keras_hub.src.models.xlm_roberta.xlm_roberta_tokenizer import XLMRobertaTokenizer from keras_hub.src.models.xlnet.xlnet_backbone import XLNetBackbone from keras_hub.src.tokenizers.tokenizer import Tokenizer diff --git a/keras_hub/api/samplers/__init__.py b/keras_hub/api/samplers/__init__.py index 9feb76c669..5270b41c63 100644 --- a/keras_hub/api/samplers/__init__.py +++ b/keras_hub/api/samplers/__init__.py @@ -4,6 +4,7 @@ since your modifications would be overwritten. """ + from keras_hub.src.samplers.beam_sampler import BeamSampler from keras_hub.src.samplers.contrastive_sampler import ContrastiveSampler from keras_hub.src.samplers.greedy_sampler import GreedySampler diff --git a/keras_hub/api/tokenizers/__init__.py b/keras_hub/api/tokenizers/__init__.py index 3615e77581..a59f1beb18 100644 --- a/keras_hub/api/tokenizers/__init__.py +++ b/keras_hub/api/tokenizers/__init__.py @@ -4,17 +4,14 @@ since your modifications would be overwritten. """ + from keras_hub.src.models.albert.albert_tokenizer import AlbertTokenizer from keras_hub.src.models.bart.bart_tokenizer import BartTokenizer from keras_hub.src.models.bert.bert_tokenizer import BertTokenizer from keras_hub.src.models.bloom.bloom_tokenizer import BloomTokenizer from keras_hub.src.models.clip.clip_tokenizer import CLIPTokenizer -from keras_hub.src.models.deberta_v3.deberta_v3_tokenizer import ( - DebertaV3Tokenizer, -) -from keras_hub.src.models.distil_bert.distil_bert_tokenizer import ( - DistilBertTokenizer, -) +from keras_hub.src.models.deberta_v3.deberta_v3_tokenizer import DebertaV3Tokenizer +from keras_hub.src.models.distil_bert.distil_bert_tokenizer import DistilBertTokenizer from keras_hub.src.models.electra.electra_tokenizer import ElectraTokenizer from keras_hub.src.models.esm.esm_tokenizer import ESMTokenizer from keras_hub.src.models.f_net.f_net_tokenizer import FNetTokenizer @@ -27,37 +24,21 @@ from keras_hub.src.models.llama3.llama3_tokenizer import Llama3Tokenizer from keras_hub.src.models.mistral.mistral_tokenizer import MistralTokenizer from keras_hub.src.models.opt.opt_tokenizer import OPTTokenizer -from keras_hub.src.models.pali_gemma.pali_gemma_tokenizer import ( - PaliGemmaTokenizer, -) +from keras_hub.src.models.pali_gemma.pali_gemma_tokenizer import PaliGemmaTokenizer from keras_hub.src.models.phi3.phi3_tokenizer import Phi3Tokenizer from keras_hub.src.models.qwen.qwen_tokenizer import QwenTokenizer -from keras_hub.src.models.qwen.qwen_tokenizer import ( - QwenTokenizer as Qwen2Tokenizer, -) +from keras_hub.src.models.qwen.qwen_tokenizer import QwenTokenizer as Qwen2Tokenizer from keras_hub.src.models.roberta.roberta_tokenizer import RobertaTokenizer -from keras_hub.src.models.roformer_v2.roformer_v2_tokenizer import ( - RoformerV2Tokenizer, -) +from keras_hub.src.models.roformer_v2.roformer_v2_tokenizer import RoformerV2Tokenizer from keras_hub.src.models.siglip.siglip_tokenizer import SigLIPTokenizer from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer from keras_hub.src.models.whisper.whisper_tokenizer import WhisperTokenizer -from keras_hub.src.models.xlm_roberta.xlm_roberta_tokenizer import ( - XLMRobertaTokenizer, -) +from keras_hub.src.models.xlm_roberta.xlm_roberta_tokenizer import XLMRobertaTokenizer from keras_hub.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer from keras_hub.src.tokenizers.byte_tokenizer import ByteTokenizer -from keras_hub.src.tokenizers.sentence_piece_tokenizer import ( - SentencePieceTokenizer, -) -from keras_hub.src.tokenizers.sentence_piece_tokenizer_trainer import ( - compute_sentence_piece_proto, -) +from keras_hub.src.tokenizers.sentence_piece_tokenizer import SentencePieceTokenizer +from keras_hub.src.tokenizers.sentence_piece_tokenizer_trainer import compute_sentence_piece_proto from keras_hub.src.tokenizers.tokenizer import Tokenizer -from keras_hub.src.tokenizers.unicode_codepoint_tokenizer import ( - UnicodeCodepointTokenizer, -) +from keras_hub.src.tokenizers.unicode_codepoint_tokenizer import UnicodeCodepointTokenizer from keras_hub.src.tokenizers.word_piece_tokenizer import WordPieceTokenizer -from keras_hub.src.tokenizers.word_piece_tokenizer_trainer import ( - compute_word_piece_vocabulary, -) +from keras_hub.src.tokenizers.word_piece_tokenizer_trainer import compute_word_piece_vocabulary diff --git a/keras_hub/api/utils/__init__.py b/keras_hub/api/utils/__init__.py index 8ce47790b0..69d3a766e5 100644 --- a/keras_hub/api/utils/__init__.py +++ b/keras_hub/api/utils/__init__.py @@ -4,10 +4,9 @@ since your modifications would be overwritten. """ + from keras_hub.src.utils.coco.coco_utils import coco_id_to_name from keras_hub.src.utils.coco.coco_utils import coco_name_to_id -from keras_hub.src.utils.imagenet.imagenet_utils import ( - decode_imagenet_predictions, -) +from keras_hub.src.utils.imagenet.imagenet_utils import decode_imagenet_predictions from keras_hub.src.utils.imagenet.imagenet_utils import imagenet_id_to_name from keras_hub.src.utils.imagenet.imagenet_utils import imagenet_name_to_id From 140207bf21545380923d491a80897f08e291e115 Mon Sep 17 00:00:00 2001 From: pass_lin <935499957@qq.com> Date: Sat, 3 May 2025 18:01:50 +0800 Subject: [PATCH 05/13] format --- keras_hub/api/__init__.py | 1 - keras_hub/api/layers/__init__.py | 81 +++-- keras_hub/api/metrics/__init__.py | 1 - keras_hub/api/models/__init__.py | 445 ++++++++++++++++++++------- keras_hub/api/samplers/__init__.py | 1 - keras_hub/api/tokenizers/__init__.py | 41 ++- keras_hub/api/utils/__init__.py | 5 +- 7 files changed, 426 insertions(+), 149 deletions(-) diff --git a/keras_hub/api/__init__.py b/keras_hub/api/__init__.py index 72580b9101..0d9b1a3eb8 100644 --- a/keras_hub/api/__init__.py +++ b/keras_hub/api/__init__.py @@ -4,7 +4,6 @@ since your modifications would be overwritten. """ - from keras_hub import layers from keras_hub import metrics from keras_hub import models diff --git a/keras_hub/api/layers/__init__.py b/keras_hub/api/layers/__init__.py index 76be91cc4c..d42af86a3c 100644 --- a/keras_hub/api/layers/__init__.py +++ b/keras_hub/api/layers/__init__.py @@ -4,47 +4,86 @@ since your modifications would be overwritten. """ - from keras_hub.src.layers.modeling.alibi_bias import AlibiBias from keras_hub.src.layers.modeling.anchor_generator import AnchorGenerator from keras_hub.src.layers.modeling.box_matcher import BoxMatcher -from keras_hub.src.layers.modeling.cached_multi_head_attention import CachedMultiHeadAttention +from keras_hub.src.layers.modeling.cached_multi_head_attention import ( + CachedMultiHeadAttention, +) from keras_hub.src.layers.modeling.f_net_encoder import FNetEncoder from keras_hub.src.layers.modeling.masked_lm_head import MaskedLMHead from keras_hub.src.layers.modeling.non_max_supression import NonMaxSuppression from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding -from keras_hub.src.layers.modeling.reversible_embedding import ReversibleEmbedding +from keras_hub.src.layers.modeling.reversible_embedding import ( + ReversibleEmbedding, +) from keras_hub.src.layers.modeling.rms_normalization import RMSNormalization from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding -from keras_hub.src.layers.modeling.sine_position_encoding import SinePositionEncoding -from keras_hub.src.layers.modeling.token_and_position_embedding import TokenAndPositionEmbedding +from keras_hub.src.layers.modeling.sine_position_encoding import ( + SinePositionEncoding, +) +from keras_hub.src.layers.modeling.token_and_position_embedding import ( + TokenAndPositionEmbedding, +) from keras_hub.src.layers.modeling.transformer_decoder import TransformerDecoder from keras_hub.src.layers.modeling.transformer_encoder import TransformerEncoder from keras_hub.src.layers.preprocessing.audio_converter import AudioConverter from keras_hub.src.layers.preprocessing.image_converter import ImageConverter -from keras_hub.src.layers.preprocessing.masked_lm_mask_generator import MaskedLMMaskGenerator -from keras_hub.src.layers.preprocessing.multi_segment_packer import MultiSegmentPacker +from keras_hub.src.layers.preprocessing.masked_lm_mask_generator import ( + MaskedLMMaskGenerator, +) +from keras_hub.src.layers.preprocessing.multi_segment_packer import ( + MultiSegmentPacker, +) from keras_hub.src.layers.preprocessing.random_deletion import RandomDeletion from keras_hub.src.layers.preprocessing.random_swap import RandomSwap from keras_hub.src.layers.preprocessing.start_end_packer import StartEndPacker -from keras_hub.src.models.basnet.basnet_image_converter import BASNetImageConverter +from keras_hub.src.models.basnet.basnet_image_converter import ( + BASNetImageConverter, +) from keras_hub.src.models.clip.clip_image_converter import CLIPImageConverter -from keras_hub.src.models.cspnet.cspnet_image_converter import CSPNetImageConverter -from keras_hub.src.models.deeplab_v3.deeplab_v3_image_converter import DeepLabV3ImageConverter -from keras_hub.src.models.densenet.densenet_image_converter import DenseNetImageConverter -from keras_hub.src.models.efficientnet.efficientnet_image_converter import EfficientNetImageConverter -from keras_hub.src.models.gemma3.gemma3_image_converter import Gemma3ImageConverter +from keras_hub.src.models.cspnet.cspnet_image_converter import ( + CSPNetImageConverter, +) +from keras_hub.src.models.deeplab_v3.deeplab_v3_image_converter import ( + DeepLabV3ImageConverter, +) +from keras_hub.src.models.densenet.densenet_image_converter import ( + DenseNetImageConverter, +) +from keras_hub.src.models.efficientnet.efficientnet_image_converter import ( + EfficientNetImageConverter, +) +from keras_hub.src.models.gemma3.gemma3_image_converter import ( + Gemma3ImageConverter, +) from keras_hub.src.models.mit.mit_image_converter import MiTImageConverter -from keras_hub.src.models.mobilenet.mobilenet_image_converter import MobileNetImageConverter -from keras_hub.src.models.pali_gemma.pali_gemma_image_converter import PaliGemmaImageConverter -from keras_hub.src.models.resnet.resnet_image_converter import ResNetImageConverter -from keras_hub.src.models.retinanet.retinanet_image_converter import RetinaNetImageConverter +from keras_hub.src.models.mobilenet.mobilenet_image_converter import ( + MobileNetImageConverter, +) +from keras_hub.src.models.pali_gemma.pali_gemma_image_converter import ( + PaliGemmaImageConverter, +) +from keras_hub.src.models.resnet.resnet_image_converter import ( + ResNetImageConverter, +) +from keras_hub.src.models.retinanet.retinanet_image_converter import ( + RetinaNetImageConverter, +) from keras_hub.src.models.sam.sam_image_converter import SAMImageConverter from keras_hub.src.models.sam.sam_mask_decoder import SAMMaskDecoder from keras_hub.src.models.sam.sam_prompt_encoder import SAMPromptEncoder -from keras_hub.src.models.segformer.segformer_image_converter import SegFormerImageConverter -from keras_hub.src.models.siglip.siglip_image_converter import SigLIPImageConverter +from keras_hub.src.models.segformer.segformer_image_converter import ( + SegFormerImageConverter, +) +from keras_hub.src.models.siglip.siglip_image_converter import ( + SigLIPImageConverter, +) from keras_hub.src.models.vgg.vgg_image_converter import VGGImageConverter from keras_hub.src.models.vit.vit_image_converter import ViTImageConverter -from keras_hub.src.models.whisper.whisper_audio_converter import WhisperAudioConverter -from keras_hub.src.models.xception.xception_image_converter import XceptionImageConverter +from keras_hub.src.models.whisper.whisper_audio_converter import ( + WhisperAudioConverter, +) +from keras_hub.src.models.xception.xception_image_converter import ( + XceptionImageConverter, +) diff --git a/keras_hub/api/metrics/__init__.py b/keras_hub/api/metrics/__init__.py index 636c3117d1..88a0a7df2b 100644 --- a/keras_hub/api/metrics/__init__.py +++ b/keras_hub/api/metrics/__init__.py @@ -4,7 +4,6 @@ since your modifications would be overwritten. """ - from keras_hub.src.metrics.bleu import Bleu from keras_hub.src.metrics.edit_distance import EditDistance from keras_hub.src.metrics.perplexity import Perplexity diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index a884bea89c..e3f0a3aa16 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -4,34 +4,55 @@ since your modifications would be overwritten. """ - from keras_hub.src.models.albert.albert_backbone import AlbertBackbone from keras_hub.src.models.albert.albert_masked_lm import AlbertMaskedLM -from keras_hub.src.models.albert.albert_masked_lm_preprocessor import AlbertMaskedLMPreprocessor -from keras_hub.src.models.albert.albert_text_classifier import AlbertTextClassifier -from keras_hub.src.models.albert.albert_text_classifier import AlbertTextClassifier as AlbertClassifier -from keras_hub.src.models.albert.albert_text_classifier_preprocessor import AlbertTextClassifierPreprocessor -from keras_hub.src.models.albert.albert_text_classifier_preprocessor import AlbertTextClassifierPreprocessor as AlbertPreprocessor +from keras_hub.src.models.albert.albert_masked_lm_preprocessor import ( + AlbertMaskedLMPreprocessor, +) +from keras_hub.src.models.albert.albert_text_classifier import ( + AlbertTextClassifier, +) +from keras_hub.src.models.albert.albert_text_classifier import ( + AlbertTextClassifier as AlbertClassifier, +) +from keras_hub.src.models.albert.albert_text_classifier_preprocessor import ( + AlbertTextClassifierPreprocessor, +) +from keras_hub.src.models.albert.albert_text_classifier_preprocessor import ( + AlbertTextClassifierPreprocessor as AlbertPreprocessor, +) from keras_hub.src.models.albert.albert_tokenizer import AlbertTokenizer from keras_hub.src.models.backbone import Backbone from keras_hub.src.models.bart.bart_backbone import BartBackbone from keras_hub.src.models.bart.bart_seq_2_seq_lm import BartSeq2SeqLM -from keras_hub.src.models.bart.bart_seq_2_seq_lm_preprocessor import BartSeq2SeqLMPreprocessor +from keras_hub.src.models.bart.bart_seq_2_seq_lm_preprocessor import ( + BartSeq2SeqLMPreprocessor, +) from keras_hub.src.models.bart.bart_tokenizer import BartTokenizer from keras_hub.src.models.basnet.basnet import BASNetImageSegmenter from keras_hub.src.models.basnet.basnet_backbone import BASNetBackbone from keras_hub.src.models.basnet.basnet_preprocessor import BASNetPreprocessor from keras_hub.src.models.bert.bert_backbone import BertBackbone from keras_hub.src.models.bert.bert_masked_lm import BertMaskedLM -from keras_hub.src.models.bert.bert_masked_lm_preprocessor import BertMaskedLMPreprocessor +from keras_hub.src.models.bert.bert_masked_lm_preprocessor import ( + BertMaskedLMPreprocessor, +) from keras_hub.src.models.bert.bert_text_classifier import BertTextClassifier -from keras_hub.src.models.bert.bert_text_classifier import BertTextClassifier as BertClassifier -from keras_hub.src.models.bert.bert_text_classifier_preprocessor import BertTextClassifierPreprocessor -from keras_hub.src.models.bert.bert_text_classifier_preprocessor import BertTextClassifierPreprocessor as BertPreprocessor +from keras_hub.src.models.bert.bert_text_classifier import ( + BertTextClassifier as BertClassifier, +) +from keras_hub.src.models.bert.bert_text_classifier_preprocessor import ( + BertTextClassifierPreprocessor, +) +from keras_hub.src.models.bert.bert_text_classifier_preprocessor import ( + BertTextClassifierPreprocessor as BertPreprocessor, +) from keras_hub.src.models.bert.bert_tokenizer import BertTokenizer from keras_hub.src.models.bloom.bloom_backbone import BloomBackbone from keras_hub.src.models.bloom.bloom_causal_lm import BloomCausalLM -from keras_hub.src.models.bloom.bloom_causal_lm_preprocessor import BloomCausalLMPreprocessor +from keras_hub.src.models.bloom.bloom_causal_lm_preprocessor import ( + BloomCausalLMPreprocessor, +) from keras_hub.src.models.bloom.bloom_tokenizer import BloomTokenizer from keras_hub.src.models.causal_lm import CausalLM from keras_hub.src.models.causal_lm_preprocessor import CausalLMPreprocessor @@ -41,194 +62,394 @@ from keras_hub.src.models.clip.clip_tokenizer import CLIPTokenizer from keras_hub.src.models.clip.clip_vision_encoder import CLIPVisionEncoder from keras_hub.src.models.cspnet.cspnet_backbone import CSPNetBackbone -from keras_hub.src.models.cspnet.cspnet_image_classifier import CSPNetImageClassifier -from keras_hub.src.models.cspnet.cspnet_image_classifier_preprocessor import CSPNetImageClassifierPreprocessor -from keras_hub.src.models.deberta_v3.deberta_v3_backbone import DebertaV3Backbone -from keras_hub.src.models.deberta_v3.deberta_v3_masked_lm import DebertaV3MaskedLM -from keras_hub.src.models.deberta_v3.deberta_v3_masked_lm_preprocessor import DebertaV3MaskedLMPreprocessor -from keras_hub.src.models.deberta_v3.deberta_v3_text_classifier import DebertaV3TextClassifier -from keras_hub.src.models.deberta_v3.deberta_v3_text_classifier import DebertaV3TextClassifier as DebertaV3Classifier -from keras_hub.src.models.deberta_v3.deberta_v3_text_classifier_preprocessor import DebertaV3TextClassifierPreprocessor -from keras_hub.src.models.deberta_v3.deberta_v3_text_classifier_preprocessor import DebertaV3TextClassifierPreprocessor as DebertaV3Preprocessor -from keras_hub.src.models.deberta_v3.deberta_v3_tokenizer import DebertaV3Tokenizer -from keras_hub.src.models.deeplab_v3.deeplab_v3_backbone import DeepLabV3Backbone -from keras_hub.src.models.deeplab_v3.deeplab_v3_image_segmeter_preprocessor import DeepLabV3ImageSegmenterPreprocessor -from keras_hub.src.models.deeplab_v3.deeplab_v3_segmenter import DeepLabV3ImageSegmenter +from keras_hub.src.models.cspnet.cspnet_image_classifier import ( + CSPNetImageClassifier, +) +from keras_hub.src.models.cspnet.cspnet_image_classifier_preprocessor import ( + CSPNetImageClassifierPreprocessor, +) +from keras_hub.src.models.deberta_v3.deberta_v3_backbone import ( + DebertaV3Backbone, +) +from keras_hub.src.models.deberta_v3.deberta_v3_masked_lm import ( + DebertaV3MaskedLM, +) +from keras_hub.src.models.deberta_v3.deberta_v3_masked_lm_preprocessor import ( + DebertaV3MaskedLMPreprocessor, +) +from keras_hub.src.models.deberta_v3.deberta_v3_text_classifier import ( + DebertaV3TextClassifier, +) +from keras_hub.src.models.deberta_v3.deberta_v3_text_classifier import ( + DebertaV3TextClassifier as DebertaV3Classifier, +) +from keras_hub.src.models.deberta_v3.deberta_v3_text_classifier_preprocessor import ( + DebertaV3TextClassifierPreprocessor, +) +from keras_hub.src.models.deberta_v3.deberta_v3_text_classifier_preprocessor import ( + DebertaV3TextClassifierPreprocessor as DebertaV3Preprocessor, +) +from keras_hub.src.models.deberta_v3.deberta_v3_tokenizer import ( + DebertaV3Tokenizer, +) +from keras_hub.src.models.deeplab_v3.deeplab_v3_backbone import ( + DeepLabV3Backbone, +) +from keras_hub.src.models.deeplab_v3.deeplab_v3_image_segmeter_preprocessor import ( + DeepLabV3ImageSegmenterPreprocessor, +) +from keras_hub.src.models.deeplab_v3.deeplab_v3_segmenter import ( + DeepLabV3ImageSegmenter, +) from keras_hub.src.models.densenet.densenet_backbone import DenseNetBackbone -from keras_hub.src.models.densenet.densenet_image_classifier import DenseNetImageClassifier -from keras_hub.src.models.densenet.densenet_image_classifier_preprocessor import DenseNetImageClassifierPreprocessor -from keras_hub.src.models.distil_bert.distil_bert_backbone import DistilBertBackbone -from keras_hub.src.models.distil_bert.distil_bert_masked_lm import DistilBertMaskedLM -from keras_hub.src.models.distil_bert.distil_bert_masked_lm_preprocessor import DistilBertMaskedLMPreprocessor -from keras_hub.src.models.distil_bert.distil_bert_text_classifier import DistilBertTextClassifier -from keras_hub.src.models.distil_bert.distil_bert_text_classifier import DistilBertTextClassifier as DistilBertClassifier -from keras_hub.src.models.distil_bert.distil_bert_text_classifier_preprocessor import DistilBertTextClassifierPreprocessor -from keras_hub.src.models.distil_bert.distil_bert_text_classifier_preprocessor import DistilBertTextClassifierPreprocessor as DistilBertPreprocessor -from keras_hub.src.models.distil_bert.distil_bert_tokenizer import DistilBertTokenizer -from keras_hub.src.models.efficientnet.efficientnet_backbone import EfficientNetBackbone -from keras_hub.src.models.efficientnet.efficientnet_image_classifier import EfficientNetImageClassifier -from keras_hub.src.models.efficientnet.efficientnet_image_classifier_preprocessor import EfficientNetImageClassifierPreprocessor +from keras_hub.src.models.densenet.densenet_image_classifier import ( + DenseNetImageClassifier, +) +from keras_hub.src.models.densenet.densenet_image_classifier_preprocessor import ( + DenseNetImageClassifierPreprocessor, +) +from keras_hub.src.models.distil_bert.distil_bert_backbone import ( + DistilBertBackbone, +) +from keras_hub.src.models.distil_bert.distil_bert_masked_lm import ( + DistilBertMaskedLM, +) +from keras_hub.src.models.distil_bert.distil_bert_masked_lm_preprocessor import ( + DistilBertMaskedLMPreprocessor, +) +from keras_hub.src.models.distil_bert.distil_bert_text_classifier import ( + DistilBertTextClassifier, +) +from keras_hub.src.models.distil_bert.distil_bert_text_classifier import ( + DistilBertTextClassifier as DistilBertClassifier, +) +from keras_hub.src.models.distil_bert.distil_bert_text_classifier_preprocessor import ( + DistilBertTextClassifierPreprocessor, +) +from keras_hub.src.models.distil_bert.distil_bert_text_classifier_preprocessor import ( + DistilBertTextClassifierPreprocessor as DistilBertPreprocessor, +) +from keras_hub.src.models.distil_bert.distil_bert_tokenizer import ( + DistilBertTokenizer, +) +from keras_hub.src.models.efficientnet.efficientnet_backbone import ( + EfficientNetBackbone, +) +from keras_hub.src.models.efficientnet.efficientnet_image_classifier import ( + EfficientNetImageClassifier, +) +from keras_hub.src.models.efficientnet.efficientnet_image_classifier_preprocessor import ( + EfficientNetImageClassifierPreprocessor, +) from keras_hub.src.models.electra.electra_backbone import ElectraBackbone from keras_hub.src.models.electra.electra_tokenizer import ElectraTokenizer from keras_hub.src.models.esm.esm_backbone import ESMBackbone from keras_hub.src.models.esm.esm_backbone import ESMBackbone as ESM2Backbone from keras_hub.src.models.esm.esm_classifier import ESMProteinClassifier -from keras_hub.src.models.esm.esm_classifier_preprocessor import ESMProteinClassifierPreprocessor +from keras_hub.src.models.esm.esm_classifier_preprocessor import ( + ESMProteinClassifierPreprocessor, +) from keras_hub.src.models.esm.esm_masked_plm import ESMMaskedPLM -from keras_hub.src.models.esm.esm_masked_plm import ESMMaskedPLM as ESM2MaskedPLM -from keras_hub.src.models.esm.esm_masked_plm_preprocessor import ESMMaskedPLMPreprocessor +from keras_hub.src.models.esm.esm_masked_plm import ( + ESMMaskedPLM as ESM2MaskedPLM, +) +from keras_hub.src.models.esm.esm_masked_plm_preprocessor import ( + ESMMaskedPLMPreprocessor, +) from keras_hub.src.models.esm.esm_tokenizer import ESMTokenizer from keras_hub.src.models.f_net.f_net_backbone import FNetBackbone from keras_hub.src.models.f_net.f_net_masked_lm import FNetMaskedLM -from keras_hub.src.models.f_net.f_net_masked_lm_preprocessor import FNetMaskedLMPreprocessor +from keras_hub.src.models.f_net.f_net_masked_lm_preprocessor import ( + FNetMaskedLMPreprocessor, +) from keras_hub.src.models.f_net.f_net_text_classifier import FNetTextClassifier -from keras_hub.src.models.f_net.f_net_text_classifier import FNetTextClassifier as FNetClassifier -from keras_hub.src.models.f_net.f_net_text_classifier_preprocessor import FNetTextClassifierPreprocessor -from keras_hub.src.models.f_net.f_net_text_classifier_preprocessor import FNetTextClassifierPreprocessor as FNetPreprocessor +from keras_hub.src.models.f_net.f_net_text_classifier import ( + FNetTextClassifier as FNetClassifier, +) +from keras_hub.src.models.f_net.f_net_text_classifier_preprocessor import ( + FNetTextClassifierPreprocessor, +) +from keras_hub.src.models.f_net.f_net_text_classifier_preprocessor import ( + FNetTextClassifierPreprocessor as FNetPreprocessor, +) from keras_hub.src.models.f_net.f_net_tokenizer import FNetTokenizer from keras_hub.src.models.falcon.falcon_backbone import FalconBackbone from keras_hub.src.models.falcon.falcon_causal_lm import FalconCausalLM -from keras_hub.src.models.falcon.falcon_causal_lm_preprocessor import FalconCausalLMPreprocessor +from keras_hub.src.models.falcon.falcon_causal_lm_preprocessor import ( + FalconCausalLMPreprocessor, +) from keras_hub.src.models.falcon.falcon_tokenizer import FalconTokenizer from keras_hub.src.models.feature_pyramid_backbone import FeaturePyramidBackbone from keras_hub.src.models.flux.flux_model import FluxBackbone from keras_hub.src.models.flux.flux_text_to_image import FluxTextToImage -from keras_hub.src.models.flux.flux_text_to_image_preprocessor import FluxTextToImagePreprocessor +from keras_hub.src.models.flux.flux_text_to_image_preprocessor import ( + FluxTextToImagePreprocessor, +) from keras_hub.src.models.gemma.gemma_backbone import GemmaBackbone from keras_hub.src.models.gemma.gemma_causal_lm import GemmaCausalLM -from keras_hub.src.models.gemma.gemma_causal_lm_preprocessor import GemmaCausalLMPreprocessor +from keras_hub.src.models.gemma.gemma_causal_lm_preprocessor import ( + GemmaCausalLMPreprocessor, +) from keras_hub.src.models.gemma.gemma_tokenizer import GemmaTokenizer from keras_hub.src.models.gemma3.gemma3_backbone import Gemma3Backbone from keras_hub.src.models.gemma3.gemma3_causal_lm import Gemma3CausalLM -from keras_hub.src.models.gemma3.gemma3_causal_lm_preprocessor import Gemma3CausalLMPreprocessor +from keras_hub.src.models.gemma3.gemma3_causal_lm_preprocessor import ( + Gemma3CausalLMPreprocessor, +) from keras_hub.src.models.gemma3.gemma3_tokenizer import Gemma3Tokenizer -from keras_hub.src.models.gemma3.gemma3_vision_encoder import Gemma3VisionEncoder +from keras_hub.src.models.gemma3.gemma3_vision_encoder import ( + Gemma3VisionEncoder, +) from keras_hub.src.models.gpt2.gpt2_backbone import GPT2Backbone from keras_hub.src.models.gpt2.gpt2_causal_lm import GPT2CausalLM -from keras_hub.src.models.gpt2.gpt2_causal_lm_preprocessor import GPT2CausalLMPreprocessor +from keras_hub.src.models.gpt2.gpt2_causal_lm_preprocessor import ( + GPT2CausalLMPreprocessor, +) from keras_hub.src.models.gpt2.gpt2_preprocessor import GPT2Preprocessor from keras_hub.src.models.gpt2.gpt2_tokenizer import GPT2Tokenizer from keras_hub.src.models.gpt_neo_x.gpt_neo_x_backbone import GPTNeoXBackbone from keras_hub.src.models.gpt_neo_x.gpt_neo_x_causal_lm import GPTNeoXCausalLM -from keras_hub.src.models.gpt_neo_x.gpt_neo_x_causal_lm_preprocessor import GPTNeoXCausalLMPreprocessor +from keras_hub.src.models.gpt_neo_x.gpt_neo_x_causal_lm_preprocessor import ( + GPTNeoXCausalLMPreprocessor, +) from keras_hub.src.models.gpt_neo_x.gpt_neo_x_tokenizer import GPTNeoXTokenizer from keras_hub.src.models.image_classifier import ImageClassifier -from keras_hub.src.models.image_classifier_preprocessor import ImageClassifierPreprocessor +from keras_hub.src.models.image_classifier_preprocessor import ( + ImageClassifierPreprocessor, +) from keras_hub.src.models.image_segmenter import ImageSegmenter -from keras_hub.src.models.image_segmenter_preprocessor import ImageSegmenterPreprocessor +from keras_hub.src.models.image_segmenter_preprocessor import ( + ImageSegmenterPreprocessor, +) from keras_hub.src.models.image_to_image import ImageToImage from keras_hub.src.models.inpaint import Inpaint from keras_hub.src.models.llama.llama_backbone import LlamaBackbone from keras_hub.src.models.llama.llama_causal_lm import LlamaCausalLM -from keras_hub.src.models.llama.llama_causal_lm_preprocessor import LlamaCausalLMPreprocessor +from keras_hub.src.models.llama.llama_causal_lm_preprocessor import ( + LlamaCausalLMPreprocessor, +) from keras_hub.src.models.llama.llama_tokenizer import LlamaTokenizer from keras_hub.src.models.llama3.llama3_backbone import Llama3Backbone from keras_hub.src.models.llama3.llama3_causal_lm import Llama3CausalLM -from keras_hub.src.models.llama3.llama3_causal_lm_preprocessor import Llama3CausalLMPreprocessor +from keras_hub.src.models.llama3.llama3_causal_lm_preprocessor import ( + Llama3CausalLMPreprocessor, +) from keras_hub.src.models.llama3.llama3_tokenizer import Llama3Tokenizer from keras_hub.src.models.masked_lm import MaskedLM from keras_hub.src.models.masked_lm_preprocessor import MaskedLMPreprocessor from keras_hub.src.models.mistral.mistral_backbone import MistralBackbone from keras_hub.src.models.mistral.mistral_causal_lm import MistralCausalLM -from keras_hub.src.models.mistral.mistral_causal_lm_preprocessor import MistralCausalLMPreprocessor +from keras_hub.src.models.mistral.mistral_causal_lm_preprocessor import ( + MistralCausalLMPreprocessor, +) from keras_hub.src.models.mistral.mistral_tokenizer import MistralTokenizer from keras_hub.src.models.mit.mit_backbone import MiTBackbone from keras_hub.src.models.mit.mit_image_classifier import MiTImageClassifier -from keras_hub.src.models.mit.mit_image_classifier_preprocessor import MiTImageClassifierPreprocessor +from keras_hub.src.models.mit.mit_image_classifier_preprocessor import ( + MiTImageClassifierPreprocessor, +) from keras_hub.src.models.mobilenet.mobilenet_backbone import MobileNetBackbone -from keras_hub.src.models.mobilenet.mobilenet_image_classifier import MobileNetImageClassifier -from keras_hub.src.models.mobilenet.mobilenet_image_classifier_preprocessor import MobileNetImageClassifierPreprocessor +from keras_hub.src.models.mobilenet.mobilenet_image_classifier import ( + MobileNetImageClassifier, +) +from keras_hub.src.models.mobilenet.mobilenet_image_classifier_preprocessor import ( + MobileNetImageClassifierPreprocessor, +) from keras_hub.src.models.object_detector import ObjectDetector -from keras_hub.src.models.object_detector import ObjectDetector as ImageObjectDetector -from keras_hub.src.models.object_detector_preprocessor import ObjectDetectorPreprocessor -from keras_hub.src.models.object_detector_preprocessor import ObjectDetectorPreprocessor as ImageObjectDetectorPreprocessor +from keras_hub.src.models.object_detector import ( + ObjectDetector as ImageObjectDetector, +) +from keras_hub.src.models.object_detector_preprocessor import ( + ObjectDetectorPreprocessor, +) +from keras_hub.src.models.object_detector_preprocessor import ( + ObjectDetectorPreprocessor as ImageObjectDetectorPreprocessor, +) from keras_hub.src.models.opt.opt_backbone import OPTBackbone from keras_hub.src.models.opt.opt_causal_lm import OPTCausalLM -from keras_hub.src.models.opt.opt_causal_lm_preprocessor import OPTCausalLMPreprocessor +from keras_hub.src.models.opt.opt_causal_lm_preprocessor import ( + OPTCausalLMPreprocessor, +) from keras_hub.src.models.opt.opt_tokenizer import OPTTokenizer -from keras_hub.src.models.pali_gemma.pali_gemma_backbone import PaliGemmaBackbone -from keras_hub.src.models.pali_gemma.pali_gemma_causal_lm import PaliGemmaCausalLM -from keras_hub.src.models.pali_gemma.pali_gemma_causal_lm_preprocessor import PaliGemmaCausalLMPreprocessor -from keras_hub.src.models.pali_gemma.pali_gemma_tokenizer import PaliGemmaTokenizer +from keras_hub.src.models.pali_gemma.pali_gemma_backbone import ( + PaliGemmaBackbone, +) +from keras_hub.src.models.pali_gemma.pali_gemma_causal_lm import ( + PaliGemmaCausalLM, +) +from keras_hub.src.models.pali_gemma.pali_gemma_causal_lm_preprocessor import ( + PaliGemmaCausalLMPreprocessor, +) +from keras_hub.src.models.pali_gemma.pali_gemma_tokenizer import ( + PaliGemmaTokenizer, +) from keras_hub.src.models.phi3.phi3_backbone import Phi3Backbone from keras_hub.src.models.phi3.phi3_causal_lm import Phi3CausalLM -from keras_hub.src.models.phi3.phi3_causal_lm_preprocessor import Phi3CausalLMPreprocessor +from keras_hub.src.models.phi3.phi3_causal_lm_preprocessor import ( + Phi3CausalLMPreprocessor, +) from keras_hub.src.models.phi3.phi3_tokenizer import Phi3Tokenizer from keras_hub.src.models.preprocessor import Preprocessor from keras_hub.src.models.qwen.qwen_backbone import QwenBackbone -from keras_hub.src.models.qwen.qwen_backbone import QwenBackbone as Qwen2Backbone +from keras_hub.src.models.qwen.qwen_backbone import ( + QwenBackbone as Qwen2Backbone, +) from keras_hub.src.models.qwen.qwen_causal_lm import QwenCausalLM -from keras_hub.src.models.qwen.qwen_causal_lm import QwenCausalLM as Qwen2CausalLM -from keras_hub.src.models.qwen.qwen_causal_lm_preprocessor import QwenCausalLMPreprocessor -from keras_hub.src.models.qwen.qwen_causal_lm_preprocessor import QwenCausalLMPreprocessor as Qwen2CausalLMPreprocessor +from keras_hub.src.models.qwen.qwen_causal_lm import ( + QwenCausalLM as Qwen2CausalLM, +) +from keras_hub.src.models.qwen.qwen_causal_lm_preprocessor import ( + QwenCausalLMPreprocessor, +) +from keras_hub.src.models.qwen.qwen_causal_lm_preprocessor import ( + QwenCausalLMPreprocessor as Qwen2CausalLMPreprocessor, +) from keras_hub.src.models.qwen.qwen_tokenizer import QwenTokenizer -from keras_hub.src.models.qwen.qwen_tokenizer import QwenTokenizer as Qwen2Tokenizer +from keras_hub.src.models.qwen.qwen_tokenizer import ( + QwenTokenizer as Qwen2Tokenizer, +) from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone -from keras_hub.src.models.resnet.resnet_image_classifier import ResNetImageClassifier -from keras_hub.src.models.resnet.resnet_image_classifier_preprocessor import ResNetImageClassifierPreprocessor +from keras_hub.src.models.resnet.resnet_image_classifier import ( + ResNetImageClassifier, +) +from keras_hub.src.models.resnet.resnet_image_classifier_preprocessor import ( + ResNetImageClassifierPreprocessor, +) from keras_hub.src.models.retinanet.retinanet_backbone import RetinaNetBackbone -from keras_hub.src.models.retinanet.retinanet_object_detector import RetinaNetObjectDetector -from keras_hub.src.models.retinanet.retinanet_object_detector_preprocessor import RetinaNetObjectDetectorPreprocessor +from keras_hub.src.models.retinanet.retinanet_object_detector import ( + RetinaNetObjectDetector, +) +from keras_hub.src.models.retinanet.retinanet_object_detector_preprocessor import ( + RetinaNetObjectDetectorPreprocessor, +) from keras_hub.src.models.roberta.roberta_backbone import RobertaBackbone from keras_hub.src.models.roberta.roberta_masked_lm import RobertaMaskedLM -from keras_hub.src.models.roberta.roberta_masked_lm_preprocessor import RobertaMaskedLMPreprocessor -from keras_hub.src.models.roberta.roberta_text_classifier import RobertaTextClassifier -from keras_hub.src.models.roberta.roberta_text_classifier import RobertaTextClassifier as RobertaClassifier -from keras_hub.src.models.roberta.roberta_text_classifier_preprocessor import RobertaTextClassifierPreprocessor -from keras_hub.src.models.roberta.roberta_text_classifier_preprocessor import RobertaTextClassifierPreprocessor as RobertaPreprocessor +from keras_hub.src.models.roberta.roberta_masked_lm_preprocessor import ( + RobertaMaskedLMPreprocessor, +) +from keras_hub.src.models.roberta.roberta_text_classifier import ( + RobertaTextClassifier, +) +from keras_hub.src.models.roberta.roberta_text_classifier import ( + RobertaTextClassifier as RobertaClassifier, +) +from keras_hub.src.models.roberta.roberta_text_classifier_preprocessor import ( + RobertaTextClassifierPreprocessor, +) +from keras_hub.src.models.roberta.roberta_text_classifier_preprocessor import ( + RobertaTextClassifierPreprocessor as RobertaPreprocessor, +) from keras_hub.src.models.roberta.roberta_tokenizer import RobertaTokenizer -from keras_hub.src.models.roformer_v2.roformer_v2_backbone import RoformerV2Backbone -from keras_hub.src.models.roformer_v2.roformer_v2_masked_lm import RoformerV2MaskedLM -from keras_hub.src.models.roformer_v2.roformer_v2_masked_lm_preprocessor import RoformerV2MaskedLMPreprocessor -from keras_hub.src.models.roformer_v2.roformer_v2_text_classifier import RoformerV2TextClassifier -from keras_hub.src.models.roformer_v2.roformer_v2_text_classifier_preprocessor import RoformerV2TextClassifierPreprocessor -from keras_hub.src.models.roformer_v2.roformer_v2_tokenizer import RoformerV2Tokenizer +from keras_hub.src.models.roformer_v2.roformer_v2_backbone import ( + RoformerV2Backbone, +) +from keras_hub.src.models.roformer_v2.roformer_v2_masked_lm import ( + RoformerV2MaskedLM, +) +from keras_hub.src.models.roformer_v2.roformer_v2_masked_lm_preprocessor import ( + RoformerV2MaskedLMPreprocessor, +) +from keras_hub.src.models.roformer_v2.roformer_v2_text_classifier import ( + RoformerV2TextClassifier, +) +from keras_hub.src.models.roformer_v2.roformer_v2_text_classifier_preprocessor import ( + RoformerV2TextClassifierPreprocessor, +) +from keras_hub.src.models.roformer_v2.roformer_v2_tokenizer import ( + RoformerV2Tokenizer, +) from keras_hub.src.models.sam.sam_backbone import SAMBackbone from keras_hub.src.models.sam.sam_image_segmenter import SAMImageSegmenter -from keras_hub.src.models.sam.sam_image_segmenter_preprocessor import SAMImageSegmenterPreprocessor +from keras_hub.src.models.sam.sam_image_segmenter_preprocessor import ( + SAMImageSegmenterPreprocessor, +) from keras_hub.src.models.segformer.segformer_backbone import SegFormerBackbone -from keras_hub.src.models.segformer.segformer_image_segmenter import SegFormerImageSegmenter -from keras_hub.src.models.segformer.segformer_image_segmenter_preprocessor import SegFormerImageSegmenterPreprocessor +from keras_hub.src.models.segformer.segformer_image_segmenter import ( + SegFormerImageSegmenter, +) +from keras_hub.src.models.segformer.segformer_image_segmenter_preprocessor import ( + SegFormerImageSegmenterPreprocessor, +) from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM from keras_hub.src.models.seq_2_seq_lm_preprocessor import Seq2SeqLMPreprocessor from keras_hub.src.models.siglip.siglip_backbone import SigLIPBackbone from keras_hub.src.models.siglip.siglip_preprocessor import SigLIPPreprocessor from keras_hub.src.models.siglip.siglip_text_encoder import SigLIPTextEncoder from keras_hub.src.models.siglip.siglip_tokenizer import SigLIPTokenizer -from keras_hub.src.models.siglip.siglip_vision_encoder import SigLIPVisionEncoder -from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import StableDiffusion3Backbone -from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_image_to_image import StableDiffusion3ImageToImage -from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_inpaint import StableDiffusion3Inpaint -from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image import StableDiffusion3TextToImage -from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image_preprocessor import StableDiffusion3TextToImagePreprocessor +from keras_hub.src.models.siglip.siglip_vision_encoder import ( + SigLIPVisionEncoder, +) +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import ( + StableDiffusion3Backbone, +) +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_image_to_image import ( + StableDiffusion3ImageToImage, +) +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_inpaint import ( + StableDiffusion3Inpaint, +) +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image import ( + StableDiffusion3TextToImage, +) +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image_preprocessor import ( + StableDiffusion3TextToImagePreprocessor, +) from keras_hub.src.models.t5.t5_backbone import T5Backbone from keras_hub.src.models.t5.t5_preprocessor import T5Preprocessor from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer from keras_hub.src.models.task import Task from keras_hub.src.models.text_classifier import TextClassifier from keras_hub.src.models.text_classifier import TextClassifier as Classifier -from keras_hub.src.models.text_classifier_preprocessor import TextClassifierPreprocessor +from keras_hub.src.models.text_classifier_preprocessor import ( + TextClassifierPreprocessor, +) from keras_hub.src.models.text_to_image import TextToImage -from keras_hub.src.models.text_to_image_preprocessor import TextToImagePreprocessor +from keras_hub.src.models.text_to_image_preprocessor import ( + TextToImagePreprocessor, +) from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone from keras_hub.src.models.vgg.vgg_image_classifier import VGGImageClassifier -from keras_hub.src.models.vgg.vgg_image_classifier_preprocessor import VGGImageClassifierPreprocessor +from keras_hub.src.models.vgg.vgg_image_classifier_preprocessor import ( + VGGImageClassifierPreprocessor, +) from keras_hub.src.models.vit.vit_backbone import ViTBackbone from keras_hub.src.models.vit.vit_image_classifier import ViTImageClassifier -from keras_hub.src.models.vit.vit_image_classifier_preprocessor import ViTImageClassifierPreprocessor +from keras_hub.src.models.vit.vit_image_classifier_preprocessor import ( + ViTImageClassifierPreprocessor, +) from keras_hub.src.models.vit_det.vit_det_backbone import ViTDetBackbone from keras_hub.src.models.whisper.whisper_backbone import WhisperBackbone from keras_hub.src.models.whisper.whisper_tokenizer import WhisperTokenizer from keras_hub.src.models.xception.xception_backbone import XceptionBackbone -from keras_hub.src.models.xception.xception_image_classifier import XceptionImageClassifier -from keras_hub.src.models.xception.xception_image_classifier_preprocessor import XceptionImageClassifierPreprocessor -from keras_hub.src.models.xlm_roberta.xlm_roberta_backbone import XLMRobertaBackbone -from keras_hub.src.models.xlm_roberta.xlm_roberta_masked_lm import XLMRobertaMaskedLM -from keras_hub.src.models.xlm_roberta.xlm_roberta_masked_lm_preprocessor import XLMRobertaMaskedLMPreprocessor -from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier import XLMRobertaTextClassifier -from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier import XLMRobertaTextClassifier as XLMRobertaClassifier -from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier_preprocessor import XLMRobertaTextClassifierPreprocessor -from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier_preprocessor import XLMRobertaTextClassifierPreprocessor as XLMRobertaPreprocessor -from keras_hub.src.models.xlm_roberta.xlm_roberta_tokenizer import XLMRobertaTokenizer +from keras_hub.src.models.xception.xception_image_classifier import ( + XceptionImageClassifier, +) +from keras_hub.src.models.xception.xception_image_classifier_preprocessor import ( + XceptionImageClassifierPreprocessor, +) +from keras_hub.src.models.xlm_roberta.xlm_roberta_backbone import ( + XLMRobertaBackbone, +) +from keras_hub.src.models.xlm_roberta.xlm_roberta_masked_lm import ( + XLMRobertaMaskedLM, +) +from keras_hub.src.models.xlm_roberta.xlm_roberta_masked_lm_preprocessor import ( + XLMRobertaMaskedLMPreprocessor, +) +from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier import ( + XLMRobertaTextClassifier, +) +from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier import ( + XLMRobertaTextClassifier as XLMRobertaClassifier, +) +from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier_preprocessor import ( + XLMRobertaTextClassifierPreprocessor, +) +from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier_preprocessor import ( + XLMRobertaTextClassifierPreprocessor as XLMRobertaPreprocessor, +) +from keras_hub.src.models.xlm_roberta.xlm_roberta_tokenizer import ( + XLMRobertaTokenizer, +) from keras_hub.src.models.xlnet.xlnet_backbone import XLNetBackbone from keras_hub.src.tokenizers.tokenizer import Tokenizer diff --git a/keras_hub/api/samplers/__init__.py b/keras_hub/api/samplers/__init__.py index 5270b41c63..9feb76c669 100644 --- a/keras_hub/api/samplers/__init__.py +++ b/keras_hub/api/samplers/__init__.py @@ -4,7 +4,6 @@ since your modifications would be overwritten. """ - from keras_hub.src.samplers.beam_sampler import BeamSampler from keras_hub.src.samplers.contrastive_sampler import ContrastiveSampler from keras_hub.src.samplers.greedy_sampler import GreedySampler diff --git a/keras_hub/api/tokenizers/__init__.py b/keras_hub/api/tokenizers/__init__.py index a59f1beb18..3615e77581 100644 --- a/keras_hub/api/tokenizers/__init__.py +++ b/keras_hub/api/tokenizers/__init__.py @@ -4,14 +4,17 @@ since your modifications would be overwritten. """ - from keras_hub.src.models.albert.albert_tokenizer import AlbertTokenizer from keras_hub.src.models.bart.bart_tokenizer import BartTokenizer from keras_hub.src.models.bert.bert_tokenizer import BertTokenizer from keras_hub.src.models.bloom.bloom_tokenizer import BloomTokenizer from keras_hub.src.models.clip.clip_tokenizer import CLIPTokenizer -from keras_hub.src.models.deberta_v3.deberta_v3_tokenizer import DebertaV3Tokenizer -from keras_hub.src.models.distil_bert.distil_bert_tokenizer import DistilBertTokenizer +from keras_hub.src.models.deberta_v3.deberta_v3_tokenizer import ( + DebertaV3Tokenizer, +) +from keras_hub.src.models.distil_bert.distil_bert_tokenizer import ( + DistilBertTokenizer, +) from keras_hub.src.models.electra.electra_tokenizer import ElectraTokenizer from keras_hub.src.models.esm.esm_tokenizer import ESMTokenizer from keras_hub.src.models.f_net.f_net_tokenizer import FNetTokenizer @@ -24,21 +27,37 @@ from keras_hub.src.models.llama3.llama3_tokenizer import Llama3Tokenizer from keras_hub.src.models.mistral.mistral_tokenizer import MistralTokenizer from keras_hub.src.models.opt.opt_tokenizer import OPTTokenizer -from keras_hub.src.models.pali_gemma.pali_gemma_tokenizer import PaliGemmaTokenizer +from keras_hub.src.models.pali_gemma.pali_gemma_tokenizer import ( + PaliGemmaTokenizer, +) from keras_hub.src.models.phi3.phi3_tokenizer import Phi3Tokenizer from keras_hub.src.models.qwen.qwen_tokenizer import QwenTokenizer -from keras_hub.src.models.qwen.qwen_tokenizer import QwenTokenizer as Qwen2Tokenizer +from keras_hub.src.models.qwen.qwen_tokenizer import ( + QwenTokenizer as Qwen2Tokenizer, +) from keras_hub.src.models.roberta.roberta_tokenizer import RobertaTokenizer -from keras_hub.src.models.roformer_v2.roformer_v2_tokenizer import RoformerV2Tokenizer +from keras_hub.src.models.roformer_v2.roformer_v2_tokenizer import ( + RoformerV2Tokenizer, +) from keras_hub.src.models.siglip.siglip_tokenizer import SigLIPTokenizer from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer from keras_hub.src.models.whisper.whisper_tokenizer import WhisperTokenizer -from keras_hub.src.models.xlm_roberta.xlm_roberta_tokenizer import XLMRobertaTokenizer +from keras_hub.src.models.xlm_roberta.xlm_roberta_tokenizer import ( + XLMRobertaTokenizer, +) from keras_hub.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer from keras_hub.src.tokenizers.byte_tokenizer import ByteTokenizer -from keras_hub.src.tokenizers.sentence_piece_tokenizer import SentencePieceTokenizer -from keras_hub.src.tokenizers.sentence_piece_tokenizer_trainer import compute_sentence_piece_proto +from keras_hub.src.tokenizers.sentence_piece_tokenizer import ( + SentencePieceTokenizer, +) +from keras_hub.src.tokenizers.sentence_piece_tokenizer_trainer import ( + compute_sentence_piece_proto, +) from keras_hub.src.tokenizers.tokenizer import Tokenizer -from keras_hub.src.tokenizers.unicode_codepoint_tokenizer import UnicodeCodepointTokenizer +from keras_hub.src.tokenizers.unicode_codepoint_tokenizer import ( + UnicodeCodepointTokenizer, +) from keras_hub.src.tokenizers.word_piece_tokenizer import WordPieceTokenizer -from keras_hub.src.tokenizers.word_piece_tokenizer_trainer import compute_word_piece_vocabulary +from keras_hub.src.tokenizers.word_piece_tokenizer_trainer import ( + compute_word_piece_vocabulary, +) diff --git a/keras_hub/api/utils/__init__.py b/keras_hub/api/utils/__init__.py index 69d3a766e5..8ce47790b0 100644 --- a/keras_hub/api/utils/__init__.py +++ b/keras_hub/api/utils/__init__.py @@ -4,9 +4,10 @@ since your modifications would be overwritten. """ - from keras_hub.src.utils.coco.coco_utils import coco_id_to_name from keras_hub.src.utils.coco.coco_utils import coco_name_to_id -from keras_hub.src.utils.imagenet.imagenet_utils import decode_imagenet_predictions +from keras_hub.src.utils.imagenet.imagenet_utils import ( + decode_imagenet_predictions, +) from keras_hub.src.utils.imagenet.imagenet_utils import imagenet_id_to_name from keras_hub.src.utils.imagenet.imagenet_utils import imagenet_name_to_id From cc9a11c0eb18823063ce9301992fb005dbed4e49 Mon Sep 17 00:00:00 2001 From: pass_lin <935499957@qq.com> Date: Sat, 3 May 2025 18:59:45 +0800 Subject: [PATCH 06/13] fix test --- keras_hub/src/models/esm/esm_backbone_test.py | 1 - keras_hub/src/models/esm/esm_classifier_test.py | 1 - keras_hub/src/models/esm/esm_masked_plm_test.py | 1 - 3 files changed, 3 deletions(-) diff --git a/keras_hub/src/models/esm/esm_backbone_test.py b/keras_hub/src/models/esm/esm_backbone_test.py index 91e5227ebe..8be6aa9c25 100644 --- a/keras_hub/src/models/esm/esm_backbone_test.py +++ b/keras_hub/src/models/esm/esm_backbone_test.py @@ -13,7 +13,6 @@ def setUp(self): "num_heads": 1, "hidden_dim": 2, "intermediate_dim": 4, - "head_size": 2, } self.input_data = { "token_ids": ops.ones((2, 5), dtype="int32"), diff --git a/keras_hub/src/models/esm/esm_classifier_test.py b/keras_hub/src/models/esm/esm_classifier_test.py index cf55cc8285..3f6e9a1501 100644 --- a/keras_hub/src/models/esm/esm_classifier_test.py +++ b/keras_hub/src/models/esm/esm_classifier_test.py @@ -24,7 +24,6 @@ def setUp(self): num_heads=2, hidden_dim=4, intermediate_dim=8, - head_size=2, ) self.init_kwargs = { "preprocessor": self.preprocessor, diff --git a/keras_hub/src/models/esm/esm_masked_plm_test.py b/keras_hub/src/models/esm/esm_masked_plm_test.py index 92a76cbb1b..b02adc106d 100644 --- a/keras_hub/src/models/esm/esm_masked_plm_test.py +++ b/keras_hub/src/models/esm/esm_masked_plm_test.py @@ -29,7 +29,6 @@ def setUp(self): num_heads=2, hidden_dim=4, intermediate_dim=8, - head_size=2, ) self.init_kwargs = { "preprocessor": self.preprocessor, From f8da784f324ee9197bfc48385e7bb279ddddf954 Mon Sep 17 00:00:00 2001 From: pass_lin <935499957@qq.com> Date: Thu, 8 May 2025 18:03:45 +0800 Subject: [PATCH 07/13] format --- keras_hub/api/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_hub/api/__init__.py b/keras_hub/api/__init__.py index 0d9b1a3eb8..3796e4c7f4 100644 --- a/keras_hub/api/__init__.py +++ b/keras_hub/api/__init__.py @@ -11,5 +11,5 @@ from keras_hub import tokenizers from keras_hub import utils from keras_hub.src.utils.preset_utils import upload_preset -from keras_hub.src.version import version from keras_hub.src.version import __version__ as __version__ +from keras_hub.src.version import version From 72e9829bde00304ae1706c8ff660b6ae5168710c Mon Sep 17 00:00:00 2001 From: pass_lin <935499957@qq.com> Date: Sat, 10 May 2025 19:26:11 +0800 Subject: [PATCH 08/13] renew --- keras_hub/api/models/__init__.py | 614 +++++++++++++++++---------- keras_hub/api/tokenizers/__init__.py | 121 ++++-- 2 files changed, 467 insertions(+), 268 deletions(-) diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index e3f0a3aa16..fe831ccabb 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -4,452 +4,606 @@ since your modifications would be overwritten. """ -from keras_hub.src.models.albert.albert_backbone import AlbertBackbone -from keras_hub.src.models.albert.albert_masked_lm import AlbertMaskedLM +from keras_hub.src.models.albert.albert_backbone import ( + AlbertBackbone as AlbertBackbone, +) +from keras_hub.src.models.albert.albert_masked_lm import ( + AlbertMaskedLM as AlbertMaskedLM, +) from keras_hub.src.models.albert.albert_masked_lm_preprocessor import ( - AlbertMaskedLMPreprocessor, + AlbertMaskedLMPreprocessor as AlbertMaskedLMPreprocessor, ) from keras_hub.src.models.albert.albert_text_classifier import ( - AlbertTextClassifier, + AlbertTextClassifier as AlbertClassifier, ) from keras_hub.src.models.albert.albert_text_classifier import ( - AlbertTextClassifier as AlbertClassifier, + AlbertTextClassifier as AlbertTextClassifier, ) from keras_hub.src.models.albert.albert_text_classifier_preprocessor import ( - AlbertTextClassifierPreprocessor, + AlbertTextClassifierPreprocessor as AlbertPreprocessor, ) from keras_hub.src.models.albert.albert_text_classifier_preprocessor import ( - AlbertTextClassifierPreprocessor as AlbertPreprocessor, + AlbertTextClassifierPreprocessor as AlbertTextClassifierPreprocessor, +) +from keras_hub.src.models.albert.albert_tokenizer import ( + AlbertTokenizer as AlbertTokenizer, +) +from keras_hub.src.models.backbone import Backbone as Backbone +from keras_hub.src.models.bart.bart_backbone import BartBackbone as BartBackbone +from keras_hub.src.models.bart.bart_seq_2_seq_lm import ( + BartSeq2SeqLM as BartSeq2SeqLM, ) -from keras_hub.src.models.albert.albert_tokenizer import AlbertTokenizer -from keras_hub.src.models.backbone import Backbone -from keras_hub.src.models.bart.bart_backbone import BartBackbone -from keras_hub.src.models.bart.bart_seq_2_seq_lm import BartSeq2SeqLM from keras_hub.src.models.bart.bart_seq_2_seq_lm_preprocessor import ( - BartSeq2SeqLMPreprocessor, -) -from keras_hub.src.models.bart.bart_tokenizer import BartTokenizer -from keras_hub.src.models.basnet.basnet import BASNetImageSegmenter -from keras_hub.src.models.basnet.basnet_backbone import BASNetBackbone -from keras_hub.src.models.basnet.basnet_preprocessor import BASNetPreprocessor -from keras_hub.src.models.bert.bert_backbone import BertBackbone -from keras_hub.src.models.bert.bert_masked_lm import BertMaskedLM + BartSeq2SeqLMPreprocessor as BartSeq2SeqLMPreprocessor, +) +from keras_hub.src.models.bart.bart_tokenizer import ( + BartTokenizer as BartTokenizer, +) +from keras_hub.src.models.basnet.basnet import ( + BASNetImageSegmenter as BASNetImageSegmenter, +) +from keras_hub.src.models.basnet.basnet_backbone import ( + BASNetBackbone as BASNetBackbone, +) +from keras_hub.src.models.basnet.basnet_preprocessor import ( + BASNetPreprocessor as BASNetPreprocessor, +) +from keras_hub.src.models.bert.bert_backbone import BertBackbone as BertBackbone +from keras_hub.src.models.bert.bert_masked_lm import ( + BertMaskedLM as BertMaskedLM, +) from keras_hub.src.models.bert.bert_masked_lm_preprocessor import ( - BertMaskedLMPreprocessor, + BertMaskedLMPreprocessor as BertMaskedLMPreprocessor, ) -from keras_hub.src.models.bert.bert_text_classifier import BertTextClassifier from keras_hub.src.models.bert.bert_text_classifier import ( BertTextClassifier as BertClassifier, ) -from keras_hub.src.models.bert.bert_text_classifier_preprocessor import ( - BertTextClassifierPreprocessor, +from keras_hub.src.models.bert.bert_text_classifier import ( + BertTextClassifier as BertTextClassifier, ) from keras_hub.src.models.bert.bert_text_classifier_preprocessor import ( BertTextClassifierPreprocessor as BertPreprocessor, ) -from keras_hub.src.models.bert.bert_tokenizer import BertTokenizer -from keras_hub.src.models.bloom.bloom_backbone import BloomBackbone -from keras_hub.src.models.bloom.bloom_causal_lm import BloomCausalLM +from keras_hub.src.models.bert.bert_text_classifier_preprocessor import ( + BertTextClassifierPreprocessor as BertTextClassifierPreprocessor, +) +from keras_hub.src.models.bert.bert_tokenizer import ( + BertTokenizer as BertTokenizer, +) +from keras_hub.src.models.bloom.bloom_backbone import ( + BloomBackbone as BloomBackbone, +) +from keras_hub.src.models.bloom.bloom_causal_lm import ( + BloomCausalLM as BloomCausalLM, +) from keras_hub.src.models.bloom.bloom_causal_lm_preprocessor import ( - BloomCausalLMPreprocessor, -) -from keras_hub.src.models.bloom.bloom_tokenizer import BloomTokenizer -from keras_hub.src.models.causal_lm import CausalLM -from keras_hub.src.models.causal_lm_preprocessor import CausalLMPreprocessor -from keras_hub.src.models.clip.clip_backbone import CLIPBackbone -from keras_hub.src.models.clip.clip_preprocessor import CLIPPreprocessor -from keras_hub.src.models.clip.clip_text_encoder import CLIPTextEncoder -from keras_hub.src.models.clip.clip_tokenizer import CLIPTokenizer -from keras_hub.src.models.clip.clip_vision_encoder import CLIPVisionEncoder -from keras_hub.src.models.cspnet.cspnet_backbone import CSPNetBackbone + BloomCausalLMPreprocessor as BloomCausalLMPreprocessor, +) +from keras_hub.src.models.bloom.bloom_tokenizer import ( + BloomTokenizer as BloomTokenizer, +) +from keras_hub.src.models.causal_lm import CausalLM as CausalLM +from keras_hub.src.models.causal_lm_preprocessor import ( + CausalLMPreprocessor as CausalLMPreprocessor, +) +from keras_hub.src.models.clip.clip_backbone import CLIPBackbone as CLIPBackbone +from keras_hub.src.models.clip.clip_preprocessor import ( + CLIPPreprocessor as CLIPPreprocessor, +) +from keras_hub.src.models.clip.clip_text_encoder import ( + CLIPTextEncoder as CLIPTextEncoder, +) +from keras_hub.src.models.clip.clip_tokenizer import ( + CLIPTokenizer as CLIPTokenizer, +) +from keras_hub.src.models.clip.clip_vision_encoder import ( + CLIPVisionEncoder as CLIPVisionEncoder, +) +from keras_hub.src.models.cspnet.cspnet_backbone import ( + CSPNetBackbone as CSPNetBackbone, +) from keras_hub.src.models.cspnet.cspnet_image_classifier import ( - CSPNetImageClassifier, + CSPNetImageClassifier as CSPNetImageClassifier, ) from keras_hub.src.models.cspnet.cspnet_image_classifier_preprocessor import ( - CSPNetImageClassifierPreprocessor, + CSPNetImageClassifierPreprocessor as CSPNetImageClassifierPreprocessor, ) from keras_hub.src.models.deberta_v3.deberta_v3_backbone import ( - DebertaV3Backbone, + DebertaV3Backbone as DebertaV3Backbone, ) from keras_hub.src.models.deberta_v3.deberta_v3_masked_lm import ( - DebertaV3MaskedLM, + DebertaV3MaskedLM as DebertaV3MaskedLM, ) from keras_hub.src.models.deberta_v3.deberta_v3_masked_lm_preprocessor import ( - DebertaV3MaskedLMPreprocessor, + DebertaV3MaskedLMPreprocessor as DebertaV3MaskedLMPreprocessor, ) from keras_hub.src.models.deberta_v3.deberta_v3_text_classifier import ( - DebertaV3TextClassifier, + DebertaV3TextClassifier as DebertaV3Classifier, ) from keras_hub.src.models.deberta_v3.deberta_v3_text_classifier import ( - DebertaV3TextClassifier as DebertaV3Classifier, + DebertaV3TextClassifier as DebertaV3TextClassifier, ) from keras_hub.src.models.deberta_v3.deberta_v3_text_classifier_preprocessor import ( - DebertaV3TextClassifierPreprocessor, + DebertaV3TextClassifierPreprocessor as DebertaV3Preprocessor, ) from keras_hub.src.models.deberta_v3.deberta_v3_text_classifier_preprocessor import ( - DebertaV3TextClassifierPreprocessor as DebertaV3Preprocessor, + DebertaV3TextClassifierPreprocessor as DebertaV3TextClassifierPreprocessor, ) from keras_hub.src.models.deberta_v3.deberta_v3_tokenizer import ( - DebertaV3Tokenizer, + DebertaV3Tokenizer as DebertaV3Tokenizer, ) from keras_hub.src.models.deeplab_v3.deeplab_v3_backbone import ( - DeepLabV3Backbone, + DeepLabV3Backbone as DeepLabV3Backbone, ) from keras_hub.src.models.deeplab_v3.deeplab_v3_image_segmeter_preprocessor import ( - DeepLabV3ImageSegmenterPreprocessor, + DeepLabV3ImageSegmenterPreprocessor as DeepLabV3ImageSegmenterPreprocessor, ) from keras_hub.src.models.deeplab_v3.deeplab_v3_segmenter import ( - DeepLabV3ImageSegmenter, + DeepLabV3ImageSegmenter as DeepLabV3ImageSegmenter, +) +from keras_hub.src.models.densenet.densenet_backbone import ( + DenseNetBackbone as DenseNetBackbone, ) -from keras_hub.src.models.densenet.densenet_backbone import DenseNetBackbone from keras_hub.src.models.densenet.densenet_image_classifier import ( - DenseNetImageClassifier, + DenseNetImageClassifier as DenseNetImageClassifier, ) from keras_hub.src.models.densenet.densenet_image_classifier_preprocessor import ( - DenseNetImageClassifierPreprocessor, + DenseNetImageClassifierPreprocessor as DenseNetImageClassifierPreprocessor, ) from keras_hub.src.models.distil_bert.distil_bert_backbone import ( - DistilBertBackbone, + DistilBertBackbone as DistilBertBackbone, ) from keras_hub.src.models.distil_bert.distil_bert_masked_lm import ( - DistilBertMaskedLM, + DistilBertMaskedLM as DistilBertMaskedLM, ) from keras_hub.src.models.distil_bert.distil_bert_masked_lm_preprocessor import ( - DistilBertMaskedLMPreprocessor, + DistilBertMaskedLMPreprocessor as DistilBertMaskedLMPreprocessor, ) from keras_hub.src.models.distil_bert.distil_bert_text_classifier import ( - DistilBertTextClassifier, + DistilBertTextClassifier as DistilBertClassifier, ) from keras_hub.src.models.distil_bert.distil_bert_text_classifier import ( - DistilBertTextClassifier as DistilBertClassifier, + DistilBertTextClassifier as DistilBertTextClassifier, ) from keras_hub.src.models.distil_bert.distil_bert_text_classifier_preprocessor import ( - DistilBertTextClassifierPreprocessor, + DistilBertTextClassifierPreprocessor as DistilBertPreprocessor, ) from keras_hub.src.models.distil_bert.distil_bert_text_classifier_preprocessor import ( - DistilBertTextClassifierPreprocessor as DistilBertPreprocessor, + DistilBertTextClassifierPreprocessor as DistilBertTextClassifierPreprocessor, ) from keras_hub.src.models.distil_bert.distil_bert_tokenizer import ( - DistilBertTokenizer, + DistilBertTokenizer as DistilBertTokenizer, ) from keras_hub.src.models.efficientnet.efficientnet_backbone import ( - EfficientNetBackbone, + EfficientNetBackbone as EfficientNetBackbone, ) from keras_hub.src.models.efficientnet.efficientnet_image_classifier import ( - EfficientNetImageClassifier, + EfficientNetImageClassifier as EfficientNetImageClassifier, ) from keras_hub.src.models.efficientnet.efficientnet_image_classifier_preprocessor import ( - EfficientNetImageClassifierPreprocessor, -) -from keras_hub.src.models.electra.electra_backbone import ElectraBackbone -from keras_hub.src.models.electra.electra_tokenizer import ElectraTokenizer -from keras_hub.src.models.esm.esm_backbone import ESMBackbone -from keras_hub.src.models.esm.esm_backbone import ESMBackbone as ESM2Backbone -from keras_hub.src.models.esm.esm_classifier import ESMProteinClassifier -from keras_hub.src.models.esm.esm_classifier_preprocessor import ( - ESMProteinClassifierPreprocessor, -) -from keras_hub.src.models.esm.esm_masked_plm import ESMMaskedPLM -from keras_hub.src.models.esm.esm_masked_plm import ( - ESMMaskedPLM as ESM2MaskedPLM, -) -from keras_hub.src.models.esm.esm_masked_plm_preprocessor import ( - ESMMaskedPLMPreprocessor, -) -from keras_hub.src.models.esm.esm_tokenizer import ESMTokenizer -from keras_hub.src.models.f_net.f_net_backbone import FNetBackbone -from keras_hub.src.models.f_net.f_net_masked_lm import FNetMaskedLM + EfficientNetImageClassifierPreprocessor as EfficientNetImageClassifierPreprocessor, +) +from keras_hub.src.models.electra.electra_backbone import ( + ElectraBackbone as ElectraBackbone, +) +from keras_hub.src.models.electra.electra_tokenizer import ( + ElectraTokenizer as ElectraTokenizer, +) +from keras_hub.src.models.f_net.f_net_backbone import ( + FNetBackbone as FNetBackbone, +) +from keras_hub.src.models.f_net.f_net_masked_lm import ( + FNetMaskedLM as FNetMaskedLM, +) from keras_hub.src.models.f_net.f_net_masked_lm_preprocessor import ( - FNetMaskedLMPreprocessor, + FNetMaskedLMPreprocessor as FNetMaskedLMPreprocessor, ) -from keras_hub.src.models.f_net.f_net_text_classifier import FNetTextClassifier from keras_hub.src.models.f_net.f_net_text_classifier import ( FNetTextClassifier as FNetClassifier, ) -from keras_hub.src.models.f_net.f_net_text_classifier_preprocessor import ( - FNetTextClassifierPreprocessor, +from keras_hub.src.models.f_net.f_net_text_classifier import ( + FNetTextClassifier as FNetTextClassifier, ) from keras_hub.src.models.f_net.f_net_text_classifier_preprocessor import ( FNetTextClassifierPreprocessor as FNetPreprocessor, ) -from keras_hub.src.models.f_net.f_net_tokenizer import FNetTokenizer -from keras_hub.src.models.falcon.falcon_backbone import FalconBackbone -from keras_hub.src.models.falcon.falcon_causal_lm import FalconCausalLM +from keras_hub.src.models.f_net.f_net_text_classifier_preprocessor import ( + FNetTextClassifierPreprocessor as FNetTextClassifierPreprocessor, +) +from keras_hub.src.models.f_net.f_net_tokenizer import ( + FNetTokenizer as FNetTokenizer, +) +from keras_hub.src.models.falcon.falcon_backbone import ( + FalconBackbone as FalconBackbone, +) +from keras_hub.src.models.falcon.falcon_causal_lm import ( + FalconCausalLM as FalconCausalLM, +) from keras_hub.src.models.falcon.falcon_causal_lm_preprocessor import ( - FalconCausalLMPreprocessor, + FalconCausalLMPreprocessor as FalconCausalLMPreprocessor, +) +from keras_hub.src.models.falcon.falcon_tokenizer import ( + FalconTokenizer as FalconTokenizer, +) +from keras_hub.src.models.feature_pyramid_backbone import ( + FeaturePyramidBackbone as FeaturePyramidBackbone, +) +from keras_hub.src.models.flux.flux_model import FluxBackbone as FluxBackbone +from keras_hub.src.models.flux.flux_text_to_image import ( + FluxTextToImage as FluxTextToImage, ) -from keras_hub.src.models.falcon.falcon_tokenizer import FalconTokenizer -from keras_hub.src.models.feature_pyramid_backbone import FeaturePyramidBackbone -from keras_hub.src.models.flux.flux_model import FluxBackbone -from keras_hub.src.models.flux.flux_text_to_image import FluxTextToImage from keras_hub.src.models.flux.flux_text_to_image_preprocessor import ( - FluxTextToImagePreprocessor, + FluxTextToImagePreprocessor as FluxTextToImagePreprocessor, +) +from keras_hub.src.models.gemma.gemma_backbone import ( + GemmaBackbone as GemmaBackbone, +) +from keras_hub.src.models.gemma.gemma_causal_lm import ( + GemmaCausalLM as GemmaCausalLM, ) -from keras_hub.src.models.gemma.gemma_backbone import GemmaBackbone -from keras_hub.src.models.gemma.gemma_causal_lm import GemmaCausalLM from keras_hub.src.models.gemma.gemma_causal_lm_preprocessor import ( - GemmaCausalLMPreprocessor, + GemmaCausalLMPreprocessor as GemmaCausalLMPreprocessor, +) +from keras_hub.src.models.gemma.gemma_tokenizer import ( + GemmaTokenizer as GemmaTokenizer, +) +from keras_hub.src.models.gemma3.gemma3_backbone import ( + Gemma3Backbone as Gemma3Backbone, +) +from keras_hub.src.models.gemma3.gemma3_causal_lm import ( + Gemma3CausalLM as Gemma3CausalLM, ) -from keras_hub.src.models.gemma.gemma_tokenizer import GemmaTokenizer -from keras_hub.src.models.gemma3.gemma3_backbone import Gemma3Backbone -from keras_hub.src.models.gemma3.gemma3_causal_lm import Gemma3CausalLM from keras_hub.src.models.gemma3.gemma3_causal_lm_preprocessor import ( - Gemma3CausalLMPreprocessor, + Gemma3CausalLMPreprocessor as Gemma3CausalLMPreprocessor, +) +from keras_hub.src.models.gemma3.gemma3_tokenizer import ( + Gemma3Tokenizer as Gemma3Tokenizer, ) -from keras_hub.src.models.gemma3.gemma3_tokenizer import Gemma3Tokenizer from keras_hub.src.models.gemma3.gemma3_vision_encoder import ( - Gemma3VisionEncoder, + Gemma3VisionEncoder as Gemma3VisionEncoder, +) +from keras_hub.src.models.gpt2.gpt2_backbone import GPT2Backbone as GPT2Backbone +from keras_hub.src.models.gpt2.gpt2_causal_lm import ( + GPT2CausalLM as GPT2CausalLM, ) -from keras_hub.src.models.gpt2.gpt2_backbone import GPT2Backbone -from keras_hub.src.models.gpt2.gpt2_causal_lm import GPT2CausalLM from keras_hub.src.models.gpt2.gpt2_causal_lm_preprocessor import ( - GPT2CausalLMPreprocessor, + GPT2CausalLMPreprocessor as GPT2CausalLMPreprocessor, +) +from keras_hub.src.models.gpt2.gpt2_preprocessor import ( + GPT2Preprocessor as GPT2Preprocessor, +) +from keras_hub.src.models.gpt2.gpt2_tokenizer import ( + GPT2Tokenizer as GPT2Tokenizer, +) +from keras_hub.src.models.gpt_neo_x.gpt_neo_x_backbone import ( + GPTNeoXBackbone as GPTNeoXBackbone, +) +from keras_hub.src.models.gpt_neo_x.gpt_neo_x_causal_lm import ( + GPTNeoXCausalLM as GPTNeoXCausalLM, ) -from keras_hub.src.models.gpt2.gpt2_preprocessor import GPT2Preprocessor -from keras_hub.src.models.gpt2.gpt2_tokenizer import GPT2Tokenizer -from keras_hub.src.models.gpt_neo_x.gpt_neo_x_backbone import GPTNeoXBackbone -from keras_hub.src.models.gpt_neo_x.gpt_neo_x_causal_lm import GPTNeoXCausalLM from keras_hub.src.models.gpt_neo_x.gpt_neo_x_causal_lm_preprocessor import ( - GPTNeoXCausalLMPreprocessor, + GPTNeoXCausalLMPreprocessor as GPTNeoXCausalLMPreprocessor, +) +from keras_hub.src.models.gpt_neo_x.gpt_neo_x_tokenizer import ( + GPTNeoXTokenizer as GPTNeoXTokenizer, +) +from keras_hub.src.models.image_classifier import ( + ImageClassifier as ImageClassifier, ) -from keras_hub.src.models.gpt_neo_x.gpt_neo_x_tokenizer import GPTNeoXTokenizer -from keras_hub.src.models.image_classifier import ImageClassifier from keras_hub.src.models.image_classifier_preprocessor import ( - ImageClassifierPreprocessor, + ImageClassifierPreprocessor as ImageClassifierPreprocessor, +) +from keras_hub.src.models.image_segmenter import ( + ImageSegmenter as ImageSegmenter, ) -from keras_hub.src.models.image_segmenter import ImageSegmenter from keras_hub.src.models.image_segmenter_preprocessor import ( - ImageSegmenterPreprocessor, + ImageSegmenterPreprocessor as ImageSegmenterPreprocessor, +) +from keras_hub.src.models.image_to_image import ImageToImage as ImageToImage +from keras_hub.src.models.inpaint import Inpaint as Inpaint +from keras_hub.src.models.llama.llama_backbone import ( + LlamaBackbone as LlamaBackbone, +) +from keras_hub.src.models.llama.llama_causal_lm import ( + LlamaCausalLM as LlamaCausalLM, ) -from keras_hub.src.models.image_to_image import ImageToImage -from keras_hub.src.models.inpaint import Inpaint -from keras_hub.src.models.llama.llama_backbone import LlamaBackbone -from keras_hub.src.models.llama.llama_causal_lm import LlamaCausalLM from keras_hub.src.models.llama.llama_causal_lm_preprocessor import ( - LlamaCausalLMPreprocessor, + LlamaCausalLMPreprocessor as LlamaCausalLMPreprocessor, +) +from keras_hub.src.models.llama.llama_tokenizer import ( + LlamaTokenizer as LlamaTokenizer, +) +from keras_hub.src.models.llama3.llama3_backbone import ( + Llama3Backbone as Llama3Backbone, +) +from keras_hub.src.models.llama3.llama3_causal_lm import ( + Llama3CausalLM as Llama3CausalLM, ) -from keras_hub.src.models.llama.llama_tokenizer import LlamaTokenizer -from keras_hub.src.models.llama3.llama3_backbone import Llama3Backbone -from keras_hub.src.models.llama3.llama3_causal_lm import Llama3CausalLM from keras_hub.src.models.llama3.llama3_causal_lm_preprocessor import ( - Llama3CausalLMPreprocessor, + Llama3CausalLMPreprocessor as Llama3CausalLMPreprocessor, +) +from keras_hub.src.models.llama3.llama3_tokenizer import ( + Llama3Tokenizer as Llama3Tokenizer, +) +from keras_hub.src.models.masked_lm import MaskedLM as MaskedLM +from keras_hub.src.models.masked_lm_preprocessor import ( + MaskedLMPreprocessor as MaskedLMPreprocessor, +) +from keras_hub.src.models.mistral.mistral_backbone import ( + MistralBackbone as MistralBackbone, +) +from keras_hub.src.models.mistral.mistral_causal_lm import ( + MistralCausalLM as MistralCausalLM, ) -from keras_hub.src.models.llama3.llama3_tokenizer import Llama3Tokenizer -from keras_hub.src.models.masked_lm import MaskedLM -from keras_hub.src.models.masked_lm_preprocessor import MaskedLMPreprocessor -from keras_hub.src.models.mistral.mistral_backbone import MistralBackbone -from keras_hub.src.models.mistral.mistral_causal_lm import MistralCausalLM from keras_hub.src.models.mistral.mistral_causal_lm_preprocessor import ( - MistralCausalLMPreprocessor, + MistralCausalLMPreprocessor as MistralCausalLMPreprocessor, +) +from keras_hub.src.models.mistral.mistral_tokenizer import ( + MistralTokenizer as MistralTokenizer, +) +from keras_hub.src.models.mit.mit_backbone import MiTBackbone as MiTBackbone +from keras_hub.src.models.mit.mit_image_classifier import ( + MiTImageClassifier as MiTImageClassifier, ) -from keras_hub.src.models.mistral.mistral_tokenizer import MistralTokenizer -from keras_hub.src.models.mit.mit_backbone import MiTBackbone -from keras_hub.src.models.mit.mit_image_classifier import MiTImageClassifier from keras_hub.src.models.mit.mit_image_classifier_preprocessor import ( - MiTImageClassifierPreprocessor, + MiTImageClassifierPreprocessor as MiTImageClassifierPreprocessor, +) +from keras_hub.src.models.mobilenet.mobilenet_backbone import ( + MobileNetBackbone as MobileNetBackbone, ) -from keras_hub.src.models.mobilenet.mobilenet_backbone import MobileNetBackbone from keras_hub.src.models.mobilenet.mobilenet_image_classifier import ( - MobileNetImageClassifier, + MobileNetImageClassifier as MobileNetImageClassifier, ) from keras_hub.src.models.mobilenet.mobilenet_image_classifier_preprocessor import ( - MobileNetImageClassifierPreprocessor, + MobileNetImageClassifierPreprocessor as MobileNetImageClassifierPreprocessor, ) -from keras_hub.src.models.object_detector import ObjectDetector from keras_hub.src.models.object_detector import ( ObjectDetector as ImageObjectDetector, ) -from keras_hub.src.models.object_detector_preprocessor import ( - ObjectDetectorPreprocessor, +from keras_hub.src.models.object_detector import ( + ObjectDetector as ObjectDetector, ) from keras_hub.src.models.object_detector_preprocessor import ( ObjectDetectorPreprocessor as ImageObjectDetectorPreprocessor, ) -from keras_hub.src.models.opt.opt_backbone import OPTBackbone -from keras_hub.src.models.opt.opt_causal_lm import OPTCausalLM +from keras_hub.src.models.object_detector_preprocessor import ( + ObjectDetectorPreprocessor as ObjectDetectorPreprocessor, +) +from keras_hub.src.models.opt.opt_backbone import OPTBackbone as OPTBackbone +from keras_hub.src.models.opt.opt_causal_lm import OPTCausalLM as OPTCausalLM from keras_hub.src.models.opt.opt_causal_lm_preprocessor import ( - OPTCausalLMPreprocessor, + OPTCausalLMPreprocessor as OPTCausalLMPreprocessor, ) -from keras_hub.src.models.opt.opt_tokenizer import OPTTokenizer +from keras_hub.src.models.opt.opt_tokenizer import OPTTokenizer as OPTTokenizer from keras_hub.src.models.pali_gemma.pali_gemma_backbone import ( - PaliGemmaBackbone, + PaliGemmaBackbone as PaliGemmaBackbone, ) from keras_hub.src.models.pali_gemma.pali_gemma_causal_lm import ( - PaliGemmaCausalLM, + PaliGemmaCausalLM as PaliGemmaCausalLM, ) from keras_hub.src.models.pali_gemma.pali_gemma_causal_lm_preprocessor import ( - PaliGemmaCausalLMPreprocessor, + PaliGemmaCausalLMPreprocessor as PaliGemmaCausalLMPreprocessor, ) from keras_hub.src.models.pali_gemma.pali_gemma_tokenizer import ( - PaliGemmaTokenizer, + PaliGemmaTokenizer as PaliGemmaTokenizer, +) +from keras_hub.src.models.phi3.phi3_backbone import Phi3Backbone as Phi3Backbone +from keras_hub.src.models.phi3.phi3_causal_lm import ( + Phi3CausalLM as Phi3CausalLM, ) -from keras_hub.src.models.phi3.phi3_backbone import Phi3Backbone -from keras_hub.src.models.phi3.phi3_causal_lm import Phi3CausalLM from keras_hub.src.models.phi3.phi3_causal_lm_preprocessor import ( - Phi3CausalLMPreprocessor, + Phi3CausalLMPreprocessor as Phi3CausalLMPreprocessor, ) -from keras_hub.src.models.phi3.phi3_tokenizer import Phi3Tokenizer -from keras_hub.src.models.preprocessor import Preprocessor -from keras_hub.src.models.qwen.qwen_backbone import QwenBackbone +from keras_hub.src.models.phi3.phi3_tokenizer import ( + Phi3Tokenizer as Phi3Tokenizer, +) +from keras_hub.src.models.preprocessor import Preprocessor as Preprocessor from keras_hub.src.models.qwen.qwen_backbone import ( QwenBackbone as Qwen2Backbone, ) -from keras_hub.src.models.qwen.qwen_causal_lm import QwenCausalLM +from keras_hub.src.models.qwen.qwen_backbone import QwenBackbone as QwenBackbone from keras_hub.src.models.qwen.qwen_causal_lm import ( QwenCausalLM as Qwen2CausalLM, ) -from keras_hub.src.models.qwen.qwen_causal_lm_preprocessor import ( - QwenCausalLMPreprocessor, +from keras_hub.src.models.qwen.qwen_causal_lm import ( + QwenCausalLM as QwenCausalLM, ) from keras_hub.src.models.qwen.qwen_causal_lm_preprocessor import ( QwenCausalLMPreprocessor as Qwen2CausalLMPreprocessor, ) -from keras_hub.src.models.qwen.qwen_tokenizer import QwenTokenizer +from keras_hub.src.models.qwen.qwen_causal_lm_preprocessor import ( + QwenCausalLMPreprocessor as QwenCausalLMPreprocessor, +) from keras_hub.src.models.qwen.qwen_tokenizer import ( QwenTokenizer as Qwen2Tokenizer, ) -from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone +from keras_hub.src.models.qwen.qwen_tokenizer import ( + QwenTokenizer as QwenTokenizer, +) +from keras_hub.src.models.resnet.resnet_backbone import ( + ResNetBackbone as ResNetBackbone, +) from keras_hub.src.models.resnet.resnet_image_classifier import ( - ResNetImageClassifier, + ResNetImageClassifier as ResNetImageClassifier, ) from keras_hub.src.models.resnet.resnet_image_classifier_preprocessor import ( - ResNetImageClassifierPreprocessor, + ResNetImageClassifierPreprocessor as ResNetImageClassifierPreprocessor, +) +from keras_hub.src.models.retinanet.retinanet_backbone import ( + RetinaNetBackbone as RetinaNetBackbone, ) -from keras_hub.src.models.retinanet.retinanet_backbone import RetinaNetBackbone from keras_hub.src.models.retinanet.retinanet_object_detector import ( - RetinaNetObjectDetector, + RetinaNetObjectDetector as RetinaNetObjectDetector, ) from keras_hub.src.models.retinanet.retinanet_object_detector_preprocessor import ( - RetinaNetObjectDetectorPreprocessor, + RetinaNetObjectDetectorPreprocessor as RetinaNetObjectDetectorPreprocessor, +) +from keras_hub.src.models.roberta.roberta_backbone import ( + RobertaBackbone as RobertaBackbone, +) +from keras_hub.src.models.roberta.roberta_masked_lm import ( + RobertaMaskedLM as RobertaMaskedLM, ) -from keras_hub.src.models.roberta.roberta_backbone import RobertaBackbone -from keras_hub.src.models.roberta.roberta_masked_lm import RobertaMaskedLM from keras_hub.src.models.roberta.roberta_masked_lm_preprocessor import ( - RobertaMaskedLMPreprocessor, + RobertaMaskedLMPreprocessor as RobertaMaskedLMPreprocessor, ) from keras_hub.src.models.roberta.roberta_text_classifier import ( - RobertaTextClassifier, + RobertaTextClassifier as RobertaClassifier, ) from keras_hub.src.models.roberta.roberta_text_classifier import ( - RobertaTextClassifier as RobertaClassifier, + RobertaTextClassifier as RobertaTextClassifier, ) from keras_hub.src.models.roberta.roberta_text_classifier_preprocessor import ( - RobertaTextClassifierPreprocessor, + RobertaTextClassifierPreprocessor as RobertaPreprocessor, ) from keras_hub.src.models.roberta.roberta_text_classifier_preprocessor import ( - RobertaTextClassifierPreprocessor as RobertaPreprocessor, + RobertaTextClassifierPreprocessor as RobertaTextClassifierPreprocessor, +) +from keras_hub.src.models.roberta.roberta_tokenizer import ( + RobertaTokenizer as RobertaTokenizer, ) -from keras_hub.src.models.roberta.roberta_tokenizer import RobertaTokenizer from keras_hub.src.models.roformer_v2.roformer_v2_backbone import ( - RoformerV2Backbone, + RoformerV2Backbone as RoformerV2Backbone, ) from keras_hub.src.models.roformer_v2.roformer_v2_masked_lm import ( - RoformerV2MaskedLM, + RoformerV2MaskedLM as RoformerV2MaskedLM, ) from keras_hub.src.models.roformer_v2.roformer_v2_masked_lm_preprocessor import ( - RoformerV2MaskedLMPreprocessor, + RoformerV2MaskedLMPreprocessor as RoformerV2MaskedLMPreprocessor, ) from keras_hub.src.models.roformer_v2.roformer_v2_text_classifier import ( - RoformerV2TextClassifier, + RoformerV2TextClassifier as RoformerV2TextClassifier, ) from keras_hub.src.models.roformer_v2.roformer_v2_text_classifier_preprocessor import ( - RoformerV2TextClassifierPreprocessor, + RoformerV2TextClassifierPreprocessor as RoformerV2TextClassifierPreprocessor, ) from keras_hub.src.models.roformer_v2.roformer_v2_tokenizer import ( - RoformerV2Tokenizer, + RoformerV2Tokenizer as RoformerV2Tokenizer, +) +from keras_hub.src.models.sam.sam_backbone import SAMBackbone as SAMBackbone +from keras_hub.src.models.sam.sam_image_segmenter import ( + SAMImageSegmenter as SAMImageSegmenter, ) -from keras_hub.src.models.sam.sam_backbone import SAMBackbone -from keras_hub.src.models.sam.sam_image_segmenter import SAMImageSegmenter from keras_hub.src.models.sam.sam_image_segmenter_preprocessor import ( - SAMImageSegmenterPreprocessor, + SAMImageSegmenterPreprocessor as SAMImageSegmenterPreprocessor, +) +from keras_hub.src.models.segformer.segformer_backbone import ( + SegFormerBackbone as SegFormerBackbone, ) -from keras_hub.src.models.segformer.segformer_backbone import SegFormerBackbone from keras_hub.src.models.segformer.segformer_image_segmenter import ( - SegFormerImageSegmenter, + SegFormerImageSegmenter as SegFormerImageSegmenter, ) from keras_hub.src.models.segformer.segformer_image_segmenter_preprocessor import ( - SegFormerImageSegmenterPreprocessor, -) -from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM -from keras_hub.src.models.seq_2_seq_lm_preprocessor import Seq2SeqLMPreprocessor -from keras_hub.src.models.siglip.siglip_backbone import SigLIPBackbone -from keras_hub.src.models.siglip.siglip_preprocessor import SigLIPPreprocessor -from keras_hub.src.models.siglip.siglip_text_encoder import SigLIPTextEncoder -from keras_hub.src.models.siglip.siglip_tokenizer import SigLIPTokenizer + SegFormerImageSegmenterPreprocessor as SegFormerImageSegmenterPreprocessor, +) +from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM as Seq2SeqLM +from keras_hub.src.models.seq_2_seq_lm_preprocessor import ( + Seq2SeqLMPreprocessor as Seq2SeqLMPreprocessor, +) +from keras_hub.src.models.siglip.siglip_backbone import ( + SigLIPBackbone as SigLIPBackbone, +) +from keras_hub.src.models.siglip.siglip_preprocessor import ( + SigLIPPreprocessor as SigLIPPreprocessor, +) +from keras_hub.src.models.siglip.siglip_text_encoder import ( + SigLIPTextEncoder as SigLIPTextEncoder, +) +from keras_hub.src.models.siglip.siglip_tokenizer import ( + SigLIPTokenizer as SigLIPTokenizer, +) from keras_hub.src.models.siglip.siglip_vision_encoder import ( - SigLIPVisionEncoder, + SigLIPVisionEncoder as SigLIPVisionEncoder, ) from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import ( - StableDiffusion3Backbone, + StableDiffusion3Backbone as StableDiffusion3Backbone, ) from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_image_to_image import ( - StableDiffusion3ImageToImage, + StableDiffusion3ImageToImage as StableDiffusion3ImageToImage, ) from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_inpaint import ( - StableDiffusion3Inpaint, + StableDiffusion3Inpaint as StableDiffusion3Inpaint, ) from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image import ( - StableDiffusion3TextToImage, + StableDiffusion3TextToImage as StableDiffusion3TextToImage, ) from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image_preprocessor import ( - StableDiffusion3TextToImagePreprocessor, + StableDiffusion3TextToImagePreprocessor as StableDiffusion3TextToImagePreprocessor, ) -from keras_hub.src.models.t5.t5_backbone import T5Backbone -from keras_hub.src.models.t5.t5_preprocessor import T5Preprocessor -from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer -from keras_hub.src.models.task import Task -from keras_hub.src.models.text_classifier import TextClassifier +from keras_hub.src.models.t5.t5_backbone import T5Backbone as T5Backbone +from keras_hub.src.models.t5.t5_preprocessor import ( + T5Preprocessor as T5Preprocessor, +) +from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer as T5Tokenizer +from keras_hub.src.models.task import Task as Task from keras_hub.src.models.text_classifier import TextClassifier as Classifier +from keras_hub.src.models.text_classifier import ( + TextClassifier as TextClassifier, +) from keras_hub.src.models.text_classifier_preprocessor import ( - TextClassifierPreprocessor, + TextClassifierPreprocessor as TextClassifierPreprocessor, ) -from keras_hub.src.models.text_to_image import TextToImage +from keras_hub.src.models.text_to_image import TextToImage as TextToImage from keras_hub.src.models.text_to_image_preprocessor import ( - TextToImagePreprocessor, + TextToImagePreprocessor as TextToImagePreprocessor, +) +from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone as VGGBackbone +from keras_hub.src.models.vgg.vgg_image_classifier import ( + VGGImageClassifier as VGGImageClassifier, ) -from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone -from keras_hub.src.models.vgg.vgg_image_classifier import VGGImageClassifier from keras_hub.src.models.vgg.vgg_image_classifier_preprocessor import ( - VGGImageClassifierPreprocessor, + VGGImageClassifierPreprocessor as VGGImageClassifierPreprocessor, +) +from keras_hub.src.models.vit.vit_backbone import ViTBackbone as ViTBackbone +from keras_hub.src.models.vit.vit_image_classifier import ( + ViTImageClassifier as ViTImageClassifier, ) -from keras_hub.src.models.vit.vit_backbone import ViTBackbone -from keras_hub.src.models.vit.vit_image_classifier import ViTImageClassifier from keras_hub.src.models.vit.vit_image_classifier_preprocessor import ( - ViTImageClassifierPreprocessor, + ViTImageClassifierPreprocessor as ViTImageClassifierPreprocessor, +) +from keras_hub.src.models.vit_det.vit_det_backbone import ( + ViTDetBackbone as ViTDetBackbone, +) +from keras_hub.src.models.whisper.whisper_backbone import ( + WhisperBackbone as WhisperBackbone, +) +from keras_hub.src.models.whisper.whisper_tokenizer import ( + WhisperTokenizer as WhisperTokenizer, +) +from keras_hub.src.models.xception.xception_backbone import ( + XceptionBackbone as XceptionBackbone, ) -from keras_hub.src.models.vit_det.vit_det_backbone import ViTDetBackbone -from keras_hub.src.models.whisper.whisper_backbone import WhisperBackbone -from keras_hub.src.models.whisper.whisper_tokenizer import WhisperTokenizer -from keras_hub.src.models.xception.xception_backbone import XceptionBackbone from keras_hub.src.models.xception.xception_image_classifier import ( - XceptionImageClassifier, + XceptionImageClassifier as XceptionImageClassifier, ) from keras_hub.src.models.xception.xception_image_classifier_preprocessor import ( - XceptionImageClassifierPreprocessor, + XceptionImageClassifierPreprocessor as XceptionImageClassifierPreprocessor, ) from keras_hub.src.models.xlm_roberta.xlm_roberta_backbone import ( - XLMRobertaBackbone, + XLMRobertaBackbone as XLMRobertaBackbone, ) from keras_hub.src.models.xlm_roberta.xlm_roberta_masked_lm import ( - XLMRobertaMaskedLM, + XLMRobertaMaskedLM as XLMRobertaMaskedLM, ) from keras_hub.src.models.xlm_roberta.xlm_roberta_masked_lm_preprocessor import ( - XLMRobertaMaskedLMPreprocessor, + XLMRobertaMaskedLMPreprocessor as XLMRobertaMaskedLMPreprocessor, ) from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier import ( - XLMRobertaTextClassifier, + XLMRobertaTextClassifier as XLMRobertaClassifier, ) from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier import ( - XLMRobertaTextClassifier as XLMRobertaClassifier, + XLMRobertaTextClassifier as XLMRobertaTextClassifier, ) from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier_preprocessor import ( - XLMRobertaTextClassifierPreprocessor, + XLMRobertaTextClassifierPreprocessor as XLMRobertaPreprocessor, ) from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier_preprocessor import ( - XLMRobertaTextClassifierPreprocessor as XLMRobertaPreprocessor, + XLMRobertaTextClassifierPreprocessor as XLMRobertaTextClassifierPreprocessor, ) from keras_hub.src.models.xlm_roberta.xlm_roberta_tokenizer import ( - XLMRobertaTokenizer, + XLMRobertaTokenizer as XLMRobertaTokenizer, +) +from keras_hub.src.models.xlnet.xlnet_backbone import ( + XLNetBackbone as XLNetBackbone, ) -from keras_hub.src.models.xlnet.xlnet_backbone import XLNetBackbone -from keras_hub.src.tokenizers.tokenizer import Tokenizer +from keras_hub.src.tokenizers.tokenizer import Tokenizer as Tokenizer \ No newline at end of file diff --git a/keras_hub/api/tokenizers/__init__.py b/keras_hub/api/tokenizers/__init__.py index 3615e77581..3cc5d3bc20 100644 --- a/keras_hub/api/tokenizers/__init__.py +++ b/keras_hub/api/tokenizers/__init__.py @@ -4,60 +4,105 @@ since your modifications would be overwritten. """ -from keras_hub.src.models.albert.albert_tokenizer import AlbertTokenizer -from keras_hub.src.models.bart.bart_tokenizer import BartTokenizer -from keras_hub.src.models.bert.bert_tokenizer import BertTokenizer -from keras_hub.src.models.bloom.bloom_tokenizer import BloomTokenizer -from keras_hub.src.models.clip.clip_tokenizer import CLIPTokenizer +from keras_hub.src.models.albert.albert_tokenizer import ( + AlbertTokenizer as AlbertTokenizer, +) +from keras_hub.src.models.bart.bart_tokenizer import ( + BartTokenizer as BartTokenizer, +) +from keras_hub.src.models.bert.bert_tokenizer import ( + BertTokenizer as BertTokenizer, +) +from keras_hub.src.models.bloom.bloom_tokenizer import ( + BloomTokenizer as BloomTokenizer, +) +from keras_hub.src.models.clip.clip_tokenizer import ( + CLIPTokenizer as CLIPTokenizer, +) from keras_hub.src.models.deberta_v3.deberta_v3_tokenizer import ( - DebertaV3Tokenizer, + DebertaV3Tokenizer as DebertaV3Tokenizer, ) from keras_hub.src.models.distil_bert.distil_bert_tokenizer import ( - DistilBertTokenizer, -) -from keras_hub.src.models.electra.electra_tokenizer import ElectraTokenizer -from keras_hub.src.models.esm.esm_tokenizer import ESMTokenizer -from keras_hub.src.models.f_net.f_net_tokenizer import FNetTokenizer -from keras_hub.src.models.falcon.falcon_tokenizer import FalconTokenizer -from keras_hub.src.models.gemma.gemma_tokenizer import GemmaTokenizer -from keras_hub.src.models.gemma3.gemma3_tokenizer import Gemma3Tokenizer -from keras_hub.src.models.gpt2.gpt2_tokenizer import GPT2Tokenizer -from keras_hub.src.models.gpt_neo_x.gpt_neo_x_tokenizer import GPTNeoXTokenizer -from keras_hub.src.models.llama.llama_tokenizer import LlamaTokenizer -from keras_hub.src.models.llama3.llama3_tokenizer import Llama3Tokenizer -from keras_hub.src.models.mistral.mistral_tokenizer import MistralTokenizer -from keras_hub.src.models.opt.opt_tokenizer import OPTTokenizer + DistilBertTokenizer as DistilBertTokenizer, +) +from keras_hub.src.models.electra.electra_tokenizer import ( + ElectraTokenizer as ElectraTokenizer, +) +from keras_hub.src.models.f_net.f_net_tokenizer import ( + FNetTokenizer as FNetTokenizer, +) +from keras_hub.src.models.falcon.falcon_tokenizer import ( + FalconTokenizer as FalconTokenizer, +) +from keras_hub.src.models.gemma.gemma_tokenizer import ( + GemmaTokenizer as GemmaTokenizer, +) +from keras_hub.src.models.gemma3.gemma3_tokenizer import ( + Gemma3Tokenizer as Gemma3Tokenizer, +) +from keras_hub.src.models.gpt2.gpt2_tokenizer import ( + GPT2Tokenizer as GPT2Tokenizer, +) +from keras_hub.src.models.gpt_neo_x.gpt_neo_x_tokenizer import ( + GPTNeoXTokenizer as GPTNeoXTokenizer, +) +from keras_hub.src.models.llama.llama_tokenizer import ( + LlamaTokenizer as LlamaTokenizer, +) +from keras_hub.src.models.llama3.llama3_tokenizer import ( + Llama3Tokenizer as Llama3Tokenizer, +) +from keras_hub.src.models.mistral.mistral_tokenizer import ( + MistralTokenizer as MistralTokenizer, +) +from keras_hub.src.models.opt.opt_tokenizer import OPTTokenizer as OPTTokenizer from keras_hub.src.models.pali_gemma.pali_gemma_tokenizer import ( - PaliGemmaTokenizer, + PaliGemmaTokenizer as PaliGemmaTokenizer, +) +from keras_hub.src.models.phi3.phi3_tokenizer import ( + Phi3Tokenizer as Phi3Tokenizer, ) -from keras_hub.src.models.phi3.phi3_tokenizer import Phi3Tokenizer -from keras_hub.src.models.qwen.qwen_tokenizer import QwenTokenizer from keras_hub.src.models.qwen.qwen_tokenizer import ( QwenTokenizer as Qwen2Tokenizer, ) -from keras_hub.src.models.roberta.roberta_tokenizer import RobertaTokenizer +from keras_hub.src.models.qwen.qwen_tokenizer import ( + QwenTokenizer as QwenTokenizer, +) +from keras_hub.src.models.roberta.roberta_tokenizer import ( + RobertaTokenizer as RobertaTokenizer, +) from keras_hub.src.models.roformer_v2.roformer_v2_tokenizer import ( - RoformerV2Tokenizer, + RoformerV2Tokenizer as RoformerV2Tokenizer, +) +from keras_hub.src.models.siglip.siglip_tokenizer import ( + SigLIPTokenizer as SigLIPTokenizer, +) +from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer as T5Tokenizer +from keras_hub.src.models.whisper.whisper_tokenizer import ( + WhisperTokenizer as WhisperTokenizer, ) -from keras_hub.src.models.siglip.siglip_tokenizer import SigLIPTokenizer -from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer -from keras_hub.src.models.whisper.whisper_tokenizer import WhisperTokenizer from keras_hub.src.models.xlm_roberta.xlm_roberta_tokenizer import ( - XLMRobertaTokenizer, + XLMRobertaTokenizer as XLMRobertaTokenizer, +) +from keras_hub.src.tokenizers.byte_pair_tokenizer import ( + BytePairTokenizer as BytePairTokenizer, +) +from keras_hub.src.tokenizers.byte_tokenizer import ( + ByteTokenizer as ByteTokenizer, ) -from keras_hub.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer -from keras_hub.src.tokenizers.byte_tokenizer import ByteTokenizer from keras_hub.src.tokenizers.sentence_piece_tokenizer import ( - SentencePieceTokenizer, + SentencePieceTokenizer as SentencePieceTokenizer, ) from keras_hub.src.tokenizers.sentence_piece_tokenizer_trainer import ( - compute_sentence_piece_proto, + compute_sentence_piece_proto as compute_sentence_piece_proto, ) -from keras_hub.src.tokenizers.tokenizer import Tokenizer +from keras_hub.src.tokenizers.tokenizer import Tokenizer as Tokenizer from keras_hub.src.tokenizers.unicode_codepoint_tokenizer import ( - UnicodeCodepointTokenizer, + UnicodeCodepointTokenizer as UnicodeCodepointTokenizer, ) -from keras_hub.src.tokenizers.word_piece_tokenizer import WordPieceTokenizer -from keras_hub.src.tokenizers.word_piece_tokenizer_trainer import ( - compute_word_piece_vocabulary, +from keras_hub.src.tokenizers.word_piece_tokenizer import ( + WordPieceTokenizer as WordPieceTokenizer, ) +from keras_hub.src.tokenizers.word_piece_tokenizer_trainer import ( + compute_word_piece_vocabulary as compute_word_piece_vocabulary, +) \ No newline at end of file From 5cbf577369d71ace3db7234798e4d515de1e05de Mon Sep 17 00:00:00 2001 From: pass_lin <935499957@qq.com> Date: Sat, 10 May 2025 20:05:37 +0800 Subject: [PATCH 09/13] format --- keras_hub/api/models/__init__.py | 638 ++++++++++----------------- keras_hub/api/tokenizers/__init__.py | 129 ++---- 2 files changed, 277 insertions(+), 490 deletions(-) diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index a67fb94aab..d0e2c7333f 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -4,627 +4,463 @@ since your modifications would be overwritten. """ -from keras_hub.src.models.albert.albert_backbone import ( - AlbertBackbone as AlbertBackbone, -) -from keras_hub.src.models.albert.albert_masked_lm import ( - AlbertMaskedLM as AlbertMaskedLM, -) +from keras_hub.src.models.albert.albert_backbone import AlbertBackbone +from keras_hub.src.models.albert.albert_masked_lm import AlbertMaskedLM from keras_hub.src.models.albert.albert_masked_lm_preprocessor import ( - AlbertMaskedLMPreprocessor as AlbertMaskedLMPreprocessor, + AlbertMaskedLMPreprocessor, ) from keras_hub.src.models.albert.albert_text_classifier import ( - AlbertTextClassifier as AlbertClassifier, + AlbertTextClassifier, ) from keras_hub.src.models.albert.albert_text_classifier import ( - AlbertTextClassifier as AlbertTextClassifier, + AlbertTextClassifier as AlbertClassifier, ) from keras_hub.src.models.albert.albert_text_classifier_preprocessor import ( - AlbertTextClassifierPreprocessor as AlbertPreprocessor, + AlbertTextClassifierPreprocessor, ) from keras_hub.src.models.albert.albert_text_classifier_preprocessor import ( - AlbertTextClassifierPreprocessor as AlbertTextClassifierPreprocessor, -) -from keras_hub.src.models.albert.albert_tokenizer import ( - AlbertTokenizer as AlbertTokenizer, -) -from keras_hub.src.models.backbone import Backbone as Backbone -from keras_hub.src.models.bart.bart_backbone import BartBackbone as BartBackbone -from keras_hub.src.models.bart.bart_seq_2_seq_lm import ( - BartSeq2SeqLM as BartSeq2SeqLM, + AlbertTextClassifierPreprocessor as AlbertPreprocessor, ) +from keras_hub.src.models.albert.albert_tokenizer import AlbertTokenizer +from keras_hub.src.models.backbone import Backbone +from keras_hub.src.models.bart.bart_backbone import BartBackbone +from keras_hub.src.models.bart.bart_seq_2_seq_lm import BartSeq2SeqLM from keras_hub.src.models.bart.bart_seq_2_seq_lm_preprocessor import ( - BartSeq2SeqLMPreprocessor as BartSeq2SeqLMPreprocessor, -) -from keras_hub.src.models.bart.bart_tokenizer import ( - BartTokenizer as BartTokenizer, -) -from keras_hub.src.models.basnet.basnet import ( - BASNetImageSegmenter as BASNetImageSegmenter, -) -from keras_hub.src.models.basnet.basnet_backbone import ( - BASNetBackbone as BASNetBackbone, -) -from keras_hub.src.models.basnet.basnet_preprocessor import ( - BASNetPreprocessor as BASNetPreprocessor, -) -from keras_hub.src.models.bert.bert_backbone import BertBackbone as BertBackbone -from keras_hub.src.models.bert.bert_masked_lm import ( - BertMaskedLM as BertMaskedLM, -) + BartSeq2SeqLMPreprocessor, +) +from keras_hub.src.models.bart.bart_tokenizer import BartTokenizer +from keras_hub.src.models.basnet.basnet import BASNetImageSegmenter +from keras_hub.src.models.basnet.basnet_backbone import BASNetBackbone +from keras_hub.src.models.basnet.basnet_preprocessor import BASNetPreprocessor +from keras_hub.src.models.bert.bert_backbone import BertBackbone +from keras_hub.src.models.bert.bert_masked_lm import BertMaskedLM from keras_hub.src.models.bert.bert_masked_lm_preprocessor import ( - BertMaskedLMPreprocessor as BertMaskedLMPreprocessor, + BertMaskedLMPreprocessor, ) +from keras_hub.src.models.bert.bert_text_classifier import BertTextClassifier from keras_hub.src.models.bert.bert_text_classifier import ( BertTextClassifier as BertClassifier, ) -from keras_hub.src.models.bert.bert_text_classifier import ( - BertTextClassifier as BertTextClassifier, -) from keras_hub.src.models.bert.bert_text_classifier_preprocessor import ( - BertTextClassifierPreprocessor as BertPreprocessor, + BertTextClassifierPreprocessor, ) from keras_hub.src.models.bert.bert_text_classifier_preprocessor import ( - BertTextClassifierPreprocessor as BertTextClassifierPreprocessor, -) -from keras_hub.src.models.bert.bert_tokenizer import ( - BertTokenizer as BertTokenizer, -) -from keras_hub.src.models.bloom.bloom_backbone import ( - BloomBackbone as BloomBackbone, -) -from keras_hub.src.models.bloom.bloom_causal_lm import ( - BloomCausalLM as BloomCausalLM, + BertTextClassifierPreprocessor as BertPreprocessor, ) +from keras_hub.src.models.bert.bert_tokenizer import BertTokenizer +from keras_hub.src.models.bloom.bloom_backbone import BloomBackbone +from keras_hub.src.models.bloom.bloom_causal_lm import BloomCausalLM from keras_hub.src.models.bloom.bloom_causal_lm_preprocessor import ( - BloomCausalLMPreprocessor as BloomCausalLMPreprocessor, -) -from keras_hub.src.models.bloom.bloom_tokenizer import ( - BloomTokenizer as BloomTokenizer, -) -from keras_hub.src.models.causal_lm import CausalLM as CausalLM -from keras_hub.src.models.causal_lm_preprocessor import ( - CausalLMPreprocessor as CausalLMPreprocessor, -) -from keras_hub.src.models.clip.clip_backbone import CLIPBackbone as CLIPBackbone -from keras_hub.src.models.clip.clip_preprocessor import ( - CLIPPreprocessor as CLIPPreprocessor, -) -from keras_hub.src.models.clip.clip_text_encoder import ( - CLIPTextEncoder as CLIPTextEncoder, -) -from keras_hub.src.models.clip.clip_tokenizer import ( - CLIPTokenizer as CLIPTokenizer, -) -from keras_hub.src.models.clip.clip_vision_encoder import ( - CLIPVisionEncoder as CLIPVisionEncoder, -) -from keras_hub.src.models.cspnet.cspnet_backbone import ( - CSPNetBackbone as CSPNetBackbone, -) + BloomCausalLMPreprocessor, +) +from keras_hub.src.models.bloom.bloom_tokenizer import BloomTokenizer +from keras_hub.src.models.causal_lm import CausalLM +from keras_hub.src.models.causal_lm_preprocessor import CausalLMPreprocessor +from keras_hub.src.models.clip.clip_backbone import CLIPBackbone +from keras_hub.src.models.clip.clip_preprocessor import CLIPPreprocessor +from keras_hub.src.models.clip.clip_text_encoder import CLIPTextEncoder +from keras_hub.src.models.clip.clip_tokenizer import CLIPTokenizer +from keras_hub.src.models.clip.clip_vision_encoder import CLIPVisionEncoder +from keras_hub.src.models.cspnet.cspnet_backbone import CSPNetBackbone from keras_hub.src.models.cspnet.cspnet_image_classifier import ( - CSPNetImageClassifier as CSPNetImageClassifier, + CSPNetImageClassifier, ) from keras_hub.src.models.cspnet.cspnet_image_classifier_preprocessor import ( - CSPNetImageClassifierPreprocessor as CSPNetImageClassifierPreprocessor, + CSPNetImageClassifierPreprocessor, ) from keras_hub.src.models.deberta_v3.deberta_v3_backbone import ( - DebertaV3Backbone as DebertaV3Backbone, + DebertaV3Backbone, ) from keras_hub.src.models.deberta_v3.deberta_v3_masked_lm import ( - DebertaV3MaskedLM as DebertaV3MaskedLM, + DebertaV3MaskedLM, ) from keras_hub.src.models.deberta_v3.deberta_v3_masked_lm_preprocessor import ( - DebertaV3MaskedLMPreprocessor as DebertaV3MaskedLMPreprocessor, + DebertaV3MaskedLMPreprocessor, ) from keras_hub.src.models.deberta_v3.deberta_v3_text_classifier import ( - DebertaV3TextClassifier as DebertaV3Classifier, + DebertaV3TextClassifier, ) from keras_hub.src.models.deberta_v3.deberta_v3_text_classifier import ( - DebertaV3TextClassifier as DebertaV3TextClassifier, + DebertaV3TextClassifier as DebertaV3Classifier, ) from keras_hub.src.models.deberta_v3.deberta_v3_text_classifier_preprocessor import ( - DebertaV3TextClassifierPreprocessor as DebertaV3Preprocessor, + DebertaV3TextClassifierPreprocessor, ) from keras_hub.src.models.deberta_v3.deberta_v3_text_classifier_preprocessor import ( - DebertaV3TextClassifierPreprocessor as DebertaV3TextClassifierPreprocessor, + DebertaV3TextClassifierPreprocessor as DebertaV3Preprocessor, ) from keras_hub.src.models.deberta_v3.deberta_v3_tokenizer import ( - DebertaV3Tokenizer as DebertaV3Tokenizer, + DebertaV3Tokenizer, ) from keras_hub.src.models.deeplab_v3.deeplab_v3_backbone import ( - DeepLabV3Backbone as DeepLabV3Backbone, + DeepLabV3Backbone, ) from keras_hub.src.models.deeplab_v3.deeplab_v3_image_segmeter_preprocessor import ( - DeepLabV3ImageSegmenterPreprocessor as DeepLabV3ImageSegmenterPreprocessor, + DeepLabV3ImageSegmenterPreprocessor, ) from keras_hub.src.models.deeplab_v3.deeplab_v3_segmenter import ( - DeepLabV3ImageSegmenter as DeepLabV3ImageSegmenter, -) -from keras_hub.src.models.densenet.densenet_backbone import ( - DenseNetBackbone as DenseNetBackbone, + DeepLabV3ImageSegmenter, ) +from keras_hub.src.models.densenet.densenet_backbone import DenseNetBackbone from keras_hub.src.models.densenet.densenet_image_classifier import ( - DenseNetImageClassifier as DenseNetImageClassifier, + DenseNetImageClassifier, ) from keras_hub.src.models.densenet.densenet_image_classifier_preprocessor import ( - DenseNetImageClassifierPreprocessor as DenseNetImageClassifierPreprocessor, + DenseNetImageClassifierPreprocessor, ) from keras_hub.src.models.distil_bert.distil_bert_backbone import ( - DistilBertBackbone as DistilBertBackbone, + DistilBertBackbone, ) from keras_hub.src.models.distil_bert.distil_bert_masked_lm import ( - DistilBertMaskedLM as DistilBertMaskedLM, + DistilBertMaskedLM, ) from keras_hub.src.models.distil_bert.distil_bert_masked_lm_preprocessor import ( - DistilBertMaskedLMPreprocessor as DistilBertMaskedLMPreprocessor, + DistilBertMaskedLMPreprocessor, ) from keras_hub.src.models.distil_bert.distil_bert_text_classifier import ( - DistilBertTextClassifier as DistilBertClassifier, + DistilBertTextClassifier, ) from keras_hub.src.models.distil_bert.distil_bert_text_classifier import ( - DistilBertTextClassifier as DistilBertTextClassifier, + DistilBertTextClassifier as DistilBertClassifier, ) from keras_hub.src.models.distil_bert.distil_bert_text_classifier_preprocessor import ( - DistilBertTextClassifierPreprocessor as DistilBertPreprocessor, + DistilBertTextClassifierPreprocessor, ) from keras_hub.src.models.distil_bert.distil_bert_text_classifier_preprocessor import ( - DistilBertTextClassifierPreprocessor as DistilBertTextClassifierPreprocessor, + DistilBertTextClassifierPreprocessor as DistilBertPreprocessor, ) from keras_hub.src.models.distil_bert.distil_bert_tokenizer import ( - DistilBertTokenizer as DistilBertTokenizer, + DistilBertTokenizer, ) from keras_hub.src.models.efficientnet.efficientnet_backbone import ( - EfficientNetBackbone as EfficientNetBackbone, + EfficientNetBackbone, ) from keras_hub.src.models.efficientnet.efficientnet_image_classifier import ( - EfficientNetImageClassifier as EfficientNetImageClassifier, + EfficientNetImageClassifier, ) from keras_hub.src.models.efficientnet.efficientnet_image_classifier_preprocessor import ( - EfficientNetImageClassifierPreprocessor as EfficientNetImageClassifierPreprocessor, -) -from keras_hub.src.models.electra.electra_backbone import ( - ElectraBackbone as ElectraBackbone, -) -from keras_hub.src.models.electra.electra_tokenizer import ( - ElectraTokenizer as ElectraTokenizer, -) -from keras_hub.src.models.f_net.f_net_backbone import ( - FNetBackbone as FNetBackbone, -) -from keras_hub.src.models.f_net.f_net_masked_lm import ( - FNetMaskedLM as FNetMaskedLM, -) + EfficientNetImageClassifierPreprocessor, +) +from keras_hub.src.models.electra.electra_backbone import ElectraBackbone +from keras_hub.src.models.electra.electra_tokenizer import ElectraTokenizer +from keras_hub.src.models.esm.esm_backbone import ESMBackbone +from keras_hub.src.models.esm.esm_backbone import ESMBackbone as ESM2Backbone +from keras_hub.src.models.esm.esm_classifier import ESMProteinClassifier +from keras_hub.src.models.esm.esm_classifier_preprocessor import ( + ESMProteinClassifierPreprocessor, +) +from keras_hub.src.models.esm.esm_masked_plm import ESMMaskedPLM +from keras_hub.src.models.esm.esm_masked_plm import ( + ESMMaskedPLM as ESM2MaskedPLM, +) +from keras_hub.src.models.esm.esm_masked_plm_preprocessor import ( + ESMMaskedPLMPreprocessor, +) +from keras_hub.src.models.esm.esm_tokenizer import ESMTokenizer +from keras_hub.src.models.f_net.f_net_backbone import FNetBackbone +from keras_hub.src.models.f_net.f_net_masked_lm import FNetMaskedLM from keras_hub.src.models.f_net.f_net_masked_lm_preprocessor import ( - FNetMaskedLMPreprocessor as FNetMaskedLMPreprocessor, + FNetMaskedLMPreprocessor, ) +from keras_hub.src.models.f_net.f_net_text_classifier import FNetTextClassifier from keras_hub.src.models.f_net.f_net_text_classifier import ( FNetTextClassifier as FNetClassifier, ) -from keras_hub.src.models.f_net.f_net_text_classifier import ( - FNetTextClassifier as FNetTextClassifier, -) from keras_hub.src.models.f_net.f_net_text_classifier_preprocessor import ( - FNetTextClassifierPreprocessor as FNetPreprocessor, + FNetTextClassifierPreprocessor, ) from keras_hub.src.models.f_net.f_net_text_classifier_preprocessor import ( - FNetTextClassifierPreprocessor as FNetTextClassifierPreprocessor, -) -from keras_hub.src.models.f_net.f_net_tokenizer import ( - FNetTokenizer as FNetTokenizer, -) -from keras_hub.src.models.falcon.falcon_backbone import ( - FalconBackbone as FalconBackbone, -) -from keras_hub.src.models.falcon.falcon_causal_lm import ( - FalconCausalLM as FalconCausalLM, + FNetTextClassifierPreprocessor as FNetPreprocessor, ) +from keras_hub.src.models.f_net.f_net_tokenizer import FNetTokenizer +from keras_hub.src.models.falcon.falcon_backbone import FalconBackbone +from keras_hub.src.models.falcon.falcon_causal_lm import FalconCausalLM from keras_hub.src.models.falcon.falcon_causal_lm_preprocessor import ( - FalconCausalLMPreprocessor as FalconCausalLMPreprocessor, -) -from keras_hub.src.models.falcon.falcon_tokenizer import ( - FalconTokenizer as FalconTokenizer, -) -from keras_hub.src.models.feature_pyramid_backbone import ( - FeaturePyramidBackbone as FeaturePyramidBackbone, -) -from keras_hub.src.models.flux.flux_model import FluxBackbone as FluxBackbone -from keras_hub.src.models.flux.flux_text_to_image import ( - FluxTextToImage as FluxTextToImage, + FalconCausalLMPreprocessor, ) +from keras_hub.src.models.falcon.falcon_tokenizer import FalconTokenizer +from keras_hub.src.models.feature_pyramid_backbone import FeaturePyramidBackbone +from keras_hub.src.models.flux.flux_model import FluxBackbone +from keras_hub.src.models.flux.flux_text_to_image import FluxTextToImage from keras_hub.src.models.flux.flux_text_to_image_preprocessor import ( - FluxTextToImagePreprocessor as FluxTextToImagePreprocessor, -) -from keras_hub.src.models.gemma.gemma_backbone import ( - GemmaBackbone as GemmaBackbone, -) -from keras_hub.src.models.gemma.gemma_causal_lm import ( - GemmaCausalLM as GemmaCausalLM, + FluxTextToImagePreprocessor, ) +from keras_hub.src.models.gemma.gemma_backbone import GemmaBackbone +from keras_hub.src.models.gemma.gemma_causal_lm import GemmaCausalLM from keras_hub.src.models.gemma.gemma_causal_lm_preprocessor import ( - GemmaCausalLMPreprocessor as GemmaCausalLMPreprocessor, -) -from keras_hub.src.models.gemma.gemma_tokenizer import ( - GemmaTokenizer as GemmaTokenizer, -) -from keras_hub.src.models.gemma3.gemma3_backbone import ( - Gemma3Backbone as Gemma3Backbone, -) -from keras_hub.src.models.gemma3.gemma3_causal_lm import ( - Gemma3CausalLM as Gemma3CausalLM, + GemmaCausalLMPreprocessor, ) +from keras_hub.src.models.gemma.gemma_tokenizer import GemmaTokenizer +from keras_hub.src.models.gemma3.gemma3_backbone import Gemma3Backbone +from keras_hub.src.models.gemma3.gemma3_causal_lm import Gemma3CausalLM from keras_hub.src.models.gemma3.gemma3_causal_lm_preprocessor import ( - Gemma3CausalLMPreprocessor as Gemma3CausalLMPreprocessor, -) -from keras_hub.src.models.gemma3.gemma3_tokenizer import ( - Gemma3Tokenizer as Gemma3Tokenizer, + Gemma3CausalLMPreprocessor, ) +from keras_hub.src.models.gemma3.gemma3_tokenizer import Gemma3Tokenizer from keras_hub.src.models.gemma3.gemma3_vision_encoder import ( - Gemma3VisionEncoder as Gemma3VisionEncoder, -) -from keras_hub.src.models.gpt2.gpt2_backbone import GPT2Backbone as GPT2Backbone -from keras_hub.src.models.gpt2.gpt2_causal_lm import ( - GPT2CausalLM as GPT2CausalLM, + Gemma3VisionEncoder, ) +from keras_hub.src.models.gpt2.gpt2_backbone import GPT2Backbone +from keras_hub.src.models.gpt2.gpt2_causal_lm import GPT2CausalLM from keras_hub.src.models.gpt2.gpt2_causal_lm_preprocessor import ( - GPT2CausalLMPreprocessor as GPT2CausalLMPreprocessor, -) -from keras_hub.src.models.gpt2.gpt2_preprocessor import ( - GPT2Preprocessor as GPT2Preprocessor, -) -from keras_hub.src.models.gpt2.gpt2_tokenizer import ( - GPT2Tokenizer as GPT2Tokenizer, -) -from keras_hub.src.models.gpt_neo_x.gpt_neo_x_backbone import ( - GPTNeoXBackbone as GPTNeoXBackbone, -) -from keras_hub.src.models.gpt_neo_x.gpt_neo_x_causal_lm import ( - GPTNeoXCausalLM as GPTNeoXCausalLM, + GPT2CausalLMPreprocessor, ) +from keras_hub.src.models.gpt2.gpt2_preprocessor import GPT2Preprocessor +from keras_hub.src.models.gpt2.gpt2_tokenizer import GPT2Tokenizer +from keras_hub.src.models.gpt_neo_x.gpt_neo_x_backbone import GPTNeoXBackbone +from keras_hub.src.models.gpt_neo_x.gpt_neo_x_causal_lm import GPTNeoXCausalLM from keras_hub.src.models.gpt_neo_x.gpt_neo_x_causal_lm_preprocessor import ( - GPTNeoXCausalLMPreprocessor as GPTNeoXCausalLMPreprocessor, -) -from keras_hub.src.models.gpt_neo_x.gpt_neo_x_tokenizer import ( - GPTNeoXTokenizer as GPTNeoXTokenizer, -) -from keras_hub.src.models.image_classifier import ( - ImageClassifier as ImageClassifier, + GPTNeoXCausalLMPreprocessor, ) +from keras_hub.src.models.gpt_neo_x.gpt_neo_x_tokenizer import GPTNeoXTokenizer +from keras_hub.src.models.image_classifier import ImageClassifier from keras_hub.src.models.image_classifier_preprocessor import ( - ImageClassifierPreprocessor as ImageClassifierPreprocessor, -) -from keras_hub.src.models.image_segmenter import ( - ImageSegmenter as ImageSegmenter, + ImageClassifierPreprocessor, ) +from keras_hub.src.models.image_segmenter import ImageSegmenter from keras_hub.src.models.image_segmenter_preprocessor import ( - ImageSegmenterPreprocessor as ImageSegmenterPreprocessor, -) -from keras_hub.src.models.image_to_image import ImageToImage as ImageToImage -from keras_hub.src.models.inpaint import Inpaint as Inpaint -from keras_hub.src.models.llama.llama_backbone import ( - LlamaBackbone as LlamaBackbone, -) -from keras_hub.src.models.llama.llama_causal_lm import ( - LlamaCausalLM as LlamaCausalLM, + ImageSegmenterPreprocessor, ) +from keras_hub.src.models.image_to_image import ImageToImage +from keras_hub.src.models.inpaint import Inpaint +from keras_hub.src.models.llama.llama_backbone import LlamaBackbone +from keras_hub.src.models.llama.llama_causal_lm import LlamaCausalLM from keras_hub.src.models.llama.llama_causal_lm_preprocessor import ( - LlamaCausalLMPreprocessor as LlamaCausalLMPreprocessor, -) -from keras_hub.src.models.llama.llama_tokenizer import ( - LlamaTokenizer as LlamaTokenizer, -) -from keras_hub.src.models.llama3.llama3_backbone import ( - Llama3Backbone as Llama3Backbone, -) -from keras_hub.src.models.llama3.llama3_causal_lm import ( - Llama3CausalLM as Llama3CausalLM, + LlamaCausalLMPreprocessor, ) +from keras_hub.src.models.llama.llama_tokenizer import LlamaTokenizer +from keras_hub.src.models.llama3.llama3_backbone import Llama3Backbone +from keras_hub.src.models.llama3.llama3_causal_lm import Llama3CausalLM from keras_hub.src.models.llama3.llama3_causal_lm_preprocessor import ( - Llama3CausalLMPreprocessor as Llama3CausalLMPreprocessor, -) -from keras_hub.src.models.llama3.llama3_tokenizer import ( - Llama3Tokenizer as Llama3Tokenizer, -) -from keras_hub.src.models.masked_lm import MaskedLM as MaskedLM -from keras_hub.src.models.masked_lm_preprocessor import ( - MaskedLMPreprocessor as MaskedLMPreprocessor, -) -from keras_hub.src.models.mistral.mistral_backbone import ( - MistralBackbone as MistralBackbone, -) -from keras_hub.src.models.mistral.mistral_causal_lm import ( - MistralCausalLM as MistralCausalLM, + Llama3CausalLMPreprocessor, ) +from keras_hub.src.models.llama3.llama3_tokenizer import Llama3Tokenizer +from keras_hub.src.models.masked_lm import MaskedLM +from keras_hub.src.models.masked_lm_preprocessor import MaskedLMPreprocessor +from keras_hub.src.models.mistral.mistral_backbone import MistralBackbone +from keras_hub.src.models.mistral.mistral_causal_lm import MistralCausalLM from keras_hub.src.models.mistral.mistral_causal_lm_preprocessor import ( - MistralCausalLMPreprocessor as MistralCausalLMPreprocessor, -) -from keras_hub.src.models.mistral.mistral_tokenizer import ( - MistralTokenizer as MistralTokenizer, -) -from keras_hub.src.models.mit.mit_backbone import MiTBackbone as MiTBackbone -from keras_hub.src.models.mit.mit_image_classifier import ( - MiTImageClassifier as MiTImageClassifier, + MistralCausalLMPreprocessor, ) +from keras_hub.src.models.mistral.mistral_tokenizer import MistralTokenizer +from keras_hub.src.models.mit.mit_backbone import MiTBackbone +from keras_hub.src.models.mit.mit_image_classifier import MiTImageClassifier from keras_hub.src.models.mit.mit_image_classifier_preprocessor import ( - MiTImageClassifierPreprocessor as MiTImageClassifierPreprocessor, -) -from keras_hub.src.models.mixtral.mixtral_backbone import ( - MixtralBackbone as MixtralBackbone, -) -from keras_hub.src.models.mixtral.mixtral_causal_lm import ( - MixtralCausalLM as MixtralCausalLM, + MiTImageClassifierPreprocessor, ) +from keras_hub.src.models.mixtral.mixtral_backbone import MixtralBackbone +from keras_hub.src.models.mixtral.mixtral_causal_lm import MixtralCausalLM from keras_hub.src.models.mixtral.mixtral_causal_lm_preprocessor import ( - MixtralCausalLMPreprocessor as MixtralCausalLMPreprocessor, -) -from keras_hub.src.models.mixtral.mixtral_tokenizer import ( - MixtralTokenizer as MixtralTokenizer, -) -from keras_hub.src.models.mobilenet.mobilenet_backbone import ( - MobileNetBackbone as MobileNetBackbone, + MixtralCausalLMPreprocessor, ) +from keras_hub.src.models.mixtral.mixtral_tokenizer import MixtralTokenizer +from keras_hub.src.models.mobilenet.mobilenet_backbone import MobileNetBackbone from keras_hub.src.models.mobilenet.mobilenet_image_classifier import ( - MobileNetImageClassifier as MobileNetImageClassifier, + MobileNetImageClassifier, ) from keras_hub.src.models.mobilenet.mobilenet_image_classifier_preprocessor import ( - MobileNetImageClassifierPreprocessor as MobileNetImageClassifierPreprocessor, + MobileNetImageClassifierPreprocessor, ) +from keras_hub.src.models.object_detector import ObjectDetector from keras_hub.src.models.object_detector import ( ObjectDetector as ImageObjectDetector, ) -from keras_hub.src.models.object_detector import ( - ObjectDetector as ObjectDetector, -) from keras_hub.src.models.object_detector_preprocessor import ( - ObjectDetectorPreprocessor as ImageObjectDetectorPreprocessor, + ObjectDetectorPreprocessor, ) from keras_hub.src.models.object_detector_preprocessor import ( - ObjectDetectorPreprocessor as ObjectDetectorPreprocessor, + ObjectDetectorPreprocessor as ImageObjectDetectorPreprocessor, ) -from keras_hub.src.models.opt.opt_backbone import OPTBackbone as OPTBackbone -from keras_hub.src.models.opt.opt_causal_lm import OPTCausalLM as OPTCausalLM +from keras_hub.src.models.opt.opt_backbone import OPTBackbone +from keras_hub.src.models.opt.opt_causal_lm import OPTCausalLM from keras_hub.src.models.opt.opt_causal_lm_preprocessor import ( - OPTCausalLMPreprocessor as OPTCausalLMPreprocessor, + OPTCausalLMPreprocessor, ) -from keras_hub.src.models.opt.opt_tokenizer import OPTTokenizer as OPTTokenizer +from keras_hub.src.models.opt.opt_tokenizer import OPTTokenizer from keras_hub.src.models.pali_gemma.pali_gemma_backbone import ( - PaliGemmaBackbone as PaliGemmaBackbone, + PaliGemmaBackbone, ) from keras_hub.src.models.pali_gemma.pali_gemma_causal_lm import ( - PaliGemmaCausalLM as PaliGemmaCausalLM, + PaliGemmaCausalLM, ) from keras_hub.src.models.pali_gemma.pali_gemma_causal_lm_preprocessor import ( - PaliGemmaCausalLMPreprocessor as PaliGemmaCausalLMPreprocessor, + PaliGemmaCausalLMPreprocessor, ) from keras_hub.src.models.pali_gemma.pali_gemma_tokenizer import ( - PaliGemmaTokenizer as PaliGemmaTokenizer, -) -from keras_hub.src.models.phi3.phi3_backbone import Phi3Backbone as Phi3Backbone -from keras_hub.src.models.phi3.phi3_causal_lm import ( - Phi3CausalLM as Phi3CausalLM, + PaliGemmaTokenizer, ) +from keras_hub.src.models.phi3.phi3_backbone import Phi3Backbone +from keras_hub.src.models.phi3.phi3_causal_lm import Phi3CausalLM from keras_hub.src.models.phi3.phi3_causal_lm_preprocessor import ( - Phi3CausalLMPreprocessor as Phi3CausalLMPreprocessor, + Phi3CausalLMPreprocessor, ) -from keras_hub.src.models.phi3.phi3_tokenizer import ( - Phi3Tokenizer as Phi3Tokenizer, -) -from keras_hub.src.models.preprocessor import Preprocessor as Preprocessor +from keras_hub.src.models.phi3.phi3_tokenizer import Phi3Tokenizer +from keras_hub.src.models.preprocessor import Preprocessor +from keras_hub.src.models.qwen.qwen_backbone import QwenBackbone from keras_hub.src.models.qwen.qwen_backbone import ( QwenBackbone as Qwen2Backbone, ) -from keras_hub.src.models.qwen.qwen_backbone import QwenBackbone as QwenBackbone +from keras_hub.src.models.qwen.qwen_causal_lm import QwenCausalLM from keras_hub.src.models.qwen.qwen_causal_lm import ( QwenCausalLM as Qwen2CausalLM, ) -from keras_hub.src.models.qwen.qwen_causal_lm import ( - QwenCausalLM as QwenCausalLM, -) from keras_hub.src.models.qwen.qwen_causal_lm_preprocessor import ( - QwenCausalLMPreprocessor as Qwen2CausalLMPreprocessor, + QwenCausalLMPreprocessor, ) from keras_hub.src.models.qwen.qwen_causal_lm_preprocessor import ( - QwenCausalLMPreprocessor as QwenCausalLMPreprocessor, + QwenCausalLMPreprocessor as Qwen2CausalLMPreprocessor, ) +from keras_hub.src.models.qwen.qwen_tokenizer import QwenTokenizer from keras_hub.src.models.qwen.qwen_tokenizer import ( QwenTokenizer as Qwen2Tokenizer, ) -from keras_hub.src.models.qwen.qwen_tokenizer import ( - QwenTokenizer as QwenTokenizer, -) -from keras_hub.src.models.qwen_moe.qwen_moe_backbone import ( - QwenMoeBackbone as QwenMoeBackbone, -) -from keras_hub.src.models.qwen_moe.qwen_moe_causal_lm import ( - QwenMoeCausalLM as QwenMoeCausalLM, -) +from keras_hub.src.models.qwen_moe.qwen_moe_backbone import QwenMoeBackbone +from keras_hub.src.models.qwen_moe.qwen_moe_causal_lm import QwenMoeCausalLM from keras_hub.src.models.qwen_moe.qwen_moe_causal_lm_preprocessor import ( - QwenMoeCausalLMPreprocessor as QwenMoeCausalLMPreprocessor, -) -from keras_hub.src.models.resnet.resnet_backbone import ( - ResNetBackbone as ResNetBackbone, + QwenMoeCausalLMPreprocessor, ) +from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone from keras_hub.src.models.resnet.resnet_image_classifier import ( - ResNetImageClassifier as ResNetImageClassifier, + ResNetImageClassifier, ) from keras_hub.src.models.resnet.resnet_image_classifier_preprocessor import ( - ResNetImageClassifierPreprocessor as ResNetImageClassifierPreprocessor, -) -from keras_hub.src.models.retinanet.retinanet_backbone import ( - RetinaNetBackbone as RetinaNetBackbone, + ResNetImageClassifierPreprocessor, ) +from keras_hub.src.models.retinanet.retinanet_backbone import RetinaNetBackbone from keras_hub.src.models.retinanet.retinanet_object_detector import ( - RetinaNetObjectDetector as RetinaNetObjectDetector, + RetinaNetObjectDetector, ) from keras_hub.src.models.retinanet.retinanet_object_detector_preprocessor import ( - RetinaNetObjectDetectorPreprocessor as RetinaNetObjectDetectorPreprocessor, -) -from keras_hub.src.models.roberta.roberta_backbone import ( - RobertaBackbone as RobertaBackbone, -) -from keras_hub.src.models.roberta.roberta_masked_lm import ( - RobertaMaskedLM as RobertaMaskedLM, + RetinaNetObjectDetectorPreprocessor, ) +from keras_hub.src.models.roberta.roberta_backbone import RobertaBackbone +from keras_hub.src.models.roberta.roberta_masked_lm import RobertaMaskedLM from keras_hub.src.models.roberta.roberta_masked_lm_preprocessor import ( - RobertaMaskedLMPreprocessor as RobertaMaskedLMPreprocessor, + RobertaMaskedLMPreprocessor, ) from keras_hub.src.models.roberta.roberta_text_classifier import ( - RobertaTextClassifier as RobertaClassifier, + RobertaTextClassifier, ) from keras_hub.src.models.roberta.roberta_text_classifier import ( - RobertaTextClassifier as RobertaTextClassifier, + RobertaTextClassifier as RobertaClassifier, ) from keras_hub.src.models.roberta.roberta_text_classifier_preprocessor import ( - RobertaTextClassifierPreprocessor as RobertaPreprocessor, + RobertaTextClassifierPreprocessor, ) from keras_hub.src.models.roberta.roberta_text_classifier_preprocessor import ( - RobertaTextClassifierPreprocessor as RobertaTextClassifierPreprocessor, -) -from keras_hub.src.models.roberta.roberta_tokenizer import ( - RobertaTokenizer as RobertaTokenizer, + RobertaTextClassifierPreprocessor as RobertaPreprocessor, ) +from keras_hub.src.models.roberta.roberta_tokenizer import RobertaTokenizer from keras_hub.src.models.roformer_v2.roformer_v2_backbone import ( - RoformerV2Backbone as RoformerV2Backbone, + RoformerV2Backbone, ) from keras_hub.src.models.roformer_v2.roformer_v2_masked_lm import ( - RoformerV2MaskedLM as RoformerV2MaskedLM, + RoformerV2MaskedLM, ) from keras_hub.src.models.roformer_v2.roformer_v2_masked_lm_preprocessor import ( - RoformerV2MaskedLMPreprocessor as RoformerV2MaskedLMPreprocessor, + RoformerV2MaskedLMPreprocessor, ) from keras_hub.src.models.roformer_v2.roformer_v2_text_classifier import ( - RoformerV2TextClassifier as RoformerV2TextClassifier, + RoformerV2TextClassifier, ) from keras_hub.src.models.roformer_v2.roformer_v2_text_classifier_preprocessor import ( - RoformerV2TextClassifierPreprocessor as RoformerV2TextClassifierPreprocessor, + RoformerV2TextClassifierPreprocessor, ) from keras_hub.src.models.roformer_v2.roformer_v2_tokenizer import ( - RoformerV2Tokenizer as RoformerV2Tokenizer, -) -from keras_hub.src.models.sam.sam_backbone import SAMBackbone as SAMBackbone -from keras_hub.src.models.sam.sam_image_segmenter import ( - SAMImageSegmenter as SAMImageSegmenter, + RoformerV2Tokenizer, ) +from keras_hub.src.models.sam.sam_backbone import SAMBackbone +from keras_hub.src.models.sam.sam_image_segmenter import SAMImageSegmenter from keras_hub.src.models.sam.sam_image_segmenter_preprocessor import ( - SAMImageSegmenterPreprocessor as SAMImageSegmenterPreprocessor, -) -from keras_hub.src.models.segformer.segformer_backbone import ( - SegFormerBackbone as SegFormerBackbone, + SAMImageSegmenterPreprocessor, ) +from keras_hub.src.models.segformer.segformer_backbone import SegFormerBackbone from keras_hub.src.models.segformer.segformer_image_segmenter import ( - SegFormerImageSegmenter as SegFormerImageSegmenter, + SegFormerImageSegmenter, ) from keras_hub.src.models.segformer.segformer_image_segmenter_preprocessor import ( - SegFormerImageSegmenterPreprocessor as SegFormerImageSegmenterPreprocessor, -) -from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM as Seq2SeqLM -from keras_hub.src.models.seq_2_seq_lm_preprocessor import ( - Seq2SeqLMPreprocessor as Seq2SeqLMPreprocessor, -) -from keras_hub.src.models.siglip.siglip_backbone import ( - SigLIPBackbone as SigLIPBackbone, -) -from keras_hub.src.models.siglip.siglip_preprocessor import ( - SigLIPPreprocessor as SigLIPPreprocessor, -) -from keras_hub.src.models.siglip.siglip_text_encoder import ( - SigLIPTextEncoder as SigLIPTextEncoder, -) -from keras_hub.src.models.siglip.siglip_tokenizer import ( - SigLIPTokenizer as SigLIPTokenizer, -) + SegFormerImageSegmenterPreprocessor, +) +from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM +from keras_hub.src.models.seq_2_seq_lm_preprocessor import Seq2SeqLMPreprocessor +from keras_hub.src.models.siglip.siglip_backbone import SigLIPBackbone +from keras_hub.src.models.siglip.siglip_preprocessor import SigLIPPreprocessor +from keras_hub.src.models.siglip.siglip_text_encoder import SigLIPTextEncoder +from keras_hub.src.models.siglip.siglip_tokenizer import SigLIPTokenizer from keras_hub.src.models.siglip.siglip_vision_encoder import ( - SigLIPVisionEncoder as SigLIPVisionEncoder, + SigLIPVisionEncoder, ) from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import ( - StableDiffusion3Backbone as StableDiffusion3Backbone, + StableDiffusion3Backbone, ) from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_image_to_image import ( - StableDiffusion3ImageToImage as StableDiffusion3ImageToImage, + StableDiffusion3ImageToImage, ) from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_inpaint import ( - StableDiffusion3Inpaint as StableDiffusion3Inpaint, + StableDiffusion3Inpaint, ) from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image import ( - StableDiffusion3TextToImage as StableDiffusion3TextToImage, + StableDiffusion3TextToImage, ) from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image_preprocessor import ( - StableDiffusion3TextToImagePreprocessor as StableDiffusion3TextToImagePreprocessor, + StableDiffusion3TextToImagePreprocessor, ) -from keras_hub.src.models.t5.t5_backbone import T5Backbone as T5Backbone -from keras_hub.src.models.t5.t5_preprocessor import ( - T5Preprocessor as T5Preprocessor, -) -from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer as T5Tokenizer -from keras_hub.src.models.task import Task as Task +from keras_hub.src.models.t5.t5_backbone import T5Backbone +from keras_hub.src.models.t5.t5_preprocessor import T5Preprocessor +from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer +from keras_hub.src.models.task import Task +from keras_hub.src.models.text_classifier import TextClassifier from keras_hub.src.models.text_classifier import TextClassifier as Classifier -from keras_hub.src.models.text_classifier import ( - TextClassifier as TextClassifier, -) from keras_hub.src.models.text_classifier_preprocessor import ( - TextClassifierPreprocessor as TextClassifierPreprocessor, + TextClassifierPreprocessor, ) -from keras_hub.src.models.text_to_image import TextToImage as TextToImage +from keras_hub.src.models.text_to_image import TextToImage from keras_hub.src.models.text_to_image_preprocessor import ( - TextToImagePreprocessor as TextToImagePreprocessor, -) -from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone as VGGBackbone -from keras_hub.src.models.vgg.vgg_image_classifier import ( - VGGImageClassifier as VGGImageClassifier, + TextToImagePreprocessor, ) +from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone +from keras_hub.src.models.vgg.vgg_image_classifier import VGGImageClassifier from keras_hub.src.models.vgg.vgg_image_classifier_preprocessor import ( - VGGImageClassifierPreprocessor as VGGImageClassifierPreprocessor, -) -from keras_hub.src.models.vit.vit_backbone import ViTBackbone as ViTBackbone -from keras_hub.src.models.vit.vit_image_classifier import ( - ViTImageClassifier as ViTImageClassifier, + VGGImageClassifierPreprocessor, ) +from keras_hub.src.models.vit.vit_backbone import ViTBackbone +from keras_hub.src.models.vit.vit_image_classifier import ViTImageClassifier from keras_hub.src.models.vit.vit_image_classifier_preprocessor import ( - ViTImageClassifierPreprocessor as ViTImageClassifierPreprocessor, -) -from keras_hub.src.models.vit_det.vit_det_backbone import ( - ViTDetBackbone as ViTDetBackbone, -) -from keras_hub.src.models.whisper.whisper_backbone import ( - WhisperBackbone as WhisperBackbone, -) -from keras_hub.src.models.whisper.whisper_tokenizer import ( - WhisperTokenizer as WhisperTokenizer, -) -from keras_hub.src.models.xception.xception_backbone import ( - XceptionBackbone as XceptionBackbone, + ViTImageClassifierPreprocessor, ) +from keras_hub.src.models.vit_det.vit_det_backbone import ViTDetBackbone +from keras_hub.src.models.whisper.whisper_backbone import WhisperBackbone +from keras_hub.src.models.whisper.whisper_tokenizer import WhisperTokenizer +from keras_hub.src.models.xception.xception_backbone import XceptionBackbone from keras_hub.src.models.xception.xception_image_classifier import ( - XceptionImageClassifier as XceptionImageClassifier, + XceptionImageClassifier, ) from keras_hub.src.models.xception.xception_image_classifier_preprocessor import ( - XceptionImageClassifierPreprocessor as XceptionImageClassifierPreprocessor, + XceptionImageClassifierPreprocessor, ) from keras_hub.src.models.xlm_roberta.xlm_roberta_backbone import ( - XLMRobertaBackbone as XLMRobertaBackbone, + XLMRobertaBackbone, ) from keras_hub.src.models.xlm_roberta.xlm_roberta_masked_lm import ( - XLMRobertaMaskedLM as XLMRobertaMaskedLM, + XLMRobertaMaskedLM, ) from keras_hub.src.models.xlm_roberta.xlm_roberta_masked_lm_preprocessor import ( - XLMRobertaMaskedLMPreprocessor as XLMRobertaMaskedLMPreprocessor, + XLMRobertaMaskedLMPreprocessor, ) from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier import ( - XLMRobertaTextClassifier as XLMRobertaClassifier, + XLMRobertaTextClassifier, ) from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier import ( - XLMRobertaTextClassifier as XLMRobertaTextClassifier, + XLMRobertaTextClassifier as XLMRobertaClassifier, ) from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier_preprocessor import ( - XLMRobertaTextClassifierPreprocessor as XLMRobertaPreprocessor, + XLMRobertaTextClassifierPreprocessor, ) from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier_preprocessor import ( - XLMRobertaTextClassifierPreprocessor as XLMRobertaTextClassifierPreprocessor, + XLMRobertaTextClassifierPreprocessor as XLMRobertaPreprocessor, ) from keras_hub.src.models.xlm_roberta.xlm_roberta_tokenizer import ( - XLMRobertaTokenizer as XLMRobertaTokenizer, -) -from keras_hub.src.models.xlnet.xlnet_backbone import ( - XLNetBackbone as XLNetBackbone, + XLMRobertaTokenizer, ) -from keras_hub.src.tokenizers.tokenizer import Tokenizer as Tokenizer \ No newline at end of file +from keras_hub.src.models.xlnet.xlnet_backbone import XLNetBackbone +from keras_hub.src.tokenizers.tokenizer import Tokenizer diff --git a/keras_hub/api/tokenizers/__init__.py b/keras_hub/api/tokenizers/__init__.py index f895749811..96818e01e7 100644 --- a/keras_hub/api/tokenizers/__init__.py +++ b/keras_hub/api/tokenizers/__init__.py @@ -4,111 +4,62 @@ since your modifications would be overwritten. """ -from keras_hub.src.models.albert.albert_tokenizer import ( - AlbertTokenizer as AlbertTokenizer, -) -from keras_hub.src.models.bart.bart_tokenizer import ( - BartTokenizer as BartTokenizer, -) -from keras_hub.src.models.bert.bert_tokenizer import ( - BertTokenizer as BertTokenizer, -) -from keras_hub.src.models.bloom.bloom_tokenizer import ( - BloomTokenizer as BloomTokenizer, -) -from keras_hub.src.models.clip.clip_tokenizer import ( - CLIPTokenizer as CLIPTokenizer, -) +from keras_hub.src.models.albert.albert_tokenizer import AlbertTokenizer +from keras_hub.src.models.bart.bart_tokenizer import BartTokenizer +from keras_hub.src.models.bert.bert_tokenizer import BertTokenizer +from keras_hub.src.models.bloom.bloom_tokenizer import BloomTokenizer +from keras_hub.src.models.clip.clip_tokenizer import CLIPTokenizer from keras_hub.src.models.deberta_v3.deberta_v3_tokenizer import ( - DebertaV3Tokenizer as DebertaV3Tokenizer, + DebertaV3Tokenizer, ) from keras_hub.src.models.distil_bert.distil_bert_tokenizer import ( - DistilBertTokenizer as DistilBertTokenizer, -) -from keras_hub.src.models.electra.electra_tokenizer import ( - ElectraTokenizer as ElectraTokenizer, -) -from keras_hub.src.models.f_net.f_net_tokenizer import ( - FNetTokenizer as FNetTokenizer, -) -from keras_hub.src.models.falcon.falcon_tokenizer import ( - FalconTokenizer as FalconTokenizer, -) -from keras_hub.src.models.gemma.gemma_tokenizer import ( - GemmaTokenizer as GemmaTokenizer, -) -from keras_hub.src.models.gemma3.gemma3_tokenizer import ( - Gemma3Tokenizer as Gemma3Tokenizer, -) -from keras_hub.src.models.gpt2.gpt2_tokenizer import ( - GPT2Tokenizer as GPT2Tokenizer, -) -from keras_hub.src.models.gpt_neo_x.gpt_neo_x_tokenizer import ( - GPTNeoXTokenizer as GPTNeoXTokenizer, -) -from keras_hub.src.models.llama.llama_tokenizer import ( - LlamaTokenizer as LlamaTokenizer, -) -from keras_hub.src.models.llama3.llama3_tokenizer import ( - Llama3Tokenizer as Llama3Tokenizer, -) -from keras_hub.src.models.mistral.mistral_tokenizer import ( - MistralTokenizer as MistralTokenizer, -) -from keras_hub.src.models.mixtral.mixtral_tokenizer import ( - MixtralTokenizer as MixtralTokenizer, -) -from keras_hub.src.models.opt.opt_tokenizer import OPTTokenizer as OPTTokenizer + DistilBertTokenizer, +) +from keras_hub.src.models.electra.electra_tokenizer import ElectraTokenizer +from keras_hub.src.models.esm.esm_tokenizer import ESMTokenizer +from keras_hub.src.models.f_net.f_net_tokenizer import FNetTokenizer +from keras_hub.src.models.falcon.falcon_tokenizer import FalconTokenizer +from keras_hub.src.models.gemma.gemma_tokenizer import GemmaTokenizer +from keras_hub.src.models.gemma3.gemma3_tokenizer import Gemma3Tokenizer +from keras_hub.src.models.gpt2.gpt2_tokenizer import GPT2Tokenizer +from keras_hub.src.models.gpt_neo_x.gpt_neo_x_tokenizer import GPTNeoXTokenizer +from keras_hub.src.models.llama.llama_tokenizer import LlamaTokenizer +from keras_hub.src.models.llama3.llama3_tokenizer import Llama3Tokenizer +from keras_hub.src.models.mistral.mistral_tokenizer import MistralTokenizer +from keras_hub.src.models.mixtral.mixtral_tokenizer import MixtralTokenizer +from keras_hub.src.models.opt.opt_tokenizer import OPTTokenizer from keras_hub.src.models.pali_gemma.pali_gemma_tokenizer import ( - PaliGemmaTokenizer as PaliGemmaTokenizer, -) -from keras_hub.src.models.phi3.phi3_tokenizer import ( - Phi3Tokenizer as Phi3Tokenizer, + PaliGemmaTokenizer, ) +from keras_hub.src.models.phi3.phi3_tokenizer import Phi3Tokenizer +from keras_hub.src.models.qwen.qwen_tokenizer import QwenTokenizer from keras_hub.src.models.qwen.qwen_tokenizer import ( QwenTokenizer as Qwen2Tokenizer, ) -from keras_hub.src.models.qwen.qwen_tokenizer import ( - QwenTokenizer as QwenTokenizer, -) -from keras_hub.src.models.qwen_moe.qwen_moe_tokenizer import ( - QwenMoeTokenizer as QwenMoeTokenizer, -) -from keras_hub.src.models.roberta.roberta_tokenizer import ( - RobertaTokenizer as RobertaTokenizer, -) +from keras_hub.src.models.qwen_moe.qwen_moe_tokenizer import QwenMoeTokenizer +from keras_hub.src.models.roberta.roberta_tokenizer import RobertaTokenizer from keras_hub.src.models.roformer_v2.roformer_v2_tokenizer import ( - RoformerV2Tokenizer as RoformerV2Tokenizer, -) -from keras_hub.src.models.siglip.siglip_tokenizer import ( - SigLIPTokenizer as SigLIPTokenizer, -) -from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer as T5Tokenizer -from keras_hub.src.models.whisper.whisper_tokenizer import ( - WhisperTokenizer as WhisperTokenizer, + RoformerV2Tokenizer, ) +from keras_hub.src.models.siglip.siglip_tokenizer import SigLIPTokenizer +from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer +from keras_hub.src.models.whisper.whisper_tokenizer import WhisperTokenizer from keras_hub.src.models.xlm_roberta.xlm_roberta_tokenizer import ( - XLMRobertaTokenizer as XLMRobertaTokenizer, -) -from keras_hub.src.tokenizers.byte_pair_tokenizer import ( - BytePairTokenizer as BytePairTokenizer, -) -from keras_hub.src.tokenizers.byte_tokenizer import ( - ByteTokenizer as ByteTokenizer, + XLMRobertaTokenizer, ) +from keras_hub.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer +from keras_hub.src.tokenizers.byte_tokenizer import ByteTokenizer from keras_hub.src.tokenizers.sentence_piece_tokenizer import ( - SentencePieceTokenizer as SentencePieceTokenizer, + SentencePieceTokenizer, ) from keras_hub.src.tokenizers.sentence_piece_tokenizer_trainer import ( - compute_sentence_piece_proto as compute_sentence_piece_proto, + compute_sentence_piece_proto, ) -from keras_hub.src.tokenizers.tokenizer import Tokenizer as Tokenizer +from keras_hub.src.tokenizers.tokenizer import Tokenizer from keras_hub.src.tokenizers.unicode_codepoint_tokenizer import ( - UnicodeCodepointTokenizer as UnicodeCodepointTokenizer, -) -from keras_hub.src.tokenizers.word_piece_tokenizer import ( - WordPieceTokenizer as WordPieceTokenizer, + UnicodeCodepointTokenizer, ) +from keras_hub.src.tokenizers.word_piece_tokenizer import WordPieceTokenizer from keras_hub.src.tokenizers.word_piece_tokenizer_trainer import ( - compute_word_piece_vocabulary as compute_word_piece_vocabulary, -) \ No newline at end of file + compute_word_piece_vocabulary, +) From 6e9f817bd8c4370c095d69bafa2c248b99e71b53 Mon Sep 17 00:00:00 2001 From: pass_lin <935499957@qq.com> Date: Sat, 10 May 2025 20:08:18 +0800 Subject: [PATCH 10/13] format --- keras_hub/api/__init__.py | 16 +- keras_hub/api/layers/__init__.py | 128 ++++-- keras_hub/api/metrics/__init__.py | 10 +- keras_hub/api/models/__init__.py | 636 +++++++++++++++++---------- keras_hub/api/samplers/__init__.py | 22 +- keras_hub/api/tokenizers/__init__.py | 128 ++++-- keras_hub/api/utils/__init__.py | 18 +- 7 files changed, 620 insertions(+), 338 deletions(-) diff --git a/keras_hub/api/__init__.py b/keras_hub/api/__init__.py index 3796e4c7f4..2aa98bf3f9 100644 --- a/keras_hub/api/__init__.py +++ b/keras_hub/api/__init__.py @@ -4,12 +4,12 @@ since your modifications would be overwritten. """ -from keras_hub import layers -from keras_hub import metrics -from keras_hub import models -from keras_hub import samplers -from keras_hub import tokenizers -from keras_hub import utils -from keras_hub.src.utils.preset_utils import upload_preset +from keras_hub import layers as layers +from keras_hub import metrics as metrics +from keras_hub import models as models +from keras_hub import samplers as samplers +from keras_hub import tokenizers as tokenizers +from keras_hub import utils as utils +from keras_hub.src.utils.preset_utils import upload_preset as upload_preset from keras_hub.src.version import __version__ as __version__ -from keras_hub.src.version import version +from keras_hub.src.version import version as version diff --git a/keras_hub/api/layers/__init__.py b/keras_hub/api/layers/__init__.py index d42af86a3c..61eb0621b6 100644 --- a/keras_hub/api/layers/__init__.py +++ b/keras_hub/api/layers/__init__.py @@ -4,86 +4,128 @@ since your modifications would be overwritten. """ -from keras_hub.src.layers.modeling.alibi_bias import AlibiBias -from keras_hub.src.layers.modeling.anchor_generator import AnchorGenerator -from keras_hub.src.layers.modeling.box_matcher import BoxMatcher +from keras_hub.src.layers.modeling.alibi_bias import AlibiBias as AlibiBias +from keras_hub.src.layers.modeling.anchor_generator import ( + AnchorGenerator as AnchorGenerator, +) +from keras_hub.src.layers.modeling.box_matcher import BoxMatcher as BoxMatcher from keras_hub.src.layers.modeling.cached_multi_head_attention import ( - CachedMultiHeadAttention, + CachedMultiHeadAttention as CachedMultiHeadAttention, +) +from keras_hub.src.layers.modeling.f_net_encoder import ( + FNetEncoder as FNetEncoder, +) +from keras_hub.src.layers.modeling.masked_lm_head import ( + MaskedLMHead as MaskedLMHead, +) +from keras_hub.src.layers.modeling.non_max_supression import ( + NonMaxSuppression as NonMaxSuppression, +) +from keras_hub.src.layers.modeling.position_embedding import ( + PositionEmbedding as PositionEmbedding, ) -from keras_hub.src.layers.modeling.f_net_encoder import FNetEncoder -from keras_hub.src.layers.modeling.masked_lm_head import MaskedLMHead -from keras_hub.src.layers.modeling.non_max_supression import NonMaxSuppression -from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding from keras_hub.src.layers.modeling.reversible_embedding import ( - ReversibleEmbedding, + ReversibleEmbedding as ReversibleEmbedding, +) +from keras_hub.src.layers.modeling.rms_normalization import ( + RMSNormalization as RMSNormalization, +) +from keras_hub.src.layers.modeling.rotary_embedding import ( + RotaryEmbedding as RotaryEmbedding, ) -from keras_hub.src.layers.modeling.rms_normalization import RMSNormalization -from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding from keras_hub.src.layers.modeling.sine_position_encoding import ( - SinePositionEncoding, + SinePositionEncoding as SinePositionEncoding, ) from keras_hub.src.layers.modeling.token_and_position_embedding import ( - TokenAndPositionEmbedding, + TokenAndPositionEmbedding as TokenAndPositionEmbedding, +) +from keras_hub.src.layers.modeling.transformer_decoder import ( + TransformerDecoder as TransformerDecoder, +) +from keras_hub.src.layers.modeling.transformer_encoder import ( + TransformerEncoder as TransformerEncoder, +) +from keras_hub.src.layers.preprocessing.audio_converter import ( + AudioConverter as AudioConverter, +) +from keras_hub.src.layers.preprocessing.image_converter import ( + ImageConverter as ImageConverter, ) -from keras_hub.src.layers.modeling.transformer_decoder import TransformerDecoder -from keras_hub.src.layers.modeling.transformer_encoder import TransformerEncoder -from keras_hub.src.layers.preprocessing.audio_converter import AudioConverter -from keras_hub.src.layers.preprocessing.image_converter import ImageConverter from keras_hub.src.layers.preprocessing.masked_lm_mask_generator import ( - MaskedLMMaskGenerator, + MaskedLMMaskGenerator as MaskedLMMaskGenerator, ) from keras_hub.src.layers.preprocessing.multi_segment_packer import ( - MultiSegmentPacker, + MultiSegmentPacker as MultiSegmentPacker, +) +from keras_hub.src.layers.preprocessing.random_deletion import ( + RandomDeletion as RandomDeletion, +) +from keras_hub.src.layers.preprocessing.random_swap import ( + RandomSwap as RandomSwap, +) +from keras_hub.src.layers.preprocessing.start_end_packer import ( + StartEndPacker as StartEndPacker, ) -from keras_hub.src.layers.preprocessing.random_deletion import RandomDeletion -from keras_hub.src.layers.preprocessing.random_swap import RandomSwap -from keras_hub.src.layers.preprocessing.start_end_packer import StartEndPacker from keras_hub.src.models.basnet.basnet_image_converter import ( - BASNetImageConverter, + BASNetImageConverter as BASNetImageConverter, +) +from keras_hub.src.models.clip.clip_image_converter import ( + CLIPImageConverter as CLIPImageConverter, ) -from keras_hub.src.models.clip.clip_image_converter import CLIPImageConverter from keras_hub.src.models.cspnet.cspnet_image_converter import ( - CSPNetImageConverter, + CSPNetImageConverter as CSPNetImageConverter, ) from keras_hub.src.models.deeplab_v3.deeplab_v3_image_converter import ( - DeepLabV3ImageConverter, + DeepLabV3ImageConverter as DeepLabV3ImageConverter, ) from keras_hub.src.models.densenet.densenet_image_converter import ( - DenseNetImageConverter, + DenseNetImageConverter as DenseNetImageConverter, ) from keras_hub.src.models.efficientnet.efficientnet_image_converter import ( - EfficientNetImageConverter, + EfficientNetImageConverter as EfficientNetImageConverter, ) from keras_hub.src.models.gemma3.gemma3_image_converter import ( - Gemma3ImageConverter, + Gemma3ImageConverter as Gemma3ImageConverter, +) +from keras_hub.src.models.mit.mit_image_converter import ( + MiTImageConverter as MiTImageConverter, ) -from keras_hub.src.models.mit.mit_image_converter import MiTImageConverter from keras_hub.src.models.mobilenet.mobilenet_image_converter import ( - MobileNetImageConverter, + MobileNetImageConverter as MobileNetImageConverter, ) from keras_hub.src.models.pali_gemma.pali_gemma_image_converter import ( - PaliGemmaImageConverter, + PaliGemmaImageConverter as PaliGemmaImageConverter, ) from keras_hub.src.models.resnet.resnet_image_converter import ( - ResNetImageConverter, + ResNetImageConverter as ResNetImageConverter, ) from keras_hub.src.models.retinanet.retinanet_image_converter import ( - RetinaNetImageConverter, + RetinaNetImageConverter as RetinaNetImageConverter, +) +from keras_hub.src.models.sam.sam_image_converter import ( + SAMImageConverter as SAMImageConverter, +) +from keras_hub.src.models.sam.sam_mask_decoder import ( + SAMMaskDecoder as SAMMaskDecoder, +) +from keras_hub.src.models.sam.sam_prompt_encoder import ( + SAMPromptEncoder as SAMPromptEncoder, ) -from keras_hub.src.models.sam.sam_image_converter import SAMImageConverter -from keras_hub.src.models.sam.sam_mask_decoder import SAMMaskDecoder -from keras_hub.src.models.sam.sam_prompt_encoder import SAMPromptEncoder from keras_hub.src.models.segformer.segformer_image_converter import ( - SegFormerImageConverter, + SegFormerImageConverter as SegFormerImageConverter, ) from keras_hub.src.models.siglip.siglip_image_converter import ( - SigLIPImageConverter, + SigLIPImageConverter as SigLIPImageConverter, +) +from keras_hub.src.models.vgg.vgg_image_converter import ( + VGGImageConverter as VGGImageConverter, +) +from keras_hub.src.models.vit.vit_image_converter import ( + ViTImageConverter as ViTImageConverter, ) -from keras_hub.src.models.vgg.vgg_image_converter import VGGImageConverter -from keras_hub.src.models.vit.vit_image_converter import ViTImageConverter from keras_hub.src.models.whisper.whisper_audio_converter import ( - WhisperAudioConverter, + WhisperAudioConverter as WhisperAudioConverter, ) from keras_hub.src.models.xception.xception_image_converter import ( - XceptionImageConverter, + XceptionImageConverter as XceptionImageConverter, ) diff --git a/keras_hub/api/metrics/__init__.py b/keras_hub/api/metrics/__init__.py index 88a0a7df2b..100c2c66fb 100644 --- a/keras_hub/api/metrics/__init__.py +++ b/keras_hub/api/metrics/__init__.py @@ -4,8 +4,8 @@ since your modifications would be overwritten. """ -from keras_hub.src.metrics.bleu import Bleu -from keras_hub.src.metrics.edit_distance import EditDistance -from keras_hub.src.metrics.perplexity import Perplexity -from keras_hub.src.metrics.rouge_l import RougeL -from keras_hub.src.metrics.rouge_n import RougeN +from keras_hub.src.metrics.bleu import Bleu as Bleu +from keras_hub.src.metrics.edit_distance import EditDistance as EditDistance +from keras_hub.src.metrics.perplexity import Perplexity as Perplexity +from keras_hub.src.metrics.rouge_l import RougeL as RougeL +from keras_hub.src.metrics.rouge_n import RougeN as RougeN diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index d0e2c7333f..d8bcc90de5 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -4,463 +4,643 @@ since your modifications would be overwritten. """ -from keras_hub.src.models.albert.albert_backbone import AlbertBackbone -from keras_hub.src.models.albert.albert_masked_lm import AlbertMaskedLM +from keras_hub.src.models.albert.albert_backbone import ( + AlbertBackbone as AlbertBackbone, +) +from keras_hub.src.models.albert.albert_masked_lm import ( + AlbertMaskedLM as AlbertMaskedLM, +) from keras_hub.src.models.albert.albert_masked_lm_preprocessor import ( - AlbertMaskedLMPreprocessor, + AlbertMaskedLMPreprocessor as AlbertMaskedLMPreprocessor, ) from keras_hub.src.models.albert.albert_text_classifier import ( - AlbertTextClassifier, + AlbertTextClassifier as AlbertClassifier, ) from keras_hub.src.models.albert.albert_text_classifier import ( - AlbertTextClassifier as AlbertClassifier, + AlbertTextClassifier as AlbertTextClassifier, ) from keras_hub.src.models.albert.albert_text_classifier_preprocessor import ( - AlbertTextClassifierPreprocessor, + AlbertTextClassifierPreprocessor as AlbertPreprocessor, ) from keras_hub.src.models.albert.albert_text_classifier_preprocessor import ( - AlbertTextClassifierPreprocessor as AlbertPreprocessor, + AlbertTextClassifierPreprocessor as AlbertTextClassifierPreprocessor, +) +from keras_hub.src.models.albert.albert_tokenizer import ( + AlbertTokenizer as AlbertTokenizer, +) +from keras_hub.src.models.backbone import Backbone as Backbone +from keras_hub.src.models.bart.bart_backbone import BartBackbone as BartBackbone +from keras_hub.src.models.bart.bart_seq_2_seq_lm import ( + BartSeq2SeqLM as BartSeq2SeqLM, ) -from keras_hub.src.models.albert.albert_tokenizer import AlbertTokenizer -from keras_hub.src.models.backbone import Backbone -from keras_hub.src.models.bart.bart_backbone import BartBackbone -from keras_hub.src.models.bart.bart_seq_2_seq_lm import BartSeq2SeqLM from keras_hub.src.models.bart.bart_seq_2_seq_lm_preprocessor import ( - BartSeq2SeqLMPreprocessor, -) -from keras_hub.src.models.bart.bart_tokenizer import BartTokenizer -from keras_hub.src.models.basnet.basnet import BASNetImageSegmenter -from keras_hub.src.models.basnet.basnet_backbone import BASNetBackbone -from keras_hub.src.models.basnet.basnet_preprocessor import BASNetPreprocessor -from keras_hub.src.models.bert.bert_backbone import BertBackbone -from keras_hub.src.models.bert.bert_masked_lm import BertMaskedLM + BartSeq2SeqLMPreprocessor as BartSeq2SeqLMPreprocessor, +) +from keras_hub.src.models.bart.bart_tokenizer import ( + BartTokenizer as BartTokenizer, +) +from keras_hub.src.models.basnet.basnet import ( + BASNetImageSegmenter as BASNetImageSegmenter, +) +from keras_hub.src.models.basnet.basnet_backbone import ( + BASNetBackbone as BASNetBackbone, +) +from keras_hub.src.models.basnet.basnet_preprocessor import ( + BASNetPreprocessor as BASNetPreprocessor, +) +from keras_hub.src.models.bert.bert_backbone import BertBackbone as BertBackbone +from keras_hub.src.models.bert.bert_masked_lm import ( + BertMaskedLM as BertMaskedLM, +) from keras_hub.src.models.bert.bert_masked_lm_preprocessor import ( - BertMaskedLMPreprocessor, + BertMaskedLMPreprocessor as BertMaskedLMPreprocessor, ) -from keras_hub.src.models.bert.bert_text_classifier import BertTextClassifier from keras_hub.src.models.bert.bert_text_classifier import ( BertTextClassifier as BertClassifier, ) -from keras_hub.src.models.bert.bert_text_classifier_preprocessor import ( - BertTextClassifierPreprocessor, +from keras_hub.src.models.bert.bert_text_classifier import ( + BertTextClassifier as BertTextClassifier, ) from keras_hub.src.models.bert.bert_text_classifier_preprocessor import ( BertTextClassifierPreprocessor as BertPreprocessor, ) -from keras_hub.src.models.bert.bert_tokenizer import BertTokenizer -from keras_hub.src.models.bloom.bloom_backbone import BloomBackbone -from keras_hub.src.models.bloom.bloom_causal_lm import BloomCausalLM +from keras_hub.src.models.bert.bert_text_classifier_preprocessor import ( + BertTextClassifierPreprocessor as BertTextClassifierPreprocessor, +) +from keras_hub.src.models.bert.bert_tokenizer import ( + BertTokenizer as BertTokenizer, +) +from keras_hub.src.models.bloom.bloom_backbone import ( + BloomBackbone as BloomBackbone, +) +from keras_hub.src.models.bloom.bloom_causal_lm import ( + BloomCausalLM as BloomCausalLM, +) from keras_hub.src.models.bloom.bloom_causal_lm_preprocessor import ( - BloomCausalLMPreprocessor, -) -from keras_hub.src.models.bloom.bloom_tokenizer import BloomTokenizer -from keras_hub.src.models.causal_lm import CausalLM -from keras_hub.src.models.causal_lm_preprocessor import CausalLMPreprocessor -from keras_hub.src.models.clip.clip_backbone import CLIPBackbone -from keras_hub.src.models.clip.clip_preprocessor import CLIPPreprocessor -from keras_hub.src.models.clip.clip_text_encoder import CLIPTextEncoder -from keras_hub.src.models.clip.clip_tokenizer import CLIPTokenizer -from keras_hub.src.models.clip.clip_vision_encoder import CLIPVisionEncoder -from keras_hub.src.models.cspnet.cspnet_backbone import CSPNetBackbone + BloomCausalLMPreprocessor as BloomCausalLMPreprocessor, +) +from keras_hub.src.models.bloom.bloom_tokenizer import ( + BloomTokenizer as BloomTokenizer, +) +from keras_hub.src.models.causal_lm import CausalLM as CausalLM +from keras_hub.src.models.causal_lm_preprocessor import ( + CausalLMPreprocessor as CausalLMPreprocessor, +) +from keras_hub.src.models.clip.clip_backbone import CLIPBackbone as CLIPBackbone +from keras_hub.src.models.clip.clip_preprocessor import ( + CLIPPreprocessor as CLIPPreprocessor, +) +from keras_hub.src.models.clip.clip_text_encoder import ( + CLIPTextEncoder as CLIPTextEncoder, +) +from keras_hub.src.models.clip.clip_tokenizer import ( + CLIPTokenizer as CLIPTokenizer, +) +from keras_hub.src.models.clip.clip_vision_encoder import ( + CLIPVisionEncoder as CLIPVisionEncoder, +) +from keras_hub.src.models.cspnet.cspnet_backbone import ( + CSPNetBackbone as CSPNetBackbone, +) from keras_hub.src.models.cspnet.cspnet_image_classifier import ( - CSPNetImageClassifier, + CSPNetImageClassifier as CSPNetImageClassifier, ) from keras_hub.src.models.cspnet.cspnet_image_classifier_preprocessor import ( - CSPNetImageClassifierPreprocessor, + CSPNetImageClassifierPreprocessor as CSPNetImageClassifierPreprocessor, ) from keras_hub.src.models.deberta_v3.deberta_v3_backbone import ( - DebertaV3Backbone, + DebertaV3Backbone as DebertaV3Backbone, ) from keras_hub.src.models.deberta_v3.deberta_v3_masked_lm import ( - DebertaV3MaskedLM, + DebertaV3MaskedLM as DebertaV3MaskedLM, ) from keras_hub.src.models.deberta_v3.deberta_v3_masked_lm_preprocessor import ( - DebertaV3MaskedLMPreprocessor, + DebertaV3MaskedLMPreprocessor as DebertaV3MaskedLMPreprocessor, ) from keras_hub.src.models.deberta_v3.deberta_v3_text_classifier import ( - DebertaV3TextClassifier, + DebertaV3TextClassifier as DebertaV3Classifier, ) from keras_hub.src.models.deberta_v3.deberta_v3_text_classifier import ( - DebertaV3TextClassifier as DebertaV3Classifier, + DebertaV3TextClassifier as DebertaV3TextClassifier, ) from keras_hub.src.models.deberta_v3.deberta_v3_text_classifier_preprocessor import ( - DebertaV3TextClassifierPreprocessor, + DebertaV3TextClassifierPreprocessor as DebertaV3Preprocessor, ) from keras_hub.src.models.deberta_v3.deberta_v3_text_classifier_preprocessor import ( - DebertaV3TextClassifierPreprocessor as DebertaV3Preprocessor, + DebertaV3TextClassifierPreprocessor as DebertaV3TextClassifierPreprocessor, ) from keras_hub.src.models.deberta_v3.deberta_v3_tokenizer import ( - DebertaV3Tokenizer, + DebertaV3Tokenizer as DebertaV3Tokenizer, ) from keras_hub.src.models.deeplab_v3.deeplab_v3_backbone import ( - DeepLabV3Backbone, + DeepLabV3Backbone as DeepLabV3Backbone, ) from keras_hub.src.models.deeplab_v3.deeplab_v3_image_segmeter_preprocessor import ( - DeepLabV3ImageSegmenterPreprocessor, + DeepLabV3ImageSegmenterPreprocessor as DeepLabV3ImageSegmenterPreprocessor, ) from keras_hub.src.models.deeplab_v3.deeplab_v3_segmenter import ( - DeepLabV3ImageSegmenter, + DeepLabV3ImageSegmenter as DeepLabV3ImageSegmenter, +) +from keras_hub.src.models.densenet.densenet_backbone import ( + DenseNetBackbone as DenseNetBackbone, ) -from keras_hub.src.models.densenet.densenet_backbone import DenseNetBackbone from keras_hub.src.models.densenet.densenet_image_classifier import ( - DenseNetImageClassifier, + DenseNetImageClassifier as DenseNetImageClassifier, ) from keras_hub.src.models.densenet.densenet_image_classifier_preprocessor import ( - DenseNetImageClassifierPreprocessor, + DenseNetImageClassifierPreprocessor as DenseNetImageClassifierPreprocessor, ) from keras_hub.src.models.distil_bert.distil_bert_backbone import ( - DistilBertBackbone, + DistilBertBackbone as DistilBertBackbone, ) from keras_hub.src.models.distil_bert.distil_bert_masked_lm import ( - DistilBertMaskedLM, + DistilBertMaskedLM as DistilBertMaskedLM, ) from keras_hub.src.models.distil_bert.distil_bert_masked_lm_preprocessor import ( - DistilBertMaskedLMPreprocessor, + DistilBertMaskedLMPreprocessor as DistilBertMaskedLMPreprocessor, ) from keras_hub.src.models.distil_bert.distil_bert_text_classifier import ( - DistilBertTextClassifier, + DistilBertTextClassifier as DistilBertClassifier, ) from keras_hub.src.models.distil_bert.distil_bert_text_classifier import ( - DistilBertTextClassifier as DistilBertClassifier, + DistilBertTextClassifier as DistilBertTextClassifier, ) from keras_hub.src.models.distil_bert.distil_bert_text_classifier_preprocessor import ( - DistilBertTextClassifierPreprocessor, + DistilBertTextClassifierPreprocessor as DistilBertPreprocessor, ) from keras_hub.src.models.distil_bert.distil_bert_text_classifier_preprocessor import ( - DistilBertTextClassifierPreprocessor as DistilBertPreprocessor, + DistilBertTextClassifierPreprocessor as DistilBertTextClassifierPreprocessor, ) from keras_hub.src.models.distil_bert.distil_bert_tokenizer import ( - DistilBertTokenizer, + DistilBertTokenizer as DistilBertTokenizer, ) from keras_hub.src.models.efficientnet.efficientnet_backbone import ( - EfficientNetBackbone, + EfficientNetBackbone as EfficientNetBackbone, ) from keras_hub.src.models.efficientnet.efficientnet_image_classifier import ( - EfficientNetImageClassifier, + EfficientNetImageClassifier as EfficientNetImageClassifier, ) from keras_hub.src.models.efficientnet.efficientnet_image_classifier_preprocessor import ( - EfficientNetImageClassifierPreprocessor, + EfficientNetImageClassifierPreprocessor as EfficientNetImageClassifierPreprocessor, +) +from keras_hub.src.models.electra.electra_backbone import ( + ElectraBackbone as ElectraBackbone, +) +from keras_hub.src.models.electra.electra_tokenizer import ( + ElectraTokenizer as ElectraTokenizer, ) -from keras_hub.src.models.electra.electra_backbone import ElectraBackbone -from keras_hub.src.models.electra.electra_tokenizer import ElectraTokenizer -from keras_hub.src.models.esm.esm_backbone import ESMBackbone from keras_hub.src.models.esm.esm_backbone import ESMBackbone as ESM2Backbone -from keras_hub.src.models.esm.esm_classifier import ESMProteinClassifier +from keras_hub.src.models.esm.esm_backbone import ESMBackbone as ESMBackbone +from keras_hub.src.models.esm.esm_classifier import ( + ESMProteinClassifier as ESMProteinClassifier, +) from keras_hub.src.models.esm.esm_classifier_preprocessor import ( - ESMProteinClassifierPreprocessor, + ESMProteinClassifierPreprocessor as ESMProteinClassifierPreprocessor, ) -from keras_hub.src.models.esm.esm_masked_plm import ESMMaskedPLM from keras_hub.src.models.esm.esm_masked_plm import ( ESMMaskedPLM as ESM2MaskedPLM, ) +from keras_hub.src.models.esm.esm_masked_plm import ESMMaskedPLM as ESMMaskedPLM from keras_hub.src.models.esm.esm_masked_plm_preprocessor import ( - ESMMaskedPLMPreprocessor, + ESMMaskedPLMPreprocessor as ESMMaskedPLMPreprocessor, +) +from keras_hub.src.models.esm.esm_tokenizer import ESMTokenizer as ESMTokenizer +from keras_hub.src.models.f_net.f_net_backbone import ( + FNetBackbone as FNetBackbone, +) +from keras_hub.src.models.f_net.f_net_masked_lm import ( + FNetMaskedLM as FNetMaskedLM, ) -from keras_hub.src.models.esm.esm_tokenizer import ESMTokenizer -from keras_hub.src.models.f_net.f_net_backbone import FNetBackbone -from keras_hub.src.models.f_net.f_net_masked_lm import FNetMaskedLM from keras_hub.src.models.f_net.f_net_masked_lm_preprocessor import ( - FNetMaskedLMPreprocessor, + FNetMaskedLMPreprocessor as FNetMaskedLMPreprocessor, ) -from keras_hub.src.models.f_net.f_net_text_classifier import FNetTextClassifier from keras_hub.src.models.f_net.f_net_text_classifier import ( FNetTextClassifier as FNetClassifier, ) -from keras_hub.src.models.f_net.f_net_text_classifier_preprocessor import ( - FNetTextClassifierPreprocessor, +from keras_hub.src.models.f_net.f_net_text_classifier import ( + FNetTextClassifier as FNetTextClassifier, ) from keras_hub.src.models.f_net.f_net_text_classifier_preprocessor import ( FNetTextClassifierPreprocessor as FNetPreprocessor, ) -from keras_hub.src.models.f_net.f_net_tokenizer import FNetTokenizer -from keras_hub.src.models.falcon.falcon_backbone import FalconBackbone -from keras_hub.src.models.falcon.falcon_causal_lm import FalconCausalLM +from keras_hub.src.models.f_net.f_net_text_classifier_preprocessor import ( + FNetTextClassifierPreprocessor as FNetTextClassifierPreprocessor, +) +from keras_hub.src.models.f_net.f_net_tokenizer import ( + FNetTokenizer as FNetTokenizer, +) +from keras_hub.src.models.falcon.falcon_backbone import ( + FalconBackbone as FalconBackbone, +) +from keras_hub.src.models.falcon.falcon_causal_lm import ( + FalconCausalLM as FalconCausalLM, +) from keras_hub.src.models.falcon.falcon_causal_lm_preprocessor import ( - FalconCausalLMPreprocessor, + FalconCausalLMPreprocessor as FalconCausalLMPreprocessor, +) +from keras_hub.src.models.falcon.falcon_tokenizer import ( + FalconTokenizer as FalconTokenizer, +) +from keras_hub.src.models.feature_pyramid_backbone import ( + FeaturePyramidBackbone as FeaturePyramidBackbone, +) +from keras_hub.src.models.flux.flux_model import FluxBackbone as FluxBackbone +from keras_hub.src.models.flux.flux_text_to_image import ( + FluxTextToImage as FluxTextToImage, ) -from keras_hub.src.models.falcon.falcon_tokenizer import FalconTokenizer -from keras_hub.src.models.feature_pyramid_backbone import FeaturePyramidBackbone -from keras_hub.src.models.flux.flux_model import FluxBackbone -from keras_hub.src.models.flux.flux_text_to_image import FluxTextToImage from keras_hub.src.models.flux.flux_text_to_image_preprocessor import ( - FluxTextToImagePreprocessor, + FluxTextToImagePreprocessor as FluxTextToImagePreprocessor, +) +from keras_hub.src.models.gemma.gemma_backbone import ( + GemmaBackbone as GemmaBackbone, +) +from keras_hub.src.models.gemma.gemma_causal_lm import ( + GemmaCausalLM as GemmaCausalLM, ) -from keras_hub.src.models.gemma.gemma_backbone import GemmaBackbone -from keras_hub.src.models.gemma.gemma_causal_lm import GemmaCausalLM from keras_hub.src.models.gemma.gemma_causal_lm_preprocessor import ( - GemmaCausalLMPreprocessor, + GemmaCausalLMPreprocessor as GemmaCausalLMPreprocessor, +) +from keras_hub.src.models.gemma.gemma_tokenizer import ( + GemmaTokenizer as GemmaTokenizer, +) +from keras_hub.src.models.gemma3.gemma3_backbone import ( + Gemma3Backbone as Gemma3Backbone, +) +from keras_hub.src.models.gemma3.gemma3_causal_lm import ( + Gemma3CausalLM as Gemma3CausalLM, ) -from keras_hub.src.models.gemma.gemma_tokenizer import GemmaTokenizer -from keras_hub.src.models.gemma3.gemma3_backbone import Gemma3Backbone -from keras_hub.src.models.gemma3.gemma3_causal_lm import Gemma3CausalLM from keras_hub.src.models.gemma3.gemma3_causal_lm_preprocessor import ( - Gemma3CausalLMPreprocessor, + Gemma3CausalLMPreprocessor as Gemma3CausalLMPreprocessor, +) +from keras_hub.src.models.gemma3.gemma3_tokenizer import ( + Gemma3Tokenizer as Gemma3Tokenizer, ) -from keras_hub.src.models.gemma3.gemma3_tokenizer import Gemma3Tokenizer from keras_hub.src.models.gemma3.gemma3_vision_encoder import ( - Gemma3VisionEncoder, + Gemma3VisionEncoder as Gemma3VisionEncoder, +) +from keras_hub.src.models.gpt2.gpt2_backbone import GPT2Backbone as GPT2Backbone +from keras_hub.src.models.gpt2.gpt2_causal_lm import ( + GPT2CausalLM as GPT2CausalLM, ) -from keras_hub.src.models.gpt2.gpt2_backbone import GPT2Backbone -from keras_hub.src.models.gpt2.gpt2_causal_lm import GPT2CausalLM from keras_hub.src.models.gpt2.gpt2_causal_lm_preprocessor import ( - GPT2CausalLMPreprocessor, + GPT2CausalLMPreprocessor as GPT2CausalLMPreprocessor, +) +from keras_hub.src.models.gpt2.gpt2_preprocessor import ( + GPT2Preprocessor as GPT2Preprocessor, +) +from keras_hub.src.models.gpt2.gpt2_tokenizer import ( + GPT2Tokenizer as GPT2Tokenizer, +) +from keras_hub.src.models.gpt_neo_x.gpt_neo_x_backbone import ( + GPTNeoXBackbone as GPTNeoXBackbone, +) +from keras_hub.src.models.gpt_neo_x.gpt_neo_x_causal_lm import ( + GPTNeoXCausalLM as GPTNeoXCausalLM, ) -from keras_hub.src.models.gpt2.gpt2_preprocessor import GPT2Preprocessor -from keras_hub.src.models.gpt2.gpt2_tokenizer import GPT2Tokenizer -from keras_hub.src.models.gpt_neo_x.gpt_neo_x_backbone import GPTNeoXBackbone -from keras_hub.src.models.gpt_neo_x.gpt_neo_x_causal_lm import GPTNeoXCausalLM from keras_hub.src.models.gpt_neo_x.gpt_neo_x_causal_lm_preprocessor import ( - GPTNeoXCausalLMPreprocessor, + GPTNeoXCausalLMPreprocessor as GPTNeoXCausalLMPreprocessor, +) +from keras_hub.src.models.gpt_neo_x.gpt_neo_x_tokenizer import ( + GPTNeoXTokenizer as GPTNeoXTokenizer, +) +from keras_hub.src.models.image_classifier import ( + ImageClassifier as ImageClassifier, ) -from keras_hub.src.models.gpt_neo_x.gpt_neo_x_tokenizer import GPTNeoXTokenizer -from keras_hub.src.models.image_classifier import ImageClassifier from keras_hub.src.models.image_classifier_preprocessor import ( - ImageClassifierPreprocessor, + ImageClassifierPreprocessor as ImageClassifierPreprocessor, +) +from keras_hub.src.models.image_segmenter import ( + ImageSegmenter as ImageSegmenter, ) -from keras_hub.src.models.image_segmenter import ImageSegmenter from keras_hub.src.models.image_segmenter_preprocessor import ( - ImageSegmenterPreprocessor, + ImageSegmenterPreprocessor as ImageSegmenterPreprocessor, +) +from keras_hub.src.models.image_to_image import ImageToImage as ImageToImage +from keras_hub.src.models.inpaint import Inpaint as Inpaint +from keras_hub.src.models.llama.llama_backbone import ( + LlamaBackbone as LlamaBackbone, +) +from keras_hub.src.models.llama.llama_causal_lm import ( + LlamaCausalLM as LlamaCausalLM, ) -from keras_hub.src.models.image_to_image import ImageToImage -from keras_hub.src.models.inpaint import Inpaint -from keras_hub.src.models.llama.llama_backbone import LlamaBackbone -from keras_hub.src.models.llama.llama_causal_lm import LlamaCausalLM from keras_hub.src.models.llama.llama_causal_lm_preprocessor import ( - LlamaCausalLMPreprocessor, + LlamaCausalLMPreprocessor as LlamaCausalLMPreprocessor, +) +from keras_hub.src.models.llama.llama_tokenizer import ( + LlamaTokenizer as LlamaTokenizer, +) +from keras_hub.src.models.llama3.llama3_backbone import ( + Llama3Backbone as Llama3Backbone, +) +from keras_hub.src.models.llama3.llama3_causal_lm import ( + Llama3CausalLM as Llama3CausalLM, ) -from keras_hub.src.models.llama.llama_tokenizer import LlamaTokenizer -from keras_hub.src.models.llama3.llama3_backbone import Llama3Backbone -from keras_hub.src.models.llama3.llama3_causal_lm import Llama3CausalLM from keras_hub.src.models.llama3.llama3_causal_lm_preprocessor import ( - Llama3CausalLMPreprocessor, + Llama3CausalLMPreprocessor as Llama3CausalLMPreprocessor, +) +from keras_hub.src.models.llama3.llama3_tokenizer import ( + Llama3Tokenizer as Llama3Tokenizer, +) +from keras_hub.src.models.masked_lm import MaskedLM as MaskedLM +from keras_hub.src.models.masked_lm_preprocessor import ( + MaskedLMPreprocessor as MaskedLMPreprocessor, +) +from keras_hub.src.models.mistral.mistral_backbone import ( + MistralBackbone as MistralBackbone, +) +from keras_hub.src.models.mistral.mistral_causal_lm import ( + MistralCausalLM as MistralCausalLM, ) -from keras_hub.src.models.llama3.llama3_tokenizer import Llama3Tokenizer -from keras_hub.src.models.masked_lm import MaskedLM -from keras_hub.src.models.masked_lm_preprocessor import MaskedLMPreprocessor -from keras_hub.src.models.mistral.mistral_backbone import MistralBackbone -from keras_hub.src.models.mistral.mistral_causal_lm import MistralCausalLM from keras_hub.src.models.mistral.mistral_causal_lm_preprocessor import ( - MistralCausalLMPreprocessor, + MistralCausalLMPreprocessor as MistralCausalLMPreprocessor, +) +from keras_hub.src.models.mistral.mistral_tokenizer import ( + MistralTokenizer as MistralTokenizer, +) +from keras_hub.src.models.mit.mit_backbone import MiTBackbone as MiTBackbone +from keras_hub.src.models.mit.mit_image_classifier import ( + MiTImageClassifier as MiTImageClassifier, ) -from keras_hub.src.models.mistral.mistral_tokenizer import MistralTokenizer -from keras_hub.src.models.mit.mit_backbone import MiTBackbone -from keras_hub.src.models.mit.mit_image_classifier import MiTImageClassifier from keras_hub.src.models.mit.mit_image_classifier_preprocessor import ( - MiTImageClassifierPreprocessor, + MiTImageClassifierPreprocessor as MiTImageClassifierPreprocessor, +) +from keras_hub.src.models.mixtral.mixtral_backbone import ( + MixtralBackbone as MixtralBackbone, +) +from keras_hub.src.models.mixtral.mixtral_causal_lm import ( + MixtralCausalLM as MixtralCausalLM, ) -from keras_hub.src.models.mixtral.mixtral_backbone import MixtralBackbone -from keras_hub.src.models.mixtral.mixtral_causal_lm import MixtralCausalLM from keras_hub.src.models.mixtral.mixtral_causal_lm_preprocessor import ( - MixtralCausalLMPreprocessor, + MixtralCausalLMPreprocessor as MixtralCausalLMPreprocessor, +) +from keras_hub.src.models.mixtral.mixtral_tokenizer import ( + MixtralTokenizer as MixtralTokenizer, +) +from keras_hub.src.models.mobilenet.mobilenet_backbone import ( + MobileNetBackbone as MobileNetBackbone, ) -from keras_hub.src.models.mixtral.mixtral_tokenizer import MixtralTokenizer -from keras_hub.src.models.mobilenet.mobilenet_backbone import MobileNetBackbone from keras_hub.src.models.mobilenet.mobilenet_image_classifier import ( - MobileNetImageClassifier, + MobileNetImageClassifier as MobileNetImageClassifier, ) from keras_hub.src.models.mobilenet.mobilenet_image_classifier_preprocessor import ( - MobileNetImageClassifierPreprocessor, + MobileNetImageClassifierPreprocessor as MobileNetImageClassifierPreprocessor, ) -from keras_hub.src.models.object_detector import ObjectDetector from keras_hub.src.models.object_detector import ( ObjectDetector as ImageObjectDetector, ) -from keras_hub.src.models.object_detector_preprocessor import ( - ObjectDetectorPreprocessor, +from keras_hub.src.models.object_detector import ( + ObjectDetector as ObjectDetector, ) from keras_hub.src.models.object_detector_preprocessor import ( ObjectDetectorPreprocessor as ImageObjectDetectorPreprocessor, ) -from keras_hub.src.models.opt.opt_backbone import OPTBackbone -from keras_hub.src.models.opt.opt_causal_lm import OPTCausalLM +from keras_hub.src.models.object_detector_preprocessor import ( + ObjectDetectorPreprocessor as ObjectDetectorPreprocessor, +) +from keras_hub.src.models.opt.opt_backbone import OPTBackbone as OPTBackbone +from keras_hub.src.models.opt.opt_causal_lm import OPTCausalLM as OPTCausalLM from keras_hub.src.models.opt.opt_causal_lm_preprocessor import ( - OPTCausalLMPreprocessor, + OPTCausalLMPreprocessor as OPTCausalLMPreprocessor, ) -from keras_hub.src.models.opt.opt_tokenizer import OPTTokenizer +from keras_hub.src.models.opt.opt_tokenizer import OPTTokenizer as OPTTokenizer from keras_hub.src.models.pali_gemma.pali_gemma_backbone import ( - PaliGemmaBackbone, + PaliGemmaBackbone as PaliGemmaBackbone, ) from keras_hub.src.models.pali_gemma.pali_gemma_causal_lm import ( - PaliGemmaCausalLM, + PaliGemmaCausalLM as PaliGemmaCausalLM, ) from keras_hub.src.models.pali_gemma.pali_gemma_causal_lm_preprocessor import ( - PaliGemmaCausalLMPreprocessor, + PaliGemmaCausalLMPreprocessor as PaliGemmaCausalLMPreprocessor, ) from keras_hub.src.models.pali_gemma.pali_gemma_tokenizer import ( - PaliGemmaTokenizer, + PaliGemmaTokenizer as PaliGemmaTokenizer, +) +from keras_hub.src.models.phi3.phi3_backbone import Phi3Backbone as Phi3Backbone +from keras_hub.src.models.phi3.phi3_causal_lm import ( + Phi3CausalLM as Phi3CausalLM, ) -from keras_hub.src.models.phi3.phi3_backbone import Phi3Backbone -from keras_hub.src.models.phi3.phi3_causal_lm import Phi3CausalLM from keras_hub.src.models.phi3.phi3_causal_lm_preprocessor import ( - Phi3CausalLMPreprocessor, + Phi3CausalLMPreprocessor as Phi3CausalLMPreprocessor, ) -from keras_hub.src.models.phi3.phi3_tokenizer import Phi3Tokenizer -from keras_hub.src.models.preprocessor import Preprocessor -from keras_hub.src.models.qwen.qwen_backbone import QwenBackbone +from keras_hub.src.models.phi3.phi3_tokenizer import ( + Phi3Tokenizer as Phi3Tokenizer, +) +from keras_hub.src.models.preprocessor import Preprocessor as Preprocessor from keras_hub.src.models.qwen.qwen_backbone import ( QwenBackbone as Qwen2Backbone, ) -from keras_hub.src.models.qwen.qwen_causal_lm import QwenCausalLM +from keras_hub.src.models.qwen.qwen_backbone import QwenBackbone as QwenBackbone from keras_hub.src.models.qwen.qwen_causal_lm import ( QwenCausalLM as Qwen2CausalLM, ) -from keras_hub.src.models.qwen.qwen_causal_lm_preprocessor import ( - QwenCausalLMPreprocessor, +from keras_hub.src.models.qwen.qwen_causal_lm import ( + QwenCausalLM as QwenCausalLM, ) from keras_hub.src.models.qwen.qwen_causal_lm_preprocessor import ( QwenCausalLMPreprocessor as Qwen2CausalLMPreprocessor, ) -from keras_hub.src.models.qwen.qwen_tokenizer import QwenTokenizer +from keras_hub.src.models.qwen.qwen_causal_lm_preprocessor import ( + QwenCausalLMPreprocessor as QwenCausalLMPreprocessor, +) from keras_hub.src.models.qwen.qwen_tokenizer import ( QwenTokenizer as Qwen2Tokenizer, ) -from keras_hub.src.models.qwen_moe.qwen_moe_backbone import QwenMoeBackbone -from keras_hub.src.models.qwen_moe.qwen_moe_causal_lm import QwenMoeCausalLM +from keras_hub.src.models.qwen.qwen_tokenizer import ( + QwenTokenizer as QwenTokenizer, +) +from keras_hub.src.models.qwen_moe.qwen_moe_backbone import ( + QwenMoeBackbone as QwenMoeBackbone, +) +from keras_hub.src.models.qwen_moe.qwen_moe_causal_lm import ( + QwenMoeCausalLM as QwenMoeCausalLM, +) from keras_hub.src.models.qwen_moe.qwen_moe_causal_lm_preprocessor import ( - QwenMoeCausalLMPreprocessor, + QwenMoeCausalLMPreprocessor as QwenMoeCausalLMPreprocessor, +) +from keras_hub.src.models.resnet.resnet_backbone import ( + ResNetBackbone as ResNetBackbone, ) -from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone from keras_hub.src.models.resnet.resnet_image_classifier import ( - ResNetImageClassifier, + ResNetImageClassifier as ResNetImageClassifier, ) from keras_hub.src.models.resnet.resnet_image_classifier_preprocessor import ( - ResNetImageClassifierPreprocessor, + ResNetImageClassifierPreprocessor as ResNetImageClassifierPreprocessor, +) +from keras_hub.src.models.retinanet.retinanet_backbone import ( + RetinaNetBackbone as RetinaNetBackbone, ) -from keras_hub.src.models.retinanet.retinanet_backbone import RetinaNetBackbone from keras_hub.src.models.retinanet.retinanet_object_detector import ( - RetinaNetObjectDetector, + RetinaNetObjectDetector as RetinaNetObjectDetector, ) from keras_hub.src.models.retinanet.retinanet_object_detector_preprocessor import ( - RetinaNetObjectDetectorPreprocessor, + RetinaNetObjectDetectorPreprocessor as RetinaNetObjectDetectorPreprocessor, +) +from keras_hub.src.models.roberta.roberta_backbone import ( + RobertaBackbone as RobertaBackbone, +) +from keras_hub.src.models.roberta.roberta_masked_lm import ( + RobertaMaskedLM as RobertaMaskedLM, ) -from keras_hub.src.models.roberta.roberta_backbone import RobertaBackbone -from keras_hub.src.models.roberta.roberta_masked_lm import RobertaMaskedLM from keras_hub.src.models.roberta.roberta_masked_lm_preprocessor import ( - RobertaMaskedLMPreprocessor, + RobertaMaskedLMPreprocessor as RobertaMaskedLMPreprocessor, ) from keras_hub.src.models.roberta.roberta_text_classifier import ( - RobertaTextClassifier, + RobertaTextClassifier as RobertaClassifier, ) from keras_hub.src.models.roberta.roberta_text_classifier import ( - RobertaTextClassifier as RobertaClassifier, + RobertaTextClassifier as RobertaTextClassifier, ) from keras_hub.src.models.roberta.roberta_text_classifier_preprocessor import ( - RobertaTextClassifierPreprocessor, + RobertaTextClassifierPreprocessor as RobertaPreprocessor, ) from keras_hub.src.models.roberta.roberta_text_classifier_preprocessor import ( - RobertaTextClassifierPreprocessor as RobertaPreprocessor, + RobertaTextClassifierPreprocessor as RobertaTextClassifierPreprocessor, +) +from keras_hub.src.models.roberta.roberta_tokenizer import ( + RobertaTokenizer as RobertaTokenizer, ) -from keras_hub.src.models.roberta.roberta_tokenizer import RobertaTokenizer from keras_hub.src.models.roformer_v2.roformer_v2_backbone import ( - RoformerV2Backbone, + RoformerV2Backbone as RoformerV2Backbone, ) from keras_hub.src.models.roformer_v2.roformer_v2_masked_lm import ( - RoformerV2MaskedLM, + RoformerV2MaskedLM as RoformerV2MaskedLM, ) from keras_hub.src.models.roformer_v2.roformer_v2_masked_lm_preprocessor import ( - RoformerV2MaskedLMPreprocessor, + RoformerV2MaskedLMPreprocessor as RoformerV2MaskedLMPreprocessor, ) from keras_hub.src.models.roformer_v2.roformer_v2_text_classifier import ( - RoformerV2TextClassifier, + RoformerV2TextClassifier as RoformerV2TextClassifier, ) from keras_hub.src.models.roformer_v2.roformer_v2_text_classifier_preprocessor import ( - RoformerV2TextClassifierPreprocessor, + RoformerV2TextClassifierPreprocessor as RoformerV2TextClassifierPreprocessor, ) from keras_hub.src.models.roformer_v2.roformer_v2_tokenizer import ( - RoformerV2Tokenizer, + RoformerV2Tokenizer as RoformerV2Tokenizer, +) +from keras_hub.src.models.sam.sam_backbone import SAMBackbone as SAMBackbone +from keras_hub.src.models.sam.sam_image_segmenter import ( + SAMImageSegmenter as SAMImageSegmenter, ) -from keras_hub.src.models.sam.sam_backbone import SAMBackbone -from keras_hub.src.models.sam.sam_image_segmenter import SAMImageSegmenter from keras_hub.src.models.sam.sam_image_segmenter_preprocessor import ( - SAMImageSegmenterPreprocessor, + SAMImageSegmenterPreprocessor as SAMImageSegmenterPreprocessor, +) +from keras_hub.src.models.segformer.segformer_backbone import ( + SegFormerBackbone as SegFormerBackbone, ) -from keras_hub.src.models.segformer.segformer_backbone import SegFormerBackbone from keras_hub.src.models.segformer.segformer_image_segmenter import ( - SegFormerImageSegmenter, + SegFormerImageSegmenter as SegFormerImageSegmenter, ) from keras_hub.src.models.segformer.segformer_image_segmenter_preprocessor import ( - SegFormerImageSegmenterPreprocessor, -) -from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM -from keras_hub.src.models.seq_2_seq_lm_preprocessor import Seq2SeqLMPreprocessor -from keras_hub.src.models.siglip.siglip_backbone import SigLIPBackbone -from keras_hub.src.models.siglip.siglip_preprocessor import SigLIPPreprocessor -from keras_hub.src.models.siglip.siglip_text_encoder import SigLIPTextEncoder -from keras_hub.src.models.siglip.siglip_tokenizer import SigLIPTokenizer + SegFormerImageSegmenterPreprocessor as SegFormerImageSegmenterPreprocessor, +) +from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM as Seq2SeqLM +from keras_hub.src.models.seq_2_seq_lm_preprocessor import ( + Seq2SeqLMPreprocessor as Seq2SeqLMPreprocessor, +) +from keras_hub.src.models.siglip.siglip_backbone import ( + SigLIPBackbone as SigLIPBackbone, +) +from keras_hub.src.models.siglip.siglip_preprocessor import ( + SigLIPPreprocessor as SigLIPPreprocessor, +) +from keras_hub.src.models.siglip.siglip_text_encoder import ( + SigLIPTextEncoder as SigLIPTextEncoder, +) +from keras_hub.src.models.siglip.siglip_tokenizer import ( + SigLIPTokenizer as SigLIPTokenizer, +) from keras_hub.src.models.siglip.siglip_vision_encoder import ( - SigLIPVisionEncoder, + SigLIPVisionEncoder as SigLIPVisionEncoder, ) from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import ( - StableDiffusion3Backbone, + StableDiffusion3Backbone as StableDiffusion3Backbone, ) from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_image_to_image import ( - StableDiffusion3ImageToImage, + StableDiffusion3ImageToImage as StableDiffusion3ImageToImage, ) from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_inpaint import ( - StableDiffusion3Inpaint, + StableDiffusion3Inpaint as StableDiffusion3Inpaint, ) from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image import ( - StableDiffusion3TextToImage, + StableDiffusion3TextToImage as StableDiffusion3TextToImage, ) from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image_preprocessor import ( - StableDiffusion3TextToImagePreprocessor, + StableDiffusion3TextToImagePreprocessor as StableDiffusion3TextToImagePreprocessor, ) -from keras_hub.src.models.t5.t5_backbone import T5Backbone -from keras_hub.src.models.t5.t5_preprocessor import T5Preprocessor -from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer -from keras_hub.src.models.task import Task -from keras_hub.src.models.text_classifier import TextClassifier +from keras_hub.src.models.t5.t5_backbone import T5Backbone as T5Backbone +from keras_hub.src.models.t5.t5_preprocessor import ( + T5Preprocessor as T5Preprocessor, +) +from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer as T5Tokenizer +from keras_hub.src.models.task import Task as Task from keras_hub.src.models.text_classifier import TextClassifier as Classifier +from keras_hub.src.models.text_classifier import ( + TextClassifier as TextClassifier, +) from keras_hub.src.models.text_classifier_preprocessor import ( - TextClassifierPreprocessor, + TextClassifierPreprocessor as TextClassifierPreprocessor, ) -from keras_hub.src.models.text_to_image import TextToImage +from keras_hub.src.models.text_to_image import TextToImage as TextToImage from keras_hub.src.models.text_to_image_preprocessor import ( - TextToImagePreprocessor, + TextToImagePreprocessor as TextToImagePreprocessor, +) +from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone as VGGBackbone +from keras_hub.src.models.vgg.vgg_image_classifier import ( + VGGImageClassifier as VGGImageClassifier, ) -from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone -from keras_hub.src.models.vgg.vgg_image_classifier import VGGImageClassifier from keras_hub.src.models.vgg.vgg_image_classifier_preprocessor import ( - VGGImageClassifierPreprocessor, + VGGImageClassifierPreprocessor as VGGImageClassifierPreprocessor, +) +from keras_hub.src.models.vit.vit_backbone import ViTBackbone as ViTBackbone +from keras_hub.src.models.vit.vit_image_classifier import ( + ViTImageClassifier as ViTImageClassifier, ) -from keras_hub.src.models.vit.vit_backbone import ViTBackbone -from keras_hub.src.models.vit.vit_image_classifier import ViTImageClassifier from keras_hub.src.models.vit.vit_image_classifier_preprocessor import ( - ViTImageClassifierPreprocessor, + ViTImageClassifierPreprocessor as ViTImageClassifierPreprocessor, +) +from keras_hub.src.models.vit_det.vit_det_backbone import ( + ViTDetBackbone as ViTDetBackbone, +) +from keras_hub.src.models.whisper.whisper_backbone import ( + WhisperBackbone as WhisperBackbone, +) +from keras_hub.src.models.whisper.whisper_tokenizer import ( + WhisperTokenizer as WhisperTokenizer, +) +from keras_hub.src.models.xception.xception_backbone import ( + XceptionBackbone as XceptionBackbone, ) -from keras_hub.src.models.vit_det.vit_det_backbone import ViTDetBackbone -from keras_hub.src.models.whisper.whisper_backbone import WhisperBackbone -from keras_hub.src.models.whisper.whisper_tokenizer import WhisperTokenizer -from keras_hub.src.models.xception.xception_backbone import XceptionBackbone from keras_hub.src.models.xception.xception_image_classifier import ( - XceptionImageClassifier, + XceptionImageClassifier as XceptionImageClassifier, ) from keras_hub.src.models.xception.xception_image_classifier_preprocessor import ( - XceptionImageClassifierPreprocessor, + XceptionImageClassifierPreprocessor as XceptionImageClassifierPreprocessor, ) from keras_hub.src.models.xlm_roberta.xlm_roberta_backbone import ( - XLMRobertaBackbone, + XLMRobertaBackbone as XLMRobertaBackbone, ) from keras_hub.src.models.xlm_roberta.xlm_roberta_masked_lm import ( - XLMRobertaMaskedLM, + XLMRobertaMaskedLM as XLMRobertaMaskedLM, ) from keras_hub.src.models.xlm_roberta.xlm_roberta_masked_lm_preprocessor import ( - XLMRobertaMaskedLMPreprocessor, + XLMRobertaMaskedLMPreprocessor as XLMRobertaMaskedLMPreprocessor, ) from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier import ( - XLMRobertaTextClassifier, + XLMRobertaTextClassifier as XLMRobertaClassifier, ) from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier import ( - XLMRobertaTextClassifier as XLMRobertaClassifier, + XLMRobertaTextClassifier as XLMRobertaTextClassifier, ) from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier_preprocessor import ( - XLMRobertaTextClassifierPreprocessor, + XLMRobertaTextClassifierPreprocessor as XLMRobertaPreprocessor, ) from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier_preprocessor import ( - XLMRobertaTextClassifierPreprocessor as XLMRobertaPreprocessor, + XLMRobertaTextClassifierPreprocessor as XLMRobertaTextClassifierPreprocessor, ) from keras_hub.src.models.xlm_roberta.xlm_roberta_tokenizer import ( - XLMRobertaTokenizer, + XLMRobertaTokenizer as XLMRobertaTokenizer, +) +from keras_hub.src.models.xlnet.xlnet_backbone import ( + XLNetBackbone as XLNetBackbone, ) -from keras_hub.src.models.xlnet.xlnet_backbone import XLNetBackbone -from keras_hub.src.tokenizers.tokenizer import Tokenizer +from keras_hub.src.tokenizers.tokenizer import Tokenizer as Tokenizer diff --git a/keras_hub/api/samplers/__init__.py b/keras_hub/api/samplers/__init__.py index 9feb76c669..29bfef00fc 100644 --- a/keras_hub/api/samplers/__init__.py +++ b/keras_hub/api/samplers/__init__.py @@ -4,13 +4,15 @@ since your modifications would be overwritten. """ -from keras_hub.src.samplers.beam_sampler import BeamSampler -from keras_hub.src.samplers.contrastive_sampler import ContrastiveSampler -from keras_hub.src.samplers.greedy_sampler import GreedySampler -from keras_hub.src.samplers.random_sampler import RandomSampler -from keras_hub.src.samplers.sampler import Sampler -from keras_hub.src.samplers.serialization import deserialize -from keras_hub.src.samplers.serialization import get -from keras_hub.src.samplers.serialization import serialize -from keras_hub.src.samplers.top_k_sampler import TopKSampler -from keras_hub.src.samplers.top_p_sampler import TopPSampler +from keras_hub.src.samplers.beam_sampler import BeamSampler as BeamSampler +from keras_hub.src.samplers.contrastive_sampler import ( + ContrastiveSampler as ContrastiveSampler, +) +from keras_hub.src.samplers.greedy_sampler import GreedySampler as GreedySampler +from keras_hub.src.samplers.random_sampler import RandomSampler as RandomSampler +from keras_hub.src.samplers.sampler import Sampler as Sampler +from keras_hub.src.samplers.serialization import deserialize as deserialize +from keras_hub.src.samplers.serialization import get as get +from keras_hub.src.samplers.serialization import serialize as serialize +from keras_hub.src.samplers.top_k_sampler import TopKSampler as TopKSampler +from keras_hub.src.samplers.top_p_sampler import TopPSampler as TopPSampler diff --git a/keras_hub/api/tokenizers/__init__.py b/keras_hub/api/tokenizers/__init__.py index 96818e01e7..303bd190fc 100644 --- a/keras_hub/api/tokenizers/__init__.py +++ b/keras_hub/api/tokenizers/__init__.py @@ -4,62 +4,112 @@ since your modifications would be overwritten. """ -from keras_hub.src.models.albert.albert_tokenizer import AlbertTokenizer -from keras_hub.src.models.bart.bart_tokenizer import BartTokenizer -from keras_hub.src.models.bert.bert_tokenizer import BertTokenizer -from keras_hub.src.models.bloom.bloom_tokenizer import BloomTokenizer -from keras_hub.src.models.clip.clip_tokenizer import CLIPTokenizer +from keras_hub.src.models.albert.albert_tokenizer import ( + AlbertTokenizer as AlbertTokenizer, +) +from keras_hub.src.models.bart.bart_tokenizer import ( + BartTokenizer as BartTokenizer, +) +from keras_hub.src.models.bert.bert_tokenizer import ( + BertTokenizer as BertTokenizer, +) +from keras_hub.src.models.bloom.bloom_tokenizer import ( + BloomTokenizer as BloomTokenizer, +) +from keras_hub.src.models.clip.clip_tokenizer import ( + CLIPTokenizer as CLIPTokenizer, +) from keras_hub.src.models.deberta_v3.deberta_v3_tokenizer import ( - DebertaV3Tokenizer, + DebertaV3Tokenizer as DebertaV3Tokenizer, ) from keras_hub.src.models.distil_bert.distil_bert_tokenizer import ( - DistilBertTokenizer, -) -from keras_hub.src.models.electra.electra_tokenizer import ElectraTokenizer -from keras_hub.src.models.esm.esm_tokenizer import ESMTokenizer -from keras_hub.src.models.f_net.f_net_tokenizer import FNetTokenizer -from keras_hub.src.models.falcon.falcon_tokenizer import FalconTokenizer -from keras_hub.src.models.gemma.gemma_tokenizer import GemmaTokenizer -from keras_hub.src.models.gemma3.gemma3_tokenizer import Gemma3Tokenizer -from keras_hub.src.models.gpt2.gpt2_tokenizer import GPT2Tokenizer -from keras_hub.src.models.gpt_neo_x.gpt_neo_x_tokenizer import GPTNeoXTokenizer -from keras_hub.src.models.llama.llama_tokenizer import LlamaTokenizer -from keras_hub.src.models.llama3.llama3_tokenizer import Llama3Tokenizer -from keras_hub.src.models.mistral.mistral_tokenizer import MistralTokenizer -from keras_hub.src.models.mixtral.mixtral_tokenizer import MixtralTokenizer -from keras_hub.src.models.opt.opt_tokenizer import OPTTokenizer + DistilBertTokenizer as DistilBertTokenizer, +) +from keras_hub.src.models.electra.electra_tokenizer import ( + ElectraTokenizer as ElectraTokenizer, +) +from keras_hub.src.models.esm.esm_tokenizer import ESMTokenizer as ESMTokenizer +from keras_hub.src.models.f_net.f_net_tokenizer import ( + FNetTokenizer as FNetTokenizer, +) +from keras_hub.src.models.falcon.falcon_tokenizer import ( + FalconTokenizer as FalconTokenizer, +) +from keras_hub.src.models.gemma.gemma_tokenizer import ( + GemmaTokenizer as GemmaTokenizer, +) +from keras_hub.src.models.gemma3.gemma3_tokenizer import ( + Gemma3Tokenizer as Gemma3Tokenizer, +) +from keras_hub.src.models.gpt2.gpt2_tokenizer import ( + GPT2Tokenizer as GPT2Tokenizer, +) +from keras_hub.src.models.gpt_neo_x.gpt_neo_x_tokenizer import ( + GPTNeoXTokenizer as GPTNeoXTokenizer, +) +from keras_hub.src.models.llama.llama_tokenizer import ( + LlamaTokenizer as LlamaTokenizer, +) +from keras_hub.src.models.llama3.llama3_tokenizer import ( + Llama3Tokenizer as Llama3Tokenizer, +) +from keras_hub.src.models.mistral.mistral_tokenizer import ( + MistralTokenizer as MistralTokenizer, +) +from keras_hub.src.models.mixtral.mixtral_tokenizer import ( + MixtralTokenizer as MixtralTokenizer, +) +from keras_hub.src.models.opt.opt_tokenizer import OPTTokenizer as OPTTokenizer from keras_hub.src.models.pali_gemma.pali_gemma_tokenizer import ( - PaliGemmaTokenizer, + PaliGemmaTokenizer as PaliGemmaTokenizer, +) +from keras_hub.src.models.phi3.phi3_tokenizer import ( + Phi3Tokenizer as Phi3Tokenizer, ) -from keras_hub.src.models.phi3.phi3_tokenizer import Phi3Tokenizer -from keras_hub.src.models.qwen.qwen_tokenizer import QwenTokenizer from keras_hub.src.models.qwen.qwen_tokenizer import ( QwenTokenizer as Qwen2Tokenizer, ) -from keras_hub.src.models.qwen_moe.qwen_moe_tokenizer import QwenMoeTokenizer -from keras_hub.src.models.roberta.roberta_tokenizer import RobertaTokenizer +from keras_hub.src.models.qwen.qwen_tokenizer import ( + QwenTokenizer as QwenTokenizer, +) +from keras_hub.src.models.qwen_moe.qwen_moe_tokenizer import ( + QwenMoeTokenizer as QwenMoeTokenizer, +) +from keras_hub.src.models.roberta.roberta_tokenizer import ( + RobertaTokenizer as RobertaTokenizer, +) from keras_hub.src.models.roformer_v2.roformer_v2_tokenizer import ( - RoformerV2Tokenizer, + RoformerV2Tokenizer as RoformerV2Tokenizer, +) +from keras_hub.src.models.siglip.siglip_tokenizer import ( + SigLIPTokenizer as SigLIPTokenizer, +) +from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer as T5Tokenizer +from keras_hub.src.models.whisper.whisper_tokenizer import ( + WhisperTokenizer as WhisperTokenizer, ) -from keras_hub.src.models.siglip.siglip_tokenizer import SigLIPTokenizer -from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer -from keras_hub.src.models.whisper.whisper_tokenizer import WhisperTokenizer from keras_hub.src.models.xlm_roberta.xlm_roberta_tokenizer import ( - XLMRobertaTokenizer, + XLMRobertaTokenizer as XLMRobertaTokenizer, +) +from keras_hub.src.tokenizers.byte_pair_tokenizer import ( + BytePairTokenizer as BytePairTokenizer, +) +from keras_hub.src.tokenizers.byte_tokenizer import ( + ByteTokenizer as ByteTokenizer, ) -from keras_hub.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer -from keras_hub.src.tokenizers.byte_tokenizer import ByteTokenizer from keras_hub.src.tokenizers.sentence_piece_tokenizer import ( - SentencePieceTokenizer, + SentencePieceTokenizer as SentencePieceTokenizer, ) from keras_hub.src.tokenizers.sentence_piece_tokenizer_trainer import ( - compute_sentence_piece_proto, + compute_sentence_piece_proto as compute_sentence_piece_proto, ) -from keras_hub.src.tokenizers.tokenizer import Tokenizer +from keras_hub.src.tokenizers.tokenizer import Tokenizer as Tokenizer from keras_hub.src.tokenizers.unicode_codepoint_tokenizer import ( - UnicodeCodepointTokenizer, + UnicodeCodepointTokenizer as UnicodeCodepointTokenizer, +) +from keras_hub.src.tokenizers.word_piece_tokenizer import ( + WordPieceTokenizer as WordPieceTokenizer, ) -from keras_hub.src.tokenizers.word_piece_tokenizer import WordPieceTokenizer from keras_hub.src.tokenizers.word_piece_tokenizer_trainer import ( - compute_word_piece_vocabulary, + compute_word_piece_vocabulary as compute_word_piece_vocabulary, ) diff --git a/keras_hub/api/utils/__init__.py b/keras_hub/api/utils/__init__.py index 8ce47790b0..0bd8cb642e 100644 --- a/keras_hub/api/utils/__init__.py +++ b/keras_hub/api/utils/__init__.py @@ -4,10 +4,18 @@ since your modifications would be overwritten. """ -from keras_hub.src.utils.coco.coco_utils import coco_id_to_name -from keras_hub.src.utils.coco.coco_utils import coco_name_to_id +from keras_hub.src.utils.coco.coco_utils import ( + coco_id_to_name as coco_id_to_name, +) +from keras_hub.src.utils.coco.coco_utils import ( + coco_name_to_id as coco_name_to_id, +) +from keras_hub.src.utils.imagenet.imagenet_utils import ( + decode_imagenet_predictions as decode_imagenet_predictions, +) +from keras_hub.src.utils.imagenet.imagenet_utils import ( + imagenet_id_to_name as imagenet_id_to_name, +) from keras_hub.src.utils.imagenet.imagenet_utils import ( - decode_imagenet_predictions, + imagenet_name_to_id as imagenet_name_to_id, ) -from keras_hub.src.utils.imagenet.imagenet_utils import imagenet_id_to_name -from keras_hub.src.utils.imagenet.imagenet_utils import imagenet_name_to_id From 2815e9c3cda5119b3d4b8c6bee9772999d5d9406 Mon Sep 17 00:00:00 2001 From: pass_lin <935499957@qq.com> Date: Sat, 17 May 2025 12:18:21 +0800 Subject: [PATCH 11/13] update --- keras_hub/api/__init__.py | 16 +- keras_hub/api/layers/__init__.py | 128 ++-- keras_hub/api/metrics/__init__.py | 10 +- keras_hub/api/models/__init__.py | 636 +++++++----------- keras_hub/api/samplers/__init__.py | 22 +- keras_hub/api/tokenizers/__init__.py | 128 ++-- keras_hub/api/utils/__init__.py | 18 +- keras_hub/src/models/esm/esm_backbone.py | 24 +- keras_hub/src/models/esm/esm_classifier.py | 3 - .../src/models/esm/esm_classifier_test.py | 9 + keras_hub/src/models/esm/esm_masked_plm.py | 4 - .../src/models/esm/esm_masked_plm_test.py | 9 + .../src/utils/transformers/convert_esm.py | 16 +- 13 files changed, 375 insertions(+), 648 deletions(-) diff --git a/keras_hub/api/__init__.py b/keras_hub/api/__init__.py index 2aa98bf3f9..3796e4c7f4 100644 --- a/keras_hub/api/__init__.py +++ b/keras_hub/api/__init__.py @@ -4,12 +4,12 @@ since your modifications would be overwritten. """ -from keras_hub import layers as layers -from keras_hub import metrics as metrics -from keras_hub import models as models -from keras_hub import samplers as samplers -from keras_hub import tokenizers as tokenizers -from keras_hub import utils as utils -from keras_hub.src.utils.preset_utils import upload_preset as upload_preset +from keras_hub import layers +from keras_hub import metrics +from keras_hub import models +from keras_hub import samplers +from keras_hub import tokenizers +from keras_hub import utils +from keras_hub.src.utils.preset_utils import upload_preset from keras_hub.src.version import __version__ as __version__ -from keras_hub.src.version import version as version +from keras_hub.src.version import version diff --git a/keras_hub/api/layers/__init__.py b/keras_hub/api/layers/__init__.py index 61eb0621b6..d42af86a3c 100644 --- a/keras_hub/api/layers/__init__.py +++ b/keras_hub/api/layers/__init__.py @@ -4,128 +4,86 @@ since your modifications would be overwritten. """ -from keras_hub.src.layers.modeling.alibi_bias import AlibiBias as AlibiBias -from keras_hub.src.layers.modeling.anchor_generator import ( - AnchorGenerator as AnchorGenerator, -) -from keras_hub.src.layers.modeling.box_matcher import BoxMatcher as BoxMatcher +from keras_hub.src.layers.modeling.alibi_bias import AlibiBias +from keras_hub.src.layers.modeling.anchor_generator import AnchorGenerator +from keras_hub.src.layers.modeling.box_matcher import BoxMatcher from keras_hub.src.layers.modeling.cached_multi_head_attention import ( - CachedMultiHeadAttention as CachedMultiHeadAttention, -) -from keras_hub.src.layers.modeling.f_net_encoder import ( - FNetEncoder as FNetEncoder, -) -from keras_hub.src.layers.modeling.masked_lm_head import ( - MaskedLMHead as MaskedLMHead, -) -from keras_hub.src.layers.modeling.non_max_supression import ( - NonMaxSuppression as NonMaxSuppression, -) -from keras_hub.src.layers.modeling.position_embedding import ( - PositionEmbedding as PositionEmbedding, + CachedMultiHeadAttention, ) +from keras_hub.src.layers.modeling.f_net_encoder import FNetEncoder +from keras_hub.src.layers.modeling.masked_lm_head import MaskedLMHead +from keras_hub.src.layers.modeling.non_max_supression import NonMaxSuppression +from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding from keras_hub.src.layers.modeling.reversible_embedding import ( - ReversibleEmbedding as ReversibleEmbedding, -) -from keras_hub.src.layers.modeling.rms_normalization import ( - RMSNormalization as RMSNormalization, -) -from keras_hub.src.layers.modeling.rotary_embedding import ( - RotaryEmbedding as RotaryEmbedding, + ReversibleEmbedding, ) +from keras_hub.src.layers.modeling.rms_normalization import RMSNormalization +from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding from keras_hub.src.layers.modeling.sine_position_encoding import ( - SinePositionEncoding as SinePositionEncoding, + SinePositionEncoding, ) from keras_hub.src.layers.modeling.token_and_position_embedding import ( - TokenAndPositionEmbedding as TokenAndPositionEmbedding, -) -from keras_hub.src.layers.modeling.transformer_decoder import ( - TransformerDecoder as TransformerDecoder, -) -from keras_hub.src.layers.modeling.transformer_encoder import ( - TransformerEncoder as TransformerEncoder, -) -from keras_hub.src.layers.preprocessing.audio_converter import ( - AudioConverter as AudioConverter, -) -from keras_hub.src.layers.preprocessing.image_converter import ( - ImageConverter as ImageConverter, + TokenAndPositionEmbedding, ) +from keras_hub.src.layers.modeling.transformer_decoder import TransformerDecoder +from keras_hub.src.layers.modeling.transformer_encoder import TransformerEncoder +from keras_hub.src.layers.preprocessing.audio_converter import AudioConverter +from keras_hub.src.layers.preprocessing.image_converter import ImageConverter from keras_hub.src.layers.preprocessing.masked_lm_mask_generator import ( - MaskedLMMaskGenerator as MaskedLMMaskGenerator, + MaskedLMMaskGenerator, ) from keras_hub.src.layers.preprocessing.multi_segment_packer import ( - MultiSegmentPacker as MultiSegmentPacker, -) -from keras_hub.src.layers.preprocessing.random_deletion import ( - RandomDeletion as RandomDeletion, -) -from keras_hub.src.layers.preprocessing.random_swap import ( - RandomSwap as RandomSwap, -) -from keras_hub.src.layers.preprocessing.start_end_packer import ( - StartEndPacker as StartEndPacker, + MultiSegmentPacker, ) +from keras_hub.src.layers.preprocessing.random_deletion import RandomDeletion +from keras_hub.src.layers.preprocessing.random_swap import RandomSwap +from keras_hub.src.layers.preprocessing.start_end_packer import StartEndPacker from keras_hub.src.models.basnet.basnet_image_converter import ( - BASNetImageConverter as BASNetImageConverter, -) -from keras_hub.src.models.clip.clip_image_converter import ( - CLIPImageConverter as CLIPImageConverter, + BASNetImageConverter, ) +from keras_hub.src.models.clip.clip_image_converter import CLIPImageConverter from keras_hub.src.models.cspnet.cspnet_image_converter import ( - CSPNetImageConverter as CSPNetImageConverter, + CSPNetImageConverter, ) from keras_hub.src.models.deeplab_v3.deeplab_v3_image_converter import ( - DeepLabV3ImageConverter as DeepLabV3ImageConverter, + DeepLabV3ImageConverter, ) from keras_hub.src.models.densenet.densenet_image_converter import ( - DenseNetImageConverter as DenseNetImageConverter, + DenseNetImageConverter, ) from keras_hub.src.models.efficientnet.efficientnet_image_converter import ( - EfficientNetImageConverter as EfficientNetImageConverter, + EfficientNetImageConverter, ) from keras_hub.src.models.gemma3.gemma3_image_converter import ( - Gemma3ImageConverter as Gemma3ImageConverter, -) -from keras_hub.src.models.mit.mit_image_converter import ( - MiTImageConverter as MiTImageConverter, + Gemma3ImageConverter, ) +from keras_hub.src.models.mit.mit_image_converter import MiTImageConverter from keras_hub.src.models.mobilenet.mobilenet_image_converter import ( - MobileNetImageConverter as MobileNetImageConverter, + MobileNetImageConverter, ) from keras_hub.src.models.pali_gemma.pali_gemma_image_converter import ( - PaliGemmaImageConverter as PaliGemmaImageConverter, + PaliGemmaImageConverter, ) from keras_hub.src.models.resnet.resnet_image_converter import ( - ResNetImageConverter as ResNetImageConverter, + ResNetImageConverter, ) from keras_hub.src.models.retinanet.retinanet_image_converter import ( - RetinaNetImageConverter as RetinaNetImageConverter, -) -from keras_hub.src.models.sam.sam_image_converter import ( - SAMImageConverter as SAMImageConverter, -) -from keras_hub.src.models.sam.sam_mask_decoder import ( - SAMMaskDecoder as SAMMaskDecoder, -) -from keras_hub.src.models.sam.sam_prompt_encoder import ( - SAMPromptEncoder as SAMPromptEncoder, + RetinaNetImageConverter, ) +from keras_hub.src.models.sam.sam_image_converter import SAMImageConverter +from keras_hub.src.models.sam.sam_mask_decoder import SAMMaskDecoder +from keras_hub.src.models.sam.sam_prompt_encoder import SAMPromptEncoder from keras_hub.src.models.segformer.segformer_image_converter import ( - SegFormerImageConverter as SegFormerImageConverter, + SegFormerImageConverter, ) from keras_hub.src.models.siglip.siglip_image_converter import ( - SigLIPImageConverter as SigLIPImageConverter, -) -from keras_hub.src.models.vgg.vgg_image_converter import ( - VGGImageConverter as VGGImageConverter, -) -from keras_hub.src.models.vit.vit_image_converter import ( - ViTImageConverter as ViTImageConverter, + SigLIPImageConverter, ) +from keras_hub.src.models.vgg.vgg_image_converter import VGGImageConverter +from keras_hub.src.models.vit.vit_image_converter import ViTImageConverter from keras_hub.src.models.whisper.whisper_audio_converter import ( - WhisperAudioConverter as WhisperAudioConverter, + WhisperAudioConverter, ) from keras_hub.src.models.xception.xception_image_converter import ( - XceptionImageConverter as XceptionImageConverter, + XceptionImageConverter, ) diff --git a/keras_hub/api/metrics/__init__.py b/keras_hub/api/metrics/__init__.py index 100c2c66fb..88a0a7df2b 100644 --- a/keras_hub/api/metrics/__init__.py +++ b/keras_hub/api/metrics/__init__.py @@ -4,8 +4,8 @@ since your modifications would be overwritten. """ -from keras_hub.src.metrics.bleu import Bleu as Bleu -from keras_hub.src.metrics.edit_distance import EditDistance as EditDistance -from keras_hub.src.metrics.perplexity import Perplexity as Perplexity -from keras_hub.src.metrics.rouge_l import RougeL as RougeL -from keras_hub.src.metrics.rouge_n import RougeN as RougeN +from keras_hub.src.metrics.bleu import Bleu +from keras_hub.src.metrics.edit_distance import EditDistance +from keras_hub.src.metrics.perplexity import Perplexity +from keras_hub.src.metrics.rouge_l import RougeL +from keras_hub.src.metrics.rouge_n import RougeN diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index d8bcc90de5..d0e2c7333f 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -4,643 +4,463 @@ since your modifications would be overwritten. """ -from keras_hub.src.models.albert.albert_backbone import ( - AlbertBackbone as AlbertBackbone, -) -from keras_hub.src.models.albert.albert_masked_lm import ( - AlbertMaskedLM as AlbertMaskedLM, -) +from keras_hub.src.models.albert.albert_backbone import AlbertBackbone +from keras_hub.src.models.albert.albert_masked_lm import AlbertMaskedLM from keras_hub.src.models.albert.albert_masked_lm_preprocessor import ( - AlbertMaskedLMPreprocessor as AlbertMaskedLMPreprocessor, + AlbertMaskedLMPreprocessor, ) from keras_hub.src.models.albert.albert_text_classifier import ( - AlbertTextClassifier as AlbertClassifier, + AlbertTextClassifier, ) from keras_hub.src.models.albert.albert_text_classifier import ( - AlbertTextClassifier as AlbertTextClassifier, + AlbertTextClassifier as AlbertClassifier, ) from keras_hub.src.models.albert.albert_text_classifier_preprocessor import ( - AlbertTextClassifierPreprocessor as AlbertPreprocessor, + AlbertTextClassifierPreprocessor, ) from keras_hub.src.models.albert.albert_text_classifier_preprocessor import ( - AlbertTextClassifierPreprocessor as AlbertTextClassifierPreprocessor, -) -from keras_hub.src.models.albert.albert_tokenizer import ( - AlbertTokenizer as AlbertTokenizer, -) -from keras_hub.src.models.backbone import Backbone as Backbone -from keras_hub.src.models.bart.bart_backbone import BartBackbone as BartBackbone -from keras_hub.src.models.bart.bart_seq_2_seq_lm import ( - BartSeq2SeqLM as BartSeq2SeqLM, + AlbertTextClassifierPreprocessor as AlbertPreprocessor, ) +from keras_hub.src.models.albert.albert_tokenizer import AlbertTokenizer +from keras_hub.src.models.backbone import Backbone +from keras_hub.src.models.bart.bart_backbone import BartBackbone +from keras_hub.src.models.bart.bart_seq_2_seq_lm import BartSeq2SeqLM from keras_hub.src.models.bart.bart_seq_2_seq_lm_preprocessor import ( - BartSeq2SeqLMPreprocessor as BartSeq2SeqLMPreprocessor, -) -from keras_hub.src.models.bart.bart_tokenizer import ( - BartTokenizer as BartTokenizer, -) -from keras_hub.src.models.basnet.basnet import ( - BASNetImageSegmenter as BASNetImageSegmenter, -) -from keras_hub.src.models.basnet.basnet_backbone import ( - BASNetBackbone as BASNetBackbone, -) -from keras_hub.src.models.basnet.basnet_preprocessor import ( - BASNetPreprocessor as BASNetPreprocessor, -) -from keras_hub.src.models.bert.bert_backbone import BertBackbone as BertBackbone -from keras_hub.src.models.bert.bert_masked_lm import ( - BertMaskedLM as BertMaskedLM, -) + BartSeq2SeqLMPreprocessor, +) +from keras_hub.src.models.bart.bart_tokenizer import BartTokenizer +from keras_hub.src.models.basnet.basnet import BASNetImageSegmenter +from keras_hub.src.models.basnet.basnet_backbone import BASNetBackbone +from keras_hub.src.models.basnet.basnet_preprocessor import BASNetPreprocessor +from keras_hub.src.models.bert.bert_backbone import BertBackbone +from keras_hub.src.models.bert.bert_masked_lm import BertMaskedLM from keras_hub.src.models.bert.bert_masked_lm_preprocessor import ( - BertMaskedLMPreprocessor as BertMaskedLMPreprocessor, + BertMaskedLMPreprocessor, ) +from keras_hub.src.models.bert.bert_text_classifier import BertTextClassifier from keras_hub.src.models.bert.bert_text_classifier import ( BertTextClassifier as BertClassifier, ) -from keras_hub.src.models.bert.bert_text_classifier import ( - BertTextClassifier as BertTextClassifier, -) from keras_hub.src.models.bert.bert_text_classifier_preprocessor import ( - BertTextClassifierPreprocessor as BertPreprocessor, + BertTextClassifierPreprocessor, ) from keras_hub.src.models.bert.bert_text_classifier_preprocessor import ( - BertTextClassifierPreprocessor as BertTextClassifierPreprocessor, -) -from keras_hub.src.models.bert.bert_tokenizer import ( - BertTokenizer as BertTokenizer, -) -from keras_hub.src.models.bloom.bloom_backbone import ( - BloomBackbone as BloomBackbone, -) -from keras_hub.src.models.bloom.bloom_causal_lm import ( - BloomCausalLM as BloomCausalLM, + BertTextClassifierPreprocessor as BertPreprocessor, ) +from keras_hub.src.models.bert.bert_tokenizer import BertTokenizer +from keras_hub.src.models.bloom.bloom_backbone import BloomBackbone +from keras_hub.src.models.bloom.bloom_causal_lm import BloomCausalLM from keras_hub.src.models.bloom.bloom_causal_lm_preprocessor import ( - BloomCausalLMPreprocessor as BloomCausalLMPreprocessor, -) -from keras_hub.src.models.bloom.bloom_tokenizer import ( - BloomTokenizer as BloomTokenizer, -) -from keras_hub.src.models.causal_lm import CausalLM as CausalLM -from keras_hub.src.models.causal_lm_preprocessor import ( - CausalLMPreprocessor as CausalLMPreprocessor, -) -from keras_hub.src.models.clip.clip_backbone import CLIPBackbone as CLIPBackbone -from keras_hub.src.models.clip.clip_preprocessor import ( - CLIPPreprocessor as CLIPPreprocessor, -) -from keras_hub.src.models.clip.clip_text_encoder import ( - CLIPTextEncoder as CLIPTextEncoder, -) -from keras_hub.src.models.clip.clip_tokenizer import ( - CLIPTokenizer as CLIPTokenizer, -) -from keras_hub.src.models.clip.clip_vision_encoder import ( - CLIPVisionEncoder as CLIPVisionEncoder, -) -from keras_hub.src.models.cspnet.cspnet_backbone import ( - CSPNetBackbone as CSPNetBackbone, -) + BloomCausalLMPreprocessor, +) +from keras_hub.src.models.bloom.bloom_tokenizer import BloomTokenizer +from keras_hub.src.models.causal_lm import CausalLM +from keras_hub.src.models.causal_lm_preprocessor import CausalLMPreprocessor +from keras_hub.src.models.clip.clip_backbone import CLIPBackbone +from keras_hub.src.models.clip.clip_preprocessor import CLIPPreprocessor +from keras_hub.src.models.clip.clip_text_encoder import CLIPTextEncoder +from keras_hub.src.models.clip.clip_tokenizer import CLIPTokenizer +from keras_hub.src.models.clip.clip_vision_encoder import CLIPVisionEncoder +from keras_hub.src.models.cspnet.cspnet_backbone import CSPNetBackbone from keras_hub.src.models.cspnet.cspnet_image_classifier import ( - CSPNetImageClassifier as CSPNetImageClassifier, + CSPNetImageClassifier, ) from keras_hub.src.models.cspnet.cspnet_image_classifier_preprocessor import ( - CSPNetImageClassifierPreprocessor as CSPNetImageClassifierPreprocessor, + CSPNetImageClassifierPreprocessor, ) from keras_hub.src.models.deberta_v3.deberta_v3_backbone import ( - DebertaV3Backbone as DebertaV3Backbone, + DebertaV3Backbone, ) from keras_hub.src.models.deberta_v3.deberta_v3_masked_lm import ( - DebertaV3MaskedLM as DebertaV3MaskedLM, + DebertaV3MaskedLM, ) from keras_hub.src.models.deberta_v3.deberta_v3_masked_lm_preprocessor import ( - DebertaV3MaskedLMPreprocessor as DebertaV3MaskedLMPreprocessor, + DebertaV3MaskedLMPreprocessor, ) from keras_hub.src.models.deberta_v3.deberta_v3_text_classifier import ( - DebertaV3TextClassifier as DebertaV3Classifier, + DebertaV3TextClassifier, ) from keras_hub.src.models.deberta_v3.deberta_v3_text_classifier import ( - DebertaV3TextClassifier as DebertaV3TextClassifier, + DebertaV3TextClassifier as DebertaV3Classifier, ) from keras_hub.src.models.deberta_v3.deberta_v3_text_classifier_preprocessor import ( - DebertaV3TextClassifierPreprocessor as DebertaV3Preprocessor, + DebertaV3TextClassifierPreprocessor, ) from keras_hub.src.models.deberta_v3.deberta_v3_text_classifier_preprocessor import ( - DebertaV3TextClassifierPreprocessor as DebertaV3TextClassifierPreprocessor, + DebertaV3TextClassifierPreprocessor as DebertaV3Preprocessor, ) from keras_hub.src.models.deberta_v3.deberta_v3_tokenizer import ( - DebertaV3Tokenizer as DebertaV3Tokenizer, + DebertaV3Tokenizer, ) from keras_hub.src.models.deeplab_v3.deeplab_v3_backbone import ( - DeepLabV3Backbone as DeepLabV3Backbone, + DeepLabV3Backbone, ) from keras_hub.src.models.deeplab_v3.deeplab_v3_image_segmeter_preprocessor import ( - DeepLabV3ImageSegmenterPreprocessor as DeepLabV3ImageSegmenterPreprocessor, + DeepLabV3ImageSegmenterPreprocessor, ) from keras_hub.src.models.deeplab_v3.deeplab_v3_segmenter import ( - DeepLabV3ImageSegmenter as DeepLabV3ImageSegmenter, -) -from keras_hub.src.models.densenet.densenet_backbone import ( - DenseNetBackbone as DenseNetBackbone, + DeepLabV3ImageSegmenter, ) +from keras_hub.src.models.densenet.densenet_backbone import DenseNetBackbone from keras_hub.src.models.densenet.densenet_image_classifier import ( - DenseNetImageClassifier as DenseNetImageClassifier, + DenseNetImageClassifier, ) from keras_hub.src.models.densenet.densenet_image_classifier_preprocessor import ( - DenseNetImageClassifierPreprocessor as DenseNetImageClassifierPreprocessor, + DenseNetImageClassifierPreprocessor, ) from keras_hub.src.models.distil_bert.distil_bert_backbone import ( - DistilBertBackbone as DistilBertBackbone, + DistilBertBackbone, ) from keras_hub.src.models.distil_bert.distil_bert_masked_lm import ( - DistilBertMaskedLM as DistilBertMaskedLM, + DistilBertMaskedLM, ) from keras_hub.src.models.distil_bert.distil_bert_masked_lm_preprocessor import ( - DistilBertMaskedLMPreprocessor as DistilBertMaskedLMPreprocessor, + DistilBertMaskedLMPreprocessor, ) from keras_hub.src.models.distil_bert.distil_bert_text_classifier import ( - DistilBertTextClassifier as DistilBertClassifier, + DistilBertTextClassifier, ) from keras_hub.src.models.distil_bert.distil_bert_text_classifier import ( - DistilBertTextClassifier as DistilBertTextClassifier, + DistilBertTextClassifier as DistilBertClassifier, ) from keras_hub.src.models.distil_bert.distil_bert_text_classifier_preprocessor import ( - DistilBertTextClassifierPreprocessor as DistilBertPreprocessor, + DistilBertTextClassifierPreprocessor, ) from keras_hub.src.models.distil_bert.distil_bert_text_classifier_preprocessor import ( - DistilBertTextClassifierPreprocessor as DistilBertTextClassifierPreprocessor, + DistilBertTextClassifierPreprocessor as DistilBertPreprocessor, ) from keras_hub.src.models.distil_bert.distil_bert_tokenizer import ( - DistilBertTokenizer as DistilBertTokenizer, + DistilBertTokenizer, ) from keras_hub.src.models.efficientnet.efficientnet_backbone import ( - EfficientNetBackbone as EfficientNetBackbone, + EfficientNetBackbone, ) from keras_hub.src.models.efficientnet.efficientnet_image_classifier import ( - EfficientNetImageClassifier as EfficientNetImageClassifier, + EfficientNetImageClassifier, ) from keras_hub.src.models.efficientnet.efficientnet_image_classifier_preprocessor import ( - EfficientNetImageClassifierPreprocessor as EfficientNetImageClassifierPreprocessor, -) -from keras_hub.src.models.electra.electra_backbone import ( - ElectraBackbone as ElectraBackbone, -) -from keras_hub.src.models.electra.electra_tokenizer import ( - ElectraTokenizer as ElectraTokenizer, + EfficientNetImageClassifierPreprocessor, ) +from keras_hub.src.models.electra.electra_backbone import ElectraBackbone +from keras_hub.src.models.electra.electra_tokenizer import ElectraTokenizer +from keras_hub.src.models.esm.esm_backbone import ESMBackbone from keras_hub.src.models.esm.esm_backbone import ESMBackbone as ESM2Backbone -from keras_hub.src.models.esm.esm_backbone import ESMBackbone as ESMBackbone -from keras_hub.src.models.esm.esm_classifier import ( - ESMProteinClassifier as ESMProteinClassifier, -) +from keras_hub.src.models.esm.esm_classifier import ESMProteinClassifier from keras_hub.src.models.esm.esm_classifier_preprocessor import ( - ESMProteinClassifierPreprocessor as ESMProteinClassifierPreprocessor, + ESMProteinClassifierPreprocessor, ) +from keras_hub.src.models.esm.esm_masked_plm import ESMMaskedPLM from keras_hub.src.models.esm.esm_masked_plm import ( ESMMaskedPLM as ESM2MaskedPLM, ) -from keras_hub.src.models.esm.esm_masked_plm import ESMMaskedPLM as ESMMaskedPLM from keras_hub.src.models.esm.esm_masked_plm_preprocessor import ( - ESMMaskedPLMPreprocessor as ESMMaskedPLMPreprocessor, -) -from keras_hub.src.models.esm.esm_tokenizer import ESMTokenizer as ESMTokenizer -from keras_hub.src.models.f_net.f_net_backbone import ( - FNetBackbone as FNetBackbone, -) -from keras_hub.src.models.f_net.f_net_masked_lm import ( - FNetMaskedLM as FNetMaskedLM, + ESMMaskedPLMPreprocessor, ) +from keras_hub.src.models.esm.esm_tokenizer import ESMTokenizer +from keras_hub.src.models.f_net.f_net_backbone import FNetBackbone +from keras_hub.src.models.f_net.f_net_masked_lm import FNetMaskedLM from keras_hub.src.models.f_net.f_net_masked_lm_preprocessor import ( - FNetMaskedLMPreprocessor as FNetMaskedLMPreprocessor, + FNetMaskedLMPreprocessor, ) +from keras_hub.src.models.f_net.f_net_text_classifier import FNetTextClassifier from keras_hub.src.models.f_net.f_net_text_classifier import ( FNetTextClassifier as FNetClassifier, ) -from keras_hub.src.models.f_net.f_net_text_classifier import ( - FNetTextClassifier as FNetTextClassifier, -) from keras_hub.src.models.f_net.f_net_text_classifier_preprocessor import ( - FNetTextClassifierPreprocessor as FNetPreprocessor, + FNetTextClassifierPreprocessor, ) from keras_hub.src.models.f_net.f_net_text_classifier_preprocessor import ( - FNetTextClassifierPreprocessor as FNetTextClassifierPreprocessor, -) -from keras_hub.src.models.f_net.f_net_tokenizer import ( - FNetTokenizer as FNetTokenizer, -) -from keras_hub.src.models.falcon.falcon_backbone import ( - FalconBackbone as FalconBackbone, -) -from keras_hub.src.models.falcon.falcon_causal_lm import ( - FalconCausalLM as FalconCausalLM, + FNetTextClassifierPreprocessor as FNetPreprocessor, ) +from keras_hub.src.models.f_net.f_net_tokenizer import FNetTokenizer +from keras_hub.src.models.falcon.falcon_backbone import FalconBackbone +from keras_hub.src.models.falcon.falcon_causal_lm import FalconCausalLM from keras_hub.src.models.falcon.falcon_causal_lm_preprocessor import ( - FalconCausalLMPreprocessor as FalconCausalLMPreprocessor, -) -from keras_hub.src.models.falcon.falcon_tokenizer import ( - FalconTokenizer as FalconTokenizer, -) -from keras_hub.src.models.feature_pyramid_backbone import ( - FeaturePyramidBackbone as FeaturePyramidBackbone, -) -from keras_hub.src.models.flux.flux_model import FluxBackbone as FluxBackbone -from keras_hub.src.models.flux.flux_text_to_image import ( - FluxTextToImage as FluxTextToImage, + FalconCausalLMPreprocessor, ) +from keras_hub.src.models.falcon.falcon_tokenizer import FalconTokenizer +from keras_hub.src.models.feature_pyramid_backbone import FeaturePyramidBackbone +from keras_hub.src.models.flux.flux_model import FluxBackbone +from keras_hub.src.models.flux.flux_text_to_image import FluxTextToImage from keras_hub.src.models.flux.flux_text_to_image_preprocessor import ( - FluxTextToImagePreprocessor as FluxTextToImagePreprocessor, -) -from keras_hub.src.models.gemma.gemma_backbone import ( - GemmaBackbone as GemmaBackbone, -) -from keras_hub.src.models.gemma.gemma_causal_lm import ( - GemmaCausalLM as GemmaCausalLM, + FluxTextToImagePreprocessor, ) +from keras_hub.src.models.gemma.gemma_backbone import GemmaBackbone +from keras_hub.src.models.gemma.gemma_causal_lm import GemmaCausalLM from keras_hub.src.models.gemma.gemma_causal_lm_preprocessor import ( - GemmaCausalLMPreprocessor as GemmaCausalLMPreprocessor, -) -from keras_hub.src.models.gemma.gemma_tokenizer import ( - GemmaTokenizer as GemmaTokenizer, -) -from keras_hub.src.models.gemma3.gemma3_backbone import ( - Gemma3Backbone as Gemma3Backbone, -) -from keras_hub.src.models.gemma3.gemma3_causal_lm import ( - Gemma3CausalLM as Gemma3CausalLM, + GemmaCausalLMPreprocessor, ) +from keras_hub.src.models.gemma.gemma_tokenizer import GemmaTokenizer +from keras_hub.src.models.gemma3.gemma3_backbone import Gemma3Backbone +from keras_hub.src.models.gemma3.gemma3_causal_lm import Gemma3CausalLM from keras_hub.src.models.gemma3.gemma3_causal_lm_preprocessor import ( - Gemma3CausalLMPreprocessor as Gemma3CausalLMPreprocessor, -) -from keras_hub.src.models.gemma3.gemma3_tokenizer import ( - Gemma3Tokenizer as Gemma3Tokenizer, + Gemma3CausalLMPreprocessor, ) +from keras_hub.src.models.gemma3.gemma3_tokenizer import Gemma3Tokenizer from keras_hub.src.models.gemma3.gemma3_vision_encoder import ( - Gemma3VisionEncoder as Gemma3VisionEncoder, -) -from keras_hub.src.models.gpt2.gpt2_backbone import GPT2Backbone as GPT2Backbone -from keras_hub.src.models.gpt2.gpt2_causal_lm import ( - GPT2CausalLM as GPT2CausalLM, + Gemma3VisionEncoder, ) +from keras_hub.src.models.gpt2.gpt2_backbone import GPT2Backbone +from keras_hub.src.models.gpt2.gpt2_causal_lm import GPT2CausalLM from keras_hub.src.models.gpt2.gpt2_causal_lm_preprocessor import ( - GPT2CausalLMPreprocessor as GPT2CausalLMPreprocessor, -) -from keras_hub.src.models.gpt2.gpt2_preprocessor import ( - GPT2Preprocessor as GPT2Preprocessor, -) -from keras_hub.src.models.gpt2.gpt2_tokenizer import ( - GPT2Tokenizer as GPT2Tokenizer, -) -from keras_hub.src.models.gpt_neo_x.gpt_neo_x_backbone import ( - GPTNeoXBackbone as GPTNeoXBackbone, -) -from keras_hub.src.models.gpt_neo_x.gpt_neo_x_causal_lm import ( - GPTNeoXCausalLM as GPTNeoXCausalLM, + GPT2CausalLMPreprocessor, ) +from keras_hub.src.models.gpt2.gpt2_preprocessor import GPT2Preprocessor +from keras_hub.src.models.gpt2.gpt2_tokenizer import GPT2Tokenizer +from keras_hub.src.models.gpt_neo_x.gpt_neo_x_backbone import GPTNeoXBackbone +from keras_hub.src.models.gpt_neo_x.gpt_neo_x_causal_lm import GPTNeoXCausalLM from keras_hub.src.models.gpt_neo_x.gpt_neo_x_causal_lm_preprocessor import ( - GPTNeoXCausalLMPreprocessor as GPTNeoXCausalLMPreprocessor, -) -from keras_hub.src.models.gpt_neo_x.gpt_neo_x_tokenizer import ( - GPTNeoXTokenizer as GPTNeoXTokenizer, -) -from keras_hub.src.models.image_classifier import ( - ImageClassifier as ImageClassifier, + GPTNeoXCausalLMPreprocessor, ) +from keras_hub.src.models.gpt_neo_x.gpt_neo_x_tokenizer import GPTNeoXTokenizer +from keras_hub.src.models.image_classifier import ImageClassifier from keras_hub.src.models.image_classifier_preprocessor import ( - ImageClassifierPreprocessor as ImageClassifierPreprocessor, -) -from keras_hub.src.models.image_segmenter import ( - ImageSegmenter as ImageSegmenter, + ImageClassifierPreprocessor, ) +from keras_hub.src.models.image_segmenter import ImageSegmenter from keras_hub.src.models.image_segmenter_preprocessor import ( - ImageSegmenterPreprocessor as ImageSegmenterPreprocessor, -) -from keras_hub.src.models.image_to_image import ImageToImage as ImageToImage -from keras_hub.src.models.inpaint import Inpaint as Inpaint -from keras_hub.src.models.llama.llama_backbone import ( - LlamaBackbone as LlamaBackbone, -) -from keras_hub.src.models.llama.llama_causal_lm import ( - LlamaCausalLM as LlamaCausalLM, + ImageSegmenterPreprocessor, ) +from keras_hub.src.models.image_to_image import ImageToImage +from keras_hub.src.models.inpaint import Inpaint +from keras_hub.src.models.llama.llama_backbone import LlamaBackbone +from keras_hub.src.models.llama.llama_causal_lm import LlamaCausalLM from keras_hub.src.models.llama.llama_causal_lm_preprocessor import ( - LlamaCausalLMPreprocessor as LlamaCausalLMPreprocessor, -) -from keras_hub.src.models.llama.llama_tokenizer import ( - LlamaTokenizer as LlamaTokenizer, -) -from keras_hub.src.models.llama3.llama3_backbone import ( - Llama3Backbone as Llama3Backbone, -) -from keras_hub.src.models.llama3.llama3_causal_lm import ( - Llama3CausalLM as Llama3CausalLM, + LlamaCausalLMPreprocessor, ) +from keras_hub.src.models.llama.llama_tokenizer import LlamaTokenizer +from keras_hub.src.models.llama3.llama3_backbone import Llama3Backbone +from keras_hub.src.models.llama3.llama3_causal_lm import Llama3CausalLM from keras_hub.src.models.llama3.llama3_causal_lm_preprocessor import ( - Llama3CausalLMPreprocessor as Llama3CausalLMPreprocessor, -) -from keras_hub.src.models.llama3.llama3_tokenizer import ( - Llama3Tokenizer as Llama3Tokenizer, -) -from keras_hub.src.models.masked_lm import MaskedLM as MaskedLM -from keras_hub.src.models.masked_lm_preprocessor import ( - MaskedLMPreprocessor as MaskedLMPreprocessor, -) -from keras_hub.src.models.mistral.mistral_backbone import ( - MistralBackbone as MistralBackbone, -) -from keras_hub.src.models.mistral.mistral_causal_lm import ( - MistralCausalLM as MistralCausalLM, + Llama3CausalLMPreprocessor, ) +from keras_hub.src.models.llama3.llama3_tokenizer import Llama3Tokenizer +from keras_hub.src.models.masked_lm import MaskedLM +from keras_hub.src.models.masked_lm_preprocessor import MaskedLMPreprocessor +from keras_hub.src.models.mistral.mistral_backbone import MistralBackbone +from keras_hub.src.models.mistral.mistral_causal_lm import MistralCausalLM from keras_hub.src.models.mistral.mistral_causal_lm_preprocessor import ( - MistralCausalLMPreprocessor as MistralCausalLMPreprocessor, -) -from keras_hub.src.models.mistral.mistral_tokenizer import ( - MistralTokenizer as MistralTokenizer, -) -from keras_hub.src.models.mit.mit_backbone import MiTBackbone as MiTBackbone -from keras_hub.src.models.mit.mit_image_classifier import ( - MiTImageClassifier as MiTImageClassifier, + MistralCausalLMPreprocessor, ) +from keras_hub.src.models.mistral.mistral_tokenizer import MistralTokenizer +from keras_hub.src.models.mit.mit_backbone import MiTBackbone +from keras_hub.src.models.mit.mit_image_classifier import MiTImageClassifier from keras_hub.src.models.mit.mit_image_classifier_preprocessor import ( - MiTImageClassifierPreprocessor as MiTImageClassifierPreprocessor, -) -from keras_hub.src.models.mixtral.mixtral_backbone import ( - MixtralBackbone as MixtralBackbone, -) -from keras_hub.src.models.mixtral.mixtral_causal_lm import ( - MixtralCausalLM as MixtralCausalLM, + MiTImageClassifierPreprocessor, ) +from keras_hub.src.models.mixtral.mixtral_backbone import MixtralBackbone +from keras_hub.src.models.mixtral.mixtral_causal_lm import MixtralCausalLM from keras_hub.src.models.mixtral.mixtral_causal_lm_preprocessor import ( - MixtralCausalLMPreprocessor as MixtralCausalLMPreprocessor, -) -from keras_hub.src.models.mixtral.mixtral_tokenizer import ( - MixtralTokenizer as MixtralTokenizer, -) -from keras_hub.src.models.mobilenet.mobilenet_backbone import ( - MobileNetBackbone as MobileNetBackbone, + MixtralCausalLMPreprocessor, ) +from keras_hub.src.models.mixtral.mixtral_tokenizer import MixtralTokenizer +from keras_hub.src.models.mobilenet.mobilenet_backbone import MobileNetBackbone from keras_hub.src.models.mobilenet.mobilenet_image_classifier import ( - MobileNetImageClassifier as MobileNetImageClassifier, + MobileNetImageClassifier, ) from keras_hub.src.models.mobilenet.mobilenet_image_classifier_preprocessor import ( - MobileNetImageClassifierPreprocessor as MobileNetImageClassifierPreprocessor, + MobileNetImageClassifierPreprocessor, ) +from keras_hub.src.models.object_detector import ObjectDetector from keras_hub.src.models.object_detector import ( ObjectDetector as ImageObjectDetector, ) -from keras_hub.src.models.object_detector import ( - ObjectDetector as ObjectDetector, -) from keras_hub.src.models.object_detector_preprocessor import ( - ObjectDetectorPreprocessor as ImageObjectDetectorPreprocessor, + ObjectDetectorPreprocessor, ) from keras_hub.src.models.object_detector_preprocessor import ( - ObjectDetectorPreprocessor as ObjectDetectorPreprocessor, + ObjectDetectorPreprocessor as ImageObjectDetectorPreprocessor, ) -from keras_hub.src.models.opt.opt_backbone import OPTBackbone as OPTBackbone -from keras_hub.src.models.opt.opt_causal_lm import OPTCausalLM as OPTCausalLM +from keras_hub.src.models.opt.opt_backbone import OPTBackbone +from keras_hub.src.models.opt.opt_causal_lm import OPTCausalLM from keras_hub.src.models.opt.opt_causal_lm_preprocessor import ( - OPTCausalLMPreprocessor as OPTCausalLMPreprocessor, + OPTCausalLMPreprocessor, ) -from keras_hub.src.models.opt.opt_tokenizer import OPTTokenizer as OPTTokenizer +from keras_hub.src.models.opt.opt_tokenizer import OPTTokenizer from keras_hub.src.models.pali_gemma.pali_gemma_backbone import ( - PaliGemmaBackbone as PaliGemmaBackbone, + PaliGemmaBackbone, ) from keras_hub.src.models.pali_gemma.pali_gemma_causal_lm import ( - PaliGemmaCausalLM as PaliGemmaCausalLM, + PaliGemmaCausalLM, ) from keras_hub.src.models.pali_gemma.pali_gemma_causal_lm_preprocessor import ( - PaliGemmaCausalLMPreprocessor as PaliGemmaCausalLMPreprocessor, + PaliGemmaCausalLMPreprocessor, ) from keras_hub.src.models.pali_gemma.pali_gemma_tokenizer import ( - PaliGemmaTokenizer as PaliGemmaTokenizer, -) -from keras_hub.src.models.phi3.phi3_backbone import Phi3Backbone as Phi3Backbone -from keras_hub.src.models.phi3.phi3_causal_lm import ( - Phi3CausalLM as Phi3CausalLM, + PaliGemmaTokenizer, ) +from keras_hub.src.models.phi3.phi3_backbone import Phi3Backbone +from keras_hub.src.models.phi3.phi3_causal_lm import Phi3CausalLM from keras_hub.src.models.phi3.phi3_causal_lm_preprocessor import ( - Phi3CausalLMPreprocessor as Phi3CausalLMPreprocessor, + Phi3CausalLMPreprocessor, ) -from keras_hub.src.models.phi3.phi3_tokenizer import ( - Phi3Tokenizer as Phi3Tokenizer, -) -from keras_hub.src.models.preprocessor import Preprocessor as Preprocessor +from keras_hub.src.models.phi3.phi3_tokenizer import Phi3Tokenizer +from keras_hub.src.models.preprocessor import Preprocessor +from keras_hub.src.models.qwen.qwen_backbone import QwenBackbone from keras_hub.src.models.qwen.qwen_backbone import ( QwenBackbone as Qwen2Backbone, ) -from keras_hub.src.models.qwen.qwen_backbone import QwenBackbone as QwenBackbone +from keras_hub.src.models.qwen.qwen_causal_lm import QwenCausalLM from keras_hub.src.models.qwen.qwen_causal_lm import ( QwenCausalLM as Qwen2CausalLM, ) -from keras_hub.src.models.qwen.qwen_causal_lm import ( - QwenCausalLM as QwenCausalLM, -) from keras_hub.src.models.qwen.qwen_causal_lm_preprocessor import ( - QwenCausalLMPreprocessor as Qwen2CausalLMPreprocessor, + QwenCausalLMPreprocessor, ) from keras_hub.src.models.qwen.qwen_causal_lm_preprocessor import ( - QwenCausalLMPreprocessor as QwenCausalLMPreprocessor, + QwenCausalLMPreprocessor as Qwen2CausalLMPreprocessor, ) +from keras_hub.src.models.qwen.qwen_tokenizer import QwenTokenizer from keras_hub.src.models.qwen.qwen_tokenizer import ( QwenTokenizer as Qwen2Tokenizer, ) -from keras_hub.src.models.qwen.qwen_tokenizer import ( - QwenTokenizer as QwenTokenizer, -) -from keras_hub.src.models.qwen_moe.qwen_moe_backbone import ( - QwenMoeBackbone as QwenMoeBackbone, -) -from keras_hub.src.models.qwen_moe.qwen_moe_causal_lm import ( - QwenMoeCausalLM as QwenMoeCausalLM, -) +from keras_hub.src.models.qwen_moe.qwen_moe_backbone import QwenMoeBackbone +from keras_hub.src.models.qwen_moe.qwen_moe_causal_lm import QwenMoeCausalLM from keras_hub.src.models.qwen_moe.qwen_moe_causal_lm_preprocessor import ( - QwenMoeCausalLMPreprocessor as QwenMoeCausalLMPreprocessor, -) -from keras_hub.src.models.resnet.resnet_backbone import ( - ResNetBackbone as ResNetBackbone, + QwenMoeCausalLMPreprocessor, ) +from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone from keras_hub.src.models.resnet.resnet_image_classifier import ( - ResNetImageClassifier as ResNetImageClassifier, + ResNetImageClassifier, ) from keras_hub.src.models.resnet.resnet_image_classifier_preprocessor import ( - ResNetImageClassifierPreprocessor as ResNetImageClassifierPreprocessor, -) -from keras_hub.src.models.retinanet.retinanet_backbone import ( - RetinaNetBackbone as RetinaNetBackbone, + ResNetImageClassifierPreprocessor, ) +from keras_hub.src.models.retinanet.retinanet_backbone import RetinaNetBackbone from keras_hub.src.models.retinanet.retinanet_object_detector import ( - RetinaNetObjectDetector as RetinaNetObjectDetector, + RetinaNetObjectDetector, ) from keras_hub.src.models.retinanet.retinanet_object_detector_preprocessor import ( - RetinaNetObjectDetectorPreprocessor as RetinaNetObjectDetectorPreprocessor, -) -from keras_hub.src.models.roberta.roberta_backbone import ( - RobertaBackbone as RobertaBackbone, -) -from keras_hub.src.models.roberta.roberta_masked_lm import ( - RobertaMaskedLM as RobertaMaskedLM, + RetinaNetObjectDetectorPreprocessor, ) +from keras_hub.src.models.roberta.roberta_backbone import RobertaBackbone +from keras_hub.src.models.roberta.roberta_masked_lm import RobertaMaskedLM from keras_hub.src.models.roberta.roberta_masked_lm_preprocessor import ( - RobertaMaskedLMPreprocessor as RobertaMaskedLMPreprocessor, + RobertaMaskedLMPreprocessor, ) from keras_hub.src.models.roberta.roberta_text_classifier import ( - RobertaTextClassifier as RobertaClassifier, + RobertaTextClassifier, ) from keras_hub.src.models.roberta.roberta_text_classifier import ( - RobertaTextClassifier as RobertaTextClassifier, + RobertaTextClassifier as RobertaClassifier, ) from keras_hub.src.models.roberta.roberta_text_classifier_preprocessor import ( - RobertaTextClassifierPreprocessor as RobertaPreprocessor, + RobertaTextClassifierPreprocessor, ) from keras_hub.src.models.roberta.roberta_text_classifier_preprocessor import ( - RobertaTextClassifierPreprocessor as RobertaTextClassifierPreprocessor, -) -from keras_hub.src.models.roberta.roberta_tokenizer import ( - RobertaTokenizer as RobertaTokenizer, + RobertaTextClassifierPreprocessor as RobertaPreprocessor, ) +from keras_hub.src.models.roberta.roberta_tokenizer import RobertaTokenizer from keras_hub.src.models.roformer_v2.roformer_v2_backbone import ( - RoformerV2Backbone as RoformerV2Backbone, + RoformerV2Backbone, ) from keras_hub.src.models.roformer_v2.roformer_v2_masked_lm import ( - RoformerV2MaskedLM as RoformerV2MaskedLM, + RoformerV2MaskedLM, ) from keras_hub.src.models.roformer_v2.roformer_v2_masked_lm_preprocessor import ( - RoformerV2MaskedLMPreprocessor as RoformerV2MaskedLMPreprocessor, + RoformerV2MaskedLMPreprocessor, ) from keras_hub.src.models.roformer_v2.roformer_v2_text_classifier import ( - RoformerV2TextClassifier as RoformerV2TextClassifier, + RoformerV2TextClassifier, ) from keras_hub.src.models.roformer_v2.roformer_v2_text_classifier_preprocessor import ( - RoformerV2TextClassifierPreprocessor as RoformerV2TextClassifierPreprocessor, + RoformerV2TextClassifierPreprocessor, ) from keras_hub.src.models.roformer_v2.roformer_v2_tokenizer import ( - RoformerV2Tokenizer as RoformerV2Tokenizer, -) -from keras_hub.src.models.sam.sam_backbone import SAMBackbone as SAMBackbone -from keras_hub.src.models.sam.sam_image_segmenter import ( - SAMImageSegmenter as SAMImageSegmenter, + RoformerV2Tokenizer, ) +from keras_hub.src.models.sam.sam_backbone import SAMBackbone +from keras_hub.src.models.sam.sam_image_segmenter import SAMImageSegmenter from keras_hub.src.models.sam.sam_image_segmenter_preprocessor import ( - SAMImageSegmenterPreprocessor as SAMImageSegmenterPreprocessor, -) -from keras_hub.src.models.segformer.segformer_backbone import ( - SegFormerBackbone as SegFormerBackbone, + SAMImageSegmenterPreprocessor, ) +from keras_hub.src.models.segformer.segformer_backbone import SegFormerBackbone from keras_hub.src.models.segformer.segformer_image_segmenter import ( - SegFormerImageSegmenter as SegFormerImageSegmenter, + SegFormerImageSegmenter, ) from keras_hub.src.models.segformer.segformer_image_segmenter_preprocessor import ( - SegFormerImageSegmenterPreprocessor as SegFormerImageSegmenterPreprocessor, -) -from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM as Seq2SeqLM -from keras_hub.src.models.seq_2_seq_lm_preprocessor import ( - Seq2SeqLMPreprocessor as Seq2SeqLMPreprocessor, -) -from keras_hub.src.models.siglip.siglip_backbone import ( - SigLIPBackbone as SigLIPBackbone, -) -from keras_hub.src.models.siglip.siglip_preprocessor import ( - SigLIPPreprocessor as SigLIPPreprocessor, -) -from keras_hub.src.models.siglip.siglip_text_encoder import ( - SigLIPTextEncoder as SigLIPTextEncoder, -) -from keras_hub.src.models.siglip.siglip_tokenizer import ( - SigLIPTokenizer as SigLIPTokenizer, -) + SegFormerImageSegmenterPreprocessor, +) +from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM +from keras_hub.src.models.seq_2_seq_lm_preprocessor import Seq2SeqLMPreprocessor +from keras_hub.src.models.siglip.siglip_backbone import SigLIPBackbone +from keras_hub.src.models.siglip.siglip_preprocessor import SigLIPPreprocessor +from keras_hub.src.models.siglip.siglip_text_encoder import SigLIPTextEncoder +from keras_hub.src.models.siglip.siglip_tokenizer import SigLIPTokenizer from keras_hub.src.models.siglip.siglip_vision_encoder import ( - SigLIPVisionEncoder as SigLIPVisionEncoder, + SigLIPVisionEncoder, ) from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import ( - StableDiffusion3Backbone as StableDiffusion3Backbone, + StableDiffusion3Backbone, ) from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_image_to_image import ( - StableDiffusion3ImageToImage as StableDiffusion3ImageToImage, + StableDiffusion3ImageToImage, ) from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_inpaint import ( - StableDiffusion3Inpaint as StableDiffusion3Inpaint, + StableDiffusion3Inpaint, ) from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image import ( - StableDiffusion3TextToImage as StableDiffusion3TextToImage, + StableDiffusion3TextToImage, ) from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image_preprocessor import ( - StableDiffusion3TextToImagePreprocessor as StableDiffusion3TextToImagePreprocessor, + StableDiffusion3TextToImagePreprocessor, ) -from keras_hub.src.models.t5.t5_backbone import T5Backbone as T5Backbone -from keras_hub.src.models.t5.t5_preprocessor import ( - T5Preprocessor as T5Preprocessor, -) -from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer as T5Tokenizer -from keras_hub.src.models.task import Task as Task +from keras_hub.src.models.t5.t5_backbone import T5Backbone +from keras_hub.src.models.t5.t5_preprocessor import T5Preprocessor +from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer +from keras_hub.src.models.task import Task +from keras_hub.src.models.text_classifier import TextClassifier from keras_hub.src.models.text_classifier import TextClassifier as Classifier -from keras_hub.src.models.text_classifier import ( - TextClassifier as TextClassifier, -) from keras_hub.src.models.text_classifier_preprocessor import ( - TextClassifierPreprocessor as TextClassifierPreprocessor, + TextClassifierPreprocessor, ) -from keras_hub.src.models.text_to_image import TextToImage as TextToImage +from keras_hub.src.models.text_to_image import TextToImage from keras_hub.src.models.text_to_image_preprocessor import ( - TextToImagePreprocessor as TextToImagePreprocessor, -) -from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone as VGGBackbone -from keras_hub.src.models.vgg.vgg_image_classifier import ( - VGGImageClassifier as VGGImageClassifier, + TextToImagePreprocessor, ) +from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone +from keras_hub.src.models.vgg.vgg_image_classifier import VGGImageClassifier from keras_hub.src.models.vgg.vgg_image_classifier_preprocessor import ( - VGGImageClassifierPreprocessor as VGGImageClassifierPreprocessor, -) -from keras_hub.src.models.vit.vit_backbone import ViTBackbone as ViTBackbone -from keras_hub.src.models.vit.vit_image_classifier import ( - ViTImageClassifier as ViTImageClassifier, + VGGImageClassifierPreprocessor, ) +from keras_hub.src.models.vit.vit_backbone import ViTBackbone +from keras_hub.src.models.vit.vit_image_classifier import ViTImageClassifier from keras_hub.src.models.vit.vit_image_classifier_preprocessor import ( - ViTImageClassifierPreprocessor as ViTImageClassifierPreprocessor, -) -from keras_hub.src.models.vit_det.vit_det_backbone import ( - ViTDetBackbone as ViTDetBackbone, -) -from keras_hub.src.models.whisper.whisper_backbone import ( - WhisperBackbone as WhisperBackbone, -) -from keras_hub.src.models.whisper.whisper_tokenizer import ( - WhisperTokenizer as WhisperTokenizer, -) -from keras_hub.src.models.xception.xception_backbone import ( - XceptionBackbone as XceptionBackbone, + ViTImageClassifierPreprocessor, ) +from keras_hub.src.models.vit_det.vit_det_backbone import ViTDetBackbone +from keras_hub.src.models.whisper.whisper_backbone import WhisperBackbone +from keras_hub.src.models.whisper.whisper_tokenizer import WhisperTokenizer +from keras_hub.src.models.xception.xception_backbone import XceptionBackbone from keras_hub.src.models.xception.xception_image_classifier import ( - XceptionImageClassifier as XceptionImageClassifier, + XceptionImageClassifier, ) from keras_hub.src.models.xception.xception_image_classifier_preprocessor import ( - XceptionImageClassifierPreprocessor as XceptionImageClassifierPreprocessor, + XceptionImageClassifierPreprocessor, ) from keras_hub.src.models.xlm_roberta.xlm_roberta_backbone import ( - XLMRobertaBackbone as XLMRobertaBackbone, + XLMRobertaBackbone, ) from keras_hub.src.models.xlm_roberta.xlm_roberta_masked_lm import ( - XLMRobertaMaskedLM as XLMRobertaMaskedLM, + XLMRobertaMaskedLM, ) from keras_hub.src.models.xlm_roberta.xlm_roberta_masked_lm_preprocessor import ( - XLMRobertaMaskedLMPreprocessor as XLMRobertaMaskedLMPreprocessor, + XLMRobertaMaskedLMPreprocessor, ) from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier import ( - XLMRobertaTextClassifier as XLMRobertaClassifier, + XLMRobertaTextClassifier, ) from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier import ( - XLMRobertaTextClassifier as XLMRobertaTextClassifier, + XLMRobertaTextClassifier as XLMRobertaClassifier, ) from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier_preprocessor import ( - XLMRobertaTextClassifierPreprocessor as XLMRobertaPreprocessor, + XLMRobertaTextClassifierPreprocessor, ) from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier_preprocessor import ( - XLMRobertaTextClassifierPreprocessor as XLMRobertaTextClassifierPreprocessor, + XLMRobertaTextClassifierPreprocessor as XLMRobertaPreprocessor, ) from keras_hub.src.models.xlm_roberta.xlm_roberta_tokenizer import ( - XLMRobertaTokenizer as XLMRobertaTokenizer, -) -from keras_hub.src.models.xlnet.xlnet_backbone import ( - XLNetBackbone as XLNetBackbone, + XLMRobertaTokenizer, ) -from keras_hub.src.tokenizers.tokenizer import Tokenizer as Tokenizer +from keras_hub.src.models.xlnet.xlnet_backbone import XLNetBackbone +from keras_hub.src.tokenizers.tokenizer import Tokenizer diff --git a/keras_hub/api/samplers/__init__.py b/keras_hub/api/samplers/__init__.py index 29bfef00fc..9feb76c669 100644 --- a/keras_hub/api/samplers/__init__.py +++ b/keras_hub/api/samplers/__init__.py @@ -4,15 +4,13 @@ since your modifications would be overwritten. """ -from keras_hub.src.samplers.beam_sampler import BeamSampler as BeamSampler -from keras_hub.src.samplers.contrastive_sampler import ( - ContrastiveSampler as ContrastiveSampler, -) -from keras_hub.src.samplers.greedy_sampler import GreedySampler as GreedySampler -from keras_hub.src.samplers.random_sampler import RandomSampler as RandomSampler -from keras_hub.src.samplers.sampler import Sampler as Sampler -from keras_hub.src.samplers.serialization import deserialize as deserialize -from keras_hub.src.samplers.serialization import get as get -from keras_hub.src.samplers.serialization import serialize as serialize -from keras_hub.src.samplers.top_k_sampler import TopKSampler as TopKSampler -from keras_hub.src.samplers.top_p_sampler import TopPSampler as TopPSampler +from keras_hub.src.samplers.beam_sampler import BeamSampler +from keras_hub.src.samplers.contrastive_sampler import ContrastiveSampler +from keras_hub.src.samplers.greedy_sampler import GreedySampler +from keras_hub.src.samplers.random_sampler import RandomSampler +from keras_hub.src.samplers.sampler import Sampler +from keras_hub.src.samplers.serialization import deserialize +from keras_hub.src.samplers.serialization import get +from keras_hub.src.samplers.serialization import serialize +from keras_hub.src.samplers.top_k_sampler import TopKSampler +from keras_hub.src.samplers.top_p_sampler import TopPSampler diff --git a/keras_hub/api/tokenizers/__init__.py b/keras_hub/api/tokenizers/__init__.py index 303bd190fc..96818e01e7 100644 --- a/keras_hub/api/tokenizers/__init__.py +++ b/keras_hub/api/tokenizers/__init__.py @@ -4,112 +4,62 @@ since your modifications would be overwritten. """ -from keras_hub.src.models.albert.albert_tokenizer import ( - AlbertTokenizer as AlbertTokenizer, -) -from keras_hub.src.models.bart.bart_tokenizer import ( - BartTokenizer as BartTokenizer, -) -from keras_hub.src.models.bert.bert_tokenizer import ( - BertTokenizer as BertTokenizer, -) -from keras_hub.src.models.bloom.bloom_tokenizer import ( - BloomTokenizer as BloomTokenizer, -) -from keras_hub.src.models.clip.clip_tokenizer import ( - CLIPTokenizer as CLIPTokenizer, -) +from keras_hub.src.models.albert.albert_tokenizer import AlbertTokenizer +from keras_hub.src.models.bart.bart_tokenizer import BartTokenizer +from keras_hub.src.models.bert.bert_tokenizer import BertTokenizer +from keras_hub.src.models.bloom.bloom_tokenizer import BloomTokenizer +from keras_hub.src.models.clip.clip_tokenizer import CLIPTokenizer from keras_hub.src.models.deberta_v3.deberta_v3_tokenizer import ( - DebertaV3Tokenizer as DebertaV3Tokenizer, + DebertaV3Tokenizer, ) from keras_hub.src.models.distil_bert.distil_bert_tokenizer import ( - DistilBertTokenizer as DistilBertTokenizer, -) -from keras_hub.src.models.electra.electra_tokenizer import ( - ElectraTokenizer as ElectraTokenizer, -) -from keras_hub.src.models.esm.esm_tokenizer import ESMTokenizer as ESMTokenizer -from keras_hub.src.models.f_net.f_net_tokenizer import ( - FNetTokenizer as FNetTokenizer, -) -from keras_hub.src.models.falcon.falcon_tokenizer import ( - FalconTokenizer as FalconTokenizer, -) -from keras_hub.src.models.gemma.gemma_tokenizer import ( - GemmaTokenizer as GemmaTokenizer, -) -from keras_hub.src.models.gemma3.gemma3_tokenizer import ( - Gemma3Tokenizer as Gemma3Tokenizer, -) -from keras_hub.src.models.gpt2.gpt2_tokenizer import ( - GPT2Tokenizer as GPT2Tokenizer, -) -from keras_hub.src.models.gpt_neo_x.gpt_neo_x_tokenizer import ( - GPTNeoXTokenizer as GPTNeoXTokenizer, -) -from keras_hub.src.models.llama.llama_tokenizer import ( - LlamaTokenizer as LlamaTokenizer, -) -from keras_hub.src.models.llama3.llama3_tokenizer import ( - Llama3Tokenizer as Llama3Tokenizer, -) -from keras_hub.src.models.mistral.mistral_tokenizer import ( - MistralTokenizer as MistralTokenizer, -) -from keras_hub.src.models.mixtral.mixtral_tokenizer import ( - MixtralTokenizer as MixtralTokenizer, -) -from keras_hub.src.models.opt.opt_tokenizer import OPTTokenizer as OPTTokenizer + DistilBertTokenizer, +) +from keras_hub.src.models.electra.electra_tokenizer import ElectraTokenizer +from keras_hub.src.models.esm.esm_tokenizer import ESMTokenizer +from keras_hub.src.models.f_net.f_net_tokenizer import FNetTokenizer +from keras_hub.src.models.falcon.falcon_tokenizer import FalconTokenizer +from keras_hub.src.models.gemma.gemma_tokenizer import GemmaTokenizer +from keras_hub.src.models.gemma3.gemma3_tokenizer import Gemma3Tokenizer +from keras_hub.src.models.gpt2.gpt2_tokenizer import GPT2Tokenizer +from keras_hub.src.models.gpt_neo_x.gpt_neo_x_tokenizer import GPTNeoXTokenizer +from keras_hub.src.models.llama.llama_tokenizer import LlamaTokenizer +from keras_hub.src.models.llama3.llama3_tokenizer import Llama3Tokenizer +from keras_hub.src.models.mistral.mistral_tokenizer import MistralTokenizer +from keras_hub.src.models.mixtral.mixtral_tokenizer import MixtralTokenizer +from keras_hub.src.models.opt.opt_tokenizer import OPTTokenizer from keras_hub.src.models.pali_gemma.pali_gemma_tokenizer import ( - PaliGemmaTokenizer as PaliGemmaTokenizer, -) -from keras_hub.src.models.phi3.phi3_tokenizer import ( - Phi3Tokenizer as Phi3Tokenizer, + PaliGemmaTokenizer, ) +from keras_hub.src.models.phi3.phi3_tokenizer import Phi3Tokenizer +from keras_hub.src.models.qwen.qwen_tokenizer import QwenTokenizer from keras_hub.src.models.qwen.qwen_tokenizer import ( QwenTokenizer as Qwen2Tokenizer, ) -from keras_hub.src.models.qwen.qwen_tokenizer import ( - QwenTokenizer as QwenTokenizer, -) -from keras_hub.src.models.qwen_moe.qwen_moe_tokenizer import ( - QwenMoeTokenizer as QwenMoeTokenizer, -) -from keras_hub.src.models.roberta.roberta_tokenizer import ( - RobertaTokenizer as RobertaTokenizer, -) +from keras_hub.src.models.qwen_moe.qwen_moe_tokenizer import QwenMoeTokenizer +from keras_hub.src.models.roberta.roberta_tokenizer import RobertaTokenizer from keras_hub.src.models.roformer_v2.roformer_v2_tokenizer import ( - RoformerV2Tokenizer as RoformerV2Tokenizer, -) -from keras_hub.src.models.siglip.siglip_tokenizer import ( - SigLIPTokenizer as SigLIPTokenizer, -) -from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer as T5Tokenizer -from keras_hub.src.models.whisper.whisper_tokenizer import ( - WhisperTokenizer as WhisperTokenizer, + RoformerV2Tokenizer, ) +from keras_hub.src.models.siglip.siglip_tokenizer import SigLIPTokenizer +from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer +from keras_hub.src.models.whisper.whisper_tokenizer import WhisperTokenizer from keras_hub.src.models.xlm_roberta.xlm_roberta_tokenizer import ( - XLMRobertaTokenizer as XLMRobertaTokenizer, -) -from keras_hub.src.tokenizers.byte_pair_tokenizer import ( - BytePairTokenizer as BytePairTokenizer, -) -from keras_hub.src.tokenizers.byte_tokenizer import ( - ByteTokenizer as ByteTokenizer, + XLMRobertaTokenizer, ) +from keras_hub.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer +from keras_hub.src.tokenizers.byte_tokenizer import ByteTokenizer from keras_hub.src.tokenizers.sentence_piece_tokenizer import ( - SentencePieceTokenizer as SentencePieceTokenizer, + SentencePieceTokenizer, ) from keras_hub.src.tokenizers.sentence_piece_tokenizer_trainer import ( - compute_sentence_piece_proto as compute_sentence_piece_proto, + compute_sentence_piece_proto, ) -from keras_hub.src.tokenizers.tokenizer import Tokenizer as Tokenizer +from keras_hub.src.tokenizers.tokenizer import Tokenizer from keras_hub.src.tokenizers.unicode_codepoint_tokenizer import ( - UnicodeCodepointTokenizer as UnicodeCodepointTokenizer, -) -from keras_hub.src.tokenizers.word_piece_tokenizer import ( - WordPieceTokenizer as WordPieceTokenizer, + UnicodeCodepointTokenizer, ) +from keras_hub.src.tokenizers.word_piece_tokenizer import WordPieceTokenizer from keras_hub.src.tokenizers.word_piece_tokenizer_trainer import ( - compute_word_piece_vocabulary as compute_word_piece_vocabulary, + compute_word_piece_vocabulary, ) diff --git a/keras_hub/api/utils/__init__.py b/keras_hub/api/utils/__init__.py index 0bd8cb642e..8ce47790b0 100644 --- a/keras_hub/api/utils/__init__.py +++ b/keras_hub/api/utils/__init__.py @@ -4,18 +4,10 @@ since your modifications would be overwritten. """ -from keras_hub.src.utils.coco.coco_utils import ( - coco_id_to_name as coco_id_to_name, -) -from keras_hub.src.utils.coco.coco_utils import ( - coco_name_to_id as coco_name_to_id, -) -from keras_hub.src.utils.imagenet.imagenet_utils import ( - decode_imagenet_predictions as decode_imagenet_predictions, -) -from keras_hub.src.utils.imagenet.imagenet_utils import ( - imagenet_id_to_name as imagenet_id_to_name, -) +from keras_hub.src.utils.coco.coco_utils import coco_id_to_name +from keras_hub.src.utils.coco.coco_utils import coco_name_to_id from keras_hub.src.utils.imagenet.imagenet_utils import ( - imagenet_name_to_id as imagenet_name_to_id, + decode_imagenet_predictions, ) +from keras_hub.src.utils.imagenet.imagenet_utils import imagenet_id_to_name +from keras_hub.src.utils.imagenet.imagenet_utils import imagenet_name_to_id diff --git a/keras_hub/src/models/esm/esm_backbone.py b/keras_hub/src/models/esm/esm_backbone.py index 5939ff0ee1..bf971fc30e 100644 --- a/keras_hub/src/models/esm/esm_backbone.py +++ b/keras_hub/src/models/esm/esm_backbone.py @@ -24,8 +24,6 @@ class ESMBackbone(Backbone): ESM2 encoder with any number of layers, heads, and embed dim.To load preset architectures and weights, use the `from_preset()` constructor. - Disclaimer: Pre-trained models are provided on an "as is" basis, without - warranties or conditions of any kind. Args: vocabulary_size: int. The size of the token vocabulary. @@ -36,16 +34,18 @@ class ESMBackbone(Backbone): intermediate_dim: int. The output dimension of the first Dense layer in a two-layer feedforward network for each transformer. dropout: float. Dropout probability for the Transformer encoder. - layer_norm_eps:bool.Should we use ln after embedding? - Since it's pre-norm, the default is false. + Defaults to 0.1 + layer_norm_eps:bool.If true, then layer norm will be + used before entering the transformer block. + Since it's pre-norm, the default is false. max_sequence_length: int. The maximum sequence length that this encoder can consume. If None, `max_sequence_length` uses the value from sequence length. This determines the variable shape for positional embeddings. position_embedding_type:esm1 use abs position embeding,esm2 use rope. so this parameter is only except for absolute and rotary. - dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use - for model computations and weights. Note that some computations, + dtype: None or str or .keras.mixed_precision.DTypePolicy. The dtype to + use for model computations and weights. Note that some computations, such as softmax and layer normalization, will always be done at float32 precision regardless of dtype. @@ -91,10 +91,14 @@ def __init__( pad_token_id=0, **kwargs, ): - support_positon_type = ["rotary", "absolute"] - if position_embedding_type.lower() not in support_positon_type: - raise ( - f"This model only support below position embedding type: {support_positon_type}" # noqa: E501 + if position_embedding_type not in ( + "rotary", + "absolute", + ): + raise ValueError( + '`position_embedding_type` must be either `"rotary"`, or ' + '`"absolute"`. Received ' + "position_embedding_type={position_embedding_type}." ) head_size = hidden_dim // num_heads # === Layers === diff --git a/keras_hub/src/models/esm/esm_classifier.py b/keras_hub/src/models/esm/esm_classifier.py index f6225157c6..2d5ac49ad2 100644 --- a/keras_hub/src/models/esm/esm_classifier.py +++ b/keras_hub/src/models/esm/esm_classifier.py @@ -22,9 +22,6 @@ class ESMProteinClassifier(RobertaTextClassifier): `fit()`, `predict()`, and `evaluate()`. This is done by default when creating the model with `from_preset()`. - Disclaimer: Pre-trained models are provided on an "as is" basis, without - warranties or conditions of any kind. - Args: backbone: A `keras_hub.models.ESMBackbone` instance. num_classes: int. Number of classes to predict. diff --git a/keras_hub/src/models/esm/esm_classifier_test.py b/keras_hub/src/models/esm/esm_classifier_test.py index 3f6e9a1501..9e3b69a34c 100644 --- a/keras_hub/src/models/esm/esm_classifier_test.py +++ b/keras_hub/src/models/esm/esm_classifier_test.py @@ -1,4 +1,5 @@ import keras +import pytest from keras_hub.src.models.esm.esm_backbone import ESMBackbone from keras_hub.src.models.esm.esm_classifier import ESMProteinClassifier @@ -45,3 +46,11 @@ def test_classifier_basics(self): train_data=self.train_data, expected_output_shape=(2, 2), ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=ESMProteinClassifier, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) diff --git a/keras_hub/src/models/esm/esm_masked_plm.py b/keras_hub/src/models/esm/esm_masked_plm.py index d7bc609e54..5506a0300f 100644 --- a/keras_hub/src/models/esm/esm_masked_plm.py +++ b/keras_hub/src/models/esm/esm_masked_plm.py @@ -27,10 +27,6 @@ class ESMMaskedPLM(MaskedLM): training and evaluation. This is done by default when creating the model with `from_preset()`. - Disclaimer: Pre-trained models are provided on an "as is" basis, without - warranties or conditions of any kind. The underlying model is provided by a - third party and subject to a separate license, available - [here](https://github.com/facebookresearch/esm). Args: backbone: A `keras_hub.models.ESM2Backbone` instance. diff --git a/keras_hub/src/models/esm/esm_masked_plm_test.py b/keras_hub/src/models/esm/esm_masked_plm_test.py index b02adc106d..bf0e9f1bb4 100644 --- a/keras_hub/src/models/esm/esm_masked_plm_test.py +++ b/keras_hub/src/models/esm/esm_masked_plm_test.py @@ -1,4 +1,5 @@ import keras +import pytest from keras_hub.src.models.esm.esm_backbone import ESMBackbone from keras_hub.src.models.esm.esm_masked_plm import ESMMaskedPLM @@ -48,3 +49,11 @@ def test_masked_lm_basics(self): train_data=self.train_data, expected_output_shape=(2, 5, 10), ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=ESMMaskedPLM, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) diff --git a/keras_hub/src/utils/transformers/convert_esm.py b/keras_hub/src/utils/transformers/convert_esm.py index c35de05b86..1c66851165 100644 --- a/keras_hub/src/utils/transformers/convert_esm.py +++ b/keras_hub/src/utils/transformers/convert_esm.py @@ -20,19 +20,13 @@ def convert_backbone_config(transformers_config): "pad_token_id": transformers_config["pad_token_id"], "max_sequence_length": transformers_config.get( "max_position_embeddings", None - ), # 默认值为None - "layer_norm_eps": transformers_config.get( - "layer_norm_eps", 1e-12 - ), # 默认值为1e-12 + ), + "layer_norm_eps": transformers_config.get("layer_norm_eps", 1e-12), "emb_layer_norm_before": transformers_config.get( "emb_layer_norm_before", False - ), # 默认值为False - "activation": transformers_config.get( - "activation", "gelu" - ), # 默认值为"gelu" - "max_wavelength": transformers_config.get( - "max_wavelength", 10000 - ), # 默认值为10000 + ), + "activation": transformers_config.get("activation", "gelu"), + "max_wavelength": transformers_config.get("max_wavelength", 10000), } From 79e738cd905eab942ebee97ea3aad42c108dc57d Mon Sep 17 00:00:00 2001 From: pass_lin <935499957@qq.com> Date: Sat, 17 May 2025 12:19:49 +0800 Subject: [PATCH 12/13] update --- keras_hub/api/__init__.py | 16 +- keras_hub/api/layers/__init__.py | 128 ++++-- keras_hub/api/metrics/__init__.py | 10 +- keras_hub/api/models/__init__.py | 636 +++++++++++++++++---------- keras_hub/api/samplers/__init__.py | 22 +- keras_hub/api/tokenizers/__init__.py | 128 ++++-- keras_hub/api/utils/__init__.py | 18 +- 7 files changed, 620 insertions(+), 338 deletions(-) diff --git a/keras_hub/api/__init__.py b/keras_hub/api/__init__.py index 3796e4c7f4..2aa98bf3f9 100644 --- a/keras_hub/api/__init__.py +++ b/keras_hub/api/__init__.py @@ -4,12 +4,12 @@ since your modifications would be overwritten. """ -from keras_hub import layers -from keras_hub import metrics -from keras_hub import models -from keras_hub import samplers -from keras_hub import tokenizers -from keras_hub import utils -from keras_hub.src.utils.preset_utils import upload_preset +from keras_hub import layers as layers +from keras_hub import metrics as metrics +from keras_hub import models as models +from keras_hub import samplers as samplers +from keras_hub import tokenizers as tokenizers +from keras_hub import utils as utils +from keras_hub.src.utils.preset_utils import upload_preset as upload_preset from keras_hub.src.version import __version__ as __version__ -from keras_hub.src.version import version +from keras_hub.src.version import version as version diff --git a/keras_hub/api/layers/__init__.py b/keras_hub/api/layers/__init__.py index d42af86a3c..61eb0621b6 100644 --- a/keras_hub/api/layers/__init__.py +++ b/keras_hub/api/layers/__init__.py @@ -4,86 +4,128 @@ since your modifications would be overwritten. """ -from keras_hub.src.layers.modeling.alibi_bias import AlibiBias -from keras_hub.src.layers.modeling.anchor_generator import AnchorGenerator -from keras_hub.src.layers.modeling.box_matcher import BoxMatcher +from keras_hub.src.layers.modeling.alibi_bias import AlibiBias as AlibiBias +from keras_hub.src.layers.modeling.anchor_generator import ( + AnchorGenerator as AnchorGenerator, +) +from keras_hub.src.layers.modeling.box_matcher import BoxMatcher as BoxMatcher from keras_hub.src.layers.modeling.cached_multi_head_attention import ( - CachedMultiHeadAttention, + CachedMultiHeadAttention as CachedMultiHeadAttention, +) +from keras_hub.src.layers.modeling.f_net_encoder import ( + FNetEncoder as FNetEncoder, +) +from keras_hub.src.layers.modeling.masked_lm_head import ( + MaskedLMHead as MaskedLMHead, +) +from keras_hub.src.layers.modeling.non_max_supression import ( + NonMaxSuppression as NonMaxSuppression, +) +from keras_hub.src.layers.modeling.position_embedding import ( + PositionEmbedding as PositionEmbedding, ) -from keras_hub.src.layers.modeling.f_net_encoder import FNetEncoder -from keras_hub.src.layers.modeling.masked_lm_head import MaskedLMHead -from keras_hub.src.layers.modeling.non_max_supression import NonMaxSuppression -from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding from keras_hub.src.layers.modeling.reversible_embedding import ( - ReversibleEmbedding, + ReversibleEmbedding as ReversibleEmbedding, +) +from keras_hub.src.layers.modeling.rms_normalization import ( + RMSNormalization as RMSNormalization, +) +from keras_hub.src.layers.modeling.rotary_embedding import ( + RotaryEmbedding as RotaryEmbedding, ) -from keras_hub.src.layers.modeling.rms_normalization import RMSNormalization -from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding from keras_hub.src.layers.modeling.sine_position_encoding import ( - SinePositionEncoding, + SinePositionEncoding as SinePositionEncoding, ) from keras_hub.src.layers.modeling.token_and_position_embedding import ( - TokenAndPositionEmbedding, + TokenAndPositionEmbedding as TokenAndPositionEmbedding, +) +from keras_hub.src.layers.modeling.transformer_decoder import ( + TransformerDecoder as TransformerDecoder, +) +from keras_hub.src.layers.modeling.transformer_encoder import ( + TransformerEncoder as TransformerEncoder, +) +from keras_hub.src.layers.preprocessing.audio_converter import ( + AudioConverter as AudioConverter, +) +from keras_hub.src.layers.preprocessing.image_converter import ( + ImageConverter as ImageConverter, ) -from keras_hub.src.layers.modeling.transformer_decoder import TransformerDecoder -from keras_hub.src.layers.modeling.transformer_encoder import TransformerEncoder -from keras_hub.src.layers.preprocessing.audio_converter import AudioConverter -from keras_hub.src.layers.preprocessing.image_converter import ImageConverter from keras_hub.src.layers.preprocessing.masked_lm_mask_generator import ( - MaskedLMMaskGenerator, + MaskedLMMaskGenerator as MaskedLMMaskGenerator, ) from keras_hub.src.layers.preprocessing.multi_segment_packer import ( - MultiSegmentPacker, + MultiSegmentPacker as MultiSegmentPacker, +) +from keras_hub.src.layers.preprocessing.random_deletion import ( + RandomDeletion as RandomDeletion, +) +from keras_hub.src.layers.preprocessing.random_swap import ( + RandomSwap as RandomSwap, +) +from keras_hub.src.layers.preprocessing.start_end_packer import ( + StartEndPacker as StartEndPacker, ) -from keras_hub.src.layers.preprocessing.random_deletion import RandomDeletion -from keras_hub.src.layers.preprocessing.random_swap import RandomSwap -from keras_hub.src.layers.preprocessing.start_end_packer import StartEndPacker from keras_hub.src.models.basnet.basnet_image_converter import ( - BASNetImageConverter, + BASNetImageConverter as BASNetImageConverter, +) +from keras_hub.src.models.clip.clip_image_converter import ( + CLIPImageConverter as CLIPImageConverter, ) -from keras_hub.src.models.clip.clip_image_converter import CLIPImageConverter from keras_hub.src.models.cspnet.cspnet_image_converter import ( - CSPNetImageConverter, + CSPNetImageConverter as CSPNetImageConverter, ) from keras_hub.src.models.deeplab_v3.deeplab_v3_image_converter import ( - DeepLabV3ImageConverter, + DeepLabV3ImageConverter as DeepLabV3ImageConverter, ) from keras_hub.src.models.densenet.densenet_image_converter import ( - DenseNetImageConverter, + DenseNetImageConverter as DenseNetImageConverter, ) from keras_hub.src.models.efficientnet.efficientnet_image_converter import ( - EfficientNetImageConverter, + EfficientNetImageConverter as EfficientNetImageConverter, ) from keras_hub.src.models.gemma3.gemma3_image_converter import ( - Gemma3ImageConverter, + Gemma3ImageConverter as Gemma3ImageConverter, +) +from keras_hub.src.models.mit.mit_image_converter import ( + MiTImageConverter as MiTImageConverter, ) -from keras_hub.src.models.mit.mit_image_converter import MiTImageConverter from keras_hub.src.models.mobilenet.mobilenet_image_converter import ( - MobileNetImageConverter, + MobileNetImageConverter as MobileNetImageConverter, ) from keras_hub.src.models.pali_gemma.pali_gemma_image_converter import ( - PaliGemmaImageConverter, + PaliGemmaImageConverter as PaliGemmaImageConverter, ) from keras_hub.src.models.resnet.resnet_image_converter import ( - ResNetImageConverter, + ResNetImageConverter as ResNetImageConverter, ) from keras_hub.src.models.retinanet.retinanet_image_converter import ( - RetinaNetImageConverter, + RetinaNetImageConverter as RetinaNetImageConverter, +) +from keras_hub.src.models.sam.sam_image_converter import ( + SAMImageConverter as SAMImageConverter, +) +from keras_hub.src.models.sam.sam_mask_decoder import ( + SAMMaskDecoder as SAMMaskDecoder, +) +from keras_hub.src.models.sam.sam_prompt_encoder import ( + SAMPromptEncoder as SAMPromptEncoder, ) -from keras_hub.src.models.sam.sam_image_converter import SAMImageConverter -from keras_hub.src.models.sam.sam_mask_decoder import SAMMaskDecoder -from keras_hub.src.models.sam.sam_prompt_encoder import SAMPromptEncoder from keras_hub.src.models.segformer.segformer_image_converter import ( - SegFormerImageConverter, + SegFormerImageConverter as SegFormerImageConverter, ) from keras_hub.src.models.siglip.siglip_image_converter import ( - SigLIPImageConverter, + SigLIPImageConverter as SigLIPImageConverter, +) +from keras_hub.src.models.vgg.vgg_image_converter import ( + VGGImageConverter as VGGImageConverter, +) +from keras_hub.src.models.vit.vit_image_converter import ( + ViTImageConverter as ViTImageConverter, ) -from keras_hub.src.models.vgg.vgg_image_converter import VGGImageConverter -from keras_hub.src.models.vit.vit_image_converter import ViTImageConverter from keras_hub.src.models.whisper.whisper_audio_converter import ( - WhisperAudioConverter, + WhisperAudioConverter as WhisperAudioConverter, ) from keras_hub.src.models.xception.xception_image_converter import ( - XceptionImageConverter, + XceptionImageConverter as XceptionImageConverter, ) diff --git a/keras_hub/api/metrics/__init__.py b/keras_hub/api/metrics/__init__.py index 88a0a7df2b..100c2c66fb 100644 --- a/keras_hub/api/metrics/__init__.py +++ b/keras_hub/api/metrics/__init__.py @@ -4,8 +4,8 @@ since your modifications would be overwritten. """ -from keras_hub.src.metrics.bleu import Bleu -from keras_hub.src.metrics.edit_distance import EditDistance -from keras_hub.src.metrics.perplexity import Perplexity -from keras_hub.src.metrics.rouge_l import RougeL -from keras_hub.src.metrics.rouge_n import RougeN +from keras_hub.src.metrics.bleu import Bleu as Bleu +from keras_hub.src.metrics.edit_distance import EditDistance as EditDistance +from keras_hub.src.metrics.perplexity import Perplexity as Perplexity +from keras_hub.src.metrics.rouge_l import RougeL as RougeL +from keras_hub.src.metrics.rouge_n import RougeN as RougeN diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index d0e2c7333f..d8bcc90de5 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -4,463 +4,643 @@ since your modifications would be overwritten. """ -from keras_hub.src.models.albert.albert_backbone import AlbertBackbone -from keras_hub.src.models.albert.albert_masked_lm import AlbertMaskedLM +from keras_hub.src.models.albert.albert_backbone import ( + AlbertBackbone as AlbertBackbone, +) +from keras_hub.src.models.albert.albert_masked_lm import ( + AlbertMaskedLM as AlbertMaskedLM, +) from keras_hub.src.models.albert.albert_masked_lm_preprocessor import ( - AlbertMaskedLMPreprocessor, + AlbertMaskedLMPreprocessor as AlbertMaskedLMPreprocessor, ) from keras_hub.src.models.albert.albert_text_classifier import ( - AlbertTextClassifier, + AlbertTextClassifier as AlbertClassifier, ) from keras_hub.src.models.albert.albert_text_classifier import ( - AlbertTextClassifier as AlbertClassifier, + AlbertTextClassifier as AlbertTextClassifier, ) from keras_hub.src.models.albert.albert_text_classifier_preprocessor import ( - AlbertTextClassifierPreprocessor, + AlbertTextClassifierPreprocessor as AlbertPreprocessor, ) from keras_hub.src.models.albert.albert_text_classifier_preprocessor import ( - AlbertTextClassifierPreprocessor as AlbertPreprocessor, + AlbertTextClassifierPreprocessor as AlbertTextClassifierPreprocessor, +) +from keras_hub.src.models.albert.albert_tokenizer import ( + AlbertTokenizer as AlbertTokenizer, +) +from keras_hub.src.models.backbone import Backbone as Backbone +from keras_hub.src.models.bart.bart_backbone import BartBackbone as BartBackbone +from keras_hub.src.models.bart.bart_seq_2_seq_lm import ( + BartSeq2SeqLM as BartSeq2SeqLM, ) -from keras_hub.src.models.albert.albert_tokenizer import AlbertTokenizer -from keras_hub.src.models.backbone import Backbone -from keras_hub.src.models.bart.bart_backbone import BartBackbone -from keras_hub.src.models.bart.bart_seq_2_seq_lm import BartSeq2SeqLM from keras_hub.src.models.bart.bart_seq_2_seq_lm_preprocessor import ( - BartSeq2SeqLMPreprocessor, -) -from keras_hub.src.models.bart.bart_tokenizer import BartTokenizer -from keras_hub.src.models.basnet.basnet import BASNetImageSegmenter -from keras_hub.src.models.basnet.basnet_backbone import BASNetBackbone -from keras_hub.src.models.basnet.basnet_preprocessor import BASNetPreprocessor -from keras_hub.src.models.bert.bert_backbone import BertBackbone -from keras_hub.src.models.bert.bert_masked_lm import BertMaskedLM + BartSeq2SeqLMPreprocessor as BartSeq2SeqLMPreprocessor, +) +from keras_hub.src.models.bart.bart_tokenizer import ( + BartTokenizer as BartTokenizer, +) +from keras_hub.src.models.basnet.basnet import ( + BASNetImageSegmenter as BASNetImageSegmenter, +) +from keras_hub.src.models.basnet.basnet_backbone import ( + BASNetBackbone as BASNetBackbone, +) +from keras_hub.src.models.basnet.basnet_preprocessor import ( + BASNetPreprocessor as BASNetPreprocessor, +) +from keras_hub.src.models.bert.bert_backbone import BertBackbone as BertBackbone +from keras_hub.src.models.bert.bert_masked_lm import ( + BertMaskedLM as BertMaskedLM, +) from keras_hub.src.models.bert.bert_masked_lm_preprocessor import ( - BertMaskedLMPreprocessor, + BertMaskedLMPreprocessor as BertMaskedLMPreprocessor, ) -from keras_hub.src.models.bert.bert_text_classifier import BertTextClassifier from keras_hub.src.models.bert.bert_text_classifier import ( BertTextClassifier as BertClassifier, ) -from keras_hub.src.models.bert.bert_text_classifier_preprocessor import ( - BertTextClassifierPreprocessor, +from keras_hub.src.models.bert.bert_text_classifier import ( + BertTextClassifier as BertTextClassifier, ) from keras_hub.src.models.bert.bert_text_classifier_preprocessor import ( BertTextClassifierPreprocessor as BertPreprocessor, ) -from keras_hub.src.models.bert.bert_tokenizer import BertTokenizer -from keras_hub.src.models.bloom.bloom_backbone import BloomBackbone -from keras_hub.src.models.bloom.bloom_causal_lm import BloomCausalLM +from keras_hub.src.models.bert.bert_text_classifier_preprocessor import ( + BertTextClassifierPreprocessor as BertTextClassifierPreprocessor, +) +from keras_hub.src.models.bert.bert_tokenizer import ( + BertTokenizer as BertTokenizer, +) +from keras_hub.src.models.bloom.bloom_backbone import ( + BloomBackbone as BloomBackbone, +) +from keras_hub.src.models.bloom.bloom_causal_lm import ( + BloomCausalLM as BloomCausalLM, +) from keras_hub.src.models.bloom.bloom_causal_lm_preprocessor import ( - BloomCausalLMPreprocessor, -) -from keras_hub.src.models.bloom.bloom_tokenizer import BloomTokenizer -from keras_hub.src.models.causal_lm import CausalLM -from keras_hub.src.models.causal_lm_preprocessor import CausalLMPreprocessor -from keras_hub.src.models.clip.clip_backbone import CLIPBackbone -from keras_hub.src.models.clip.clip_preprocessor import CLIPPreprocessor -from keras_hub.src.models.clip.clip_text_encoder import CLIPTextEncoder -from keras_hub.src.models.clip.clip_tokenizer import CLIPTokenizer -from keras_hub.src.models.clip.clip_vision_encoder import CLIPVisionEncoder -from keras_hub.src.models.cspnet.cspnet_backbone import CSPNetBackbone + BloomCausalLMPreprocessor as BloomCausalLMPreprocessor, +) +from keras_hub.src.models.bloom.bloom_tokenizer import ( + BloomTokenizer as BloomTokenizer, +) +from keras_hub.src.models.causal_lm import CausalLM as CausalLM +from keras_hub.src.models.causal_lm_preprocessor import ( + CausalLMPreprocessor as CausalLMPreprocessor, +) +from keras_hub.src.models.clip.clip_backbone import CLIPBackbone as CLIPBackbone +from keras_hub.src.models.clip.clip_preprocessor import ( + CLIPPreprocessor as CLIPPreprocessor, +) +from keras_hub.src.models.clip.clip_text_encoder import ( + CLIPTextEncoder as CLIPTextEncoder, +) +from keras_hub.src.models.clip.clip_tokenizer import ( + CLIPTokenizer as CLIPTokenizer, +) +from keras_hub.src.models.clip.clip_vision_encoder import ( + CLIPVisionEncoder as CLIPVisionEncoder, +) +from keras_hub.src.models.cspnet.cspnet_backbone import ( + CSPNetBackbone as CSPNetBackbone, +) from keras_hub.src.models.cspnet.cspnet_image_classifier import ( - CSPNetImageClassifier, + CSPNetImageClassifier as CSPNetImageClassifier, ) from keras_hub.src.models.cspnet.cspnet_image_classifier_preprocessor import ( - CSPNetImageClassifierPreprocessor, + CSPNetImageClassifierPreprocessor as CSPNetImageClassifierPreprocessor, ) from keras_hub.src.models.deberta_v3.deberta_v3_backbone import ( - DebertaV3Backbone, + DebertaV3Backbone as DebertaV3Backbone, ) from keras_hub.src.models.deberta_v3.deberta_v3_masked_lm import ( - DebertaV3MaskedLM, + DebertaV3MaskedLM as DebertaV3MaskedLM, ) from keras_hub.src.models.deberta_v3.deberta_v3_masked_lm_preprocessor import ( - DebertaV3MaskedLMPreprocessor, + DebertaV3MaskedLMPreprocessor as DebertaV3MaskedLMPreprocessor, ) from keras_hub.src.models.deberta_v3.deberta_v3_text_classifier import ( - DebertaV3TextClassifier, + DebertaV3TextClassifier as DebertaV3Classifier, ) from keras_hub.src.models.deberta_v3.deberta_v3_text_classifier import ( - DebertaV3TextClassifier as DebertaV3Classifier, + DebertaV3TextClassifier as DebertaV3TextClassifier, ) from keras_hub.src.models.deberta_v3.deberta_v3_text_classifier_preprocessor import ( - DebertaV3TextClassifierPreprocessor, + DebertaV3TextClassifierPreprocessor as DebertaV3Preprocessor, ) from keras_hub.src.models.deberta_v3.deberta_v3_text_classifier_preprocessor import ( - DebertaV3TextClassifierPreprocessor as DebertaV3Preprocessor, + DebertaV3TextClassifierPreprocessor as DebertaV3TextClassifierPreprocessor, ) from keras_hub.src.models.deberta_v3.deberta_v3_tokenizer import ( - DebertaV3Tokenizer, + DebertaV3Tokenizer as DebertaV3Tokenizer, ) from keras_hub.src.models.deeplab_v3.deeplab_v3_backbone import ( - DeepLabV3Backbone, + DeepLabV3Backbone as DeepLabV3Backbone, ) from keras_hub.src.models.deeplab_v3.deeplab_v3_image_segmeter_preprocessor import ( - DeepLabV3ImageSegmenterPreprocessor, + DeepLabV3ImageSegmenterPreprocessor as DeepLabV3ImageSegmenterPreprocessor, ) from keras_hub.src.models.deeplab_v3.deeplab_v3_segmenter import ( - DeepLabV3ImageSegmenter, + DeepLabV3ImageSegmenter as DeepLabV3ImageSegmenter, +) +from keras_hub.src.models.densenet.densenet_backbone import ( + DenseNetBackbone as DenseNetBackbone, ) -from keras_hub.src.models.densenet.densenet_backbone import DenseNetBackbone from keras_hub.src.models.densenet.densenet_image_classifier import ( - DenseNetImageClassifier, + DenseNetImageClassifier as DenseNetImageClassifier, ) from keras_hub.src.models.densenet.densenet_image_classifier_preprocessor import ( - DenseNetImageClassifierPreprocessor, + DenseNetImageClassifierPreprocessor as DenseNetImageClassifierPreprocessor, ) from keras_hub.src.models.distil_bert.distil_bert_backbone import ( - DistilBertBackbone, + DistilBertBackbone as DistilBertBackbone, ) from keras_hub.src.models.distil_bert.distil_bert_masked_lm import ( - DistilBertMaskedLM, + DistilBertMaskedLM as DistilBertMaskedLM, ) from keras_hub.src.models.distil_bert.distil_bert_masked_lm_preprocessor import ( - DistilBertMaskedLMPreprocessor, + DistilBertMaskedLMPreprocessor as DistilBertMaskedLMPreprocessor, ) from keras_hub.src.models.distil_bert.distil_bert_text_classifier import ( - DistilBertTextClassifier, + DistilBertTextClassifier as DistilBertClassifier, ) from keras_hub.src.models.distil_bert.distil_bert_text_classifier import ( - DistilBertTextClassifier as DistilBertClassifier, + DistilBertTextClassifier as DistilBertTextClassifier, ) from keras_hub.src.models.distil_bert.distil_bert_text_classifier_preprocessor import ( - DistilBertTextClassifierPreprocessor, + DistilBertTextClassifierPreprocessor as DistilBertPreprocessor, ) from keras_hub.src.models.distil_bert.distil_bert_text_classifier_preprocessor import ( - DistilBertTextClassifierPreprocessor as DistilBertPreprocessor, + DistilBertTextClassifierPreprocessor as DistilBertTextClassifierPreprocessor, ) from keras_hub.src.models.distil_bert.distil_bert_tokenizer import ( - DistilBertTokenizer, + DistilBertTokenizer as DistilBertTokenizer, ) from keras_hub.src.models.efficientnet.efficientnet_backbone import ( - EfficientNetBackbone, + EfficientNetBackbone as EfficientNetBackbone, ) from keras_hub.src.models.efficientnet.efficientnet_image_classifier import ( - EfficientNetImageClassifier, + EfficientNetImageClassifier as EfficientNetImageClassifier, ) from keras_hub.src.models.efficientnet.efficientnet_image_classifier_preprocessor import ( - EfficientNetImageClassifierPreprocessor, + EfficientNetImageClassifierPreprocessor as EfficientNetImageClassifierPreprocessor, +) +from keras_hub.src.models.electra.electra_backbone import ( + ElectraBackbone as ElectraBackbone, +) +from keras_hub.src.models.electra.electra_tokenizer import ( + ElectraTokenizer as ElectraTokenizer, ) -from keras_hub.src.models.electra.electra_backbone import ElectraBackbone -from keras_hub.src.models.electra.electra_tokenizer import ElectraTokenizer -from keras_hub.src.models.esm.esm_backbone import ESMBackbone from keras_hub.src.models.esm.esm_backbone import ESMBackbone as ESM2Backbone -from keras_hub.src.models.esm.esm_classifier import ESMProteinClassifier +from keras_hub.src.models.esm.esm_backbone import ESMBackbone as ESMBackbone +from keras_hub.src.models.esm.esm_classifier import ( + ESMProteinClassifier as ESMProteinClassifier, +) from keras_hub.src.models.esm.esm_classifier_preprocessor import ( - ESMProteinClassifierPreprocessor, + ESMProteinClassifierPreprocessor as ESMProteinClassifierPreprocessor, ) -from keras_hub.src.models.esm.esm_masked_plm import ESMMaskedPLM from keras_hub.src.models.esm.esm_masked_plm import ( ESMMaskedPLM as ESM2MaskedPLM, ) +from keras_hub.src.models.esm.esm_masked_plm import ESMMaskedPLM as ESMMaskedPLM from keras_hub.src.models.esm.esm_masked_plm_preprocessor import ( - ESMMaskedPLMPreprocessor, + ESMMaskedPLMPreprocessor as ESMMaskedPLMPreprocessor, +) +from keras_hub.src.models.esm.esm_tokenizer import ESMTokenizer as ESMTokenizer +from keras_hub.src.models.f_net.f_net_backbone import ( + FNetBackbone as FNetBackbone, +) +from keras_hub.src.models.f_net.f_net_masked_lm import ( + FNetMaskedLM as FNetMaskedLM, ) -from keras_hub.src.models.esm.esm_tokenizer import ESMTokenizer -from keras_hub.src.models.f_net.f_net_backbone import FNetBackbone -from keras_hub.src.models.f_net.f_net_masked_lm import FNetMaskedLM from keras_hub.src.models.f_net.f_net_masked_lm_preprocessor import ( - FNetMaskedLMPreprocessor, + FNetMaskedLMPreprocessor as FNetMaskedLMPreprocessor, ) -from keras_hub.src.models.f_net.f_net_text_classifier import FNetTextClassifier from keras_hub.src.models.f_net.f_net_text_classifier import ( FNetTextClassifier as FNetClassifier, ) -from keras_hub.src.models.f_net.f_net_text_classifier_preprocessor import ( - FNetTextClassifierPreprocessor, +from keras_hub.src.models.f_net.f_net_text_classifier import ( + FNetTextClassifier as FNetTextClassifier, ) from keras_hub.src.models.f_net.f_net_text_classifier_preprocessor import ( FNetTextClassifierPreprocessor as FNetPreprocessor, ) -from keras_hub.src.models.f_net.f_net_tokenizer import FNetTokenizer -from keras_hub.src.models.falcon.falcon_backbone import FalconBackbone -from keras_hub.src.models.falcon.falcon_causal_lm import FalconCausalLM +from keras_hub.src.models.f_net.f_net_text_classifier_preprocessor import ( + FNetTextClassifierPreprocessor as FNetTextClassifierPreprocessor, +) +from keras_hub.src.models.f_net.f_net_tokenizer import ( + FNetTokenizer as FNetTokenizer, +) +from keras_hub.src.models.falcon.falcon_backbone import ( + FalconBackbone as FalconBackbone, +) +from keras_hub.src.models.falcon.falcon_causal_lm import ( + FalconCausalLM as FalconCausalLM, +) from keras_hub.src.models.falcon.falcon_causal_lm_preprocessor import ( - FalconCausalLMPreprocessor, + FalconCausalLMPreprocessor as FalconCausalLMPreprocessor, +) +from keras_hub.src.models.falcon.falcon_tokenizer import ( + FalconTokenizer as FalconTokenizer, +) +from keras_hub.src.models.feature_pyramid_backbone import ( + FeaturePyramidBackbone as FeaturePyramidBackbone, +) +from keras_hub.src.models.flux.flux_model import FluxBackbone as FluxBackbone +from keras_hub.src.models.flux.flux_text_to_image import ( + FluxTextToImage as FluxTextToImage, ) -from keras_hub.src.models.falcon.falcon_tokenizer import FalconTokenizer -from keras_hub.src.models.feature_pyramid_backbone import FeaturePyramidBackbone -from keras_hub.src.models.flux.flux_model import FluxBackbone -from keras_hub.src.models.flux.flux_text_to_image import FluxTextToImage from keras_hub.src.models.flux.flux_text_to_image_preprocessor import ( - FluxTextToImagePreprocessor, + FluxTextToImagePreprocessor as FluxTextToImagePreprocessor, +) +from keras_hub.src.models.gemma.gemma_backbone import ( + GemmaBackbone as GemmaBackbone, +) +from keras_hub.src.models.gemma.gemma_causal_lm import ( + GemmaCausalLM as GemmaCausalLM, ) -from keras_hub.src.models.gemma.gemma_backbone import GemmaBackbone -from keras_hub.src.models.gemma.gemma_causal_lm import GemmaCausalLM from keras_hub.src.models.gemma.gemma_causal_lm_preprocessor import ( - GemmaCausalLMPreprocessor, + GemmaCausalLMPreprocessor as GemmaCausalLMPreprocessor, +) +from keras_hub.src.models.gemma.gemma_tokenizer import ( + GemmaTokenizer as GemmaTokenizer, +) +from keras_hub.src.models.gemma3.gemma3_backbone import ( + Gemma3Backbone as Gemma3Backbone, +) +from keras_hub.src.models.gemma3.gemma3_causal_lm import ( + Gemma3CausalLM as Gemma3CausalLM, ) -from keras_hub.src.models.gemma.gemma_tokenizer import GemmaTokenizer -from keras_hub.src.models.gemma3.gemma3_backbone import Gemma3Backbone -from keras_hub.src.models.gemma3.gemma3_causal_lm import Gemma3CausalLM from keras_hub.src.models.gemma3.gemma3_causal_lm_preprocessor import ( - Gemma3CausalLMPreprocessor, + Gemma3CausalLMPreprocessor as Gemma3CausalLMPreprocessor, +) +from keras_hub.src.models.gemma3.gemma3_tokenizer import ( + Gemma3Tokenizer as Gemma3Tokenizer, ) -from keras_hub.src.models.gemma3.gemma3_tokenizer import Gemma3Tokenizer from keras_hub.src.models.gemma3.gemma3_vision_encoder import ( - Gemma3VisionEncoder, + Gemma3VisionEncoder as Gemma3VisionEncoder, +) +from keras_hub.src.models.gpt2.gpt2_backbone import GPT2Backbone as GPT2Backbone +from keras_hub.src.models.gpt2.gpt2_causal_lm import ( + GPT2CausalLM as GPT2CausalLM, ) -from keras_hub.src.models.gpt2.gpt2_backbone import GPT2Backbone -from keras_hub.src.models.gpt2.gpt2_causal_lm import GPT2CausalLM from keras_hub.src.models.gpt2.gpt2_causal_lm_preprocessor import ( - GPT2CausalLMPreprocessor, + GPT2CausalLMPreprocessor as GPT2CausalLMPreprocessor, +) +from keras_hub.src.models.gpt2.gpt2_preprocessor import ( + GPT2Preprocessor as GPT2Preprocessor, +) +from keras_hub.src.models.gpt2.gpt2_tokenizer import ( + GPT2Tokenizer as GPT2Tokenizer, +) +from keras_hub.src.models.gpt_neo_x.gpt_neo_x_backbone import ( + GPTNeoXBackbone as GPTNeoXBackbone, +) +from keras_hub.src.models.gpt_neo_x.gpt_neo_x_causal_lm import ( + GPTNeoXCausalLM as GPTNeoXCausalLM, ) -from keras_hub.src.models.gpt2.gpt2_preprocessor import GPT2Preprocessor -from keras_hub.src.models.gpt2.gpt2_tokenizer import GPT2Tokenizer -from keras_hub.src.models.gpt_neo_x.gpt_neo_x_backbone import GPTNeoXBackbone -from keras_hub.src.models.gpt_neo_x.gpt_neo_x_causal_lm import GPTNeoXCausalLM from keras_hub.src.models.gpt_neo_x.gpt_neo_x_causal_lm_preprocessor import ( - GPTNeoXCausalLMPreprocessor, + GPTNeoXCausalLMPreprocessor as GPTNeoXCausalLMPreprocessor, +) +from keras_hub.src.models.gpt_neo_x.gpt_neo_x_tokenizer import ( + GPTNeoXTokenizer as GPTNeoXTokenizer, +) +from keras_hub.src.models.image_classifier import ( + ImageClassifier as ImageClassifier, ) -from keras_hub.src.models.gpt_neo_x.gpt_neo_x_tokenizer import GPTNeoXTokenizer -from keras_hub.src.models.image_classifier import ImageClassifier from keras_hub.src.models.image_classifier_preprocessor import ( - ImageClassifierPreprocessor, + ImageClassifierPreprocessor as ImageClassifierPreprocessor, +) +from keras_hub.src.models.image_segmenter import ( + ImageSegmenter as ImageSegmenter, ) -from keras_hub.src.models.image_segmenter import ImageSegmenter from keras_hub.src.models.image_segmenter_preprocessor import ( - ImageSegmenterPreprocessor, + ImageSegmenterPreprocessor as ImageSegmenterPreprocessor, +) +from keras_hub.src.models.image_to_image import ImageToImage as ImageToImage +from keras_hub.src.models.inpaint import Inpaint as Inpaint +from keras_hub.src.models.llama.llama_backbone import ( + LlamaBackbone as LlamaBackbone, +) +from keras_hub.src.models.llama.llama_causal_lm import ( + LlamaCausalLM as LlamaCausalLM, ) -from keras_hub.src.models.image_to_image import ImageToImage -from keras_hub.src.models.inpaint import Inpaint -from keras_hub.src.models.llama.llama_backbone import LlamaBackbone -from keras_hub.src.models.llama.llama_causal_lm import LlamaCausalLM from keras_hub.src.models.llama.llama_causal_lm_preprocessor import ( - LlamaCausalLMPreprocessor, + LlamaCausalLMPreprocessor as LlamaCausalLMPreprocessor, +) +from keras_hub.src.models.llama.llama_tokenizer import ( + LlamaTokenizer as LlamaTokenizer, +) +from keras_hub.src.models.llama3.llama3_backbone import ( + Llama3Backbone as Llama3Backbone, +) +from keras_hub.src.models.llama3.llama3_causal_lm import ( + Llama3CausalLM as Llama3CausalLM, ) -from keras_hub.src.models.llama.llama_tokenizer import LlamaTokenizer -from keras_hub.src.models.llama3.llama3_backbone import Llama3Backbone -from keras_hub.src.models.llama3.llama3_causal_lm import Llama3CausalLM from keras_hub.src.models.llama3.llama3_causal_lm_preprocessor import ( - Llama3CausalLMPreprocessor, + Llama3CausalLMPreprocessor as Llama3CausalLMPreprocessor, +) +from keras_hub.src.models.llama3.llama3_tokenizer import ( + Llama3Tokenizer as Llama3Tokenizer, +) +from keras_hub.src.models.masked_lm import MaskedLM as MaskedLM +from keras_hub.src.models.masked_lm_preprocessor import ( + MaskedLMPreprocessor as MaskedLMPreprocessor, +) +from keras_hub.src.models.mistral.mistral_backbone import ( + MistralBackbone as MistralBackbone, +) +from keras_hub.src.models.mistral.mistral_causal_lm import ( + MistralCausalLM as MistralCausalLM, ) -from keras_hub.src.models.llama3.llama3_tokenizer import Llama3Tokenizer -from keras_hub.src.models.masked_lm import MaskedLM -from keras_hub.src.models.masked_lm_preprocessor import MaskedLMPreprocessor -from keras_hub.src.models.mistral.mistral_backbone import MistralBackbone -from keras_hub.src.models.mistral.mistral_causal_lm import MistralCausalLM from keras_hub.src.models.mistral.mistral_causal_lm_preprocessor import ( - MistralCausalLMPreprocessor, + MistralCausalLMPreprocessor as MistralCausalLMPreprocessor, +) +from keras_hub.src.models.mistral.mistral_tokenizer import ( + MistralTokenizer as MistralTokenizer, +) +from keras_hub.src.models.mit.mit_backbone import MiTBackbone as MiTBackbone +from keras_hub.src.models.mit.mit_image_classifier import ( + MiTImageClassifier as MiTImageClassifier, ) -from keras_hub.src.models.mistral.mistral_tokenizer import MistralTokenizer -from keras_hub.src.models.mit.mit_backbone import MiTBackbone -from keras_hub.src.models.mit.mit_image_classifier import MiTImageClassifier from keras_hub.src.models.mit.mit_image_classifier_preprocessor import ( - MiTImageClassifierPreprocessor, + MiTImageClassifierPreprocessor as MiTImageClassifierPreprocessor, +) +from keras_hub.src.models.mixtral.mixtral_backbone import ( + MixtralBackbone as MixtralBackbone, +) +from keras_hub.src.models.mixtral.mixtral_causal_lm import ( + MixtralCausalLM as MixtralCausalLM, ) -from keras_hub.src.models.mixtral.mixtral_backbone import MixtralBackbone -from keras_hub.src.models.mixtral.mixtral_causal_lm import MixtralCausalLM from keras_hub.src.models.mixtral.mixtral_causal_lm_preprocessor import ( - MixtralCausalLMPreprocessor, + MixtralCausalLMPreprocessor as MixtralCausalLMPreprocessor, +) +from keras_hub.src.models.mixtral.mixtral_tokenizer import ( + MixtralTokenizer as MixtralTokenizer, +) +from keras_hub.src.models.mobilenet.mobilenet_backbone import ( + MobileNetBackbone as MobileNetBackbone, ) -from keras_hub.src.models.mixtral.mixtral_tokenizer import MixtralTokenizer -from keras_hub.src.models.mobilenet.mobilenet_backbone import MobileNetBackbone from keras_hub.src.models.mobilenet.mobilenet_image_classifier import ( - MobileNetImageClassifier, + MobileNetImageClassifier as MobileNetImageClassifier, ) from keras_hub.src.models.mobilenet.mobilenet_image_classifier_preprocessor import ( - MobileNetImageClassifierPreprocessor, + MobileNetImageClassifierPreprocessor as MobileNetImageClassifierPreprocessor, ) -from keras_hub.src.models.object_detector import ObjectDetector from keras_hub.src.models.object_detector import ( ObjectDetector as ImageObjectDetector, ) -from keras_hub.src.models.object_detector_preprocessor import ( - ObjectDetectorPreprocessor, +from keras_hub.src.models.object_detector import ( + ObjectDetector as ObjectDetector, ) from keras_hub.src.models.object_detector_preprocessor import ( ObjectDetectorPreprocessor as ImageObjectDetectorPreprocessor, ) -from keras_hub.src.models.opt.opt_backbone import OPTBackbone -from keras_hub.src.models.opt.opt_causal_lm import OPTCausalLM +from keras_hub.src.models.object_detector_preprocessor import ( + ObjectDetectorPreprocessor as ObjectDetectorPreprocessor, +) +from keras_hub.src.models.opt.opt_backbone import OPTBackbone as OPTBackbone +from keras_hub.src.models.opt.opt_causal_lm import OPTCausalLM as OPTCausalLM from keras_hub.src.models.opt.opt_causal_lm_preprocessor import ( - OPTCausalLMPreprocessor, + OPTCausalLMPreprocessor as OPTCausalLMPreprocessor, ) -from keras_hub.src.models.opt.opt_tokenizer import OPTTokenizer +from keras_hub.src.models.opt.opt_tokenizer import OPTTokenizer as OPTTokenizer from keras_hub.src.models.pali_gemma.pali_gemma_backbone import ( - PaliGemmaBackbone, + PaliGemmaBackbone as PaliGemmaBackbone, ) from keras_hub.src.models.pali_gemma.pali_gemma_causal_lm import ( - PaliGemmaCausalLM, + PaliGemmaCausalLM as PaliGemmaCausalLM, ) from keras_hub.src.models.pali_gemma.pali_gemma_causal_lm_preprocessor import ( - PaliGemmaCausalLMPreprocessor, + PaliGemmaCausalLMPreprocessor as PaliGemmaCausalLMPreprocessor, ) from keras_hub.src.models.pali_gemma.pali_gemma_tokenizer import ( - PaliGemmaTokenizer, + PaliGemmaTokenizer as PaliGemmaTokenizer, +) +from keras_hub.src.models.phi3.phi3_backbone import Phi3Backbone as Phi3Backbone +from keras_hub.src.models.phi3.phi3_causal_lm import ( + Phi3CausalLM as Phi3CausalLM, ) -from keras_hub.src.models.phi3.phi3_backbone import Phi3Backbone -from keras_hub.src.models.phi3.phi3_causal_lm import Phi3CausalLM from keras_hub.src.models.phi3.phi3_causal_lm_preprocessor import ( - Phi3CausalLMPreprocessor, + Phi3CausalLMPreprocessor as Phi3CausalLMPreprocessor, ) -from keras_hub.src.models.phi3.phi3_tokenizer import Phi3Tokenizer -from keras_hub.src.models.preprocessor import Preprocessor -from keras_hub.src.models.qwen.qwen_backbone import QwenBackbone +from keras_hub.src.models.phi3.phi3_tokenizer import ( + Phi3Tokenizer as Phi3Tokenizer, +) +from keras_hub.src.models.preprocessor import Preprocessor as Preprocessor from keras_hub.src.models.qwen.qwen_backbone import ( QwenBackbone as Qwen2Backbone, ) -from keras_hub.src.models.qwen.qwen_causal_lm import QwenCausalLM +from keras_hub.src.models.qwen.qwen_backbone import QwenBackbone as QwenBackbone from keras_hub.src.models.qwen.qwen_causal_lm import ( QwenCausalLM as Qwen2CausalLM, ) -from keras_hub.src.models.qwen.qwen_causal_lm_preprocessor import ( - QwenCausalLMPreprocessor, +from keras_hub.src.models.qwen.qwen_causal_lm import ( + QwenCausalLM as QwenCausalLM, ) from keras_hub.src.models.qwen.qwen_causal_lm_preprocessor import ( QwenCausalLMPreprocessor as Qwen2CausalLMPreprocessor, ) -from keras_hub.src.models.qwen.qwen_tokenizer import QwenTokenizer +from keras_hub.src.models.qwen.qwen_causal_lm_preprocessor import ( + QwenCausalLMPreprocessor as QwenCausalLMPreprocessor, +) from keras_hub.src.models.qwen.qwen_tokenizer import ( QwenTokenizer as Qwen2Tokenizer, ) -from keras_hub.src.models.qwen_moe.qwen_moe_backbone import QwenMoeBackbone -from keras_hub.src.models.qwen_moe.qwen_moe_causal_lm import QwenMoeCausalLM +from keras_hub.src.models.qwen.qwen_tokenizer import ( + QwenTokenizer as QwenTokenizer, +) +from keras_hub.src.models.qwen_moe.qwen_moe_backbone import ( + QwenMoeBackbone as QwenMoeBackbone, +) +from keras_hub.src.models.qwen_moe.qwen_moe_causal_lm import ( + QwenMoeCausalLM as QwenMoeCausalLM, +) from keras_hub.src.models.qwen_moe.qwen_moe_causal_lm_preprocessor import ( - QwenMoeCausalLMPreprocessor, + QwenMoeCausalLMPreprocessor as QwenMoeCausalLMPreprocessor, +) +from keras_hub.src.models.resnet.resnet_backbone import ( + ResNetBackbone as ResNetBackbone, ) -from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone from keras_hub.src.models.resnet.resnet_image_classifier import ( - ResNetImageClassifier, + ResNetImageClassifier as ResNetImageClassifier, ) from keras_hub.src.models.resnet.resnet_image_classifier_preprocessor import ( - ResNetImageClassifierPreprocessor, + ResNetImageClassifierPreprocessor as ResNetImageClassifierPreprocessor, +) +from keras_hub.src.models.retinanet.retinanet_backbone import ( + RetinaNetBackbone as RetinaNetBackbone, ) -from keras_hub.src.models.retinanet.retinanet_backbone import RetinaNetBackbone from keras_hub.src.models.retinanet.retinanet_object_detector import ( - RetinaNetObjectDetector, + RetinaNetObjectDetector as RetinaNetObjectDetector, ) from keras_hub.src.models.retinanet.retinanet_object_detector_preprocessor import ( - RetinaNetObjectDetectorPreprocessor, + RetinaNetObjectDetectorPreprocessor as RetinaNetObjectDetectorPreprocessor, +) +from keras_hub.src.models.roberta.roberta_backbone import ( + RobertaBackbone as RobertaBackbone, +) +from keras_hub.src.models.roberta.roberta_masked_lm import ( + RobertaMaskedLM as RobertaMaskedLM, ) -from keras_hub.src.models.roberta.roberta_backbone import RobertaBackbone -from keras_hub.src.models.roberta.roberta_masked_lm import RobertaMaskedLM from keras_hub.src.models.roberta.roberta_masked_lm_preprocessor import ( - RobertaMaskedLMPreprocessor, + RobertaMaskedLMPreprocessor as RobertaMaskedLMPreprocessor, ) from keras_hub.src.models.roberta.roberta_text_classifier import ( - RobertaTextClassifier, + RobertaTextClassifier as RobertaClassifier, ) from keras_hub.src.models.roberta.roberta_text_classifier import ( - RobertaTextClassifier as RobertaClassifier, + RobertaTextClassifier as RobertaTextClassifier, ) from keras_hub.src.models.roberta.roberta_text_classifier_preprocessor import ( - RobertaTextClassifierPreprocessor, + RobertaTextClassifierPreprocessor as RobertaPreprocessor, ) from keras_hub.src.models.roberta.roberta_text_classifier_preprocessor import ( - RobertaTextClassifierPreprocessor as RobertaPreprocessor, + RobertaTextClassifierPreprocessor as RobertaTextClassifierPreprocessor, +) +from keras_hub.src.models.roberta.roberta_tokenizer import ( + RobertaTokenizer as RobertaTokenizer, ) -from keras_hub.src.models.roberta.roberta_tokenizer import RobertaTokenizer from keras_hub.src.models.roformer_v2.roformer_v2_backbone import ( - RoformerV2Backbone, + RoformerV2Backbone as RoformerV2Backbone, ) from keras_hub.src.models.roformer_v2.roformer_v2_masked_lm import ( - RoformerV2MaskedLM, + RoformerV2MaskedLM as RoformerV2MaskedLM, ) from keras_hub.src.models.roformer_v2.roformer_v2_masked_lm_preprocessor import ( - RoformerV2MaskedLMPreprocessor, + RoformerV2MaskedLMPreprocessor as RoformerV2MaskedLMPreprocessor, ) from keras_hub.src.models.roformer_v2.roformer_v2_text_classifier import ( - RoformerV2TextClassifier, + RoformerV2TextClassifier as RoformerV2TextClassifier, ) from keras_hub.src.models.roformer_v2.roformer_v2_text_classifier_preprocessor import ( - RoformerV2TextClassifierPreprocessor, + RoformerV2TextClassifierPreprocessor as RoformerV2TextClassifierPreprocessor, ) from keras_hub.src.models.roformer_v2.roformer_v2_tokenizer import ( - RoformerV2Tokenizer, + RoformerV2Tokenizer as RoformerV2Tokenizer, +) +from keras_hub.src.models.sam.sam_backbone import SAMBackbone as SAMBackbone +from keras_hub.src.models.sam.sam_image_segmenter import ( + SAMImageSegmenter as SAMImageSegmenter, ) -from keras_hub.src.models.sam.sam_backbone import SAMBackbone -from keras_hub.src.models.sam.sam_image_segmenter import SAMImageSegmenter from keras_hub.src.models.sam.sam_image_segmenter_preprocessor import ( - SAMImageSegmenterPreprocessor, + SAMImageSegmenterPreprocessor as SAMImageSegmenterPreprocessor, +) +from keras_hub.src.models.segformer.segformer_backbone import ( + SegFormerBackbone as SegFormerBackbone, ) -from keras_hub.src.models.segformer.segformer_backbone import SegFormerBackbone from keras_hub.src.models.segformer.segformer_image_segmenter import ( - SegFormerImageSegmenter, + SegFormerImageSegmenter as SegFormerImageSegmenter, ) from keras_hub.src.models.segformer.segformer_image_segmenter_preprocessor import ( - SegFormerImageSegmenterPreprocessor, -) -from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM -from keras_hub.src.models.seq_2_seq_lm_preprocessor import Seq2SeqLMPreprocessor -from keras_hub.src.models.siglip.siglip_backbone import SigLIPBackbone -from keras_hub.src.models.siglip.siglip_preprocessor import SigLIPPreprocessor -from keras_hub.src.models.siglip.siglip_text_encoder import SigLIPTextEncoder -from keras_hub.src.models.siglip.siglip_tokenizer import SigLIPTokenizer + SegFormerImageSegmenterPreprocessor as SegFormerImageSegmenterPreprocessor, +) +from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM as Seq2SeqLM +from keras_hub.src.models.seq_2_seq_lm_preprocessor import ( + Seq2SeqLMPreprocessor as Seq2SeqLMPreprocessor, +) +from keras_hub.src.models.siglip.siglip_backbone import ( + SigLIPBackbone as SigLIPBackbone, +) +from keras_hub.src.models.siglip.siglip_preprocessor import ( + SigLIPPreprocessor as SigLIPPreprocessor, +) +from keras_hub.src.models.siglip.siglip_text_encoder import ( + SigLIPTextEncoder as SigLIPTextEncoder, +) +from keras_hub.src.models.siglip.siglip_tokenizer import ( + SigLIPTokenizer as SigLIPTokenizer, +) from keras_hub.src.models.siglip.siglip_vision_encoder import ( - SigLIPVisionEncoder, + SigLIPVisionEncoder as SigLIPVisionEncoder, ) from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import ( - StableDiffusion3Backbone, + StableDiffusion3Backbone as StableDiffusion3Backbone, ) from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_image_to_image import ( - StableDiffusion3ImageToImage, + StableDiffusion3ImageToImage as StableDiffusion3ImageToImage, ) from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_inpaint import ( - StableDiffusion3Inpaint, + StableDiffusion3Inpaint as StableDiffusion3Inpaint, ) from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image import ( - StableDiffusion3TextToImage, + StableDiffusion3TextToImage as StableDiffusion3TextToImage, ) from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image_preprocessor import ( - StableDiffusion3TextToImagePreprocessor, + StableDiffusion3TextToImagePreprocessor as StableDiffusion3TextToImagePreprocessor, ) -from keras_hub.src.models.t5.t5_backbone import T5Backbone -from keras_hub.src.models.t5.t5_preprocessor import T5Preprocessor -from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer -from keras_hub.src.models.task import Task -from keras_hub.src.models.text_classifier import TextClassifier +from keras_hub.src.models.t5.t5_backbone import T5Backbone as T5Backbone +from keras_hub.src.models.t5.t5_preprocessor import ( + T5Preprocessor as T5Preprocessor, +) +from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer as T5Tokenizer +from keras_hub.src.models.task import Task as Task from keras_hub.src.models.text_classifier import TextClassifier as Classifier +from keras_hub.src.models.text_classifier import ( + TextClassifier as TextClassifier, +) from keras_hub.src.models.text_classifier_preprocessor import ( - TextClassifierPreprocessor, + TextClassifierPreprocessor as TextClassifierPreprocessor, ) -from keras_hub.src.models.text_to_image import TextToImage +from keras_hub.src.models.text_to_image import TextToImage as TextToImage from keras_hub.src.models.text_to_image_preprocessor import ( - TextToImagePreprocessor, + TextToImagePreprocessor as TextToImagePreprocessor, +) +from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone as VGGBackbone +from keras_hub.src.models.vgg.vgg_image_classifier import ( + VGGImageClassifier as VGGImageClassifier, ) -from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone -from keras_hub.src.models.vgg.vgg_image_classifier import VGGImageClassifier from keras_hub.src.models.vgg.vgg_image_classifier_preprocessor import ( - VGGImageClassifierPreprocessor, + VGGImageClassifierPreprocessor as VGGImageClassifierPreprocessor, +) +from keras_hub.src.models.vit.vit_backbone import ViTBackbone as ViTBackbone +from keras_hub.src.models.vit.vit_image_classifier import ( + ViTImageClassifier as ViTImageClassifier, ) -from keras_hub.src.models.vit.vit_backbone import ViTBackbone -from keras_hub.src.models.vit.vit_image_classifier import ViTImageClassifier from keras_hub.src.models.vit.vit_image_classifier_preprocessor import ( - ViTImageClassifierPreprocessor, + ViTImageClassifierPreprocessor as ViTImageClassifierPreprocessor, +) +from keras_hub.src.models.vit_det.vit_det_backbone import ( + ViTDetBackbone as ViTDetBackbone, +) +from keras_hub.src.models.whisper.whisper_backbone import ( + WhisperBackbone as WhisperBackbone, +) +from keras_hub.src.models.whisper.whisper_tokenizer import ( + WhisperTokenizer as WhisperTokenizer, +) +from keras_hub.src.models.xception.xception_backbone import ( + XceptionBackbone as XceptionBackbone, ) -from keras_hub.src.models.vit_det.vit_det_backbone import ViTDetBackbone -from keras_hub.src.models.whisper.whisper_backbone import WhisperBackbone -from keras_hub.src.models.whisper.whisper_tokenizer import WhisperTokenizer -from keras_hub.src.models.xception.xception_backbone import XceptionBackbone from keras_hub.src.models.xception.xception_image_classifier import ( - XceptionImageClassifier, + XceptionImageClassifier as XceptionImageClassifier, ) from keras_hub.src.models.xception.xception_image_classifier_preprocessor import ( - XceptionImageClassifierPreprocessor, + XceptionImageClassifierPreprocessor as XceptionImageClassifierPreprocessor, ) from keras_hub.src.models.xlm_roberta.xlm_roberta_backbone import ( - XLMRobertaBackbone, + XLMRobertaBackbone as XLMRobertaBackbone, ) from keras_hub.src.models.xlm_roberta.xlm_roberta_masked_lm import ( - XLMRobertaMaskedLM, + XLMRobertaMaskedLM as XLMRobertaMaskedLM, ) from keras_hub.src.models.xlm_roberta.xlm_roberta_masked_lm_preprocessor import ( - XLMRobertaMaskedLMPreprocessor, + XLMRobertaMaskedLMPreprocessor as XLMRobertaMaskedLMPreprocessor, ) from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier import ( - XLMRobertaTextClassifier, + XLMRobertaTextClassifier as XLMRobertaClassifier, ) from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier import ( - XLMRobertaTextClassifier as XLMRobertaClassifier, + XLMRobertaTextClassifier as XLMRobertaTextClassifier, ) from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier_preprocessor import ( - XLMRobertaTextClassifierPreprocessor, + XLMRobertaTextClassifierPreprocessor as XLMRobertaPreprocessor, ) from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier_preprocessor import ( - XLMRobertaTextClassifierPreprocessor as XLMRobertaPreprocessor, + XLMRobertaTextClassifierPreprocessor as XLMRobertaTextClassifierPreprocessor, ) from keras_hub.src.models.xlm_roberta.xlm_roberta_tokenizer import ( - XLMRobertaTokenizer, + XLMRobertaTokenizer as XLMRobertaTokenizer, +) +from keras_hub.src.models.xlnet.xlnet_backbone import ( + XLNetBackbone as XLNetBackbone, ) -from keras_hub.src.models.xlnet.xlnet_backbone import XLNetBackbone -from keras_hub.src.tokenizers.tokenizer import Tokenizer +from keras_hub.src.tokenizers.tokenizer import Tokenizer as Tokenizer diff --git a/keras_hub/api/samplers/__init__.py b/keras_hub/api/samplers/__init__.py index 9feb76c669..29bfef00fc 100644 --- a/keras_hub/api/samplers/__init__.py +++ b/keras_hub/api/samplers/__init__.py @@ -4,13 +4,15 @@ since your modifications would be overwritten. """ -from keras_hub.src.samplers.beam_sampler import BeamSampler -from keras_hub.src.samplers.contrastive_sampler import ContrastiveSampler -from keras_hub.src.samplers.greedy_sampler import GreedySampler -from keras_hub.src.samplers.random_sampler import RandomSampler -from keras_hub.src.samplers.sampler import Sampler -from keras_hub.src.samplers.serialization import deserialize -from keras_hub.src.samplers.serialization import get -from keras_hub.src.samplers.serialization import serialize -from keras_hub.src.samplers.top_k_sampler import TopKSampler -from keras_hub.src.samplers.top_p_sampler import TopPSampler +from keras_hub.src.samplers.beam_sampler import BeamSampler as BeamSampler +from keras_hub.src.samplers.contrastive_sampler import ( + ContrastiveSampler as ContrastiveSampler, +) +from keras_hub.src.samplers.greedy_sampler import GreedySampler as GreedySampler +from keras_hub.src.samplers.random_sampler import RandomSampler as RandomSampler +from keras_hub.src.samplers.sampler import Sampler as Sampler +from keras_hub.src.samplers.serialization import deserialize as deserialize +from keras_hub.src.samplers.serialization import get as get +from keras_hub.src.samplers.serialization import serialize as serialize +from keras_hub.src.samplers.top_k_sampler import TopKSampler as TopKSampler +from keras_hub.src.samplers.top_p_sampler import TopPSampler as TopPSampler diff --git a/keras_hub/api/tokenizers/__init__.py b/keras_hub/api/tokenizers/__init__.py index 96818e01e7..303bd190fc 100644 --- a/keras_hub/api/tokenizers/__init__.py +++ b/keras_hub/api/tokenizers/__init__.py @@ -4,62 +4,112 @@ since your modifications would be overwritten. """ -from keras_hub.src.models.albert.albert_tokenizer import AlbertTokenizer -from keras_hub.src.models.bart.bart_tokenizer import BartTokenizer -from keras_hub.src.models.bert.bert_tokenizer import BertTokenizer -from keras_hub.src.models.bloom.bloom_tokenizer import BloomTokenizer -from keras_hub.src.models.clip.clip_tokenizer import CLIPTokenizer +from keras_hub.src.models.albert.albert_tokenizer import ( + AlbertTokenizer as AlbertTokenizer, +) +from keras_hub.src.models.bart.bart_tokenizer import ( + BartTokenizer as BartTokenizer, +) +from keras_hub.src.models.bert.bert_tokenizer import ( + BertTokenizer as BertTokenizer, +) +from keras_hub.src.models.bloom.bloom_tokenizer import ( + BloomTokenizer as BloomTokenizer, +) +from keras_hub.src.models.clip.clip_tokenizer import ( + CLIPTokenizer as CLIPTokenizer, +) from keras_hub.src.models.deberta_v3.deberta_v3_tokenizer import ( - DebertaV3Tokenizer, + DebertaV3Tokenizer as DebertaV3Tokenizer, ) from keras_hub.src.models.distil_bert.distil_bert_tokenizer import ( - DistilBertTokenizer, -) -from keras_hub.src.models.electra.electra_tokenizer import ElectraTokenizer -from keras_hub.src.models.esm.esm_tokenizer import ESMTokenizer -from keras_hub.src.models.f_net.f_net_tokenizer import FNetTokenizer -from keras_hub.src.models.falcon.falcon_tokenizer import FalconTokenizer -from keras_hub.src.models.gemma.gemma_tokenizer import GemmaTokenizer -from keras_hub.src.models.gemma3.gemma3_tokenizer import Gemma3Tokenizer -from keras_hub.src.models.gpt2.gpt2_tokenizer import GPT2Tokenizer -from keras_hub.src.models.gpt_neo_x.gpt_neo_x_tokenizer import GPTNeoXTokenizer -from keras_hub.src.models.llama.llama_tokenizer import LlamaTokenizer -from keras_hub.src.models.llama3.llama3_tokenizer import Llama3Tokenizer -from keras_hub.src.models.mistral.mistral_tokenizer import MistralTokenizer -from keras_hub.src.models.mixtral.mixtral_tokenizer import MixtralTokenizer -from keras_hub.src.models.opt.opt_tokenizer import OPTTokenizer + DistilBertTokenizer as DistilBertTokenizer, +) +from keras_hub.src.models.electra.electra_tokenizer import ( + ElectraTokenizer as ElectraTokenizer, +) +from keras_hub.src.models.esm.esm_tokenizer import ESMTokenizer as ESMTokenizer +from keras_hub.src.models.f_net.f_net_tokenizer import ( + FNetTokenizer as FNetTokenizer, +) +from keras_hub.src.models.falcon.falcon_tokenizer import ( + FalconTokenizer as FalconTokenizer, +) +from keras_hub.src.models.gemma.gemma_tokenizer import ( + GemmaTokenizer as GemmaTokenizer, +) +from keras_hub.src.models.gemma3.gemma3_tokenizer import ( + Gemma3Tokenizer as Gemma3Tokenizer, +) +from keras_hub.src.models.gpt2.gpt2_tokenizer import ( + GPT2Tokenizer as GPT2Tokenizer, +) +from keras_hub.src.models.gpt_neo_x.gpt_neo_x_tokenizer import ( + GPTNeoXTokenizer as GPTNeoXTokenizer, +) +from keras_hub.src.models.llama.llama_tokenizer import ( + LlamaTokenizer as LlamaTokenizer, +) +from keras_hub.src.models.llama3.llama3_tokenizer import ( + Llama3Tokenizer as Llama3Tokenizer, +) +from keras_hub.src.models.mistral.mistral_tokenizer import ( + MistralTokenizer as MistralTokenizer, +) +from keras_hub.src.models.mixtral.mixtral_tokenizer import ( + MixtralTokenizer as MixtralTokenizer, +) +from keras_hub.src.models.opt.opt_tokenizer import OPTTokenizer as OPTTokenizer from keras_hub.src.models.pali_gemma.pali_gemma_tokenizer import ( - PaliGemmaTokenizer, + PaliGemmaTokenizer as PaliGemmaTokenizer, +) +from keras_hub.src.models.phi3.phi3_tokenizer import ( + Phi3Tokenizer as Phi3Tokenizer, ) -from keras_hub.src.models.phi3.phi3_tokenizer import Phi3Tokenizer -from keras_hub.src.models.qwen.qwen_tokenizer import QwenTokenizer from keras_hub.src.models.qwen.qwen_tokenizer import ( QwenTokenizer as Qwen2Tokenizer, ) -from keras_hub.src.models.qwen_moe.qwen_moe_tokenizer import QwenMoeTokenizer -from keras_hub.src.models.roberta.roberta_tokenizer import RobertaTokenizer +from keras_hub.src.models.qwen.qwen_tokenizer import ( + QwenTokenizer as QwenTokenizer, +) +from keras_hub.src.models.qwen_moe.qwen_moe_tokenizer import ( + QwenMoeTokenizer as QwenMoeTokenizer, +) +from keras_hub.src.models.roberta.roberta_tokenizer import ( + RobertaTokenizer as RobertaTokenizer, +) from keras_hub.src.models.roformer_v2.roformer_v2_tokenizer import ( - RoformerV2Tokenizer, + RoformerV2Tokenizer as RoformerV2Tokenizer, +) +from keras_hub.src.models.siglip.siglip_tokenizer import ( + SigLIPTokenizer as SigLIPTokenizer, +) +from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer as T5Tokenizer +from keras_hub.src.models.whisper.whisper_tokenizer import ( + WhisperTokenizer as WhisperTokenizer, ) -from keras_hub.src.models.siglip.siglip_tokenizer import SigLIPTokenizer -from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer -from keras_hub.src.models.whisper.whisper_tokenizer import WhisperTokenizer from keras_hub.src.models.xlm_roberta.xlm_roberta_tokenizer import ( - XLMRobertaTokenizer, + XLMRobertaTokenizer as XLMRobertaTokenizer, +) +from keras_hub.src.tokenizers.byte_pair_tokenizer import ( + BytePairTokenizer as BytePairTokenizer, +) +from keras_hub.src.tokenizers.byte_tokenizer import ( + ByteTokenizer as ByteTokenizer, ) -from keras_hub.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer -from keras_hub.src.tokenizers.byte_tokenizer import ByteTokenizer from keras_hub.src.tokenizers.sentence_piece_tokenizer import ( - SentencePieceTokenizer, + SentencePieceTokenizer as SentencePieceTokenizer, ) from keras_hub.src.tokenizers.sentence_piece_tokenizer_trainer import ( - compute_sentence_piece_proto, + compute_sentence_piece_proto as compute_sentence_piece_proto, ) -from keras_hub.src.tokenizers.tokenizer import Tokenizer +from keras_hub.src.tokenizers.tokenizer import Tokenizer as Tokenizer from keras_hub.src.tokenizers.unicode_codepoint_tokenizer import ( - UnicodeCodepointTokenizer, + UnicodeCodepointTokenizer as UnicodeCodepointTokenizer, +) +from keras_hub.src.tokenizers.word_piece_tokenizer import ( + WordPieceTokenizer as WordPieceTokenizer, ) -from keras_hub.src.tokenizers.word_piece_tokenizer import WordPieceTokenizer from keras_hub.src.tokenizers.word_piece_tokenizer_trainer import ( - compute_word_piece_vocabulary, + compute_word_piece_vocabulary as compute_word_piece_vocabulary, ) diff --git a/keras_hub/api/utils/__init__.py b/keras_hub/api/utils/__init__.py index 8ce47790b0..0bd8cb642e 100644 --- a/keras_hub/api/utils/__init__.py +++ b/keras_hub/api/utils/__init__.py @@ -4,10 +4,18 @@ since your modifications would be overwritten. """ -from keras_hub.src.utils.coco.coco_utils import coco_id_to_name -from keras_hub.src.utils.coco.coco_utils import coco_name_to_id +from keras_hub.src.utils.coco.coco_utils import ( + coco_id_to_name as coco_id_to_name, +) +from keras_hub.src.utils.coco.coco_utils import ( + coco_name_to_id as coco_name_to_id, +) +from keras_hub.src.utils.imagenet.imagenet_utils import ( + decode_imagenet_predictions as decode_imagenet_predictions, +) +from keras_hub.src.utils.imagenet.imagenet_utils import ( + imagenet_id_to_name as imagenet_id_to_name, +) from keras_hub.src.utils.imagenet.imagenet_utils import ( - decode_imagenet_predictions, + imagenet_name_to_id as imagenet_name_to_id, ) -from keras_hub.src.utils.imagenet.imagenet_utils import imagenet_id_to_name -from keras_hub.src.utils.imagenet.imagenet_utils import imagenet_name_to_id From 20d505114d0b16fe45eb06f4238a6204ea6bb5c1 Mon Sep 17 00:00:00 2001 From: pass_lin <935499957@qq.com> Date: Sat, 17 May 2025 12:22:30 +0800 Subject: [PATCH 13/13] update --- keras_hub/src/models/esm/esm_backbone.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/keras_hub/src/models/esm/esm_backbone.py b/keras_hub/src/models/esm/esm_backbone.py index bf971fc30e..94e9c22499 100644 --- a/keras_hub/src/models/esm/esm_backbone.py +++ b/keras_hub/src/models/esm/esm_backbone.py @@ -35,9 +35,9 @@ class ESMBackbone(Backbone): a two-layer feedforward network for each transformer. dropout: float. Dropout probability for the Transformer encoder. Defaults to 0.1 - layer_norm_eps:bool.If true, then layer norm will be - used before entering the transformer block. - Since it's pre-norm, the default is false. + layer_norm_eps:bool.If true, then layer norm will be used before + entering the transformer block. + Since it's pre-norm, the default is false. max_sequence_length: int. The maximum sequence length that this encoder can consume. If None, `max_sequence_length` uses the value from sequence length. This determines the variable shape for positional