Skip to content

Commit 8f4d974

Browse files
file level split
1 parent 55e7bb6 commit 8f4d974

File tree

1 file changed

+45
-0
lines changed

1 file changed

+45
-0
lines changed

file_level_split.py

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import glob
2+
import os
3+
import random
4+
from math import ceil
5+
from shutil import copyfile as cp
6+
7+
8+
data_dir = "/data2/edinella/java-small-og/"
9+
out_dir = "/data2/edinella/java-small-og-fs/"
10+
11+
TRAIN_SPLIT = .8
12+
TEST_VAL_SPLIT = .1
13+
14+
def copy_files(files, folder):
15+
for i in range(0, len(files)):
16+
cp(files[i], os.path.join(out_dir, folder, str(i) + ".java"))
17+
18+
all_files = []
19+
20+
for (dirpath, dirnames, filenames) in os.walk(data_dir):
21+
all_files += [os.path.join(dirpath, _file) for _file in filenames]
22+
23+
random.shuffle(all_files)
24+
25+
l = len(all_files)
26+
end = ceil(TRAIN_SPLIT*l)
27+
train = all_files[0:end]
28+
29+
start = end
30+
end = end + ceil(TEST_VAL_SPLIT*l)
31+
val = all_files[start:end]
32+
33+
test = all_files[end:]
34+
35+
36+
if not os.path.exists(out_dir):
37+
os.mkdir(out_dir)
38+
os.mkdir(os.path.join(out_dir, "training"))
39+
os.mkdir(os.path.join(out_dir, "test"))
40+
os.mkdir(os.path.join(out_dir, "validation"))
41+
42+
copy_files(train, "training")
43+
copy_files(test, "test")
44+
copy_files(val, "validation")
45+

0 commit comments

Comments
 (0)