-
Notifications
You must be signed in to change notification settings - Fork 283
Added LayoutLMv3 #2178
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
carrycooldude
wants to merge
9
commits into
keras-team:master
Choose a base branch
from
carrycooldude:feature/layoutlmv3-port
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Added LayoutLMv3 #2178
Changes from 2 commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
ae79d15
added the files
carrycooldude 737f03a
Restructure LayoutLMv3 implementation to match KerasHub style
carrycooldude 455a140
Refactor: Move LayoutLMv3 files to models directory and make code bac…
carrycooldude d92c8c4
refactor: Move LayoutLMv3 files to dedicated directory
carrycooldude 0948f95
fix: Update LayoutLMv3 init files to follow correct format
carrycooldude 3c02f78
fix: Update LayoutLMv3 backbone to follow project standards
carrycooldude 4a79d9b
refactor: remove unnecessary files and fix imports in LayoutLMv3 module
carrycooldude c2fed4c
Add minimal stub for LayoutLMv3TransformerLayer
carrycooldude e828047
fix: resolve merge conflicts and complete rebase
carrycooldude File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
"""LayoutLMv3 model.""" | ||
|
||
from keras_hub.src.models.layoutlmv3.layoutlmv3_backbone import LayoutLMv3Backbone | ||
from keras_hub.src.models.layoutlmv3.layoutlmv3_tokenizer import LayoutLMv3Tokenizer | ||
from keras_hub.src.models.layoutlmv3.document_classifier import LayoutLMv3DocumentClassifier | ||
from keras_hub.src.models.layoutlmv3.document_classifier import LayoutLMv3DocumentClassifierPreprocessor | ||
from keras_hub.src.models.layoutlmv3.layoutlmv3_presets import backbone_presets | ||
from keras_hub.src.utils.preset_utils import register_presets | ||
|
||
register_presets(backbone_presets, LayoutLMv3Backbone) |
4 changes: 4 additions & 0 deletions
4
keras_hub/src/models/layoutlmv3/document_classifier/__init__.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
"""LayoutLMv3 document classifier.""" | ||
|
||
from keras_hub.src.models.layoutlmv3.document_classifier.layoutlmv3_document_classifier import LayoutLMv3DocumentClassifier | ||
from keras_hub.src.models.layoutlmv3.document_classifier.layoutlmv3_document_classifier_preprocessor import LayoutLMv3DocumentClassifierPreprocessor |
103 changes: 103 additions & 0 deletions
103
keras_hub/src/models/layoutlmv3/document_classifier/layoutlmv3_document_classifier.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
"""LayoutLMv3 document classifier task model.""" | ||
|
||
import tensorflow as tf | ||
from tensorflow import keras | ||
|
||
from keras_hub.src.models.layoutlmv3.layoutlmv3_backbone import LayoutLMv3Backbone | ||
|
||
|
||
@keras.saving.register_keras_serializable(package="keras_hub") | ||
class LayoutLMv3DocumentClassifier(keras.Model): | ||
"""LayoutLMv3 document classifier task model. | ||
|
||
This model takes text, layout (bounding boxes) and image inputs and outputs | ||
document classification predictions. | ||
|
||
Args: | ||
backbone: A LayoutLMv3Backbone instance. | ||
num_classes: int. Number of classes to classify documents into. | ||
dropout: float. Dropout probability for the classification head. | ||
activation: str or callable. The activation function to use on the | ||
classification head. | ||
**kwargs: Additional keyword arguments. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
backbone, | ||
num_classes, | ||
dropout=0.1, | ||
activation="softmax", | ||
**kwargs, | ||
): | ||
inputs = { | ||
"input_ids": keras.Input(shape=(None,), dtype=tf.int32), | ||
"bbox": keras.Input(shape=(None, 4), dtype=tf.int32), | ||
"attention_mask": keras.Input(shape=(None,), dtype=tf.int32), | ||
"image": keras.Input(shape=(None, None, 3), dtype=tf.float32), | ||
} | ||
|
||
# Get backbone outputs | ||
backbone_outputs = backbone(inputs) | ||
sequence_output = backbone_outputs["sequence_output"] | ||
pooled_output = backbone_outputs["pooled_output"] | ||
|
||
# Classification head | ||
x = keras.layers.Dropout(dropout)(pooled_output) | ||
outputs = keras.layers.Dense( | ||
num_classes, | ||
activation=activation, | ||
name="classifier", | ||
)(x) | ||
|
||
super().__init__( | ||
inputs=inputs, | ||
outputs=outputs, | ||
**kwargs, | ||
) | ||
|
||
self.backbone = backbone | ||
self.num_classes = num_classes | ||
self.dropout = dropout | ||
self.activation = activation | ||
|
||
def get_config(self): | ||
config = super().get_config() | ||
config.update({ | ||
"backbone": keras.saving.serialize_keras_object(self.backbone), | ||
"num_classes": self.num_classes, | ||
"dropout": self.dropout, | ||
"activation": self.activation, | ||
}) | ||
return config | ||
|
||
@classmethod | ||
def from_preset( | ||
cls, | ||
preset, | ||
num_classes, | ||
dropout=0.1, | ||
activation="softmax", | ||
**kwargs, | ||
): | ||
"""Create a LayoutLMv3 document classifier from a preset. | ||
|
||
Args: | ||
preset: string. Must be one of "layoutlmv3_base", "layoutlmv3_large". | ||
num_classes: int. Number of classes to classify documents into. | ||
dropout: float. Dropout probability for the classification head. | ||
activation: str or callable. The activation function to use on the | ||
classification head. | ||
**kwargs: Additional keyword arguments. | ||
|
||
Returns: | ||
A LayoutLMv3DocumentClassifier instance. | ||
""" | ||
backbone = LayoutLMv3Backbone.from_preset(preset) | ||
return cls( | ||
backbone=backbone, | ||
num_classes=num_classes, | ||
dropout=dropout, | ||
activation=activation, | ||
**kwargs, | ||
) |
184 changes: 184 additions & 0 deletions
184
.../src/models/layoutlmv3/document_classifier/layoutlmv3_document_classifier_preprocessor.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
"""LayoutLMv3 document classifier preprocessor. | ||
|
||
This preprocessor inherits from Preprocessor and adds LayoutLMv3-specific | ||
functionality for document classification. | ||
|
||
Example: | ||
```python | ||
# Initialize the preprocessor | ||
preprocessor = LayoutLMv3DocumentClassifierPreprocessor( | ||
tokenizer=LayoutLMv3Tokenizer.from_preset("layoutlmv3_base"), | ||
sequence_length=512, | ||
image_size=(112, 112), | ||
) | ||
|
||
# Preprocess input | ||
features = { | ||
"text": ["Invoice #12345\nTotal: $100.00", "Receipt #67890\nTotal: $50.00"], | ||
"bbox": [ | ||
[[0, 0, 100, 20], [0, 30, 100, 50]], # Bounding boxes for first document | ||
[[0, 0, 100, 20], [0, 30, 100, 50]], # Bounding boxes for second document | ||
], | ||
"image": tf.random.uniform((2, 112, 112, 3)), # Random images for demo | ||
} | ||
preprocessed = preprocessor(features) | ||
``` | ||
""" | ||
|
||
import os | ||
import json | ||
import tensorflow as tf | ||
from keras.saving import register_keras_serializable | ||
from keras.utils import register_keras_serializable | ||
from keras_hub.src.models.preprocessor import Preprocessor | ||
from .layoutlmv3_tokenizer import LayoutLMv3Tokenizer | ||
|
||
import keras | ||
from keras import layers | ||
from keras.src.saving import register_keras_serializable | ||
|
||
from keras_hub.src.api_export import keras_hub_export | ||
from keras_hub.src.models.layoutlmv3.layoutlmv3_backbone import LayoutLMv3Backbone | ||
from keras_hub.src.utils.tensor_utils import preprocessing_function | ||
|
||
|
||
@keras_hub_export( | ||
[ | ||
"keras_hub.models.LayoutLMv3DocumentClassifierPreprocessor", | ||
"keras_hub.models.LayoutLMv3Preprocessor", | ||
] | ||
) | ||
@register_keras_serializable() | ||
class LayoutLMv3DocumentClassifierPreprocessor(Preprocessor): | ||
"""LayoutLMv3 document classifier preprocessor. | ||
|
||
This preprocessor inherits from Preprocessor and adds LayoutLMv3-specific | ||
functionality for document classification. | ||
|
||
Args: | ||
tokenizer: A LayoutLMv3Tokenizer instance. | ||
sequence_length: The maximum sequence length to use. | ||
image_size: A tuple of (height, width) for resizing images. | ||
**kwargs: Additional keyword arguments. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
tokenizer, | ||
sequence_length=512, | ||
image_size=(112, 112), | ||
**kwargs, | ||
): | ||
super().__init__( | ||
tokenizer=tokenizer, | ||
sequence_length=sequence_length, | ||
image_size=image_size, | ||
**kwargs, | ||
) | ||
|
||
def call(self, x, y=None, sample_weight=None): | ||
"""Process the inputs. | ||
|
||
Args: | ||
x: A dictionary containing: | ||
- "text": A string or list of strings to tokenize. | ||
- "image": A numpy array or list of numpy arrays of shape (112, 112, 3). | ||
- "bbox": A list of bounding boxes for each token in the text. | ||
y: Any label data. Will be passed through unaltered. | ||
sample_weight: Any label weight data. Will be passed through unaltered. | ||
|
||
Returns: | ||
A tuple of (processed_inputs, y, sample_weight). | ||
""" | ||
# Tokenize the text | ||
tokenized = self.tokenizer(x["text"]) | ||
input_ids = tokenized["token_ids"] | ||
attention_mask = tokenized["attention_mask"] | ||
|
||
# Process bounding boxes | ||
bbox = x["bbox"] | ||
if isinstance(bbox, list): | ||
bbox = tf.ragged.constant(bbox) | ||
bbox = bbox.to_tensor(shape=(None, self.sequence_length, 4)) | ||
|
||
# Process image | ||
image = x["image"] | ||
if isinstance(image, list): | ||
image = tf.stack(image) | ||
image = tf.cast(image, tf.float32) | ||
|
||
# Pad or truncate inputs | ||
input_ids = input_ids[:, : self.sequence_length] | ||
attention_mask = attention_mask[:, : self.sequence_length] | ||
bbox = bbox[:, : self.sequence_length] | ||
|
||
# Create padding mask | ||
padding_mask = tf.cast(attention_mask, tf.int32) | ||
|
||
# Return processed inputs | ||
processed_inputs = { | ||
"input_ids": input_ids, | ||
"bbox": bbox, | ||
"attention_mask": attention_mask, | ||
"image": image, | ||
} | ||
|
||
return processed_inputs, y, sample_weight | ||
|
||
def get_config(self): | ||
config = super().get_config() | ||
config.update( | ||
{ | ||
"tokenizer": keras.saving.serialize_keras_object(self.tokenizer), | ||
"sequence_length": self.sequence_length, | ||
"image_size": self.image_size, | ||
} | ||
) | ||
return config | ||
|
||
@classmethod | ||
def from_config(cls, config): | ||
if "tokenizer" in config: | ||
config["tokenizer"] = keras.saving.deserialize_keras_object( | ||
config["tokenizer"] | ||
) | ||
return cls(**config) | ||
|
||
@classmethod | ||
def from_preset( | ||
cls, | ||
preset, | ||
**kwargs, | ||
): | ||
"""Instantiate LayoutLMv3DocumentClassifierPreprocessor from preset. | ||
|
||
Args: | ||
preset: string. Must be one of "layoutlmv3_base", "layoutlmv3_large". | ||
|
||
Examples: | ||
```python | ||
# Load preprocessor from preset | ||
preprocessor = LayoutLMv3DocumentClassifierPreprocessor.from_preset("layoutlmv3_base") | ||
``` | ||
""" | ||
if preset not in cls.presets: | ||
raise ValueError( | ||
"`preset` must be one of " | ||
f"""{", ".join(cls.presets)}. Received: {preset}""" | ||
) | ||
|
||
metadata = cls.presets[preset] | ||
config = metadata["config"] | ||
|
||
# Create tokenizer | ||
tokenizer = LayoutLMv3Tokenizer.from_preset(preset) | ||
|
||
# Create preprocessor | ||
preprocessor = cls( | ||
tokenizer=tokenizer, | ||
sequence_length=config["sequence_length"], | ||
image_size=config["image_size"], | ||
**kwargs, | ||
) | ||
|
||
return preprocessor |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file is mainly to register presets, follow other models to understand the format we follow.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pending