-
Notifications
You must be signed in to change notification settings - Fork 38
Implement GQE Manuscript V2 features and refactor architecture #373
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Implement GQE Manuscript V2 features and refactor architecture #373
Conversation
- Split transformer.py into pipeline.py and model.py for better modularity - pipeline.py: Contains Pipeline class for data processing - model.py: Contains GPT2 model class - Updated gqe.py to import from new modules - Improves code organization and maintainability Signed-off-by: Kohei Nakaji <[email protected]>
- Add factory.py: Factory class for creating loss functions - Refactor pipeline.py: Simplify code and remove redundant logic - Update gqe.py: Improve code structure and readability - Update loss.py: Minor improvements for better consistency - Reduces code complexity and improves maintainability Signed-off-by: Kohei Nakaji <[email protected]>
- Add scheduler.py: Contains TemperatureScheduler, DefaultScheduler, and CosineScheduler - Update gqe.py: Import schedulers from new module, remove scheduler definitions - Improves code organization and separation of concerns - Reduces gqe.py complexity by ~70 lines Signed-off-by: Kohei Nakaji <[email protected]>
- Migrate from manual Fabric loop to Lightning Trainer - Extract ReplayBuffer and BufferDataset to data.py - Create callbacks.py with MinEnergyCallback and TrajectoryCallback - Simplify TemperatureScheduler interface (get_inverse_temperature, update) - Add Factory.create_temperature_scheduler method - Move seed_everything to Pipeline.__init__ - Disable checkpointing for performance (2.5s -> 0.001s between epochs) - Add num_sanity_val_steps=0 to suppress warnings - Fix device placement issues for CUDA tensors - Set DataLoader num_workers=0 to avoid pickling SpinOperator Signed-off-by: Kohei Nakaji <[email protected]>
- Add GRPOLoss class inheriting from Loss base class - Update loss.compute() signature to use **kwargs instead of context - Add GRPOLoss to Factory with configurable clip_ratio - Remove unnecessary logger check in Pipeline - Update gqe_h2.py example with max_iters=50 - Clean up pyscf-generated files (.log, .chk) Signed-off-by: Kohei Nakaji <[email protected]>
- Implement variance-based adaptive temperature scheduler - Adjusts temperature based on energy variance in training batches - Increases temperature for high variance (exploration) - Decreases temperature for low variance (exploitation) - Add scheduler factory support - Extend Factory.create_temperature_scheduler() to support 'variance' mode - Configure via cfg.scheduler='variance' and cfg.target_variance - Fix Loss classes to properly inherit from torch.nn.Module - Add super().__init__() calls to all Loss subclasses - Fix device placement issues in GFlowLogitMatching - Enhance test coverage - Add test_variance_scheduler() for VarBasedScheduler unit testing - Add test_solvers_gqe_with_variance_scheduler() for integration testing - Add test_solvers_gqe_with_cosine_scheduler() for CosineScheduler - Add test_solvers_gqe_with_exp_loss() for ExpLogitMatching - Fix existing scheduler tests to use new API methods All 11 GQE tests pass successfully. Signed-off-by: Kohei Nakaji <[email protected]>
- Add trainer_kwargs and callbacks configuration options for customization - Flatten config structure (cfg.trainer.* → cfg.*) - Add operator pool utility functions (utils.py) - Add N2 molecule example (gqe_n2.py) - Update config docstring to match implementation Signed-off-by: Kohei Nakaji <[email protected]>
…fault - Add test_get_gqe_pauli_pool to verify operator pool generation - Set enable_checkpointing=False by default in trainer config - Minor formatting fixes in test file Signed-off-by: Kohei Nakaji <[email protected]>
|
/ok to test 0010fc7 |
|
Link to related doc update follow-up (not to be included in this PR): #382 |
melody-ren
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @konakaji . I took a brief look and will probably have more questions later.
Could you please run yapf to format the python files that you updated/added? Thank you!
| @@ -0,0 +1,44 @@ | |||
| from collections import deque | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'll need a licence header for this file too
| @@ -0,0 +1,38 @@ | |||
| from .loss import ExpLogitMatching, GFlowLogitMatching, GRPOLoss | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add a licence header
| @@ -0,0 +1,13 @@ | |||
| from transformers import GPT2LMHeadModel, GPT2Config | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add a licence header
Signed-off-by: Kohei Nakaji <[email protected]>
Set max_iters to 50 for quick testing, with note for full training Signed-off-by: Kohei Nakaji <[email protected]>
|
/ok to test 0a7b65f |
Signed-off-by: Kohei Nakaji <[email protected]>
|
Wheels job passed here: https://github.com/NVIDIA/cudaqx/actions/runs/20110636646 |
Implement GQE Manuscript V2 features (GRPO loss, Replay Buffer, Variance-based scheduler)
Summary
This PR implements key features from the updated manuscript (arXiv:2401.09253v2)
including GRPO loss function, Replay Buffer mechanism, and Variance-based temperature scheduler,
along with a comprehensive refactoring of the GQE (Generative Quantum Eigensolver) internal
implementation to improve modularity and add user customization options.
Motivation
Buffer mechanism, and Variance-based temperature scheduler described in the updated manuscript
(arXiv:2401.09253v2), which were not present
in the original implementation
Changes
Core Implementation
Optimization) loss and replay buffer mechanism as described in
arXiv:2401.09253v2, set as default
based on energy variance as described in the manuscript V2
GPT2 (model) and Pipeline (LightningModule)
framework, removing Fabric/Lightning mixing. Updated logger from Fabric to Lightning
accordingly
modularity
User-Facing Features
trainer_kwargsandcallbacksto customize Lightning Trainer behaviorutils.pywith helper functions for generatingoperator pools (
get_identity,get_gqe_pauli_pool)Examples and Tests
gqe_n2.pydemonstrating Pauli operator pool usageBreaking Changes
cfg.fabric_logger→cfg.lightning_loggercfg.use_fabric_logging→cfg.use_lightning_loggingTesting
References
arXiv:2401.09253v2 (2024)