diff --git a/AI/backtrade/__init__.py b/AI/backtrade/__init__.py new file mode 100644 index 00000000..76aabec1 --- /dev/null +++ b/AI/backtrade/__init__.py @@ -0,0 +1,3 @@ +#AI/backtrade/__init__.py +from .main import backtrade, BacktradeConfig +__all__ = ["backtrade", "BacktradeConfig"] \ No newline at end of file diff --git a/AI/backtrade/main.py b/AI/backtrade/main.py new file mode 100644 index 00000000..11672835 --- /dev/null +++ b/AI/backtrade/main.py @@ -0,0 +1,159 @@ +# backtrade/main.py +""" +한국어 주석: +- OHLCV 없이, Transformer 결정 로그(decision_log)의 price만으로 + 간소화된 백테스트를 수행하는 환경(Environment) 역할. +- 수량/포지션 결정은 backtrade/order_policy.py 모듈로 분리됨. +""" + +from __future__ import annotations +from dataclasses import dataclass +from typing import Optional, Dict, Tuple, List +import pandas as pd +import numpy as np + +from backtrade.order_policy import decide_order # 분리된 정책 모듈 import + + +# === 설정 클래스 === +@dataclass +class BacktradeConfig: + """ + 한국어 주석: + - 간소화 백테스터 설정 + - 향후 강화학습 환경 초기화 시에도 그대로 사용 가능 + """ + initial_cash: float = 100_000.0 + slippage_bps: float = 5.0 + commission_bps: float = 3.0 + risk_frac: float = 0.2 + max_positions_per_ticker: int = 1 + fill_on_same_day: bool = True + + +# === 내부 유틸 === +def _apply_price_with_slippage(price: float, side: str, slippage_bps: float) -> float: + """슬리피지를 체결가에 반영""" + adj = 1.0 + (slippage_bps / 10_000.0) * (1 if side.upper() == "BUY" else -1) + return float(price) * adj + + +def _apply_commission(value: float, commission_bps: float) -> float: + """체결 금액에 대해 bps 단위 수수료 계산""" + return abs(value) * (commission_bps / 10_000.0) + + +def _fill_date_from_signal(sig_date: pd.Timestamp, same_day: bool) -> pd.Timestamp: + """OHLCV 없이 동일일 또는 다음날 체결로 단순 처리""" + return sig_date if same_day else (sig_date + pd.Timedelta(days=1)) + + +# === 백테스트 본체 === +def backtrade( + decision_log: pd.DataFrame, + config: Optional[BacktradeConfig] = None, + run_id: Optional[str] = None, +) -> Tuple[pd.DataFrame, Dict]: + """ + 한국어 주석: + - 입력: Transformer 의사결정 로그(decision_log) + - 처리: 가격 기반 슬리피지·수수료 반영 후 체결/포지션 갱신 + - 반환: (fills_df, summary) + """ + if config is None: + config = BacktradeConfig() + + dl = decision_log.copy() + if not {"ticker", "date", "action", "price"}.issubset(dl.columns): + raise ValueError("decision_log에 'ticker','date','action','price' 컬럼이 필요합니다.") + + dl["date"] = pd.to_datetime(dl["date"]) + dl = dl.sort_values(["date", "ticker"]).reset_index(drop=True) + + cash = float(config.initial_cash) + positions: Dict[str, Dict[str, float]] = {} + records: List[Dict] = [] + + for _, r in dl.iterrows(): + ticker = str(r["ticker"]) + sig_date = pd.Timestamp(r["date"]) + sig = str(r["action"]).upper() + sig_price = float(r.get("price", np.nan)) + + xai_id = r.get("xai_report_id") + + if sig not in ("BUY", "SELL"): + continue + + fill_date = _fill_date_from_signal(sig_date, config.fill_on_same_day) + fill_price = _apply_price_with_slippage(sig_price, sig, config.slippage_bps) + + pos = positions.get(ticker, {"qty": 0, "avg": 0.0}) + cur_qty = pos["qty"] + avg_price = pos["avg"] + side = "BUY" if sig == "BUY" else "SELL" + + # === 🔹 체결 정책 호출 (외부 모듈) === + qty, trade_value = decide_order( + side=side, + cash=cash, + cur_qty=cur_qty, + avg_price=avg_price, + fill_price=fill_price, + config=config, + ) + + if qty <= 0: + continue + + # === 나머지는 환경의 기계적 계산 === + commission = _apply_commission(trade_value, config.commission_bps) + cash_after = cash - trade_value - commission + + # 포지션 업데이트 + if side == "BUY": + new_qty = cur_qty + qty + new_avg = (avg_price * cur_qty + fill_price * qty) / max(1, new_qty) + else: + new_qty = cur_qty - qty + new_avg = avg_price if new_qty > 0 else 0.0 + + pnl_realized = 0.0 + if side == "SELL": + pnl_realized = (fill_price - avg_price) * qty + + pnl_unrealized = 0.0 + + # 상태 저장 + cash = cash_after + positions[ticker] = {"qty": new_qty, "avg": new_avg} + + records.append({ + "run_id": run_id, + "xai_report_id": xai_id, + "ticker": ticker, + "signal_date": sig_date.date().isoformat(), + "signal_price": float(sig_price), + "signal": sig, + "fill_date": fill_date.date().isoformat(), + "fill_price": float(fill_price), + "qty": int(qty), + "side": side, + "value": float(trade_value), + "commission": float(commission), + "cash_after": float(cash_after), + "position_qty": int(new_qty), + "avg_price": float(new_avg), + "pnl_realized": float(pnl_realized), + "pnl_unrealized": float(pnl_unrealized), + }) + + fills = pd.DataFrame.from_records(records) + summary = { + "run_id": run_id, + "trades": int(len(fills)), + "cash_final": float(cash), + "pnl_realized_sum": float(fills["pnl_realized"].sum()) if not fills.empty else 0.0, + "commission_sum": float(fills["commission"].sum()) if not fills.empty else 0.0, + } + return fills, summary diff --git a/AI/backtrade/order_policy.py b/AI/backtrade/order_policy.py new file mode 100644 index 00000000..3846d2ee --- /dev/null +++ b/AI/backtrade/order_policy.py @@ -0,0 +1,85 @@ +# backtrade/order_policy.py +# -*- coding: utf-8 -*- +""" +한국어 주석: +- 백테스터의 '체결 수량 및 포지션 결정 로직'을 별도 모듈로 분리. +- 현재는 단순 rule-based이며, 향후 강화학습(Agent) 정책으로 교체할 수 있음. +""" + +from __future__ import annotations +from typing import Dict, Tuple + + +def decide_order( + side: str, + cash: float, + cur_qty: int, + avg_price: float, + fill_price: float, + config, +) -> Tuple[int, float]: + """ + 한국어 주석: + - 체결 수량 및 거래금액을 결정하는 핵심 함수. + - 강화학습 Agent가 교체할 대상 부분. + ----------------------------- + 입력값: + side: "BUY" 또는 "SELL" + cash: 현재 현금 잔고 + cur_qty: 현재 보유 주식 수량 + avg_price: 현재 보유 평균단가 + fill_price: 이번 체결 기준가 (슬리피지 반영 전) + config: BacktradeConfig 인스턴스 + 반환값: + (qty, trade_value) + qty: 매수/매도 수량 + trade_value: 체결 총액(+BUY 지출, -SELL 유입) + """ + + qty = 0 + trade_value = 0.0 + + if side == "BUY": + # 현금 중 risk_frac 비율만큼 투자 + cash_to_use = max(0.0, cash * config.risk_frac) + qty = int(cash_to_use // fill_price) + + # 동시 보유 제한 + if config.max_positions_per_ticker == 1 and cur_qty > 0: + qty = 0 + + trade_value = fill_price * qty # BUY → 현금 지출 (+) + + elif side == "SELL": + qty = cur_qty # 전량 청산 + trade_value = -fill_price * qty # SELL → 현금 유입 (-) + + return qty, trade_value + + +# === 확장용 RL 정책 클래스 === +class RLOrderPolicy: + """ + 한국어 주석: + - 강화학습 Agent를 위한 placeholder 클래스. + - 현재는 rule-based decide_order를 그대로 호출하지만, + 이후 RL 모델의 action 출력을 이용해 수량/금액을 결정할 수 있다. + """ + + def __init__(self, model=None): + self.model = model # RL 네트워크 or 정책 객체 + + def decide(self, state: Dict, config) -> Tuple[int, float]: + """ + state: {'cash':..., 'price':..., 'pos':..., 'side':...} + 반환: (qty, trade_value) + """ + side = state.get("side", "BUY") + return decide_order( + side=side, + cash=state.get("cash", 0.0), + cur_qty=state.get("cur_qty", 0), + avg_price=state.get("avg_price", 0.0), + fill_price=state.get("price", 0.0), + config=config, + ) diff --git a/AI/configs/config.json b/AI/configs/config.json deleted file mode 100644 index 36e05aff..00000000 --- a/AI/configs/config.json +++ /dev/null @@ -1,19 +0,0 @@ -{ - "db": { - "host": "ep-misty-lab-adgec0kl-pooler.c-2.us-east-1.aws.neon.tech", - "user": "neondb_owner", - "password": "npg_hWkg04MwGlYs", - "dbname": "neondb", - "port": 5432 - } - - - , - "report_DB": { - "host": "ep-jolly-waterfall-ads25fir-pooler.c-2.us-east-1.aws.neon.tech", - "user": "neondb_owner", - "password": "npg_lo0rC9aOyFkw", - "dbname": "neondb", - "port": 5432 - } -} diff --git a/AI/daily_data_collection/__init__.py b/AI/daily_data_collection/__init__.py new file mode 100644 index 00000000..7e5ba52f --- /dev/null +++ b/AI/daily_data_collection/__init__.py @@ -0,0 +1,3 @@ +#AI/daily_data_collection/__init__.py +from .main import run_data_collection +__all__ = ["run_data_collection"] \ No newline at end of file diff --git a/AI/daily_data_collection/main.py b/AI/daily_data_collection/main.py new file mode 100644 index 00000000..1bbbe3e0 --- /dev/null +++ b/AI/daily_data_collection/main.py @@ -0,0 +1,1165 @@ +# ====================================================================== +# SECTION 1 — Imports + 경로 설정 + 공용 유틸 + DB 전체 티커 로딩 +# ====================================================================== + +from __future__ import annotations + +import os +import sys +from datetime import datetime, timedelta, timezone +from typing import List, Dict, Any, Optional + +import numpy as np +import pandas as pd +import yfinance as yf +from psycopg2.extras import execute_values +from fredapi import Fred + +# ---------------------------------------------------------------------- +# 프로젝트 루트 경로를 sys.path에 추가 (절대경로 import 문제 해결) +# ---------------------------------------------------------------------- +project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +if project_root not in sys.path: + sys.path.append(project_root) + +# DB 연결 함수 가져오기 +from libs.utils.get_db_conn import get_db_conn, get_engine + + +# ---------------------------------------------------------------------- +# 한국 표준시 (KST) +# ---------------------------------------------------------------------- +KST = timezone(timedelta(hours=9)) + + +# ====================================================================== +# 공용 유틸 함수 +# ====================================================================== + +def today_kst() -> datetime.date: + """한국(KST) 기준 today's date 반환.""" + return datetime.now(KST).date() + + +def get_last_date_in_table(db_name: str, table: str, date_col: str) -> Optional[datetime.date]: + """ + 테이블의 날짜 컬럼(date_col)에서 MAX(date)를 얻는 함수 + """ + from sqlalchemy import text + engine = get_engine(db_name) + + with engine.connect() as conn: + res = conn.execute(text(f"SELECT MAX({date_col}) FROM {table};")).scalar() + + return res if res is not None else None + + +# ====================================================================== +# ✨ DB 전체 티커 로딩 함수 (핵심 기능 추가) +# ====================================================================== + +def get_all_tickers_from_db(db_name: str) -> List[str]: + """ + public.price_data 에서 존재하는 모든 ticker를 DISTINCT 로 가져오는 함수. + manual_backfill_all() 에서 전체 티커를 자동으로 사용하기 위해 도입. + """ + from sqlalchemy import text + engine = get_engine(db_name) + + with engine.connect() as conn: + rows = conn.execute(text(""" + SELECT DISTINCT ticker + FROM public.price_data + ORDER BY ticker; + """)).fetchall() + + return [r[0] for r in rows] + + +# ====================================================================== +# ✨ Series, numpy 자료형 등 → 스칼라 float 변환 함수 (중요) +# ====================================================================== + +def to_scalar(v): + """ + yfinance 또는 pandas Series에서 발생하는 + numpy.float64 / ndarray / Series 형태를 모두 float로 정규화. + """ + if isinstance(v, (np.generic, np.float64, np.int64)): + return float(v) + + if isinstance(v, (list, tuple, np.ndarray)): + return float(v[0]) if len(v) > 0 else None + + if isinstance(v, pd.Series): + # Series인 경우 iloc으로 첫 값만 사용 + return float(v.iloc[0]) if len(v) > 0 else None + + # 이미 파이썬 float인 경우 + try: + return float(v) + except Exception: + return None +# ====================================================================== +# SECTION 2 — PRICE_DATA FETCH / UPSERT / PIPELINE +# ====================================================================== + +# -------------------------------------------------------------- +# DataFrame → 파이썬 records (numpy / Series 모두 스칼라로 변환) +# -------------------------------------------------------------- + +def df_to_python_records_price(df: pd.DataFrame): + """ + price_data DataFrame 을 Python tuple 리스트로 변환. + - 스키마 기준 8컬럼만 사용한다. + - MultiIndex / 중복 컬럼 이름이 있어도 첫 번째 컬럼만 사용하도록 정규화한다. + """ + # MultiIndex -> 1단계 이름으로 평탄화 + if isinstance(df.columns, pd.MultiIndex): + df.columns = df.columns.get_level_values(0) + + target_cols = ["ticker", "date", "open", "high", "low", "close", "volume", "adjusted_close"] + + # 중복 컬럼 / DataFrame 반환 케이스 안전하게 처리 + safe_dict = {} + for c in target_cols: + if c not in df.columns: + # 아예 없으면 전부 None + safe_dict[c] = pd.Series([None] * len(df), index=df.index) + continue + + col = df[c] + # 같은 이름의 컬럼이 여러 개인 경우 df[c] 가 DataFrame 이라서 + # 첫 번째 컬럼만 사용 + if isinstance(col, pd.DataFrame): + safe_dict[c] = col.iloc[:, 0] + else: + safe_dict[c] = col + + df = pd.DataFrame(safe_dict) + + records = [] + + def scalar_or_none(v): + if v is None or pd.isna(v): + return None + if isinstance(v, (int, float)): + return v + try: + return float(v) + except Exception: + return None + + # 여기서는 무조건 8개 컬럼만 존재 + for ticker, date, open_, high_, low_, close_, volume, adjusted_close in df.itertuples(index=False, name=None): + records.append( + ( + str(ticker), + date, + scalar_or_none(open_), + scalar_or_none(high_), + scalar_or_none(low_), + scalar_or_none(close_), + int(volume) if (volume is not None and not pd.isna(volume)) else None, + scalar_or_none(adjusted_close), + ) + ) + + return records + + +# -------------------------------------------------------------- +# yfinance 일봉 수집 +# -------------------------------------------------------------- +def fetch_price_data_from_yf(tickers: List[str], start: str, end: str) -> pd.DataFrame: + """ + yfinance 를 '티커 1개씩' 호출해서 + 항상 세로(long) 구조의 8컬럼 DF를 반환한다. + """ + frames = [] + + for t in tickers: + print(f"[PRICE] Fetch {t} {start}~{end}") + df = yf.download(t, start=start, end=end, auto_adjust=False) + + # MultiIndex 반환되는 경우 강제 단일화 + if isinstance(df.columns, pd.MultiIndex): + df.columns = df.columns.get_level_values(0) + + + if df.empty: + print(f"[PRICE] No data for {t}") + continue + + # index = DatetimeIndex + df.index.name = "date" + df = df.reset_index() + + # 표준 컬럼명으로 통일 + df = df.rename( + columns={ + "Date": "date", + "Open": "open", + "High": "high", + "Low": "low", + "Close": "close", + "Adj Close": "adjusted_close", + "Volume": "volume", + } + ) + + df["ticker"] = t + df["date"] = pd.to_datetime(df["date"]).dt.date + + df = df[ + ["ticker", "date", "open", + "high", "low", "close", + "volume", "adjusted_close"] + ] + + frames.append(df) + + if not frames: + return pd.DataFrame( + columns=[ + "ticker", "date", "open", "high", "low", + "close", "volume", "adjusted_close" + ] + ) + + return pd.concat(frames, ignore_index=True) + + + +# -------------------------------------------------------------- +# price_data UPSERT +# -------------------------------------------------------------- +def upsert_price_data(db_name: str, df: pd.DataFrame, batch_size: int = 10_000) -> None: + """ + public.price_data 테이블에 (ticker, date) 기준으로 UPSERT 수행. + + - df 전체를 한 번에 records 로 바꾸지 않고, + batch_size 단위로 잘라서 DB에 넣음 → 메모리 폭발 방지. + """ + if df.empty: + print("[PRICE] No new data to upsert.") + return + + sql = """ + INSERT INTO public.price_data + (ticker, date, open, high, low, close, volume, adjusted_close) + VALUES %s + ON CONFLICT (ticker, date) DO UPDATE SET + open = EXCLUDED.open, + high = EXCLUDED.high, + low = EXCLUDED.low, + close = EXCLUDED.close, + volume = EXCLUDED.volume, + adjusted_close = EXCLUDED.adjusted_close; + """ + + conn = get_db_conn(db_name) + try: + with conn.cursor() as cur: + # df 를 행 기준으로 잘라서 차례대로 INSERT + total_rows = len(df) + for start in range(0, total_rows, batch_size): + end = min(start + batch_size, total_rows) + batch_df = df.iloc[start:end] + records = df_to_python_records_price(batch_df) + + if not records: + continue + + execute_values(cur, sql, records) + + print(f"[PRICE] Upserted rows {start} ~ {end-1} / {total_rows}") + + conn.commit() + print(f"[PRICE] Upserted TOTAL {len(df)} rows into public.price_data") + + finally: + conn.close() + + + +# -------------------------------------------------------------- +# (자동용) price pipeline: 증분 업데이트 +# -------------------------------------------------------------- +def run_price_pipeline(config: Dict[str, Any]) -> None: + db_name = config["db_name"] + tickers = config["tickers"] + + last = get_last_date_in_table(db_name, "public.price_data", "date") + + if last is None: + start_date = config.get("price_start", "2017-01-01") + print(f"[PRICE] 기존 없음 → {start_date} 부터 시작") + else: + start_date = (last + timedelta(days=1)).strftime("%Y-%m-%d") + print(f"[PRICE] 마지막 날짜 {last} 이후 → {start_date} 부터 수집") + + end_date = today_kst().strftime("%Y-%m-%d") + + if start_date > end_date: + print("[PRICE] Already up to date.") + return + + df = fetch_price_data_from_yf(tickers, start_date, end_date) + upsert_price_data(db_name, df) +# ====================================================================== +# SECTION 3 — TECHNICAL INDICATORS (계산 + UPSERT + PIPELINE) +# ====================================================================== + +# -------------------------------------------------------------- +# 기술지표 계산 (티커 1개 단위) +# -------------------------------------------------------------- +def compute_technical_indicators(df: pd.DataFrame) -> pd.DataFrame: + """ + 입력 df: ticker, date, open, high, low, close, volume + 출력 df: ticker, date, RSI, MACD, Bollinger_Bands_upper, ... MA_200 + """ + + df = df.sort_values("date").reset_index(drop=True) + + + close = df["close"] + high = df["high"] + low = df["low"] + volume = df["volume"] + # 결측치 보정 (앞뒤로 채우기) + df["close"] = df["close"].ffill().bfill() + df["high"] = df["high"].ffill().bfill() + df["low"] = df["low"].ffill().bfill() + df["volume"] = df["volume"].fillna(0) + + + # ------------------------- RSI ------------------------- + delta = close.diff() + gain = delta.clip(lower=0) + loss = -delta.clip(upper=0) + + window_rsi = 14 + avg_gain = gain.rolling(window_rsi).mean() + avg_loss = loss.rolling(window_rsi).mean() + rs = avg_gain / (avg_loss + 1e-9) + rsi = 100 - (100 / (1 + rs)) + + # ------------------------- MACD ------------------------- + ema12 = close.ewm(span=12, adjust=False).mean() + ema26 = close.ewm(span=26, adjust=False).mean() + macd = ema12 - ema26 + + # ------------------------- Bollinger Bands ------------------------- + ma20 = close.rolling(20).mean() + std20 = close.rolling(20).std() + bb_upper = ma20 + 2 * std20 + bb_lower = ma20 - 2 * std20 + + # ------------------------- ATR ------------------------- + prev_close = close.shift(1) + tr1 = high - low + tr2 = (high - prev_close).abs() + tr3 = (low - prev_close).abs() + tr = pd.concat([tr1, tr2, tr3], axis=1).max(axis=1) + atr = tr.rolling(14).mean() + + # ------------------------- OBV ------------------------- + obv = [0] + for i in range(1, len(close)): + if close.iloc[i] > close.iloc[i-1]: + obv.append(obv[-1] + volume.iloc[i]) + elif close.iloc[i] < close.iloc[i-1]: + obv.append(obv[-1] - volume.iloc[i]) + else: + obv.append(obv[-1]) + obv = pd.Series(obv, index=df.index) + + # ------------------------- Stochastic ------------------------- + lowest_low = low.rolling(14).min() + highest_high = high.rolling(14).max() + stochastic = (close - lowest_low) / (highest_high - lowest_low + 1e-9) * 100 + + # ------------------------- MFI ------------------------- + typical_price = (high + low + close) / 3 + raw_mf = typical_price * volume + tp_diff = typical_price.diff() + pos_mf = raw_mf.where(tp_diff > 0, 0) + neg_mf = raw_mf.where(tp_diff < 0, 0) + pos_14 = pos_mf.rolling(14).sum() + neg_14 = neg_mf.abs().rolling(14).sum() + 1e-9 + mfr = pos_14 / neg_14 + mfi = 100 - (100 / (1 + mfr)) + + # ------------------------- MA ------------------------- + ma_5 = close.rolling(5).mean() + ma_20_roll = close.rolling(20).mean() + ma_50 = close.rolling(50).mean() + ma_200 = close.rolling(200).mean() + + out = pd.DataFrame({ + "ticker": df["ticker"], + "date": df["date"], + "RSI": rsi, + "MACD": macd, + "Bollinger_Bands_upper": bb_upper, + "Bollinger_Bands_lower": bb_lower, + "ATR": atr, + "OBV": obv, + "Stochastic": stochastic, + "MFI": mfi, + "MA_5": ma_5, + "MA_20": ma_20_roll, + "MA_50": ma_50, + "MA_200": ma_200, + }) + + return out + + +# -------------------------------------------------------------- +# UPSERT for technical_indicators +# -------------------------------------------------------------- +def upsert_technical_indicators(db_name: str, df: pd.DataFrame, batch_size: int = 50_000): + if df.empty: + print("[TECH] No technical indicators to upsert.") + return + + sql = """ + INSERT INTO public.technical_indicators ( + ticker, date, + RSI, MACD, + Bollinger_Bands_upper, Bollinger_Bands_lower, + ATR, OBV, Stochastic, MFI, + MA_5, MA_20, MA_50, MA_200 + ) + VALUES %s + ON CONFLICT (ticker, date) DO UPDATE SET + RSI = EXCLUDED.RSI, + MACD = EXCLUDED.MACD, + Bollinger_Bands_upper = EXCLUDED.Bollinger_Bands_upper, + Bollinger_Bands_lower = EXCLUDED.Bollinger_Bands_lower, + ATR = EXCLUDED.ATR, + OBV = EXCLUDED.OBV, + Stochastic = EXCLUDED.Stochastic, + MFI = EXCLUDED.MFI, + MA_5 = EXCLUDED.MA_5, + MA_20 = EXCLUDED.MA_20, + MA_50 = EXCLUDED.MA_50, + MA_200 = EXCLUDED.MA_200; + """ + + conn = get_db_conn(db_name) + try: + with conn.cursor() as cur: + total = len(df) + for start in range(0, total, batch_size): + end = min(start + batch_size, total) + batch = df.iloc[start:end] + + records = [] + for _, r in batch.iterrows(): + records.append(( + r["ticker"], r["date"], + to_scalar(r["RSI"]), to_scalar(r["MACD"]), + to_scalar(r["Bollinger_Bands_upper"]), to_scalar(r["Bollinger_Bands_lower"]), + to_scalar(r["ATR"]), to_scalar(r["OBV"]), + to_scalar(r["Stochastic"]), to_scalar(r["MFI"]), + to_scalar(r["MA_5"]), to_scalar(r["MA_20"]), + to_scalar(r["MA_50"]), to_scalar(r["MA_200"]), + )) + + if not records: + continue + + execute_values(cur, sql, records) + print(f"[TECH] Upserted rows {start} ~ {end-1} / {total}") + + conn.commit() + print(f"[TECH] Upserted TOTAL {len(df)} rows into public.technical_indicators") + + finally: + conn.close() + + + +# -------------------------------------------------------------- +# 기술지표 파이프라인: DB price_data 전체 기반 +# -------------------------------------------------------------- +def run_technical_indicators_full(config: Dict[str, Any]) -> None: + """ + 기술지표 전체 FULL 재계산 (manual_backfill_all 전용) + """ + db_name = config["db_name"] + + print("[TECH-FULL] 전체 기간 price_data 로딩 중…") + + from sqlalchemy import text + engine = get_engine(db_name) + + query = text(""" + SELECT ticker, date, open, high, low, close, volume, adjusted_close + FROM public.price_data + ORDER BY ticker, date; + """) + + with engine.connect() as conn: + df_price = pd.read_sql(query, conn) + + if df_price.empty: + print("[TECH-FULL] price_data 없음 → 스킵") + return + + tickers = sorted(df_price["ticker"].unique()) + print(f"[TECH-FULL] 전체 티커 수: {len(tickers)}") + + frames = [] + + with engine.connect() as conn: + conn.execute(text("DELETE FROM public.technical_indicators;")) + print("[TECH-FULL] 기존 technical_indicators 테이블 초기화 완료.") + conn.commit() + + + for idx, t in enumerate(tickers, start=1): + print(f"[TECH-FULL] ({idx}/{len(tickers)}) {t} 계산 중…") + df_t = df_price[df_price["ticker"] == t] + tech_df = compute_technical_indicators(df_t) + frames.append(tech_df) + + full_df = pd.concat(frames, ignore_index=True) + + print(f"[TECH-FULL] 기술지표 전체 계산 완료: {len(full_df)} rows") + + # 전체 덮어쓰기 (UPSERT) + upsert_technical_indicators(db_name, full_df) + + print("[TECH-FULL] 기술지표 FULL 업서트 완료.") + +# -------------------------------------------------------------- +# 기술지표 파이프라인: 증분용 (최근 250일치만 계산) +# -------------------------------------------------------------- + +def run_technical_indicators_incremental(config: Dict[str, Any]) -> None: + """ + 기존: price_data 전체 기간을 읽고 모든 기술지표 전체 재계산 (매우 무거움) + 변경: 최근 250일만 계산 → 5~10배 이상 빠름 + """ + db_name = config["db_name"] + window_days = 250 # rolling 기간 + 여유 buffer + + # --------------------------------------------------------- + # 최신 날짜 구하고, 최근 250일 구간만 조회 + # --------------------------------------------------------- + from sqlalchemy import text + engine = get_engine(db_name) + + # price_data 에 최신 날짜 확인 + with engine.connect() as conn: + last_date = conn.execute( + text("SELECT MAX(date) FROM public.price_data") + ).scalar() + + if last_date is None: + print("[TECH] price_data empty → 기술지표 계산 불가") + return + + start_date = last_date - timedelta(days=window_days) + + print(f"[TECH] 최근 {window_days}일 ({start_date} ~ {last_date}) 데이터만 사용") + + # 최근 250일 price data만 로딩 + query = text(""" + SELECT ticker, date, open, high, low, close, volume + FROM public.price_data + WHERE date >= :start_date + ORDER BY ticker, date + """) + + with engine.connect() as conn: + df_price = pd.read_sql(query, conn, params={"start_date": start_date}) + + if df_price.empty: + print("[TECH] 최근 데이터 없음 → 기술지표 계산 스킵") + return + + # --------------------------------------------------------- + # 티커별 기술지표 계산 + # --------------------------------------------------------- + tickers = sorted(df_price["ticker"].unique()) + print(f"[TECH] 대상 티커 수: {len(tickers)}") + + frames = [] + for idx, t in enumerate(tickers, start=1): + print(f"[TECH] ({idx}/{len(tickers)}) {t} 지표 계산") + df_t = df_price[df_price["ticker"] == t] + tech_df = compute_technical_indicators(df_t) + frames.append(tech_df) + + tech_recent = pd.concat(frames, ignore_index=True) + + print(f"[TECH] 기술지표 계산 완료: {len(tech_recent)} rows") + + # --------------------------------------------------------- + # 최근 250일 데이터만 UPSERT + # --------------------------------------------------------- + upsert_technical_indicators(db_name, tech_recent) + + print("[TECH] 최근 250일 기술지표 업데이트 완료") + + +# ====================================================================== +# SECTION 4 — MACROECONOMIC INDICATORS (FETCH + UPSERT + PIPELINE) +# ====================================================================== + +# -------------------------------------------------------------- +# Macro Data Fetch from FRED +# -------------------------------------------------------------- + +# FRED API 읽어오기 +FRED_API_KEY = os.getenv("FRED_API_KEY") +def get_fred_client() -> Fred | None: + """ + 한국어 주석: + - FRED API 클라이언트를 생성하는 헬퍼 함수. + - FRED_API_KEY가 없거나 잘못 설정된 경우 None을 반환하고, + 매크로 데이터 수집을 스킵할 수 있도록 한다. + """ + if not FRED_API_KEY: + print("[WARN] FRED_API_KEY가 설정되어 있지 않아 매크로 데이터 수집을 건너뜁니다.") + return None + + try: + fred = Fred(api_key=FRED_API_KEY) + return fred + except ValueError as e: + # fredapi 내부에서 키가 잘못되었을 때 발생하는 에러를 안전하게 처리 + print(f"[WARN] FRED API 초기화 실패: {e}") + return None + +def fetch_macro_from_fred(series_map: Dict[str, str], start: str, end: str) -> pd.DataFrame: + dates = pd.date_range(start, end, freq="D") + out = pd.DataFrame({"date": dates}) + out["date"] = out["date"].dt.date + + fred = get_fred_client() + if fred is None: + print("[MACRO] FRED 클라이언트 초기화 실패 → 매크로 데이터 수집 건너뜀") + return pd.DataFrame(columns=["date"] + list(series_map.keys())) + + for col_name, fred_symbol in series_map.items(): + print(f"[MACRO] Fetch {col_name} ({fred_symbol}) {start}~{end}") + s = fred.get_series(fred_symbol, observation_start=start, observation_end=end) + s = s.reset_index().rename(columns={"index": "date", 0: col_name}) + s["date"] = pd.to_datetime(s["date"]).dt.date + out = out.merge(s, on="date", how="left") + + return out + +# -------------------------------------------------------------- +# Macro UPSERT +# -------------------------------------------------------------- +def upsert_macro(db_name: str, df: pd.DataFrame): + if df.empty: + print("[MACRO] No macro data to upsert.") + return + + # 누락된 컬럼 채우기 + required_cols = [ + "cpi", "gdp", "ppi", "jolt", + "cci", "interest_rate", "trade_balance" + ] + for c in required_cols: + if c not in df.columns: + df[c] = None + + # pandas/Numpy → Python scalar + records = [] + for _, r in df.iterrows(): + records.append(( + r["date"], + to_scalar(r["cpi"]), + to_scalar(r["gdp"]), + to_scalar(r["ppi"]), + to_scalar(r["jolt"]), + to_scalar(r["cci"]), + to_scalar(r["interest_rate"]), + to_scalar(r["trade_balance"]), + )) + + sql = """ + INSERT INTO public.macroeconomic_indicators ( + date, cpi, gdp, ppi, jolt, + cci, interest_rate, trade_balance + ) + VALUES %s + ON CONFLICT (date) DO UPDATE SET + cpi = EXCLUDED.cpi, + gdp = EXCLUDED.gdp, + ppi = EXCLUDED.ppi, + jolt = EXCLUDED.jolt, + cci = EXCLUDED.cci, + interest_rate = EXCLUDED.interest_rate, + trade_balance = EXCLUDED.trade_balance; + """ + + conn = get_db_conn(db_name) + try: + with conn.cursor() as cur: + execute_values(cur, sql, records) + conn.commit() + print(f"[MACRO] Upserted {len(df)} rows into public.macroeconomic_indicators") + + finally: + conn.close() + + +# -------------------------------------------------------------- +# Macro Pipeline (증분) +# -------------------------------------------------------------- +def run_macro_pipeline(config: Dict[str, Any]): + db_name = config["db_name"] + series_map = config["macro_series"] + + if not series_map: + print("[MACRO] macro_series is empty → skip macroeconomic_indicators") + return + + # 테이블 마지막 날짜 + last = get_last_date_in_table( + db_name, + "public.macroeconomic_indicators", + "date" + ) + + if last is None: + start_date = config.get("macro_start", "2017-01-01") + print(f"[MACRO] 기존 없음 → {start_date} 부터 시작") + else: + start_date = (last + timedelta(days=1)).strftime("%Y-%m-%d") + print(f"[MACRO] 마지막 날짜 {last} 이후 → {start_date} 부터 수집") + + end_date = today_kst().strftime("%Y-%m-%d") + + if start_date > end_date: + print("[MACRO] Already up to date.") + return + + df_macro = fetch_macro_from_fred(series_map, start_date, end_date) + upsert_macro(db_name, df_macro) +# ====================================================================== +# SECTION 5 — COMPANY FUNDAMENTALS (FETCH + UPSERT + PIPELINE) +# ====================================================================== + +# -------------------------------------------------------------- +# yfinance 재무제표 Fetch +# -------------------------------------------------------------- +def fetch_company_fundamentals_from_yf(tickers: List[str]) -> pd.DataFrame: + """ + 각 티커에 대해 다음 데이터를 가져와 company_fundamentals 형태로 변환: + - annual financials (Income Statement) + - annual balance sheet + - EPS, PE_ratio (yfinance.info 에서) + """ + rows = [] + + for t in tickers: + print(f"[FUND] Fetch fundamentals for {t}") + yf_t = yf.Ticker(t) + + fs = yf_t.financials # 손익계산서 + bs = yf_t.balance_sheet # 재무상태표 + + if fs is None or fs.empty: + print(f"[FUND] No financials for {t}") + continue + if bs is None or bs.empty: + print(f"[FUND] No balance_sheet for {t}") + continue + + fs = fs.copy() + bs = bs.copy() + + # --- 여기 부분만 수정 --- + def normalize_cols_to_date(idx): + # 이미 DatetimeIndex면 .date (ndarray of date) 사용 + if isinstance(idx, pd.DatetimeIndex): + return idx.date + # 그 외에는 to_datetime 후 .date + return pd.to_datetime(idx).date + + fs.columns = normalize_cols_to_date(fs.columns) + bs.columns = normalize_cols_to_date(bs.columns) + # ------------------------ + + # 계정명 매칭 함수 + def find_row(df, names): + for n in names: + if n in df.index: + return df.loc[n] + df_map = {idx.lower().replace(" ", ""): idx for idx in df.index} + for n in names: + key = n.lower().replace(" ", "") + if key in df_map: + return df.loc[df_map[key]] + return None + + revenue_row = find_row(fs, ["Total Revenue", "Revenue"]) + net_income_row = find_row(fs, ["Net Income", "NetIncome"]) + assets_row = find_row(bs, ["Total Assets"]) + liab_row = find_row(bs, ["Total Liab", "Total Liabilities"]) + equity_row = find_row(bs, ["Total Stockholder Equity", "Total Equity"]) + + info = {} + try: + info = yf_t.info or {} + except: + pass + + eps_value = info.get("trailingEps") + pe_value = info.get("trailingPE") + + report_dates = sorted(set(fs.columns) | set(bs.columns)) + + for d in report_dates: + rows.append({ + "ticker": t, + "date": d, + "revenue": float(revenue_row[d]) if (revenue_row is not None and d in revenue_row and pd.notna(revenue_row[d])) else None, + "net_income": float(net_income_row[d]) if (net_income_row is not None and d in net_income_row and pd.notna(net_income_row[d])) else None, + "total_assets": float(assets_row[d]) if (assets_row is not None and d in assets_row and pd.notna(assets_row[d])) else None, + "total_liabilities": float(liab_row[d]) if (liab_row is not None and d in liab_row and pd.notna(liab_row[d])) else None, + "equity": float(equity_row[d]) if (equity_row is not None and d in equity_row and pd.notna(equity_row[d])) else None, + "EPS": float(eps_value) if eps_value is not None else None, + "PE_ratio": float(pe_value) if pe_value is not None else None, + }) + + if not rows: + return pd.DataFrame( + columns=[ + "ticker", "date", + "revenue", "net_income", + "total_assets", "total_liabilities", + "equity", "EPS", "PE_ratio" + ] + ) + + return pd.DataFrame(rows) + + + +# -------------------------------------------------------------- +# UPSERT Company Fundamentals +# -------------------------------------------------------------- +def upsert_company_fundamentals(db_name: str, df: pd.DataFrame): + if df.empty: + print("[FUND] No fundamentals to upsert.") + return + + # pandas → python scalar + records = [] + for _, r in df.iterrows(): + records.append(( + r["ticker"], r["date"], + to_scalar(r["revenue"]), + to_scalar(r["net_income"]), + to_scalar(r["total_assets"]), + to_scalar(r["total_liabilities"]), + to_scalar(r["equity"]), + to_scalar(r["EPS"]), + to_scalar(r["PE_ratio"]), + )) + + sql = """ + INSERT INTO public.company_fundamentals ( + ticker, date, + revenue, net_income, + total_assets, total_liabilities, + equity, EPS, PE_ratio + ) + VALUES %s + ON CONFLICT (ticker, date) DO UPDATE SET + revenue = EXCLUDED.revenue, + net_income = EXCLUDED.net_income, + total_assets = EXCLUDED.total_assets, + total_liabilities = EXCLUDED.total_liabilities, + equity = EXCLUDED.equity, + EPS = EXCLUDED.EPS, + PE_ratio = EXCLUDED.PE_ratio; + """ + + conn = get_db_conn(db_name) + try: + with conn.cursor() as cur: + execute_values(cur, sql, records) + conn.commit() + print(f"[FUND] Upserted {len(df)} rows into public.company_fundamentals") + + finally: + conn.close() + + +# -------------------------------------------------------------- +# Fundamentals Pipeline (전체 업데이트) +# -------------------------------------------------------------- +def run_company_fundamentals_pipeline(config: Dict[str, Any]): + db_name = config["db_name"] + tickers = config["tickers_for_fund"] + + df = fetch_company_fundamentals_from_yf(tickers) + upsert_company_fundamentals(db_name, df) +# ====================================================================== +# SECTION 6 — DAILY AUTO DATA COLLECTION (증분 업데이트) +# ====================================================================== + +def run_data_collection() -> None: + """ + pipeline/run_pipeline.py 에서 "STEP 0" 으로 호출되는 자동 데이터 수집 함수. + + ▣ 동작 + 1) price_data: MAX(date) 이후 구간만 yfinance 로 추가 수집 후 UPSERT + 2) technical_indicators: price_data 전체 기반으로 다시 계산 후 UPSERT + 3) macroeconomic_indicators: MAX(date)+1 이후 구간만 수집 + 4) company_fundamentals: 전체 재수집 (annual/quarterly 로 증분 개념이 애매해서) + """ + + print("\n=== [STEP 0] DAILY DATA COLLECTION (AUTO, incremental) ===") + + db_name = "db" + today_str = today_kst().strftime("%Y-%m-%d") + + # ------------------------------------------------------------------ + # 1) 가격/기술지표 대상 티커 (환경변수 or 기본값) + # ------------------------------------------------------------------ + # 1) DB에서 모든 티커 가져오기 + tickers = get_all_tickers_from_db(db_name) + + # DB가 비어있으면 기본값 사용 + if not tickers: + print("[INFO] DB에 티커가 없어서 기본 유니버스 사용: AAPL, MSFT, TSLA") + tickers = ["AAPL", "MSFT", "TSLA"] + + + # ------------------------------------------------------------------ + # 2) 펀더멘털 대상 티커 + # ------------------------------------------------------------------ + tickers_for_fund = get_all_tickers_from_db(db_name) + + # DB가 비어있으면 기본값 사용 + if not tickers_for_fund: + print("[INFO] DB에 티커가 없어서 기본 유니버스 사용: AAPL, MSFT, TSLA") + tickers_for_fund = ["AAPL", "MSFT", "TSLA"] + + + # ------------------------------------------------------------------ + # 3) 거시지표 시리즈 매핑 + # ------------------------------------------------------------------ + macro_series = { + # 예시. 실제 쓰는 yfinance symbol 로 교체 가능 + "cpi": "CPIAUCSL", + "gdp": "GDP", + "ppi": "PPIACO", + "jolt": "JTSJOL", + "cci": "CONCCONF", + "interest_rate": "^TNX", + } + + config = { + "db_name": db_name, + "tickers": tickers, + "price_start": "2017-01-01", + "tickers_for_fund": tickers_for_fund, + "macro_series": macro_series, + "macro_start": "2017-01-01", + } + + # ------------------------------------------------------------------ + # [1] price_data 증분 업데이트 + # ------------------------------------------------------------------ + print("\n[1] price_data incremental update…") + run_price_pipeline(config) + + # ------------------------------------------------------------------ + # [2] technical_indicators 전체 재계산 + # ------------------------------------------------------------------ + print("\n[2] technical_indicators incremental recompute…") + run_technical_indicators_incremental(config) + + # ------------------------------------------------------------------ + # [3] macroeconomic_indicators 증분 업데이트 + # ------------------------------------------------------------------ + print("\n[3] macroeconomic_indicators incremental update…") + run_macro_pipeline(config) + + # ------------------------------------------------------------------ + # [4] company_fundamentals 전체 업데이트 + # ------------------------------------------------------------------ + print("\n[4] company_fundamentals full update…") + run_company_fundamentals_pipeline(config) + + print("=== [STEP 0] DAILY DATA COLLECTION DONE ===\n") +# ====================================================================== +# SECTION 7 — MANUAL BACKFILL ALL (FULL REFILL) +# ====================================================================== + +def manual_backfill_all() -> None: + """ + ⚠ 전체 백필(Backfill) 모드 + ----------------------------------------- + 이 함수는 "모든 티커 × 모든 스키마 × 전체 기간"을 다시 채운다. + 즉, 다음 순서로 강제 재수집한다: + + 1) price_data (OHLCV 전체 재수집) + 2) technical_indicators (price_data 전체 기반 재계산) + 3) macroeconomic_indicators (전체 재수집) + 4) company_fundamentals (전체 재수집) + + ※ 이 작업은 매우 오래 걸릴 수 있음. + 실행 전 반드시 티커 개수와 시작일 범위를 확인해야 함. + """ + + db_name = "db" + today_str = today_kst().strftime("%Y-%m-%d") + + # ------------------------------------------------------------ + # 1) 가격/기술지표 대상 티커 + # ------------------------------------------------------------ + tickers = get_all_tickers_from_db(db_name) + if not tickers: + print("[INFO] DB에 티커가 없어서 기본 유니버스 사용: AAPL, MSFT, TSLA") + tickers = ["AAPL", "MSFT", "TSLA"] + + # ------------------------------------------------------------ + # 2) 펀더멘털 대상 티커 + # ------------------------------------------------------------ + tickers_for_fund = get_all_tickers_from_db(db_name) + if not tickers_for_fund: + print("[INFO] DB에 티커가 없어서 기본 유니버스 사용: AAPL, MSFT, TSLA") + tickers_for_fund = ["AAPL", "MSFT", "TSLA"] + + # ------------------------------------------------------------ + # 3) 거시지표 시리즈 + # ------------------------------------------------------------ + macro_series = { + # 필요 시 실제 프로젝트에서 수정 + "cpi": "CPIAUCSL", + "gdp": "GDP", + "ppi": "PPIACO", + "jolt": "JTSJOL", + } + + # ------------------------------------------------------------ + # 4) 백필 시작일 설정 + # ------------------------------------------------------------ + price_backfill_start = os.getenv("PRICE_BACKFILL_START", "2017-01-01") + macro_backfill_start = os.getenv("MACRO_BACKFILL_START", "2017-01-01") + + # ------------------------------------------------------------ + # ⚠️ 진짜 위험한 작업이므로 경고 출력 + # ------------------------------------------------------------ + print( + "\n⚠⚠⚠ [전체 백필 모드 경고] ⚠⚠⚠\n" + "이 명령은 다음 작업을 '모두 강제로' 다시 수행합니다.\n" + " 1) 모든 티커의 전체 OHLCV → price_data 를 다시 채움\n" + " 2) price_data 전체 기반으로 technical_indicators 전체 재계산\n" + " 3) 거시지표 전체 재수집 → macroeconomic_indicators 다시 기록\n" + " 4) 모든 펀더멘털 티커에 대해 재무제표 전체 재수집\n\n" + "■ 작업량이 매우 큽니다. (티커 수 × 기간 × 스키마)\n" + "■ 네트워크 상황에 따라 몇십 분~몇 시간 걸릴 수 있습니다.\n" + "■ 지금 설정된 값은 다음과 같습니다:\n" + "---------------------------------------------------------------\n" + f" 오늘(KST): {today_str}\n" + f" price_data 티커: {tickers}\n" + f" fundamentals 티커: {tickers_for_fund}\n" + f" macro 시리즈: {macro_series}\n" + f" 가격 백필 시작일: {price_backfill_start}\n" + f" 거시 백필 시작일: {macro_backfill_start}\n" + "---------------------------------------------------------------\n" + "※ 의도한 값이 맞는지 반드시 확인하세요.\n" + "※ 중간에 CTRL+C 로 중단해도 그 시점까지는 DB에 반영됨.\n" + "---------------------------------------------------------------\n" + ) + + # ================================================================== + # STEP 1 — PRICE_DATA 전체 백필 (티커도 나눠서 처리) + # ================================================================== + print("\n[STEP 1/4] price_data FULL backfill 시작…") + + tickers_per_batch = 20 # 티커도 20개 단위로 끊어서 처리 (원하면 조정 가능) + + for i in range(0, len(tickers), tickers_per_batch): + batch_tickers = tickers[i : i + tickers_per_batch] + print(f"[STEP 1/4] price_data batch {i} ~ {i + len(batch_tickers) - 1} 티커 처리 중...") + + # 이 배치의 모든 티커에 대해 OHLCV 수집 + df_price_batch = fetch_price_data_from_yf( + batch_tickers, + price_backfill_start, + today_str, + ) + + # 이 배치만 DB에 UPSERT (행 단위 batch_size 는 upsert_price_data 에서 또 나눔) + upsert_price_data(db_name, df_price_batch, batch_size=10_000) + + print("[STEP 1/4] price_data FULL backfill 완료.\n") + + + # ================================================================== + # STEP 2 — TECHNICAL_INDICATORS 전체 재계산/백필 + # ================================================================== + print("[STEP 2/4] technical_indicators FULL 재계산…") + config_for_tech = { + "db_name": db_name, + "tickers": tickers + } + run_technical_indicators_full(config_for_tech) + print("[STEP 2/4] technical_indicators FULL backfill 완료.\n") + + # ================================================================== + # STEP 3 — MACROECONOMIC_INDICATORS 전체 백필 + # ================================================================== + if macro_series: + print("[STEP 3/4] macroeconomic_indicators FULL backfill 시작…") + df_macro = fetch_macro_from_fred(macro_series, macro_backfill_start, today_str) + upsert_macro(db_name, df_macro) + print("[STEP 3/4] macroeconomic_indicators FULL backfill 완료.\n") + else: + print("[STEP 3/4] macroeconomic_indicators: macro_series 비어 있어서 skip.\n") + + # ================================================================== + # STEP 4 — COMPANY FUNDAMENTALS 전체 백필 + # ================================================================== + print("[STEP 4/4] company_fundamentals FULL backfill 시작…") + config_for_fund = { + "db_name": db_name, + "tickers_for_fund": tickers_for_fund + } + run_company_fundamentals_pipeline(config_for_fund) + print("[STEP 4/4] company_fundamentals FULL backfill 완료.\n") + + print("✅ 전체 백필 완료. (manual_backfill_all)") +# ====================================================================== +# SECTION 8 — MAIN ENTRYPOINT +# ====================================================================== + +if __name__ == "__main__": + """ + 이 파일을 직접 실행하면 “전체 백필 모드”가 실행된다. + + 사용 예: + python AI/daily_data_collection/main.py + + 주의: + - manual_backfill_all() 은 매우 무거운 작업이다. + - 다음 환경변수를 통해 대상 티커/기간을 조절할 수 있다: + + PRICE_INGEST_TICKERS : price_data + technical 대상 + FUND_INGEST_TICKERS : company_fundamentals 대상 + PRICE_BACKFILL_START : OHLCV 백필 시작일 + MACRO_BACKFILL_START : macro 백필 시작일 + + - pipeline/run_pipeline.py 에서 자동 실행되는 함수는 + run_data_collection() + 이고, manual_backfill_all() 은 직접 실행할 때만 사용해야 한다. + """ + manual_backfill_all() diff --git a/AI/finder/financial_eval.py b/AI/finder/financial_eval.py index 9a7be3ab..581a81b6 100644 --- a/AI/finder/financial_eval.py +++ b/AI/finder/financial_eval.py @@ -1,139 +1,87 @@ +# -*- coding: utf-8 -*- +""" +한국어 주석: +- 이 파일은 기존 config.json/psycopg2 기반 코드를 + 프로젝트 최신 표준 DB 유틸(get_engine, get_db_conn) 기반으로 완전히 리팩터링한 버전입니다. + +- 모든 DB 접속은 .env 기반 환경변수에서 읽고 + get_engine("db") 로 일관성 있게 연결합니다. +""" + import pandas as pd -import time -import json -import re from datetime import datetime - import warnings warnings.filterwarnings("ignore", category=FutureWarning) -import psycopg2 +# 🚀 신규 DB 엔진 유틸 (환경변수 기반) +from AI.libs.utils.get_db_conn import get_engine + +# ---------------------------------------------------------------------- +# 1️⃣ DB에서 재무제표 불러오기 +# ---------------------------------------------------------------------- +def load_company_fundamentals(db_name="db"): + """ + company_fundamentals 테이블 전체를 불러오는 함수. + JSON 설정 불필요. 환경변수 기반 DB 연결(get_engine) 사용. + """ + query = "SELECT * FROM company_fundamentals;" + engine = get_engine(db_name) -from pathlib import Path -import os + df = pd.read_sql(query, engine) + print(f"[INFO] Loaded fundamentals: {len(df)} rows") + return df -def load_config(path: str = "AI/configs/config.json") -> dict: #컨피그 파일 열기 - p=Path(path) - if not p.exists(): - alt = Path(__file__).parent / path - if alt.exists(): - p=alt - else: - raise FileNotFoundError(f"config file not found: {path}") - with p.open("r", encoding="utf-8") as f: - cfg = json.load(f) - if "db" not in cfg: - raise ValueError("config.json에 'db'섹션이 없습니다.") - # port 정수 보정 - if isinstance(cfg["db"].get("port",5432),str): #타입 검사, 없는경우 5432 사용 - cfg["db"]["port"]=int(cfg["db"]["port"]) - return cfg - -cfg = load_config("AI/configs/config.json") -db_cfg = cfg["db"] - -# DB 연결 & 데이터 로드 - -query = "SELECT * FROM company_fundamentals;" -with psycopg2.connect(**db_cfg) as conn: - company = pd.read_sql(query, conn) - -print(company.head()) - -ticker_list = company['ticker'].unique() # 498 - -#중복 정의라 주석처리했습니다. 확인후 수정 바람. -# 결측치 처리 -#def fill_financials(df: pd.DataFrame, industry_median: pd.DataFrame = None): -# df = df.copy() -# -# # 1. Forward/Backward fill (연속 시계열 기준) -# df = df.ffill().bfill() -# -# # 2. Equity 보정 (자산 - 부채) -# if 'Assets' in df.columns and 'Liabilities' in df.columns: -# df['Equity'] = df['Assets'] - df['Liabilities'] -# -# # 3. EPS 보정 (순이익 / 자본 or 주식수 정보 있을 경우) -# if 'NetIncome' in df.columns and 'EPS' in df.columns: -# df['EPS'] = df['EPS'].fillna(df['NetIncome'] / (df['Equity'].replace(0, pd.NA))) -# -# # 4. 업계 중앙값으로 채우기 (옵션) -# if industry_median is not None: -# df = df.fillna(industry_median) -# -# return df +# ---------------------------------------------------------------------- +# 2️⃣ 재무제표 결측치 보정 함수 (원본과 동일) +# ---------------------------------------------------------------------- def fill_financials(df, industry_medians=None): df = df.copy() - # 1. 시계열 정렬 df['date'] = pd.to_datetime(df['date']) df = df.sort_values(['ticker', 'date']) - # 2. 그룹별 처리 (기업별) def fill_group(g): g = g.sort_values('date') - - # revenue, net_income 보간 - g['revenue'] = g['revenue'].interpolate(method='linear').fillna(method='ffill').fillna(method='bfill') - g['net_income'] = g['net_income'].interpolate(method='linear').fillna(method='ffill').fillna(method='bfill') - - # equity 보정 (assets, liabilities 둘 다 있는 경우만) + + g['revenue'] = g['revenue'].interpolate().ffill().bfill() + g['net_income'] = g['net_income'].interpolate().ffill().bfill() + + # equity 보정 mask = g['equity'].isna() & g['total_assets'].notna() & g['total_liabilities'].notna() g.loc[mask, 'equity'] = g.loc[mask, 'total_assets'] - g.loc[mask, 'total_liabilities'] - + # liabilities 보정 mask = g['total_liabilities'].isna() & g['total_assets'].notna() & g['equity'].notna() g.loc[mask, 'total_liabilities'] = g.loc[mask, 'total_assets'] - g.loc[mask, 'equity'] - + # assets 보정 mask = g['total_assets'].isna() & g['equity'].notna() & g['total_liabilities'].notna() g.loc[mask, 'total_assets'] = g.loc[mask, 'equity'] + g.loc[mask, 'total_liabilities'] - - # 남은 결측치 시계열 보간 + + # 나머지 보간 for col in ['total_assets', 'total_liabilities', 'equity']: - g[col] = g[col].interpolate(method='linear').fillna(method='ffill').fillna(method='bfill') - - # eps 보간 - g['eps'] = g['eps'].interpolate(method='linear').fillna(method='ffill').fillna(method='bfill') - - # pe_ratio는 같은 기업의 중앙값으로 + g[col] = g[col].interpolate().ffill().bfill() + + g['eps'] = g['eps'].interpolate().ffill().bfill() + + # PE Ratio 기업중앙값 g['pe_ratio'] = g['pe_ratio'].fillna(g['pe_ratio'].median()) - + return g - + df = df.groupby('ticker').apply(fill_group).reset_index(drop=True) - # 3. 업계 평균으로 남은 결측치 채우기 (옵션) if industry_medians is not None: df = df.fillna(df['ticker'].map(industry_medians)) - - return df - -c_df = fill_financials(company) + return df -# 연간 재무제표로 변환 +# ---------------------------------------------------------------------- +# 3️⃣ 연간 재무제표 집계 +# ---------------------------------------------------------------------- def aggregate_yearly_financials(df): - """ - 연간 재무제표로 변환 - - revenue, net_income, eps: 연간 평균 (Flow) - - total_assets, total_liabilities, equity: 연말 값 (Stock) - - Parameters - ---------- - df : DataFrame - columns = [ticker, date, revenue, net_income, total_assets, total_liabilities, equity, eps] - - Returns - ------- - DataFrame - 연간 집계 데이터 - """ - - # 날짜 변환 & 연도 추출 df["date"] = pd.to_datetime(df["date"]) df["year"] = df["date"].dt.year @@ -145,30 +93,26 @@ def aggregate_yearly_financials(df): for (ticker, year), g in df.groupby(["ticker", "year"]): flow_data = g[flow_cols].mean() stock_data = g.loc[g["date"].idxmax(), stock_cols] - + yearly_data = {"ticker": ticker, "year": year} yearly_data.update(flow_data.to_dict()) yearly_data.update(stock_data.to_dict()) yearly_list.append(yearly_data) - yearly_df = pd.DataFrame(yearly_list) - - return yearly_df - -year_df = aggregate_yearly_financials(c_df) -year_df.head() + return pd.DataFrame(yearly_list) -# 평가표 생성 +# ---------------------------------------------------------------------- +# 4️⃣ 안정성 평가 점수 계산 +# ---------------------------------------------------------------------- def stability_score(df): results = [] - # 기업별 평가 for ticker, g in df.groupby("ticker"): latest = g.iloc[0] - # 1. Debt Ratio (총부채 / 총자산) + # Debt Ratio debt_ratio = latest["total_liabilities"] / latest["total_assets"] if latest["total_assets"] else None if debt_ratio is not None: if debt_ratio < 0.4: debt_score = 5 @@ -177,9 +121,9 @@ def stability_score(df): elif debt_ratio < 1.0: debt_score = 2 else: debt_score = 1 else: - debt_score = 3 # 결측 시 중립점 + debt_score = 3 - # 2. ROA (순이익 / 총자산) + # ROA roa = latest["net_income"] / latest["total_assets"] if latest["total_assets"] else None if roa is not None: if roa >= 0.08: roa_score = 5 @@ -190,7 +134,7 @@ def stability_score(df): else: roa_score = 3 - # 3. ROE (순이익 / 자본) + # ROE roe = latest["net_income"] / latest["equity"] if latest["equity"] else None if roe is not None: if roe >= 0.12: roe_score = 5 @@ -201,11 +145,16 @@ def stability_score(df): else: roe_score = 3 - # 4. 매출 성장률 (Revenue Growth: 최근 2년) + # 매출 성장률 if len(g) >= 2: - rev_growth = (g.iloc[-1]["revenue"] - g.iloc[-2]["revenue"]) / g.iloc[-2]["revenue"] if g.iloc[-2]["revenue"] else None + prev, curr = g.iloc[-2], g.iloc[-1] + if prev["revenue"] != 0: + rev_growth = (curr["revenue"] - prev["revenue"]) / prev["revenue"] + else: + rev_growth = None else: rev_growth = None + if rev_growth is not None: if rev_growth >= 0.10: rev_score = 5 elif rev_growth >= 0.05: rev_score = 4 @@ -214,13 +163,9 @@ def stability_score(df): else: rev_score = 3 - # 5. EPS - if latest["eps"] is not None: - eps_score = 5 if latest["eps"] > 0 else 1 - else: - eps_score = 3 + # EPS + eps_score = 5 if latest["eps"] > 0 else 1 if latest["eps"] is not None else 3 - # 최종 점수 (평균) total_score = round((debt_score + roa_score + roe_score + rev_score + eps_score) / 5, 2) results.append({ @@ -235,13 +180,25 @@ def stability_score(df): return pd.DataFrame(results) -recent_y_df = year_df[year_df['year'] > datetime.now().year-3].groupby("ticker")[['revenue', 'net_income', 'eps', - 'total_assets', 'total_liabilities', - 'equity']].mean() -eval_df = stability_score(recent_y_df) -eval_df.sort_values(by='stability_score', ascending=False, inplace=True) -# save -eval_df.to_csv(f'data/stability_score_{datetime.now().year}.csv', index=False) +# ---------------------------------------------------------------------- +# 5️⃣ 실행 파이프라인 +# ---------------------------------------------------------------------- +if __name__ == "__main__": + print("[STEP] Load fundamentals") + company = load_company_fundamentals("db") + + print("[STEP] Fill missing values") + c_df = fill_financials(company) + + print("[STEP] Aggregate yearly") + year_df = aggregate_yearly_financials(c_df) + + print("[STEP] Evaluate stability") + recent_y_df = year_df[year_df['year'] >= datetime.now().year - 3].groupby("ticker").mean() + eval_df = stability_score(recent_y_df) + eval_df.sort_values("stability_score", ascending=False, inplace=True) + eval_df.to_csv(f"data/stability_score_{datetime.now().year}.csv", index=False) + print("[DONE] stability score exported") diff --git a/AI/libs/core/__init__.py b/AI/libs/core/__init__.py new file mode 100644 index 00000000..7ace3c8e --- /dev/null +++ b/AI/libs/core/__init__.py @@ -0,0 +1,3 @@ +#AI/libs/core/__init__.py +from libs.core.pipeline import run_pipeline +__all__ = ["run_pipeline"] \ No newline at end of file diff --git a/AI/libs/core/pipeline.py b/AI/libs/core/pipeline.py index 6622ad09..47e2d809 100644 --- a/AI/libs/core/pipeline.py +++ b/AI/libs/core/pipeline.py @@ -1,184 +1,370 @@ +# pipeline/run_pipeline.py +""" +한국어 주석 (개요): +- 본 파일은 "주간 자동 파이프라인"의 전체 흐름을 오케스트레이션합니다. + (Finder → Transformer → XAI 리포트 → Backtrade → DB 저장) + +[전체 플로우] +1) Finder + - 시장/전략 조건에 맞는 종목 목록(ticker list)을 선정합니다. + +2) Transformer + - 선택된 종목들의 OHLCV 데이터를 DB에서 가져옵니다(fetch_ohlcv). + - LSTM/Rule 기반 등의 Transformer 로직을 통해 의사결정 로그(DataFrame)를 생성합니다. + - 이 의사결정 로그는 XAI와 Backtrade에서 모두 공통으로 사용됩니다. + +3) XAI (e.g. GROQ 등 LLM 기반 설명 생성) + - 각 의사결정에 대해 feature_name / feature_score를 기반으로 + "왜 이 신호가 나왔는지"에 대한 자연어 리포트를 생성합니다. + - 결과는 xai_reports 테이블에 먼저 저장됩니다. + - 이 때 생성된 xai_reports.id를 decision_log(logs_df)에 xai_report_id로 심습니다. + +4) Backtrade + - xai_report_id가 포함된 의사결정 로그(decision_log)를 받아, + price 컬럼을 "체결 기준가"로 직접 사용해 간소화된 백테스트를 수행합니다. + - 백테스트 결과(fills_df)는 xai_report_id를 그대로 보존한 상태로 executions 테이블에 저장됩니다. + +[주의 사항] +- Transformer가 생성하는 decision_log(DataFrame)는 최소한 아래 컬럼을 포함해야 합니다: + ['ticker', 'date', 'action', 'price', + 'feature_name1', 'feature_name2', 'feature_name3', + 'feature_score1', 'feature_score2', 'feature_score3'] +- GROQ_API_KEY 환경변수가 없으면 XAI 단계는 자동으로 스킵됩니다. +""" + import os import sys +import json from typing import List, Dict, Optional, Tuple -from datetime import datetime, timedelta, timezone +from datetime import datetime, timezone, timedelta + import pandas as pd -# --- 프로젝트 루트 경로 설정 --- +# ---------------------------------------------------------------------- +# 프로젝트 루트 경로 설정 +# ---------------------------------------------------------------------- project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(project_root) -# ------------------------------ - -# --- 모듈 import --- -from finder.main import run_finder -from transformer.main import run_transformer -from libs.utils.fetch_ohlcv import fetch_ohlcv -from xai.run_xai import run_xai -from libs.utils.get_db_conn import get_db_conn -from libs.utils.save_reports_to_db import save_reports_to_db -# --------------------------------- - -# DB 이름 상수(실제 등록된 키와 반드시 일치해야 함) -MARKET_DB_NAME = "db" # 시세/원천 데이터 DB -REPORT_DB_NAME = "report_DB" # 리포트 저장 DB - -# === (신규 전용) 필수 컬럼: inference 로그 → XAI 변환에 필요한 것만 강제 === + +# ---------------------------------------------------------------------- +# 외부 모듈 import (각 단계별 역할) +# ---------------------------------------------------------------------- +from daily_data_collection.main import run_data_collection # 0) 시세/시장 데이터 수집 +from finder.main import run_finder # 1) 종목 발굴 +from transformer.main import run_transformer # 2) 신호 생성(의사결정 로그 생성) +from backtrade.main import backtrade, BacktradeConfig # 3) 백트레이딩(간소화 체결 엔진) +from libs.utils.save_executions_to_db import save_executions_to_db # 3.5) 체결내역 DB 저장 +from xai.run_xai import run_xai # 4) XAI 리포트 텍스트 생성 +from libs.utils.save_reports_to_db import save_reports_to_db # 4.5) XAI 리포트 DB 저장 (id 반환) +from libs.utils.fetch_ohlcv import fetch_ohlcv # (Transformer용) OHLCV 수집 헬퍼 + +# ---------------------------------------------------------------------- +# 타입 별칭 (다른 모듈과의 데이터 계약을 명확히 하기 위함) +# ---------------------------------------------------------------------- +ReportRow = Tuple[str, str, float, str, str] # (ticker, signal, price, date, report_text) + +# ---------------------------------------------------------------------- +# DB 이름 상수 +# ---------------------------------------------------------------------- +MARKET_DB_NAME = "db" # 시세/시장 데이터(DB) 명. +REPORT_DB_NAME = "db" # 체결내역 / XAI 리포트를 저장하는 DB 명 + +# ---------------------------------------------------------------------- +# XAI 및 Backtrade에서 공통으로 요구하는 "결정 로그 필수 컬럼" 정의 +# ---------------------------------------------------------------------- REQUIRED_LOG_COLS = { - "ticker", "date", "action", "price", + "ticker", + "date", + "action", + "price", # XAI evidence 구성에 꼭 필요한 신규 컬럼 - "feature_name1", "feature_name2", "feature_name3", - "feature_score1", "feature_score2", "feature_score3", - # (원하면 로깅/모니터링용 확률도 계속 받되 필수는 아님) + "feature_name1", + "feature_name2", + "feature_name3", + "feature_score1", + "feature_score2", + "feature_score3", } +# ---------------------------------------------------------------------- +# 주간 티커 캐시 파일 경로 +# ---------------------------------------------------------------------- +TICKER_CACHE_PATH = os.path.join(project_root, "weekly_tickers.json") + + +# ====================================================================== +# 유틸리티 함수 모음 +# ====================================================================== + +def _utcnow() -> datetime: + """현재 시각을 UTC 기준 datetime으로 반환.""" + return datetime.now(timezone.utc) + + +def _to_iso_date(v) -> str: + """값을 'YYYY-MM-DD' 문자열로 변환.""" + try: + if isinstance(v, (pd.Timestamp, datetime)): + return v.strftime("%Y-%m-%d") + return str(v) + except Exception: + return str(v) + + +def _to_float(v, fallback: float = 0.0) -> float: + """값을 float로 변환, 실패 시 fallback.""" + try: + f = float(v) + if pd.isna(f): + return float(fallback) + return f + except Exception: + return float(fallback) + + +# ====================================================================== +# 1) Finder: 주간 종목 추출 단계 +# ====================================================================== + def run_weekly_finder() -> List[str]: """ - 주간 종목 발굴(Finder)을 실행하고 결과(종목 리스트)를 반환합니다. + Finder 모듈을 실행하여 후보 티커 리스트를 반환합니다. + (주에 한 번, 월요일에만 실제 실행) """ print("--- [PIPELINE-STEP 1] Finder 모듈 실행 시작 ---") - # top_tickers = run_finder() # TODO: 종목 선정 이슈 해결 후 사용 - top_tickers = ["AAPL", "MSFT", "GOOGL"] # 임시 데이터 - print("--- [PIPELINE-STEP 1] Finder 모듈 실행 완료 ---") - return top_tickers + try: + tickers = run_finder() + if not tickers: + tickers = ["AAPL", "MSFT", "GOOGL"] + except Exception as e: + print(f"[WARN] Finder 실행 중 오류 발생: {e} → 임시 티커 리스트를 사용합니다.") + tickers = ["AAPL", "MSFT", "GOOGL"] + + print(f"--- [PIPELINE-STEP 1] Finder 완료: tickers={tickers} ---") + return tickers -def _utcnow() -> datetime: - return datetime.now(timezone.utc) + +def save_weekly_tickers(tickers: List[str]) -> None: + """월요일에 산출한 티커 리스트를 로컬 캐시 파일에 저장.""" + try: + with open(TICKER_CACHE_PATH, "w", encoding="utf-8") as f: + json.dump({"tickers": tickers, "saved_at": _utcnow().isoformat()}, f, ensure_ascii=False) + print(f"[INFO] 주간 티커 캐시 저장 완료: {TICKER_CACHE_PATH}") + except Exception as e: + print(f"[WARN] 주간 티커 캐시 저장 실패: {e}") + + +def load_weekly_tickers() -> List[str]: + """캐시 파일에서 주간 티커 리스트를 불러옴.""" + if not os.path.exists(TICKER_CACHE_PATH): + print(f"[WARN] 주간 티커 캐시 파일이 존재하지 않습니다: {TICKER_CACHE_PATH}") + return [] + + try: + with open(TICKER_CACHE_PATH, "r", encoding="utf-8") as f: + data = json.load(f) + tickers = data.get("tickers", []) + print(f"[INFO] 캐시에서 주간 티커 로드: {tickers}") + return tickers + except Exception as e: + print(f"[WARN] 주간 티커 캐시 로드 실패: {e}") + return [] + + +# ====================================================================== +# 2) Transformer: 신호/의사결정 로그 생성 단계 +# ====================================================================== def run_signal_transformer(tickers: List[str], db_name: str) -> pd.DataFrame: """ - 종목 리스트를 받아 Transformer 모듈을 실행하고, 신호(결정 로그)를 반환합니다. + 종목 리스트에 대해 DB에서 OHLCV를 수집하고 Transformer를 호출하여 + 의사결정 로그(DataFrame)를 생성합니다. """ print("--- [PIPELINE-STEP 2] Transformer 모듈 실행 시작 ---") if not tickers: - print("[WARN] 빈 종목 리스트가 입력되어 Transformer를 건너뜁니다.") + print("[WARN] 빈 종목 리스트가 입력되어 Transformer 단계를 건너뜁니다.") return pd.DataFrame() - # end_date = _utcnow() # 서버 사용 시 - end_date = datetime.strptime("2024-10-30", "%Y-%m-%d") # 임시 고정 날짜 + end_date = datetime.now().date() start_date = end_date - timedelta(days=600) all_ohlcv_df: List[pd.DataFrame] = [] + for ticker in tickers: try: ohlcv_df = fetch_ohlcv( ticker=ticker, start=start_date.strftime("%Y-%m-%d"), end=end_date.strftime("%Y-%m-%d"), - db_name=db_name + db_name=db_name, ) + if ohlcv_df is None or ohlcv_df.empty: print(f"[WARN] OHLCV 미수집: {ticker}") continue + ohlcv_df = ohlcv_df.copy() + + if not pd.api.types.is_datetime64_any_dtype(ohlcv_df["date"]): + ohlcv_df["date"] = pd.to_datetime(ohlcv_df["date"]) + ohlcv_df["ticker"] = ticker all_ohlcv_df.append(ohlcv_df) + except Exception as e: - print(f"[ERROR] OHLCV 수집 실패({ticker}): {e}") + print(f"[ERROR] OHLCV 수집 실패 (ticker={ticker}): {e}") if not all_ohlcv_df: - print("[ERROR] 어떤 티커에서도 OHLCV 데이터를 가져오지 못했습니다.") + print("[ERROR] 어떤 티커에서도 OHLCV 데이터를 가져오지 못했습니다. Transformer를 종료합니다.") return pd.DataFrame() raw_data = pd.concat(all_ohlcv_df, ignore_index=True) + raw_data = raw_data.sort_values(["ticker", "date"]).reset_index(drop=True) finder_df = pd.DataFrame(tickers, columns=["ticker"]) + transformer_result: Dict = run_transformer( finder_df=finder_df, seq_len=60, pred_h=1, - raw_data=raw_data + raw_data=raw_data, ) or {} logs_df: pd.DataFrame = transformer_result.get("logs", pd.DataFrame()) + if logs_df is None or logs_df.empty: - print("[WARN] Transformer 결과 로그가 비어 있습니다.") + print("[WARN] Transformer 결과 로그(logs)가 비어 있습니다.") return pd.DataFrame() - # === 신규 포맷 강제 체크 === missing_cols = REQUIRED_LOG_COLS - set(logs_df.columns) if missing_cols: - print(f"[ERROR] 결정 로그에 필수 컬럼 누락(신규 포맷 전용): {sorted(missing_cols)}") + print(f"[ERROR] 결정 로그에 필수 컬럼 누락 (신규 포맷 전용): {sorted(missing_cols)}") return pd.DataFrame() print("--- [PIPELINE-STEP 2] Transformer 모듈 실행 완료 ---") return logs_df -# --- 안전 변환 유틸 --- -def _to_iso_date(v) -> str: - try: - if isinstance(v, (pd.Timestamp, datetime)): - return v.strftime("%Y-%m-%d") - return str(v) - except Exception: - return str(v) -def _to_float(v, fallback=0.0) -> float: - try: - f = float(v) - if pd.isna(f): - return float(fallback) - return f - except Exception: - return float(fallback) +# ====================================================================== +# 3) Backtrade: 의사결정 로그 기반 체결/포지션 계산 단계 +# ====================================================================== -# --- XAI 리포트: 5-튜플(rows)로 반환 --- -def run_xai_report(decision_log: pd.DataFrame) -> List[Tuple[str, str, float, str, str]]: +def run_backtrade(decision_log: pd.DataFrame) -> pd.DataFrame: """ - 반환: List[(ticker, signal, price, date, report_text)] - XAI 포맷: - { - "ticker": "...", - "date": "YYYY-MM-DD", - "signal": "BUY|HOLD|SELL", - "price": float, - "evidence": [ - {"feature_name": str, "contribution": float}, # 0~1 점수 권장 - ... - ] - } - ※ 신규 포맷 전용: - - feature_name1~3, feature_score1~3 필수 + Transformer에서 생성된 의사결정 로그(decision_log)의 price 컬럼을 + OHLCV 없이 "체결 기준가"로 직접 사용해 간소화된 백테스트를 수행합니다. + + 주의: + - decision_log는 xai_report_id 컬럼을 포함할 수 있으며, + backtrade() 구현이 해당 컬럼을 드롭하지 않으면 fills_df에도 그대로 보존됩니다. + """ + print("--- [PIPELINE-STEP 4] Backtrade 실행 시작 ---") + + if decision_log is None or decision_log.empty: + print("[WARN] Backtrade: 비어있는 결정 로그가 입력되었습니다. 체결을 수행하지 않습니다.") + return pd.DataFrame() + + run_id = _utcnow().strftime("run-%Y%m%d-%H%M%S") + + cfg = BacktradeConfig( + initial_cash=100_000.0, + slippage_bps=5.0, + commission_bps=3.0, + risk_frac=0.2, + max_positions_per_ticker=1, + fill_on_same_day=True, + ) + + fills_df, summary = backtrade( + decision_log=decision_log, + config=cfg, + run_id=run_id, + ) + + if fills_df is None or fills_df.empty: + print("[WARN] Backtrade: 생성된 체결 내역이 없습니다.") + return pd.DataFrame() + + print( + f"--- [PIPELINE-STEP 4] Backtrade 완료: " + f"trades={len(fills_df)}, " + f"cash_final={summary.get('cash_final')}, " + f"pnl_realized_sum={summary.get('pnl_realized_sum')} ---" + ) + return fills_df + + +# ====================================================================== +# 4) XAI 리포트: 설명 텍스트 생성 단계 +# ====================================================================== + +def run_xai_report(decision_log: pd.DataFrame) -> List[ReportRow]: + """ + Transformer 결정 로그를 입력으로 받아, 각 행(의사결정)에 대한 + XAI 설명 리포트(자연어 텍스트)를 생성합니다. + + 반환: + - ReportRow 리스트: (ticker, signal, price, date_s, report_text) + 이 리스트의 순서는 decision_log에서 XAI를 생성한 "행 순서"와 동일합니다. + (단, 여기서 일부 행을 스킵할 수도 있으므로, logs_df 행 수와 1:1 대응이라는 보장은 없음) """ print("--- [PIPELINE-STEP 3] XAI 리포트 생성 시작 ---") + api_key = os.environ.get("GROQ_API_KEY") if not api_key: - print("[STOP] GROQ_API_KEY 미설정: XAI 리포트 단계를 건너뜁니다.") + print("[STOP] GROQ_API_KEY 환경변수가 설정되어 있지 않아 XAI 리포트를 생성하지 않습니다.") return [] if decision_log is None or decision_log.empty: print("[WARN] 비어있는 결정 로그가 입력되어 XAI 리포트를 생성하지 않습니다.") return [] - # 신규 포맷 강제(안전망) - for c in ["feature_name1","feature_name2","feature_name3", - "feature_score1","feature_score2","feature_score3"]: + for c in [ + "feature_name1", + "feature_name2", + "feature_name3", + "feature_score1", + "feature_score2", + "feature_score3", + ]: if c not in decision_log.columns: print(f"[ERROR] XAI: 신규 포맷 필수 컬럼 누락: {c}") return [] - rows: List[Tuple[str, str, float, str, str]] = [] + rows: List[ReportRow] = [] for _, row in decision_log.iterrows(): ticker = str(row.get("ticker", "UNKNOWN")) date_s = _to_iso_date(row.get("date", "")) - signal = str(row.get("action", "")) # action -> signal - price = _to_float(row.get("price", 0.0)) + signal = str(row.get("action", "")) # action → signal + price = _to_float(row.get("price", 0.0)) + + # 한국어 주석: + # - ticker / signal / date_s 중 하나라도 비어 있으면 + # save_reports_to_db 단계에서 필터링되어 id 개수가 줄어들어 + # logs_df 와 xai_ids 길이가 어긋날 수 있다. + # - 따라서 여기서부터 필수 값이 비어 있는 행은 XAI 생성 자체를 스킵한다. + if not ticker or not signal or not date_s: + print(f"[WARN] XAI 입력 값 누락으로 스킵 (ticker={ticker}, signal={signal}, date={date_s})") + continue - # === 신규 포맷 전용 evidence === evidence: List[Dict[str, float]] = [] for i in (1, 2, 3): name = row.get(f"feature_name{i}") score = row.get(f"feature_score{i}") - # 이름/점수 모두 있어야 추가 + if name is None or str(name).strip() == "": continue if score is None or pd.isna(score): continue - evidence.append({ - "feature_name": str(name), - "contribution": _to_float(score, 0.0) # 0~1 정규화 점수 - }) + + evidence.append( + { + "feature_name": str(name), + "contribution": _to_float(score, 0.0), + } + ) decision_payload = { "ticker": ticker, @@ -190,11 +376,11 @@ def run_xai_report(decision_log: pd.DataFrame) -> List[Tuple[str, str, float, st try: report_text = run_xai(decision_payload, api_key) - report_text = str(report_text) # 혹시 모를 비문자 타입 대비 - print(f"--- {ticker} XAI 리포트 생성 완료 ---") + report_text = str(report_text) + print(f"--- [XAI] {ticker} 리포트 생성 완료 ---") except Exception as e: report_text = f"[ERROR] XAI 리포트 생성 실패: {e}" - print(f"--- {ticker} XAI 리포트 생성 중 오류: {e} ---") + print(f"--- [XAI] {ticker} 리포트 생성 중 오류: {e} ---") rows.append((ticker, signal, price, date_s, report_text)) @@ -202,40 +388,152 @@ def run_xai_report(decision_log: pd.DataFrame) -> List[Tuple[str, str, float, st return rows -def run_pipeline() -> Optional[List[str]]: +# ====================================================================== +# 5) 전체 파이프라인 오케스트레이션 +# ====================================================================== + +def run_pipeline() -> Optional[List[ReportRow]]: """ - 전체 파이프라인(Finder -> Transformer -> XAI)을 실행합니다. + 전체 파이프라인(Finder → Transformer → XAI → Backtrade → DB 저장)을 + 한 번에 실행하는 엔트리 포인트 함수. + + - Finder: 주 1회, 월요일에만 실제 실행 (티커를 캐시에 저장) + - 나머지(Transformer, XAI, Backtrade, DB 저장): 매일 실행 + → 평일에는 캐시에서 티커를 읽어서 사용 """ - # 1) Finder - tickers = run_weekly_finder() - if not tickers: - print("[STOP] Finder에서 종목을 찾지 못해 파이프라인을 중단합니다.") - return None + today = datetime.now() # 서버 로컬 시간 기준 (필요시 timezone 조정 가능) + weekday = today.weekday() # 월=0, 화=1, ..., 일=6 - # 2) Transformer + #------------------------------- + # 0) 주가 데이터 저장 실행 + #------------------------------- + print("--- [PIPELINE-STEP 0] 주가 데이터 수집 실행 시작 ---") + try: + # run_data_collection() + print("--- [PIPELINE-STEP 0] 주가 데이터 수집 실행 완료 ---") + except Exception as e: + print(f"[WARN] 데이터 수집 실행 중 오류 발생: {e} → 계속 진행합니다.") + + # ------------------------------- + # 1) Finder: 월요일에만 실행 + # ------------------------------- + if weekday == 0: # 월요일 + print("[INFO] 오늘은 월요일입니다. Finder를 실행합니다.") + tickers = run_weekly_finder() + if not tickers: + print("[STOP] Finder에서 종목을 찾지 못해 파이프라인을 중단합니다.") + return None + # 월요일에 산출한 티커를 캐시에 저장 + save_weekly_tickers(tickers) + else: + print("[INFO] 오늘은 월요일이 아니므로 Finder를 실행하지 않습니다. 캐시에서 티커를 불러옵니다.") + tickers = load_weekly_tickers() + if not tickers: + print("[WARN] 캐시된 티커가 없어 임시로 Finder를 한 번 실행합니다.") + tickers = run_weekly_finder() + if not tickers: + print("[STOP] 임시 Finder에서도 종목을 찾지 못해 파이프라인을 중단합니다.") + return None + + # ------------------------------- + # 2) Transformer: 매일 실행 + # ------------------------------- logs_df = run_signal_transformer(tickers, MARKET_DB_NAME) if logs_df is None or logs_df.empty: - print("[STOP] Transformer에서 신호를 생성하지 못해 파이프라인을 중단합니다.") + print("[STOP] Transformer에서 유효한 신호를 생성하지 못해 파이프라인을 중단합니다.") return None - # 3) XAI + # ------------------------------- + # 3) XAI 리포트 생성: 매일 실행 (환경변수 없으면 자동 스킵) + # ------------------------------- reports = run_xai_report(logs_df) - # 4) 저장 - save_reports_to_db(reports, REPORT_DB_NAME) + # 3.5) XAI 리포트 DB 저장 → 생성된 id 리스트 수신 + try: + xai_ids = save_reports_to_db(reports, REPORT_DB_NAME) + print("[INFO] XAI 리포트를 DB에 저장했습니다.") + except Exception as e: + print(f"[WARN] XAI 리포트 DB 저장 실패: {e}") + xai_ids = [] + + # 3.7) logs_df에 xai_report_id 심기 + # ------------------------------------------------ + # ✅ 기존: 길이가 다르면 전부 NULL 처리 + # ❌ 문제: XAI 리포트 수 != logs_df row 수 인 경우가 발생 + # ✅ 수정: (ticker, date, signal) 키로 매핑해서 가능한 행에만 ID 부여 + # ------------------------------------------------ + logs_df = logs_df.copy().reset_index(drop=True) + + # 기본값: 전부 None + logs_df["xai_report_id"] = None + + if reports and xai_ids: + if len(xai_ids) != len(reports): + print( + f"[WARN] XAI ID 개수({len(xai_ids)})와 리포트 행 수({len(reports)})가 달라 " + "완전한 1:1 대응은 아닐 수 있습니다. (최소한의 매핑만 수행)" + ) + + # (ticker, date, signal) → xai_id 매핑 딕셔너리 생성 + mapping: Dict[Tuple[str, str, str], int] = {} + for (ticker, signal, price, date_s, _report_text), xai_id in zip(reports, xai_ids): + if not ticker or not signal or not date_s: + continue + key = (str(ticker), str(date_s), str(signal)) + mapping[key] = xai_id + + if not mapping: + print("[WARN] 유효한 XAI 매핑이 없어 xai_report_id를 None으로 유지합니다.") + else: + def _lookup_xai_id(row: pd.Series) -> Optional[int]: + t = str(row.get("ticker", "")) + d = _to_iso_date(row.get("date", "")) + s = str(row.get("action", "")) + return mapping.get((t, d, s)) + + logs_df["xai_report_id"] = logs_df.apply(_lookup_xai_id, axis=1) + + # 디버그용: 실제로 몇 개의 row에 ID가 들어갔는지 확인 + assigned_count = logs_df["xai_report_id"].notna().sum() + print(f"[INFO] decision_log에 xai_report_id 매핑 완료 (할당된 row 수: {assigned_count})") + + else: + if not reports: + print("[INFO] 생성된 XAI 리포트가 없어 xai_report_id는 모두 NULL입니다.") + elif not xai_ids: + print("[INFO] XAI ID가 비어 있어 xai_report_id는 모두 NULL입니다.") + + # ------------------------------- + # 4) Backtrade: 매일 실행 + # ------------------------------- + fills_df = run_backtrade(logs_df) + + # ------------------------------- + # 5) executions 테이블에 체결 내역 저장: 매일 실행 + # ------------------------------- + try: + save_executions_to_db(fills_df, REPORT_DB_NAME) + print("[INFO] 체결 내역을 DB에 저장했습니다.") + except Exception as e: + print(f"[WARN] 체결 내역 DB 저장 실패: {e}") return reports -# --- 테스트 실행 --- + +# ====================================================================== +# 스크립트 단독 실행 시 테스트용 엔트리 포인트 +# ====================================================================== if __name__ == "__main__": - print(">>> 파이프라인 (Finder -> Transformer -> XAI) 테스트를 시작합니다.") + print(">>> 파이프라인 (Finder → Transformer → XAI → Backtrade) 테스트를 시작합니다.") final_reports = run_pipeline() + print("\n>>> 최종 반환 결과 (XAI Reports):") if final_reports: for report in final_reports: print(report) else: print("생성된 리포트가 없습니다.") + print("\n---") print("테스트가 정상적으로 완료되었다면, 위 '최종 반환 결과'에 각 종목에 대한 XAI 리포트가 출력되어야 합니다.") print("---") diff --git a/AI/libs/utils/get_db_conn.py b/AI/libs/utils/get_db_conn.py index 6ac43771..e58a3305 100644 --- a/AI/libs/utils/get_db_conn.py +++ b/AI/libs/utils/get_db_conn.py @@ -1,100 +1,109 @@ # AI/libs/utils/get_db_conn.py -# 한국어 주석: JSON 설정에서 DB 접속정보를 읽어 -# 1) psycopg2 Connection (로우 커넥션) -# 2) SQLAlchemy Engine (권장, 커넥션 풀/프리핑) -# 을 생성하는 유틸. 중복 로딩 방지를 위해 캐시 사용. +# 환경변수 기반 DB 연결 유틸 from __future__ import annotations import os import sys -import json -from typing import Dict, Any, Optional +from typing import Dict, Any from pathlib import Path from urllib.parse import quote_plus import psycopg2 from sqlalchemy import create_engine -# --- 프로젝트 루트 경로 설정 --- -project_root = Path(__file__).resolve().parents[3] # .../AI/libs/utils/get_db_conn.py 기준 상위 3단계 -sys.path.append(str(project_root)) -# -------------------------------- - -# 필수 키(포트는 선택) -REQUIRED_KEYS = {"host", "user", "password", "dbname"} - -# JSON 설정 캐시 (프로세스 내 1회 로드) -_CONFIG_CACHE: Optional[Dict[str, Dict[str, Any]]] = None - -def _config_path() -> Path: - """configs/config.json 경로를 안전하게 계산""" - return project_root/"AI"/"configs"/"config.json" +# ---------------------------------------------------------------------- +# 프로젝트 루트 경로 설정 +# ---------------------------------------------------------------------- +project_root = Path(__file__).resolve().parents[3] +sys.path.append(str(project_root)) -def _load_configs() -> Dict[str, Dict[str, Any]]: +# ---------------------------------------------------------------------- +# 내부 헬퍼: prefix 정규화 +# ---------------------------------------------------------------------- +def _normalize_prefix(name: str) -> str: """ - - configs/config.json을 읽어서 {db_name: {host, user, password, dbname, port?, sslmode?}} 형태로 반환 - - 파일은 깃에 올리지 않는 것을 권장(민감정보) + 한국어 주석: + - db_name="db" → "DB_" + - db_name="DB" → "DB_" + - db_name="DB_" → "DB_" + - 이미 REPORT_DB_처럼 끝이 "_" 로 끝나는 경우 → 그대로 사용 + - REPORT_DB → REPORT_DB_ + - market_db → MARKET_DB_ """ - global _CONFIG_CACHE - if _CONFIG_CACHE is not None: - return _CONFIG_CACHE + if not name: + return "DB_" + + up = name.upper() + + # 이미 정확한 환경변수 prefix 구조인 경우 (예: REPORT_DB_) + if up.endswith("_"): + return up - path = _config_path() - if not path.exists(): - raise FileNotFoundError(f"[DB CONFIG] 설정 파일이 없습니다: {path}") + # REPORT_DB → REPORT_DB_ + if "_" in up and not up.endswith("_"): + return up + "_" - try: - with path.open("r", encoding="utf-8") as f: - data: Dict[str, Dict[str, Any]] = json.load(f) - except json.JSONDecodeError as e: - raise ValueError(f"[DB CONFIG] JSON 파싱 오류: {e}") from e + # db → DB_ + return up + "_" - # 간단한 구조 검증(필수키 확인은 get_* 함수에서 db별로 수행) - if not isinstance(data, dict) or not data: - raise ValueError("[DB CONFIG] 최상위 JSON은 비어있지 않은 객체여야 합니다.") - _CONFIG_CACHE = data - return _CONFIG_CACHE +# ---------------------------------------------------------------------- +# 환경변수에서 DB 설정 가져오기 +# ---------------------------------------------------------------------- +REQUIRED_ENV_KEYS = ["HOST", "USER", "PASSWORD", "NAME"] -def _get_db_config(db_name: str) -> Dict[str, Any]: +def _load_db_env(prefix: str) -> Dict[str, Any]: """ - - 특정 db_name에 해당하는 설정 블록을 반환 - - 필수 키(host, user, password, dbname) 존재 검증 + prefix(DB_, REPORT_DB_, MARKET_DB_ 등)에 맞는 환경변수 로딩 + 예: DB_HOST, REPORT_DB_HOST, MARKET_DB_PASSWORD... """ - if not db_name or not isinstance(db_name, str): - raise ValueError("db_name must be a non-empty string") + cfg = {} - configs = _load_configs() - cfg = configs.get(db_name) - if not cfg: - raise KeyError(f"[DB CONFIG] '{db_name}' 설정 블록을 찾을 수 없습니다. (configs/config.json)") + for key in os.environ: + if key.startswith(prefix): + # DB_HOST → host + sub = key.replace(prefix, "").lower() + cfg[sub] = os.environ[key] + + # 필수값 검사 + missing = [] + for key in REQUIRED_ENV_KEYS: + env_name = prefix + key + if env_name not in os.environ: + missing.append(env_name) - missing = REQUIRED_KEYS - set(cfg.keys()) if missing: - raise KeyError(f"[DB CONFIG] '{db_name}'에 필수 키 누락: {sorted(missing)}") + raise EnvironmentError( + f"[DB CONFIG ERROR] 아래 환경변수가 필요합니다:\n" + f" {missing}\n" + f"예시:\n" + f" export {prefix}HOST=...\n" + f" export {prefix}USER=...\n" + f" export {prefix}PASSWORD=...\n" + f" export {prefix}NAME=..." + ) + + # 기본 포트 + cfg.setdefault("port", "5432") return cfg +# ---------------------------------------------------------------------- +# SQLAlchemy URL 생성 +# ---------------------------------------------------------------------- def _build_sqlalchemy_url(cfg: Dict[str, Any]) -> str: - """ - - SQLAlchemy용 PostgreSQL URI를 안전하게 구성 - - 비밀번호/유저의 특수문자를 URL 인코딩(quote_plus)로 보호 - - 예: postgresql+psycopg2://user:pass@host:port/dbname?sslmode=require - """ - user = quote_plus(str(cfg["user"])) - password = quote_plus(str(cfg["password"])) - host = str(cfg["host"]) - port = int(cfg.get("port", 5432)) - dbname = str(cfg["dbname"]) + user = quote_plus(cfg["user"]) + password = quote_plus(cfg["password"]) + host = cfg["host"] + port = cfg.get("port", "5432") + dbname = cfg["name"] base = f"postgresql+psycopg2://{user}:{password}@{host}:{port}/{dbname}" - # 선택 옵션: sslmode - # (Neon/클라우드 Postgres의 경우 require가 흔함) sslmode = cfg.get("sslmode") if sslmode: return f"{base}?sslmode={sslmode}" @@ -102,28 +111,35 @@ def _build_sqlalchemy_url(cfg: Dict[str, Any]) -> str: return base -def get_db_conn(db_name: str): +# ---------------------------------------------------------------------- +# psycopg2 커넥션 +# ---------------------------------------------------------------------- +def get_db_conn(db_name: str = "DB_"): """ - - psycopg2 로우 커넥션 생성(직접 커서 열어 사용할 때) - - pandas 경고가 싫다면 read_sql에는 get_engine() 사용을 권장 + 한국어 주석: + - db_name="db" 또는 "DB" → 자동으로 prefix="DB_" + - db_name="report_db" → "REPORT_DB_" + - 이미 "REPORT_DB_" 라면 그대로 사용 """ - cfg = _get_db_config(db_name) + + prefix = _normalize_prefix(db_name) + cfg = _load_db_env(prefix) + return psycopg2.connect( host=cfg["host"], user=cfg["user"], password=cfg["password"], - dbname=cfg["dbname"], + dbname=cfg["name"], port=int(cfg.get("port", 5432)), - sslmode=cfg.get("sslmode", None), # 필요 시 자동 적용 + sslmode=cfg.get("sslmode"), ) -def get_engine(db_name: str): - """ - - SQLAlchemy Engine 생성(권장) - - 커넥션 풀 + pre_ping으로 죽은 연결 사전 감지 → 운영 안정성↑ - - pandas.read_sql, 대량입출력 등에서 사용 - """ - cfg = _get_db_config(db_name) +# ---------------------------------------------------------------------- +# SQLAlchemy 엔진 +# ---------------------------------------------------------------------- +def get_engine(db_name: str = "DB_"): + prefix = _normalize_prefix(db_name) + cfg = _load_db_env(prefix) url = _build_sqlalchemy_url(cfg) return create_engine(url, pool_pre_ping=True) diff --git a/AI/libs/utils/save_executions_to_db.py b/AI/libs/utils/save_executions_to_db.py new file mode 100644 index 00000000..0eeec7cd --- /dev/null +++ b/AI/libs/utils/save_executions_to_db.py @@ -0,0 +1,261 @@ +# -*- coding: utf-8 -*- +""" +한국어 주석: +- executions 테이블 저장 → portfolio_positions 갱신 → portfolio_summary 갱신 +- 하나의 executions INSERT가 발생하면 계좌 전체 상태를 즉시 업데이트한다. +""" + +from __future__ import annotations +from typing import Optional +from sqlalchemy import text +from datetime import datetime, timezone +import pandas as pd + +from libs.utils.get_db_conn import get_engine # 기존 프로젝트 헬퍼 사용 + + +# ------------------------------------------------------------------- +# 공용 헬퍼 +# ------------------------------------------------------------------- +def _utcnow_iso() -> str: + """ISO 포맷 UTC 타임스탬프 문자열""" + return datetime.now(timezone.utc).isoformat() + +# ------------------------------------------------------------------- +# 📌 portfolio_positions 업데이트 +# ------------------------------------------------------------------- +def update_portfolio_position(conn, execution: dict): + """ + 한국어 주석: + - executions에 새 체결 1건이 기록될 때 호출 + - portfolio_positions는 ticker당 1행만 유지 (UNIQUE) + """ + + ticker = execution["ticker"] + qty = int(execution["qty"]) + fill_price = float(execution["fill_price"]) + side = execution["side"].upper() + realized_pnl = float(execution["pnl_realized"]) + + # --- 기존 포지션 로드 + old = conn.execute(text(""" + SELECT position_qty, avg_price, pnl_realized_cum + FROM portfolio_positions + WHERE ticker = :ticker + """), {"ticker": ticker}).fetchone() + + if old is None: + # 신규 포지션 (BUY만 가능) + if side != "BUY": + raise ValueError(f"[ERROR] 보유량 없이 SELL 발생: {ticker}") + + new_qty = qty + new_avg_price = fill_price + new_realized_cum = 0.0 + + else: + old_qty = old.position_qty + old_avg_price = float(old.avg_price) + old_realized_cum = float(old.pnl_realized_cum) + + if side == "BUY": + new_qty = old_qty + qty + new_avg_price = (old_avg_price * old_qty + fill_price * qty) / new_qty + new_realized_cum = old_realized_cum + + elif side == "SELL": + new_qty = old_qty - qty + new_realized_cum = old_realized_cum + realized_pnl + + if new_qty == 0: + new_avg_price = 0.0 + else: + new_avg_price = old_avg_price + + # 평가 관련 데이터 + current_price = fill_price + market_value = new_qty * current_price + pnl_unrealized = (current_price - new_avg_price) * new_qty + + # UPSERT + conn.execute(text(""" + INSERT INTO portfolio_positions + (ticker, position_qty, avg_price, + current_price, market_value, pnl_unrealized, + pnl_realized_cum, updated_at) + VALUES + (:ticker, :q, :avg, + :cp, :mv, :pnl_u, + :pnl_r, NOW()) + ON CONFLICT (ticker) + DO UPDATE SET + position_qty = EXCLUDED.position_qty, + avg_price = EXCLUDED.avg_price, + current_price = EXCLUDED.current_price, + market_value = EXCLUDED.market_value, + pnl_unrealized = EXCLUDED.pnl_unrealized, + pnl_realized_cum = EXCLUDED.pnl_realized_cum, + updated_at = NOW(); + """), { + "ticker": ticker, + "q": new_qty, + "avg": new_avg_price, + "cp": current_price, + "mv": market_value, + "pnl_u": pnl_unrealized, + "pnl_r": new_realized_cum, + }) + + +# ------------------------------------------------------------------- +# 📌 portfolio_summary 업데이트 +# ------------------------------------------------------------------- +def update_portfolio_summary(conn, fill_date: str): + """ + 한국어 주석: + - 계좌 전체 요약(자산, 평가금액, 수익률)을 fill_date 기준으로 M2M 업데이트. + - executions → portfolio_positions → portfolio_summary 순으로 호출됨. + """ + + # 1) 최신 현금(cash): executions.cash_after 기준 + cash_row = conn.execute(text(""" + SELECT cash_after + FROM executions + ORDER BY id DESC + LIMIT 1; + """)).fetchone() + + cash = float(cash_row.cash_after) if cash_row else 0.0 + + # 2) 전체 평가금액 = portfolio_positions.market_value 합 + mv_row = conn.execute(text(""" + SELECT COALESCE(SUM(market_value), 0) AS mv + FROM portfolio_positions; + """)).fetchone() + + market_value = float(mv_row.mv) + + total_asset = cash + market_value + + # 3) 누적 실현손익 + realized_row = conn.execute(text(""" + SELECT COALESCE(SUM(pnl_realized_cum), 0) AS pnl_r + FROM portfolio_positions; + """)).fetchone() + + pnl_realized_cum = float(realized_row.pnl_r) + + # 4) 미실현손익 합계 + unrealized_row = conn.execute(text(""" + SELECT COALESCE(SUM(pnl_unrealized), 0) AS pnl_u + FROM portfolio_positions; + """)).fetchone() + + pnl_unrealized = float(unrealized_row.pnl_u) + + # 5) 초기 원금 불러오기 (없으면 total_asset으로 설정) + init = conn.execute(text(""" + SELECT initial_capital + FROM portfolio_summary + ORDER BY date ASC + LIMIT 1; + """)).fetchone() + + if init is None: + initial_capital = total_asset + else: + initial_capital = float(init.initial_capital) + + return_rate = (total_asset / initial_capital) - 1 + + # UPSERT + conn.execute(text(""" + INSERT INTO portfolio_summary + (date, total_asset, cash, market_value, + pnl_unrealized, pnl_realized_cum, + initial_capital, return_rate, created_at) + VALUES + (:d, :ta, :cash, :mv, + :pnl_u, :pnl_r, + :init, :rr, NOW()) + ON CONFLICT (date) + DO UPDATE SET + total_asset = EXCLUDED.total_asset, + cash = EXCLUDED.cash, + market_value = EXCLUDED.market_value, + pnl_unrealized = EXCLUDED.pnl_unrealized, + pnl_realized_cum = EXCLUDED.pnl_realized_cum, + return_rate = EXCLUDED.return_rate, + created_at = NOW(); + """), { + "d": fill_date, + "ta": total_asset, + "cash": cash, + "mv": market_value, + "pnl_u": pnl_unrealized, + "pnl_r": pnl_realized_cum, + "init": initial_capital, + "rr": return_rate, + }) + + +# ------------------------------------------------------------------- +# 📌 메인 함수: executions + positions + summary +# ------------------------------------------------------------------- +def save_executions_to_db(rows_df: pd.DataFrame, db_name: str) -> None: + """ + 한국어 주석: + - rows_df 전체를 executions 테이블에 저장 + - 각 행 마다 portfolio_positions 갱신 + - 마지막으로 portfolio_summary 갱신 + """ + + if rows_df is None or rows_df.empty: + return + + engine = get_engine(db_name) + + # xai_report_id 없으면 NULL로 + if "xai_report_id" not in rows_df.columns: + rows_df = rows_df.copy() + print("[WARN] xai_report_id 없음. NULL 처리.") + rows_df["xai_report_id"] = None + + payload = rows_df.to_dict(orient="records") + + with engine.begin() as conn: + + # ============================================================= + # 1) executions 테이블 INSERT (배치) + # ============================================================= + insert_sql = text(""" + INSERT INTO executions + (run_id, xai_report_id, + ticker, signal_date, signal_price, signal, + fill_date, fill_price, qty, side, + value, commission, cash_after, + position_qty, avg_price, + pnl_realized, pnl_unrealized, created_at) + VALUES + (:run_id, :xai_report_id, + :ticker, :signal_date, :signal_price, :signal, + :fill_date, :fill_price, :qty, :side, + :value, :commission, :cash_after, + :position_qty, :avg_price, + :pnl_realized, :pnl_unrealized, NOW()) + """) + + conn.execute(insert_sql, payload) + + # ============================================================= + # 2) portfolio_positions 갱신 (행 단위) + # ============================================================= + for ex in payload: + update_portfolio_position(conn, ex) + + # ============================================================= + # 3) portfolio_summary 갱신 (마지막 fill_date 기준) + # ============================================================= + last_fill_date = payload[-1]["fill_date"] + update_portfolio_summary(conn, last_fill_date) + diff --git a/AI/libs/utils/save_reports_to_db.py b/AI/libs/utils/save_reports_to_db.py index 6cda73d5..77fc4a5c 100644 --- a/AI/libs/utils/save_reports_to_db.py +++ b/AI/libs/utils/save_reports_to_db.py @@ -1,56 +1,41 @@ # libs/utils/save_reports_to_db.py +# -*- coding: utf-8 -*- +""" +한국어 주석: +- XAI 리포트 저장 +- 체결/자산(cash) 업데이트 로직 없이, xai_reports 테이블에 INSERT만 수행한다. +- DB 스키마는 절대 변경하지 않는다(테이블/컬럼 생성/수정 X). +- 입력 rows 형식: List[Tuple[ticker, signal, price, date_str, report_text]] + +변경점: +- 기존에는 "실제 INSERT 된 행 수(int)"를 반환했으나, + 이제는 "각 INSERT 행의 xai_reports.id 리스트(List[int])"를 반환한다. + (rows의 순서와 id 리스트의 순서는 동일하다.) +""" + from __future__ import annotations -from typing import Iterable, Tuple, List, Optional +from typing import Iterable, Tuple, List from datetime import datetime, timezone -from decimal import Decimal -import os from sqlalchemy import text +from libs.utils.get_db_conn import get_engine # 프로젝트 표준 엔진 헬퍼 -# 내부 유틸에서 엔진만 사용 (스키마는 절대 변경 X) -from libs.utils.get_db_conn import get_engine - -ReportRow = Tuple[str, str, float, str, str] # (ticker, signal, price, date_str, report_text) +# (ticker, signal, price, date_str, report_text) +ReportRow = Tuple[str, str, float, str, str] -# ----- 환경 변수로 자산 테이블/컬럼 지정 (기본값 제공) ----- -ASSETS_TABLE = os.getenv("ASSETS_TABLE", "assets") -ASSETS_ID_COLUMN = os.getenv("ASSETS_ID_COLUMN", "id") -ASSETS_CASH_COLUMN = os.getenv("ASSETS_CASH_COLUMN", "cash") -ASSETS_ROW_ID = os.getenv("ASSETS_ROW_ID", "1") # ----- 유틸 ----- def utcnow() -> datetime: + """한국어 주석: 현재 UTC 시간을 반환(생성 시각 created_at 기록용).""" return datetime.now(timezone.utc) -def _to_decimal(x) -> Decimal: - if isinstance(x, Decimal): - return x - try: - return Decimal(str(x)) - except Exception: - return Decimal(0) - -def _fetch_current_cash(conn) -> Optional[Decimal]: - sql = text(f""" - SELECT {ASSETS_CASH_COLUMN} - FROM public.{ASSETS_TABLE} - WHERE {ASSETS_ID_COLUMN} = :rid - FOR UPDATE - """) - row = conn.execute(sql, {"rid": ASSETS_ROW_ID}).fetchone() - if not row: - return None - return _to_decimal(row[0]) - -def _update_cash(conn, new_cash: Decimal) -> None: - sql = text(f""" - UPDATE public.{ASSETS_TABLE} - SET {ASSETS_CASH_COLUMN} = :cash - WHERE {ASSETS_ID_COLUMN} = :rid - """) - conn.execute(sql, {"cash": str(new_cash), "rid": ASSETS_ROW_ID}) def _build_insert_params(rows: Iterable[ReportRow], created_at: datetime) -> List[dict]: + """ + 한국어 주석: + - 파이프라인에서 넘어온 리포트 튜플들을 DB INSERT용 딕셔너리 리스트로 변환한다. + - 방어적 필터링: ticker/signal/date 가 비어 있으면 해당 행은 건너뜀. + """ out: List[dict] = [] for (ticker, signal, price, date_s, report_text) in rows: if not ticker or not signal or not date_s: @@ -59,47 +44,50 @@ def _build_insert_params(rows: Iterable[ReportRow], created_at: datetime) -> Lis "ticker": ticker, "signal": signal, "price": float(price), - "date": date_s, # 'YYYY-MM-DD' + "date": date_s, # 'YYYY-MM-DD' "report": str(report_text), - "created_at": created_at, + "created_at": created_at, # TIMESTAMPTZ }) return out -# ----- 메인: 1주 고정 체결 + 자산 업데이트 + 리포트 저장 ----- -def save_reports_to_db(rows: List[ReportRow], db_name: str) -> int: + +# ----- 메인: 리포트 저장(INSERT only) ----- +def save_reports_to_db(rows: List[ReportRow], db_name: str) -> List[int]: """ - 요구사항: - - 저장 '직전'에 티커/시그널/가격을 보고 1주만 체결 - - 매 건 체결 후 잔여 현금(자산) 업데이트 - - DB 스키마 변경 금지 (xai_reports는 기존대로 INSERT만) + 한국어 주석: + - 입력된 XAI 리포트(rows)를 public.xai_reports 테이블에 INSERT 한다. + - 테이블 스키마는 건드리지 않는다(생성/ALTER 하지 않음). + - 반환값: 실제 INSERT 된 각 행의 xai_reports.id 리스트. + (rows의 순서에서 유효한 행만 추려낸 순서와 동일) """ if not rows: - print("[INFO] 저장할 리포트가 없습니다.") - return 0 + print("[INFO] 저장할 XAI 리포트가 없습니다.") + return [] engine = get_engine(db_name) created_at = utcnow() - # INSERT 템플릿 (스키마는 건드리지 않음) + # INSERT 템플릿 (스키마는 기존 그대로 사용) + RETURNING id insert_sql = text(""" INSERT INTO public.xai_reports (ticker, signal, price, date, report, created_at) VALUES (:ticker, :signal, :price, :date, :report, :created_at) + RETURNING id """) - inserted = 0 + inserted_ids: List[int] = [] + with engine.begin() as conn: - # 현금 락 걸고 읽기 - current_cash = _fetch_current_cash(conn) - if current_cash is None: - # 자산 테이블이 없거나 행이 없으면 바로 저장만 수행 - print(f"[WARN] 자산 테이블 public.{ASSETS_TABLE}에서 행을 찾지 못했어요. 체결 없이 리포트만 저장할게요.") - params = _build_insert_params(rows, created_at) - if params: - # 청크 삽입 - CHUNK = 1000 - for i in range(0, len(params), CHUNK): - batch = params[i:i+CHUNK] - conn.execute(insert_sql, batch) - inserted += len(batch) - print(f"--- {inserted}개의 XAI 리포트가 저장되었습니다 (자산 미적용). ---") - return inserted \ No newline at end of file + params = _build_insert_params(rows, created_at) + if not params: + print("[WARN] 유효한 XAI 리포트 파라미터가 없어 저장을 생략합니다.") + return [] + + # 리포트 개수가 엄청 많지 않을 것으로 가정하고, id를 받기 위해 한 행씩 INSERT + for p in params: + result = conn.execute(insert_sql, p) + new_id = result.scalar() + if new_id is not None: + inserted_ids.append(int(new_id)) + + print(f"--- {len(inserted_ids)}개의 XAI 리포트가 저장되었습니다. ---") + return inserted_ids diff --git a/AI/requirements.txt b/AI/requirements.txt index bdeab51d..b520683e 100644 --- a/AI/requirements.txt +++ b/AI/requirements.txt @@ -11,4 +11,5 @@ yfinance groq requests beautifulsoup4 -pathlib \ No newline at end of file +pathlib +fredapi \ No newline at end of file diff --git a/AI/tests/quick_db_check.py b/AI/tests/quick_db_check.py index f066bd8f..fdcf254c 100644 --- a/AI/tests/quick_db_check.py +++ b/AI/tests/quick_db_check.py @@ -1,48 +1,59 @@ # quick_db_check.py + +""" +DB 연결을 빠르게 테스트하는 스크립트. +- 프로젝트 루트(sisc-web) 자동 계산 +- .env 자동 로드 +- 환경변수 기반 get_db_conn 사용 +""" + import os import sys -import json -from typing import Dict, Union +from dotenv import load_dotenv -import psycopg2 +# ----------------------------- +# 1) 프로젝트 루트 계산 (중요!) +# ----------------------------- +# 현재 위치: sisc-web/AI/tests/quick_db_check.py +project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +sys.path.append(project_root) +# ----------------------------- +# 2) .env 파일 로드 +# ----------------------------- +load_dotenv(os.path.join(project_root, ".env")) + +# ----------------------------- +# 3) DB 유틸 +# ----------------------------- +from AI.libs.utils.get_db_conn import get_db_conn -# --- 프로젝트 루트 경로 설정 --------------------------------------------------- -project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -sys.path.append(project_root) -# --- 설정 파일 로드 ------------------------------------------------------------ -cfg_path = os.path.join(project_root, "configs", "config.json") - -config: Dict = {} -if os.path.isfile(cfg_path): - with open(cfg_path, "r", encoding="utf-8") as f: - config = json.load(f) - print("[INFO] configs/config.json 로드 완료") -else: - print(f"[WARN] 설정 파일이 없습니다: {cfg_path}") - -db_cfg: Union[str, Dict] = (config or {}).get("db", {}) - -# --- DB 연결 테스트 ------------------------------------------------------------ -conn = None -try: - if isinstance(db_cfg, dict): - with psycopg2.connect(**db_cfg) as conn: - with conn.cursor() as cur: - cur.execute("SELECT version();") - print("✅ 연결 성공:", cur.fetchone()[0]) - cur.execute("SELECT current_database(), current_user;") - db, user = cur.fetchone() - print(f"ℹ️ DB/USER: {db} / {user}") - else: - with psycopg2.connect(dsn=str(db_cfg)) as conn: - with conn.cursor() as cur: - ... -except Exception as e: - print("❌ 연결 실패:", repr(e)) -finally: - if conn: +def quick_db_check(db_name: str = "db"): + print(f"[INFO] DB 연결 테스트 시작 (db_name='{db_name}')") + + try: + conn = get_db_conn(db_name) + except Exception as e: + print("❌ DB 연결 실패 (커넥션 생성 오류):", repr(e)) + return + + try: + with conn.cursor() as cur: + cur.execute("SELECT version();") + version = cur.fetchone()[0] + print("✅ 연결 성공:", version) + + cur.execute("SELECT current_database(), current_user;") + db, user = cur.fetchone() + print(f"ℹ️ DB = {db}, USER = {user}") + + except Exception as e: + print("❌ 쿼리 실행 실패:", repr(e)) + finally: conn.close() - print("DB 연결 종료") + print("🔌 DB 연결 종료") + +if __name__ == "__main__": + quick_db_check("db") diff --git a/AI/tests/test_transfomer.py b/AI/tests/test_transfomer.py deleted file mode 100644 index 8953394c..00000000 --- a/AI/tests/test_transfomer.py +++ /dev/null @@ -1,209 +0,0 @@ -# AI/tests/test_transformer_live_fetch.py -# -*- coding: utf-8 -*- -""" -[목적] -- Transformer 모듈만 실제 OHLCV로 테스트 (저장 없음, 출력만) -- 데이터 수집은 프로젝트 표준 유틸: libs.utils.fetch_ohlcv.fetch_ohlcv 를 '그대로' 사용 - (즉, DB 우선 조회 + 실패/결측 시 야후 파이낸스 API 폴백 로직은 fetch_ohlcv 내부 정책을 따름) - -[실행] -> cd AI -> python -m tests.test_transformer_live_fetch -또는 -> python tests/test_transformer_live_fetch.py -""" - -import os -import sys -import json -from typing import Dict, List, Optional -from datetime import datetime, timedelta -import time -import random - -import pandas as pd -import numpy as np - -# --- 프로젝트 루트 경로 설정 --------------------------------------------------- -project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -sys.path.append(project_root) -# --------------------------------------------------------------------------- - -# --- 모듈 임포트 -------------------------------------------------------------- -from transformer.main import run_transformer # Transformer 단독 테스트 대상 -from libs.utils.fetch_ohlcv import fetch_ohlcv # ★ 표준 OHLCV 수집 유틸(반드시 이걸 사용) -# --------------------------------------------------------------------------- - - -# ============================================================================= -# (옵션) 안전한 fetch 래퍼: 일시적 실패/429 등에 대비한 재시도 -# - fetch_ohlcv 내부에서도 재시도/폴백이 구현되어 있을 수 있으나, -# 테스트 안정성을 위해 여기서 추가로 얇은 재시도를 감쌉니다. -# ============================================================================= -def safe_fetch_ohlcv( - ticker: str, - start: str, - end: str, - config: Optional[Dict] = None, - max_retries: int = 5, - base_sleep: float = 0.8 -) -> pd.DataFrame: - """ - fetch_ohlcv 호출을 얇게 감싸는 재시도 래퍼. - - 429/일시 네트워크 오류 같은 경우를 대비하여 지수 백오프 + 지터 적용 - - fetch_ohlcv가 raise하면 여기서 재시도 후 최종 raise - """ - attempt = 0 - while True: - try: - df = fetch_ohlcv( - ticker=ticker, - start=start, - end=end, - config=(config or {}) - ) - return df - except Exception as e: - attempt += 1 - if attempt >= max_retries: - raise - # 지수 백오프 + 약간의 랜덤 지터 - sleep_s = base_sleep * (2 ** (attempt - 1)) + random.uniform(0, 0.6) - print(f"[WARN] {ticker} fetch_ohlcv 실패({attempt}/{max_retries}) -> {e} | {sleep_s:.2f}s 대기 후 재시도") - time.sleep(sleep_s) - - -# ============================================================================= -# Transformer 단독 라이브 테스트 -# ============================================================================= -def run_transform_only_live_with_fetch(): - """ - - configs/config.json 로드 → db 설정 전달 - - 티커별로 fetch_ohlcv 호출(★ 프로젝트 유틸 사용) → raw_data 결합 - - run_transformer 호출 → 출력만 수행(저장 없음) - """ - print("=== [TEST] Transformer 단독(실데이터) 테스트 시작 — using libs.utils.fetch_ohlcv ===") - - # ---------------------------------------------------------------------- - # (A) 설정/입력 - # ---------------------------------------------------------------------- - cfg_path = os.path.join(project_root, "configs", "config.json") - config: Dict = {} - if os.path.isfile(cfg_path): - try: - with open(cfg_path, "r", encoding="utf-8") as f: - config = json.load(f) - print("[INFO] configs/config.json 로드 완료") - except Exception as e: - print(f"[WARN] config 로드 실패(빈 설정으로 진행): {e}") - - db_config = (config or {}).get("db", {}) # fetch_ohlcv에 그대로 넘김 - - # 테스트 티커/기간 - tickers: List[str] = ["AAPL", "MSFT", "GOOGL"] # 필요 시 교체 - end_dt = datetime.now() - start_dt = end_dt - timedelta(days=600) - start_str = start_dt.strftime("%Y-%m-%d") - end_str = end_dt.strftime("%Y-%m-%d") - - seq_len = 60 - pred_h = 1 - transformer_cfg: Dict = { - "transformer": { - "features": ["open", "high", "low", "close", "volume","adjusted_close"], - "target": "close", - "scaler": "standard" - } - } - - # ---------------------------------------------------------------------- - # (B) fetch_ohlcv 로 실데이터 가져오기 (티커별 → concat) - # ---------------------------------------------------------------------- - raw_parts: List[pd.DataFrame] = [] - for tkr in tickers: - try: - print(f"[INFO] 수집 시작: {tkr} ({start_str} → {end_str})") - df = safe_fetch_ohlcv( - ticker=tkr, - start=start_str, - end=end_str, - config=db_config, # ★ fetch_ohlcv는 내부에서 DB/야후 폴백 처리 - max_retries=5, - base_sleep=0.8 - ) - if df is None or df.empty: - print(f"[WARN] {tkr} 데이터가 비어 있습니다.") - else: - # 스키마 정합성(Transformer가 기대하는 컬럼 존재 검사) - required = ["ticker", "date", "open", "high", "low", "close", "volume", "adjusted_close"] - missing = [c for c in required if c not in df.columns] - if missing: - raise ValueError(f"{tkr} 수집 데이터에 필수 컬럼 누락: {missing}") - # 날짜형 변환 보정 - if not np.issubdtype(df["date"].dtype, np.datetime64): - df["date"] = pd.to_datetime(df["date"], errors="coerce") - raw_parts.append(df.reset_index(drop=True)) - print(f"[INFO] {tkr} 수집 완료: {len(df):,} rows") - finally: - # API rate 제한 완화(티커 사이 간격) - time.sleep(0.6 + random.uniform(0, 0.6)) - - if not raw_parts: - print("[ERROR] 모든 소스에서 OHLCV 확보 실패(fetch_ohlcv 사용).") - return - - raw_data = pd.concat(raw_parts, ignore_index=True) - - # ---------------------------------------------------------------------- - # (C) Transformer 호출 - # ---------------------------------------------------------------------- - finder_df = pd.DataFrame({"ticker": tickers}) - - print("[INFO] run_transformer 호출 중...") - result = run_transformer( - finder_df=finder_df, - seq_len=seq_len, - pred_h=pred_h, - raw_data=raw_data, - config=transformer_cfg - ) - - logs_df = result.get("logs", pd.DataFrame()) if isinstance(result, dict) else pd.DataFrame() - meta = {k: v for k, v in result.items() if k != "logs"} if isinstance(result, dict) else {} - - # ---------------------------------------------------------------------- - # (D) 출력만(저장 없음) - # ---------------------------------------------------------------------- - print("\n--- [RESULT] Transformer 반환 메타 키 ---") - print(list(meta.keys())) - - print("\n--- [RESULT] 결정 로그(logs) 미리보기 ---") - if not logs_df.empty: - if not np.issubdtype(logs_df["date"].dtype, np.datetime64): - logs_df["date"] = pd.to_datetime(logs_df["date"], errors="coerce") - print(logs_df.head(10).to_string(index=False)) - else: - print("logs_df가 비어 있습니다. Transformer 내부 로직을 확인하세요.") - - if not logs_df.empty: - if "action" in logs_df.columns: - print("\n--- [STATS] 액션별 건수 ---") - print(logs_df["action"].value_counts()) - if {"ticker", "date"}.issubset(logs_df.columns): - print("\n--- [STATS] 티커별 최근 신호 2건 ---") - latest = ( - logs_df.sort_values(["ticker", "date"], ascending=[True, False]) - .groupby("ticker") - .head(2) - .reset_index(drop=True) - ) - print(latest.to_string(index=False)) - - print(f"\n=== [TEST] 종료: 총 원시행(raw_data) = {len(raw_data):,} ===") - - -# ----------------------------------------------------------------------------- -# 엔트리포인트 -# ----------------------------------------------------------------------------- -if __name__ == "__main__": - run_transform_only_live_with_fetch() diff --git a/AI/tests/test_transformer_backtrader.py b/AI/tests/test_transformer_backtrader.py new file mode 100644 index 00000000..757e2822 --- /dev/null +++ b/AI/tests/test_transformer_backtrader.py @@ -0,0 +1,194 @@ +# AI/tests/test_transformer_backtrader.py +# 사용불가. 추후 수정 필요 +import os +import sys +from typing import List, Dict, Optional +from datetime import datetime, timedelta +import pandas as pd +import pathlib + +# --- 프로젝트/레포 경로 설정 --------------------------------------------------- +_this_file = os.path.abspath(__file__) +project_root = os.path.dirname(os.path.dirname(_this_file)) # .../transformer +repo_root = os.path.dirname(project_root) # .../ +libs_root = os.path.join(repo_root, "libs") # .../libs + +# sys.path에 중복 없이 추가 +for p in (project_root, repo_root, libs_root): + if p not in sys.path: + sys.path.append(p) +# ------------------------------------------------------------------------------ + +# ---------------------------------------------------------------------- +# Transformer 실행 +# ---------------------------------------------------------------------- +from transformer.main import run_transformer + +from typing import Optional +import pandas as pd +from sqlalchemy import text + +# DB용 유틸: SQLAlchemy Engine 생성 함수 사용 (get_engine) +from libs.utils.get_db_conn import get_engine + +def fetch_ohlcv( + ticker: List[str], + start: str, + end: str, + interval: str = "1d", + db_name: str = "db", +) -> pd.DataFrame: + """ + 특정 티커, 날짜 범위의 OHLCV 데이터를 DB에서 불러오기 (SQLAlchemy 엔진 사용) + + Args: + ticker (List[str]): 종목 코드 리스트 (예: ["AAPL", "MSFT", "GOOGL"]) + start (str): 시작일자 'YYYY-MM-DD' (inclusive) + end (str): 종료일자 'YYYY-MM-DD' (inclusive) + interval (str): 데이터 간격 ('1d' 등) - 현재 테이블이 일봉만 제공하면 무시됨 + db_name (str): get_engine()가 참조할 설정 블록 이름 (예: "db", "report_DB") + + Returns: + pd.DataFrame: 컬럼 = [ticker, date, open, high, low, close, adjusted_close, volume] + (date 컬럼은 pandas datetime으로 변환됨) + """ + + # 1) SQLAlchemy engine 얻기 (configs/config.json 기준) + engine = get_engine(db_name) + + # 2) 쿼리: named parameter(:tickers 등) 사용 -> 안전하고 가독성 좋음 + query = text(""" + SELECT ticker, date, open, high, low, close, adjusted_close, volume + FROM public.price_data + WHERE ticker IN :tickers -- 수정된 부분 + AND date BETWEEN :start AND :end + ORDER BY date; + """) + + # 3) DB에서 읽기 (with 문으로 커넥션 자동 정리) + with engine.connect() as conn: + df = pd.read_sql( + query, + con=conn, # 꼭 키워드 인자로 con=conn + params={"tickers": tuple(ticker), "start": start, "end": end}, # 수정된 부분 + ) + + # 4) 후처리: 컬럼 정렬 및 date 타입 통일 + if df is None or df.empty: + # 빈 DataFrame이면 일관된 컬럼 스키마로 반환 + return pd.DataFrame(columns=["ticker", "date", "open", "high", "low", "close", "adjusted_close", "volume"]) + + # date 컬럼을 datetime으로 변경 (UTC로 맞추고 싶으면 pd.to_datetime(..., utc=True) 사용) + if "date" in df.columns: + df["date"] = pd.to_datetime(df["date"]) + + # 선택: 컬럼 순서 고정 (일관성 유지) + desired_cols = ["ticker", "date", "open", "high", "low", "close", "adjusted_close", "volume"] + # 존재하는 컬럼만 가져오기 + cols_present = [c for c in desired_cols if c in df.columns] + df = df.loc[:, cols_present] + + return df + + + +def run_transformer_for_test(finder_df: pd.DataFrame, raw_data: pd.DataFrame, start_date: str, end_date: str) -> pd.DataFrame: + """ + Transformer 모델을 실행하여 매매 신호를 예측하는 함수 + :param finder_df: 종목 리스트 + :param raw_data: OHLCV 시계열 데이터 + :param start_date: 시작 날짜 + :param end_date: 종료 날짜 + :return: 매매 신호를 포함한 DataFrame + """ + transformer_result = run_transformer( + finder_df=finder_df, + seq_len=64, + pred_h=5, + raw_data=raw_data, + run_date=end_date, # 예측 날짜 + weights_path=None, # 가중치 경로. transformer/main.py 내부에서 기본 경로 사용 + interval="1d" + ) + + logs_df = transformer_result.get("logs", pd.DataFrame()) + return logs_df + +# ---------------------------------------------------------------------- +# Backtrader 실행 +# ---------------------------------------------------------------------- +from backtrader import Cerebro, Strategy +from backtrader.feeds import PandasData + +class SimpleStrategy(Strategy): + """ + Backtrader 전략 + """ + def __init__(self, logs_df: pd.DataFrame): + self.order = None + self.buy_price = None + self.logs_df = logs_df # 매매 신호를 담은 logs_df를 클래스에 저장 + + def next(self): + # 매수/매도 신호에 따라 거래 진행 + if self.order: + return # 이미 주문이 있으면 아무것도 하지 않음 + + for _, row in self.logs_df.iterrows(): + if row['action'] == 'BUY' and self.data.datetime.date(0) == pd.to_datetime(row['date']).date(): + self.buy_price = row['predicted_price'] + self.order = self.buy(size=1) # 예시: 1주 매수 + + elif row['action'] == 'SELL' and self.data.datetime.date(0) == pd.to_datetime(row['date']).date(): + if self.buy_price: + sell_price = row['predicted_price'] + profit = (sell_price - self.buy_price) / self.buy_price * 100 # 수익률 계산 + print(f"Profit from {row['ticker']}: {profit:.2f}%") + self.order = self.sell(size=1) # 예시: 1주 매도 + self.buy_price = None + +# ---------------------------------------------------------------------- +# 테스트 실행 (2024년 1월 1일부터 12월 31일까지의 데이터로 테스트) +# ---------------------------------------------------------------------- +def test_transformer_backtrader(): + """ + 1년 동안 Transformer 모델을 통해 매매 신호를 예측하고, + Backtrader를 사용하여 수익률을 계산하는 테스트 함수 + """ + start_date = "2024-01-01" + end_date = "2024-12-31" + db_name = "db" # DB 이름 + + # 1. Transformer 모델을 통한 예측 신호 생성 + finder_df = pd.DataFrame({"ticker": ["AAPL", "MSFT", "GOOGL"]}) + + # DB에서 OHLCV 데이터 가져오기 + raw_data = fetch_ohlcv(finder_df['ticker'].tolist(), start_date, end_date, db_name=db_name) + + # Transformer 모델 실행 + logs_df = run_transformer_for_test(finder_df, raw_data, start_date, end_date) + + if logs_df.empty: + print("Transformer 모델에서 예측된 신호가 없습니다.") + return + + # 2. Backtrader 전략 실행 (매매 시뮬레이션) + ohlcv_data_feed = PandasData(dataname=raw_data) + + # Cerebro 엔진 설정 + cerebro = Cerebro() + cerebro.addstrategy(SimpleStrategy, logs_df=logs_df) # logs_df를 전략에 전달 + cerebro.adddata(ohlcv_data_feed) + + # 초기 자본금 및 수수료 설정 + cerebro.broker.set_cash(100000) # 초기 자본금 설정 + cerebro.broker.set_commission(commission=0.001) # 거래 수수료 설정 + + # 3. 백테스트 실행 + cerebro.run() + + # 최종 자본금 출력 + print(f"Final Portfolio Value: {cerebro.broker.getvalue():.2f}") + +if __name__ == "__main__": + test_transformer_backtrader() diff --git a/AI/transformer/modules/inference.py b/AI/transformer/modules/inference.py index 55aa394e..7a468d3d 100644 --- a/AI/transformer/modules/inference.py +++ b/AI/transformer/modules/inference.py @@ -61,7 +61,7 @@ def run_inference( """ tickers = finder_df["ticker"].astype(str).tolist() if raw_data is None or raw_data.empty: - print("[INFER] raw_data empty -> empty logs") + print("[INFER] raw_data empty -> empty logs") return {"logs": pd.DataFrame(columns=[ "ticker","date","action","price","weight", "feature1","feature2","feature3", diff --git a/AI/weekly_tickers.json b/AI/weekly_tickers.json new file mode 100644 index 00000000..182836de --- /dev/null +++ b/AI/weekly_tickers.json @@ -0,0 +1 @@ +{"tickers": ["AAPL", "MSFT", "GOOGL"], "saved_at": "2025-11-24T06:30:32.207866+00:00"} \ No newline at end of file