-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsplit_dataset.py
22 lines (18 loc) · 944 Bytes
/
split_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import os
from preprocessing.CaptionsSplitter import CaptionsSplitter
'''
This script adds tokens and splits the captions into train, validation and test sets.
'''
if __name__ == '__main__':
captions_path = 'dataset/CaptionsClean3.txt'
train_captions_path = 'dataset/CaptionsClean3_train.txt'
val_captions_path = 'dataset/CaptionsClean3_validation.txt'
test_captions_path = 'dataset/CaptionsClean3_test.txt'
if input('Split Dataset?(y/n) ') == 'y':
splitter_obj = CaptionsSplitter(captions_path)
splitter_obj.add_tokens_to_caption()
train_captions, val_captions, test_captions = splitter_obj.split_dataset()
splitter_obj.print_captions_info()
splitter_obj.save_captions_to_file(train_captions_path, train_captions)
splitter_obj.save_captions_to_file(val_captions_path, val_captions)
splitter_obj.save_captions_to_file(test_captions_path, test_captions)