Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Relation extraction llama #522

Open
wants to merge 170 commits into
base: master
Choose a base branch
from
Open

Relation extraction llama #522

wants to merge 170 commits into from

Conversation

vladd-bit
Copy link
Member

This should be the final version of the relation extraction implementation, the following were implemented:

  • Llama model support
  • ModernBert model support
  • saving/loading for all three variations of models (BERT/Llama/ModernBERT)
  • improvements to documentation
  • fixes to dataset creation from trainer exports, the way the num non-relations were
  • fixes to dataset creation from spacy docs/fake docs, removed duplicates
  • fixes to inference (input data was not filtered/validated, many invalid combinations were present) resulting in a lot of garbage relations

@vladd-bit vladd-bit self-assigned this Mar 10, 2025
@vladd-bit vladd-bit requested a review from mart-r March 12, 2025 11:39
Copy link
Collaborator

@mart-r mart-r left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good overall, I'd say!

I think we can make this slightly easier to maintain by reducing the amount of redundant code and creating/following some protocols for tokenizers and models.

I also had some questions as well, regarding changes to the dependencies and some changes to config doc strings.

Don't hesitate to rebut some of my suggestions!

PS:
Didn't go into too much details reviewing some of the more low level changes.


NOTE: If used along MetaCAT or additional NER, only one of the seeds will take effect
NB! For these changes to take effect, the pipe would need to be recreated."""
"""The seed for random number generation."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was there a reason the doc string was changed?
Does the RelCAT random number generation seed now operate separately from the rest of the seeds?
Is one able to now change the seed dynamtically during operation?

@@ -50,9 +93,9 @@ class RelCAT(PipeRunner):

log = logging.getLogger(__name__)

