Skip to content

th3ch103/mlds490_hw2

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

HW2 — Deep Q-Network (DQN)

This repo provides a PyTorch implementation of DQN for CartPole-v1 and MsPacman-v0 with:

  • Replay Buffer
  • Target Network
  • ε-greedy policy with linear decay
  • Optional Huber loss & LR scheduler
  • Required plots (max-Q vs episodes, rewards vs episodes with moving average, 500-episode rollout histogram + mean/std)

Setup

# Create virtual environment
python -m venv .venv_gpu
source .venv_gpu/bin/activate

# Install dependencies
pip install --upgrade pip
pip install -r requirements.txt

Train

CartPole

python src/train_cartpole.py --episodes 800 --gamma 0.95

MsPacman

python src/train_mspacman.py --episodes 5000 --gamma 0.99

Outputs (plots, logs) are saved under outputs/ with timestamped run folders.

Generate 500-episode Rollout Histogram (after training)

Each training script automatically runs a 500-episode evaluation using the best checkpoint and saves the histogram & stats to the same run folder.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages