Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding a Deep Nearest Class Means Classifier model to Flair #3532

Open
wants to merge 8 commits into
base: master
Choose a base branch
from

Conversation

sheldon-roberts
Copy link
Contributor

This PR adds a DeepNCMClassifier to flair.models
My reasons for adding this model are outlined in the issue: #3531

This model requires a TrainerPlugin because it makes the prototype updates using an after_training_batch hook. Please let me know if there is a cleaner way to handle this.

Example Script:

from flair.data import Corpus
from flair.datasets import TREC_50
from flair.embeddings import TransformerDocumentEmbeddings
from flair.models import DeepNCMClassifier
from flair.trainers import ModelTrainer
from flair.trainers.plugins import DeepNCMPlugin

# load the TREC dataset
corpus: Corpus = TREC_50()

# make a transformer document embedding
document_embeddings = TransformerDocumentEmbeddings("roberta-base", fine_tune=True)

# create the classifier
classifier = DeepNCMClassifier(
    document_embeddings,
    label_dictionary=corpus.make_label_dictionary(label_type="class"),
    label_type="class",
    use_encoder=False,
    mean_update_method="condensation",
)

# initialize the trainer
trainer = ModelTrainer(classifier, corpus)

# train the model
trainer.fine_tune(
    "resources/taggers/deepncm_trec",
    plugins=[DeepNCMPlugin()],
)

@plonerma
Copy link
Collaborator

plonerma commented Aug 19, 2024

Hello @sheldon-roberts,

Thanks a lot for your contribution! This is had been buried deep in the backlog of things to implement.

I also don't see a way of how this could be implemented without a TrainerPlugin.

What do you think about implementing this as a decoder (such as the PrototypicalDecoder), such that it can be used with the default classifier? Then it could be used with all model types (i.e. span, text, etc. classification).

Additionally, what do you think about supporting the different distance functions similar to the PrototypicalDecoder?

@sheldon-roberts
Copy link
Contributor Author

Hi @plonerma, Thanks for taking a look!

What do you think about implementing this as a decoder (such as the PrototypicalDecoder), such that it can be used with the default classifier? Then it could be used with all model types (i.e. span, text, etc. classification).
Additionally, what do you think about supporting the different distance functions similar to the PrototypicalDecoder?

I really like both of these ideas! I will look into making these changes soon

@MattGPT-ai
Copy link
Contributor

Hello @sheldon-roberts,

Thanks a lot for your contribution! This is had been buried deep in the backlog of things to implement.

I also don't see a way of how this could be implemented without a TrainerPlugin.

What do you think about implementing this as a decoder (such as the PrototypicalDecoder), such that it can be used with the default classifier? Then it could be used with all model types (i.e. span, text, etc. classification).

Additionally, what do you think about supporting the different distance functions similar to the PrototypicalDecoder?

In order to avoid using a trainer plugin, could we just add a function like def after_training_epoch(): pass that gets added to the base Model class, which gets called right before or after self.dispatch("after_training_epoch", epoch=epoch) in the train_custom function?

I think this would work with this being a class, but might not work when it gets changed to a decoder.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants