Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Edge Feature Support in Trainer and Inferrer Pipeline #1088

Open
wants to merge 24 commits into
base: v0.4_ef_dev
Choose a base branch
from

Conversation

zhjwy9343
Copy link
Contributor

Issue #, if available:

Description of changes:
This is the 5th PR for edge feature support. This PR include:

  • Modify gsf.py, adding an assert to check if GNN encorders support edge features, adding edge feature arguments when building RGCN ecnoder.
  • Modify the fit() function in GSgnnNodePredictionTrainer, GSgnnEdgePredictionTrainer, and GSgnnLinkPredictionTrainer to use edge feature.
  • Modify node_gnn.py, adding edge feature support.
  • Add test cases of trainers and inferrers in unit test.

By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.

@zhjwy9343 zhjwy9343 added enhancement New feature or request ready able to trigger the CI 0.4 labels Nov 13, 2024
@zhjwy9343 zhjwy9343 added ready able to trigger the CI and removed ready able to trigger the CI labels Nov 13, 2024
@@ -990,6 +990,14 @@ def set_encoder(model, g, config, train_task):
# Set GNN encoders
dropout = config.dropout if train_task else 0
out_emb_size = config.out_emb_size if config.out_emb_size else config.hidden_size

# Check use edge feature and GNN encoder capacity
assert (config.edge_feat_name is None) or \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this check?
I think we have the check inside the gnn_encoder.

Copy link
Contributor Author

@zhjwy9343 zhjwy9343 Nov 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not exactly. Although gnn_encoder will check this, but all inheritances, e.g. RGAT, HGT, Sage, and GAT/v2 need to add the two edge feature arguments in init function to allow gnn_encoder to work.

If we add the arguments, these APIs are changed and users will see the new arguments. But none of them implement the support of edge features. This will confuse users.

So, I choose the current logic to control our own CLI pipeline, and not change APIs.

python/graphstorm/model/node_gnn.py Show resolved Hide resolved
python/graphstorm/model/rgcn_encoder.py Outdated Show resolved Hide resolved
python/graphstorm/trainer/ep_trainer.py Show resolved Hide resolved
python/graphstorm/trainer/lp_trainer.py Show resolved Hide resolved
_, part_config = generate_dummy_dist_graph(tmpdirname)
gdata = GSgnnData(part_config=part_config)

# Test case 4: abnormal case, set RGCN model with edge features for NC, but
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Test case 4: abnormal case, set RGCN model with edge features for NC, but
# Test case 4: abnormal case, set RGCN model with edge features for EC, but

Please check other comments.

inferrer1.infer(
data=gdata,
loader=infer_dataloader1,
save_embed_path='/tmp/embs',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

inferrer2.infer(
data=gdata,
loader=infer_dataloader2,
save_embed_path='/tmp/embs',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

# Test case 0: normal case, set HGT model without edge features for NC, and not provide
# edge features.
# Should complete inference process
create_config4ef(Path(tmpdirname), 'gnn_nc.yaml', encoder='hgt', use_ef=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can use @pytest.mark.parametrize("encoder", ["rgcn", "hgt"]) to simplify the test code.

@@ -721,6 +727,1181 @@ def check_eval(mock_do_mini_batch_inference,

check_eval()

def create_config4ef(tmp_path, file_name, encoder='rgcn', task='nc', use_ef=True):

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comments as test_infer.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
0.4 enhancement New feature or request ready able to trigger the CI
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants