Skip to content

Commit 4809ce8

Browse files
CuongLC4.RECuongLC4.RE
authored andcommitted
add fuzz testing
1 parent 47cf5e1 commit 4809ce8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+2499
-0
lines changed

cfg.py

Lines changed: 1138 additions & 0 deletions
Large diffs are not rendered by default.

fuzz_testing.py

Lines changed: 348 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,348 @@
1+
import os
2+
import anthropic
3+
import ast, astor
4+
from cfg import *
5+
import re
6+
import sys
7+
import trace_execution
8+
import os
9+
import io
10+
import pandas as pd
11+
from torchtext import data
12+
from torchtext.data import Iterator
13+
import pandas as pd
14+
import torch
15+
import model
16+
import time
17+
import config
18+
import argparse
19+
20+
def generate_prompt(method_code, feedback=""):
21+
prompt = f"""
22+
\n\nHuman: You are a terminal. Analyze the following Python code and generate likely inputs for all variables that might raise errors. Add these generated inputs at the beginning of the code snippet.
23+
24+
Example:
25+
Python Method:
26+
if(S[0]=="A" and S[2,-1].count("C")==1):
27+
cnt=0
28+
for i in S:
29+
if(97<=ord(i) and ord(i)<=122):
30+
cnt+=1
31+
if(cnt==2):
32+
print("AC")
33+
else :
34+
print("WA")
35+
else :
36+
print("WA")
37+
38+
Generated Input:
39+
S = 'AtCoder'
40+
41+
Task:
42+
Given the following Python method, generate likely inputs for variables:
43+
{feedback}
44+
45+
Python Method:
46+
{method_code}
47+
48+
Generated Input:
49+
(No explanation needed, only one Generated Input:)
50+
\n\nAssistant:
51+
"""
52+
return prompt
53+
54+
def get_generated_inputs(claude_api_key, model, method_code, feedback=""):
55+
client = anthropic.Anthropic(api_key=claude_api_key)
56+
prompt = generate_prompt(method_code, feedback)
57+
response = client.messages.create(
58+
model=model,
59+
max_tokens=1024,
60+
messages=[
61+
{"role": "user", "content": prompt}
62+
]
63+
)
64+
return response.content[0].text
65+
66+
def add_generated_inputs_to_code(code, inputs):
67+
lines = code.split('\n')
68+
# Find the first non-import line
69+
insert_index = 0
70+
for i, line in enumerate(lines):
71+
if not line.startswith(('import', 'from', '\n')):
72+
insert_index = i
73+
break
74+
75+
# Insert generated inputs at the found index
76+
for input_line in inputs.split('\n'):
77+
if input_line.startswith("Generated"):
78+
continue
79+
if input_line.strip():
80+
lines.insert(insert_index, input_line)
81+
insert_index += 1
82+
83+
return '\n'.join(lines)
84+
85+
def read_data(data_path, fields):
86+
csv_data = pd.read_csv(data_path, chunksize=100)
87+
all_examples = []
88+
for n, chunk in enumerate(csv_data):
89+
examples = chunk.apply(lambda r: data.Example.fromlist([eval(r['nodes']), eval(r['forward']), eval(r['backward']),
90+
eval(r['target'])], fields), axis=1)
91+
all_examples.extend(list(examples))
92+
return all_examples
93+
94+
opt = config.parse()
95+
if opt.claude_api_key == None:
96+
raise Exception("Lack of CLAUDE api")
97+
if opt.cuda_num == None:
98+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
99+
else:
100+
device = torch.device(f"cuda:{opt.cuda_num}" if torch.cuda.is_available() else "cpu")
101+
102+
TEXT = data.Field(tokenize=lambda x: x.split()[:512])
103+
NODE = data.NestedField(TEXT, preprocessing=lambda x: x[:100], include_lengths=True)
104+
ROW = data.Field(pad_token=1.0, use_vocab=False,
105+
preprocessing=lambda x: [1, 1] if any(i > 100 for i in x) else x)
106+
EDGE = data.NestedField(ROW)
107+
TARGET = data.Field(use_vocab=False, preprocessing=lambda x: x[:100], pad_token=0, batch_first=True)
108+
109+
fields = [("nodes", NODE), ("forward", EDGE), ("backward", EDGE), ("target", TARGET)]
110+
111+
print('Read data...')
112+
examples = read_data(f'data/CodeNet_train.csv', fields)
113+
train = data.Dataset(examples, fields)
114+
NODE.build_vocab(train, max_size=100000)
115+
116+
orin_nodes = ['BEGIN', "_in = ['2', 3]", 'cont_str = _in[0] * _in[1]', 'cont_num = int(cont_str)', 'sqrt_flag = False', 'p1 = 0', 'p1 < len(range(4, 100))', 'T i = range(4, 100)[p1]', 'sqrt_flag', 'sqrt = i * i', 'cont_num == sqrt', 'T sqrt_flag = True', 'p1 += 1', "T print('Yes')", "print('No')", 'EXIT']
117+
orin_fwd_edges = [(1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7), (7, 8), (7, 9), (8, 10), (10, 11), (11, 12), (12, 9), (9, 14), (9, 15), (11, 13), (14, 16), (15, 16)]
118+
orin_back_edges = [(13, 7)]
119+
orin_exe_path = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1]
120+
121+
net = model.CodeFlow(opt).to(device)
122+
checkpoint_path = f"checkpoints/checkpoints_{opt.checkpoint}/epoch-{opt.epoch}.pt"
123+
net.load_state_dict(torch.load(checkpoint_path, map_location=device))
124+
net.eval()
125+
126+
outpath = 'fuzz_testing_output'
127+
if not os.path.exists(outpath):
128+
os.makedirs(outpath)
129+
130+
def extract_inputs(generated_text):
131+
# Use regular expression to match lines that are variable assignments
132+
input_lines = re.findall(r'^\s*\w+\s*=\s*.+$', generated_text, re.MULTILINE)
133+
return '\n'.join(input_lines)
134+
135+
error_dict = {}
136+
locate = 0
137+
for root, _, files in os.walk(opt.folder_path):
138+
files = sorted(files, key=lambda x: int(x.split('.')[0][5:]))
139+
for file in files:
140+
print(f'Fuzz testing file {file}')
141+
feedback_list = []
142+
start_time = time.time()
143+
time_limit = opt.time # time limit in seconds
144+
repeat = True
145+
while repeat:
146+
if time.time() - start_time > time_limit:
147+
print(f'Time limit exceeded for file {file}')
148+
break
149+
feedback = f"\nThese inputs did not raise runtime errors, avoid to generate the same:\n{feedback_list}" if feedback_list else ""
150+
file_path = os.path.join(opt.folder_path, file)
151+
with open(file_path, 'r') as f:
152+
code = f.read()
153+
generated_inputs = get_generated_inputs(opt.claude_api_key, opt.model, code, feedback)
154+
generated_inputs = extract_inputs(generated_inputs)
155+
print(generated_inputs)
156+
# Add generated inputs to the original code
157+
modified_code = add_generated_inputs_to_code(code, generated_inputs)
158+
filename = os.path.join(outpath, file)
159+
with open(filename, 'w') as modified_file:
160+
modified_file.write(modified_code)
161+
162+
BlockId().counter = 0
163+
try:
164+
source = open(filename, 'r').read()
165+
compile(source, filename, 'exec')
166+
except:
167+
print('Error in source code')
168+
exit(1)
169+
parser = PyParser(source)
170+
parser.removeCommentsAndDocstrings()
171+
parser.formatCode()
172+
try:
173+
cfg = CFGVisitor().build(filename, ast.parse(parser.script))
174+
except AttributeError:
175+
continue
176+
except IndentationError:
177+
continue
178+
except TypeError:
179+
continue
180+
except SyntaxError:
181+
continue
182+
183+
cfg.clean()
184+
try:
185+
cfg.track_execution()
186+
except Exception:
187+
print("Generated input is not valid")
188+
continue
189+
code = {}
190+
for_loop = {}
191+
for i in cfg.blocks:
192+
if cfg.blocks[i].for_loop != 0:
193+
if cfg.blocks[i].for_loop not in for_loop:
194+
for_loop[cfg.blocks[i].for_loop] = [i]
195+
else:
196+
for_loop[cfg.blocks[i].for_loop].append(i)
197+
first = []
198+
second = []
199+
for i in for_loop:
200+
first.append(for_loop[i][0]+1)
201+
second.append(for_loop[i][1])
202+
orin_node = []
203+
track = {}
204+
track_for = {}
205+
for i in cfg.blocks:
206+
if cfg.blocks[i].stmts_to_code():
207+
if int(i) == 1:
208+
st = 'BEGIN'
209+
elif int(i) == len(cfg.blocks):
210+
st = 'EXIT'
211+
else:
212+
if i in first:
213+
line = astor.to_source(cfg.blocks[i].for_name)
214+
st = line.split('\n')[0]
215+
st = re.sub(r"\s+", "", st).replace('"', "'").replace("(", "").replace(")", "")
216+
else:
217+
st = cfg.blocks[i].stmts_to_code()
218+
st = re.sub(r"\s+", "", st).replace('"', "'").replace("(", "").replace(")", "")
219+
orin_node.append([i, st, None])
220+
if st not in track:
221+
track[st] = [len(orin_node)-1]
222+
else:
223+
track[st].append(len(orin_node)-1)
224+
track_for[i] = len(orin_node)-1
225+
with open(filename, 'r') as file_open:
226+
lines = file_open.readlines()
227+
for i in range(1, len(lines)+1):
228+
line = lines[i-1]
229+
#delete \n at the end of each line and delete all spaces
230+
line = line.strip()
231+
line = re.sub(r"\s+", "", line).replace('"', "'").replace("(", "").replace(")", "")
232+
if line.startswith('elif'):
233+
line = line[2:]
234+
if line in track:
235+
orin_node[track[line][0]][2] = i
236+
if orin_node[track[line][0]][0] in first:
237+
orin_node[track[line][0]-1][2] = i-0.4
238+
orin_node[track[line][0]+1][2] = i+0.4
239+
if len(track[line]) > 1:
240+
track[line].pop(0)
241+
for i in second:
242+
max_val = 0
243+
for edge in cfg.edges:
244+
if edge[0] == i:
245+
if orin_node[track_for[edge[1]]][2] > max_val:
246+
max_val = orin_node[track_for[edge[1]]][2]
247+
if edge[1] == i:
248+
if orin_node[track_for[edge[0]]][2] > max_val:
249+
max_val = orin_node[track_for[edge[0]]][2]
250+
orin_node[track_for[i]][2] = max_val + 0.5
251+
orin_node[0][2] = 0
252+
orin_node[-1][2] = len(lines)+1
253+
# sort orin_node by the third element
254+
orin_node.sort(key=lambda x: x[2])
255+
256+
nodes = []
257+
matching = {}
258+
for i in cfg.blocks:
259+
if cfg.blocks[i].stmts_to_code():
260+
if int(i) == 1:
261+
nodes.append('BEGIN')
262+
elif int(i) == len(cfg.blocks):
263+
nodes.append('EXIT')
264+
else:
265+
st = cfg.blocks[i].stmts_to_code()
266+
st_no_space = re.sub(r"\s+", "", st)
267+
# if start with if or while, delete these keywords
268+
if st.startswith('if'):
269+
st = st[3:]
270+
elif st.startswith('while'):
271+
st = st[6:]
272+
if cfg.blocks[i].condition:
273+
st = 'T '+ st
274+
if st.endswith('\n'):
275+
st = st[:-1]
276+
if st.endswith(":"):
277+
st = st[:-1]
278+
nodes.append(st)
279+
matching[i] = len(nodes)
280+
281+
fwd_edges = []
282+
back_edges = []
283+
edges = {}
284+
for edge in cfg.edges:
285+
if edge not in cfg.back_edges:
286+
fwd_edges.append((matching[edge[0]], matching[edge[1]]))
287+
else:
288+
back_edges.append((matching[edge[0]], matching[edge[1]]))
289+
if matching[edge[0]] not in edges:
290+
edges[matching[edge[0]]] = [matching[edge[1]]]
291+
else:
292+
edges[matching[edge[0]]].append(matching[edge[1]])
293+
exe_path = [0 for i in range(len(nodes))]
294+
for i in range(len(cfg.path)):
295+
if cfg.path[i] == 1:
296+
exe_path[matching[i+1]-1] = 1
297+
out_nodes=[nodes, orin_nodes]
298+
out_fw_path=[fwd_edges, orin_fwd_edges]
299+
out_back_path=[back_edges, orin_back_edges]
300+
out_exe_path=[exe_path, orin_exe_path]
301+
data_example = {
302+
'nodes': out_nodes,
303+
'forward': out_fw_path,
304+
'backward': out_back_path,
305+
'target': out_exe_path,
306+
}
307+
308+
df = pd.DataFrame(data_example)
309+
# Save to CSV
310+
df.to_csv(f'{outpath}/output.csv', index=False, quoting=1)
311+
examples = read_data(f'{outpath}/output.csv', fields)
312+
test = data.Dataset(examples, fields)
313+
test_iter = Iterator(test, batch_size=2, device=device, train=False,
314+
sort=False, sort_key=lambda x: len(x.nodes), sort_within_batch=False, repeat=False, shuffle=False)
315+
with torch.no_grad():
316+
for batch in test_iter:
317+
x, edges, target = batch.nodes, (batch.forward, batch.backward), batch.target.float()
318+
if isinstance(x, tuple):
319+
pred = net(x[0], edges, x[1], x[2])
320+
else:
321+
pred = net(x, edges)
322+
pred = pred[0].squeeze()
323+
pred = (pred > opt.beta).float()
324+
if pred[len(nodes)-1] == 1:
325+
print("No Runtime Error")
326+
feedback_list.append(generated_inputs)
327+
else:
328+
mask_pred = pred[:len(nodes)] == 1
329+
indices_pred = torch.nonzero(mask_pred).flatten()
330+
farthest_pred = indices_pred.max().item()
331+
error_line = nodes[farthest_pred]
332+
print(f"Runtime Error in line: {error_line}")
333+
334+
mask_target = target[0][:len(nodes)] == 1
335+
indices_target = torch.nonzero(mask_target).flatten()
336+
farthest_target = indices_target.max().item()
337+
true_error_line = nodes[farthest_target]
338+
error_dict[file] = [error_line, true_error_line]
339+
340+
if farthest_pred == farthest_target:
341+
locate += 1
342+
repeat = False
343+
344+
locate_true = locate/len(error_dict)*100
345+
print(f'Fuzz testing within {opt.time}s')
346+
print(f'Sucessfully detect: {len(error_dict)}/{len(files)}')
347+
print(f'Bug Localization Acc: {locate_true:.2f}%')
348+
print(error_dict)

fuzz_testing_dataset/code_1.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
ans=0
2+
cur=0
3+
ACGT=set("A","C","G","T")
4+
for ss in s:
5+
if ss in ACGT:
6+
cur+=1
7+
else:
8+
ans=max(cur,ans)
9+
cur=0
10+
print(max(ans,cur))

fuzz_testing_dataset/code_10.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
S = 'nikoandsolstice'
2+
s = len(S)
3+
if (s <= K):
4+
print(S)

fuzz_testing_dataset/code_11.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
s = 'CSS'
2+
s=0
3+
for i in range(len(s)):
4+
if s[i]==t[i]:
5+
s+=1
6+
print(s)

fuzz_testing_dataset/code_12.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
col = N / 2
2+
if col == 0:
3+
print(col)
4+
else:
5+
col += 1
6+
print(col)

fuzz_testing_dataset/code_13.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
x=a+a*a+a*a*a
2+
print(a)

0 commit comments

Comments
 (0)