-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfunsearch_gls_tsp_api.py
253 lines (225 loc) · 10.4 KB
/
funsearch_gls_tsp_api.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
import http.client
import json
import multiprocessing
import pickle
import time
from argparse import ArgumentParser
from typing import Collection, Any, Tuple
import tiktoken
import requests
import os
# Thanks to Fei LIU for the evaluation of gls_tsp is helped by AEL module
from gls_tsp.eval_helper import ael_evaluation
from gls_tsp import utils
from funsearch_impl import funsearch
from funsearch_impl import config
from funsearch_impl import sampler
from funsearch_impl import evaluator_accelerate
from funsearch_impl import evaluator
from funsearch_impl import code_manipulation
parser = ArgumentParser()
parser.add_argument('--run', type=int)
# parser.add_argument('--port', type=int, default=11045)
parser.add_argument('--config', type=str, default='run_runtime_llm_config.json')
parser.add_argument('--resume_run', default=False, action='store_true')
parser.add_argument('--llm', type=str, default='gemini-1.5-flash')
parser.add_argument('--key', type=str)
args = parser.parse_args()
def _trim_preface_of_body(sample: str) -> str:
"""Trim the redundant descriptions/symbols/'def' declaration before the function body.
Please see my comments in sampler.LLM (in sampler.py).
Since the LLM used in this file is not a pure code completion LLM, this trim function is required.
-Example sample (function & description generated by LLM):
-------------------------------------
This is the optimized function ...
def priority_v2(...) -> ...:
return ...
This function aims to ...
-------------------------------------
-This function removes the description above the function's signature, and the function's signature.
-The indent of the code is preserved.
-Return of this function:
-------------------------------------
return ...
This function aims to ...
-------------------------------------
"""
lines = sample.splitlines()
func_body_lineno = 0
find_def_declaration = False
for lineno, line in enumerate(lines):
# find the first 'def' statement in the given code
if line[:3] == 'def':
func_body_lineno = lineno
find_def_declaration = True
break
if find_def_declaration:
code = ''
for line in lines[func_body_lineno + 1:]:
code += line + '\n'
return code
return sample
class LLMAPI(sampler.LLM):
"""Language model that predicts continuation of provided source code.
"""
def __init__(self, samples_per_prompt: int, timeout=30, trim=True):
super().__init__(samples_per_prompt)
additional_prompt = ('Complete a different and more complex Python function. '
'Be creative and you can insert multiple if-else and for-loop in the code logic.'
'Only output the Python code, no descriptions.'
'Do not repeat update_edge_distance_v0 in your response')
self._additional_prompt = additional_prompt
self._trim = trim
self._timeout = timeout
self.prompt_tokens = 0
self.completion_tokens = 0
self.model = 'gemini'
def cal_usage_LLM(self, lst_prompt, lst_completion, encoding_name="cl100k_base"):
"""Returns the number of tokens in a text string."""
encoding = tiktoken.get_encoding(encoding_name)
for i in range(len(lst_prompt)):
if 'gemini' in self.model:
self.prompt_tokens += len(encoding.encode(lst_prompt[i][0] + " " + lst_prompt[i][1]))
else:
for message in lst_prompt[i]:
for key, value in message.items():
self.prompt_tokens += len(encoding.encode(value))
self.completion_tokens += len(encoding.encode(lst_completion[i]))
def draw_samples(self, prompt: str) -> Collection[str]:
"""Returns multiple predicted continuations of `prompt`."""
return [self._draw_sample(prompt) for _ in range(self._samples_per_prompt)]
def _draw_sample(self, content: str) -> str:
prompt = '\n'.join([content, self._additional_prompt])
while True:
try:
if args.key is None:
raise Exception("API key is require input!")
if 'gemini' in args.llm:
conn = http.client.HTTPSConnection("generativelanguage.googleapis.com")
payload = json.dumps({
"contents": [{"parts": [{"text": prompt}]}]
})
headers = {
'Content-Type': 'application/json'
}
conn.request("POST",
"/v1beta/models/gemini-1.5-flash-latest:generateContent?key=" + args.key,
payload, headers)
res = conn.getresponse()
data = res.read().decode("utf-8")
data = json.loads(data)
# print("Prompt:", prompt)
response = data['candidates'][0]['content']['parts'][0]['text']
else:
conn = http.client.HTTPSConnection("api.openai.com", timeout=self._timeout)
payload = json.dumps({
"model": args.llm,
"messages": [
{
"role": "user",
"content": prompt
}
]
})
headers = {
'Authorization': 'Bearer ' + args.key,
'Content-Type': 'application/json'
}
conn.request("POST", "/v1/chat/completions", payload, headers)
res = conn.getresponse()
data = res.read().decode("utf-8")
data = json.loads(data)
response = data['choices'][0]['message']['content']
self.cal_usage_LLM([['', prompt]], [response])
print(f"LLM usage: prompt_tokens = {self.prompt_tokens}, completion_tokens = {self.completion_tokens}")
append_to_file(f'logs/run{args.run}/main.log',
f"LLM usage: prompt_tokens = {self.prompt_tokens}, completion_tokens = {self.completion_tokens}")
# print("Response:", response)
# trim function
if self._trim:
response = _trim_preface_of_body(response)
return response
except Exception:
time.sleep(2)
continue
class Sandbox(evaluator.Sandbox):
"""Sandbox for executing generated code. Implemented by RZ.
RZ: Sandbox returns the 'score' of the program and:
1) avoids the generated code to be harmful (accessing the internet, take up too much RAM).
2) stops the execution of the code in time (avoid endless loop).
"""
def __init__(
self, verbose=False,
numba_accelerate=True
):
"""
Args:
verbose : Print evaluate information.
numba_accelerate: Use numba to accelerate the evaluation. It should be noted that not all numpy functions
support numba acceleration, such as np.piecewise().
"""
self._verbose = verbose
self._numba_accelerate = numba_accelerate
def run(
self,
program: str,
function_to_run: str, # RZ: refers to the name of the function to run (e.g., 'evaluate')
function_to_evolve: str, # RZ: accelerate the code by decorating @numba.jit() on function_to_evolve.
inputs: Any, # refers to the dataset
test_input: str, # refers to the current instance
timeout_seconds: int,
# **kwargs # RZ: add this
) -> tuple[Any, bool]:
"""Returns `function_to_run(test_input)` and whether execution succeeded.
RZ: If the generated code (generated by LLM) is executed successfully,
the output of this function is the score of a given program.
The evaluate time limitation and exception handling modules are implemented within AEL's evaluation module.
"""
try:
if self._numba_accelerate:
program = evaluator_accelerate.add_numba_decorator(
program=program,
function_name=[function_to_evolve]
)
# compile the program, and maps the global func/var/class name to its address
all_globals_namespace = {}
# execute the program, map func/var/class to global namespace
exec(program, all_globals_namespace)
# get the pointer of 'function_to_evolve', which will be sent to AEL's evaluation module later
function_to_evolve_pointer = all_globals_namespace[function_to_evolve]
evaluator = ael_evaluation.Evaluation()
# do evaluate
results = evaluator.evaluate(heuristic_func=function_to_evolve_pointer)
# make sure the score is int or float
if results is not None:
if not isinstance(results, (int, float)):
results = (None, False)
else:
# negation because our optimization objective is bigger, the better
results = (-results, True) # convert to FunSearch result format
else:
results = None, False
except:
results = None, False
return results
def append_to_file(file_path, text):
# Ensure the directory exists
os.makedirs(os.path.dirname(file_path), exist_ok=True)
# Append the text to the file, creating it if it doesn't exist
with open(file_path, 'a') as file:
file.write(text + '\n')
# It should be noted that the if __name__ == '__main__' is required.
# Because the inner code uses multiprocess evaluation.
if __name__ == '__main__':
class_config = config.ClassConfig(llm_class=LLMAPI, sandbox_class=Sandbox)
config = config.Config(samples_per_prompt=4, evaluate_timeout_seconds=100)
global_max_sample_num = 4 # if it is set to None, funsearch will execute an endless loop
funsearch.main(
specification=utils.specification,
inputs=[None],
config=config,
max_sample_nums=global_max_sample_num,
class_config=class_config,
log_dir=f'logs/run{args.run}',
resume_run=args.resume_run
)