Skip to content

Commit 49b3efd

Browse files
committed
add scripts, update utils
1 parent d13504c commit 49b3efd

15 files changed

+251
-144
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ Explorer is a PyTorch reinforcement learning framework for **exploring** new ide
3535
| | ├── DDQN
3636
| | ├── NoisyNetDQN
3737
| | ├── BootstrappedDQN
38-
| | └── MeDQN_Uniform, MeDQN_Real
38+
| | └── MeDQN: MeDQN(U), MeDQN(R)
3939
| ├── Maxmin DQN ── Ensemble DQN
4040
| └── Averaged DQN
4141
└── REINFORCE

analysis.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def get_process_result_dict(result, config_idx, mode='Train'):
1010
'Env': result['Env'][0],
1111
'Agent': result['Agent'][0],
1212
'Config Index': config_idx,
13-
'Return (mean)': result['Return'][-100:].mean() if mode=='Train' else result['Return'][-5:].mean()
13+
'Return (mean)': result['Return'][-100:].mean(skipna=False) if mode=='Train' else result['Return'][-5:].mean(skipna=False)
1414
}
1515
return result_dict
1616

@@ -19,7 +19,7 @@ def get_csv_result_dict(result, config_idx, mode='Train'):
1919
'Env': result['Env'][0],
2020
'Agent': result['Agent'][0],
2121
'Config Index': config_idx,
22-
'Return (mean)': result['Return (mean)'].mean(),
22+
'Return (mean)': result['Return (mean)'].mean(skipna=False),
2323
'Return (se)': result['Return (mean)'].sem(ddof=0)
2424
}
2525
return result_dict
@@ -29,6 +29,8 @@ def get_csv_result_dict(result, config_idx, mode='Train'):
2929
'merged': True,
3030
'x_label': 'Step',
3131
'y_label': 'Average Return',
32+
# 'rolling_score_window': 20,
33+
'rolling_score_window': -1,
3234
'hue_label': 'Agent',
3335
'show': False,
3436
'imgType': 'png',
@@ -39,7 +41,7 @@ def get_csv_result_dict(result, config_idx, mode='Train'):
3941
'ylim': {'min': None, 'max': None},
4042
'EMA': True,
4143
'loc': 'lower right',
42-
'sweep_keys': [],
44+
'sweep_keys': ['optimizer/actor_kwargs/lr', 'optimizer/critic_kwargs/lr', 'optimizer/reward_kwargs/lr'],
4345
'sort_by': ['Return (mean)', 'Return (se)'],
4446
'ascending': [False, True],
4547
'runs': 1

clean.sh

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Clean: remove all job indexes, and output
2+
rm -f job_idx_*
3+
rm -rf ./output/*
4+
rm -rf slurm-*.out

copyfile.sh

-8
This file was deleted.

find_config.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
from utils.sweeper import Sweeper
33

44

5-
def find_one_run():
6-
agent_config = 'mc_medqn.json'
5+
def find_cfg_idx():
6+
agent_config = 'MERL_mc_medqn.json'
77
config_file = os.path.join('./configs/', agent_config)
88
sweeper = Sweeper(config_file)
99
for i in range(1, 1+sweeper.config_dicts['num_combinations']):
@@ -13,7 +13,7 @@ def find_one_run():
1313
print()
1414

1515

16-
def find_many_runs():
16+
def get_cfg_idx_for_runs():
1717
l = [23,146,150,147,255,207,133,130,114,55,235,210,138,82,140,209,228,69,71,353,317]
1818
l.sort()
1919
print('len(l)=', len(l))
@@ -22,9 +22,10 @@ def find_many_runs():
2222
for x in l:
2323
ll.append(x+360*r)
2424
print('len(ll)=', len(ll))
25-
print(*ll)
25+
for x in ll:
26+
print(x, end=',')
2627

2728

2829
if __name__ == "__main__":
29-
find_one_run()
30-
# find_many_runs()
30+
find_cfg_idx()
31+
# get_cfg_idx_for_runs()

git_commit_id.sh

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
git rev-parse --short HEAD

move_log.sh

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Get git commit id
2+
git_id=$(git rev-parse --short HEAD)
3+
printf "$git_id\n"
4+
5+
# Create old_log directory
6+
mkdir old_logs
7+
dest_dir=./old_logs/logs-$git_id/
8+
mkdir $dest_dir
9+
10+
# Compress log files.
11+
cd logs
12+
files=$(ls)
13+
for filename in $files
14+
do
15+
printf "zip and move $filename to old_logs...\n"
16+
zip -rq $filename.zip ./$filename
17+
mv -f $filename.zip ../$dest_dir
18+
done
19+
cd ..

requirements.txt

-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
gym==0.23.1
2-
gym_games==1.0.4
31
matplotlib==3.5.2
42
numpy==1.22.0
53
opencv_python==4.5.5.64

run.py

+43-61
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import os
22
import sys
3+
import argparse
4+
from math import ceil
35
from utils.submitter import Submitter
46

57

@@ -9,83 +11,63 @@ def make_dir(dir):
911

1012

1113
def main(argv):
14+
# python run_narval.py --job_type S
15+
# python run_narval.py --job_type M
16+
parser = argparse.ArgumentParser(description="Submit jobs")
17+
parser.add_argument('--job_type', type=str, default='S', help='Run single (S) or multiple (M) jobs in one experiment: S, M')
18+
args = parser.parse_args()
19+
1220
sbatch_cfg = {
1321
# Account name
1422
# 'account': 'def-ashique',
1523
'account': 'rrg-ashique',
1624
# Job name
17-
# 'job-name': 'minatar_dqn',
18-
'job-name': 'minatar_dqn_sm',
19-
# 'job-name': 'minatar_medqn_real',
20-
# 'job-name': 'minatar_medqn_uniform',
25+
'job-name': 'MERL_mc_dqn',
2126
# Job time
22-
'time': '0-05:00:00',
23-
# GPU/CPU type
24-
'cpus-per-task': 1,
25-
# Memory
26-
# 'mem-per-cpu': '2500M',
27-
'mem-per-cpu': '1500M',
28-
# Email address
27+
'time': '0-01:00:00',
28+
# Email notification
2929
'mail-user': '[email protected]'
3030
}
31-
32-
# sbatch configs backup for different games
33-
# sbatch_cfg['job-name'], sbatch_cfg['time'], sbatch_cfg['mem-per-cpu'] = 'catcher', '0-10:00:00', '2000M'
34-
# sbatch_cfg['job-name'], sbatch_cfg['time'], sbatch_cfg['mem-per-cpu'] = 'copter', '0-05:00:00', '2000M'
35-
# sbatch_cfg['job-name'], sbatch_cfg['time'], sbatch_cfg['mem-per-cpu'] = 'lunar', '0-07:00:00', '2000M'
36-
# sbatch_cfg['job-name'], sbatch_cfg['time'], sbatch_cfg['mem-per-cpu'] = 'minatar', '0-05:00:00', '2500M'
37-
38-
39-
l_dqn = [11,15,19,7,13,17,9,12,16,6,10,2,18]
40-
l_dqn.sort()
41-
ll_dqn = []
42-
for r in range(1,10):
43-
for x in l_dqn:
44-
ll_dqn.append(x+20*r)
45-
46-
l_dqn_sm = [19,11,15,13,17,9,18,14,2,20,12]
47-
l_dqn_sm.sort()
48-
ll_dqn_sm = []
49-
for r in range(1,10):
50-
for x in l_dqn_sm:
51-
ll_dqn_sm.append(x+20*r)
52-
53-
l_uniform = [827,267,927,155,583,691,751,351,147,747,587,277,269,273,497,357,669,433,501,509,821,205,517,577,254,270,826,746,490,510,730,430,830,734,732,736,652,888,656,892,512,496,572,592,508,352]
54-
l_uniform.sort()
55-
ll_uniform = []
56-
for r in range(1,10):
57-
for x in l_uniform:
58-
ll_uniform.append(x+960*r)
59-
60-
l_real = [195,643,27,179,423,115,403,187,191,267,419,43,351,199,31,203,357,121,41,201,125,285,129,133,213,749,429,433,517,493,501,505,489,57,432,888,884,564,648,664,644,188,416,652,276,352,340,108,256,426,106,402,566,110,510,406,410,254,414,206,574]
61-
l_real.sort()
62-
ll_real = []
63-
for r in range(1,10):
64-
for x in l_real:
65-
ll_real.append(x+960*r)
66-
6731
general_cfg = {
6832
# User name
6933
'user': 'qlan3',
70-
# Sbatch script path
71-
'script-path': './sbatch.sh',
7234
# Check time interval in minutes
7335
'check-time-interval': 5,
74-
# Clusters info: {name: capacity}
75-
'clusters': {'Narval': 1000},
36+
# Clusters info: name & capacity
37+
'cluster_name': 'Narval',
38+
'cluster_capacity': 996,
7639
# Job indexes list
77-
# 'job-list': list(range(1, 20+1))
78-
# 'job-list': list(range(1, 960+1))
79-
# 'job-list': ll_uniform
80-
# 'job-list': ll_real
81-
# 'job-list': ll_dqn
82-
'job-list': ll_dqn_sm
83-
# 'job-list': []
40+
'job-list': list(range(1, 10+1))
8441
}
85-
8642
make_dir(f"output/{sbatch_cfg['job-name']}")
87-
submitter = Submitter(general_cfg, sbatch_cfg)
88-
submitter.submit()
43+
44+
if args.job_type == 'M':
45+
# Max number of parallel jobs in one experiment
46+
max_parallel_jobs = 4
47+
mem_per_job = 16 # in GB
48+
cpu_per_job = 2 # Increase cpus_per_job to 5/10 can further increase training speed.
49+
mem_per_cpu = int(ceil(mem_per_job/cpu_per_job))
50+
# Write to procfile for Parallel
51+
with open('procfile', 'w') as f:
52+
f.write(str(max_parallel_jobs))
53+
sbatch_cfg['gres'] = 'gpu:1' # GPU type
54+
sbatch_cfg['cpus-per-task'] = cpu_per_job*max_parallel_jobs
55+
sbatch_cfg['mem-per-cpu'] = f'{mem_per_cpu}G' # Memory
56+
# Sbatch script path
57+
general_cfg['script-path'] = './sbatch_m.sh'
58+
# Max number of jobs for Parallel
59+
general_cfg['max_parallel_jobs'] = max_parallel_jobs
60+
submitter = Submitter(general_cfg, sbatch_cfg)
61+
submitter.multiple_submit()
62+
elif args.job_type == 'S':
63+
mem_per_cpu = 1500
64+
sbatch_cfg['cpus-per-task'] = 1
65+
sbatch_cfg['mem-per-cpu'] = f'{mem_per_cpu}M' # Memory
66+
# Sbatch script path
67+
general_cfg['script-path'] = './sbatch_s.sh'
68+
submitter = Submitter(general_cfg, sbatch_cfg)
69+
submitter.single_submit()
70+
8971

9072
if __name__=='__main__':
9173
main(sys.argv)

sbatch_m.sh

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#!/bin/bash
2+
# Ask SLURM to send the USR1 signal 300 seconds before end of the time limit
3+
#SBATCH --signal=B:USR1@300
4+
#SBATCH --output=output/%x/%a.txt
5+
#SBATCH --mail-type=ALL
6+
#SBATCH --exclude=nc20552,nc11001,nc11002,nc11103,nc11126,nc10303,nc20305,nc10249,nc20325,nc11124,nc20529,nc20526,nc20342,nc20354,nc30616,nc30305,nc20133,nc10220
7+
8+
# ---------------------------------------------------------------------
9+
echo "Current working directory: `pwd`"
10+
echo "Starting run at: `date`"
11+
# ---------------------------------------------------------------------
12+
echo "Job Array ID / Job ID: $SLURM_ARRAY_JOB_ID / $SLURM_JOB_ID"
13+
echo "This is job $SLURM_ARRAY_TASK_ID out of $SLURM_ARRAY_TASK_COUNT jobs"
14+
echo "SLURM_TMPDIR: $SLURM_TMPDIR"
15+
echo "SLURM_JOB_NODELIST: $SLURM_JOB_NODELIST"
16+
# ---------------------------------------------------------------------
17+
cleanup()
18+
{
19+
echo "Copy log files from temporary directory"
20+
sour=$SLURM_TMPDIR/$SLURM_JOB_NAME/.
21+
dest=./logs/$SLURM_JOB_NAME/
22+
echo "Source directory: $sour"
23+
echo "Destination directory: $dest"
24+
cp -rf $sour $dest
25+
}
26+
# Call `cleanup` once we receive USR1 or EXIT signal
27+
trap 'cleanup' USR1 EXIT
28+
# ---------------------------------------------------------------------
29+
# export OMP_NUM_THREADS=1
30+
module load gcc/9.3.0 arrow/2.0.0 python/3.8 scipy-stack
31+
source ~/envs/tianshou/bin/activate
32+
33+
parallel --ungroup --jobs procfile python main.py --config_file ./configs/${SLURM_JOB_NAME}.json --config_idx {1} --slurm_dir $SLURM_TMPDIR :::: job_idx_${SLURM_JOB_NAME}_${SLURM_ARRAY_TASK_ID}.txt
34+
# parallel --eta --ungroup --jobs procfile python main.py --config_file ./configs/${SLURM_JOB_NAME}.json --config_idx {1} --slurm_dir $SLURM_TMPDIR :::: job_idx_${SLURM_JOB_NAME}_${SLURM_ARRAY_TASK_ID}.txt
35+
# parallel --ungroup --jobs procfile python main.py --config_file ./configs/${SLURM_JOB_NAME}.json --config_idx {1} :::: job_idx_${SLURM_JOB_NAME}_${SLURM_ARRAY_TASK_ID}.txt
36+
37+
# ---------------------------------------------------------------------
38+
echo "Job finished with exit code $? at: `date`"
39+
# ---------------------------------------------------------------------

sbatch.sh renamed to sbatch_s.sh

+8-6
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
#SBATCH --signal=B:USR1@300
44
#SBATCH --output=output/%x/%a.txt
55
#SBATCH --mail-type=ALL
6-
#SBATCH --mail-type=TIME_LIMIT
7-
#SBATCH --exclude=nc20552,nc11103,nc11126,nc10303,nc20305,nc10249,nc20325,nc11124,nc20529,nc20526,nc20342,nc20354,nc30616,nc30305,nc20133,nc10220
6+
#SBATCH --exclude=nc20552,nc11001,nc11002,nc11004,nc11003,nc11010,nc11011,nc11022,nc11025,nc11103,nc11126,nc10303,nc20305,nc10249,nc20325,nc11124,nc20529,nc20526,nc20342,nc20354,nc30616,nc30305,nc20133,nc10220
7+
88
# ---------------------------------------------------------------------
99
echo "Current working directory: `pwd`"
1010
echo "Starting run at: `date`"
@@ -17,7 +17,7 @@ echo "SLURM_JOB_NODELIST: $SLURM_JOB_NODELIST"
1717
cleanup()
1818
{
1919
echo "Copy log files from temporary directory"
20-
sour=$SLURM_TMPDIR/$SLURM_JOB_NAME/$SLURM_ARRAY_TASK_ID/
20+
sour=$SLURM_TMPDIR/$SLURM_JOB_NAME/.
2121
dest=./logs/$SLURM_JOB_NAME/
2222
echo "Source directory: $sour"
2323
echo "Destination directory: $dest"
@@ -26,11 +26,13 @@ cleanup()
2626
# Call `cleanup` once we receive USR1 or EXIT signal
2727
trap 'cleanup' USR1 EXIT
2828
# ---------------------------------------------------------------------
29-
export OMP_NUM_THREADS=1
30-
module load gcc/9.3.0 arrow/2.0.0 python/3.7 scipy-stack
31-
source ~/envs/gym/bin/activate
29+
# export OMP_NUM_THREADS=1
30+
module load gcc/9.3.0 arrow/2.0.0 python/3.8 scipy-stack
31+
source ~/envs/tianshou/bin/activate
32+
3233
python main.py --config_file ./configs/${SLURM_JOB_NAME}.json --config_idx $SLURM_ARRAY_TASK_ID --slurm_dir $SLURM_TMPDIR
3334
# python main.py --config_file ./configs/${SLURM_JOB_NAME}.json --config_idx $SLURM_ARRAY_TASK_ID
35+
3436
# ---------------------------------------------------------------------
3537
echo "Job finished with exit code $? at: `date`"
3638
# ---------------------------------------------------------------------

utils/logger.py

+5-11
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
class Logger(object):
77
def __init__(self, logs_dir, file_name='log.txt', filemode='w'):
8+
self.logs_dir = logs_dir
89
logging.basicConfig(
910
format='%(asctime)s - %(levelname)s: %(message)s',
1011
filename=f'{logs_dir}{file_name}',
@@ -18,18 +19,11 @@ def __init__(self, logs_dir, file_name='log.txt', filemode='w'):
1819
self.warning = logger.warning
1920
self.error = logger.error
2021
self.critical = logger.critical
21-
22-
self.logs_dir = logs_dir
22+
# Set default writer
2323
self.writer = None
2424

2525
def init_writer(self):
2626
self.writer = SummaryWriter(self.logs_dir)
27-
28-
def add_scalar(self, tag, scalar_value, global_step=None):
29-
self.writer.add_scalar(tag, scalar_value, global_step)
30-
31-
def add_scalars(self, main_tag, tag_scalar_dict, global_step=None):
32-
self.writer.add_scalars(main_tag, tag_scalar_dict, global_step)
33-
34-
def add_histogram(self, tag, values, global_step=None):
35-
self.writer.add_histogram(tag, values, global_step)
27+
self.add_scalar = self.writer.add_scalar # Input: tag, scalar_value, global_step
28+
self.add_scalars = self.writer.add_scalars # Input: main_tag, tag_scalar_dict, global_step
29+
self.add_histogram = self.writer.add_histogram # Input: tag, values, global_step

0 commit comments

Comments
 (0)