diff --git a/.gitignore b/.gitignore index d1e0b2b..3401160 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ trained_models/ packaged_models/ .env.local +.env __pycache__ *.pyc .DS_Store @@ -15,3 +16,4 @@ __pycache__ */.DS_Store data/sets/* requirements.verbose.txt +*env* diff --git a/train.py b/train.py index bdf47a0..2780be3 100644 --- a/train.py +++ b/train.py @@ -4,6 +4,9 @@ # pip install -r requirements.txt import sys from datetime import datetime, timedelta +from dateutil.relativedelta import relativedelta +import os +import pandas as pd # pylint: disable=no-name-in-module from configs import models @@ -13,8 +16,47 @@ from models.model_factory import ModelFactory from utils.common import print_colored +def split_date_ranges(start_date, end_date, frequency): + """ + Splits the date range between start_date and end_date into subranges based on the specified frequency. + + Args: + start_date (str): Start date in the format 'YYYY-MM-DD'. + end_date (str): End date in the format 'YYYY-MM-DD'. + frequency (str): Frequency to split the range. Options are '1min', '5min', '1hour', '4hour', '1day'. + + Returns: + list of tuples: A list of (start_date, end_date) tuples. + """ + + start_date = datetime.strptime(start_date, "%Y-%m-%d") + end_date = datetime.strptime(end_date, "%Y-%m-%d") + + date_ranges = [] + current_date = start_date + + # Define the time delta based on frequency + if frequency == '1min': + delta = timedelta(days=1) + elif frequency == '5min': + delta = timedelta(days=15) + elif frequency == '1hour': + delta = timedelta(days=182) # Approx. 6 months + elif frequency in ['4hour', '1day']: + delta = timedelta(days=365) # Approx. 1 year + else: + raise ValueError("Unsupported frequency. Choose from '1min', '5min', '1hour', '4hour', '1day'.") + + while current_date < end_date: + next_date = current_date + delta + if next_date > end_date: + next_date = end_date + date_ranges.append((current_date.strftime("%Y-%m-%d"), next_date.strftime("%Y-%m-%d"))) + current_date = next_date + + return date_ranges -def select_data(fetcher, default_selection=None, file_path=None): +def select_data(fetcher, default_selection=None, file_path=None, dir_path=None): """Provide an interface to choose between Tiingo stock, Tiingo crypto, or CSV data.""" default_end_date = (datetime.now() - timedelta(days=1)).strftime("%Y-%m-%d") @@ -24,8 +66,9 @@ def select_data(fetcher, default_selection=None, file_path=None): print("1. Tiingo Stock Data") print("2. Tiingo Crypto Data") print("3. Load data from CSV file") + print("4. Load data from a directory of CSV files") - selection = input("Enter your choice (1/2/3): ").strip() + selection = input("Enter your choice (1/2/3/4): ").strip() else: selection = default_selection @@ -52,7 +95,14 @@ def select_data(fetcher, default_selection=None, file_path=None): print( f"Fetching Tiingo Stock Data for {symbol} from {start_date} to {end_date} with {frequency} frequency..." ) - return fetcher.fetch_tiingo_stock_data(symbol, start_date, end_date, frequency) + stock_data = pd.DataFrame() + date_ranges = split_date_ranges(start_date, end_date, frequency) + for start, end in date_ranges: + fetched_data = fetcher.fetch_tiingo_stock_data(symbol, start.strftime('%Y-%m-%d'), end.strftime('%Y-%m-%d'), frequency) + if fetched_data is not None and not fetched_data.empty: + stock_data = pd.concat([stock_data, fetched_data], ignore_index=True) + + return stock_data if selection == "2": print("You selected Tiingo Crypto Data.") @@ -77,7 +127,14 @@ def select_data(fetcher, default_selection=None, file_path=None): print( f"Fetching Tiingo Crypto Data for {symbol} from {start_date} to {end_date} with {frequency} frequency..." ) - return fetcher.fetch_tiingo_crypto_data(symbol, start_date, end_date, frequency) + crypto_data = pd.DataFrame() + date_ranges = split_date_ranges(start_date, end_date, frequency) + for start, end in date_ranges: + fetched_data = fetcher.fetch_tiingo_crypto_data(symbol, start.strftime('%Y-%m-%d'), end.strftime('%Y-%m-%d'), frequency) + if fetched_data is not None and not fetched_data.empty: + crypto_data = pd.concat([crypto_data, fetched_data], ignore_index=True) + + return crypto_data if selection == "3": print("You selected to load data from a CSV file.") @@ -85,6 +142,34 @@ def select_data(fetcher, default_selection=None, file_path=None): file_path = input("Enter the CSV file path: ").strip() return CSVLoader.load_csv(file_path) + if selection == "4": + print("You selected to load data from a directory of CSV files.") + if dir_path is None: + dir_path = input("Enter the directory path: ").strip() + + if not os.path.exists(dir_path) or not os.path.isdir(dir_path): + print("Invalid directory path.") + sys.exit(1) + + combined_data = pd.DataFrame() + + for filename in os.listdir(dir_path): + if filename.endswith(".csv"): + file_path = os.path.join(dir_path, filename) + print(f"Loading data from {file_path}...") + try: + data = CSVLoader.load_csv(file_path) + combined_data = pd.concat([combined_data, data], ignore_index=True) + except Exception as e: + print(f"Failed to load {file_path}: {e}") + + combined_data.drop_duplicates(inplace=True) + + print(f"Loaded and combined data from directory: {dir_path}") + print(f"Total rows after removing duplicates: {len(combined_data)}") + + return combined_data + # Exit the program if the user enters an invalid choice print_colored("Invalid choice", "error") sys.exit(1)