-
Notifications
You must be signed in to change notification settings - Fork 280
Added LayoutLMv3 #2178
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Added LayoutLMv3 #2178
Changes from all commits
ae79d15
737f03a
455a140
d92c8c4
0948f95
3c02f78
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
|
||
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
"""LayoutLMv3 document classifier.""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This file needs to be empty, all the import is handled in keras_hub/api directory and will be automatically generated whenever you run There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. pending |
||
|
||
from keras_hub.src.models.layoutlmv3.document_classifier.layoutlmv3_document_classifier import LayoutLMv3DocumentClassifier | ||
from keras_hub.src.models.layoutlmv3.document_classifier.layoutlmv3_document_classifier_preprocessor import LayoutLMv3DocumentClassifierPreprocessor |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
from keras_hub.src.models.layoutlmv3.layoutlmv3_backbone import LayoutLMv3Backbone | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This file is mainly to register presets, follow other models to understand the format we follow. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. pending |
||
from keras_hub.src.models.layoutlmv3.layoutlmv3_document_classifier import LayoutLMv3DocumentClassifier | ||
from keras_hub.src.models.layoutlmv3.layoutlmv3_document_classifier_preprocessor import LayoutLMv3DocumentClassifierPreprocessor | ||
from keras_hub.src.models.layoutlmv3.layoutlmv3_tokenizer import LayoutLMv3Tokenizer | ||
from keras_hub.src.models.layoutlmv3.layoutlmv3_transformer import LayoutLMv3Transformer | ||
from keras_hub.src.models.layoutlmv3.layoutlmv3_presets import layoutlmv3_presets, backbone_presets | ||
from keras_hub.src.utils.preset_utils import register_presets | ||
|
||
__all__ = [ | ||
"LayoutLMv3Backbone", | ||
"LayoutLMv3DocumentClassifier", | ||
"LayoutLMv3DocumentClassifierPreprocessor", | ||
"LayoutLMv3Tokenizer", | ||
"LayoutLMv3Transformer", | ||
"layoutlmv3_presets", | ||
] | ||
|
||
register_presets(backbone_presets, LayoutLMv3Backbone) |
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
# Copyright 2024 The Keras Hub Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
|
||
Comment on lines
+1
to
+15
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove this |
||
import os | ||
import numpy as np | ||
from keras import testing_utils | ||
from keras import ops | ||
from keras import backend | ||
from keras.testing import test_case | ||
from ..layoutlmv3.layoutlmv3_backbone import LayoutLMv3Backbone | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No relative imports |
||
|
||
class LayoutLMv3BackboneTest(test_case.TestCase): | ||
def setUp(self): | ||
super().setUp() | ||
self.backbone = LayoutLMv3Backbone( | ||
vocab_size=100, | ||
hidden_size=64, | ||
num_hidden_layers=2, | ||
num_attention_heads=2, | ||
intermediate_size=128, | ||
image_size=(112, 112), | ||
patch_size=16, | ||
) | ||
|
||
# Create dummy inputs | ||
self.batch_size = 2 | ||
self.seq_length = 16 | ||
self.input_ids = ops.random.uniform( | ||
(self.batch_size, self.seq_length), minval=0, maxval=100, dtype="int32" | ||
) | ||
self.bbox = ops.random.uniform( | ||
(self.batch_size, self.seq_length, 4), minval=0, maxval=100, dtype="int32" | ||
) | ||
self.attention_mask = ops.ones((self.batch_size, self.seq_length), dtype="int32") | ||
self.image = ops.random.uniform( | ||
(self.batch_size, 112, 112, 3), minval=0, maxval=1, dtype="float32" | ||
) | ||
|
||
self.inputs = { | ||
"input_ids": self.input_ids, | ||
"bbox": self.bbox, | ||
"attention_mask": self.attention_mask, | ||
"image": self.image, | ||
} | ||
|
||
def test_valid_call(self): | ||
"""Test the backbone with valid inputs.""" | ||
outputs = self.backbone(self.inputs) | ||
self.assertIn("sequence_output", outputs) | ||
self.assertIn("pooled_output", outputs) | ||
self.assertEqual(outputs["sequence_output"].shape, (self.batch_size, self.seq_length + 49 + 1, 64)) # text + image patches + cls | ||
self.assertEqual(outputs["pooled_output"].shape, (self.batch_size, 64)) | ||
|
||
def test_save_and_load(self): | ||
"""Test saving and loading the backbone.""" | ||
outputs = self.backbone(self.inputs) | ||
path = self.get_temp_dir() | ||
self.backbone.save(path) | ||
restored_backbone = backend.saving.load_model(path) | ||
restored_outputs = restored_backbone(self.inputs) | ||
self.assertAllClose(outputs["sequence_output"], restored_outputs["sequence_output"]) | ||
self.assertAllClose(outputs["pooled_output"], restored_outputs["pooled_output"]) | ||
|
||
def test_from_preset(self): | ||
"""Test creating a backbone from a preset.""" | ||
backbone = LayoutLMv3Backbone.from_preset("layoutlmv3_base") | ||
inputs = { | ||
"input_ids": ops.random.uniform((2, 16), 0, 100, dtype="int32"), | ||
"bbox": ops.random.uniform((2, 16, 4), 0, 100, dtype="int32"), | ||
"attention_mask": ops.ones((2, 16), dtype="int32"), | ||
"image": ops.random.uniform((2, 112, 112, 3), dtype="float32"), | ||
} | ||
outputs = backbone(inputs) | ||
self.assertIn("sequence_output", outputs) | ||
self.assertIn("pooled_output", outputs) | ||
|
||
def test_backbone_with_different_input_shapes(self): | ||
"""Test the backbone with different input shapes.""" | ||
# Test with different sequence lengths | ||
seq_lengths = [32, 128] | ||
for seq_len in seq_lengths: | ||
inputs = { | ||
"input_ids": ops.random.uniform( | ||
(self.batch_size, seq_len), minval=0, maxval=100, dtype="int32" | ||
), | ||
"bbox": ops.random.uniform( | ||
(self.batch_size, seq_len, 4), minval=0, maxval=100, dtype="int32" | ||
), | ||
"attention_mask": ops.ones((self.batch_size, seq_len), dtype="int32"), | ||
"image": self.image, | ||
} | ||
outputs = self.backbone(inputs) | ||
expected_seq_length = seq_len + 49 + 1 | ||
self.assertEqual(outputs["sequence_output"].shape, (self.batch_size, expected_seq_length, 64)) | ||
|
||
# Test with different batch sizes | ||
batch_sizes = [1, 4] | ||
for batch_size in batch_sizes: | ||
inputs = { | ||
"input_ids": ops.random.uniform( | ||
(batch_size, self.seq_length), minval=0, maxval=100, dtype="int32" | ||
), | ||
"bbox": ops.random.uniform( | ||
(batch_size, self.seq_length, 4), minval=0, maxval=100, dtype="int32" | ||
), | ||
"attention_mask": ops.ones((batch_size, self.seq_length), dtype="int32"), | ||
"image": ops.random.uniform( | ||
(batch_size, 112, 112, 3), minval=0, maxval=1, dtype="float32" | ||
), | ||
} | ||
outputs = self.backbone(inputs) | ||
expected_seq_length = self.seq_length + 49 + 1 | ||
self.assertEqual(outputs["sequence_output"].shape, (batch_size, expected_seq_length, 64)) | ||
|
||
def test_backbone_with_attention_mask(self): | ||
"""Test the backbone with different attention masks.""" | ||
# Create a mask with some padding | ||
attention_mask = ops.ones((self.batch_size, self.seq_length), dtype="int32") | ||
indices = ops.array([[0, 32], [1, 48]], dtype="int32") | ||
updates = ops.array([0, 0], dtype="int32") | ||
attention_mask = ops.scatter_nd(indices, updates, attention_mask.shape) | ||
|
||
inputs = { | ||
"input_ids": self.input_ids, | ||
"bbox": self.bbox, | ||
"attention_mask": attention_mask, | ||
"image": self.image, | ||
} | ||
|
||
outputs = self.backbone(inputs) | ||
self.assertIsInstance(outputs, dict) | ||
self.assertIn("sequence_output", outputs) | ||
self.assertIn("pooled_output", outputs) | ||
|
||
def test_backbone_gradient(self): | ||
"""Test that the backbone produces gradients.""" | ||
with backend.GradientTape() as tape: | ||
outputs = self.backbone(self.inputs) | ||
loss = ops.mean(outputs["pooled_output"]) | ||
|
||
# Check if gradients exist for all trainable variables | ||
gradients = tape.gradient(loss, self.backbone.trainable_variables) | ||
for grad in gradients: | ||
self.assertIsNotNone(grad) | ||
self.assertFalse(ops.all(ops.isnan(grad))) | ||
self.assertFalse(ops.all(ops.isinf(grad))) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
"""LayoutLMv3 document classifier implementation. | ||
|
||
This module implements a document classification model using the LayoutLMv3 backbone. | ||
""" | ||
|
||
from typing import Dict, List, Optional, Union | ||
Comment on lines
+1
to
+6
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need of this |
||
|
||
from keras import backend, layers, ops | ||
from keras.saving import register_keras_serializable | ||
from keras_hub.src.api_export import keras_hub_export | ||
from keras_hub.src.models.backbone import Backbone | ||
|
||
from .layoutlmv3_backbone import LayoutLMv3Backbone | ||
from .layoutlmv3_document_classifier_preprocessor import LayoutLMv3DocumentClassifierPreprocessor | ||
|
||
@keras_hub_export("keras_hub.models.LayoutLMv3DocumentClassifier") | ||
class LayoutLMv3DocumentClassifier(layers.Layer): | ||
"""Document classifier using LayoutLMv3 backbone. | ||
|
||
This model uses the LayoutLMv3 backbone for document classification tasks, | ||
adding a classification head on top of the backbone's pooled output. | ||
|
||
Args: | ||
backbone: LayoutLMv3Backbone instance or string preset name. | ||
num_classes: int, defaults to 2. Number of output classes. | ||
dropout: float, defaults to 0.1. Dropout rate for the classification head. | ||
**kwargs: Additional keyword arguments passed to the parent class. | ||
|
||
Example: | ||
```python | ||
# Initialize classifier from preset | ||
classifier = LayoutLMv3DocumentClassifier.from_preset("layoutlmv3_base") | ||
|
||
# Process document | ||
outputs = classifier({ | ||
"input_ids": input_ids, | ||
"bbox": bbox, | ||
"attention_mask": attention_mask, | ||
"image": image | ||
}) | ||
``` | ||
""" | ||
|
||
def __init__( | ||
self, | ||
backbone, | ||
num_classes=2, | ||
dropout=0.1, | ||
**kwargs, | ||
): | ||
super().__init__(**kwargs) | ||
self.backbone = backbone | ||
self.num_classes = num_classes | ||
self.dropout = dropout | ||
|
||
def call(self, inputs): | ||
# Get backbone outputs | ||
backbone_outputs = self.backbone(inputs) | ||
sequence_output = backbone_outputs["sequence_output"] | ||
pooled_output = backbone_outputs["pooled_output"] | ||
|
||
# Classification head | ||
x = layers.Dropout(self.dropout)(pooled_output) | ||
outputs = layers.Dense( | ||
self.num_classes, | ||
activation="softmax", | ||
name="classifier", | ||
)(x) | ||
|
||
return outputs | ||
|
||
def get_config(self): | ||
config = super().get_config() | ||
config.update({ | ||
"backbone": self.backbone, | ||
"num_classes": self.num_classes, | ||
"dropout": self.dropout, | ||
}) | ||
return config | ||
|
||
@classmethod | ||
def from_preset( | ||
cls, | ||
preset, | ||
num_classes=2, | ||
dropout=0.1, | ||
**kwargs, | ||
): | ||
"""Create a LayoutLMv3 document classifier from a preset. | ||
|
||
Args: | ||
preset: string. Must be one of "layoutlmv3_base", "layoutlmv3_large". | ||
num_classes: int. Number of classes to classify documents into. | ||
dropout: float. Dropout probability for the classification head. | ||
**kwargs: Additional keyword arguments. | ||
|
||
Returns: | ||
A LayoutLMv3DocumentClassifier instance. | ||
""" | ||
backbone = LayoutLMv3Backbone.from_preset(preset) | ||
return cls( | ||
backbone=backbone, | ||
num_classes=num_classes, | ||
dropout=dropout, | ||
**kwargs, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove this directory and file
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This still needs to be removed