-
Notifications
You must be signed in to change notification settings - Fork 60
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Edge Feature Support in Trainer and Inferrer Pipeline #1088
base: v0.4_ef_dev
Are you sure you want to change the base?
Conversation
@@ -990,6 +990,14 @@ def set_encoder(model, g, config, train_task): | |||
# Set GNN encoders | |||
dropout = config.dropout if train_task else 0 | |||
out_emb_size = config.out_emb_size if config.out_emb_size else config.hidden_size | |||
|
|||
# Check use edge feature and GNN encoder capacity | |||
assert (config.edge_feat_name is None) or \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need this check?
I think we have the check inside the gnn_encoder.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not exactly. Although gnn_encoder will check this, but all inheritances, e.g. RGAT, HGT, Sage, and GAT/v2 need to add the two edge feature arguments in init
function to allow gnn_encoder to work.
If we add the arguments, these APIs are changed and users will see the new arguments. But none of them implement the support of edge features. This will confuse users.
So, I choose the current logic to control our own CLI pipeline, and not change APIs.
_, part_config = generate_dummy_dist_graph(tmpdirname) | ||
gdata = GSgnnData(part_config=part_config) | ||
|
||
# Test case 4: abnormal case, set RGCN model with edge features for NC, but |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# Test case 4: abnormal case, set RGCN model with edge features for NC, but | |
# Test case 4: abnormal case, set RGCN model with edge features for EC, but |
Please check other comments.
inferrer1.infer( | ||
data=gdata, | ||
loader=infer_dataloader1, | ||
save_embed_path='/tmp/embs', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
inferrer2.infer( | ||
data=gdata, | ||
loader=infer_dataloader2, | ||
save_embed_path='/tmp/embs', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
# Test case 0: normal case, set HGT model without edge features for NC, and not provide | ||
# edge features. | ||
# Should complete inference process | ||
create_config4ef(Path(tmpdirname), 'gnn_nc.yaml', encoder='hgt', use_ef=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can use @pytest.mark.parametrize("encoder", ["rgcn", "hgt"]) to simplify the test code.
@@ -721,6 +727,1181 @@ def check_eval(mock_do_mini_batch_inference, | |||
|
|||
check_eval() | |||
|
|||
def create_config4ef(tmp_path, file_name, encoder='rgcn', task='nc', use_ef=True): | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comments as test_infer.py
Issue #, if available:
Description of changes:
This is the 5th PR for edge feature support. This PR include:
gsf.py
, adding an assert to check if GNN encorders support edge features, adding edge feature arguments when building RGCN ecnoder.fit()
function inGSgnnNodePredictionTrainer
,GSgnnEdgePredictionTrainer
, andGSgnnLinkPredictionTrainer
to use edge feature.node_gnn.py
, adding edge feature support.By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.