|
| 1 | +import os |
| 2 | +import anthropic |
| 3 | +import ast, astor |
| 4 | +from cfg import * |
| 5 | +import re |
| 6 | +import sys |
| 7 | +import trace_execution |
| 8 | +import os |
| 9 | +import io |
| 10 | +import pandas as pd |
| 11 | +from torchtext import data |
| 12 | +from torchtext.data import Iterator |
| 13 | +import pandas as pd |
| 14 | +import torch |
| 15 | +import model |
| 16 | +import time |
| 17 | +import config |
| 18 | +import argparse |
| 19 | + |
| 20 | +def generate_prompt(method_code, feedback=""): |
| 21 | + prompt = f""" |
| 22 | +\n\nHuman: You are a terminal. Analyze the following Python code and generate likely inputs for all variables that might raise errors. Add these generated inputs at the beginning of the code snippet. |
| 23 | +
|
| 24 | +Example: |
| 25 | +Python Method: |
| 26 | +if(S[0]=="A" and S[2,-1].count("C")==1): |
| 27 | + cnt=0 |
| 28 | + for i in S: |
| 29 | + if(97<=ord(i) and ord(i)<=122): |
| 30 | + cnt+=1 |
| 31 | + if(cnt==2): |
| 32 | + print("AC") |
| 33 | + else : |
| 34 | + print("WA") |
| 35 | +else : |
| 36 | + print("WA") |
| 37 | +
|
| 38 | +Generated Input: |
| 39 | +S = 'AtCoder' |
| 40 | +
|
| 41 | +Task: |
| 42 | +Given the following Python method, generate likely inputs for variables: |
| 43 | +{feedback} |
| 44 | +
|
| 45 | +Python Method: |
| 46 | +{method_code} |
| 47 | +
|
| 48 | +Generated Input: |
| 49 | +(No explanation needed, only one Generated Input:) |
| 50 | +\n\nAssistant: |
| 51 | + """ |
| 52 | + return prompt |
| 53 | + |
| 54 | +def get_generated_inputs(claude_api_key, model, method_code, feedback=""): |
| 55 | + client = anthropic.Anthropic(api_key=claude_api_key) |
| 56 | + prompt = generate_prompt(method_code, feedback) |
| 57 | + response = client.messages.create( |
| 58 | + model=model, |
| 59 | + max_tokens=1024, |
| 60 | + messages=[ |
| 61 | + {"role": "user", "content": prompt} |
| 62 | + ] |
| 63 | + ) |
| 64 | + return response.content[0].text |
| 65 | + |
| 66 | +def add_generated_inputs_to_code(code, inputs): |
| 67 | + lines = code.split('\n') |
| 68 | + # Find the first non-import line |
| 69 | + insert_index = 0 |
| 70 | + for i, line in enumerate(lines): |
| 71 | + if not line.startswith(('import', 'from', '\n')): |
| 72 | + insert_index = i |
| 73 | + break |
| 74 | + |
| 75 | + # Insert generated inputs at the found index |
| 76 | + for input_line in inputs.split('\n'): |
| 77 | + if input_line.startswith("Generated"): |
| 78 | + continue |
| 79 | + if input_line.strip(): |
| 80 | + lines.insert(insert_index, input_line) |
| 81 | + insert_index += 1 |
| 82 | + |
| 83 | + return '\n'.join(lines) |
| 84 | + |
| 85 | +def read_data(data_path, fields): |
| 86 | + csv_data = pd.read_csv(data_path, chunksize=100) |
| 87 | + all_examples = [] |
| 88 | + for n, chunk in enumerate(csv_data): |
| 89 | + examples = chunk.apply(lambda r: data.Example.fromlist([eval(r['nodes']), eval(r['forward']), eval(r['backward']), |
| 90 | + eval(r['target'])], fields), axis=1) |
| 91 | + all_examples.extend(list(examples)) |
| 92 | + return all_examples |
| 93 | + |
| 94 | +opt = config.parse() |
| 95 | +if opt.claude_api_key == None: |
| 96 | + raise Exception("Lack of CLAUDE api") |
| 97 | +if opt.cuda_num == None: |
| 98 | + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| 99 | +else: |
| 100 | + device = torch.device(f"cuda:{opt.cuda_num}" if torch.cuda.is_available() else "cpu") |
| 101 | + |
| 102 | +TEXT = data.Field(tokenize=lambda x: x.split()[:512]) |
| 103 | +NODE = data.NestedField(TEXT, preprocessing=lambda x: x[:100], include_lengths=True) |
| 104 | +ROW = data.Field(pad_token=1.0, use_vocab=False, |
| 105 | + preprocessing=lambda x: [1, 1] if any(i > 100 for i in x) else x) |
| 106 | +EDGE = data.NestedField(ROW) |
| 107 | +TARGET = data.Field(use_vocab=False, preprocessing=lambda x: x[:100], pad_token=0, batch_first=True) |
| 108 | + |
| 109 | +fields = [("nodes", NODE), ("forward", EDGE), ("backward", EDGE), ("target", TARGET)] |
| 110 | + |
| 111 | +print('Read data...') |
| 112 | +examples = read_data(f'data/CodeNet_train.csv', fields) |
| 113 | +train = data.Dataset(examples, fields) |
| 114 | +NODE.build_vocab(train, max_size=100000) |
| 115 | + |
| 116 | +orin_nodes = ['BEGIN', "_in = ['2', 3]", 'cont_str = _in[0] * _in[1]', 'cont_num = int(cont_str)', 'sqrt_flag = False', 'p1 = 0', 'p1 < len(range(4, 100))', 'T i = range(4, 100)[p1]', 'sqrt_flag', 'sqrt = i * i', 'cont_num == sqrt', 'T sqrt_flag = True', 'p1 += 1', "T print('Yes')", "print('No')", 'EXIT'] |
| 117 | +orin_fwd_edges = [(1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7), (7, 8), (7, 9), (8, 10), (10, 11), (11, 12), (12, 9), (9, 14), (9, 15), (11, 13), (14, 16), (15, 16)] |
| 118 | +orin_back_edges = [(13, 7)] |
| 119 | +orin_exe_path = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1] |
| 120 | + |
| 121 | +net = model.CodeFlow(opt).to(device) |
| 122 | +checkpoint_path = f"checkpoints/checkpoints_{opt.checkpoint}/epoch-{opt.epoch}.pt" |
| 123 | +net.load_state_dict(torch.load(checkpoint_path, map_location=device)) |
| 124 | +net.eval() |
| 125 | + |
| 126 | +outpath = 'fuzz_testing_output' |
| 127 | +if not os.path.exists(outpath): |
| 128 | + os.makedirs(outpath) |
| 129 | + |
| 130 | +def extract_inputs(generated_text): |
| 131 | + # Use regular expression to match lines that are variable assignments |
| 132 | + input_lines = re.findall(r'^\s*\w+\s*=\s*.+$', generated_text, re.MULTILINE) |
| 133 | + return '\n'.join(input_lines) |
| 134 | + |
| 135 | +error_dict = {} |
| 136 | +locate = 0 |
| 137 | +for root, _, files in os.walk(opt.folder_path): |
| 138 | + files = sorted(files, key=lambda x: int(x.split('.')[0][5:])) |
| 139 | + for file in files: |
| 140 | + print(f'Fuzz testing file {file}') |
| 141 | + feedback_list = [] |
| 142 | + start_time = time.time() |
| 143 | + time_limit = opt.time # time limit in seconds |
| 144 | + repeat = True |
| 145 | + while repeat: |
| 146 | + if time.time() - start_time > time_limit: |
| 147 | + print(f'Time limit exceeded for file {file}') |
| 148 | + break |
| 149 | + feedback = f"\nThese inputs did not raise runtime errors, avoid to generate the same:\n{feedback_list}" if feedback_list else "" |
| 150 | + file_path = os.path.join(opt.folder_path, file) |
| 151 | + with open(file_path, 'r') as f: |
| 152 | + code = f.read() |
| 153 | + generated_inputs = get_generated_inputs(opt.claude_api_key, opt.model, code, feedback) |
| 154 | + generated_inputs = extract_inputs(generated_inputs) |
| 155 | + print(generated_inputs) |
| 156 | + # Add generated inputs to the original code |
| 157 | + modified_code = add_generated_inputs_to_code(code, generated_inputs) |
| 158 | + filename = os.path.join(outpath, file) |
| 159 | + with open(filename, 'w') as modified_file: |
| 160 | + modified_file.write(modified_code) |
| 161 | + |
| 162 | + BlockId().counter = 0 |
| 163 | + try: |
| 164 | + source = open(filename, 'r').read() |
| 165 | + compile(source, filename, 'exec') |
| 166 | + except: |
| 167 | + print('Error in source code') |
| 168 | + exit(1) |
| 169 | + parser = PyParser(source) |
| 170 | + parser.removeCommentsAndDocstrings() |
| 171 | + parser.formatCode() |
| 172 | + try: |
| 173 | + cfg = CFGVisitor().build(filename, ast.parse(parser.script)) |
| 174 | + except AttributeError: |
| 175 | + continue |
| 176 | + except IndentationError: |
| 177 | + continue |
| 178 | + except TypeError: |
| 179 | + continue |
| 180 | + except SyntaxError: |
| 181 | + continue |
| 182 | + |
| 183 | + cfg.clean() |
| 184 | + try: |
| 185 | + cfg.track_execution() |
| 186 | + except Exception: |
| 187 | + print("Generated input is not valid") |
| 188 | + continue |
| 189 | + code = {} |
| 190 | + for_loop = {} |
| 191 | + for i in cfg.blocks: |
| 192 | + if cfg.blocks[i].for_loop != 0: |
| 193 | + if cfg.blocks[i].for_loop not in for_loop: |
| 194 | + for_loop[cfg.blocks[i].for_loop] = [i] |
| 195 | + else: |
| 196 | + for_loop[cfg.blocks[i].for_loop].append(i) |
| 197 | + first = [] |
| 198 | + second = [] |
| 199 | + for i in for_loop: |
| 200 | + first.append(for_loop[i][0]+1) |
| 201 | + second.append(for_loop[i][1]) |
| 202 | + orin_node = [] |
| 203 | + track = {} |
| 204 | + track_for = {} |
| 205 | + for i in cfg.blocks: |
| 206 | + if cfg.blocks[i].stmts_to_code(): |
| 207 | + if int(i) == 1: |
| 208 | + st = 'BEGIN' |
| 209 | + elif int(i) == len(cfg.blocks): |
| 210 | + st = 'EXIT' |
| 211 | + else: |
| 212 | + if i in first: |
| 213 | + line = astor.to_source(cfg.blocks[i].for_name) |
| 214 | + st = line.split('\n')[0] |
| 215 | + st = re.sub(r"\s+", "", st).replace('"', "'").replace("(", "").replace(")", "") |
| 216 | + else: |
| 217 | + st = cfg.blocks[i].stmts_to_code() |
| 218 | + st = re.sub(r"\s+", "", st).replace('"', "'").replace("(", "").replace(")", "") |
| 219 | + orin_node.append([i, st, None]) |
| 220 | + if st not in track: |
| 221 | + track[st] = [len(orin_node)-1] |
| 222 | + else: |
| 223 | + track[st].append(len(orin_node)-1) |
| 224 | + track_for[i] = len(orin_node)-1 |
| 225 | + with open(filename, 'r') as file_open: |
| 226 | + lines = file_open.readlines() |
| 227 | + for i in range(1, len(lines)+1): |
| 228 | + line = lines[i-1] |
| 229 | + #delete \n at the end of each line and delete all spaces |
| 230 | + line = line.strip() |
| 231 | + line = re.sub(r"\s+", "", line).replace('"', "'").replace("(", "").replace(")", "") |
| 232 | + if line.startswith('elif'): |
| 233 | + line = line[2:] |
| 234 | + if line in track: |
| 235 | + orin_node[track[line][0]][2] = i |
| 236 | + if orin_node[track[line][0]][0] in first: |
| 237 | + orin_node[track[line][0]-1][2] = i-0.4 |
| 238 | + orin_node[track[line][0]+1][2] = i+0.4 |
| 239 | + if len(track[line]) > 1: |
| 240 | + track[line].pop(0) |
| 241 | + for i in second: |
| 242 | + max_val = 0 |
| 243 | + for edge in cfg.edges: |
| 244 | + if edge[0] == i: |
| 245 | + if orin_node[track_for[edge[1]]][2] > max_val: |
| 246 | + max_val = orin_node[track_for[edge[1]]][2] |
| 247 | + if edge[1] == i: |
| 248 | + if orin_node[track_for[edge[0]]][2] > max_val: |
| 249 | + max_val = orin_node[track_for[edge[0]]][2] |
| 250 | + orin_node[track_for[i]][2] = max_val + 0.5 |
| 251 | + orin_node[0][2] = 0 |
| 252 | + orin_node[-1][2] = len(lines)+1 |
| 253 | + # sort orin_node by the third element |
| 254 | + orin_node.sort(key=lambda x: x[2]) |
| 255 | + |
| 256 | + nodes = [] |
| 257 | + matching = {} |
| 258 | + for i in cfg.blocks: |
| 259 | + if cfg.blocks[i].stmts_to_code(): |
| 260 | + if int(i) == 1: |
| 261 | + nodes.append('BEGIN') |
| 262 | + elif int(i) == len(cfg.blocks): |
| 263 | + nodes.append('EXIT') |
| 264 | + else: |
| 265 | + st = cfg.blocks[i].stmts_to_code() |
| 266 | + st_no_space = re.sub(r"\s+", "", st) |
| 267 | + # if start with if or while, delete these keywords |
| 268 | + if st.startswith('if'): |
| 269 | + st = st[3:] |
| 270 | + elif st.startswith('while'): |
| 271 | + st = st[6:] |
| 272 | + if cfg.blocks[i].condition: |
| 273 | + st = 'T '+ st |
| 274 | + if st.endswith('\n'): |
| 275 | + st = st[:-1] |
| 276 | + if st.endswith(":"): |
| 277 | + st = st[:-1] |
| 278 | + nodes.append(st) |
| 279 | + matching[i] = len(nodes) |
| 280 | + |
| 281 | + fwd_edges = [] |
| 282 | + back_edges = [] |
| 283 | + edges = {} |
| 284 | + for edge in cfg.edges: |
| 285 | + if edge not in cfg.back_edges: |
| 286 | + fwd_edges.append((matching[edge[0]], matching[edge[1]])) |
| 287 | + else: |
| 288 | + back_edges.append((matching[edge[0]], matching[edge[1]])) |
| 289 | + if matching[edge[0]] not in edges: |
| 290 | + edges[matching[edge[0]]] = [matching[edge[1]]] |
| 291 | + else: |
| 292 | + edges[matching[edge[0]]].append(matching[edge[1]]) |
| 293 | + exe_path = [0 for i in range(len(nodes))] |
| 294 | + for i in range(len(cfg.path)): |
| 295 | + if cfg.path[i] == 1: |
| 296 | + exe_path[matching[i+1]-1] = 1 |
| 297 | + out_nodes=[nodes, orin_nodes] |
| 298 | + out_fw_path=[fwd_edges, orin_fwd_edges] |
| 299 | + out_back_path=[back_edges, orin_back_edges] |
| 300 | + out_exe_path=[exe_path, orin_exe_path] |
| 301 | + data_example = { |
| 302 | + 'nodes': out_nodes, |
| 303 | + 'forward': out_fw_path, |
| 304 | + 'backward': out_back_path, |
| 305 | + 'target': out_exe_path, |
| 306 | + } |
| 307 | + |
| 308 | + df = pd.DataFrame(data_example) |
| 309 | + # Save to CSV |
| 310 | + df.to_csv(f'{outpath}/output.csv', index=False, quoting=1) |
| 311 | + examples = read_data(f'{outpath}/output.csv', fields) |
| 312 | + test = data.Dataset(examples, fields) |
| 313 | + test_iter = Iterator(test, batch_size=2, device=device, train=False, |
| 314 | + sort=False, sort_key=lambda x: len(x.nodes), sort_within_batch=False, repeat=False, shuffle=False) |
| 315 | + with torch.no_grad(): |
| 316 | + for batch in test_iter: |
| 317 | + x, edges, target = batch.nodes, (batch.forward, batch.backward), batch.target.float() |
| 318 | + if isinstance(x, tuple): |
| 319 | + pred = net(x[0], edges, x[1], x[2]) |
| 320 | + else: |
| 321 | + pred = net(x, edges) |
| 322 | + pred = pred[0].squeeze() |
| 323 | + pred = (pred > opt.beta).float() |
| 324 | + if pred[len(nodes)-1] == 1: |
| 325 | + print("No Runtime Error") |
| 326 | + feedback_list.append(generated_inputs) |
| 327 | + else: |
| 328 | + mask_pred = pred[:len(nodes)] == 1 |
| 329 | + indices_pred = torch.nonzero(mask_pred).flatten() |
| 330 | + farthest_pred = indices_pred.max().item() |
| 331 | + error_line = nodes[farthest_pred] |
| 332 | + print(f"Runtime Error in line: {error_line}") |
| 333 | + |
| 334 | + mask_target = target[0][:len(nodes)] == 1 |
| 335 | + indices_target = torch.nonzero(mask_target).flatten() |
| 336 | + farthest_target = indices_target.max().item() |
| 337 | + true_error_line = nodes[farthest_target] |
| 338 | + error_dict[file] = [error_line, true_error_line] |
| 339 | + |
| 340 | + if farthest_pred == farthest_target: |
| 341 | + locate += 1 |
| 342 | + repeat = False |
| 343 | + |
| 344 | +locate_true = locate/len(error_dict)*100 |
| 345 | +print(f'Fuzz testing within {opt.time}s') |
| 346 | +print(f'Sucessfully detect: {len(error_dict)}/{len(files)}') |
| 347 | +print(f'Bug Localization Acc: {locate_true:.2f}%') |
| 348 | +print(error_dict) |
0 commit comments