def __init__(self, cdb: CDB, tokenizer: TokenizerWrapperBERT, config: ConfigRelCAT = ConfigRelCAT(), task="train", init_model=False):
def __init__(self, cdb: CDB, tokenizer: Union[TokenizerWrapperBERT, TokenizerWrapperModernBERT, TokenizerWrapperLlama], config: ConfigRelCAT = ConfigRelCAT(), task="train", init_model=False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps it would be better to define a protocol that each tokenizer wrapper follows and using that protocol as the type here?
I.e something that defines methods such as:
get_size
save
config_from_pretrained
config_from_json_file
model_from_pretrained

That way we'd be able to avoid checking the type (e.g in lines 166+ in _get_model; and in lines 233+, 268, and 282+ in load).

classification_logits = self.fc3(x)
return classification_logits.to(self.relcat_config.general.device)

def forward(self,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quite similar to line 320 (ModernBertModel_RelationExtraction.forward) and line 467 (LlamaModel_RelationExtraction.forward).
Perhaps we could abstract this out so as to not maintain 3 different versions of this method?

classification_logits = self.fc3(x)
return classification_logits.to(self.relcat_config.general.device)

def forward(self,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quite similar to line line 112 (BertModel_RelationExtraction.forward) and line 467 (LlamaModel_RelationExtraction.forward).
Perhaps we could abstract this out so as to not maintain 2 different versions of this method?

@@ -170,10 +463,12 @@ def output2logits(self, pooled_output: torch.Tensor, sequence_output: torch.Tens
classification_logits = self.fc3(x)
return classification_logits.to(self.relcat_config.general.device)


def forward(self,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quite similar to line 320 (ModernBertModel_RelationExtraction.forward) and line 112 (BertModel_RelationExtraction.forward).
Perhaps we could abstract this out so as to not maintain 3 different versions of this method?

@@ -18,7 +17,7 @@ class RelData(Dataset):

log = logging.getLogger(__name__)

def __init__(self, tokenizer: TokenizerWrapperBERT, config: ConfigRelCAT, cdb: CDB = CDB()):
def __init__(self, tokenizer: Union[TokenizerWrapperBERT, TokenizerWrapperLlama, TokenizerWrapperModernBERT], config: ConfigRelCAT, cdb: CDB = CDB()):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, protocol would help - see comment in line 96 in rel_cat.RelCAT.__init__

config (ConfigRelCAT): same config used in RelCAT
cdb (CDB): Optional, used to add concept ids and types to detected ents,
useful when creating datasets from MedCAT output. Defaults to CDB().
"""

self.cdb: CDB = cdb
self.config: ConfigRelCAT = config
self.tokenizer: TokenizerWrapperBERT = tokenizer
self.tokenizer: Union[TokenizerWrapperBERT, TokenizerWrapperLlama, TokenizerWrapperModernBERT] = tokenizer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, protocol would help - see comment in line 96 in rel_cat.RelCAT.__init__

@mart-r
Copy link
Collaborator

mart-r commented Mar 19, 2025

I feel like you've not really addressed most of my comments so far.
Is there something that you've not pushed into here yet?

Or did you miss the comments because they were on files that don't normally show up due to having too many changes?
Referring to comments in the following files:

  • rel_dataset.py
  • models.py
  • ml_utils.py
  • rel_cat.py

@vladd-bit
Copy link
Member Author

I feel like you've not really addressed most of my comments so far. Is there something that you've not pushed into here yet?

Or did you miss the comments because they were on files that don't normally show up due to having too many changes? Referring to comments in the following files:

  • rel_dataset.py
  • models.py
  • ml_utils.py
  • rel_cat.py
  • the model loading/saving and assignment has been changed now, using the generic hf_model attr
  • generic tokenizer class /w implementation inherited by all other tokenizers, reducing codebase
    to do: models and type hinting

@mart-r
Copy link
Collaborator

mart-r commented Mar 19, 2025

Alright, so what I'm talking about:

  • Since we have a common base class for the tokenizer (BaseTokenizerWrapper), we could use it for typing
    • E.g rel_cat.py lines 96 and 98
    • And in ml_utils.py lines 244 and 261?
      • PS: is ml_utils.tokenize (line 261) even needed? doesn't seem to be used
    • And in rel_dataset.py line 20
  • The base class could define ways to create the appropriate model config
    • E.g to avoid type-checks in rel_cat.py line 159 and onwards
  • There could be a base protocol for the config as well
    • That would allow the model protocol to create return something that follows this protocol
    • I.e in rel_cat.py line 159+
  • There's multiple occurances of near enough identical methods of output2logits
    • In models.py starting at lines 58, 267, and 409
    • Perhaps these could be consolidated to avoid code duplication?
  • There's multiple occurances of near enough identical methods of forward
    • In models.py starting at lines 112, 320, and 467
    • Perhapse these could be consolidated to avoid code duplication?
  • There's a comment about logginge exceptions
    • In rel_cat.py line 297
  • There's a bunch if string literals re-used in ml_utils.py
    • If we defined them as constants, we could re-use them and lower the chance of typos
  • Perhaps we could have a load_tokenizer_from_file method in tokenizers.py
    • To use rel_cat.py line 200
    • This would even further remove the implementation-specific details away from rel_cat.py

For the abstraction, here's what I was picturing:

from typing import Protocol
from abc import ABC, abstractmethod

# protocol for the config
class ModelConfig(Protocol):
    # the attributes
    vocab_size: int
    pad_token_id: int
    # the method(s)
    def to_json_file(self, file_path: str):
        pass # perhaps some doc string
# since the configs already adhere to this, shouldn't really need much other changes
# though at the per-implementation level there may need to be a type check to make sure these are of the correct type (i.e in `model_from_pretrained`)

class BaseTokenizerWrapper(PreTrainedTokenizerFast, ABC):
    # existing stuff
    @abstractmethod
    def config_from_pretrained(self) -> ModelConfig:
        pass # perhaps some doc string
    @abstractmethod
    def config_from_json_file(self, file_path: str) -> ModelConfig:
        pass # perhaps some doc string
    def model_from_pretrained(self, relcat_config: ConfigRelCAT, model_config: ModelConfig,
    pretrained_model_name_or_path: str = 'default') -> nn.Module:  # with a different default specified on a per implementation basis
        pass # perhaps some doc string
# and then in the implementation you'd just define these on a per-implementation basis

That way we could simplify some of the code that relies on the type checks in rel_cat.py. I feel like having each implementation keep track of the additional information needed for the config init would also be preferable over having implementation-specific values in rel_cat.py.

My main concern here is that if we leave a lot of implementation-specific stuff in rel_cat.py it becomes a bit of a mess that is hard to maintain. I.e, curtently, defaults that are specific to TokenizerWrapperBERT are scattered around rel_cat.py within different methods (i.e in get_model and load).
Ideally, we want the rel_cat.py class only act on the abstractions (e.g BaseTokenizerWrapper, ModelConfig, and nn.Module) and not do implementation-specific stuff.

Let me know what you think.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants