Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit dc73868

Browse files
Jules PondardJules Pondard
Jules Pondard
authored and
Jules Pondard
committed
Add vanilla MCTS/UCT algorithm for options
Generate options using UCT algorithm. Useful for benchmarking.
1 parent 59e5bcc commit dc73868

File tree

1 file changed

+179
-0
lines changed

1 file changed

+179
-0
lines changed
+179
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
import tensor_comprehensions as tc
2+
import torch
3+
import utils
4+
import numpy as np
5+
#from tqdm import tqdm
6+
from visdom import Visdom
7+
8+
viz = Visdom()
9+
10+
class Node:
11+
def __init__(self, father=None, new_act=0):
12+
self.value = 0
13+
self.values = []
14+
self.nbVisits=0
15+
self.nbChildrenSeen = 0
16+
self.pos=0
17+
#self.hasSeen = {} #todo
18+
self.children=[]
19+
self.parent = father
20+
self.stateVector = [0] * utils.NB_HYPERPARAMS
21+
if(father != None):
22+
self.pos = father.pos+1
23+
#self.hasSeen = {} #todo
24+
self.stateVector = father.stateVector[:]
25+
self.stateVector[self.pos-1] = new_act
26+
27+
def getRoot(self):
28+
return self
29+
30+
def getParent(self):
31+
return self.parent
32+
33+
def notRoot(self):
34+
return (self.parent != None)
35+
36+
class MCTS:
37+
def __init__(self):
38+
self.C = 1 #to tune
39+
40+
self.exptuner_config = utils.ExpTunerConfig()
41+
self.exptuner_config.set_convolution_tc()
42+
43+
self.nbActions = self.exptuner_config.cat_sz
44+
self.tree = Node()
45+
46+
self.best_rewards = []
47+
self.rws = []
48+
49+
self.curIter=0
50+
self.curr_best=0
51+
self.running_reward=0
52+
self.win0 = viz.line(X=np.arange(5), Y=np.random.rand(5))
53+
54+
def main_search(self, starting_pos): #, init_inp):
55+
node = starting_pos
56+
#node.nbVisits+=1
57+
ttNbIters = 10 #2*self.nbActions[node.pos]
58+
for _ in range(max(ttNbIters, self.nbActions[node.pos])):
59+
leaf = self.getLeaf(node)
60+
val = self.evaluate(leaf)
61+
self.backup(leaf, val)
62+
#print(node.value / node.nbVisits)
63+
_, action = self.getBestChild2(node)
64+
return action
65+
66+
def take_action(self, node, act):
67+
if(node.nbChildrenSeen > act):
68+
return node.children[act]
69+
new_child = Node(father=node, new_act=act)
70+
node.children.append(new_child)
71+
#node.hasSeen[act]=1
72+
node.nbChildrenSeen += 1
73+
return node.children[-1]
74+
75+
def getLeaf(self, node):
76+
first=True
77+
while(node.pos < utils.NB_HYPERPARAMS and (first or node.nbVisits != 0)):
78+
first=False
79+
pos = node.pos
80+
if(node.nbChildrenSeen == self.nbActions[pos]):
81+
node, _ = self.getBestChild(node)
82+
else:
83+
act=node.nbChildrenSeen
84+
self.take_action(node, act)
85+
return node.children[-1]
86+
return node
87+
88+
def getBestChild2(self, node):
89+
bestIndic = 0.
90+
bestAction = 0
91+
first=True
92+
pos = node.pos
93+
for act in range(self.nbActions[pos]):
94+
child = node.children[act]
95+
#indic = np.percentile(child.values, 20)
96+
indic = child.value / child.nbVisits
97+
if(first or indic > bestIndic):
98+
bestIndic = indic
99+
bestAction = act
100+
first=False
101+
return node.children[bestAction], bestAction
102+
103+
def getBestChild(self, node):
104+
bestIndic = 0.
105+
bestAction = 0
106+
first=True
107+
pos = node.pos
108+
for act in range(self.nbActions[pos]):
109+
child = node.children[act]
110+
#indic = np.percentile(child.values, 20) + self.C * np.sqrt(2*np.log(node.nbVisits) / child.nbVisits)
111+
indic = child.value / child.nbVisits + self.C * np.sqrt(2*np.log(node.nbVisits) / child.nbVisits)
112+
if(first or indic > bestIndic):
113+
bestIndic = indic
114+
bestAction = act
115+
first=False
116+
return node.children[bestAction], bestAction
117+
118+
def saveReward(self, reward, opts):
119+
INTER_DISP = 20
120+
#print(-reward)
121+
if(self.curIter == 0):
122+
self.running_reward = reward
123+
self.curr_best = reward
124+
if(self.curIter == 0 or reward > self.curr_best):
125+
print(-reward)
126+
print(opts)
127+
self.curIter += 1
128+
self.running_reward = self.running_reward * 0.99 + reward * 0.01
129+
self.curr_best = max(self.curr_best, reward)
130+
#self.rewards.append(-reward)
131+
self.best_rewards.append(-self.curr_best)
132+
self.rws.append(-self.running_reward)
133+
if self.curIter % INTER_DISP == 0:
134+
viz.line(X=np.column_stack((np.arange(self.curIter), np.arange(self.curIter))), \
135+
Y=np.column_stack((np.array(self.rws), np.array(self.best_rewards))), \
136+
win=self.win0, opts=dict(legend=["Geometric run", "Best time"]))
137+
138+
def randomSampleScoreFrom(self, node):
139+
pos = node.pos
140+
optsVector = node.stateVector
141+
for i in range(utils.NB_HYPERPARAMS - (pos)):
142+
a = np.random.randint(self.nbActions[i+pos])
143+
optsVector[i+(pos)] = a
144+
#print(optsVector)
145+
reward = -np.log(utils.evalTime(optsVector, self.exptuner_config))
146+
self.saveReward(reward, optsVector)
147+
return reward
148+
149+
def evaluate(self, leaf):
150+
score = 0
151+
nb_iters=5
152+
for _ in range(nb_iters):
153+
score += self.randomSampleScoreFrom(leaf)
154+
return score / nb_iters
155+
156+
def backup(self, leaf, val):
157+
#if(val > 10.): #infty
158+
# return
159+
node = leaf
160+
while(node.notRoot()):
161+
node.nbVisits += 1
162+
#node.values.append(val)
163+
node.value += val
164+
node = node.getParent()
165+
node.nbVisits += 1
166+
node.value += val
167+
node.values.append(val)
168+
169+
mcts = MCTS()
170+
171+
opts = []
172+
curr_node = mcts.tree
173+
for i in range(utils.NB_HYPERPARAMS):
174+
opts.append(mcts.main_search(curr_node))
175+
curr_node = mcts.take_action(curr_node, opts[-1])
176+
print(opts)
177+
opts = np.array(opts).astype(int)
178+
print(utils.evalTime(opts.tolist(), mcts.exptuner_config))
179+
utils.print_opt(opts)

0 commit comments

Comments
 (0)