diff --git a/AI/libs/utils/daily_inquest.py b/AI/libs/utils/daily_inquest.py new file mode 100644 index 00000000..02850a13 --- /dev/null +++ b/AI/libs/utils/daily_inquest.py @@ -0,0 +1,337 @@ +# AI/data_ingestion/daily_ingest.py + +from __future__ import annotations +import os +from datetime import datetime, timedelta, timezone +from typing import List, Dict, Any + +import pandas as pd +import yfinance as yf +from psycopg2.extras import execute_values + +from AI.libs.utils.get_db_conn import get_db_conn, get_engine + +KST = timezone(timedelta(hours=9)) + + +# ============================= +# 공통 유틸 +# ============================= +def today_kst() -> datetime.date: + return datetime.now(KST).date() + + +def get_last_date_in_table(db_name: str, table: str, date_col: str) -> datetime.date | None: + """ + 해당 테이블에서 date_col의 최대 날짜를 가져옴. + 아무 데이터도 없으면 None 리턴. + """ + from sqlalchemy import text + + engine = get_engine(db_name) + with engine.connect() as conn: + res = conn.execute(text(f"SELECT MAX({date_col}) FROM {table};")).scalar() + if res is None: + return None + # res 가 date/datetime 타입이라 가정 + return res + + +# ============================= +# 1) 주가 데이터 수집/업서트 +# ============================= +def fetch_price_data_from_yf(tickers: List[str], start: str, end: str) -> pd.DataFrame: + """ + yfinance에서 일봉 데이터 가져와서 + transformer에서 쓰는 public.price_data 구조에 맞게 정리 + """ + frames = [] + for t in tickers: + print(f"[PRICE] Fetch {t} {start}~{end}") + df = yf.download(t, start=start, end=end, auto_adjust=False) + if df.empty: + continue + df = df.rename( + columns={ + "Open": "open", + "High": "high", + "Low": "low", + "Close": "close", + "Adj Close": "adjusted_close", + "Volume": "volume", + } + ) + df["ticker"] = t + df = df[["ticker", "open", "high", "low", "close", "volume", "adjusted_close"]] + df.index.name = "date" + df = df.reset_index() + # date는 timezone 없는 date로 + df["date"] = pd.to_datetime(df["date"]).dt.date + frames.append(df) + + if not frames: + return pd.DataFrame(columns=["ticker", "date", "open", "high", "low", "close", "volume", "adjusted_close"]) + out = pd.concat(frames, ignore_index=True) + return out + + +def upsert_price_data(db_name: str, df: pd.DataFrame): + """ + public.price_data 에 (ticker, date) 기준으로 UPSERT + """ + if df.empty: + print("[PRICE] No new data to upsert.") + return + + conn = get_db_conn(db_name) + try: + records = df[ + ["ticker", "date", "open", "high", "low", "close", "volume", "adjusted_close"] + ].to_records(index=False) + + sql = """ + INSERT INTO public.price_data + (ticker, date, open, high, low, close, volume, adjusted_close) + VALUES %s + ON CONFLICT (ticker, date) DO UPDATE SET + open = EXCLUDED.open, + high = EXCLUDED.high, + low = EXCLUDED.low, + close = EXCLUDED.close, + volume = EXCLUDED.volume, + adjusted_close = EXCLUDED.adjusted_close; + """ + with conn.cursor() as cur: + execute_values(cur, sql, records) + conn.commit() + print(f"[PRICE] Upserted {len(df)} rows into public.price_data") + finally: + conn.close() + + +def run_price_pipeline(config: Dict[str, Any]): + db_name = config["db_name"] + tickers = config["tickers"] + + # 테이블에 아무것도 없으면 start_from_config, 있으면 max(date)+1 부터 오늘까지 + last = get_last_date_in_table(db_name, "public.price_data", "date") + if last is None: + start_date = config.get("price_start", "2018-01-01") + else: + start_date = (last + timedelta(days=1)).strftime("%Y-%m-%d") + + end_date = today_kst().strftime("%Y-%m-%d") + + if start_date > end_date: + print("[PRICE] Already up to date.") + return + + df = fetch_price_data_from_yf(tickers, start_date, end_date) + upsert_price_data(db_name, df) + + +# ============================= +# 2) 재무제표 수집/업서트 +# ============================= +def fetch_financials_from_yf(tickers: List[str]) -> pd.DataFrame: + """ + yfinance 의 재무제표(fast 버전) → long 형태로 풀어서 저장 + - 손익계산서: IS + - 재무상태표: BS + - 현금흐름표: CF + """ + rows = [] + for t in tickers: + print(f"[FS] Fetch financials for {t}") + yf_t = yf.Ticker(t) + + # annual / quarterly 예시 (필요에 따라 둘 다 or 하나만) + fs_map = [ + ("IS", "annual", yf_t.financials), + ("BS", "annual", yf_t.balance_sheet), + ("CF", "annual", yf_t.cashflow), + ("IS", "quarterly", yf_t.quarterly_financials), + ("BS", "quarterly", yf_t.quarterly_balance_sheet), + ("CF", "quarterly", yf_t.quarterly_cashflow), + ] + + for fs_type, freq, df in fs_map: + if df is None or df.empty: + continue + # columns: 보고일(date), index: 항목 + df = df.copy() + df.columns = pd.to_datetime(df.columns).date + for report_date in df.columns: + for item, value in df[report_date].items(): + if pd.isna(value): + continue + rows.append( + { + "ticker": t, + "report_date": report_date, + "fs_type": fs_type, + "item": str(item), + "value": float(value), + "currency": None, # 필요하면 yfinance info에서 끌어와도 됨 + "freq": freq, + } + ) + + if not rows: + return pd.DataFrame( + columns=["ticker", "report_date", "fs_type", "item", "value", "currency", "freq"] + ) + + return pd.DataFrame(rows) + + +def upsert_financials(db_name: str, df: pd.DataFrame): + if df.empty: + print("[FS] No financials to upsert.") + return + + conn = get_db_conn(db_name) + try: + records = df[ + ["ticker", "report_date", "fs_type", "item", "value", "currency", "freq"] + ].to_records(index=False) + + sql = """ + INSERT INTO public.financials + (ticker, report_date, fs_type, item, value, currency, freq) + VALUES %s + ON CONFLICT (ticker, report_date, fs_type, item) DO UPDATE SET + value = EXCLUDED.value, + currency = COALESCE(EXCLUDED.currency, public.financials.currency), + freq = EXCLUDED.freq; + """ + with conn.cursor() as cur: + execute_values(cur, sql, records) + conn.commit() + print(f"[FS] Upserted {len(df)} rows into public.financials") + finally: + conn.close() + + +def run_financials_pipeline(config: Dict[str, Any]): + db_name = config["db_name"] + tickers = config["tickers_for_fs"] + df = fetch_financials_from_yf(tickers) + upsert_financials(db_name, df) + + +# ============================= +# 3) 거시지표 수집/업서트 +# ============================= +def fetch_macro_from_yf(series_map: Dict[str, str], start: str, end: str) -> pd.DataFrame: + """ + series_map: {내부코드: yfinance_티커} 형태 + 예: {"US10Y": "^TNX", "KOSPI": "^KS11", "KRWUSD": "KRW=X"} + """ + rows = [] + for code, yf_symbol in series_map.items(): + print(f"[MACRO] Fetch {code}({yf_symbol}) {start}~{end}") + df = yf.download(yf_symbol, start=start, end=end, auto_adjust=False) + if df.empty: + continue + df.index.name = "date" + df = df.reset_index() + df["date"] = pd.to_datetime(df["date"]).dt.date + for _, r in df.iterrows(): + # 여기서는 종가만 value로 사용 (필요하면 다른 컬럼도 가능) + value = r.get("Close") + if pd.isna(value): + continue + rows.append( + { + "series_code": code, + "date": r["date"], + "value": float(value), + "meta": None, + } + ) + + if not rows: + return pd.DataFrame(columns=["series_code", "date", "value", "meta"]) + return pd.DataFrame(rows) + + +def upsert_macro(db_name: str, df: pd.DataFrame): + if df.empty: + print("[MACRO] No macro data to upsert.") + return + + conn = get_db_conn(db_name) + try: + records = df[["series_code", "date", "value", "meta"]].to_records(index=False) + sql = """ + INSERT INTO public.macro_data + (series_code, date, value, meta) + VALUES %s + ON CONFLICT (series_code, date) DO UPDATE SET + value = EXCLUDED.value, + meta = COALESCE(EXCLUDED.meta, public.macro_data.meta); + """ + with conn.cursor() as cur: + execute_values(cur, sql, records) + conn.commit() + print(f"[MACRO] Upserted {len(df)} rows into public.macro_data") + finally: + conn.close() + + +def run_macro_pipeline(config: Dict[str, Any]): + db_name = config["db_name"] + series_map = config["macro_series"] + + last = get_last_date_in_table(db_name, "public.macro_data", "date") + if last is None: + start_date = config.get("macro_start", "2010-01-01") + else: + start_date = (last + timedelta(days=1)).strftime("%Y-%m-%d") + end_date = today_kst().strftime("%Y-%m-%d") + + if start_date > end_date: + print("[MACRO] Already up to date.") + return + + df = fetch_macro_from_yf(series_map, start_date, end_date) + upsert_macro(db_name, df) + + +# ============================= +# 메인: 하루 한 번 돌릴 거 +# ============================= +def run_all(): + today = today_kst().strftime("%Y-%m-%d") + print(f"=== DAILY INGEST ({today}) ===") + + config = { + "db_name": "db", # get_db_conn 에서 쓰는 이름 (config.json 의 키) + + # 1) 주가 + "tickers": ["AAPL", "MSFT", "TSLA", "^KS11"], # 네가 원하는 티커들 + "price_start": "2018-01-01", + + # 2) 재무제표를 받을 티커 (보통은 개별 주식만) + "tickers_for_fs": ["AAPL", "MSFT", "TSLA"], + + # 3) 거시지표: {내부코드: yfinance 심볼} + "macro_series": { + "US10Y": "^TNX", + "KOSPI": "^KS11", + "KRWUSD": "KRW=X", + }, + "macro_start": "2010-01-01", + } + + # 필요한 것만 골라서 돌리면 됨 + run_price_pipeline(config) + run_financials_pipeline(config) + run_macro_pipeline(config) + + print("=== DAILY INGEST DONE ===") + + +if __name__ == "__main__": + run_all() diff --git a/AI/libs/utils/save_reports_to_db.py b/AI/libs/utils/save_reports_to_db.py index 6b833fdf..6cda73d5 100644 --- a/AI/libs/utils/save_reports_to_db.py +++ b/AI/libs/utils/save_reports_to_db.py @@ -1,53 +1,56 @@ -# libs/core/save_reports_to_db.py +# libs/utils/save_reports_to_db.py from __future__ import annotations -from typing import Iterable, Tuple, List +from typing import Iterable, Tuple, List, Optional from datetime import datetime, timezone -import sys -from sqlalchemy import create_engine, text +from decimal import Decimal import os -# --- 프로젝트 루트 경로 설정 --- -project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -sys.path.append(project_root) -# ------------------------------ +from sqlalchemy import text +# 내부 유틸에서 엔진만 사용 (스키마는 절대 변경 X) from libs.utils.get_db_conn import get_engine -ReportRow = Tuple[str, str, float, str, str] +ReportRow = Tuple[str, str, float, str, str] # (ticker, signal, price, date_str, report_text) +# ----- 환경 변수로 자산 테이블/컬럼 지정 (기본값 제공) ----- +ASSETS_TABLE = os.getenv("ASSETS_TABLE", "assets") +ASSETS_ID_COLUMN = os.getenv("ASSETS_ID_COLUMN", "id") +ASSETS_CASH_COLUMN = os.getenv("ASSETS_CASH_COLUMN", "cash") +ASSETS_ROW_ID = os.getenv("ASSETS_ROW_ID", "1") + +# ----- 유틸 ----- def utcnow() -> datetime: return datetime.now(timezone.utc) +def _to_decimal(x) -> Decimal: + if isinstance(x, Decimal): + return x + try: + return Decimal(str(x)) + except Exception: + return Decimal(0) -def ensure_table_schema(engine) -> None: - """ - 한국어 주석: - - 정보스키마 조회 후 필요한 컬럼만 추가. - """ - with engine.begin() as conn: - cols = conn.execute(text(""" - SELECT column_name FROM information_schema.columns - WHERE table_schema='public' AND table_name='xai_reports'; - """)).fetchall() - existing = {r[0] for r in cols} - need = {"ticker", "signal", "price", "date", "report", "created_at"} - missing = need - existing - if missing: - parts = [] - if "ticker" in missing: parts.append("ADD COLUMN IF NOT EXISTS ticker varchar(20) NOT NULL") - if "signal" in missing: parts.append("ADD COLUMN IF NOT EXISTS signal varchar(10) NOT NULL") - if "price" in missing: parts.append("ADD COLUMN IF NOT EXISTS price numeric(10,2) NOT NULL") - if "date" in missing: parts.append("ADD COLUMN IF NOT EXISTS date date NOT NULL") - if "report" in missing: parts.append("ADD COLUMN IF NOT EXISTS report text") - if "created_at" in missing: - parts.append("ADD COLUMN IF NOT EXISTS created_at timestamptz NOT NULL DEFAULT now()") - conn.execute(text(f"ALTER TABLE public.xai_reports {', '.join(parts)};")) +def _fetch_current_cash(conn) -> Optional[Decimal]: + sql = text(f""" + SELECT {ASSETS_CASH_COLUMN} + FROM public.{ASSETS_TABLE} + WHERE {ASSETS_ID_COLUMN} = :rid + FOR UPDATE + """) + row = conn.execute(sql, {"rid": ASSETS_ROW_ID}).fetchone() + if not row: + return None + return _to_decimal(row[0]) -def build_insert_params(rows: Iterable[ReportRow], created_at: datetime) -> List[dict]: - """ - 한국어 주석: - - SQLAlchemy의 named parameter 형태(dict)로 변환. - """ +def _update_cash(conn, new_cash: Decimal) -> None: + sql = text(f""" + UPDATE public.{ASSETS_TABLE} + SET {ASSETS_CASH_COLUMN} = :cash + WHERE {ASSETS_ID_COLUMN} = :rid + """) + conn.execute(sql, {"cash": str(new_cash), "rid": ASSETS_ROW_ID}) + +def _build_insert_params(rows: Iterable[ReportRow], created_at: datetime) -> List[dict]: out: List[dict] = [] for (ticker, signal, price, date_s, report_text) in rows: if not ticker or not signal or not date_s: @@ -56,44 +59,47 @@ def build_insert_params(rows: Iterable[ReportRow], created_at: datetime) -> List "ticker": ticker, "signal": signal, "price": float(price), - "date": date_s, # 'YYYY-MM-DD' + "date": date_s, # 'YYYY-MM-DD' "report": str(report_text), "created_at": created_at, }) return out +# ----- 메인: 1주 고정 체결 + 자산 업데이트 + 리포트 저장 ----- def save_reports_to_db(rows: List[ReportRow], db_name: str) -> int: """ - 한국어 주석: - - SQLAlchemy로 안전하게 INSERT. - - pandas 경고 제거, 커넥션 관리 자동화, 프리핑으로 죽은 커넥션 방지. + 요구사항: + - 저장 '직전'에 티커/시그널/가격을 보고 1주만 체결 + - 매 건 체결 후 잔여 현금(자산) 업데이트 + - DB 스키마 변경 금지 (xai_reports는 기존대로 INSERT만) """ if not rows: print("[INFO] 저장할 리포트가 없습니다.") return 0 engine = get_engine(db_name) - ensure_table_schema(engine) - created_at = utcnow() - params = build_insert_params(rows, created_at) - if not params: - print("[WARN] 유효한 저장 파라미터가 없어 INSERT를 건너뜁니다.") - return 0 + # INSERT 템플릿 (스키마는 건드리지 않음) insert_sql = text(""" INSERT INTO public.xai_reports (ticker, signal, price, date, report, created_at) VALUES (:ticker, :signal, :price, :date, :report, :created_at) """) inserted = 0 - # 대량이면 청크 분할 권장 - CHUNK = 1000 with engine.begin() as conn: - for i in range(0, len(params), CHUNK): - batch = params[i:i+CHUNK] - conn.execute(insert_sql, batch) - inserted += len(batch) - - print(f"--- {inserted}개의 XAI 리포트가 데이터베이스에 저장되었습니다. ---") - return inserted + # 현금 락 걸고 읽기 + current_cash = _fetch_current_cash(conn) + if current_cash is None: + # 자산 테이블이 없거나 행이 없으면 바로 저장만 수행 + print(f"[WARN] 자산 테이블 public.{ASSETS_TABLE}에서 행을 찾지 못했어요. 체결 없이 리포트만 저장할게요.") + params = _build_insert_params(rows, created_at) + if params: + # 청크 삽입 + CHUNK = 1000 + for i in range(0, len(params), CHUNK): + batch = params[i:i+CHUNK] + conn.execute(insert_sql, batch) + inserted += len(batch) + print(f"--- {inserted}개의 XAI 리포트가 저장되었습니다 (자산 미적용). ---") + return inserted \ No newline at end of file diff --git a/AI/requirements.txt b/AI/requirements.txt index bdeab51d..47825679 100644 --- a/AI/requirements.txt +++ b/AI/requirements.txt @@ -11,4 +11,5 @@ yfinance groq requests beautifulsoup4 -pathlib \ No newline at end of file +pathlib +aa \ No newline at end of file diff --git a/AI/transformer/training/train_transformer.py b/AI/transformer/training/train_transformer.py index 3f20c887..bc03476d 100644 --- a/AI/transformer/training/train_transformer.py +++ b/AI/transformer/training/train_transformer.py @@ -381,7 +381,7 @@ def train_transformer_classifier( pickle.dump(scaler, f) _log(f"[TRAIN] Scaler saved: {scaler_out_path}") - _log(f"[TRAIN] Weights saved(best): {model_out_path}") + _log(f"[TRAIN] Weights saved(최고): {model_out_path}") return { "history": history.history, "n_samples": int(len(X)), @@ -538,17 +538,20 @@ def run_training(config: dict): _log(f"Weights: {result['model_path']}") _log(f"Scaler : {result['scaler_path']}") - if __name__ == "__main__": + from datetime import datetime + + # 오늘 날짜 YYYYMMDD 문자열 생성 + today_str = datetime.now().strftime("%Y%m%d") config = { # --- 데이터/티커 소스 --- - "tickers_source": "db", # 티커: DB에서 전체 로드 - "use_adjusted_close": True, # adjusted_close가 있으면 close로 사용 - "db_chunk_size": 200, # IN 청크 크기(파라미터/성능 균형) + "tickers_source": "db", + "use_adjusted_close": True, + "db_chunk_size": 200, # --- 기간/빈도 --- - "start": "2018-01-01", # DB DATE 기준 (YYYY-MM-DD) + "start": "2018-01-01", "end": "2024-10-31", # --- 시퀀스/라벨 --- @@ -561,8 +564,8 @@ def run_training(config: dict): "epochs": 50, "batch_size": 512, - # --- 출력 경로 --- - "model_out": "AI/transformer/weights/initial.weights.h5", + # --- 출력 경로 (오늘날짜 적용) --- + "model_out": f"AI/transformer/weights/{today_str}.weights.h5", "scaler_out": "AI/transformer/scaler/scaler.pkl", # --- 기타 ---