-
Notifications
You must be signed in to change notification settings - Fork 5
Ensemble Models #77
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Ensemble Models #77
Conversation
… from protein_prediction to dev
- dummy param should be linked to loss to build gradient graph - Error : element 0 of tensors does not require grad and does not have a grad_fn
Test example config : class_path: chebai.models.ensemble.ChebiEnsemble
init_args:
optimizer_kwargs:
lr: 1e-3
model_configs: {
"chebi50": {
"ckpt_path": "logs/chebi50_bce_unweighted/version_9/checkpoints/best_epoch=00_val_loss=0.3521_val_macro-f1=0.0381_val_micro-f1=0.1753.ckpt",
"class_path": chebai.models.Electra,
"labels_path": "data/chebi_v231/ChEBI50/processed/classes.json",
},
"chebi50_1": {
"ckpt_path": "logs/chebi50_bce_unweighted/version_9/checkpoints/best_epoch=00_val_loss=0.3521_val_macro-f1=0.0381_val_micro-f1=0.1753.ckpt",
"class_path": chebai.models.Electra,
"labels_path": "data/chebi_v231/ChEBI50/processed/classes.json",
},
}
|
- confidence=2×∣x−0.5∣
- for clean code and code reusability
Please find the latest config file eg... which is working for electra and gnn class_path: chebai.ensemble.FullEnsembleWMV
init_args:
data_processed_dir_main: "data/chebi_v231/ChEBI50/processed"
operation_mode: "prediction"
smiles_list_file_path: "data/chebi_v231/ChEBI50/processed/smiles.txt"
_perform_validation_checks: False # To avoid check for using same model with same model configs
model_configs: {
"chebi50": {
"model_ckpt_path": "logs\\chebi50_electra_231\\version_1\\checkpoints\\best_epoch=00_val_loss=0.6766_val_macro-f1=0.7273_val_micro-f1=0.0248.ckpt",
"model_config_file_path": "configs\\model\\electra.yml",
"data_config_file_path": "configs\\data\\chebi\\chebi50.yml",
"model_labels_path": "data/chebi_v231/ChEBI50/processed/classes.json",
"wrapper_class_path": chebai.ensemble.NNWrapper,
},
"chebi50_1": {
"model_ckpt_path": "logs\\chebi50_electra_231\\version_1\\checkpoints\\best_epoch=00_val_loss=0.6766_val_macro-f1=0.7273_val_micro-f1=0.0248.ckpt",
"model_config_file_path": "configs\\model\\electra.yml",
"data_config_file_path": "configs\\data\\chebi\\chebi50.yml",
"model_labels_path": "data/chebi_v231/ChEBI50/processed/classes.json",
"wrapper_class_path": chebai.ensemble.NNWrapper,
},
"chebi50_graph": {
"model_ckpt_path": "logs\\regated_gnn_v231\\version_0\\checkpoints\\best_epoch=00_val_loss=0.6786_val_macro-f1=0.6364_val_micro-f1=0.0321.ckpt",
"model_config_file_path": "..\\python-chebai-graph\\configs\\model\\gnn_res_gated.yml",
"data_config_file_path": "..\\python-chebai-graph\\configs\\data\\chebi50_graph_properties.yml",
"model_labels_path": "data/chebi_v231/ChEBI50/processed/classes.json",
"wrapper_class_path": chebai.ensemble.GNNWrapper,
},
# "chemlog": {
# "model_labels_path": "data/chebi_v231/ChEBI50/processed/classes.json",
# "wrapper_class_path": chebai.ensemble.ChemLogWrapper,
# },
}
|
smiles_list = data_df["SMILES"].to_list() | ||
return {"logits": self.get_chemlog_results(smiles_list)} | ||
|
||
def get_chemlog_results(self, smiles_list) -> list: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sfluegel05, I have added wrapper for chemlog. What I need from the method is some kind of prediction for each class, which can be passed to sigmoid function later on.
I am not sure how chemlog achieves this. Can you please guide me?
I moved the ensemble logic to https://github.com/ChEB-AI/python-chebifier and the property calculation to the results folder |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All right, then this should be mergable
Uh oh!
There was an error while loading. Please reload this page.