-
Notifications
You must be signed in to change notification settings - Fork 105
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Relation extraction llama #522
base: master
Are you sure you want to change the base?
Conversation
…tion_extraction
…tion_extraction
…tion_extraction_llama
…tion_extraction_llama
…ck/MedCAT into relation_extraction_llama
…tion_extraction_llama
…tion_extraction_llama
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good overall, I'd say!
I think we can make this slightly easier to maintain by reducing the amount of redundant code and creating/following some protocols for tokenizers and models.
I also had some questions as well, regarding changes to the dependencies and some changes to config doc strings.
Don't hesitate to rebut some of my suggestions!
PS:
Didn't go into too much details reviewing some of the more low level changes.
medcat/config_rel_cat.py
Outdated
|
||
NOTE: If used along MetaCAT or additional NER, only one of the seeds will take effect | ||
NB! For these changes to take effect, the pipe would need to be recreated.""" | ||
"""The seed for random number generation.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was there a reason the doc string was changed?
Does the RelCAT random number generation seed now operate separately from the rest of the seeds?
Is one able to now change the seed dynamtically during operation?
@@ -50,9 +93,9 @@ class RelCAT(PipeRunner): | |||
|
|||
log = logging.getLogger(__name__) | |||
|
|||
def __init__(self, cdb: CDB, tokenizer: TokenizerWrapperBERT, config: ConfigRelCAT = ConfigRelCAT(), task="train", init_model=False): | |||
def __init__(self, cdb: CDB, tokenizer: Union[TokenizerWrapperBERT, TokenizerWrapperModernBERT, TokenizerWrapperLlama], config: ConfigRelCAT = ConfigRelCAT(), task="train", init_model=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps it would be better to define a protocol that each tokenizer wrapper follows and using that protocol as the type here?
I.e something that defines methods such as:
get_size
save
config_from_pretrained
config_from_json_file
model_from_pretrained
That way we'd be able to avoid checking the type (e.g in lines 166+ in _get_model
; and in lines 233+, 268, and 282+ in load
).
classification_logits = self.fc3(x) | ||
return classification_logits.to(self.relcat_config.general.device) | ||
|
||
def forward(self, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Quite similar to line 320 (ModernBertModel_RelationExtraction.forward
) and line 467 (LlamaModel_RelationExtraction.forward
).
Perhaps we could abstract this out so as to not maintain 3 different versions of this method?
classification_logits = self.fc3(x) | ||
return classification_logits.to(self.relcat_config.general.device) | ||
|
||
def forward(self, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Quite similar to line line 112 (BertModel_RelationExtraction.forward
) and line 467 (LlamaModel_RelationExtraction.forward
).
Perhaps we could abstract this out so as to not maintain 2 different versions of this method?
@@ -170,10 +463,12 @@ def output2logits(self, pooled_output: torch.Tensor, sequence_output: torch.Tens | |||
classification_logits = self.fc3(x) | |||
return classification_logits.to(self.relcat_config.general.device) | |||
|
|||
|
|||
def forward(self, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Quite similar to line 320 (ModernBertModel_RelationExtraction.forward
) and line 112 (BertModel_RelationExtraction.forward
).
Perhaps we could abstract this out so as to not maintain 3 different versions of this method?
@@ -18,7 +17,7 @@ class RelData(Dataset): | |||
|
|||
log = logging.getLogger(__name__) | |||
|
|||
def __init__(self, tokenizer: TokenizerWrapperBERT, config: ConfigRelCAT, cdb: CDB = CDB()): | |||
def __init__(self, tokenizer: Union[TokenizerWrapperBERT, TokenizerWrapperLlama, TokenizerWrapperModernBERT], config: ConfigRelCAT, cdb: CDB = CDB()): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again, protocol would help - see comment in line 96 in rel_cat.RelCAT.__init__
config (ConfigRelCAT): same config used in RelCAT | ||
cdb (CDB): Optional, used to add concept ids and types to detected ents, | ||
useful when creating datasets from MedCAT output. Defaults to CDB(). | ||
""" | ||
|
||
self.cdb: CDB = cdb | ||
self.config: ConfigRelCAT = config | ||
self.tokenizer: TokenizerWrapperBERT = tokenizer | ||
self.tokenizer: Union[TokenizerWrapperBERT, TokenizerWrapperLlama, TokenizerWrapperModernBERT] = tokenizer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again, protocol would help - see comment in line 96 in rel_cat.RelCAT.__init__
I feel like you've not really addressed most of my comments so far. Or did you miss the comments because they were on files that don't normally show up due to having too many changes?
|
|
Alright, so what I'm talking about:
For the abstraction, here's what I was picturing: from typing import Protocol
from abc import ABC, abstractmethod
# protocol for the config
class ModelConfig(Protocol):
# the attributes
vocab_size: int
pad_token_id: int
# the method(s)
def to_json_file(self, file_path: str):
pass # perhaps some doc string
# since the configs already adhere to this, shouldn't really need much other changes
# though at the per-implementation level there may need to be a type check to make sure these are of the correct type (i.e in `model_from_pretrained`)
class BaseTokenizerWrapper(PreTrainedTokenizerFast, ABC):
# existing stuff
@abstractmethod
def config_from_pretrained(self) -> ModelConfig:
pass # perhaps some doc string
@abstractmethod
def config_from_json_file(self, file_path: str) -> ModelConfig:
pass # perhaps some doc string
def model_from_pretrained(self, relcat_config: ConfigRelCAT, model_config: ModelConfig,
pretrained_model_name_or_path: str = 'default') -> nn.Module: # with a different default specified on a per implementation basis
pass # perhaps some doc string
# and then in the implementation you'd just define these on a per-implementation basis That way we could simplify some of the code that relies on the type checks in My main concern here is that if we leave a lot of implementation-specific stuff in Let me know what you think. |
This should be the final version of the relation extraction implementation, the following were implemented: