Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
__pycache__/
*.py[cod]
*$py.class
*.csv

.DS_Store
.idea/
Expand Down
48 changes: 48 additions & 0 deletions cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import logging
from argparse import ArgumentParser

import data_generator


def arg_parse():
arg_p = ArgumentParser()
arg_p.add_argument('--action', required=True) # resample, tick_count, generate_plots
arg_p.add_argument('--input_filename')

# if action == resample
arg_p.add_argument('--output_filename')
arg_p.add_argument('--resample_frequency', default='5Min')

# if action == generate_plots
arg_p.add_argument('--use_quantiles', action='store_true')
arg_p.add_argument('--output_dir')
return arg_p


def main():
args = arg_parse().parse_args()

if args.action == 'resample':
from utils import file_processor
p = file_processor(args.input_filename, args.resample_frequency)
if args.output_filename is not None:
p.to_csv(args.output_filename)
else:
print(p)
elif args.action == 'tick_count':
from utils import count_ticks_per_day
count_ticks_per_day(args.input_filename)
elif args.action == 'generate_plots':
data_gen_func = data_generator.generate_quantiles if args.use_quantiles else data_generator.generate_up_down
print('Using: {}'.format(data_gen_func))
data_gen_func(args.output_dir, args.input_filename)

"""
assert len(args) == 4, 'Usage: python3 {} DATA_FOLDER_TO_STORE_GENERATED_DATASET ' \
'BITCOIN_MARKET_DATA_CSV_PATH USE_QUANTILES'.format(args[0])
"""


if __name__ == '__main__':
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
main()
74 changes: 41 additions & 33 deletions data_generator.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import os
import shutil
import sys
from time import time
from uuid import uuid4

import logging
import numpy as np
import os
import pandas as pd
import shutil
from tqdm import tqdm

from data_manager import file_processor
from returns_quantization import add_returns_in_place
from utils import *
from utils import mkdir_p, save_to_file

logger = logging.getLogger(__name__)

np.set_printoptions(threshold=np.nan)
pd.set_option('display.max_rows', 500)
Expand All @@ -27,10 +28,9 @@ def get_label(btc_df, btc_slice, i, slice_size):

def generate_up_down(data_folder, bitcoin_file):
def get_price_direction(btc_df, btc_slice, i, slice_size):
# last_price = btc_slice[-2:-1]['price_close'].values[0] #this is actually the second last price
last_price = btc_slice[-1:]['price_close'].values[0] #one option to get the correct last price
# last_price = btc_df[i + slice_size - 1:i + slice_size]['price_close'].values[0] #another option to get the correct last price

last_price = btc_slice[-1:]['price_close'].values[0] # one option to get the correct last price
# another option to get the correct last price
# last_price = btc_df[i + slice_size - 1:i + slice_size]['price_close'].values[0]
next_price = btc_df[i + slice_size:i + slice_size + 1]['price_close'].values[0]
if last_price < next_price:
class_name = 'UP'
Expand All @@ -41,8 +41,23 @@ def get_price_direction(btc_df, btc_slice, i, slice_size):
return generate_cnn_dataset(data_folder, bitcoin_file, get_price_direction)


def dump_example_to_file(df, i, slice_size, get_class_name, data_folder, is_test):
btc_slice = df[i:i + slice_size]
if btc_slice.isnull().values.any():
# sometimes prices are discontinuous and nothing happened in one 5min bucket.
# in that case, we consider this slice as wrong and we raise an exception.
# it's likely to happen at the beginning of the data set where the volumes are low.
raise Exception('NaN values detected. Please remove them.')
class_name = get_class_name(df, btc_slice, i, slice_size)
save_dir = os.path.join(data_folder, 'test' if is_test else 'train', class_name)
filename = os.path.join(save_dir, str(i)) + '.png'
mkdir_p(save_dir)
save_to_file(btc_slice, filename=filename)
return filename


def generate_cnn_dataset(data_folder, bitcoin_file, get_class_name):
btc_df = file_processor(bitcoin_file)
btc_df = pd.read_csv(bitcoin_file)
btc_df, levels = add_returns_in_place(btc_df)

print('-' * 80)
Expand All @@ -52,34 +67,27 @@ def generate_cnn_dataset(data_folder, bitcoin_file, get_class_name):
print(levels)
print('-' * 80)

slice_size = 40
test_every_steps = 10
n = len(btc_df) - slice_size

shutil.rmtree(data_folder, ignore_errors=True)
for epoch in range(int(1e6)):
st = time()

i = np.random.choice(n)
btc_slice = btc_df[i:i + slice_size]
slice_size = 40
cutoff = int(len(btc_df) * 0.9)
btc_df_tr = btc_df[:cutoff]
btc_df_te = btc_df[cutoff:]

if btc_slice.isnull().values.any():
# sometimes prices are discontinuous and nothing happened in one 5min bucket.
# in that case, we consider this slice as wrong and we raise an exception.
# it's likely to happen at the beginning of the data set where the volumes are low.
raise Exception('NaN values detected. Please remove them.')
bar_train = tqdm(range(len(btc_df_tr) - slice_size - 1))
for i in bar_train:
filename = dump_example_to_file(btc_df_tr, i, slice_size, get_class_name, data_folder, is_test=False)
bar_train.set_description(filename)
bar_train.close()

class_name = get_class_name(btc_df, btc_slice, i, slice_size)
save_dir = os.path.join(data_folder, 'train', class_name)
if epoch % test_every_steps == 0:
save_dir = os.path.join(data_folder, 'test', class_name)
mkdir_p(save_dir)
filename = save_dir + '/' + str(uuid4()) + '.png'
save_to_file(btc_slice, filename=filename)
print('epoch = {0}, time = {1:.3f}, filename = {2}'.format(str(epoch).zfill(8), time() - st, filename))
bar_test = tqdm(range(len(btc_df_te) - slice_size - 1))
for i in bar_test:
filename = dump_example_to_file(btc_df_te, i, slice_size, get_class_name, data_folder, is_test=True)
bar_test.set_description(filename)
bar_test.close()


def main():
logging.basicConfig(format='%(asctime)12s - %(levelname)s - %(message)s', level=logging.INFO)
args = sys.argv
assert len(args) == 4, 'Usage: python3 {} DATA_FOLDER_TO_STORE_GENERATED_DATASET ' \
'BITCOIN_MARKET_DATA_CSV_PATH USE_QUANTILES'.format(args[0])
Expand Down
25 changes: 0 additions & 25 deletions data_manager.py

This file was deleted.

2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
numpy
pandas
matplotlib
mpl_finance
tqdm
2 changes: 1 addition & 1 deletion returns_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pandas as pd

from data_manager import file_processor
from utils import file_processor
from utils import compute_returns


Expand Down
35 changes: 35 additions & 0 deletions start.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#!/usr/bin/env bash

set -e

if [[ $# -eq 0 ]] ; then
echo "Please give one argument to the script: <input_filename> <arguments>."
echo "Example is: ./$0 coinbaseUSD.csv --use_quantiles"
exit 0
fi

WORK_DIR=~/deep-learning-bitcoin


if [ -d "$WORK_DIR" ]; then
echo "Folder already exists."
source ${WORK_DIR}/venv/bin/activate
else
mkdir -p ${WORK_DIR}
virtualenv -p python3.6 ${WORK_DIR}/venv
source ${WORK_DIR}/venv/bin/activate
pip install -r requirements.txt
fi

INPUT_FILENAME=$1
RESAMPLE_FREQUENCY=5Min

INPUT_FILENAME_WO_EXT=$(basename ${INPUT_FILENAME} .csv)
RESAMPLE_FILENAME=${WORK_DIR}/${INPUT_FILENAME_WO_EXT}_$RESAMPLE_FREQUENCY.csv


python cli.py --action tick_count --input_filename ${INPUT_FILENAME}

python cli.py --action resample --input_filename ${INPUT_FILENAME} --resample_frequency ${RESAMPLE_FREQUENCY} --output_filename ${RESAMPLE_FILENAME}

python cli.py --action generate_plots --input_filename ${RESAMPLE_FILENAME} --output_dir ${WORK_DIR}/plots/
68 changes: 52 additions & 16 deletions utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
import datetime

import matplotlib
import pandas as pd

matplotlib.use('Agg')

import matplotlib.pyplot as plt
from mpl_finance import candlestick2_ohlc


def compute_returns(p):
close_prices = p['price_close']
close_prices_returns = 100 * ((close_prices.shift(-1) - close_prices) / close_prices).fillna(0.0)
return close_prices_returns.shift(1).fillna(0)


def plot_p(df):
import matplotlib.pyplot as plt
from matplotlib.finance import candlestick2_ohlc
def prepare_plot(df):
fig, ax = plt.subplots()
candlestick2_ohlc(ax,
df['price_open'].values,
Expand All @@ -20,23 +26,17 @@ def plot_p(df):
colorup='g',
colordown='r',
alpha=1)
plt.grid(True)
return fig


def plot_p(df):
prepare_plot(df)
plt.show()
print('Done.')


def save_to_file(df, filename):
import matplotlib.pyplot as plt
from matplotlib.finance import candlestick2_ohlc
fig, ax = plt.subplots()
candlestick2_ohlc(ax,
df['price_open'].values,
df['price_high'].values,
df['price_low'].values,
df['price_close'].values,
width=0.6,
colorup='g',
colordown='r',
alpha=1)
fig = prepare_plot(df)
plt.savefig(filename)
plt.close(fig)

Expand All @@ -51,3 +51,39 @@ def mkdir_p(path):
pass
else:
raise


def raw_read(data_file):
print('Reading bitcoin market data file: {}.'.format(data_file))
d = pd.read_table(data_file, sep=',', header=None, index_col=0, names=['price', 'volume'])
d.index = d.index.map(lambda ts: datetime.datetime.fromtimestamp(int(ts)))
d.index.names = ['DateTime_UTC']
return d


def file_processor(data_file, resample_frequency='1H'):
d = raw_read(data_file)
p = pd.DataFrame(d['price'].resample(resample_frequency).ohlc())
p.columns = ['price_open', 'price_high', 'price_low', 'price_close']
v = pd.DataFrame(d['volume'].resample(resample_frequency).sum())
v.columns = ['volume']
p['volume'] = v['volume']

# drop NaN values.
# for example sometimes we don't have data for like one hour in a row.
# So we have NaN buckets of resample_frequency in this particular hour.
# Our convention is to avoid those NaN values and drop them!
p.dropna(inplace=True)
# p.to_csv('/tmp/bitcoin_coinbase_{}.csv'.format(resample_frequency), sep='\t')
return p


def count_ticks_per_day(data_file):
d = raw_read(data_file)
print('date, tick_count')
for d_day in to_days(d):
print('{}, {}'.format(str(d_day.index[0]).split(' ')[0], len(d_day)))


def to_days(df):
return [g[1] for g in df.groupby([df.index.year, df.index.month, df.index.day])] # DataFrame to List<Day>