Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
trained_models/
packaged_models/
.env.local
.env
__pycache__
*.pyc
.DS_Store
Expand All @@ -15,3 +16,4 @@ __pycache__
*/.DS_Store
data/sets/*
requirements.verbose.txt
*env*
93 changes: 89 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
# pip install -r requirements.txt
import sys
from datetime import datetime, timedelta
from dateutil.relativedelta import relativedelta
import os
import pandas as pd

# pylint: disable=no-name-in-module
from configs import models
Expand All @@ -13,8 +16,47 @@
from models.model_factory import ModelFactory
from utils.common import print_colored

def split_date_ranges(start_date, end_date, frequency):
"""
Splits the date range between start_date and end_date into subranges based on the specified frequency.

Args:
start_date (str): Start date in the format 'YYYY-MM-DD'.
end_date (str): End date in the format 'YYYY-MM-DD'.
frequency (str): Frequency to split the range. Options are '1min', '5min', '1hour', '4hour', '1day'.

Returns:
list of tuples: A list of (start_date, end_date) tuples.
"""

start_date = datetime.strptime(start_date, "%Y-%m-%d")
end_date = datetime.strptime(end_date, "%Y-%m-%d")

date_ranges = []
current_date = start_date

# Define the time delta based on frequency
if frequency == '1min':
delta = timedelta(days=1)
elif frequency == '5min':
delta = timedelta(days=15)
elif frequency == '1hour':
delta = timedelta(days=182) # Approx. 6 months
elif frequency in ['4hour', '1day']:
delta = timedelta(days=365) # Approx. 1 year
else:
raise ValueError("Unsupported frequency. Choose from '1min', '5min', '1hour', '4hour', '1day'.")

while current_date < end_date:
next_date = current_date + delta
if next_date > end_date:
next_date = end_date
date_ranges.append((current_date.strftime("%Y-%m-%d"), next_date.strftime("%Y-%m-%d")))
current_date = next_date

return date_ranges

def select_data(fetcher, default_selection=None, file_path=None):
def select_data(fetcher, default_selection=None, file_path=None, dir_path=None):
"""Provide an interface to choose between Tiingo stock, Tiingo crypto, or CSV data."""

default_end_date = (datetime.now() - timedelta(days=1)).strftime("%Y-%m-%d")
Expand All @@ -24,8 +66,9 @@ def select_data(fetcher, default_selection=None, file_path=None):
print("1. Tiingo Stock Data")
print("2. Tiingo Crypto Data")
print("3. Load data from CSV file")
print("4. Load data from a directory of CSV files")

selection = input("Enter your choice (1/2/3): ").strip()
selection = input("Enter your choice (1/2/3/4): ").strip()
else:
selection = default_selection

Expand All @@ -52,7 +95,14 @@ def select_data(fetcher, default_selection=None, file_path=None):
print(
f"Fetching Tiingo Stock Data for {symbol} from {start_date} to {end_date} with {frequency} frequency..."
)
return fetcher.fetch_tiingo_stock_data(symbol, start_date, end_date, frequency)
stock_data = pd.DataFrame()
date_ranges = split_date_ranges(start_date, end_date, frequency)
for start, end in date_ranges:
fetched_data = fetcher.fetch_tiingo_stock_data(symbol, start.strftime('%Y-%m-%d'), end.strftime('%Y-%m-%d'), frequency)
if fetched_data is not None and not fetched_data.empty:
stock_data = pd.concat([stock_data, fetched_data], ignore_index=True)

return stock_data

if selection == "2":
print("You selected Tiingo Crypto Data.")
Expand All @@ -77,14 +127,49 @@ def select_data(fetcher, default_selection=None, file_path=None):
print(
f"Fetching Tiingo Crypto Data for {symbol} from {start_date} to {end_date} with {frequency} frequency..."
)
return fetcher.fetch_tiingo_crypto_data(symbol, start_date, end_date, frequency)
crypto_data = pd.DataFrame()
date_ranges = split_date_ranges(start_date, end_date, frequency)
for start, end in date_ranges:
fetched_data = fetcher.fetch_tiingo_crypto_data(symbol, start.strftime('%Y-%m-%d'), end.strftime('%Y-%m-%d'), frequency)
if fetched_data is not None and not fetched_data.empty:
crypto_data = pd.concat([crypto_data, fetched_data], ignore_index=True)

return crypto_data

if selection == "3":
print("You selected to load data from a CSV file.")
if file_path is None:
file_path = input("Enter the CSV file path: ").strip()
return CSVLoader.load_csv(file_path)

if selection == "4":
print("You selected to load data from a directory of CSV files.")
if dir_path is None:
dir_path = input("Enter the directory path: ").strip()

if not os.path.exists(dir_path) or not os.path.isdir(dir_path):
print("Invalid directory path.")
sys.exit(1)

combined_data = pd.DataFrame()

for filename in os.listdir(dir_path):
if filename.endswith(".csv"):
file_path = os.path.join(dir_path, filename)
print(f"Loading data from {file_path}...")
try:
data = CSVLoader.load_csv(file_path)
combined_data = pd.concat([combined_data, data], ignore_index=True)
except Exception as e:
print(f"Failed to load {file_path}: {e}")

combined_data.drop_duplicates(inplace=True)

print(f"Loaded and combined data from directory: {dir_path}")
print(f"Total rows after removing duplicates: {len(combined_data)}")

return combined_data

# Exit the program if the user enters an invalid choice
print_colored("Invalid choice", "error")
sys.exit(1)
Expand Down