Skip to content

Conversation

@TroyGarden
Copy link
Contributor

@TroyGarden TroyGarden commented Dec 4, 2025

Summary:

context

  • TorchRec relies on nn.Module.named_children() to traverse the model to find the sparse modules to shard.
  • In a normal case, every sparse module only appears once in the model hiearchy, i.e., it has only one parent module.
  • However, in some corner cases, a sparse module might have multiple parent modules. This might confuse the TorchRec sharder due to its traversing logic: the same sparse module has multiple FQNs, and hence being sharded multiple times (create multiple sharded modules according to the FQNs).
image

solution

  • cache the sharded module with the original sparse module's object id
  • when the sparse module has multiple FQNs, only the first time in the named_children traversing, a sharded module will be created.

changes

Differential Revision: D88218200

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 4, 2025
@meta-codesync
Copy link
Contributor

meta-codesync bot commented Dec 4, 2025

@TroyGarden has exported this pull request. If you are a Meta employee, you can view the originating Diff in D88218200.

TroyGarden added a commit to TroyGarden/torchrec that referenced this pull request Dec 11, 2025
Summary:

# context
* TorchRec relies on `nn.Module.named_children()` to traverse the model to find the sparse modules to shard. 
* In a normal case, every sparse module only appears once in the model hiearchy, i.e., it has **only one** parent module. 
* However, in some corner cases, a sparse module might have multiple parent modules. This might confuse the TorchRec sharder due to its traversing logic: the same sparse module has multiple FQNs, and hence being sharded multiple times (create multiple sharded modules according to the FQNs).

 {F1983896519} 
# solution
* cache the sharded module with the original sparse module's object id
* when the sparse module has multiple FQNs, only the first time in the `named_children` traversing, a sharded module will be created.

# changes
* the change is protected by a KillSwitch: [enable_module_id_cache_for_dmp_shard_modules](https://www.internalfb.com/intern/justknobs/?name=pytorch%2Ftorchrec#enable_module_id_cache_for_dmp_shard_modules)
https://fb.workplace.com/groups/429376538334034/permalink/1343336826937996/

Reviewed By: malaybag, iamzainhuda

Differential Revision: D88218200
@TroyGarden TroyGarden force-pushed the export-D88218200 branch 2 times, most recently from bbbcea6 to 89df447 Compare December 11, 2025 01:01
TroyGarden added a commit to TroyGarden/torchrec that referenced this pull request Dec 11, 2025
Summary:

# context
* TorchRec relies on `nn.Module.named_children()` to traverse the model to find the sparse modules to shard. 
* In a normal case, every sparse module only appears once in the model hiearchy, i.e., it has **only one** parent module. 
* However, in some corner cases, a sparse module might have multiple parent modules. This might confuse the TorchRec sharder due to its traversing logic: the same sparse module has multiple FQNs, and hence being sharded multiple times (create multiple sharded modules according to the FQNs).

 {F1983896519} 
# solution
* cache the sharded module with the original sparse module's object id
* when the sparse module has multiple FQNs, only the first time in the `named_children` traversing, a sharded module will be created.

# changes
* the change is protected by a KillSwitch: [enable_module_id_cache_for_dmp_shard_modules](https://www.internalfb.com/intern/justknobs/?name=pytorch%2Ftorchrec#enable_module_id_cache_for_dmp_shard_modules)
https://fb.workplace.com/groups/429376538334034/permalink/1343336826937996/

Reviewed By: malaybag, iamzainhuda

Differential Revision: D88218200
Summary:

# context
* TorchRec relies on `nn.Module.named_children()` to traverse the model to find the sparse modules to shard. 
* In a normal case, every sparse module only appears once in the model hiearchy, i.e., it has **only one** parent module. 
* However, in some corner cases, a sparse module might have multiple parent modules. This might confuse the TorchRec sharder due to its traversing logic: the same sparse module has multiple FQNs, and hence being sharded multiple times (create multiple sharded modules according to the FQNs).

 {F1983896519} 
# solution
* cache the sharded module with the original sparse module's object id
* when the sparse module has multiple FQNs, only the first time in the `named_children` traversing, a sharded module will be created.

# changes
* the change is protected by a KillSwitch: [enable_module_id_cache_for_dmp_shard_modules](https://www.internalfb.com/intern/justknobs/?name=pytorch%2Ftorchrec#enable_module_id_cache_for_dmp_shard_modules)
https://fb.workplace.com/groups/429376538334034/permalink/1343336826937996/

Reviewed By: malaybag, iamzainhuda

Differential Revision: D88218200
@meta-codesync meta-codesync bot closed this in cc8df00 Dec 11, 2025
@TroyGarden TroyGarden deleted the export-D88218200 branch December 11, 2025 16:48
@TroyGarden TroyGarden changed the title use module object id to cache the sharded modules [detailed] use module object id to cache the sharded modules Dec 11, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported meta-exported

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant