From e62f32acd78db46f37426929bc72373f5776199e Mon Sep 17 00:00:00 2001 From: twq110 Date: Wed, 12 Nov 2025 21:46:27 +0900 Subject: [PATCH 1/8] =?UTF-8?q?[AI]=20SISC2-43=20[FEAT]=20=EC=9E=90?= =?UTF-8?q?=EC=82=B0=EB=B0=B0=EB=B6=84=20=EB=A1=9C=EC=A7=81=20=EC=9E=91?= =?UTF-8?q?=EC=84=B1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- AI/backtest/order_policy.py | 85 +++++++++ AI/backtest/simple_backtester.py | 157 ++++++++++++++++ AI/libs/core/pipeline.py | 249 +++++++++++++++---------- AI/libs/utils/save_executions_to_db.py | 81 ++++++++ AI/libs/utils/save_reports_to_db.py | 107 +++++------ 5 files changed, 516 insertions(+), 163 deletions(-) create mode 100644 AI/backtest/order_policy.py create mode 100644 AI/backtest/simple_backtester.py create mode 100644 AI/libs/utils/save_executions_to_db.py diff --git a/AI/backtest/order_policy.py b/AI/backtest/order_policy.py new file mode 100644 index 00000000..839095a3 --- /dev/null +++ b/AI/backtest/order_policy.py @@ -0,0 +1,85 @@ +# backtest/order_policy.py +# -*- coding: utf-8 -*- +""" +한국어 주석: +- 백테스터의 '체결 수량 및 포지션 결정 로직'을 별도 모듈로 분리. +- 현재는 단순 rule-based이며, 향후 강화학습(Agent) 정책으로 교체할 수 있음. +""" + +from __future__ import annotations +from typing import Dict, Tuple + + +def decide_order( + side: str, + cash: float, + cur_qty: int, + avg_price: float, + fill_price: float, + config, +) -> Tuple[int, float]: + """ + 한국어 주석: + - 체결 수량 및 거래금액을 결정하는 핵심 함수. + - 강화학습 Agent가 교체할 대상 부분. + ----------------------------- + 입력값: + side: "BUY" 또는 "SELL" + cash: 현재 현금 잔고 + cur_qty: 현재 보유 주식 수량 + avg_price: 현재 보유 평균단가 + fill_price: 이번 체결 기준가 (슬리피지 반영 전) + config: BacktestConfig 인스턴스 + 반환값: + (qty, trade_value) + qty: 매수/매도 수량 + trade_value: 체결 총액(+BUY 지출, -SELL 유입) + """ + + qty = 0 + trade_value = 0.0 + + if side == "BUY": + # 현금 중 risk_frac 비율만큼 투자 + cash_to_use = max(0.0, cash * config.risk_frac) + qty = int(cash_to_use // fill_price) + + # 동시 보유 제한 + if config.max_positions_per_ticker == 1 and cur_qty > 0: + qty = 0 + + trade_value = fill_price * qty # BUY → 현금 지출 (+) + + elif side == "SELL": + qty = cur_qty # 전량 청산 + trade_value = -fill_price * qty # SELL → 현금 유입 (-) + + return qty, trade_value + + +# === 확장용 RL 정책 클래스 === +class RLOrderPolicy: + """ + 한국어 주석: + - 강화학습 Agent를 위한 placeholder 클래스. + - 현재는 rule-based decide_order를 그대로 호출하지만, + 이후 RL 모델의 action 출력을 이용해 수량/금액을 결정할 수 있다. + """ + + def __init__(self, model=None): + self.model = model # RL 네트워크 or 정책 객체 + + def decide(self, state: Dict, config) -> Tuple[int, float]: + """ + state: {'cash':..., 'price':..., 'pos':..., 'side':...} + 반환: (qty, trade_value) + """ + side = state.get("side", "BUY") + return decide_order( + side=side, + cash=state.get("cash", 0.0), + cur_qty=state.get("cur_qty", 0), + avg_price=state.get("avg_price", 0.0), + fill_price=state.get("price", 0.0), + config=config, + ) diff --git a/AI/backtest/simple_backtester.py b/AI/backtest/simple_backtester.py new file mode 100644 index 00000000..97cfbd5f --- /dev/null +++ b/AI/backtest/simple_backtester.py @@ -0,0 +1,157 @@ +# backtest/simple_backtester.py +# -*- coding: utf-8 -*- +""" +한국어 주석: +- OHLCV 없이, Transformer 결정 로그(decision_log)의 price만으로 + 간소화된 백테스트를 수행하는 환경(Environment) 역할. +- 수량/포지션 결정은 backtest/order_policy.py 모듈로 분리됨. +""" + +from __future__ import annotations +from dataclasses import dataclass +from typing import Optional, Dict, Tuple, List +import pandas as pd +import numpy as np + +from backtest.order_policy import decide_order # 분리된 정책 모듈 import + + +# === 설정 클래스 === +@dataclass +class BacktestConfig: + """ + 한국어 주석: + - 간소화 백테스터 설정 + - 향후 강화학습 환경 초기화 시에도 그대로 사용 가능 + """ + initial_cash: float = 100_000.0 + slippage_bps: float = 5.0 + commission_bps: float = 3.0 + risk_frac: float = 0.2 + max_positions_per_ticker: int = 1 + fill_on_same_day: bool = True + + +# === 내부 유틸 === +def _apply_price_with_slippage(price: float, side: str, slippage_bps: float) -> float: + """슬리피지를 체결가에 반영""" + adj = 1.0 + (slippage_bps / 10_000.0) * (1 if side.upper() == "BUY" else -1) + return float(price) * adj + + +def _apply_commission(value: float, commission_bps: float) -> float: + """체결 금액에 대해 bps 단위 수수료 계산""" + return abs(value) * (commission_bps / 10_000.0) + + +def _fill_date_from_signal(sig_date: pd.Timestamp, same_day: bool) -> pd.Timestamp: + """OHLCV 없이 동일일 또는 다음날 체결로 단순 처리""" + return sig_date if same_day else (sig_date + pd.Timedelta(days=1)) + + +# === 백테스트 본체 === +def backtest( + decision_log: pd.DataFrame, + config: Optional[BacktestConfig] = None, + run_id: Optional[str] = None, +) -> Tuple[pd.DataFrame, Dict]: + """ + 한국어 주석: + - 입력: Transformer 의사결정 로그(decision_log) + - 처리: 가격 기반 슬리피지·수수료 반영 후 체결/포지션 갱신 + - 반환: (fills_df, summary) + """ + if config is None: + config = BacktestConfig() + + dl = decision_log.copy() + if not {"ticker", "date", "action", "price"}.issubset(dl.columns): + raise ValueError("decision_log에 'ticker','date','action','price' 컬럼이 필요합니다.") + + dl["date"] = pd.to_datetime(dl["date"]) + dl = dl.sort_values(["date", "ticker"]).reset_index(drop=True) + + cash = float(config.initial_cash) + positions: Dict[str, Dict[str, float]] = {} + records: List[Dict] = [] + + for _, r in dl.iterrows(): + ticker = str(r["ticker"]) + sig_date = pd.Timestamp(r["date"]) + sig = str(r["action"]).upper() + sig_price = float(r.get("price", np.nan)) + + if sig not in ("BUY", "SELL"): + continue + + fill_date = _fill_date_from_signal(sig_date, config.fill_on_same_day) + fill_price = _apply_price_with_slippage(sig_price, sig, config.slippage_bps) + + pos = positions.get(ticker, {"qty": 0, "avg": 0.0}) + cur_qty = pos["qty"] + avg_price = pos["avg"] + side = "BUY" if sig == "BUY" else "SELL" + + # === 🔹 체결 정책 호출 (외부 모듈) === + qty, trade_value = decide_order( + side=side, + cash=cash, + cur_qty=cur_qty, + avg_price=avg_price, + fill_price=fill_price, + config=config, + ) + + if qty <= 0: + continue + + # === 나머지는 환경의 기계적 계산 === + commission = _apply_commission(trade_value, config.commission_bps) + cash_after = cash - trade_value - commission + + # 포지션 업데이트 + if side == "BUY": + new_qty = cur_qty + qty + new_avg = (avg_price * cur_qty + fill_price * qty) / max(1, new_qty) + else: + new_qty = cur_qty - qty + new_avg = avg_price if new_qty > 0 else 0.0 + + pnl_realized = 0.0 + if side == "SELL": + pnl_realized = (fill_price - avg_price) * qty + + pnl_unrealized = 0.0 + + # 상태 저장 + cash = cash_after + positions[ticker] = {"qty": new_qty, "avg": new_avg} + + records.append({ + "run_id": run_id, + "ticker": ticker, + "signal_date": sig_date.date().isoformat(), + "signal_price": float(sig_price), + "signal": sig, + "fill_date": fill_date.date().isoformat(), + "fill_price": float(fill_price), + "qty": int(qty), + "side": side, + "value": float(trade_value), + "commission": float(commission), + "cash_after": float(cash_after), + "position_qty": int(new_qty), + "avg_price": float(new_avg), + "pnl_realized": float(pnl_realized), + "pnl_unrealized": float(pnl_unrealized), + }) + + fills = pd.DataFrame.from_records(records) + summary = { + "run_id": run_id, + "trades": int(len(fills)), + "cash_final": float(cash), + "pnl_realized_sum": float(fills["pnl_realized"].sum()) if not fills.empty else 0.0, + "commission_sum": float(fills["commission"].sum()) if not fills.empty else 0.0, + } + return fills, summary diff --git a/AI/libs/core/pipeline.py b/AI/libs/core/pipeline.py index 6622ad09..c45f1bfb 100644 --- a/AI/libs/core/pipeline.py +++ b/AI/libs/core/pipeline.py @@ -1,7 +1,21 @@ +# pipeline/run_pipeline.py +# -*- coding: utf-8 -*- +""" +한국어 주석: +- 본 파일은 주간 파이프라인(종목 발굴 → 신호 변환 → 백테스트(체결/수량) → XAI 리포트 → DB 저장)을 한 번에 실행한다. +- 핵심 변화: + 1) Backtester 단계 신설: Transformer 의사결정 로그의 price를 그대로 체결가 기준으로 사용(OHLCV 불필요) + 2) 체결내역을 executions 테이블에 저장(save_executions_to_db) + 3) XAI 리포트를 기존대로 reports 테이블(또는 xai_reports)에 저장(save_reports_to_db) +- 주의: + - decision_log에는 반드시 ['ticker','date','action','price']가 있어야 한다. + - XAI용 필수 feature 컬럼(feature_name1~3, feature_score1~3) 점검 로직 포함. +""" + import os import sys from typing import List, Dict, Optional, Tuple -from datetime import datetime, timedelta, timezone +from datetime import datetime, timezone import pandas as pd # --- 프로젝트 루트 경로 설정 --- @@ -10,19 +24,20 @@ # ------------------------------ # --- 모듈 import --- -from finder.main import run_finder -from transformer.main import run_transformer -from libs.utils.fetch_ohlcv import fetch_ohlcv -from xai.run_xai import run_xai -from libs.utils.get_db_conn import get_db_conn -from libs.utils.save_reports_to_db import save_reports_to_db +from finder.main import run_finder # 1) 종목 발굴 +from transformer.main import run_transformer # 2) 신호 생성(결정 로그 생성) +from xai.run_xai import run_xai # 4) XAI 리포트 텍스트 생성 +from libs.utils.save_reports_to_db import save_reports_to_db # 5) 리포트 저장 +from libs.utils.get_db_conn import get_db_conn # (옵션) DB 연결 헬퍼 +from backtest.simple_backtester import backtest, BacktestConfig # 3) 백테스팅(간소화) +from libs.utils.save_executions_to_db import save_executions_to_db # 3.5) 체결내역 저장 # --------------------------------- # DB 이름 상수(실제 등록된 키와 반드시 일치해야 함) -MARKET_DB_NAME = "db" # 시세/원천 데이터 DB -REPORT_DB_NAME = "report_DB" # 리포트 저장 DB +MARKET_DB_NAME = "db" # (현재 파이프라인에선 직접 사용 안함, 향후 확장 대비) +REPORT_DB_NAME = "db" # 체결내역/리포트 저장용 DB, 하나로 통합된 예정 -# === (신규 전용) 필수 컬럼: inference 로그 → XAI 변환에 필요한 것만 강제 === +# === (신규 전용) XAI 필수 컬럼: inference 로그 → XAI 변환에 필요한 것만 강제 === REQUIRED_LOG_COLS = { "ticker", "date", "action", "price", # XAI evidence 구성에 꼭 필요한 신규 컬럼 @@ -31,63 +46,74 @@ # (원하면 로깅/모니터링용 확률도 계속 받되 필수는 아님) } +# === 유틸 === +def _utcnow() -> datetime: + """한국어 주석: UTC now 헬퍼(테스트 재현성을 위해 따로 분리)""" + return datetime.now(timezone.utc) + +def _to_iso_date(v) -> str: + """한국어 주석: pandas.Timestamp/ datetime → YYYY-MM-DD 문자열로 안전 변환""" + try: + if isinstance(v, (pd.Timestamp, datetime)): + return v.strftime("%Y-%m-%d") + return str(v) + except Exception: + return str(v) + +def _to_float(v, fallback=0.0) -> float: + """한국어 주석: 숫자 변환 시 NaN/예외에 대해 기본값 제공""" + try: + f = float(v) + if pd.isna(f): + return float(fallback) + return f + except Exception: + return float(fallback) + +# === 1) Finder: 주간 종목 추출 === def run_weekly_finder() -> List[str]: """ - 주간 종목 발굴(Finder)을 실행하고 결과(종목 리스트)를 반환합니다. + 한국어 주석: + - Finder 모듈을 실행하여 후보 티커 리스트를 반환한다. + - 현재는 run_finder()의 구현/데이터 이슈로 인해 임시 티커를 반환할 수 있다. """ print("--- [PIPELINE-STEP 1] Finder 모듈 실행 시작 ---") - # top_tickers = run_finder() # TODO: 종목 선정 이슈 해결 후 사용 - top_tickers = ["AAPL", "MSFT", "GOOGL"] # 임시 데이터 - print("--- [PIPELINE-STEP 1] Finder 모듈 실행 완료 ---") - return top_tickers - -def _utcnow() -> datetime: - return datetime.now(timezone.utc) - + try: + # 실제 구현 연결 시 사용 + tickers = run_finder() + if not tickers: + # 비상용 임시 리스트 + tickers = ["AAPL", "MSFT", "GOOGL"] + except Exception as e: + print(f"[WARN] Finder 실행 중 오류: {e} → 임시 티커 사용") + tickers = ["AAPL", "MSFT", "GOOGL"] + + print(f"--- [PIPELINE-STEP 1] Finder 완료: tickers={tickers} ---") + return tickers + +# === 2) Transformer: 신호/로그 생성 === def run_signal_transformer(tickers: List[str], db_name: str) -> pd.DataFrame: """ - 종목 리스트를 받아 Transformer 모듈을 실행하고, 신호(결정 로그)를 반환합니다. + 한국어 주석: + - Transformer 모듈을 실행하여 의사결정 로그(decision logs)를 반환한다. + - 반환 데이터프레임 예시 컬럼: + ['ticker','date','action','price','feature_name1','feature_score1', ...] + - XAI 단계의 필수 컬럼(REQUIRED_LOG_COLS)을 사전 점검한다. """ print("--- [PIPELINE-STEP 2] Transformer 모듈 실행 시작 ---") - if not tickers: print("[WARN] 빈 종목 리스트가 입력되어 Transformer를 건너뜁니다.") return pd.DataFrame() - # end_date = _utcnow() # 서버 사용 시 - end_date = datetime.strptime("2024-10-30", "%Y-%m-%d") # 임시 고정 날짜 - start_date = end_date - timedelta(days=600) - - all_ohlcv_df: List[pd.DataFrame] = [] - for ticker in tickers: - try: - ohlcv_df = fetch_ohlcv( - ticker=ticker, - start=start_date.strftime("%Y-%m-%d"), - end=end_date.strftime("%Y-%m-%d"), - db_name=db_name - ) - if ohlcv_df is None or ohlcv_df.empty: - print(f"[WARN] OHLCV 미수집: {ticker}") - continue - ohlcv_df = ohlcv_df.copy() - ohlcv_df["ticker"] = ticker - all_ohlcv_df.append(ohlcv_df) - except Exception as e: - print(f"[ERROR] OHLCV 수집 실패({ticker}): {e}") - - if not all_ohlcv_df: - print("[ERROR] 어떤 티커에서도 OHLCV 데이터를 가져오지 못했습니다.") - return pd.DataFrame() - - raw_data = pd.concat(all_ohlcv_df, ignore_index=True) - + # Transformer 인터페이스 규약: finder_df, seq_len, pred_h, raw_data 등 필요 시 맞춰 전달 + # (여기서는 run_transformer 내부에서 데이터 수집/전처리까지 처리한다고 가정) finder_df = pd.DataFrame(tickers, columns=["ticker"]) + transformer_result: Dict = run_transformer( finder_df=finder_df, seq_len=60, pred_h=1, - raw_data=raw_data + raw_data=None # OHLCV 비사용 파이프라인이므로 None 전달(내부에서 자체 처리할 수 있음) ) or {} logs_df: pd.DataFrame = transformer_result.get("logs", pd.DataFrame()) @@ -95,52 +121,64 @@ def run_signal_transformer(tickers: List[str], db_name: str) -> pd.DataFrame: print("[WARN] Transformer 결과 로그가 비어 있습니다.") return pd.DataFrame() - # === 신규 포맷 강제 체크 === + # 신규 포맷 강제 체크(XAI 및 백테스터가 기대하는 컬럼 존재 여부 확인) missing_cols = REQUIRED_LOG_COLS - set(logs_df.columns) if missing_cols: print(f"[ERROR] 결정 로그에 필수 컬럼 누락(신규 포맷 전용): {sorted(missing_cols)}") return pd.DataFrame() - print("--- [PIPELINE-STEP 2] Transformer 모듈 실행 완료 ---") + print(f"--- [PIPELINE-STEP 2] Transformer 완료: logs={len(logs_df)} rows ---") return logs_df -# --- 안전 변환 유틸 --- -def _to_iso_date(v) -> str: - try: - if isinstance(v, (pd.Timestamp, datetime)): - return v.strftime("%Y-%m-%d") - return str(v) - except Exception: - return str(v) +# === 3) Backtester: 로그 price 기준 체결가/수량 산출 === +def run_backtester(decision_log: pd.DataFrame) -> pd.DataFrame: + """ + 한국어 주석: + - Transformer 의사결정 로그의 price를 체결 기준가로 사용(OHLCV 불필요) + - 슬리피지/수수료/사이징을 BacktestConfig로 제어 + - 반환: 체결내역 DataFrame (채결가, 수량, 실현손익 등 포함) + """ + print("--- [PIPELINE-STEP 3] Backtester 실행 시작 ---") + if decision_log is None or decision_log.empty: + print("[WARN] Backtester: 비어있는 결정 로그가 입력되었습니다.") + return pd.DataFrame() -def _to_float(v, fallback=0.0) -> float: - try: - f = float(v) - if pd.isna(f): - return float(fallback) - return f - except Exception: - return float(fallback) + # 동일 실행 묶음 식별자(run_id) + run_id = _utcnow().strftime("run-%Y%m%d-%H%M%S") + + cfg = BacktestConfig( + initial_cash=100_000.0, # 시작 현금 + slippage_bps=5.0, # 슬리피지 5bp + commission_bps=3.0, # 수수료 3bp + risk_frac=0.2, # 1회 진입에 현금의 20% 사용 + max_positions_per_ticker=1, + fill_on_same_day=True # 로그 가격으로 '즉시' 체결 + ) + + fills_df, summary = backtest( + decision_log=decision_log, + config=cfg, + run_id=run_id + ) + + if fills_df is None or fills_df.empty: + print("[WARN] Backtester: 생성된 체결 내역이 없습니다.") + return pd.DataFrame() -# --- XAI 리포트: 5-튜플(rows)로 반환 --- + print(f"--- [PIPELINE-STEP 3] 완료: trades={len(fills_df)}, " + f"cash_final={summary.get('cash_final')}, pnl_realized_sum={summary.get('pnl_realized_sum')} ---") + return fills_df + +# === 4) XAI 리포트: 5-튜플(rows) 생성 === def run_xai_report(decision_log: pd.DataFrame) -> List[Tuple[str, str, float, str, str]]: """ - 반환: List[(ticker, signal, price, date, report_text)] - XAI 포맷: - { - "ticker": "...", - "date": "YYYY-MM-DD", - "signal": "BUY|HOLD|SELL", - "price": float, - "evidence": [ - {"feature_name": str, "contribution": float}, # 0~1 점수 권장 - ... - ] - } - ※ 신규 포맷 전용: - - feature_name1~3, feature_score1~3 필수 + 한국어 주석: + - 입력: Transformer 결정 로그 + - 출력: List[(ticker, signal, price, date, report_text)] + - GROQ_API_KEY 미설정 시 XAI 단계를 건너뛰고 빈 리스트 반환 + - 필수 feature 컬럼 존재 검사 포함 """ - print("--- [PIPELINE-STEP 3] XAI 리포트 생성 시작 ---") + print("--- [PIPELINE-STEP 4] XAI 리포트 생성 시작 ---") api_key = os.environ.get("GROQ_API_KEY") if not api_key: print("[STOP] GROQ_API_KEY 미설정: XAI 리포트 단계를 건너뜁니다.") @@ -158,27 +196,22 @@ def run_xai_report(decision_log: pd.DataFrame) -> List[Tuple[str, str, float, st return [] rows: List[Tuple[str, str, float, str, str]] = [] - for _, row in decision_log.iterrows(): ticker = str(row.get("ticker", "UNKNOWN")) date_s = _to_iso_date(row.get("date", "")) signal = str(row.get("action", "")) # action -> signal price = _to_float(row.get("price", 0.0)) - # === 신규 포맷 전용 evidence === + # === 신규 포맷 전용 evidence(간단 직렬화) === evidence: List[Dict[str, float]] = [] for i in (1, 2, 3): name = row.get(f"feature_name{i}") score = row.get(f"feature_score{i}") - # 이름/점수 모두 있어야 추가 if name is None or str(name).strip() == "": continue if score is None or pd.isna(score): continue - evidence.append({ - "feature_name": str(name), - "contribution": _to_float(score, 0.0) # 0~1 정규화 점수 - }) + evidence.append({"feature_name": str(name), "contribution": _to_float(score, 0.0)}) decision_payload = { "ticker": ticker, @@ -190,7 +223,7 @@ def run_xai_report(decision_log: pd.DataFrame) -> List[Tuple[str, str, float, st try: report_text = run_xai(decision_payload, api_key) - report_text = str(report_text) # 혹시 모를 비문자 타입 대비 + report_text = str(report_text) print(f"--- {ticker} XAI 리포트 생성 완료 ---") except Exception as e: report_text = f"[ERROR] XAI 리포트 생성 실패: {e}" @@ -198,13 +231,15 @@ def run_xai_report(decision_log: pd.DataFrame) -> List[Tuple[str, str, float, st rows.append((ticker, signal, price, date_s, report_text)) - print("--- [PIPELINE-STEP 3] XAI 리포트 생성 완료 ---") + print("--- [PIPELINE-STEP 4] XAI 리포트 생성 완료 ---") return rows - -def run_pipeline() -> Optional[List[str]]: +# === 파이프라인 오케스트레이션 === +def run_pipeline() -> Optional[List[Tuple[str, str, float, str, str]]]: """ - 전체 파이프라인(Finder -> Transformer -> XAI)을 실행합니다. + 한국어 주석: + - 전체 파이프라인(Finder -> Transformer -> Backtester -> XAI -> 저장)을 실행한다. + - 반환: XAI 리포트 rows(5-튜플 리스트), 저장은 내부에서 수행. """ # 1) Finder tickers = run_weekly_finder() @@ -218,17 +253,31 @@ def run_pipeline() -> Optional[List[str]]: print("[STOP] Transformer에서 신호를 생성하지 못해 파이프라인을 중단합니다.") return None - # 3) XAI + # 3) Backtester (의사결정 로그 price로 체결) + fills_df = run_backtester(logs_df) + + # 3.5) 체결 내역 DB 저장(선택 사항이지만 기본 저장 권장) + try: + save_executions_to_db(fills_df, REPORT_DB_NAME) + print("[INFO] 체결 내역을 DB에 저장했습니다.") + except Exception as e: + print(f"[WARN] 체결 내역 DB 저장 실패: {e}") + + # 4) XAI reports = run_xai_report(logs_df) - # 4) 저장 - save_reports_to_db(reports, REPORT_DB_NAME) + # 5) 리포트 저장 + try: + save_reports_to_db(reports, REPORT_DB_NAME) + print("[INFO] XAI 리포트를 DB에 저장했습니다.") + except Exception as e: + print(f"[WARN] XAI 리포트 DB 저장 실패: {e}") return reports # --- 테스트 실행 --- if __name__ == "__main__": - print(">>> 파이프라인 (Finder -> Transformer -> XAI) 테스트를 시작합니다.") + print(">>> 파이프라인 (Finder -> Transformer -> Backtester -> XAI) 테스트를 시작합니다.") final_reports = run_pipeline() print("\n>>> 최종 반환 결과 (XAI Reports):") if final_reports: diff --git a/AI/libs/utils/save_executions_to_db.py b/AI/libs/utils/save_executions_to_db.py new file mode 100644 index 00000000..8ab35d26 --- /dev/null +++ b/AI/libs/utils/save_executions_to_db.py @@ -0,0 +1,81 @@ +# libs/utils/save_executions_to_db.py +# -*- coding: utf-8 -*- +""" +한국어 주석: +- 간소화 백테스터의 체결 내역(DataFrame)을 executions 테이블에 저장한다. +- 주요 컬럼: + run_id, ticker, signal_date, signal_price, signal, fill_date, fill_price, qty, side, + value, commission, cash_after, position_qty, avg_price, pnl_realized, pnl_unrealized, created_at +- DB 엔진: libs.utils.get_db_conn 모듈의 get_engine(db_name) 사용(프로젝트 규약 준수) +""" + +from __future__ import annotations +from typing import Iterable, Optional +from sqlalchemy import text +from datetime import datetime, timezone + +from libs.utils.get_db_conn import get_engine # 프로젝트 기존 헬퍼 사용 + +def _utcnow_iso() -> str: + """한국어 주석: created_at 등의 기록용 ISO8601 타임스탬프 문자열""" + return datetime.now(timezone.utc).isoformat() + +def ensure_exec_table_schema(engine) -> None: + """ + 한국어 주석: + - executions 테이블이 없으면 생성한다. + - 이미 있을 경우는 CREATE TABLE IF NOT EXISTS로 무해. + - 컬럼 타입은 PostgreSQL 기준(NUMERIC 정밀도 넉넉히 설정). + * SQLite를 쓴다면 NUMERIC이 실수로 저장되지만 문제 없이 동작. + """ + with engine.begin() as conn: + conn.execute(text(""" + CREATE TABLE IF NOT EXISTS executions ( + id SERIAL PRIMARY KEY, + run_id VARCHAR(64), + ticker VARCHAR(20) NOT NULL, + signal_date DATE NOT NULL, + signal_price NUMERIC(18,6), + signal VARCHAR(10) NOT NULL, + fill_date DATE NOT NULL, + fill_price NUMERIC(18,6) NOT NULL, + qty INTEGER NOT NULL, + side VARCHAR(5) NOT NULL, + value NUMERIC(20,6) NOT NULL, + commission NUMERIC(18,6) NOT NULL, + cash_after NUMERIC(20,6) NOT NULL, + position_qty INTEGER NOT NULL, + avg_price NUMERIC(18,6) NOT NULL, + pnl_realized NUMERIC(18,6) NOT NULL, + pnl_unrealized NUMERIC(18,6) NOT NULL, + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW() + ); + """)) + +def save_executions_to_db(rows_df, db_name: str) -> None: + """ + 한국어 주석: + - 체결 내역 DataFrame(rows_df)을 executions 테이블에 일괄 insert 한다. + - rows_df는 backtest()가 반환한 fills_df 스키마를 그대로 따른다. + - 빈 DF가 들어오면 아무 것도 하지 않는다. + """ + if rows_df is None or rows_df.empty: + # 저장할 내용 없음 + return + + engine = get_engine(db_name) + ensure_exec_table_schema(engine) + + # dict 레코드 리스트로 변환하여 executemany 형태로 성능 확보 + payload = rows_df.to_dict(orient="records") + + with engine.begin() as conn: + sql = text(""" + INSERT INTO executions + (run_id, ticker, signal_date, signal_price, signal, fill_date, fill_price, qty, side, + value, commission, cash_after, position_qty, avg_price, pnl_realized, pnl_unrealized, created_at) + VALUES + (:run_id, :ticker, :signal_date, :signal_price, :signal, :fill_date, :fill_price, :qty, :side, + :value, :commission, :cash_after, :position_qty, :avg_price, :pnl_realized, :pnl_unrealized, NOW()) + """) + conn.execute(sql, payload) diff --git a/AI/libs/utils/save_reports_to_db.py b/AI/libs/utils/save_reports_to_db.py index 6cda73d5..1958030d 100644 --- a/AI/libs/utils/save_reports_to_db.py +++ b/AI/libs/utils/save_reports_to_db.py @@ -1,56 +1,37 @@ # libs/utils/save_reports_to_db.py +# -*- coding: utf-8 -*- +""" +한국어 주석: +- XAI 리포트 저장(원복 버전) +- 체결/자산(cash) 업데이트 로직을 모두 제거하고, + 기존처럼 xai_reports 테이블에 INSERT만 수행한다. +- DB 스키마는 절대 변경하지 않는다(테이블/컬럼 생성/수정 X). +- 입력 rows 형식: List[Tuple[ticker, signal, price, date_str, report_text]] +""" + from __future__ import annotations -from typing import Iterable, Tuple, List, Optional +from typing import Iterable, Tuple, List from datetime import datetime, timezone -from decimal import Decimal -import os from sqlalchemy import text +from libs.utils.get_db_conn import get_engine # 프로젝트 표준 엔진 헬퍼 -# 내부 유틸에서 엔진만 사용 (스키마는 절대 변경 X) -from libs.utils.get_db_conn import get_engine - -ReportRow = Tuple[str, str, float, str, str] # (ticker, signal, price, date_str, report_text) +# (ticker, signal, price, date_str, report_text) +ReportRow = Tuple[str, str, float, str, str] -# ----- 환경 변수로 자산 테이블/컬럼 지정 (기본값 제공) ----- -ASSETS_TABLE = os.getenv("ASSETS_TABLE", "assets") -ASSETS_ID_COLUMN = os.getenv("ASSETS_ID_COLUMN", "id") -ASSETS_CASH_COLUMN = os.getenv("ASSETS_CASH_COLUMN", "cash") -ASSETS_ROW_ID = os.getenv("ASSETS_ROW_ID", "1") # ----- 유틸 ----- def utcnow() -> datetime: + """한국어 주석: 현재 UTC 시간을 반환(생성 시각 created_at 기록용).""" return datetime.now(timezone.utc) -def _to_decimal(x) -> Decimal: - if isinstance(x, Decimal): - return x - try: - return Decimal(str(x)) - except Exception: - return Decimal(0) - -def _fetch_current_cash(conn) -> Optional[Decimal]: - sql = text(f""" - SELECT {ASSETS_CASH_COLUMN} - FROM public.{ASSETS_TABLE} - WHERE {ASSETS_ID_COLUMN} = :rid - FOR UPDATE - """) - row = conn.execute(sql, {"rid": ASSETS_ROW_ID}).fetchone() - if not row: - return None - return _to_decimal(row[0]) - -def _update_cash(conn, new_cash: Decimal) -> None: - sql = text(f""" - UPDATE public.{ASSETS_TABLE} - SET {ASSETS_CASH_COLUMN} = :cash - WHERE {ASSETS_ID_COLUMN} = :rid - """) - conn.execute(sql, {"cash": str(new_cash), "rid": ASSETS_ROW_ID}) def _build_insert_params(rows: Iterable[ReportRow], created_at: datetime) -> List[dict]: + """ + 한국어 주석: + - 파이프라인에서 넘어온 리포트 튜플들을 DB INSERT용 딕셔너리 리스트로 변환한다. + - 방어적 필터링: ticker/signal/date 가 비어 있으면 해당 행은 건너뜀. + """ out: List[dict] = [] for (ticker, signal, price, date_s, report_text) in rows: if not ticker or not signal or not date_s: @@ -59,28 +40,29 @@ def _build_insert_params(rows: Iterable[ReportRow], created_at: datetime) -> Lis "ticker": ticker, "signal": signal, "price": float(price), - "date": date_s, # 'YYYY-MM-DD' + "date": date_s, # 'YYYY-MM-DD' "report": str(report_text), - "created_at": created_at, + "created_at": created_at, # TIMESTAMPTZ }) return out -# ----- 메인: 1주 고정 체결 + 자산 업데이트 + 리포트 저장 ----- + +# ----- 메인: 리포트 저장(INSERT only) ----- def save_reports_to_db(rows: List[ReportRow], db_name: str) -> int: """ - 요구사항: - - 저장 '직전'에 티커/시그널/가격을 보고 1주만 체결 - - 매 건 체결 후 잔여 현금(자산) 업데이트 - - DB 스키마 변경 금지 (xai_reports는 기존대로 INSERT만) + 한국어 주석: + - 입력된 XAI 리포트(rows)를 public.xai_reports 테이블에 INSERT 한다. + - 테이블 스키마는 건드리지 않는다(생성/ALTER 하지 않음). + - 반환값: 실제 INSERT 된 행 수. """ if not rows: - print("[INFO] 저장할 리포트가 없습니다.") + print("[INFO] 저장할 XAI 리포트가 없습니다.") return 0 engine = get_engine(db_name) created_at = utcnow() - # INSERT 템플릿 (스키마는 건드리지 않음) + # INSERT 템플릿 (스키마는 기존 그대로 사용) insert_sql = text(""" INSERT INTO public.xai_reports (ticker, signal, price, date, report, created_at) VALUES (:ticker, :signal, :price, :date, :report, :created_at) @@ -88,18 +70,17 @@ def save_reports_to_db(rows: List[ReportRow], db_name: str) -> int: inserted = 0 with engine.begin() as conn: - # 현금 락 걸고 읽기 - current_cash = _fetch_current_cash(conn) - if current_cash is None: - # 자산 테이블이 없거나 행이 없으면 바로 저장만 수행 - print(f"[WARN] 자산 테이블 public.{ASSETS_TABLE}에서 행을 찾지 못했어요. 체결 없이 리포트만 저장할게요.") - params = _build_insert_params(rows, created_at) - if params: - # 청크 삽입 - CHUNK = 1000 - for i in range(0, len(params), CHUNK): - batch = params[i:i+CHUNK] - conn.execute(insert_sql, batch) - inserted += len(batch) - print(f"--- {inserted}개의 XAI 리포트가 저장되었습니다 (자산 미적용). ---") - return inserted \ No newline at end of file + params = _build_insert_params(rows, created_at) + if not params: + print("[WARN] 유효한 XAI 리포트 파라미터가 없어 저장을 생략합니다.") + return 0 + + # 대량 삽입 시 성능을 위해 청크 처리 + CHUNK = 1000 + for i in range(0, len(params), CHUNK): + batch = params[i:i + CHUNK] + conn.execute(insert_sql, batch) + inserted += len(batch) + + print(f"--- {inserted}개의 XAI 리포트가 저장되었습니다. ---") + return inserted From 38321b44ef537e5104b4cfd9e1c42fdcd5403c13 Mon Sep 17 00:00:00 2001 From: twq110 Date: Thu, 13 Nov 2025 18:16:33 +0900 Subject: [PATCH 2/8] =?UTF-8?q?[AI]=20SISC2-43-[FIX]=20ohclv=20=EB=88=84?= =?UTF-8?q?=EB=9D=BD=20=EC=88=98=EC=A0=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- AI/libs/core/__init__.py | 3 ++ AI/libs/core/pipeline.py | 74 +++++++++++++++++++++++------ AI/requirements.txt | 2 +- AI/transformer/modules/inference.py | 2 +- 4 files changed, 65 insertions(+), 16 deletions(-) create mode 100644 AI/libs/core/__init__.py diff --git a/AI/libs/core/__init__.py b/AI/libs/core/__init__.py new file mode 100644 index 00000000..de6ce59a --- /dev/null +++ b/AI/libs/core/__init__.py @@ -0,0 +1,3 @@ +#AI/libs/core/pipeline.py +from libs.core.pipeline import run_pipeline +__all__ = ["run_pipeline"] \ No newline at end of file diff --git a/AI/libs/core/pipeline.py b/AI/libs/core/pipeline.py index c45f1bfb..2bb07ccd 100644 --- a/AI/libs/core/pipeline.py +++ b/AI/libs/core/pipeline.py @@ -15,7 +15,7 @@ import os import sys from typing import List, Dict, Optional, Tuple -from datetime import datetime, timezone +from datetime import datetime, timezone ,timedelta import pandas as pd # --- 프로젝트 루트 경로 설정 --- @@ -26,11 +26,14 @@ # --- 모듈 import --- from finder.main import run_finder # 1) 종목 발굴 from transformer.main import run_transformer # 2) 신호 생성(결정 로그 생성) -from xai.run_xai import run_xai # 4) XAI 리포트 텍스트 생성 -from libs.utils.save_reports_to_db import save_reports_to_db # 5) 리포트 저장 -from libs.utils.get_db_conn import get_db_conn # (옵션) DB 연결 헬퍼 from backtest.simple_backtester import backtest, BacktestConfig # 3) 백테스팅(간소화) from libs.utils.save_executions_to_db import save_executions_to_db # 3.5) 체결내역 저장 +from xai.run_xai import run_xai # 4) XAI 리포트 텍스트 생성 +from libs.utils.save_reports_to_db import save_reports_to_db # 5) 리포트 저장 +#--------------------------------- + +# --- 추가 유틸 import --- +from libs.utils.fetch_ohlcv import fetch_ohlcv # 2) OHLCV 수집 헬퍼 # --------------------------------- # DB 이름 상수(실제 등록된 키와 반드시 일치해야 함) @@ -95,25 +98,67 @@ def run_weekly_finder() -> List[str]: def run_signal_transformer(tickers: List[str], db_name: str) -> pd.DataFrame: """ 한국어 주석: - - Transformer 모듈을 실행하여 의사결정 로그(decision logs)를 반환한다. - - 반환 데이터프레임 예시 컬럼: - ['ticker','date','action','price','feature_name1','feature_score1', ...] - - XAI 단계의 필수 컬럼(REQUIRED_LOG_COLS)을 사전 점검한다. + - 종목 리스트를 받아 DB에서 OHLCV를 수집(fetch_ohlcv)하고, 이를 raw_data로 Transformer에 전달한다. + - Transformer는 전달받은 raw_data를 바탕으로 의사결정 로그(logs_df)를 생성한다. + - 반환: 의사결정 로그 DataFrame (REQUIRED_LOG_COLS 검증 포함) + - 전제: + * fetch_ohlcv(ticker, start, end, db_name) -> pd.DataFrame(columns=['date','open','high','low','close','volume',...]) + * run_transformer(finder_df, seq_len, pred_h, raw_data) -> {'logs': pd.DataFrame(...)} """ print("--- [PIPELINE-STEP 2] Transformer 모듈 실행 시작 ---") + + # 1) 입력 방어 if not tickers: print("[WARN] 빈 종목 리스트가 입력되어 Transformer를 건너뜁니다.") return pd.DataFrame() - # Transformer 인터페이스 규약: finder_df, seq_len, pred_h, raw_data 등 필요 시 맞춰 전달 - # (여기서는 run_transformer 내부에서 데이터 수집/전처리까지 처리한다고 가정) - finder_df = pd.DataFrame(tickers, columns=["ticker"]) + # 2) 날짜 구간 설정 + # - 서버 실사용 시: end_date = _utcnow() + # - 재현 테스트/고정 시: 아래 고정값 활용 + end_date = datetime.strptime("2024-11-1", "%Y-%m-%d") # 임시 고정 날짜 + start_date = end_date - timedelta(days=600) + + # 3) 티커별 OHLCV 수집 (DB) + all_ohlcv_df: List[pd.DataFrame] = [] + for ticker in tickers: + try: + ohlcv_df = fetch_ohlcv( + ticker=ticker, + start=start_date.strftime("%Y-%m-%d"), + end=end_date.strftime("%Y-%m-%d"), + db_name=db_name, + ) + if ohlcv_df is None or ohlcv_df.empty: + print(f"[WARN] OHLCV 미수집: {ticker}") + continue + + # 스키마 안전화 및 티커 컬럼 주입 + ohlcv_df = ohlcv_df.copy() + # date 컬럼이 문자열이면 Timestamp로 변환(Transformer에서 기대하는 형식 맞추기) + if not pd.api.types.is_datetime64_any_dtype(ohlcv_df["date"]): + ohlcv_df["date"] = pd.to_datetime(ohlcv_df["date"]) + ohlcv_df["ticker"] = ticker + all_ohlcv_df.append(ohlcv_df) + except Exception as e: + print(f"[ERROR] OHLCV 수집 실패({ticker}): {e}") + + # 4) 합치기 및 검증 + if not all_ohlcv_df: + print("[ERROR] 어떤 티커에서도 OHLCV 데이터를 가져오지 못했습니다.") + return pd.DataFrame() + + raw_data = pd.concat(all_ohlcv_df, ignore_index=True) + # 정렬/인덱스 리셋 (모델 입력 일관성) + raw_data = raw_data.sort_values(["ticker", "date"]).reset_index(drop=True) + + # 5) Transformer 호출 + finder_df = pd.DataFrame(tickers, columns=["ticker"]) transformer_result: Dict = run_transformer( finder_df=finder_df, seq_len=60, pred_h=1, - raw_data=None # OHLCV 비사용 파이프라인이므로 None 전달(내부에서 자체 처리할 수 있음) + raw_data=raw_data, # ✅ DB에서 가져온 OHLCV를 그대로 전달 ) or {} logs_df: pd.DataFrame = transformer_result.get("logs", pd.DataFrame()) @@ -121,15 +166,16 @@ def run_signal_transformer(tickers: List[str], db_name: str) -> pd.DataFrame: print("[WARN] Transformer 결과 로그가 비어 있습니다.") return pd.DataFrame() - # 신규 포맷 강제 체크(XAI 및 백테스터가 기대하는 컬럼 존재 여부 확인) + # 6) XAI/백테스터 필수 컬럼 체크 missing_cols = REQUIRED_LOG_COLS - set(logs_df.columns) if missing_cols: print(f"[ERROR] 결정 로그에 필수 컬럼 누락(신규 포맷 전용): {sorted(missing_cols)}") return pd.DataFrame() - print(f"--- [PIPELINE-STEP 2] Transformer 완료: logs={len(logs_df)} rows ---") + print("--- [PIPELINE-STEP 2] Transformer 모듈 실행 완료 ---") return logs_df + # === 3) Backtester: 로그 price 기준 체결가/수량 산출 === def run_backtester(decision_log: pd.DataFrame) -> pd.DataFrame: """ diff --git a/AI/requirements.txt b/AI/requirements.txt index bdeab51d..94f93321 100644 --- a/AI/requirements.txt +++ b/AI/requirements.txt @@ -11,4 +11,4 @@ yfinance groq requests beautifulsoup4 -pathlib \ No newline at end of file +pathlib diff --git a/AI/transformer/modules/inference.py b/AI/transformer/modules/inference.py index 55aa394e..7a468d3d 100644 --- a/AI/transformer/modules/inference.py +++ b/AI/transformer/modules/inference.py @@ -61,7 +61,7 @@ def run_inference( """ tickers = finder_df["ticker"].astype(str).tolist() if raw_data is None or raw_data.empty: - print("[INFER] raw_data empty -> empty logs") + print("[INFER] raw_data empty -> empty logs") return {"logs": pd.DataFrame(columns=[ "ticker","date","action","price","weight", "feature1","feature2","feature3", From 49b40b0d588085da8d47f7fb500bba9358c43b90 Mon Sep 17 00:00:00 2001 From: twq110 Date: Thu, 20 Nov 2025 12:45:31 +0900 Subject: [PATCH 3/8] =?UTF-8?q?[AI]=20SISC2-43=20[FEAT]=20=EB=A7=A4?= =?UTF-8?q?=EB=A7=A4=20=EB=A1=9C=EC=A7=81=20=EC=9E=91=EC=84=B1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- AI/libs/core/pipeline.py | 341 +++++++++++++++---------- AI/libs/utils/save_executions_to_db.py | 51 +++- AI/libs/utils/save_reports_to_db.py | 41 +-- 3 files changed, 278 insertions(+), 155 deletions(-) diff --git a/AI/libs/core/pipeline.py b/AI/libs/core/pipeline.py index 2bb07ccd..e51e2a59 100644 --- a/AI/libs/core/pipeline.py +++ b/AI/libs/core/pipeline.py @@ -1,61 +1,102 @@ -# pipeline/run_pipeline.py +# pipeline/run_pipeline.py # -*- coding: utf-8 -*- """ -한국어 주석: -- 본 파일은 주간 파이프라인(종목 발굴 → 신호 변환 → 백테스트(체결/수량) → XAI 리포트 → DB 저장)을 한 번에 실행한다. -- 핵심 변화: - 1) Backtester 단계 신설: Transformer 의사결정 로그의 price를 그대로 체결가 기준으로 사용(OHLCV 불필요) - 2) 체결내역을 executions 테이블에 저장(save_executions_to_db) - 3) XAI 리포트를 기존대로 reports 테이블(또는 xai_reports)에 저장(save_reports_to_db) -- 주의: - - decision_log에는 반드시 ['ticker','date','action','price']가 있어야 한다. - - XAI용 필수 feature 컬럼(feature_name1~3, feature_score1~3) 점검 로직 포함. +한국어 주석 (개요): +- 본 파일은 "주간 자동 파이프라인"의 전체 흐름을 오케스트레이션한다. + (Finder → Transformer → XAI 리포트 → Backtester → DB 저장) + +[전체 플로우] +1) Finder + - 시장/전략 조건에 맞는 종목 목록(ticker list)을 선정한다. + +2) Transformer + - 선택된 종목들의 OHLCV를 DB에서 가져온다(fetch_ohlcv). + - LSTM/Rule 기반 등의 Transformer 로직을 통해 의사결정 로그(DataFrame)를 생성한다. + - 이 의사결정 로그는 XAI와 Backtester에서 모두 공통으로 사용된다. + +3) XAI (e.g. GROQ 등 LLM 기반 설명 생성) + - 각 의사결정에 대해 feature_name / feature_score를 기반으로 + "왜 이 신호가 나왔는지"에 대한 자연어 리포트를 생성한다. + - 결과는 xai_reports 테이블에 먼저 저장된다. + - 이 때 생성된 xai_reports.id를 decision_log(logs_df)에 xai_report_id로 심는다. + +4) Backtester + - xai_report_id가 포함된 의사결정 로그(decision_log)를 받아, + price 컬럼을 "체결 기준가"로 직접 사용해 간소화된 백테스트를 수행한다. + - Backtest 결과(fills_df)는 xai_report_id를 그대로 보존한 상태로 executions에 저장된다. + +[주의 사항] +- Transformer가 생성하는 decision_log(DataFrame)는 최소한 아래 컬럼을 포함해야 한다. + ['ticker', 'date', 'action', 'price', + 'feature_name1', 'feature_name2', 'feature_name3', + 'feature_score1', 'feature_score2', 'feature_score3'] +- GROQ_API_KEY 환경변수가 없으면 XAI 단계는 자동으로 스킵된다. """ import os import sys from typing import List, Dict, Optional, Tuple -from datetime import datetime, timezone ,timedelta +from datetime import datetime, timezone, timedelta + import pandas as pd -# --- 프로젝트 루트 경로 설정 --- +# ---------------------------------------------------------------------- +# 프로젝트 루트 경로 설정 +# ---------------------------------------------------------------------- project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(project_root) -# ------------------------------ - -# --- 모듈 import --- -from finder.main import run_finder # 1) 종목 발굴 -from transformer.main import run_transformer # 2) 신호 생성(결정 로그 생성) -from backtest.simple_backtester import backtest, BacktestConfig # 3) 백테스팅(간소화) -from libs.utils.save_executions_to_db import save_executions_to_db # 3.5) 체결내역 저장 -from xai.run_xai import run_xai # 4) XAI 리포트 텍스트 생성 -from libs.utils.save_reports_to_db import save_reports_to_db # 5) 리포트 저장 -#--------------------------------- - -# --- 추가 유틸 import --- -from libs.utils.fetch_ohlcv import fetch_ohlcv # 2) OHLCV 수집 헬퍼 -# --------------------------------- - -# DB 이름 상수(실제 등록된 키와 반드시 일치해야 함) -MARKET_DB_NAME = "db" # (현재 파이프라인에선 직접 사용 안함, 향후 확장 대비) -REPORT_DB_NAME = "db" # 체결내역/리포트 저장용 DB, 하나로 통합된 예정 - -# === (신규 전용) XAI 필수 컬럼: inference 로그 → XAI 변환에 필요한 것만 강제 === + +# ---------------------------------------------------------------------- +# 외부 모듈 import (각 단계별 역할) +# ---------------------------------------------------------------------- +from finder.main import run_finder # 1) 종목 발굴 +from transformer.main import run_transformer # 2) 신호 생성(의사결정 로그 생성) +from backtest.simple_backtester import backtest, BacktestConfig # 4) 백테스팅(간소화 체결 엔진) +from libs.utils.save_executions_to_db import save_executions_to_db # 5) 체결내역 DB 저장 +from xai.run_xai import run_xai # 3) XAI 리포트 텍스트 생성 +from libs.utils.save_reports_to_db import save_reports_to_db # 3.5) XAI 리포트 DB 저장 (id 반환) +from libs.utils.fetch_ohlcv import fetch_ohlcv # (Transformer용) OHLCV 수집 헬퍼 + +# ---------------------------------------------------------------------- +# 타입 별칭 (다른 모듈과의 데이터 계약을 명확히 하기 위함) +# ---------------------------------------------------------------------- +ReportRow = Tuple[str, str, float, str, str] # (ticker, signal, price, date, report_text) + +# ---------------------------------------------------------------------- +# DB 이름 상수 +# ---------------------------------------------------------------------- +MARKET_DB_NAME = "db" # 시세/시장 데이터(DB) 명. +REPORT_DB_NAME = "db" # 체결내역 / XAI 리포트를 저장하는 DB 명 + +# ---------------------------------------------------------------------- +# XAI 및 Backtester에서 공통으로 요구하는 "결정 로그 필수 컬럼" 정의 +# ---------------------------------------------------------------------- REQUIRED_LOG_COLS = { - "ticker", "date", "action", "price", + "ticker", + "date", + "action", + "price", # XAI evidence 구성에 꼭 필요한 신규 컬럼 - "feature_name1", "feature_name2", "feature_name3", - "feature_score1", "feature_score2", "feature_score3", - # (원하면 로깅/모니터링용 확률도 계속 받되 필수는 아님) + "feature_name1", + "feature_name2", + "feature_name3", + "feature_score1", + "feature_score2", + "feature_score3", } -# === 유틸 === + +# ====================================================================== +# 유틸리티 함수 모음 +# ====================================================================== + def _utcnow() -> datetime: - """한국어 주석: UTC now 헬퍼(테스트 재현성을 위해 따로 분리)""" + """현재 시각을 UTC 기준 datetime으로 반환.""" return datetime.now(timezone.utc) + def _to_iso_date(v) -> str: - """한국어 주석: pandas.Timestamp/ datetime → YYYY-MM-DD 문자열로 안전 변환""" + """값을 'YYYY-MM-DD' 문자열로 변환.""" try: if isinstance(v, (pd.Timestamp, datetime)): return v.strftime("%Y-%m-%d") @@ -63,8 +104,9 @@ def _to_iso_date(v) -> str: except Exception: return str(v) -def _to_float(v, fallback=0.0) -> float: - """한국어 주석: 숫자 변환 시 NaN/예외에 대해 기본값 제공""" + +def _to_float(v, fallback: float = 0.0) -> float: + """값을 float로 변환, 실패 시 fallback.""" try: f = float(v) if pd.isna(f): @@ -73,53 +115,48 @@ def _to_float(v, fallback=0.0) -> float: except Exception: return float(fallback) -# === 1) Finder: 주간 종목 추출 === + +# ====================================================================== +# 1) Finder: 주간 종목 추출 단계 +# ====================================================================== + def run_weekly_finder() -> List[str]: """ - 한국어 주석: - - Finder 모듈을 실행하여 후보 티커 리스트를 반환한다. - - 현재는 run_finder()의 구현/데이터 이슈로 인해 임시 티커를 반환할 수 있다. + Finder 모듈을 실행하여 후보 티커 리스트를 반환. """ print("--- [PIPELINE-STEP 1] Finder 모듈 실행 시작 ---") try: - # 실제 구현 연결 시 사용 tickers = run_finder() if not tickers: - # 비상용 임시 리스트 tickers = ["AAPL", "MSFT", "GOOGL"] except Exception as e: - print(f"[WARN] Finder 실행 중 오류: {e} → 임시 티커 사용") + print(f"[WARN] Finder 실행 중 오류 발생: {e} → 임시 티커 리스트를 사용합니다.") tickers = ["AAPL", "MSFT", "GOOGL"] print(f"--- [PIPELINE-STEP 1] Finder 완료: tickers={tickers} ---") return tickers -# === 2) Transformer: 신호/로그 생성 === + +# ====================================================================== +# 2) Transformer: 신호/의사결정 로그 생성 단계 +# ====================================================================== + def run_signal_transformer(tickers: List[str], db_name: str) -> pd.DataFrame: """ - 한국어 주석: - - 종목 리스트를 받아 DB에서 OHLCV를 수집(fetch_ohlcv)하고, 이를 raw_data로 Transformer에 전달한다. - - Transformer는 전달받은 raw_data를 바탕으로 의사결정 로그(logs_df)를 생성한다. - - 반환: 의사결정 로그 DataFrame (REQUIRED_LOG_COLS 검증 포함) - - 전제: - * fetch_ohlcv(ticker, start, end, db_name) -> pd.DataFrame(columns=['date','open','high','low','close','volume',...]) - * run_transformer(finder_df, seq_len, pred_h, raw_data) -> {'logs': pd.DataFrame(...)} + 종목 리스트에 대해 DB에서 OHLCV를 수집하고 Transformer를 호출하여 + 의사결정 로그(DataFrame)를 생성한다. """ print("--- [PIPELINE-STEP 2] Transformer 모듈 실행 시작 ---") - # 1) 입력 방어 if not tickers: - print("[WARN] 빈 종목 리스트가 입력되어 Transformer를 건너뜁니다.") + print("[WARN] 빈 종목 리스트가 입력되어 Transformer 단계를 건너뜁니다.") return pd.DataFrame() - # 2) 날짜 구간 설정 - # - 서버 실사용 시: end_date = _utcnow() - # - 재현 테스트/고정 시: 아래 고정값 활용 end_date = datetime.strptime("2024-11-1", "%Y-%m-%d") # 임시 고정 날짜 start_date = end_date - timedelta(days=600) - # 3) 티커별 OHLCV 수집 (DB) all_ohlcv_df: List[pd.DataFrame] = [] + for ticker in tickers: try: ohlcv_df = fetch_ohlcv( @@ -128,136 +165,158 @@ def run_signal_transformer(tickers: List[str], db_name: str) -> pd.DataFrame: end=end_date.strftime("%Y-%m-%d"), db_name=db_name, ) + if ohlcv_df is None or ohlcv_df.empty: print(f"[WARN] OHLCV 미수집: {ticker}") continue - # 스키마 안전화 및 티커 컬럼 주입 ohlcv_df = ohlcv_df.copy() - # date 컬럼이 문자열이면 Timestamp로 변환(Transformer에서 기대하는 형식 맞추기) + if not pd.api.types.is_datetime64_any_dtype(ohlcv_df["date"]): ohlcv_df["date"] = pd.to_datetime(ohlcv_df["date"]) - ohlcv_df["ticker"] = ticker + ohlcv_df["ticker"] = ticker all_ohlcv_df.append(ohlcv_df) + except Exception as e: - print(f"[ERROR] OHLCV 수집 실패({ticker}): {e}") + print(f"[ERROR] OHLCV 수집 실패 (ticker={ticker}): {e}") - # 4) 합치기 및 검증 if not all_ohlcv_df: - print("[ERROR] 어떤 티커에서도 OHLCV 데이터를 가져오지 못했습니다.") + print("[ERROR] 어떤 티커에서도 OHLCV 데이터를 가져오지 못했습니다. Transformer를 종료합니다.") return pd.DataFrame() raw_data = pd.concat(all_ohlcv_df, ignore_index=True) - # 정렬/인덱스 리셋 (모델 입력 일관성) raw_data = raw_data.sort_values(["ticker", "date"]).reset_index(drop=True) - # 5) Transformer 호출 finder_df = pd.DataFrame(tickers, columns=["ticker"]) + transformer_result: Dict = run_transformer( finder_df=finder_df, seq_len=60, pred_h=1, - raw_data=raw_data, # ✅ DB에서 가져온 OHLCV를 그대로 전달 + raw_data=raw_data, ) or {} logs_df: pd.DataFrame = transformer_result.get("logs", pd.DataFrame()) + if logs_df is None or logs_df.empty: - print("[WARN] Transformer 결과 로그가 비어 있습니다.") + print("[WARN] Transformer 결과 로그(logs)가 비어 있습니다.") return pd.DataFrame() - # 6) XAI/백테스터 필수 컬럼 체크 missing_cols = REQUIRED_LOG_COLS - set(logs_df.columns) if missing_cols: - print(f"[ERROR] 결정 로그에 필수 컬럼 누락(신규 포맷 전용): {sorted(missing_cols)}") + print(f"[ERROR] 결정 로그에 필수 컬럼 누락 (신규 포맷 전용): {sorted(missing_cols)}") return pd.DataFrame() print("--- [PIPELINE-STEP 2] Transformer 모듈 실행 완료 ---") return logs_df -# === 3) Backtester: 로그 price 기준 체결가/수량 산출 === +# ====================================================================== +# 3) Backtester: 의사결정 로그 기반 체결/포지션 계산 단계 +# ====================================================================== + def run_backtester(decision_log: pd.DataFrame) -> pd.DataFrame: """ - 한국어 주석: - - Transformer 의사결정 로그의 price를 체결 기준가로 사용(OHLCV 불필요) - - 슬리피지/수수료/사이징을 BacktestConfig로 제어 - - 반환: 체결내역 DataFrame (채결가, 수량, 실현손익 등 포함) + Transformer에서 생성된 의사결정 로그(decision_log)의 price 컬럼을 + OHLCV 없이 "체결 기준가"로 직접 사용해 간소화된 백테스트를 수행한다. + + 주의: + - decision_log는 xai_report_id 컬럼을 포함할 수 있으며, + backtest() 구현이 해당 컬럼을 드롭하지 않으면 fills_df에도 그대로 보존된다. """ - print("--- [PIPELINE-STEP 3] Backtester 실행 시작 ---") + print("--- [PIPELINE-STEP 4] Backtester 실행 시작 ---") + if decision_log is None or decision_log.empty: - print("[WARN] Backtester: 비어있는 결정 로그가 입력되었습니다.") + print("[WARN] Backtester: 비어있는 결정 로그가 입력되었습니다. 체결을 수행하지 않습니다.") return pd.DataFrame() - # 동일 실행 묶음 식별자(run_id) run_id = _utcnow().strftime("run-%Y%m%d-%H%M%S") cfg = BacktestConfig( - initial_cash=100_000.0, # 시작 현금 - slippage_bps=5.0, # 슬리피지 5bp - commission_bps=3.0, # 수수료 3bp - risk_frac=0.2, # 1회 진입에 현금의 20% 사용 + initial_cash=100_000.0, + slippage_bps=5.0, + commission_bps=3.0, + risk_frac=0.2, max_positions_per_ticker=1, - fill_on_same_day=True # 로그 가격으로 '즉시' 체결 + fill_on_same_day=True, ) fills_df, summary = backtest( decision_log=decision_log, config=cfg, - run_id=run_id + run_id=run_id, ) if fills_df is None or fills_df.empty: print("[WARN] Backtester: 생성된 체결 내역이 없습니다.") return pd.DataFrame() - print(f"--- [PIPELINE-STEP 3] 완료: trades={len(fills_df)}, " - f"cash_final={summary.get('cash_final')}, pnl_realized_sum={summary.get('pnl_realized_sum')} ---") + print( + f"--- [PIPELINE-STEP 4] Backtester 완료: " + f"trades={len(fills_df)}, " + f"cash_final={summary.get('cash_final')}, " + f"pnl_realized_sum={summary.get('pnl_realized_sum')} ---" + ) return fills_df -# === 4) XAI 리포트: 5-튜플(rows) 생성 === -def run_xai_report(decision_log: pd.DataFrame) -> List[Tuple[str, str, float, str, str]]: + +# ====================================================================== +# 4) XAI 리포트: 설명 텍스트 생성 단계 +# ====================================================================== + +def run_xai_report(decision_log: pd.DataFrame) -> List[ReportRow]: """ - 한국어 주석: - - 입력: Transformer 결정 로그 - - 출력: List[(ticker, signal, price, date, report_text)] - - GROQ_API_KEY 미설정 시 XAI 단계를 건너뛰고 빈 리스트 반환 - - 필수 feature 컬럼 존재 검사 포함 + Transformer 결정 로그를 입력으로 받아, 각 행(의사결정)에 대한 + XAI 설명 리포트(자연어 텍스트)를 생성한다. """ - print("--- [PIPELINE-STEP 4] XAI 리포트 생성 시작 ---") + print("--- [PIPELINE-STEP 3] XAI 리포트 생성 시작 ---") + api_key = os.environ.get("GROQ_API_KEY") if not api_key: - print("[STOP] GROQ_API_KEY 미설정: XAI 리포트 단계를 건너뜁니다.") + print("[STOP] GROQ_API_KEY 환경변수가 설정되어 있지 않아 XAI 리포트를 생성하지 않습니다.") return [] if decision_log is None or decision_log.empty: print("[WARN] 비어있는 결정 로그가 입력되어 XAI 리포트를 생성하지 않습니다.") return [] - # 신규 포맷 강제(안전망) - for c in ["feature_name1","feature_name2","feature_name3", - "feature_score1","feature_score2","feature_score3"]: + for c in [ + "feature_name1", + "feature_name2", + "feature_name3", + "feature_score1", + "feature_score2", + "feature_score3", + ]: if c not in decision_log.columns: print(f"[ERROR] XAI: 신규 포맷 필수 컬럼 누락: {c}") return [] - rows: List[Tuple[str, str, float, str, str]] = [] + rows: List[ReportRow] = [] + for _, row in decision_log.iterrows(): ticker = str(row.get("ticker", "UNKNOWN")) date_s = _to_iso_date(row.get("date", "")) - signal = str(row.get("action", "")) # action -> signal - price = _to_float(row.get("price", 0.0)) + signal = str(row.get("action", "")) # action → signal + price = _to_float(row.get("price", 0.0)) - # === 신규 포맷 전용 evidence(간단 직렬화) === evidence: List[Dict[str, float]] = [] for i in (1, 2, 3): name = row.get(f"feature_name{i}") score = row.get(f"feature_score{i}") + if name is None or str(name).strip() == "": continue if score is None or pd.isna(score): continue - evidence.append({"feature_name": str(name), "contribution": _to_float(score, 0.0)}) + + evidence.append( + { + "feature_name": str(name), + "contribution": _to_float(score, 0.0), + } + ) decision_payload = { "ticker": ticker, @@ -270,22 +329,25 @@ def run_xai_report(decision_log: pd.DataFrame) -> List[Tuple[str, str, float, st try: report_text = run_xai(decision_payload, api_key) report_text = str(report_text) - print(f"--- {ticker} XAI 리포트 생성 완료 ---") + print(f"--- [XAI] {ticker} 리포트 생성 완료 ---") except Exception as e: report_text = f"[ERROR] XAI 리포트 생성 실패: {e}" - print(f"--- {ticker} XAI 리포트 생성 중 오류: {e} ---") + print(f"--- [XAI] {ticker} 리포트 생성 중 오류: {e} ---") rows.append((ticker, signal, price, date_s, report_text)) - print("--- [PIPELINE-STEP 4] XAI 리포트 생성 완료 ---") + print("--- [PIPELINE-STEP 3] XAI 리포트 생성 완료 ---") return rows -# === 파이프라인 오케스트레이션 === -def run_pipeline() -> Optional[List[Tuple[str, str, float, str, str]]]: + +# ====================================================================== +# 5) 전체 파이프라인 오케스트레이션 +# ====================================================================== + +def run_pipeline() -> Optional[List[ReportRow]]: """ - 한국어 주석: - - 전체 파이프라인(Finder -> Transformer -> Backtester -> XAI -> 저장)을 실행한다. - - 반환: XAI 리포트 rows(5-튜플 리스트), 저장은 내부에서 수행. + 전체 파이프라인(Finder → Transformer → XAI → Backtester → DB 저장)을 + 한 번에 실행하는 엔트리 포인트 함수. """ # 1) Finder tickers = run_weekly_finder() @@ -296,41 +358,60 @@ def run_pipeline() -> Optional[List[Tuple[str, str, float, str, str]]]: # 2) Transformer logs_df = run_signal_transformer(tickers, MARKET_DB_NAME) if logs_df is None or logs_df.empty: - print("[STOP] Transformer에서 신호를 생성하지 못해 파이프라인을 중단합니다.") + print("[STOP] Transformer에서 유효한 신호를 생성하지 못해 파이프라인을 중단합니다.") return None - # 3) Backtester (의사결정 로그 price로 체결) - fills_df = run_backtester(logs_df) + # 3) XAI 리포트 생성 + reports = run_xai_report(logs_df) - # 3.5) 체결 내역 DB 저장(선택 사항이지만 기본 저장 권장) + # 3.5) XAI 리포트 DB 저장 → 생성된 id 리스트 수신 try: - save_executions_to_db(fills_df, REPORT_DB_NAME) - print("[INFO] 체결 내역을 DB에 저장했습니다.") + xai_ids = save_reports_to_db(reports, REPORT_DB_NAME) + print("[INFO] XAI 리포트를 DB에 저장했습니다.") except Exception as e: - print(f"[WARN] 체결 내역 DB 저장 실패: {e}") + print(f"[WARN] XAI 리포트 DB 저장 실패: {e}") + xai_ids = [] - # 4) XAI - reports = run_xai_report(logs_df) + # 3.7) logs_df에 xai_report_id 심기 + # (길이가 맞지 않거나 XAI 저장 실패 시에는 NULL로 채워서 진행) + logs_df = logs_df.copy().reset_index(drop=True) + if xai_ids and len(xai_ids) == len(logs_df): + logs_df["xai_report_id"] = xai_ids + else: + logs_df["xai_report_id"] = None + if xai_ids and len(xai_ids) != len(logs_df): + print( + f"[WARN] XAI ID 개수({len(xai_ids)})와 decision_log 행 수({len(logs_df)})가 달라 " + "xai_report_id를 매핑하지 못했습니다. (모두 NULL 처리)" + ) + + # 4) Backtester: xai_report_id 포함 decision_log로 체결 내역 생성 + fills_df = run_backtester(logs_df) - # 5) 리포트 저장 + # 5) executions 테이블에 체결 내역 저장 try: - save_reports_to_db(reports, REPORT_DB_NAME) - print("[INFO] XAI 리포트를 DB에 저장했습니다.") + save_executions_to_db(fills_df, REPORT_DB_NAME) + print("[INFO] 체결 내역을 DB에 저장했습니다.") except Exception as e: - print(f"[WARN] XAI 리포트 DB 저장 실패: {e}") + print(f"[WARN] 체결 내역 DB 저장 실패: {e}") return reports -# --- 테스트 실행 --- + +# ====================================================================== +# 스크립트 단독 실행 시 테스트용 엔트리 포인트 +# ====================================================================== if __name__ == "__main__": - print(">>> 파이프라인 (Finder -> Transformer -> Backtester -> XAI) 테스트를 시작합니다.") + print(">>> 파이프라인 (Finder → Transformer → XAI → Backtester) 테스트를 시작합니다.") final_reports = run_pipeline() + print("\n>>> 최종 반환 결과 (XAI Reports):") if final_reports: for report in final_reports: print(report) else: print("생성된 리포트가 없습니다.") + print("\n---") print("테스트가 정상적으로 완료되었다면, 위 '최종 반환 결과'에 각 종목에 대한 XAI 리포트가 출력되어야 합니다.") print("---") diff --git a/AI/libs/utils/save_executions_to_db.py b/AI/libs/utils/save_executions_to_db.py index 8ab35d26..a9096c9d 100644 --- a/AI/libs/utils/save_executions_to_db.py +++ b/AI/libs/utils/save_executions_to_db.py @@ -4,35 +4,42 @@ 한국어 주석: - 간소화 백테스터의 체결 내역(DataFrame)을 executions 테이블에 저장한다. - 주요 컬럼: - run_id, ticker, signal_date, signal_price, signal, fill_date, fill_price, qty, side, - value, commission, cash_after, position_qty, avg_price, pnl_realized, pnl_unrealized, created_at + run_id, xai_report_id, ticker, signal_date, signal_price, signal, fill_date, fill_price, + qty, side, value, commission, cash_after, position_qty, avg_price, + pnl_realized, pnl_unrealized, created_at - DB 엔진: libs.utils.get_db_conn 모듈의 get_engine(db_name) 사용(프로젝트 규약 준수) """ from __future__ import annotations -from typing import Iterable, Optional +from typing import Optional from sqlalchemy import text from datetime import datetime, timezone from libs.utils.get_db_conn import get_engine # 프로젝트 기존 헬퍼 사용 + def _utcnow_iso() -> str: """한국어 주석: created_at 등의 기록용 ISO8601 타임스탬프 문자열""" return datetime.now(timezone.utc).isoformat() + def ensure_exec_table_schema(engine) -> None: """ 한국어 주석: - executions 테이블이 없으면 생성한다. - 이미 있을 경우는 CREATE TABLE IF NOT EXISTS로 무해. - 컬럼 타입은 PostgreSQL 기준(NUMERIC 정밀도 넉넉히 설정). - * SQLite를 쓴다면 NUMERIC이 실수로 저장되지만 문제 없이 동작. + - xai_report_id 컬럼을 추가하여 xai_reports(id)를 FK로 참조한다. """ with engine.begin() as conn: + # 테이블이 없을 때만 생성 conn.execute(text(""" CREATE TABLE IF NOT EXISTS executions ( id SERIAL PRIMARY KEY, run_id VARCHAR(64), + + xai_report_id BIGINT, -- 🔗 xai_reports.id 참조용 (NULL 허용) + ticker VARCHAR(20) NOT NULL, signal_date DATE NOT NULL, signal_price NUMERIC(18,6), @@ -52,11 +59,25 @@ def ensure_exec_table_schema(engine) -> None: ); """)) + # FK는 이미 있을 수 있으니, 한 번 시도하고 실패하면 무시 + try: + conn.execute(text(""" + ALTER TABLE executions + ADD CONSTRAINT fk_executions_xai_reports + FOREIGN KEY (xai_report_id) + REFERENCES xai_reports(id); + """)) + except Exception: + # 이미 FK가 있거나 에러가 나더라도 전체 플로우를 막지 않음 + pass + + def save_executions_to_db(rows_df, db_name: str) -> None: """ 한국어 주석: - 체결 내역 DataFrame(rows_df)을 executions 테이블에 일괄 insert 한다. - rows_df는 backtest()가 반환한 fills_df 스키마를 그대로 따른다. + - XAI 연동 시에는 rows_df에 xai_report_id 컬럼이 포함될 수 있다. - 빈 DF가 들어오면 아무 것도 하지 않는다. """ if rows_df is None or rows_df.empty: @@ -66,16 +87,30 @@ def save_executions_to_db(rows_df, db_name: str) -> None: engine = get_engine(db_name) ensure_exec_table_schema(engine) + # XAI를 안 돌렸거나 매핑이 실패한 경우를 대비하여 컬럼이 없으면 NULL로 채워서 생성 + if "xai_report_id" not in rows_df.columns: + rows_df = rows_df.copy() + print("[WARN] xai_report_id 매핑 실패 또는 XAI 미실행 감지, NULL로 저장합니다.") + rows_df["xai_report_id"] = None + # dict 레코드 리스트로 변환하여 executemany 형태로 성능 확보 payload = rows_df.to_dict(orient="records") with engine.begin() as conn: sql = text(""" INSERT INTO executions - (run_id, ticker, signal_date, signal_price, signal, fill_date, fill_price, qty, side, - value, commission, cash_after, position_qty, avg_price, pnl_realized, pnl_unrealized, created_at) + (run_id, xai_report_id, + ticker, signal_date, signal_price, signal, + fill_date, fill_price, qty, side, + value, commission, cash_after, + position_qty, avg_price, + pnl_realized, pnl_unrealized, created_at) VALUES - (:run_id, :ticker, :signal_date, :signal_price, :signal, :fill_date, :fill_price, :qty, :side, - :value, :commission, :cash_after, :position_qty, :avg_price, :pnl_realized, :pnl_unrealized, NOW()) + (:run_id, :xai_report_id, + :ticker, :signal_date, :signal_price, :signal, + :fill_date, :fill_price, :qty, :side, + :value, :commission, :cash_after, + :position_qty, :avg_price, + :pnl_realized, :pnl_unrealized, NOW()) """) conn.execute(sql, payload) diff --git a/AI/libs/utils/save_reports_to_db.py b/AI/libs/utils/save_reports_to_db.py index 1958030d..77fc4a5c 100644 --- a/AI/libs/utils/save_reports_to_db.py +++ b/AI/libs/utils/save_reports_to_db.py @@ -2,11 +2,15 @@ # -*- coding: utf-8 -*- """ 한국어 주석: -- XAI 리포트 저장(원복 버전) -- 체결/자산(cash) 업데이트 로직을 모두 제거하고, - 기존처럼 xai_reports 테이블에 INSERT만 수행한다. +- XAI 리포트 저장 +- 체결/자산(cash) 업데이트 로직 없이, xai_reports 테이블에 INSERT만 수행한다. - DB 스키마는 절대 변경하지 않는다(테이블/컬럼 생성/수정 X). - 입력 rows 형식: List[Tuple[ticker, signal, price, date_str, report_text]] + +변경점: +- 기존에는 "실제 INSERT 된 행 수(int)"를 반환했으나, + 이제는 "각 INSERT 행의 xai_reports.id 리스트(List[int])"를 반환한다. + (rows의 순서와 id 리스트의 순서는 동일하다.) """ from __future__ import annotations @@ -48,39 +52,42 @@ def _build_insert_params(rows: Iterable[ReportRow], created_at: datetime) -> Lis # ----- 메인: 리포트 저장(INSERT only) ----- -def save_reports_to_db(rows: List[ReportRow], db_name: str) -> int: +def save_reports_to_db(rows: List[ReportRow], db_name: str) -> List[int]: """ 한국어 주석: - 입력된 XAI 리포트(rows)를 public.xai_reports 테이블에 INSERT 한다. - 테이블 스키마는 건드리지 않는다(생성/ALTER 하지 않음). - - 반환값: 실제 INSERT 된 행 수. + - 반환값: 실제 INSERT 된 각 행의 xai_reports.id 리스트. + (rows의 순서에서 유효한 행만 추려낸 순서와 동일) """ if not rows: print("[INFO] 저장할 XAI 리포트가 없습니다.") - return 0 + return [] engine = get_engine(db_name) created_at = utcnow() - # INSERT 템플릿 (스키마는 기존 그대로 사용) + # INSERT 템플릿 (스키마는 기존 그대로 사용) + RETURNING id insert_sql = text(""" INSERT INTO public.xai_reports (ticker, signal, price, date, report, created_at) VALUES (:ticker, :signal, :price, :date, :report, :created_at) + RETURNING id """) - inserted = 0 + inserted_ids: List[int] = [] + with engine.begin() as conn: params = _build_insert_params(rows, created_at) if not params: print("[WARN] 유효한 XAI 리포트 파라미터가 없어 저장을 생략합니다.") - return 0 + return [] - # 대량 삽입 시 성능을 위해 청크 처리 - CHUNK = 1000 - for i in range(0, len(params), CHUNK): - batch = params[i:i + CHUNK] - conn.execute(insert_sql, batch) - inserted += len(batch) + # 리포트 개수가 엄청 많지 않을 것으로 가정하고, id를 받기 위해 한 행씩 INSERT + for p in params: + result = conn.execute(insert_sql, p) + new_id = result.scalar() + if new_id is not None: + inserted_ids.append(int(new_id)) - print(f"--- {inserted}개의 XAI 리포트가 저장되었습니다. ---") - return inserted + print(f"--- {len(inserted_ids)}개의 XAI 리포트가 저장되었습니다. ---") + return inserted_ids From b71bbb3b45d6dbd5b06d627bd74bae9c250e2837 Mon Sep 17 00:00:00 2001 From: twq110 Date: Thu, 20 Nov 2025 13:04:33 +0900 Subject: [PATCH 4/8] =?UTF-8?q?[AI]=20SISC2-43=20[REFACTOR]=20backtest=20?= =?UTF-8?q?=EB=8C=80=EC=8B=A0=20backtrader=20=EB=8B=A8=EC=96=B4=20?= =?UTF-8?q?=EC=82=AC=EC=9A=A9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- AI/{backtest => backtrader}/order_policy.py | 2 +- .../simple_backtrader.py} | 8 ++++---- AI/libs/core/pipeline.py | 16 ++++++++-------- 3 files changed, 13 insertions(+), 13 deletions(-) rename AI/{backtest => backtrader}/order_policy.py (98%) rename AI/{backtest/simple_backtester.py => backtrader/simple_backtrader.py} (96%) diff --git a/AI/backtest/order_policy.py b/AI/backtrader/order_policy.py similarity index 98% rename from AI/backtest/order_policy.py rename to AI/backtrader/order_policy.py index 839095a3..65a93df5 100644 --- a/AI/backtest/order_policy.py +++ b/AI/backtrader/order_policy.py @@ -1,4 +1,4 @@ -# backtest/order_policy.py +# backtrader/order_policy.py # -*- coding: utf-8 -*- """ 한국어 주석: diff --git a/AI/backtest/simple_backtester.py b/AI/backtrader/simple_backtrader.py similarity index 96% rename from AI/backtest/simple_backtester.py rename to AI/backtrader/simple_backtrader.py index 97cfbd5f..8cd796fc 100644 --- a/AI/backtest/simple_backtester.py +++ b/AI/backtrader/simple_backtrader.py @@ -1,10 +1,10 @@ -# backtest/simple_backtester.py +# backtrader/simple_backtrader.py # -*- coding: utf-8 -*- """ 한국어 주석: - OHLCV 없이, Transformer 결정 로그(decision_log)의 price만으로 간소화된 백테스트를 수행하는 환경(Environment) 역할. -- 수량/포지션 결정은 backtest/order_policy.py 모듈로 분리됨. +- 수량/포지션 결정은 backtrader/order_policy.py 모듈로 분리됨. """ from __future__ import annotations @@ -13,7 +13,7 @@ import pandas as pd import numpy as np -from backtest.order_policy import decide_order # 분리된 정책 모듈 import +from backtrader.order_policy import decide_order # 분리된 정책 모듈 import # === 설정 클래스 === @@ -50,7 +50,7 @@ def _fill_date_from_signal(sig_date: pd.Timestamp, same_day: bool) -> pd.Timesta # === 백테스트 본체 === -def backtest( +def backtrader( decision_log: pd.DataFrame, config: Optional[BacktestConfig] = None, run_id: Optional[str] = None, diff --git a/AI/libs/core/pipeline.py b/AI/libs/core/pipeline.py index e51e2a59..b88b4349 100644 --- a/AI/libs/core/pipeline.py +++ b/AI/libs/core/pipeline.py @@ -51,7 +51,7 @@ # ---------------------------------------------------------------------- from finder.main import run_finder # 1) 종목 발굴 from transformer.main import run_transformer # 2) 신호 생성(의사결정 로그 생성) -from backtest.simple_backtester import backtest, BacktestConfig # 4) 백테스팅(간소화 체결 엔진) +from backtrader.run_backtrader import backtest, BacktestConfig # 4) 백트레이딩(간소화 체결 엔진) from libs.utils.save_executions_to_db import save_executions_to_db # 5) 체결내역 DB 저장 from xai.run_xai import run_xai # 3) XAI 리포트 텍스트 생성 from libs.utils.save_reports_to_db import save_reports_to_db # 3.5) XAI 리포트 DB 저장 (id 반환) @@ -213,10 +213,10 @@ def run_signal_transformer(tickers: List[str], db_name: str) -> pd.DataFrame: # ====================================================================== -# 3) Backtester: 의사결정 로그 기반 체결/포지션 계산 단계 +# 3) Backtrader: 의사결정 로그 기반 체결/포지션 계산 단계 # ====================================================================== -def run_backtester(decision_log: pd.DataFrame) -> pd.DataFrame: +def run_backtrader(decision_log: pd.DataFrame) -> pd.DataFrame: """ Transformer에서 생성된 의사결정 로그(decision_log)의 price 컬럼을 OHLCV 없이 "체결 기준가"로 직접 사용해 간소화된 백테스트를 수행한다. @@ -233,7 +233,7 @@ def run_backtester(decision_log: pd.DataFrame) -> pd.DataFrame: run_id = _utcnow().strftime("run-%Y%m%d-%H%M%S") - cfg = BacktestConfig( + cfg = BacktradeConfig( initial_cash=100_000.0, slippage_bps=5.0, commission_bps=3.0, @@ -242,18 +242,18 @@ def run_backtester(decision_log: pd.DataFrame) -> pd.DataFrame: fill_on_same_day=True, ) - fills_df, summary = backtest( + fills_df, summary = backtrader( decision_log=decision_log, config=cfg, run_id=run_id, ) if fills_df is None or fills_df.empty: - print("[WARN] Backtester: 생성된 체결 내역이 없습니다.") + print("[WARN] Backtrader: 생성된 체결 내역이 없습니다.") return pd.DataFrame() print( - f"--- [PIPELINE-STEP 4] Backtester 완료: " + f"--- [PIPELINE-STEP 4] Backtrader 완료: " f"trades={len(fills_df)}, " f"cash_final={summary.get('cash_final')}, " f"pnl_realized_sum={summary.get('pnl_realized_sum')} ---" @@ -386,7 +386,7 @@ def run_pipeline() -> Optional[List[ReportRow]]: ) # 4) Backtester: xai_report_id 포함 decision_log로 체결 내역 생성 - fills_df = run_backtester(logs_df) + fills_df = run_backtrader(logs_df) # 5) executions 테이블에 체결 내역 저장 try: From decd9c14b2c515b57e29d71665930d54a78731cd Mon Sep 17 00:00:00 2001 From: twq110 Date: Thu, 20 Nov 2025 13:39:12 +0900 Subject: [PATCH 5/8] =?UTF-8?q?[AI]=20SISC2-45=20[REFACTOR]=20=EC=A3=BC?= =?UTF-8?q?=EC=84=9D=20=EC=9D=BC=EB=B6=80=20=EB=B3=80=EA=B2=BD=20=EB=B0=8F?= =?UTF-8?q?=20=ED=8C=A8=ED=82=A4=EC=A7=80=20=EC=9E=91=EC=84=B1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- AI/backtrader/__init__.py | 3 ++ AI/backtrader/simple_backtrader.py | 6 ++-- AI/libs/core/pipeline.py | 57 +++++++++++++++--------------- 3 files changed, 34 insertions(+), 32 deletions(-) create mode 100644 AI/backtrader/__init__.py diff --git a/AI/backtrader/__init__.py b/AI/backtrader/__init__.py new file mode 100644 index 00000000..2bf8edec --- /dev/null +++ b/AI/backtrader/__init__.py @@ -0,0 +1,3 @@ +#AI/backtrader/simple_backtrader.py +from backtrader.simple_backtrader import backtrader, BacktradeConfig +__all__ = ["backtrader", "BacktradeConfig"] \ No newline at end of file diff --git a/AI/backtrader/simple_backtrader.py b/AI/backtrader/simple_backtrader.py index 8cd796fc..f5387d77 100644 --- a/AI/backtrader/simple_backtrader.py +++ b/AI/backtrader/simple_backtrader.py @@ -18,7 +18,7 @@ # === 설정 클래스 === @dataclass -class BacktestConfig: +class BacktradeConfig: """ 한국어 주석: - 간소화 백테스터 설정 @@ -52,7 +52,7 @@ def _fill_date_from_signal(sig_date: pd.Timestamp, same_day: bool) -> pd.Timesta # === 백테스트 본체 === def backtrader( decision_log: pd.DataFrame, - config: Optional[BacktestConfig] = None, + config: Optional[BacktradeConfig] = None, run_id: Optional[str] = None, ) -> Tuple[pd.DataFrame, Dict]: """ @@ -62,7 +62,7 @@ def backtrader( - 반환: (fills_df, summary) """ if config is None: - config = BacktestConfig() + config = BacktraderConfig() dl = decision_log.copy() if not {"ticker", "date", "action", "price"}.issubset(dl.columns): diff --git a/AI/libs/core/pipeline.py b/AI/libs/core/pipeline.py index b88b4349..7043e7db 100644 --- a/AI/libs/core/pipeline.py +++ b/AI/libs/core/pipeline.py @@ -2,35 +2,35 @@ # -*- coding: utf-8 -*- """ 한국어 주석 (개요): -- 본 파일은 "주간 자동 파이프라인"의 전체 흐름을 오케스트레이션한다. - (Finder → Transformer → XAI 리포트 → Backtester → DB 저장) +- 본 파일은 "주간 자동 파이프라인"의 전체 흐름을 오케스트레이션합니다. + (Finder → Transformer → XAI 리포트 → Backtrader → DB 저장) [전체 플로우] 1) Finder - - 시장/전략 조건에 맞는 종목 목록(ticker list)을 선정한다. + - 시장/전략 조건에 맞는 종목 목록(ticker list)을 선정합니다. 2) Transformer - - 선택된 종목들의 OHLCV를 DB에서 가져온다(fetch_ohlcv). - - LSTM/Rule 기반 등의 Transformer 로직을 통해 의사결정 로그(DataFrame)를 생성한다. - - 이 의사결정 로그는 XAI와 Backtester에서 모두 공통으로 사용된다. + - 선택된 종목들의 OHLCV 데이터를 DB에서 가져옵니다(fetch_ohlcv). + - LSTM/Rule 기반 등의 Transformer 로직을 통해 의사결정 로그(DataFrame)를 생성합니다. + - 이 의사결정 로그는 XAI와 Backtrader에서 모두 공통으로 사용됩니다. 3) XAI (e.g. GROQ 등 LLM 기반 설명 생성) - 각 의사결정에 대해 feature_name / feature_score를 기반으로 - "왜 이 신호가 나왔는지"에 대한 자연어 리포트를 생성한다. - - 결과는 xai_reports 테이블에 먼저 저장된다. - - 이 때 생성된 xai_reports.id를 decision_log(logs_df)에 xai_report_id로 심는다. + "왜 이 신호가 나왔는지"에 대한 자연어 리포트를 생성합니다. + - 결과는 xai_reports 테이블에 먼저 저장됩니다. + - 이 때 생성된 xai_reports.id를 decision_log(logs_df)에 xai_report_id로 심습니다. -4) Backtester +4) Backtrader - xai_report_id가 포함된 의사결정 로그(decision_log)를 받아, - price 컬럼을 "체결 기준가"로 직접 사용해 간소화된 백테스트를 수행한다. - - Backtest 결과(fills_df)는 xai_report_id를 그대로 보존한 상태로 executions에 저장된다. + price 컬럼을 "체결 기준가"로 직접 사용해 간소화된 백테스트를 수행합니다. + - 백테스트 결과(fills_df)는 xai_report_id를 그대로 보존한 상태로 executions 테이블에 저장됩니다. [주의 사항] -- Transformer가 생성하는 decision_log(DataFrame)는 최소한 아래 컬럼을 포함해야 한다. +- Transformer가 생성하는 decision_log(DataFrame)는 최소한 아래 컬럼을 포함해야 합니다: ['ticker', 'date', 'action', 'price', 'feature_name1', 'feature_name2', 'feature_name3', 'feature_score1', 'feature_score2', 'feature_score3'] -- GROQ_API_KEY 환경변수가 없으면 XAI 단계는 자동으로 스킵된다. +- GROQ_API_KEY 환경변수가 없으면 XAI 단계는 자동으로 스킵됩니다. """ import os @@ -51,10 +51,10 @@ # ---------------------------------------------------------------------- from finder.main import run_finder # 1) 종목 발굴 from transformer.main import run_transformer # 2) 신호 생성(의사결정 로그 생성) -from backtrader.run_backtrader import backtest, BacktestConfig # 4) 백트레이딩(간소화 체결 엔진) -from libs.utils.save_executions_to_db import save_executions_to_db # 5) 체결내역 DB 저장 -from xai.run_xai import run_xai # 3) XAI 리포트 텍스트 생성 -from libs.utils.save_reports_to_db import save_reports_to_db # 3.5) XAI 리포트 DB 저장 (id 반환) +from backtrader.run_backtrader import backtrader, BacktradeConfig # 3) 백트레이딩(간소화 체결 엔진) +from libs.utils.save_executions_to_db import save_executions_to_db # 3.5) 체결내역 DB 저장 +from xai.run_xai import run_xai # 4) XAI 리포트 텍스트 생성 +from libs.utils.save_reports_to_db import save_reports_to_db # 4.5) XAI 리포트 DB 저장 (id 반환) from libs.utils.fetch_ohlcv import fetch_ohlcv # (Transformer용) OHLCV 수집 헬퍼 # ---------------------------------------------------------------------- @@ -69,7 +69,7 @@ REPORT_DB_NAME = "db" # 체결내역 / XAI 리포트를 저장하는 DB 명 # ---------------------------------------------------------------------- -# XAI 및 Backtester에서 공통으로 요구하는 "결정 로그 필수 컬럼" 정의 +# XAI 및 Backtrader에서 공통으로 요구하는 "결정 로그 필수 컬럼" 정의 # ---------------------------------------------------------------------- REQUIRED_LOG_COLS = { "ticker", @@ -122,7 +122,7 @@ def _to_float(v, fallback: float = 0.0) -> float: def run_weekly_finder() -> List[str]: """ - Finder 모듈을 실행하여 후보 티커 리스트를 반환. + Finder 모듈을 실행하여 후보 티커 리스트를 반환합니다. """ print("--- [PIPELINE-STEP 1] Finder 모듈 실행 시작 ---") try: @@ -144,7 +144,7 @@ def run_weekly_finder() -> List[str]: def run_signal_transformer(tickers: List[str], db_name: str) -> pd.DataFrame: """ 종목 리스트에 대해 DB에서 OHLCV를 수집하고 Transformer를 호출하여 - 의사결정 로그(DataFrame)를 생성한다. + 의사결정 로그(DataFrame)를 생성합니다. """ print("--- [PIPELINE-STEP 2] Transformer 모듈 실행 시작 ---") @@ -219,21 +219,21 @@ def run_signal_transformer(tickers: List[str], db_name: str) -> pd.DataFrame: def run_backtrader(decision_log: pd.DataFrame) -> pd.DataFrame: """ Transformer에서 생성된 의사결정 로그(decision_log)의 price 컬럼을 - OHLCV 없이 "체결 기준가"로 직접 사용해 간소화된 백테스트를 수행한다. + OHLCV 없이 "체결 기준가"로 직접 사용해 간소화된 백테스트를 수행합니다. 주의: - decision_log는 xai_report_id 컬럼을 포함할 수 있으며, - backtest() 구현이 해당 컬럼을 드롭하지 않으면 fills_df에도 그대로 보존된다. + backtrader() 구현이 해당 컬럼을 드롭하지 않으면 fills_df에도 그대로 보존됩니다. """ - print("--- [PIPELINE-STEP 4] Backtester 실행 시작 ---") + print("--- [PIPELINE-STEP 4] Backtrader 실행 시작 ---") if decision_log is None or decision_log.empty: - print("[WARN] Backtester: 비어있는 결정 로그가 입력되었습니다. 체결을 수행하지 않습니다.") + print("[WARN] Backtrader: 비어있는 결정 로그가 입력되었습니다. 체결을 수행하지 않습니다.") return pd.DataFrame() run_id = _utcnow().strftime("run-%Y%m%d-%H%M%S") - cfg = BacktradeConfig( + cfg = BacktraderConfig( initial_cash=100_000.0, slippage_bps=5.0, commission_bps=3.0, @@ -268,7 +268,7 @@ def run_backtrader(decision_log: pd.DataFrame) -> pd.DataFrame: def run_xai_report(decision_log: pd.DataFrame) -> List[ReportRow]: """ Transformer 결정 로그를 입력으로 받아, 각 행(의사결정)에 대한 - XAI 설명 리포트(자연어 텍스트)를 생성한다. + XAI 설명 리포트(자연어 텍스트)를 생성합니다. """ print("--- [PIPELINE-STEP 3] XAI 리포트 생성 시작 ---") @@ -346,7 +346,7 @@ def run_xai_report(decision_log: pd.DataFrame) -> List[ReportRow]: def run_pipeline() -> Optional[List[ReportRow]]: """ - 전체 파이프라인(Finder → Transformer → XAI → Backtester → DB 저장)을 + 전체 파이프라인(Finder → Transformer → XAI → Backtrader → DB 저장)을 한 번에 실행하는 엔트리 포인트 함수. """ # 1) Finder @@ -373,7 +373,6 @@ def run_pipeline() -> Optional[List[ReportRow]]: xai_ids = [] # 3.7) logs_df에 xai_report_id 심기 - # (길이가 맞지 않거나 XAI 저장 실패 시에는 NULL로 채워서 진행) logs_df = logs_df.copy().reset_index(drop=True) if xai_ids and len(xai_ids) == len(logs_df): logs_df["xai_report_id"] = xai_ids From e1834f5862995df89cfe22ee6e41ad469d5f2c90 Mon Sep 17 00:00:00 2001 From: twq110 Date: Fri, 28 Nov 2025 14:53:02 +0900 Subject: [PATCH 6/8] =?UTF-8?q?[AI]=20SISC2-43=20[FEAT]=20=EC=9E=90?= =?UTF-8?q?=EC=82=B0=EB=B0=B0=EB=B6=84=20=EB=A1=9C=EC=A7=81=20=EC=9E=91?= =?UTF-8?q?=EC=84=B1=20=EB=B0=8F=20=EC=B5=9C=EC=A2=85=20=EC=88=98=EC=A0=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- AI/backtrade/__init__.py | 3 + .../main.py} | 11 +- AI/{backtrader => backtrade}/order_policy.py | 4 +- AI/backtrader/__init__.py | 3 - AI/daily_data_collection/__init__.py | 3 + AI/daily_data_collection/main.py | 1165 +++++++++++++++++ AI/daily_data_collection/test.py | 187 +++ AI/libs/core/__init__.py | 2 +- AI/libs/core/pipeline.py | 132 +- AI/libs/utils/save_executions_to_db.py | 279 +++- AI/requirements.txt | 1 + AI/tests/test_transfomer.py | 209 --- AI/tests/test_transformer_backtrader.py | 194 +++ AI/weekly_tickers.json | 1 + 14 files changed, 1877 insertions(+), 317 deletions(-) create mode 100644 AI/backtrade/__init__.py rename AI/{backtrader/simple_backtrader.py => backtrade/main.py} (94%) rename AI/{backtrader => backtrade}/order_policy.py (97%) delete mode 100644 AI/backtrader/__init__.py create mode 100644 AI/daily_data_collection/__init__.py create mode 100644 AI/daily_data_collection/main.py create mode 100644 AI/daily_data_collection/test.py delete mode 100644 AI/tests/test_transfomer.py create mode 100644 AI/tests/test_transformer_backtrader.py create mode 100644 AI/weekly_tickers.json diff --git a/AI/backtrade/__init__.py b/AI/backtrade/__init__.py new file mode 100644 index 00000000..afd617a7 --- /dev/null +++ b/AI/backtrade/__init__.py @@ -0,0 +1,3 @@ +#AI/backtrader/__init__.py +from .main import backtrade, BacktradeConfig +__all__ = ["backtrade", "BacktradeConfig"] \ No newline at end of file diff --git a/AI/backtrader/simple_backtrader.py b/AI/backtrade/main.py similarity index 94% rename from AI/backtrader/simple_backtrader.py rename to AI/backtrade/main.py index f5387d77..8185244a 100644 --- a/AI/backtrader/simple_backtrader.py +++ b/AI/backtrade/main.py @@ -1,10 +1,9 @@ -# backtrader/simple_backtrader.py -# -*- coding: utf-8 -*- +# backtrade/main.py """ 한국어 주석: - OHLCV 없이, Transformer 결정 로그(decision_log)의 price만으로 간소화된 백테스트를 수행하는 환경(Environment) 역할. -- 수량/포지션 결정은 backtrader/order_policy.py 모듈로 분리됨. +- 수량/포지션 결정은 backtrade/order_policy.py 모듈로 분리됨. """ from __future__ import annotations @@ -13,7 +12,7 @@ import pandas as pd import numpy as np -from backtrader.order_policy import decide_order # 분리된 정책 모듈 import +from backtrade.order_policy import decide_order # 분리된 정책 모듈 import # === 설정 클래스 === @@ -50,7 +49,7 @@ def _fill_date_from_signal(sig_date: pd.Timestamp, same_day: bool) -> pd.Timesta # === 백테스트 본체 === -def backtrader( +def backtrade( decision_log: pd.DataFrame, config: Optional[BacktradeConfig] = None, run_id: Optional[str] = None, @@ -62,7 +61,7 @@ def backtrader( - 반환: (fills_df, summary) """ if config is None: - config = BacktraderConfig() + config = BacktradeConfig() dl = decision_log.copy() if not {"ticker", "date", "action", "price"}.issubset(dl.columns): diff --git a/AI/backtrader/order_policy.py b/AI/backtrade/order_policy.py similarity index 97% rename from AI/backtrader/order_policy.py rename to AI/backtrade/order_policy.py index 65a93df5..3846d2ee 100644 --- a/AI/backtrader/order_policy.py +++ b/AI/backtrade/order_policy.py @@ -1,4 +1,4 @@ -# backtrader/order_policy.py +# backtrade/order_policy.py # -*- coding: utf-8 -*- """ 한국어 주석: @@ -29,7 +29,7 @@ def decide_order( cur_qty: 현재 보유 주식 수량 avg_price: 현재 보유 평균단가 fill_price: 이번 체결 기준가 (슬리피지 반영 전) - config: BacktestConfig 인스턴스 + config: BacktradeConfig 인스턴스 반환값: (qty, trade_value) qty: 매수/매도 수량 diff --git a/AI/backtrader/__init__.py b/AI/backtrader/__init__.py deleted file mode 100644 index 2bf8edec..00000000 --- a/AI/backtrader/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -#AI/backtrader/simple_backtrader.py -from backtrader.simple_backtrader import backtrader, BacktradeConfig -__all__ = ["backtrader", "BacktradeConfig"] \ No newline at end of file diff --git a/AI/daily_data_collection/__init__.py b/AI/daily_data_collection/__init__.py new file mode 100644 index 00000000..2fb0cec4 --- /dev/null +++ b/AI/daily_data_collection/__init__.py @@ -0,0 +1,3 @@ +#AI/dail_data_collection/__init__.py +from .main import run_data_collection +__all__ = ["run_data_collection"] \ No newline at end of file diff --git a/AI/daily_data_collection/main.py b/AI/daily_data_collection/main.py new file mode 100644 index 00000000..1bbbe3e0 --- /dev/null +++ b/AI/daily_data_collection/main.py @@ -0,0 +1,1165 @@ +# ====================================================================== +# SECTION 1 — Imports + 경로 설정 + 공용 유틸 + DB 전체 티커 로딩 +# ====================================================================== + +from __future__ import annotations + +import os +import sys +from datetime import datetime, timedelta, timezone +from typing import List, Dict, Any, Optional + +import numpy as np +import pandas as pd +import yfinance as yf +from psycopg2.extras import execute_values +from fredapi import Fred + +# ---------------------------------------------------------------------- +# 프로젝트 루트 경로를 sys.path에 추가 (절대경로 import 문제 해결) +# ---------------------------------------------------------------------- +project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +if project_root not in sys.path: + sys.path.append(project_root) + +# DB 연결 함수 가져오기 +from libs.utils.get_db_conn import get_db_conn, get_engine + + +# ---------------------------------------------------------------------- +# 한국 표준시 (KST) +# ---------------------------------------------------------------------- +KST = timezone(timedelta(hours=9)) + + +# ====================================================================== +# 공용 유틸 함수 +# ====================================================================== + +def today_kst() -> datetime.date: + """한국(KST) 기준 today's date 반환.""" + return datetime.now(KST).date() + + +def get_last_date_in_table(db_name: str, table: str, date_col: str) -> Optional[datetime.date]: + """ + 테이블의 날짜 컬럼(date_col)에서 MAX(date)를 얻는 함수 + """ + from sqlalchemy import text + engine = get_engine(db_name) + + with engine.connect() as conn: + res = conn.execute(text(f"SELECT MAX({date_col}) FROM {table};")).scalar() + + return res if res is not None else None + + +# ====================================================================== +# ✨ DB 전체 티커 로딩 함수 (핵심 기능 추가) +# ====================================================================== + +def get_all_tickers_from_db(db_name: str) -> List[str]: + """ + public.price_data 에서 존재하는 모든 ticker를 DISTINCT 로 가져오는 함수. + manual_backfill_all() 에서 전체 티커를 자동으로 사용하기 위해 도입. + """ + from sqlalchemy import text + engine = get_engine(db_name) + + with engine.connect() as conn: + rows = conn.execute(text(""" + SELECT DISTINCT ticker + FROM public.price_data + ORDER BY ticker; + """)).fetchall() + + return [r[0] for r in rows] + + +# ====================================================================== +# ✨ Series, numpy 자료형 등 → 스칼라 float 변환 함수 (중요) +# ====================================================================== + +def to_scalar(v): + """ + yfinance 또는 pandas Series에서 발생하는 + numpy.float64 / ndarray / Series 형태를 모두 float로 정규화. + """ + if isinstance(v, (np.generic, np.float64, np.int64)): + return float(v) + + if isinstance(v, (list, tuple, np.ndarray)): + return float(v[0]) if len(v) > 0 else None + + if isinstance(v, pd.Series): + # Series인 경우 iloc으로 첫 값만 사용 + return float(v.iloc[0]) if len(v) > 0 else None + + # 이미 파이썬 float인 경우 + try: + return float(v) + except Exception: + return None +# ====================================================================== +# SECTION 2 — PRICE_DATA FETCH / UPSERT / PIPELINE +# ====================================================================== + +# -------------------------------------------------------------- +# DataFrame → 파이썬 records (numpy / Series 모두 스칼라로 변환) +# -------------------------------------------------------------- + +def df_to_python_records_price(df: pd.DataFrame): + """ + price_data DataFrame 을 Python tuple 리스트로 변환. + - 스키마 기준 8컬럼만 사용한다. + - MultiIndex / 중복 컬럼 이름이 있어도 첫 번째 컬럼만 사용하도록 정규화한다. + """ + # MultiIndex -> 1단계 이름으로 평탄화 + if isinstance(df.columns, pd.MultiIndex): + df.columns = df.columns.get_level_values(0) + + target_cols = ["ticker", "date", "open", "high", "low", "close", "volume", "adjusted_close"] + + # 중복 컬럼 / DataFrame 반환 케이스 안전하게 처리 + safe_dict = {} + for c in target_cols: + if c not in df.columns: + # 아예 없으면 전부 None + safe_dict[c] = pd.Series([None] * len(df), index=df.index) + continue + + col = df[c] + # 같은 이름의 컬럼이 여러 개인 경우 df[c] 가 DataFrame 이라서 + # 첫 번째 컬럼만 사용 + if isinstance(col, pd.DataFrame): + safe_dict[c] = col.iloc[:, 0] + else: + safe_dict[c] = col + + df = pd.DataFrame(safe_dict) + + records = [] + + def scalar_or_none(v): + if v is None or pd.isna(v): + return None + if isinstance(v, (int, float)): + return v + try: + return float(v) + except Exception: + return None + + # 여기서는 무조건 8개 컬럼만 존재 + for ticker, date, open_, high_, low_, close_, volume, adjusted_close in df.itertuples(index=False, name=None): + records.append( + ( + str(ticker), + date, + scalar_or_none(open_), + scalar_or_none(high_), + scalar_or_none(low_), + scalar_or_none(close_), + int(volume) if (volume is not None and not pd.isna(volume)) else None, + scalar_or_none(adjusted_close), + ) + ) + + return records + + +# -------------------------------------------------------------- +# yfinance 일봉 수집 +# -------------------------------------------------------------- +def fetch_price_data_from_yf(tickers: List[str], start: str, end: str) -> pd.DataFrame: + """ + yfinance 를 '티커 1개씩' 호출해서 + 항상 세로(long) 구조의 8컬럼 DF를 반환한다. + """ + frames = [] + + for t in tickers: + print(f"[PRICE] Fetch {t} {start}~{end}") + df = yf.download(t, start=start, end=end, auto_adjust=False) + + # MultiIndex 반환되는 경우 강제 단일화 + if isinstance(df.columns, pd.MultiIndex): + df.columns = df.columns.get_level_values(0) + + + if df.empty: + print(f"[PRICE] No data for {t}") + continue + + # index = DatetimeIndex + df.index.name = "date" + df = df.reset_index() + + # 표준 컬럼명으로 통일 + df = df.rename( + columns={ + "Date": "date", + "Open": "open", + "High": "high", + "Low": "low", + "Close": "close", + "Adj Close": "adjusted_close", + "Volume": "volume", + } + ) + + df["ticker"] = t + df["date"] = pd.to_datetime(df["date"]).dt.date + + df = df[ + ["ticker", "date", "open", + "high", "low", "close", + "volume", "adjusted_close"] + ] + + frames.append(df) + + if not frames: + return pd.DataFrame( + columns=[ + "ticker", "date", "open", "high", "low", + "close", "volume", "adjusted_close" + ] + ) + + return pd.concat(frames, ignore_index=True) + + + +# -------------------------------------------------------------- +# price_data UPSERT +# -------------------------------------------------------------- +def upsert_price_data(db_name: str, df: pd.DataFrame, batch_size: int = 10_000) -> None: + """ + public.price_data 테이블에 (ticker, date) 기준으로 UPSERT 수행. + + - df 전체를 한 번에 records 로 바꾸지 않고, + batch_size 단위로 잘라서 DB에 넣음 → 메모리 폭발 방지. + """ + if df.empty: + print("[PRICE] No new data to upsert.") + return + + sql = """ + INSERT INTO public.price_data + (ticker, date, open, high, low, close, volume, adjusted_close) + VALUES %s + ON CONFLICT (ticker, date) DO UPDATE SET + open = EXCLUDED.open, + high = EXCLUDED.high, + low = EXCLUDED.low, + close = EXCLUDED.close, + volume = EXCLUDED.volume, + adjusted_close = EXCLUDED.adjusted_close; + """ + + conn = get_db_conn(db_name) + try: + with conn.cursor() as cur: + # df 를 행 기준으로 잘라서 차례대로 INSERT + total_rows = len(df) + for start in range(0, total_rows, batch_size): + end = min(start + batch_size, total_rows) + batch_df = df.iloc[start:end] + records = df_to_python_records_price(batch_df) + + if not records: + continue + + execute_values(cur, sql, records) + + print(f"[PRICE] Upserted rows {start} ~ {end-1} / {total_rows}") + + conn.commit() + print(f"[PRICE] Upserted TOTAL {len(df)} rows into public.price_data") + + finally: + conn.close() + + + +# -------------------------------------------------------------- +# (자동용) price pipeline: 증분 업데이트 +# -------------------------------------------------------------- +def run_price_pipeline(config: Dict[str, Any]) -> None: + db_name = config["db_name"] + tickers = config["tickers"] + + last = get_last_date_in_table(db_name, "public.price_data", "date") + + if last is None: + start_date = config.get("price_start", "2017-01-01") + print(f"[PRICE] 기존 없음 → {start_date} 부터 시작") + else: + start_date = (last + timedelta(days=1)).strftime("%Y-%m-%d") + print(f"[PRICE] 마지막 날짜 {last} 이후 → {start_date} 부터 수집") + + end_date = today_kst().strftime("%Y-%m-%d") + + if start_date > end_date: + print("[PRICE] Already up to date.") + return + + df = fetch_price_data_from_yf(tickers, start_date, end_date) + upsert_price_data(db_name, df) +# ====================================================================== +# SECTION 3 — TECHNICAL INDICATORS (계산 + UPSERT + PIPELINE) +# ====================================================================== + +# -------------------------------------------------------------- +# 기술지표 계산 (티커 1개 단위) +# -------------------------------------------------------------- +def compute_technical_indicators(df: pd.DataFrame) -> pd.DataFrame: + """ + 입력 df: ticker, date, open, high, low, close, volume + 출력 df: ticker, date, RSI, MACD, Bollinger_Bands_upper, ... MA_200 + """ + + df = df.sort_values("date").reset_index(drop=True) + + + close = df["close"] + high = df["high"] + low = df["low"] + volume = df["volume"] + # 결측치 보정 (앞뒤로 채우기) + df["close"] = df["close"].ffill().bfill() + df["high"] = df["high"].ffill().bfill() + df["low"] = df["low"].ffill().bfill() + df["volume"] = df["volume"].fillna(0) + + + # ------------------------- RSI ------------------------- + delta = close.diff() + gain = delta.clip(lower=0) + loss = -delta.clip(upper=0) + + window_rsi = 14 + avg_gain = gain.rolling(window_rsi).mean() + avg_loss = loss.rolling(window_rsi).mean() + rs = avg_gain / (avg_loss + 1e-9) + rsi = 100 - (100 / (1 + rs)) + + # ------------------------- MACD ------------------------- + ema12 = close.ewm(span=12, adjust=False).mean() + ema26 = close.ewm(span=26, adjust=False).mean() + macd = ema12 - ema26 + + # ------------------------- Bollinger Bands ------------------------- + ma20 = close.rolling(20).mean() + std20 = close.rolling(20).std() + bb_upper = ma20 + 2 * std20 + bb_lower = ma20 - 2 * std20 + + # ------------------------- ATR ------------------------- + prev_close = close.shift(1) + tr1 = high - low + tr2 = (high - prev_close).abs() + tr3 = (low - prev_close).abs() + tr = pd.concat([tr1, tr2, tr3], axis=1).max(axis=1) + atr = tr.rolling(14).mean() + + # ------------------------- OBV ------------------------- + obv = [0] + for i in range(1, len(close)): + if close.iloc[i] > close.iloc[i-1]: + obv.append(obv[-1] + volume.iloc[i]) + elif close.iloc[i] < close.iloc[i-1]: + obv.append(obv[-1] - volume.iloc[i]) + else: + obv.append(obv[-1]) + obv = pd.Series(obv, index=df.index) + + # ------------------------- Stochastic ------------------------- + lowest_low = low.rolling(14).min() + highest_high = high.rolling(14).max() + stochastic = (close - lowest_low) / (highest_high - lowest_low + 1e-9) * 100 + + # ------------------------- MFI ------------------------- + typical_price = (high + low + close) / 3 + raw_mf = typical_price * volume + tp_diff = typical_price.diff() + pos_mf = raw_mf.where(tp_diff > 0, 0) + neg_mf = raw_mf.where(tp_diff < 0, 0) + pos_14 = pos_mf.rolling(14).sum() + neg_14 = neg_mf.abs().rolling(14).sum() + 1e-9 + mfr = pos_14 / neg_14 + mfi = 100 - (100 / (1 + mfr)) + + # ------------------------- MA ------------------------- + ma_5 = close.rolling(5).mean() + ma_20_roll = close.rolling(20).mean() + ma_50 = close.rolling(50).mean() + ma_200 = close.rolling(200).mean() + + out = pd.DataFrame({ + "ticker": df["ticker"], + "date": df["date"], + "RSI": rsi, + "MACD": macd, + "Bollinger_Bands_upper": bb_upper, + "Bollinger_Bands_lower": bb_lower, + "ATR": atr, + "OBV": obv, + "Stochastic": stochastic, + "MFI": mfi, + "MA_5": ma_5, + "MA_20": ma_20_roll, + "MA_50": ma_50, + "MA_200": ma_200, + }) + + return out + + +# -------------------------------------------------------------- +# UPSERT for technical_indicators +# -------------------------------------------------------------- +def upsert_technical_indicators(db_name: str, df: pd.DataFrame, batch_size: int = 50_000): + if df.empty: + print("[TECH] No technical indicators to upsert.") + return + + sql = """ + INSERT INTO public.technical_indicators ( + ticker, date, + RSI, MACD, + Bollinger_Bands_upper, Bollinger_Bands_lower, + ATR, OBV, Stochastic, MFI, + MA_5, MA_20, MA_50, MA_200 + ) + VALUES %s + ON CONFLICT (ticker, date) DO UPDATE SET + RSI = EXCLUDED.RSI, + MACD = EXCLUDED.MACD, + Bollinger_Bands_upper = EXCLUDED.Bollinger_Bands_upper, + Bollinger_Bands_lower = EXCLUDED.Bollinger_Bands_lower, + ATR = EXCLUDED.ATR, + OBV = EXCLUDED.OBV, + Stochastic = EXCLUDED.Stochastic, + MFI = EXCLUDED.MFI, + MA_5 = EXCLUDED.MA_5, + MA_20 = EXCLUDED.MA_20, + MA_50 = EXCLUDED.MA_50, + MA_200 = EXCLUDED.MA_200; + """ + + conn = get_db_conn(db_name) + try: + with conn.cursor() as cur: + total = len(df) + for start in range(0, total, batch_size): + end = min(start + batch_size, total) + batch = df.iloc[start:end] + + records = [] + for _, r in batch.iterrows(): + records.append(( + r["ticker"], r["date"], + to_scalar(r["RSI"]), to_scalar(r["MACD"]), + to_scalar(r["Bollinger_Bands_upper"]), to_scalar(r["Bollinger_Bands_lower"]), + to_scalar(r["ATR"]), to_scalar(r["OBV"]), + to_scalar(r["Stochastic"]), to_scalar(r["MFI"]), + to_scalar(r["MA_5"]), to_scalar(r["MA_20"]), + to_scalar(r["MA_50"]), to_scalar(r["MA_200"]), + )) + + if not records: + continue + + execute_values(cur, sql, records) + print(f"[TECH] Upserted rows {start} ~ {end-1} / {total}") + + conn.commit() + print(f"[TECH] Upserted TOTAL {len(df)} rows into public.technical_indicators") + + finally: + conn.close() + + + +# -------------------------------------------------------------- +# 기술지표 파이프라인: DB price_data 전체 기반 +# -------------------------------------------------------------- +def run_technical_indicators_full(config: Dict[str, Any]) -> None: + """ + 기술지표 전체 FULL 재계산 (manual_backfill_all 전용) + """ + db_name = config["db_name"] + + print("[TECH-FULL] 전체 기간 price_data 로딩 중…") + + from sqlalchemy import text + engine = get_engine(db_name) + + query = text(""" + SELECT ticker, date, open, high, low, close, volume, adjusted_close + FROM public.price_data + ORDER BY ticker, date; + """) + + with engine.connect() as conn: + df_price = pd.read_sql(query, conn) + + if df_price.empty: + print("[TECH-FULL] price_data 없음 → 스킵") + return + + tickers = sorted(df_price["ticker"].unique()) + print(f"[TECH-FULL] 전체 티커 수: {len(tickers)}") + + frames = [] + + with engine.connect() as conn: + conn.execute(text("DELETE FROM public.technical_indicators;")) + print("[TECH-FULL] 기존 technical_indicators 테이블 초기화 완료.") + conn.commit() + + + for idx, t in enumerate(tickers, start=1): + print(f"[TECH-FULL] ({idx}/{len(tickers)}) {t} 계산 중…") + df_t = df_price[df_price["ticker"] == t] + tech_df = compute_technical_indicators(df_t) + frames.append(tech_df) + + full_df = pd.concat(frames, ignore_index=True) + + print(f"[TECH-FULL] 기술지표 전체 계산 완료: {len(full_df)} rows") + + # 전체 덮어쓰기 (UPSERT) + upsert_technical_indicators(db_name, full_df) + + print("[TECH-FULL] 기술지표 FULL 업서트 완료.") + +# -------------------------------------------------------------- +# 기술지표 파이프라인: 증분용 (최근 250일치만 계산) +# -------------------------------------------------------------- + +def run_technical_indicators_incremental(config: Dict[str, Any]) -> None: + """ + 기존: price_data 전체 기간을 읽고 모든 기술지표 전체 재계산 (매우 무거움) + 변경: 최근 250일만 계산 → 5~10배 이상 빠름 + """ + db_name = config["db_name"] + window_days = 250 # rolling 기간 + 여유 buffer + + # --------------------------------------------------------- + # 최신 날짜 구하고, 최근 250일 구간만 조회 + # --------------------------------------------------------- + from sqlalchemy import text + engine = get_engine(db_name) + + # price_data 에 최신 날짜 확인 + with engine.connect() as conn: + last_date = conn.execute( + text("SELECT MAX(date) FROM public.price_data") + ).scalar() + + if last_date is None: + print("[TECH] price_data empty → 기술지표 계산 불가") + return + + start_date = last_date - timedelta(days=window_days) + + print(f"[TECH] 최근 {window_days}일 ({start_date} ~ {last_date}) 데이터만 사용") + + # 최근 250일 price data만 로딩 + query = text(""" + SELECT ticker, date, open, high, low, close, volume + FROM public.price_data + WHERE date >= :start_date + ORDER BY ticker, date + """) + + with engine.connect() as conn: + df_price = pd.read_sql(query, conn, params={"start_date": start_date}) + + if df_price.empty: + print("[TECH] 최근 데이터 없음 → 기술지표 계산 스킵") + return + + # --------------------------------------------------------- + # 티커별 기술지표 계산 + # --------------------------------------------------------- + tickers = sorted(df_price["ticker"].unique()) + print(f"[TECH] 대상 티커 수: {len(tickers)}") + + frames = [] + for idx, t in enumerate(tickers, start=1): + print(f"[TECH] ({idx}/{len(tickers)}) {t} 지표 계산") + df_t = df_price[df_price["ticker"] == t] + tech_df = compute_technical_indicators(df_t) + frames.append(tech_df) + + tech_recent = pd.concat(frames, ignore_index=True) + + print(f"[TECH] 기술지표 계산 완료: {len(tech_recent)} rows") + + # --------------------------------------------------------- + # 최근 250일 데이터만 UPSERT + # --------------------------------------------------------- + upsert_technical_indicators(db_name, tech_recent) + + print("[TECH] 최근 250일 기술지표 업데이트 완료") + + +# ====================================================================== +# SECTION 4 — MACROECONOMIC INDICATORS (FETCH + UPSERT + PIPELINE) +# ====================================================================== + +# -------------------------------------------------------------- +# Macro Data Fetch from FRED +# -------------------------------------------------------------- + +# FRED API 읽어오기 +FRED_API_KEY = os.getenv("FRED_API_KEY") +def get_fred_client() -> Fred | None: + """ + 한국어 주석: + - FRED API 클라이언트를 생성하는 헬퍼 함수. + - FRED_API_KEY가 없거나 잘못 설정된 경우 None을 반환하고, + 매크로 데이터 수집을 스킵할 수 있도록 한다. + """ + if not FRED_API_KEY: + print("[WARN] FRED_API_KEY가 설정되어 있지 않아 매크로 데이터 수집을 건너뜁니다.") + return None + + try: + fred = Fred(api_key=FRED_API_KEY) + return fred + except ValueError as e: + # fredapi 내부에서 키가 잘못되었을 때 발생하는 에러를 안전하게 처리 + print(f"[WARN] FRED API 초기화 실패: {e}") + return None + +def fetch_macro_from_fred(series_map: Dict[str, str], start: str, end: str) -> pd.DataFrame: + dates = pd.date_range(start, end, freq="D") + out = pd.DataFrame({"date": dates}) + out["date"] = out["date"].dt.date + + fred = get_fred_client() + if fred is None: + print("[MACRO] FRED 클라이언트 초기화 실패 → 매크로 데이터 수집 건너뜀") + return pd.DataFrame(columns=["date"] + list(series_map.keys())) + + for col_name, fred_symbol in series_map.items(): + print(f"[MACRO] Fetch {col_name} ({fred_symbol}) {start}~{end}") + s = fred.get_series(fred_symbol, observation_start=start, observation_end=end) + s = s.reset_index().rename(columns={"index": "date", 0: col_name}) + s["date"] = pd.to_datetime(s["date"]).dt.date + out = out.merge(s, on="date", how="left") + + return out + +# -------------------------------------------------------------- +# Macro UPSERT +# -------------------------------------------------------------- +def upsert_macro(db_name: str, df: pd.DataFrame): + if df.empty: + print("[MACRO] No macro data to upsert.") + return + + # 누락된 컬럼 채우기 + required_cols = [ + "cpi", "gdp", "ppi", "jolt", + "cci", "interest_rate", "trade_balance" + ] + for c in required_cols: + if c not in df.columns: + df[c] = None + + # pandas/Numpy → Python scalar + records = [] + for _, r in df.iterrows(): + records.append(( + r["date"], + to_scalar(r["cpi"]), + to_scalar(r["gdp"]), + to_scalar(r["ppi"]), + to_scalar(r["jolt"]), + to_scalar(r["cci"]), + to_scalar(r["interest_rate"]), + to_scalar(r["trade_balance"]), + )) + + sql = """ + INSERT INTO public.macroeconomic_indicators ( + date, cpi, gdp, ppi, jolt, + cci, interest_rate, trade_balance + ) + VALUES %s + ON CONFLICT (date) DO UPDATE SET + cpi = EXCLUDED.cpi, + gdp = EXCLUDED.gdp, + ppi = EXCLUDED.ppi, + jolt = EXCLUDED.jolt, + cci = EXCLUDED.cci, + interest_rate = EXCLUDED.interest_rate, + trade_balance = EXCLUDED.trade_balance; + """ + + conn = get_db_conn(db_name) + try: + with conn.cursor() as cur: + execute_values(cur, sql, records) + conn.commit() + print(f"[MACRO] Upserted {len(df)} rows into public.macroeconomic_indicators") + + finally: + conn.close() + + +# -------------------------------------------------------------- +# Macro Pipeline (증분) +# -------------------------------------------------------------- +def run_macro_pipeline(config: Dict[str, Any]): + db_name = config["db_name"] + series_map = config["macro_series"] + + if not series_map: + print("[MACRO] macro_series is empty → skip macroeconomic_indicators") + return + + # 테이블 마지막 날짜 + last = get_last_date_in_table( + db_name, + "public.macroeconomic_indicators", + "date" + ) + + if last is None: + start_date = config.get("macro_start", "2017-01-01") + print(f"[MACRO] 기존 없음 → {start_date} 부터 시작") + else: + start_date = (last + timedelta(days=1)).strftime("%Y-%m-%d") + print(f"[MACRO] 마지막 날짜 {last} 이후 → {start_date} 부터 수집") + + end_date = today_kst().strftime("%Y-%m-%d") + + if start_date > end_date: + print("[MACRO] Already up to date.") + return + + df_macro = fetch_macro_from_fred(series_map, start_date, end_date) + upsert_macro(db_name, df_macro) +# ====================================================================== +# SECTION 5 — COMPANY FUNDAMENTALS (FETCH + UPSERT + PIPELINE) +# ====================================================================== + +# -------------------------------------------------------------- +# yfinance 재무제표 Fetch +# -------------------------------------------------------------- +def fetch_company_fundamentals_from_yf(tickers: List[str]) -> pd.DataFrame: + """ + 각 티커에 대해 다음 데이터를 가져와 company_fundamentals 형태로 변환: + - annual financials (Income Statement) + - annual balance sheet + - EPS, PE_ratio (yfinance.info 에서) + """ + rows = [] + + for t in tickers: + print(f"[FUND] Fetch fundamentals for {t}") + yf_t = yf.Ticker(t) + + fs = yf_t.financials # 손익계산서 + bs = yf_t.balance_sheet # 재무상태표 + + if fs is None or fs.empty: + print(f"[FUND] No financials for {t}") + continue + if bs is None or bs.empty: + print(f"[FUND] No balance_sheet for {t}") + continue + + fs = fs.copy() + bs = bs.copy() + + # --- 여기 부분만 수정 --- + def normalize_cols_to_date(idx): + # 이미 DatetimeIndex면 .date (ndarray of date) 사용 + if isinstance(idx, pd.DatetimeIndex): + return idx.date + # 그 외에는 to_datetime 후 .date + return pd.to_datetime(idx).date + + fs.columns = normalize_cols_to_date(fs.columns) + bs.columns = normalize_cols_to_date(bs.columns) + # ------------------------ + + # 계정명 매칭 함수 + def find_row(df, names): + for n in names: + if n in df.index: + return df.loc[n] + df_map = {idx.lower().replace(" ", ""): idx for idx in df.index} + for n in names: + key = n.lower().replace(" ", "") + if key in df_map: + return df.loc[df_map[key]] + return None + + revenue_row = find_row(fs, ["Total Revenue", "Revenue"]) + net_income_row = find_row(fs, ["Net Income", "NetIncome"]) + assets_row = find_row(bs, ["Total Assets"]) + liab_row = find_row(bs, ["Total Liab", "Total Liabilities"]) + equity_row = find_row(bs, ["Total Stockholder Equity", "Total Equity"]) + + info = {} + try: + info = yf_t.info or {} + except: + pass + + eps_value = info.get("trailingEps") + pe_value = info.get("trailingPE") + + report_dates = sorted(set(fs.columns) | set(bs.columns)) + + for d in report_dates: + rows.append({ + "ticker": t, + "date": d, + "revenue": float(revenue_row[d]) if (revenue_row is not None and d in revenue_row and pd.notna(revenue_row[d])) else None, + "net_income": float(net_income_row[d]) if (net_income_row is not None and d in net_income_row and pd.notna(net_income_row[d])) else None, + "total_assets": float(assets_row[d]) if (assets_row is not None and d in assets_row and pd.notna(assets_row[d])) else None, + "total_liabilities": float(liab_row[d]) if (liab_row is not None and d in liab_row and pd.notna(liab_row[d])) else None, + "equity": float(equity_row[d]) if (equity_row is not None and d in equity_row and pd.notna(equity_row[d])) else None, + "EPS": float(eps_value) if eps_value is not None else None, + "PE_ratio": float(pe_value) if pe_value is not None else None, + }) + + if not rows: + return pd.DataFrame( + columns=[ + "ticker", "date", + "revenue", "net_income", + "total_assets", "total_liabilities", + "equity", "EPS", "PE_ratio" + ] + ) + + return pd.DataFrame(rows) + + + +# -------------------------------------------------------------- +# UPSERT Company Fundamentals +# -------------------------------------------------------------- +def upsert_company_fundamentals(db_name: str, df: pd.DataFrame): + if df.empty: + print("[FUND] No fundamentals to upsert.") + return + + # pandas → python scalar + records = [] + for _, r in df.iterrows(): + records.append(( + r["ticker"], r["date"], + to_scalar(r["revenue"]), + to_scalar(r["net_income"]), + to_scalar(r["total_assets"]), + to_scalar(r["total_liabilities"]), + to_scalar(r["equity"]), + to_scalar(r["EPS"]), + to_scalar(r["PE_ratio"]), + )) + + sql = """ + INSERT INTO public.company_fundamentals ( + ticker, date, + revenue, net_income, + total_assets, total_liabilities, + equity, EPS, PE_ratio + ) + VALUES %s + ON CONFLICT (ticker, date) DO UPDATE SET + revenue = EXCLUDED.revenue, + net_income = EXCLUDED.net_income, + total_assets = EXCLUDED.total_assets, + total_liabilities = EXCLUDED.total_liabilities, + equity = EXCLUDED.equity, + EPS = EXCLUDED.EPS, + PE_ratio = EXCLUDED.PE_ratio; + """ + + conn = get_db_conn(db_name) + try: + with conn.cursor() as cur: + execute_values(cur, sql, records) + conn.commit() + print(f"[FUND] Upserted {len(df)} rows into public.company_fundamentals") + + finally: + conn.close() + + +# -------------------------------------------------------------- +# Fundamentals Pipeline (전체 업데이트) +# -------------------------------------------------------------- +def run_company_fundamentals_pipeline(config: Dict[str, Any]): + db_name = config["db_name"] + tickers = config["tickers_for_fund"] + + df = fetch_company_fundamentals_from_yf(tickers) + upsert_company_fundamentals(db_name, df) +# ====================================================================== +# SECTION 6 — DAILY AUTO DATA COLLECTION (증분 업데이트) +# ====================================================================== + +def run_data_collection() -> None: + """ + pipeline/run_pipeline.py 에서 "STEP 0" 으로 호출되는 자동 데이터 수집 함수. + + ▣ 동작 + 1) price_data: MAX(date) 이후 구간만 yfinance 로 추가 수집 후 UPSERT + 2) technical_indicators: price_data 전체 기반으로 다시 계산 후 UPSERT + 3) macroeconomic_indicators: MAX(date)+1 이후 구간만 수집 + 4) company_fundamentals: 전체 재수집 (annual/quarterly 로 증분 개념이 애매해서) + """ + + print("\n=== [STEP 0] DAILY DATA COLLECTION (AUTO, incremental) ===") + + db_name = "db" + today_str = today_kst().strftime("%Y-%m-%d") + + # ------------------------------------------------------------------ + # 1) 가격/기술지표 대상 티커 (환경변수 or 기본값) + # ------------------------------------------------------------------ + # 1) DB에서 모든 티커 가져오기 + tickers = get_all_tickers_from_db(db_name) + + # DB가 비어있으면 기본값 사용 + if not tickers: + print("[INFO] DB에 티커가 없어서 기본 유니버스 사용: AAPL, MSFT, TSLA") + tickers = ["AAPL", "MSFT", "TSLA"] + + + # ------------------------------------------------------------------ + # 2) 펀더멘털 대상 티커 + # ------------------------------------------------------------------ + tickers_for_fund = get_all_tickers_from_db(db_name) + + # DB가 비어있으면 기본값 사용 + if not tickers_for_fund: + print("[INFO] DB에 티커가 없어서 기본 유니버스 사용: AAPL, MSFT, TSLA") + tickers_for_fund = ["AAPL", "MSFT", "TSLA"] + + + # ------------------------------------------------------------------ + # 3) 거시지표 시리즈 매핑 + # ------------------------------------------------------------------ + macro_series = { + # 예시. 실제 쓰는 yfinance symbol 로 교체 가능 + "cpi": "CPIAUCSL", + "gdp": "GDP", + "ppi": "PPIACO", + "jolt": "JTSJOL", + "cci": "CONCCONF", + "interest_rate": "^TNX", + } + + config = { + "db_name": db_name, + "tickers": tickers, + "price_start": "2017-01-01", + "tickers_for_fund": tickers_for_fund, + "macro_series": macro_series, + "macro_start": "2017-01-01", + } + + # ------------------------------------------------------------------ + # [1] price_data 증분 업데이트 + # ------------------------------------------------------------------ + print("\n[1] price_data incremental update…") + run_price_pipeline(config) + + # ------------------------------------------------------------------ + # [2] technical_indicators 전체 재계산 + # ------------------------------------------------------------------ + print("\n[2] technical_indicators incremental recompute…") + run_technical_indicators_incremental(config) + + # ------------------------------------------------------------------ + # [3] macroeconomic_indicators 증분 업데이트 + # ------------------------------------------------------------------ + print("\n[3] macroeconomic_indicators incremental update…") + run_macro_pipeline(config) + + # ------------------------------------------------------------------ + # [4] company_fundamentals 전체 업데이트 + # ------------------------------------------------------------------ + print("\n[4] company_fundamentals full update…") + run_company_fundamentals_pipeline(config) + + print("=== [STEP 0] DAILY DATA COLLECTION DONE ===\n") +# ====================================================================== +# SECTION 7 — MANUAL BACKFILL ALL (FULL REFILL) +# ====================================================================== + +def manual_backfill_all() -> None: + """ + ⚠ 전체 백필(Backfill) 모드 + ----------------------------------------- + 이 함수는 "모든 티커 × 모든 스키마 × 전체 기간"을 다시 채운다. + 즉, 다음 순서로 강제 재수집한다: + + 1) price_data (OHLCV 전체 재수집) + 2) technical_indicators (price_data 전체 기반 재계산) + 3) macroeconomic_indicators (전체 재수집) + 4) company_fundamentals (전체 재수집) + + ※ 이 작업은 매우 오래 걸릴 수 있음. + 실행 전 반드시 티커 개수와 시작일 범위를 확인해야 함. + """ + + db_name = "db" + today_str = today_kst().strftime("%Y-%m-%d") + + # ------------------------------------------------------------ + # 1) 가격/기술지표 대상 티커 + # ------------------------------------------------------------ + tickers = get_all_tickers_from_db(db_name) + if not tickers: + print("[INFO] DB에 티커가 없어서 기본 유니버스 사용: AAPL, MSFT, TSLA") + tickers = ["AAPL", "MSFT", "TSLA"] + + # ------------------------------------------------------------ + # 2) 펀더멘털 대상 티커 + # ------------------------------------------------------------ + tickers_for_fund = get_all_tickers_from_db(db_name) + if not tickers_for_fund: + print("[INFO] DB에 티커가 없어서 기본 유니버스 사용: AAPL, MSFT, TSLA") + tickers_for_fund = ["AAPL", "MSFT", "TSLA"] + + # ------------------------------------------------------------ + # 3) 거시지표 시리즈 + # ------------------------------------------------------------ + macro_series = { + # 필요 시 실제 프로젝트에서 수정 + "cpi": "CPIAUCSL", + "gdp": "GDP", + "ppi": "PPIACO", + "jolt": "JTSJOL", + } + + # ------------------------------------------------------------ + # 4) 백필 시작일 설정 + # ------------------------------------------------------------ + price_backfill_start = os.getenv("PRICE_BACKFILL_START", "2017-01-01") + macro_backfill_start = os.getenv("MACRO_BACKFILL_START", "2017-01-01") + + # ------------------------------------------------------------ + # ⚠️ 진짜 위험한 작업이므로 경고 출력 + # ------------------------------------------------------------ + print( + "\n⚠⚠⚠ [전체 백필 모드 경고] ⚠⚠⚠\n" + "이 명령은 다음 작업을 '모두 강제로' 다시 수행합니다.\n" + " 1) 모든 티커의 전체 OHLCV → price_data 를 다시 채움\n" + " 2) price_data 전체 기반으로 technical_indicators 전체 재계산\n" + " 3) 거시지표 전체 재수집 → macroeconomic_indicators 다시 기록\n" + " 4) 모든 펀더멘털 티커에 대해 재무제표 전체 재수집\n\n" + "■ 작업량이 매우 큽니다. (티커 수 × 기간 × 스키마)\n" + "■ 네트워크 상황에 따라 몇십 분~몇 시간 걸릴 수 있습니다.\n" + "■ 지금 설정된 값은 다음과 같습니다:\n" + "---------------------------------------------------------------\n" + f" 오늘(KST): {today_str}\n" + f" price_data 티커: {tickers}\n" + f" fundamentals 티커: {tickers_for_fund}\n" + f" macro 시리즈: {macro_series}\n" + f" 가격 백필 시작일: {price_backfill_start}\n" + f" 거시 백필 시작일: {macro_backfill_start}\n" + "---------------------------------------------------------------\n" + "※ 의도한 값이 맞는지 반드시 확인하세요.\n" + "※ 중간에 CTRL+C 로 중단해도 그 시점까지는 DB에 반영됨.\n" + "---------------------------------------------------------------\n" + ) + + # ================================================================== + # STEP 1 — PRICE_DATA 전체 백필 (티커도 나눠서 처리) + # ================================================================== + print("\n[STEP 1/4] price_data FULL backfill 시작…") + + tickers_per_batch = 20 # 티커도 20개 단위로 끊어서 처리 (원하면 조정 가능) + + for i in range(0, len(tickers), tickers_per_batch): + batch_tickers = tickers[i : i + tickers_per_batch] + print(f"[STEP 1/4] price_data batch {i} ~ {i + len(batch_tickers) - 1} 티커 처리 중...") + + # 이 배치의 모든 티커에 대해 OHLCV 수집 + df_price_batch = fetch_price_data_from_yf( + batch_tickers, + price_backfill_start, + today_str, + ) + + # 이 배치만 DB에 UPSERT (행 단위 batch_size 는 upsert_price_data 에서 또 나눔) + upsert_price_data(db_name, df_price_batch, batch_size=10_000) + + print("[STEP 1/4] price_data FULL backfill 완료.\n") + + + # ================================================================== + # STEP 2 — TECHNICAL_INDICATORS 전체 재계산/백필 + # ================================================================== + print("[STEP 2/4] technical_indicators FULL 재계산…") + config_for_tech = { + "db_name": db_name, + "tickers": tickers + } + run_technical_indicators_full(config_for_tech) + print("[STEP 2/4] technical_indicators FULL backfill 완료.\n") + + # ================================================================== + # STEP 3 — MACROECONOMIC_INDICATORS 전체 백필 + # ================================================================== + if macro_series: + print("[STEP 3/4] macroeconomic_indicators FULL backfill 시작…") + df_macro = fetch_macro_from_fred(macro_series, macro_backfill_start, today_str) + upsert_macro(db_name, df_macro) + print("[STEP 3/4] macroeconomic_indicators FULL backfill 완료.\n") + else: + print("[STEP 3/4] macroeconomic_indicators: macro_series 비어 있어서 skip.\n") + + # ================================================================== + # STEP 4 — COMPANY FUNDAMENTALS 전체 백필 + # ================================================================== + print("[STEP 4/4] company_fundamentals FULL backfill 시작…") + config_for_fund = { + "db_name": db_name, + "tickers_for_fund": tickers_for_fund + } + run_company_fundamentals_pipeline(config_for_fund) + print("[STEP 4/4] company_fundamentals FULL backfill 완료.\n") + + print("✅ 전체 백필 완료. (manual_backfill_all)") +# ====================================================================== +# SECTION 8 — MAIN ENTRYPOINT +# ====================================================================== + +if __name__ == "__main__": + """ + 이 파일을 직접 실행하면 “전체 백필 모드”가 실행된다. + + 사용 예: + python AI/daily_data_collection/main.py + + 주의: + - manual_backfill_all() 은 매우 무거운 작업이다. + - 다음 환경변수를 통해 대상 티커/기간을 조절할 수 있다: + + PRICE_INGEST_TICKERS : price_data + technical 대상 + FUND_INGEST_TICKERS : company_fundamentals 대상 + PRICE_BACKFILL_START : OHLCV 백필 시작일 + MACRO_BACKFILL_START : macro 백필 시작일 + + - pipeline/run_pipeline.py 에서 자동 실행되는 함수는 + run_data_collection() + 이고, manual_backfill_all() 은 직접 실행할 때만 사용해야 한다. + """ + manual_backfill_all() diff --git a/AI/daily_data_collection/test.py b/AI/daily_data_collection/test.py new file mode 100644 index 00000000..d682c276 --- /dev/null +++ b/AI/daily_data_collection/test.py @@ -0,0 +1,187 @@ +import pandas as pd +import yfinance as yf +from datetime import datetime, date, timedelta +from calendar import monthrange + +# ============================================================ +# ① 안전한 Series 추출 유틸 (중복 컬럼 / MultiIndex 방어) +# ============================================================ + +def get_series(df: pd.DataFrame, col_name: str) -> pd.Series: + """ + df[col_name]이 Series가 아니라 DataFrame으로 나오는 경우 + (동일 이름 컬럼 여러 개 등)를 방어해서 + 항상 1차원 Series만 반환하도록 정규화하는 함수. + """ + col = df[col_name] + if isinstance(col, pd.DataFrame): + # 같은 이름의 컬럼이 여러 개 있으면 첫 번째 컬럼만 사용 + return col.iloc[:, 0] + return col + + +# ============================================================ +# ② 기술적 지표 계산 함수 (네 코드 + 컬럼 정규화 보강) +# ============================================================ + +def compute_technical_indicators(df: pd.DataFrame) -> pd.DataFrame: + """ + 입력 df: ticker, date, open, high, low, close, volume + 출력 df: ticker, date, RSI, MACD, Bollinger_Bands_upper, ... MA_200 + """ + + # MultiIndex 컬럼 방어: 상위 레벨만 사용 + if isinstance(df.columns, pd.MultiIndex): + df.columns = df.columns.get_level_values(0) + + df = df.sort_values("date").reset_index(drop=True) + + # 여기서 무조건 Series로 강제 + close = get_series(df, "close") + high = get_series(df, "high") + low = get_series(df, "low") + volume = get_series(df, "volume").fillna(0) + + # ------------------------- RSI ------------------------- + delta = close.diff() + gain = delta.clip(lower=0) + loss = -delta.clip(upper=0) + + window_rsi = 14 + avg_gain = gain.rolling(window_rsi).mean() + avg_loss = loss.rolling(window_rsi).mean() + rs = avg_gain / (avg_loss + 1e-9) + rsi = 100 - (100 / (1 + rs)) + + # ------------------------- MACD ------------------------- + ema12 = close.ewm(span=12, adjust=False).mean() + ema26 = close.ewm(span=26, adjust=False).mean() + macd = ema12 - ema26 + + # ------------------------- Bollinger Bands ------------------------- + ma20 = close.rolling(20).mean() + std20 = close.rolling(20).std() + bb_upper = ma20 + 2 * std20 + bb_lower = ma20 - 2 * std20 + + # ------------------------- ATR ------------------------- + prev_close = close.shift(1) + tr1 = high - low + tr2 = (high - prev_close).abs() + tr3 = (low - prev_close).abs() + tr = pd.concat([tr1, tr2, tr3], axis=1).max(axis=1) + atr = tr.rolling(14).mean() + + # ------------------------- OBV ------------------------- + obv = [0] + for i in range(1, len(close)): + if close.iloc[i] > close.iloc[i - 1]: + obv.append(obv[-1] + volume.iloc[i]) + elif close.iloc[i] < close.iloc[i - 1]: + obv.append(obv[-1] - volume.iloc[i]) + else: + obv.append(obv[-1]) + obv = pd.Series(obv, index=df.index) + + # ------------------------- Stochastic ------------------------- + lowest_low = low.rolling(14).min() + highest_high = high.rolling(14).max() + stochastic = (close - lowest_low) / (highest_high - lowest_low + 1e-9) * 100 + + # ------------------------- MFI ------------------------- + typical_price = (high + low + close) / 3 + raw_mf = typical_price * volume + tp_diff = typical_price.diff() + pos_mf = raw_mf.where(tp_diff > 0, 0) + neg_mf = raw_mf.where(tp_diff < 0, 0) + pos_14 = pos_mf.rolling(14).sum() + neg_14 = neg_mf.abs().rolling(14).sum() + 1e-9 + mfr = pos_14 / neg_14 + mfi = 100 - (100 / (1 + mfr)) + + # ------------------------- MA ------------------------- + ma_5 = close.rolling(5).mean() + ma_20_roll = close.rolling(20).mean() + ma_50 = close.rolling(50).mean() + ma_200 = close.rolling(200).mean() + + out = pd.DataFrame({ + "ticker": df["ticker"], + "date": df["date"], + "RSI": rsi, + "MACD": macd, + "Bollinger_Bands_upper": bb_upper, + "Bollinger_Bands_lower": bb_lower, + "ATR": atr, + "OBV": obv, + "Stochastic": stochastic, + "MFI": mfi, + "MA_5": ma_5, + "MA_20": ma_20_roll, + "MA_50": ma_50, + "MA_200": ma_200, + }) + + return out + + +# ============================================================ +# ③ yfinance 로 "작년 같은 달" 1개월치 MSFT OHLCV 다운로드 +# ============================================================ + +def fetch_msft_last_year_one_month() -> pd.DataFrame: + today = date.today() + last_year = today.year - 1 + month = today.month + + # 작년 같은 달의 1일 ~ 말일 + start = date(last_year, month, 1) + last_day = monthrange(last_year, month)[1] + end = date(last_year, month, last_day) + timedelta(days=1) # yfinance end는 exclusive + + print(f"[TEST] Fetching MSFT data: {start} ~ {end} (작년 같은 달 1개월)") + + df = yf.download("MSFT", start=start.strftime("%Y-%m-%d"), end=end.strftime("%Y-%m-%d")) + + if df.empty: + print("[TEST] No data returned from yfinance.") + return pd.DataFrame() + + # 인덱스 → 컬럼 + df = df.reset_index() + + # yfinance 포맷 → 표준 컬럼명으로 정리 + # (단일 티커이므로 MultiIndex 방어는 compute 함수에서 추가로 한 번 더 함) + df["date"] = pd.to_datetime(df["Date"]).dt.date + df["ticker"] = "MSFT" + + df = df.rename(columns={ + "Open": "open", + "High": "high", + "Low": "low", + "Close": "close", + "Volume": "volume", + }) + + return df[["ticker", "date", "open", "high", "low", "close", "volume"]] + + +# ============================================================ +# ④ 실제 계산 & 출력 (DB 업서트 없음) +# ============================================================ + +if __name__ == "__main__": + df_price = fetch_msft_last_year_one_month() + + if df_price.empty: + print("[TEST] 가격 데이터가 없어 기술지표를 계산할 수 없습니다.") + else: + df_tech = compute_technical_indicators(df_price) + + # 앞/뒤 일부를 확인해보고 싶으면 둘 다 찍어보자 + print("\n===== MSFT 기술적 지표 (앞 5행) =====\n") + print(df_tech.head(5).to_string(index=False)) + + print("\n===== MSFT 기술적 지표 (뒤 10행) =====\n") + print(df_tech.tail(10).to_string(index=False)) + diff --git a/AI/libs/core/__init__.py b/AI/libs/core/__init__.py index de6ce59a..7ace3c8e 100644 --- a/AI/libs/core/__init__.py +++ b/AI/libs/core/__init__.py @@ -1,3 +1,3 @@ -#AI/libs/core/pipeline.py +#AI/libs/core/__init__.py from libs.core.pipeline import run_pipeline __all__ = ["run_pipeline"] \ No newline at end of file diff --git a/AI/libs/core/pipeline.py b/AI/libs/core/pipeline.py index 7043e7db..6a90fba6 100644 --- a/AI/libs/core/pipeline.py +++ b/AI/libs/core/pipeline.py @@ -1,9 +1,8 @@ # pipeline/run_pipeline.py -# -*- coding: utf-8 -*- """ 한국어 주석 (개요): - 본 파일은 "주간 자동 파이프라인"의 전체 흐름을 오케스트레이션합니다. - (Finder → Transformer → XAI 리포트 → Backtrader → DB 저장) + (Finder → Transformer → XAI 리포트 → Backtrade → DB 저장) [전체 플로우] 1) Finder @@ -12,7 +11,7 @@ 2) Transformer - 선택된 종목들의 OHLCV 데이터를 DB에서 가져옵니다(fetch_ohlcv). - LSTM/Rule 기반 등의 Transformer 로직을 통해 의사결정 로그(DataFrame)를 생성합니다. - - 이 의사결정 로그는 XAI와 Backtrader에서 모두 공통으로 사용됩니다. + - 이 의사결정 로그는 XAI와 Backtrade에서 모두 공통으로 사용됩니다. 3) XAI (e.g. GROQ 등 LLM 기반 설명 생성) - 각 의사결정에 대해 feature_name / feature_score를 기반으로 @@ -20,7 +19,7 @@ - 결과는 xai_reports 테이블에 먼저 저장됩니다. - 이 때 생성된 xai_reports.id를 decision_log(logs_df)에 xai_report_id로 심습니다. -4) Backtrader +4) Backtrade - xai_report_id가 포함된 의사결정 로그(decision_log)를 받아, price 컬럼을 "체결 기준가"로 직접 사용해 간소화된 백테스트를 수행합니다. - 백테스트 결과(fills_df)는 xai_report_id를 그대로 보존한 상태로 executions 테이블에 저장됩니다. @@ -35,6 +34,7 @@ import os import sys +import json from typing import List, Dict, Optional, Tuple from datetime import datetime, timezone, timedelta @@ -49,9 +49,10 @@ # ---------------------------------------------------------------------- # 외부 모듈 import (각 단계별 역할) # ---------------------------------------------------------------------- +from daily_data_collection.main import run_data_collection # 0) 시세/시장 데이터 수집 from finder.main import run_finder # 1) 종목 발굴 from transformer.main import run_transformer # 2) 신호 생성(의사결정 로그 생성) -from backtrader.run_backtrader import backtrader, BacktradeConfig # 3) 백트레이딩(간소화 체결 엔진) +from backtrade.main import backtrade, BacktradeConfig # 3) 백트레이딩(간소화 체결 엔진) from libs.utils.save_executions_to_db import save_executions_to_db # 3.5) 체결내역 DB 저장 from xai.run_xai import run_xai # 4) XAI 리포트 텍스트 생성 from libs.utils.save_reports_to_db import save_reports_to_db # 4.5) XAI 리포트 DB 저장 (id 반환) @@ -69,7 +70,7 @@ REPORT_DB_NAME = "db" # 체결내역 / XAI 리포트를 저장하는 DB 명 # ---------------------------------------------------------------------- -# XAI 및 Backtrader에서 공통으로 요구하는 "결정 로그 필수 컬럼" 정의 +# XAI 및 Backtrade에서 공통으로 요구하는 "결정 로그 필수 컬럼" 정의 # ---------------------------------------------------------------------- REQUIRED_LOG_COLS = { "ticker", @@ -85,6 +86,11 @@ "feature_score3", } +# ---------------------------------------------------------------------- +# 주간 티커 캐시 파일 경로 +# ---------------------------------------------------------------------- +TICKER_CACHE_PATH = os.path.join(project_root, "weekly_tickers.json") + # ====================================================================== # 유틸리티 함수 모음 @@ -123,6 +129,7 @@ def _to_float(v, fallback: float = 0.0) -> float: def run_weekly_finder() -> List[str]: """ Finder 모듈을 실행하여 후보 티커 리스트를 반환합니다. + (주에 한 번, 월요일에만 실제 실행) """ print("--- [PIPELINE-STEP 1] Finder 모듈 실행 시작 ---") try: @@ -137,6 +144,33 @@ def run_weekly_finder() -> List[str]: return tickers +def save_weekly_tickers(tickers: List[str]) -> None: + """월요일에 산출한 티커 리스트를 로컬 캐시 파일에 저장.""" + try: + with open(TICKER_CACHE_PATH, "w", encoding="utf-8") as f: + json.dump({"tickers": tickers, "saved_at": _utcnow().isoformat()}, f, ensure_ascii=False) + print(f"[INFO] 주간 티커 캐시 저장 완료: {TICKER_CACHE_PATH}") + except Exception as e: + print(f"[WARN] 주간 티커 캐시 저장 실패: {e}") + + +def load_weekly_tickers() -> List[str]: + """캐시 파일에서 주간 티커 리스트를 불러옴.""" + if not os.path.exists(TICKER_CACHE_PATH): + print(f"[WARN] 주간 티커 캐시 파일이 존재하지 않습니다: {TICKER_CACHE_PATH}") + return [] + + try: + with open(TICKER_CACHE_PATH, "r", encoding="utf-8") as f: + data = json.load(f) + tickers = data.get("tickers", []) + print(f"[INFO] 캐시에서 주간 티커 로드: {tickers}") + return tickers + except Exception as e: + print(f"[WARN] 주간 티커 캐시 로드 실패: {e}") + return [] + + # ====================================================================== # 2) Transformer: 신호/의사결정 로그 생성 단계 # ====================================================================== @@ -152,7 +186,8 @@ def run_signal_transformer(tickers: List[str], db_name: str) -> pd.DataFrame: print("[WARN] 빈 종목 리스트가 입력되어 Transformer 단계를 건너뜁니다.") return pd.DataFrame() - end_date = datetime.strptime("2024-11-1", "%Y-%m-%d") # 임시 고정 날짜 + + end_date = datetime.now().date() start_date = end_date - timedelta(days=600) all_ohlcv_df: List[pd.DataFrame] = [] @@ -213,27 +248,27 @@ def run_signal_transformer(tickers: List[str], db_name: str) -> pd.DataFrame: # ====================================================================== -# 3) Backtrader: 의사결정 로그 기반 체결/포지션 계산 단계 +# 3) Backtrade: 의사결정 로그 기반 체결/포지션 계산 단계 # ====================================================================== -def run_backtrader(decision_log: pd.DataFrame) -> pd.DataFrame: +def run_backtrade(decision_log: pd.DataFrame) -> pd.DataFrame: """ Transformer에서 생성된 의사결정 로그(decision_log)의 price 컬럼을 OHLCV 없이 "체결 기준가"로 직접 사용해 간소화된 백테스트를 수행합니다. 주의: - decision_log는 xai_report_id 컬럼을 포함할 수 있으며, - backtrader() 구현이 해당 컬럼을 드롭하지 않으면 fills_df에도 그대로 보존됩니다. + backtrade() 구현이 해당 컬럼을 드롭하지 않으면 fills_df에도 그대로 보존됩니다. """ - print("--- [PIPELINE-STEP 4] Backtrader 실행 시작 ---") + print("--- [PIPELINE-STEP 4] Backtrade 실행 시작 ---") if decision_log is None or decision_log.empty: - print("[WARN] Backtrader: 비어있는 결정 로그가 입력되었습니다. 체결을 수행하지 않습니다.") + print("[WARN] Backtrade: 비어있는 결정 로그가 입력되었습니다. 체결을 수행하지 않습니다.") return pd.DataFrame() run_id = _utcnow().strftime("run-%Y%m%d-%H%M%S") - cfg = BacktraderConfig( + cfg = BacktradeConfig( initial_cash=100_000.0, slippage_bps=5.0, commission_bps=3.0, @@ -242,18 +277,18 @@ def run_backtrader(decision_log: pd.DataFrame) -> pd.DataFrame: fill_on_same_day=True, ) - fills_df, summary = backtrader( + fills_df, summary = backtrade( decision_log=decision_log, config=cfg, run_id=run_id, ) if fills_df is None or fills_df.empty: - print("[WARN] Backtrader: 생성된 체결 내역이 없습니다.") + print("[WARN] Backtrade: 생성된 체결 내역이 없습니다.") return pd.DataFrame() print( - f"--- [PIPELINE-STEP 4] Backtrader 완료: " + f"--- [PIPELINE-STEP 4] Backtrade 완료: " f"trades={len(fills_df)}, " f"cash_final={summary.get('cash_final')}, " f"pnl_realized_sum={summary.get('pnl_realized_sum')} ---" @@ -346,22 +381,57 @@ def run_xai_report(decision_log: pd.DataFrame) -> List[ReportRow]: def run_pipeline() -> Optional[List[ReportRow]]: """ - 전체 파이프라인(Finder → Transformer → XAI → Backtrader → DB 저장)을 + 전체 파이프라인(Finder → Transformer → XAI → Backtrade → DB 저장)을 한 번에 실행하는 엔트리 포인트 함수. - """ - # 1) Finder - tickers = run_weekly_finder() - if not tickers: - print("[STOP] Finder에서 종목을 찾지 못해 파이프라인을 중단합니다.") - return None - # 2) Transformer + - Finder: 주 1회, 월요일에만 실제 실행 (티커를 캐시에 저장) + - 나머지(Transformer, XAI, Backtrade, DB 저장): 매일 실행 + → 평일에는 캐시에서 티커를 읽어서 사용 + """ + today = datetime.now() # 서버 로컬 시간 기준 (필요시 timezone 조정 가능) + weekday = today.weekday() # 월=0, 화=1, ..., 일=6 + #------------------------------- + # 0) 주가 데이터 저장 실행 + #------------------------------- + print("--- [PIPELINE-STEP 0] 주가 데이터 수집 실행 시작 ---") + try: + #run_data_collection() + print("--- [PIPELINE-STEP 0] 주가 데이터 수집 실행 완료 ---") + except Exception as e: + print(f"[WARN] 데이터 수집 실행 중 오류 발생: {e} → 계속 진행합니다.") + + # ------------------------------- + # 1) Finder: 월요일에만 실행 + # ------------------------------- + if weekday == 0: # 월요일 + print("[INFO] 오늘은 월요일입니다. Finder를 실행합니다.") + tickers = run_weekly_finder() + if not tickers: + print("[STOP] Finder에서 종목을 찾지 못해 파이프라인을 중단합니다.") + return None + # 월요일에 산출한 티커를 캐시에 저장 + save_weekly_tickers(tickers) + else: + print("[INFO] 오늘은 월요일이 아니므로 Finder를 실행하지 않습니다. 캐시에서 티커를 불러옵니다.") + tickers = load_weekly_tickers() + if not tickers: + print("[WARN] 캐시된 티커가 없어 임시로 Finder를 한 번 실행합니다.") + tickers = run_weekly_finder() + if not tickers: + print("[STOP] 임시 Finder에서도 종목을 찾지 못해 파이프라인을 중단합니다.") + return None + + # ------------------------------- + # 2) Transformer: 매일 실행 + # ------------------------------- logs_df = run_signal_transformer(tickers, MARKET_DB_NAME) if logs_df is None or logs_df.empty: print("[STOP] Transformer에서 유효한 신호를 생성하지 못해 파이프라인을 중단합니다.") return None - # 3) XAI 리포트 생성 + # ------------------------------- + # 3) XAI 리포트 생성: 매일 실행 (환경변수 없으면 자동 스킵) + # ------------------------------- reports = run_xai_report(logs_df) # 3.5) XAI 리포트 DB 저장 → 생성된 id 리스트 수신 @@ -384,10 +454,14 @@ def run_pipeline() -> Optional[List[ReportRow]]: "xai_report_id를 매핑하지 못했습니다. (모두 NULL 처리)" ) - # 4) Backtester: xai_report_id 포함 decision_log로 체결 내역 생성 - fills_df = run_backtrader(logs_df) + # ------------------------------- + # 4) Backtrade: 매일 실행 + # ------------------------------- + fills_df = run_backtrade(logs_df) - # 5) executions 테이블에 체결 내역 저장 + # ------------------------------- + # 5) executions 테이블에 체결 내역 저장: 매일 실행 + # ------------------------------- try: save_executions_to_db(fills_df, REPORT_DB_NAME) print("[INFO] 체결 내역을 DB에 저장했습니다.") @@ -401,7 +475,7 @@ def run_pipeline() -> Optional[List[ReportRow]]: # 스크립트 단독 실행 시 테스트용 엔트리 포인트 # ====================================================================== if __name__ == "__main__": - print(">>> 파이프라인 (Finder → Transformer → XAI → Backtester) 테스트를 시작합니다.") + print(">>> 파이프라인 (Finder → Transformer → XAI → Backtrade) 테스트를 시작합니다.") final_reports = run_pipeline() print("\n>>> 최종 반환 결과 (XAI Reports):") diff --git a/AI/libs/utils/save_executions_to_db.py b/AI/libs/utils/save_executions_to_db.py index a9096c9d..0eeec7cd 100644 --- a/AI/libs/utils/save_executions_to_db.py +++ b/AI/libs/utils/save_executions_to_db.py @@ -1,103 +1,234 @@ -# libs/utils/save_executions_to_db.py # -*- coding: utf-8 -*- """ 한국어 주석: -- 간소화 백테스터의 체결 내역(DataFrame)을 executions 테이블에 저장한다. -- 주요 컬럼: - run_id, xai_report_id, ticker, signal_date, signal_price, signal, fill_date, fill_price, - qty, side, value, commission, cash_after, position_qty, avg_price, - pnl_realized, pnl_unrealized, created_at -- DB 엔진: libs.utils.get_db_conn 모듈의 get_engine(db_name) 사용(프로젝트 규약 준수) +- executions 테이블 저장 → portfolio_positions 갱신 → portfolio_summary 갱신 +- 하나의 executions INSERT가 발생하면 계좌 전체 상태를 즉시 업데이트한다. """ from __future__ import annotations from typing import Optional from sqlalchemy import text from datetime import datetime, timezone +import pandas as pd -from libs.utils.get_db_conn import get_engine # 프로젝트 기존 헬퍼 사용 +from libs.utils.get_db_conn import get_engine # 기존 프로젝트 헬퍼 사용 +# ------------------------------------------------------------------- +# 공용 헬퍼 +# ------------------------------------------------------------------- def _utcnow_iso() -> str: - """한국어 주석: created_at 등의 기록용 ISO8601 타임스탬프 문자열""" + """ISO 포맷 UTC 타임스탬프 문자열""" return datetime.now(timezone.utc).isoformat() +# ------------------------------------------------------------------- +# 📌 portfolio_positions 업데이트 +# ------------------------------------------------------------------- +def update_portfolio_position(conn, execution: dict): + """ + 한국어 주석: + - executions에 새 체결 1건이 기록될 때 호출 + - portfolio_positions는 ticker당 1행만 유지 (UNIQUE) + """ + + ticker = execution["ticker"] + qty = int(execution["qty"]) + fill_price = float(execution["fill_price"]) + side = execution["side"].upper() + realized_pnl = float(execution["pnl_realized"]) + + # --- 기존 포지션 로드 + old = conn.execute(text(""" + SELECT position_qty, avg_price, pnl_realized_cum + FROM portfolio_positions + WHERE ticker = :ticker + """), {"ticker": ticker}).fetchone() + + if old is None: + # 신규 포지션 (BUY만 가능) + if side != "BUY": + raise ValueError(f"[ERROR] 보유량 없이 SELL 발생: {ticker}") + + new_qty = qty + new_avg_price = fill_price + new_realized_cum = 0.0 + + else: + old_qty = old.position_qty + old_avg_price = float(old.avg_price) + old_realized_cum = float(old.pnl_realized_cum) + + if side == "BUY": + new_qty = old_qty + qty + new_avg_price = (old_avg_price * old_qty + fill_price * qty) / new_qty + new_realized_cum = old_realized_cum + + elif side == "SELL": + new_qty = old_qty - qty + new_realized_cum = old_realized_cum + realized_pnl -def ensure_exec_table_schema(engine) -> None: + if new_qty == 0: + new_avg_price = 0.0 + else: + new_avg_price = old_avg_price + + # 평가 관련 데이터 + current_price = fill_price + market_value = new_qty * current_price + pnl_unrealized = (current_price - new_avg_price) * new_qty + + # UPSERT + conn.execute(text(""" + INSERT INTO portfolio_positions + (ticker, position_qty, avg_price, + current_price, market_value, pnl_unrealized, + pnl_realized_cum, updated_at) + VALUES + (:ticker, :q, :avg, + :cp, :mv, :pnl_u, + :pnl_r, NOW()) + ON CONFLICT (ticker) + DO UPDATE SET + position_qty = EXCLUDED.position_qty, + avg_price = EXCLUDED.avg_price, + current_price = EXCLUDED.current_price, + market_value = EXCLUDED.market_value, + pnl_unrealized = EXCLUDED.pnl_unrealized, + pnl_realized_cum = EXCLUDED.pnl_realized_cum, + updated_at = NOW(); + """), { + "ticker": ticker, + "q": new_qty, + "avg": new_avg_price, + "cp": current_price, + "mv": market_value, + "pnl_u": pnl_unrealized, + "pnl_r": new_realized_cum, + }) + + +# ------------------------------------------------------------------- +# 📌 portfolio_summary 업데이트 +# ------------------------------------------------------------------- +def update_portfolio_summary(conn, fill_date: str): """ 한국어 주석: - - executions 테이블이 없으면 생성한다. - - 이미 있을 경우는 CREATE TABLE IF NOT EXISTS로 무해. - - 컬럼 타입은 PostgreSQL 기준(NUMERIC 정밀도 넉넉히 설정). - - xai_report_id 컬럼을 추가하여 xai_reports(id)를 FK로 참조한다. + - 계좌 전체 요약(자산, 평가금액, 수익률)을 fill_date 기준으로 M2M 업데이트. + - executions → portfolio_positions → portfolio_summary 순으로 호출됨. """ - with engine.begin() as conn: - # 테이블이 없을 때만 생성 - conn.execute(text(""" - CREATE TABLE IF NOT EXISTS executions ( - id SERIAL PRIMARY KEY, - run_id VARCHAR(64), - - xai_report_id BIGINT, -- 🔗 xai_reports.id 참조용 (NULL 허용) - - ticker VARCHAR(20) NOT NULL, - signal_date DATE NOT NULL, - signal_price NUMERIC(18,6), - signal VARCHAR(10) NOT NULL, - fill_date DATE NOT NULL, - fill_price NUMERIC(18,6) NOT NULL, - qty INTEGER NOT NULL, - side VARCHAR(5) NOT NULL, - value NUMERIC(20,6) NOT NULL, - commission NUMERIC(18,6) NOT NULL, - cash_after NUMERIC(20,6) NOT NULL, - position_qty INTEGER NOT NULL, - avg_price NUMERIC(18,6) NOT NULL, - pnl_realized NUMERIC(18,6) NOT NULL, - pnl_unrealized NUMERIC(18,6) NOT NULL, - created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW() - ); - """)) - - # FK는 이미 있을 수 있으니, 한 번 시도하고 실패하면 무시 - try: - conn.execute(text(""" - ALTER TABLE executions - ADD CONSTRAINT fk_executions_xai_reports - FOREIGN KEY (xai_report_id) - REFERENCES xai_reports(id); - """)) - except Exception: - # 이미 FK가 있거나 에러가 나더라도 전체 플로우를 막지 않음 - pass - - -def save_executions_to_db(rows_df, db_name: str) -> None: + + # 1) 최신 현금(cash): executions.cash_after 기준 + cash_row = conn.execute(text(""" + SELECT cash_after + FROM executions + ORDER BY id DESC + LIMIT 1; + """)).fetchone() + + cash = float(cash_row.cash_after) if cash_row else 0.0 + + # 2) 전체 평가금액 = portfolio_positions.market_value 합 + mv_row = conn.execute(text(""" + SELECT COALESCE(SUM(market_value), 0) AS mv + FROM portfolio_positions; + """)).fetchone() + + market_value = float(mv_row.mv) + + total_asset = cash + market_value + + # 3) 누적 실현손익 + realized_row = conn.execute(text(""" + SELECT COALESCE(SUM(pnl_realized_cum), 0) AS pnl_r + FROM portfolio_positions; + """)).fetchone() + + pnl_realized_cum = float(realized_row.pnl_r) + + # 4) 미실현손익 합계 + unrealized_row = conn.execute(text(""" + SELECT COALESCE(SUM(pnl_unrealized), 0) AS pnl_u + FROM portfolio_positions; + """)).fetchone() + + pnl_unrealized = float(unrealized_row.pnl_u) + + # 5) 초기 원금 불러오기 (없으면 total_asset으로 설정) + init = conn.execute(text(""" + SELECT initial_capital + FROM portfolio_summary + ORDER BY date ASC + LIMIT 1; + """)).fetchone() + + if init is None: + initial_capital = total_asset + else: + initial_capital = float(init.initial_capital) + + return_rate = (total_asset / initial_capital) - 1 + + # UPSERT + conn.execute(text(""" + INSERT INTO portfolio_summary + (date, total_asset, cash, market_value, + pnl_unrealized, pnl_realized_cum, + initial_capital, return_rate, created_at) + VALUES + (:d, :ta, :cash, :mv, + :pnl_u, :pnl_r, + :init, :rr, NOW()) + ON CONFLICT (date) + DO UPDATE SET + total_asset = EXCLUDED.total_asset, + cash = EXCLUDED.cash, + market_value = EXCLUDED.market_value, + pnl_unrealized = EXCLUDED.pnl_unrealized, + pnl_realized_cum = EXCLUDED.pnl_realized_cum, + return_rate = EXCLUDED.return_rate, + created_at = NOW(); + """), { + "d": fill_date, + "ta": total_asset, + "cash": cash, + "mv": market_value, + "pnl_u": pnl_unrealized, + "pnl_r": pnl_realized_cum, + "init": initial_capital, + "rr": return_rate, + }) + + +# ------------------------------------------------------------------- +# 📌 메인 함수: executions + positions + summary +# ------------------------------------------------------------------- +def save_executions_to_db(rows_df: pd.DataFrame, db_name: str) -> None: """ 한국어 주석: - - 체결 내역 DataFrame(rows_df)을 executions 테이블에 일괄 insert 한다. - - rows_df는 backtest()가 반환한 fills_df 스키마를 그대로 따른다. - - XAI 연동 시에는 rows_df에 xai_report_id 컬럼이 포함될 수 있다. - - 빈 DF가 들어오면 아무 것도 하지 않는다. + - rows_df 전체를 executions 테이블에 저장 + - 각 행 마다 portfolio_positions 갱신 + - 마지막으로 portfolio_summary 갱신 """ + if rows_df is None or rows_df.empty: - # 저장할 내용 없음 return engine = get_engine(db_name) - ensure_exec_table_schema(engine) - # XAI를 안 돌렸거나 매핑이 실패한 경우를 대비하여 컬럼이 없으면 NULL로 채워서 생성 + # xai_report_id 없으면 NULL로 if "xai_report_id" not in rows_df.columns: rows_df = rows_df.copy() - print("[WARN] xai_report_id 매핑 실패 또는 XAI 미실행 감지, NULL로 저장합니다.") + print("[WARN] xai_report_id 없음. NULL 처리.") rows_df["xai_report_id"] = None - # dict 레코드 리스트로 변환하여 executemany 형태로 성능 확보 payload = rows_df.to_dict(orient="records") with engine.begin() as conn: - sql = text(""" + + # ============================================================= + # 1) executions 테이블 INSERT (배치) + # ============================================================= + insert_sql = text(""" INSERT INTO executions (run_id, xai_report_id, ticker, signal_date, signal_price, signal, @@ -113,4 +244,18 @@ def save_executions_to_db(rows_df, db_name: str) -> None: :position_qty, :avg_price, :pnl_realized, :pnl_unrealized, NOW()) """) - conn.execute(sql, payload) + + conn.execute(insert_sql, payload) + + # ============================================================= + # 2) portfolio_positions 갱신 (행 단위) + # ============================================================= + for ex in payload: + update_portfolio_position(conn, ex) + + # ============================================================= + # 3) portfolio_summary 갱신 (마지막 fill_date 기준) + # ============================================================= + last_fill_date = payload[-1]["fill_date"] + update_portfolio_summary(conn, last_fill_date) + diff --git a/AI/requirements.txt b/AI/requirements.txt index 94f93321..b520683e 100644 --- a/AI/requirements.txt +++ b/AI/requirements.txt @@ -12,3 +12,4 @@ groq requests beautifulsoup4 pathlib +fredapi \ No newline at end of file diff --git a/AI/tests/test_transfomer.py b/AI/tests/test_transfomer.py deleted file mode 100644 index 8953394c..00000000 --- a/AI/tests/test_transfomer.py +++ /dev/null @@ -1,209 +0,0 @@ -# AI/tests/test_transformer_live_fetch.py -# -*- coding: utf-8 -*- -""" -[목적] -- Transformer 모듈만 실제 OHLCV로 테스트 (저장 없음, 출력만) -- 데이터 수집은 프로젝트 표준 유틸: libs.utils.fetch_ohlcv.fetch_ohlcv 를 '그대로' 사용 - (즉, DB 우선 조회 + 실패/결측 시 야후 파이낸스 API 폴백 로직은 fetch_ohlcv 내부 정책을 따름) - -[실행] -> cd AI -> python -m tests.test_transformer_live_fetch -또는 -> python tests/test_transformer_live_fetch.py -""" - -import os -import sys -import json -from typing import Dict, List, Optional -from datetime import datetime, timedelta -import time -import random - -import pandas as pd -import numpy as np - -# --- 프로젝트 루트 경로 설정 --------------------------------------------------- -project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -sys.path.append(project_root) -# --------------------------------------------------------------------------- - -# --- 모듈 임포트 -------------------------------------------------------------- -from transformer.main import run_transformer # Transformer 단독 테스트 대상 -from libs.utils.fetch_ohlcv import fetch_ohlcv # ★ 표준 OHLCV 수집 유틸(반드시 이걸 사용) -# --------------------------------------------------------------------------- - - -# ============================================================================= -# (옵션) 안전한 fetch 래퍼: 일시적 실패/429 등에 대비한 재시도 -# - fetch_ohlcv 내부에서도 재시도/폴백이 구현되어 있을 수 있으나, -# 테스트 안정성을 위해 여기서 추가로 얇은 재시도를 감쌉니다. -# ============================================================================= -def safe_fetch_ohlcv( - ticker: str, - start: str, - end: str, - config: Optional[Dict] = None, - max_retries: int = 5, - base_sleep: float = 0.8 -) -> pd.DataFrame: - """ - fetch_ohlcv 호출을 얇게 감싸는 재시도 래퍼. - - 429/일시 네트워크 오류 같은 경우를 대비하여 지수 백오프 + 지터 적용 - - fetch_ohlcv가 raise하면 여기서 재시도 후 최종 raise - """ - attempt = 0 - while True: - try: - df = fetch_ohlcv( - ticker=ticker, - start=start, - end=end, - config=(config or {}) - ) - return df - except Exception as e: - attempt += 1 - if attempt >= max_retries: - raise - # 지수 백오프 + 약간의 랜덤 지터 - sleep_s = base_sleep * (2 ** (attempt - 1)) + random.uniform(0, 0.6) - print(f"[WARN] {ticker} fetch_ohlcv 실패({attempt}/{max_retries}) -> {e} | {sleep_s:.2f}s 대기 후 재시도") - time.sleep(sleep_s) - - -# ============================================================================= -# Transformer 단독 라이브 테스트 -# ============================================================================= -def run_transform_only_live_with_fetch(): - """ - - configs/config.json 로드 → db 설정 전달 - - 티커별로 fetch_ohlcv 호출(★ 프로젝트 유틸 사용) → raw_data 결합 - - run_transformer 호출 → 출력만 수행(저장 없음) - """ - print("=== [TEST] Transformer 단독(실데이터) 테스트 시작 — using libs.utils.fetch_ohlcv ===") - - # ---------------------------------------------------------------------- - # (A) 설정/입력 - # ---------------------------------------------------------------------- - cfg_path = os.path.join(project_root, "configs", "config.json") - config: Dict = {} - if os.path.isfile(cfg_path): - try: - with open(cfg_path, "r", encoding="utf-8") as f: - config = json.load(f) - print("[INFO] configs/config.json 로드 완료") - except Exception as e: - print(f"[WARN] config 로드 실패(빈 설정으로 진행): {e}") - - db_config = (config or {}).get("db", {}) # fetch_ohlcv에 그대로 넘김 - - # 테스트 티커/기간 - tickers: List[str] = ["AAPL", "MSFT", "GOOGL"] # 필요 시 교체 - end_dt = datetime.now() - start_dt = end_dt - timedelta(days=600) - start_str = start_dt.strftime("%Y-%m-%d") - end_str = end_dt.strftime("%Y-%m-%d") - - seq_len = 60 - pred_h = 1 - transformer_cfg: Dict = { - "transformer": { - "features": ["open", "high", "low", "close", "volume","adjusted_close"], - "target": "close", - "scaler": "standard" - } - } - - # ---------------------------------------------------------------------- - # (B) fetch_ohlcv 로 실데이터 가져오기 (티커별 → concat) - # ---------------------------------------------------------------------- - raw_parts: List[pd.DataFrame] = [] - for tkr in tickers: - try: - print(f"[INFO] 수집 시작: {tkr} ({start_str} → {end_str})") - df = safe_fetch_ohlcv( - ticker=tkr, - start=start_str, - end=end_str, - config=db_config, # ★ fetch_ohlcv는 내부에서 DB/야후 폴백 처리 - max_retries=5, - base_sleep=0.8 - ) - if df is None or df.empty: - print(f"[WARN] {tkr} 데이터가 비어 있습니다.") - else: - # 스키마 정합성(Transformer가 기대하는 컬럼 존재 검사) - required = ["ticker", "date", "open", "high", "low", "close", "volume", "adjusted_close"] - missing = [c for c in required if c not in df.columns] - if missing: - raise ValueError(f"{tkr} 수집 데이터에 필수 컬럼 누락: {missing}") - # 날짜형 변환 보정 - if not np.issubdtype(df["date"].dtype, np.datetime64): - df["date"] = pd.to_datetime(df["date"], errors="coerce") - raw_parts.append(df.reset_index(drop=True)) - print(f"[INFO] {tkr} 수집 완료: {len(df):,} rows") - finally: - # API rate 제한 완화(티커 사이 간격) - time.sleep(0.6 + random.uniform(0, 0.6)) - - if not raw_parts: - print("[ERROR] 모든 소스에서 OHLCV 확보 실패(fetch_ohlcv 사용).") - return - - raw_data = pd.concat(raw_parts, ignore_index=True) - - # ---------------------------------------------------------------------- - # (C) Transformer 호출 - # ---------------------------------------------------------------------- - finder_df = pd.DataFrame({"ticker": tickers}) - - print("[INFO] run_transformer 호출 중...") - result = run_transformer( - finder_df=finder_df, - seq_len=seq_len, - pred_h=pred_h, - raw_data=raw_data, - config=transformer_cfg - ) - - logs_df = result.get("logs", pd.DataFrame()) if isinstance(result, dict) else pd.DataFrame() - meta = {k: v for k, v in result.items() if k != "logs"} if isinstance(result, dict) else {} - - # ---------------------------------------------------------------------- - # (D) 출력만(저장 없음) - # ---------------------------------------------------------------------- - print("\n--- [RESULT] Transformer 반환 메타 키 ---") - print(list(meta.keys())) - - print("\n--- [RESULT] 결정 로그(logs) 미리보기 ---") - if not logs_df.empty: - if not np.issubdtype(logs_df["date"].dtype, np.datetime64): - logs_df["date"] = pd.to_datetime(logs_df["date"], errors="coerce") - print(logs_df.head(10).to_string(index=False)) - else: - print("logs_df가 비어 있습니다. Transformer 내부 로직을 확인하세요.") - - if not logs_df.empty: - if "action" in logs_df.columns: - print("\n--- [STATS] 액션별 건수 ---") - print(logs_df["action"].value_counts()) - if {"ticker", "date"}.issubset(logs_df.columns): - print("\n--- [STATS] 티커별 최근 신호 2건 ---") - latest = ( - logs_df.sort_values(["ticker", "date"], ascending=[True, False]) - .groupby("ticker") - .head(2) - .reset_index(drop=True) - ) - print(latest.to_string(index=False)) - - print(f"\n=== [TEST] 종료: 총 원시행(raw_data) = {len(raw_data):,} ===") - - -# ----------------------------------------------------------------------------- -# 엔트리포인트 -# ----------------------------------------------------------------------------- -if __name__ == "__main__": - run_transform_only_live_with_fetch() diff --git a/AI/tests/test_transformer_backtrader.py b/AI/tests/test_transformer_backtrader.py new file mode 100644 index 00000000..757e2822 --- /dev/null +++ b/AI/tests/test_transformer_backtrader.py @@ -0,0 +1,194 @@ +# AI/tests/test_transformer_backtrader.py +# 사용불가. 추후 수정 필요 +import os +import sys +from typing import List, Dict, Optional +from datetime import datetime, timedelta +import pandas as pd +import pathlib + +# --- 프로젝트/레포 경로 설정 --------------------------------------------------- +_this_file = os.path.abspath(__file__) +project_root = os.path.dirname(os.path.dirname(_this_file)) # .../transformer +repo_root = os.path.dirname(project_root) # .../ +libs_root = os.path.join(repo_root, "libs") # .../libs + +# sys.path에 중복 없이 추가 +for p in (project_root, repo_root, libs_root): + if p not in sys.path: + sys.path.append(p) +# ------------------------------------------------------------------------------ + +# ---------------------------------------------------------------------- +# Transformer 실행 +# ---------------------------------------------------------------------- +from transformer.main import run_transformer + +from typing import Optional +import pandas as pd +from sqlalchemy import text + +# DB용 유틸: SQLAlchemy Engine 생성 함수 사용 (get_engine) +from libs.utils.get_db_conn import get_engine + +def fetch_ohlcv( + ticker: List[str], + start: str, + end: str, + interval: str = "1d", + db_name: str = "db", +) -> pd.DataFrame: + """ + 특정 티커, 날짜 범위의 OHLCV 데이터를 DB에서 불러오기 (SQLAlchemy 엔진 사용) + + Args: + ticker (List[str]): 종목 코드 리스트 (예: ["AAPL", "MSFT", "GOOGL"]) + start (str): 시작일자 'YYYY-MM-DD' (inclusive) + end (str): 종료일자 'YYYY-MM-DD' (inclusive) + interval (str): 데이터 간격 ('1d' 등) - 현재 테이블이 일봉만 제공하면 무시됨 + db_name (str): get_engine()가 참조할 설정 블록 이름 (예: "db", "report_DB") + + Returns: + pd.DataFrame: 컬럼 = [ticker, date, open, high, low, close, adjusted_close, volume] + (date 컬럼은 pandas datetime으로 변환됨) + """ + + # 1) SQLAlchemy engine 얻기 (configs/config.json 기준) + engine = get_engine(db_name) + + # 2) 쿼리: named parameter(:tickers 등) 사용 -> 안전하고 가독성 좋음 + query = text(""" + SELECT ticker, date, open, high, low, close, adjusted_close, volume + FROM public.price_data + WHERE ticker IN :tickers -- 수정된 부분 + AND date BETWEEN :start AND :end + ORDER BY date; + """) + + # 3) DB에서 읽기 (with 문으로 커넥션 자동 정리) + with engine.connect() as conn: + df = pd.read_sql( + query, + con=conn, # 꼭 키워드 인자로 con=conn + params={"tickers": tuple(ticker), "start": start, "end": end}, # 수정된 부분 + ) + + # 4) 후처리: 컬럼 정렬 및 date 타입 통일 + if df is None or df.empty: + # 빈 DataFrame이면 일관된 컬럼 스키마로 반환 + return pd.DataFrame(columns=["ticker", "date", "open", "high", "low", "close", "adjusted_close", "volume"]) + + # date 컬럼을 datetime으로 변경 (UTC로 맞추고 싶으면 pd.to_datetime(..., utc=True) 사용) + if "date" in df.columns: + df["date"] = pd.to_datetime(df["date"]) + + # 선택: 컬럼 순서 고정 (일관성 유지) + desired_cols = ["ticker", "date", "open", "high", "low", "close", "adjusted_close", "volume"] + # 존재하는 컬럼만 가져오기 + cols_present = [c for c in desired_cols if c in df.columns] + df = df.loc[:, cols_present] + + return df + + + +def run_transformer_for_test(finder_df: pd.DataFrame, raw_data: pd.DataFrame, start_date: str, end_date: str) -> pd.DataFrame: + """ + Transformer 모델을 실행하여 매매 신호를 예측하는 함수 + :param finder_df: 종목 리스트 + :param raw_data: OHLCV 시계열 데이터 + :param start_date: 시작 날짜 + :param end_date: 종료 날짜 + :return: 매매 신호를 포함한 DataFrame + """ + transformer_result = run_transformer( + finder_df=finder_df, + seq_len=64, + pred_h=5, + raw_data=raw_data, + run_date=end_date, # 예측 날짜 + weights_path=None, # 가중치 경로. transformer/main.py 내부에서 기본 경로 사용 + interval="1d" + ) + + logs_df = transformer_result.get("logs", pd.DataFrame()) + return logs_df + +# ---------------------------------------------------------------------- +# Backtrader 실행 +# ---------------------------------------------------------------------- +from backtrader import Cerebro, Strategy +from backtrader.feeds import PandasData + +class SimpleStrategy(Strategy): + """ + Backtrader 전략 + """ + def __init__(self, logs_df: pd.DataFrame): + self.order = None + self.buy_price = None + self.logs_df = logs_df # 매매 신호를 담은 logs_df를 클래스에 저장 + + def next(self): + # 매수/매도 신호에 따라 거래 진행 + if self.order: + return # 이미 주문이 있으면 아무것도 하지 않음 + + for _, row in self.logs_df.iterrows(): + if row['action'] == 'BUY' and self.data.datetime.date(0) == pd.to_datetime(row['date']).date(): + self.buy_price = row['predicted_price'] + self.order = self.buy(size=1) # 예시: 1주 매수 + + elif row['action'] == 'SELL' and self.data.datetime.date(0) == pd.to_datetime(row['date']).date(): + if self.buy_price: + sell_price = row['predicted_price'] + profit = (sell_price - self.buy_price) / self.buy_price * 100 # 수익률 계산 + print(f"Profit from {row['ticker']}: {profit:.2f}%") + self.order = self.sell(size=1) # 예시: 1주 매도 + self.buy_price = None + +# ---------------------------------------------------------------------- +# 테스트 실행 (2024년 1월 1일부터 12월 31일까지의 데이터로 테스트) +# ---------------------------------------------------------------------- +def test_transformer_backtrader(): + """ + 1년 동안 Transformer 모델을 통해 매매 신호를 예측하고, + Backtrader를 사용하여 수익률을 계산하는 테스트 함수 + """ + start_date = "2024-01-01" + end_date = "2024-12-31" + db_name = "db" # DB 이름 + + # 1. Transformer 모델을 통한 예측 신호 생성 + finder_df = pd.DataFrame({"ticker": ["AAPL", "MSFT", "GOOGL"]}) + + # DB에서 OHLCV 데이터 가져오기 + raw_data = fetch_ohlcv(finder_df['ticker'].tolist(), start_date, end_date, db_name=db_name) + + # Transformer 모델 실행 + logs_df = run_transformer_for_test(finder_df, raw_data, start_date, end_date) + + if logs_df.empty: + print("Transformer 모델에서 예측된 신호가 없습니다.") + return + + # 2. Backtrader 전략 실행 (매매 시뮬레이션) + ohlcv_data_feed = PandasData(dataname=raw_data) + + # Cerebro 엔진 설정 + cerebro = Cerebro() + cerebro.addstrategy(SimpleStrategy, logs_df=logs_df) # logs_df를 전략에 전달 + cerebro.adddata(ohlcv_data_feed) + + # 초기 자본금 및 수수료 설정 + cerebro.broker.set_cash(100000) # 초기 자본금 설정 + cerebro.broker.set_commission(commission=0.001) # 거래 수수료 설정 + + # 3. 백테스트 실행 + cerebro.run() + + # 최종 자본금 출력 + print(f"Final Portfolio Value: {cerebro.broker.getvalue():.2f}") + +if __name__ == "__main__": + test_transformer_backtrader() diff --git a/AI/weekly_tickers.json b/AI/weekly_tickers.json new file mode 100644 index 00000000..182836de --- /dev/null +++ b/AI/weekly_tickers.json @@ -0,0 +1 @@ +{"tickers": ["AAPL", "MSFT", "GOOGL"], "saved_at": "2025-11-24T06:30:32.207866+00:00"} \ No newline at end of file From 80f51ed04b97ca6db2f748d9f865f74f2c7febcf Mon Sep 17 00:00:00 2001 From: twq110 Date: Sat, 29 Nov 2025 21:27:35 +0900 Subject: [PATCH 7/8] =?UTF-8?q?[AI]=20SISC2-43=20[FIX]=20xai=5Freport=5Fid?= =?UTF-8?q?=EA=B0=80=20fills=5Fdf=EC=97=90=20=EC=A0=84=ED=8C=8C=EB=90=98?= =?UTF-8?q?=EC=A7=80=20=EC=95=8A=EB=8A=94=20=EB=AC=B8=EC=A0=9C=20=EC=88=98?= =?UTF-8?q?=EC=A0=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- AI/backtrade/__init__.py | 2 +- AI/backtrade/main.py | 3 + AI/daily_data_collection/__init__.py | 2 +- AI/daily_data_collection/test.py | 187 --------------------------- AI/libs/core/pipeline.py | 67 ++++++++-- 5 files changed, 63 insertions(+), 198 deletions(-) delete mode 100644 AI/daily_data_collection/test.py diff --git a/AI/backtrade/__init__.py b/AI/backtrade/__init__.py index afd617a7..76aabec1 100644 --- a/AI/backtrade/__init__.py +++ b/AI/backtrade/__init__.py @@ -1,3 +1,3 @@ -#AI/backtrader/__init__.py +#AI/backtrade/__init__.py from .main import backtrade, BacktradeConfig __all__ = ["backtrade", "BacktradeConfig"] \ No newline at end of file diff --git a/AI/backtrade/main.py b/AI/backtrade/main.py index 8185244a..11672835 100644 --- a/AI/backtrade/main.py +++ b/AI/backtrade/main.py @@ -80,6 +80,8 @@ def backtrade( sig = str(r["action"]).upper() sig_price = float(r.get("price", np.nan)) + xai_id = r.get("xai_report_id") + if sig not in ("BUY", "SELL"): continue @@ -128,6 +130,7 @@ def backtrade( records.append({ "run_id": run_id, + "xai_report_id": xai_id, "ticker": ticker, "signal_date": sig_date.date().isoformat(), "signal_price": float(sig_price), diff --git a/AI/daily_data_collection/__init__.py b/AI/daily_data_collection/__init__.py index 2fb0cec4..7e5ba52f 100644 --- a/AI/daily_data_collection/__init__.py +++ b/AI/daily_data_collection/__init__.py @@ -1,3 +1,3 @@ -#AI/dail_data_collection/__init__.py +#AI/daily_data_collection/__init__.py from .main import run_data_collection __all__ = ["run_data_collection"] \ No newline at end of file diff --git a/AI/daily_data_collection/test.py b/AI/daily_data_collection/test.py deleted file mode 100644 index d682c276..00000000 --- a/AI/daily_data_collection/test.py +++ /dev/null @@ -1,187 +0,0 @@ -import pandas as pd -import yfinance as yf -from datetime import datetime, date, timedelta -from calendar import monthrange - -# ============================================================ -# ① 안전한 Series 추출 유틸 (중복 컬럼 / MultiIndex 방어) -# ============================================================ - -def get_series(df: pd.DataFrame, col_name: str) -> pd.Series: - """ - df[col_name]이 Series가 아니라 DataFrame으로 나오는 경우 - (동일 이름 컬럼 여러 개 등)를 방어해서 - 항상 1차원 Series만 반환하도록 정규화하는 함수. - """ - col = df[col_name] - if isinstance(col, pd.DataFrame): - # 같은 이름의 컬럼이 여러 개 있으면 첫 번째 컬럼만 사용 - return col.iloc[:, 0] - return col - - -# ============================================================ -# ② 기술적 지표 계산 함수 (네 코드 + 컬럼 정규화 보강) -# ============================================================ - -def compute_technical_indicators(df: pd.DataFrame) -> pd.DataFrame: - """ - 입력 df: ticker, date, open, high, low, close, volume - 출력 df: ticker, date, RSI, MACD, Bollinger_Bands_upper, ... MA_200 - """ - - # MultiIndex 컬럼 방어: 상위 레벨만 사용 - if isinstance(df.columns, pd.MultiIndex): - df.columns = df.columns.get_level_values(0) - - df = df.sort_values("date").reset_index(drop=True) - - # 여기서 무조건 Series로 강제 - close = get_series(df, "close") - high = get_series(df, "high") - low = get_series(df, "low") - volume = get_series(df, "volume").fillna(0) - - # ------------------------- RSI ------------------------- - delta = close.diff() - gain = delta.clip(lower=0) - loss = -delta.clip(upper=0) - - window_rsi = 14 - avg_gain = gain.rolling(window_rsi).mean() - avg_loss = loss.rolling(window_rsi).mean() - rs = avg_gain / (avg_loss + 1e-9) - rsi = 100 - (100 / (1 + rs)) - - # ------------------------- MACD ------------------------- - ema12 = close.ewm(span=12, adjust=False).mean() - ema26 = close.ewm(span=26, adjust=False).mean() - macd = ema12 - ema26 - - # ------------------------- Bollinger Bands ------------------------- - ma20 = close.rolling(20).mean() - std20 = close.rolling(20).std() - bb_upper = ma20 + 2 * std20 - bb_lower = ma20 - 2 * std20 - - # ------------------------- ATR ------------------------- - prev_close = close.shift(1) - tr1 = high - low - tr2 = (high - prev_close).abs() - tr3 = (low - prev_close).abs() - tr = pd.concat([tr1, tr2, tr3], axis=1).max(axis=1) - atr = tr.rolling(14).mean() - - # ------------------------- OBV ------------------------- - obv = [0] - for i in range(1, len(close)): - if close.iloc[i] > close.iloc[i - 1]: - obv.append(obv[-1] + volume.iloc[i]) - elif close.iloc[i] < close.iloc[i - 1]: - obv.append(obv[-1] - volume.iloc[i]) - else: - obv.append(obv[-1]) - obv = pd.Series(obv, index=df.index) - - # ------------------------- Stochastic ------------------------- - lowest_low = low.rolling(14).min() - highest_high = high.rolling(14).max() - stochastic = (close - lowest_low) / (highest_high - lowest_low + 1e-9) * 100 - - # ------------------------- MFI ------------------------- - typical_price = (high + low + close) / 3 - raw_mf = typical_price * volume - tp_diff = typical_price.diff() - pos_mf = raw_mf.where(tp_diff > 0, 0) - neg_mf = raw_mf.where(tp_diff < 0, 0) - pos_14 = pos_mf.rolling(14).sum() - neg_14 = neg_mf.abs().rolling(14).sum() + 1e-9 - mfr = pos_14 / neg_14 - mfi = 100 - (100 / (1 + mfr)) - - # ------------------------- MA ------------------------- - ma_5 = close.rolling(5).mean() - ma_20_roll = close.rolling(20).mean() - ma_50 = close.rolling(50).mean() - ma_200 = close.rolling(200).mean() - - out = pd.DataFrame({ - "ticker": df["ticker"], - "date": df["date"], - "RSI": rsi, - "MACD": macd, - "Bollinger_Bands_upper": bb_upper, - "Bollinger_Bands_lower": bb_lower, - "ATR": atr, - "OBV": obv, - "Stochastic": stochastic, - "MFI": mfi, - "MA_5": ma_5, - "MA_20": ma_20_roll, - "MA_50": ma_50, - "MA_200": ma_200, - }) - - return out - - -# ============================================================ -# ③ yfinance 로 "작년 같은 달" 1개월치 MSFT OHLCV 다운로드 -# ============================================================ - -def fetch_msft_last_year_one_month() -> pd.DataFrame: - today = date.today() - last_year = today.year - 1 - month = today.month - - # 작년 같은 달의 1일 ~ 말일 - start = date(last_year, month, 1) - last_day = monthrange(last_year, month)[1] - end = date(last_year, month, last_day) + timedelta(days=1) # yfinance end는 exclusive - - print(f"[TEST] Fetching MSFT data: {start} ~ {end} (작년 같은 달 1개월)") - - df = yf.download("MSFT", start=start.strftime("%Y-%m-%d"), end=end.strftime("%Y-%m-%d")) - - if df.empty: - print("[TEST] No data returned from yfinance.") - return pd.DataFrame() - - # 인덱스 → 컬럼 - df = df.reset_index() - - # yfinance 포맷 → 표준 컬럼명으로 정리 - # (단일 티커이므로 MultiIndex 방어는 compute 함수에서 추가로 한 번 더 함) - df["date"] = pd.to_datetime(df["Date"]).dt.date - df["ticker"] = "MSFT" - - df = df.rename(columns={ - "Open": "open", - "High": "high", - "Low": "low", - "Close": "close", - "Volume": "volume", - }) - - return df[["ticker", "date", "open", "high", "low", "close", "volume"]] - - -# ============================================================ -# ④ 실제 계산 & 출력 (DB 업서트 없음) -# ============================================================ - -if __name__ == "__main__": - df_price = fetch_msft_last_year_one_month() - - if df_price.empty: - print("[TEST] 가격 데이터가 없어 기술지표를 계산할 수 없습니다.") - else: - df_tech = compute_technical_indicators(df_price) - - # 앞/뒤 일부를 확인해보고 싶으면 둘 다 찍어보자 - print("\n===== MSFT 기술적 지표 (앞 5행) =====\n") - print(df_tech.head(5).to_string(index=False)) - - print("\n===== MSFT 기술적 지표 (뒤 10행) =====\n") - print(df_tech.tail(10).to_string(index=False)) - diff --git a/AI/libs/core/pipeline.py b/AI/libs/core/pipeline.py index 6a90fba6..47e2d809 100644 --- a/AI/libs/core/pipeline.py +++ b/AI/libs/core/pipeline.py @@ -186,7 +186,6 @@ def run_signal_transformer(tickers: List[str], db_name: str) -> pd.DataFrame: print("[WARN] 빈 종목 리스트가 입력되어 Transformer 단계를 건너뜁니다.") return pd.DataFrame() - end_date = datetime.now().date() start_date = end_date - timedelta(days=600) @@ -304,6 +303,11 @@ def run_xai_report(decision_log: pd.DataFrame) -> List[ReportRow]: """ Transformer 결정 로그를 입력으로 받아, 각 행(의사결정)에 대한 XAI 설명 리포트(자연어 텍스트)를 생성합니다. + + 반환: + - ReportRow 리스트: (ticker, signal, price, date_s, report_text) + 이 리스트의 순서는 decision_log에서 XAI를 생성한 "행 순서"와 동일합니다. + (단, 여기서 일부 행을 스킵할 수도 있으므로, logs_df 행 수와 1:1 대응이라는 보장은 없음) """ print("--- [PIPELINE-STEP 3] XAI 리포트 생성 시작 ---") @@ -336,6 +340,15 @@ def run_xai_report(decision_log: pd.DataFrame) -> List[ReportRow]: signal = str(row.get("action", "")) # action → signal price = _to_float(row.get("price", 0.0)) + # 한국어 주석: + # - ticker / signal / date_s 중 하나라도 비어 있으면 + # save_reports_to_db 단계에서 필터링되어 id 개수가 줄어들어 + # logs_df 와 xai_ids 길이가 어긋날 수 있다. + # - 따라서 여기서부터 필수 값이 비어 있는 행은 XAI 생성 자체를 스킵한다. + if not ticker or not signal or not date_s: + print(f"[WARN] XAI 입력 값 누락으로 스킵 (ticker={ticker}, signal={signal}, date={date_s})") + continue + evidence: List[Dict[str, float]] = [] for i in (1, 2, 3): name = row.get(f"feature_name{i}") @@ -390,12 +403,13 @@ def run_pipeline() -> Optional[List[ReportRow]]: """ today = datetime.now() # 서버 로컬 시간 기준 (필요시 timezone 조정 가능) weekday = today.weekday() # 월=0, 화=1, ..., 일=6 + #------------------------------- # 0) 주가 데이터 저장 실행 #------------------------------- print("--- [PIPELINE-STEP 0] 주가 데이터 수집 실행 시작 ---") try: - #run_data_collection() + # run_data_collection() print("--- [PIPELINE-STEP 0] 주가 데이터 수집 실행 완료 ---") except Exception as e: print(f"[WARN] 데이터 수집 실행 중 오류 발생: {e} → 계속 진행합니다.") @@ -443,17 +457,52 @@ def run_pipeline() -> Optional[List[ReportRow]]: xai_ids = [] # 3.7) logs_df에 xai_report_id 심기 + # ------------------------------------------------ + # ✅ 기존: 길이가 다르면 전부 NULL 처리 + # ❌ 문제: XAI 리포트 수 != logs_df row 수 인 경우가 발생 + # ✅ 수정: (ticker, date, signal) 키로 매핑해서 가능한 행에만 ID 부여 + # ------------------------------------------------ logs_df = logs_df.copy().reset_index(drop=True) - if xai_ids and len(xai_ids) == len(logs_df): - logs_df["xai_report_id"] = xai_ids - else: - logs_df["xai_report_id"] = None - if xai_ids and len(xai_ids) != len(logs_df): + + # 기본값: 전부 None + logs_df["xai_report_id"] = None + + if reports and xai_ids: + if len(xai_ids) != len(reports): print( - f"[WARN] XAI ID 개수({len(xai_ids)})와 decision_log 행 수({len(logs_df)})가 달라 " - "xai_report_id를 매핑하지 못했습니다. (모두 NULL 처리)" + f"[WARN] XAI ID 개수({len(xai_ids)})와 리포트 행 수({len(reports)})가 달라 " + "완전한 1:1 대응은 아닐 수 있습니다. (최소한의 매핑만 수행)" ) + # (ticker, date, signal) → xai_id 매핑 딕셔너리 생성 + mapping: Dict[Tuple[str, str, str], int] = {} + for (ticker, signal, price, date_s, _report_text), xai_id in zip(reports, xai_ids): + if not ticker or not signal or not date_s: + continue + key = (str(ticker), str(date_s), str(signal)) + mapping[key] = xai_id + + if not mapping: + print("[WARN] 유효한 XAI 매핑이 없어 xai_report_id를 None으로 유지합니다.") + else: + def _lookup_xai_id(row: pd.Series) -> Optional[int]: + t = str(row.get("ticker", "")) + d = _to_iso_date(row.get("date", "")) + s = str(row.get("action", "")) + return mapping.get((t, d, s)) + + logs_df["xai_report_id"] = logs_df.apply(_lookup_xai_id, axis=1) + + # 디버그용: 실제로 몇 개의 row에 ID가 들어갔는지 확인 + assigned_count = logs_df["xai_report_id"].notna().sum() + print(f"[INFO] decision_log에 xai_report_id 매핑 완료 (할당된 row 수: {assigned_count})") + + else: + if not reports: + print("[INFO] 생성된 XAI 리포트가 없어 xai_report_id는 모두 NULL입니다.") + elif not xai_ids: + print("[INFO] XAI ID가 비어 있어 xai_report_id는 모두 NULL입니다.") + # ------------------------------- # 4) Backtrade: 매일 실행 # ------------------------------- From c53a7c0f1d938d236354318dc8185e0ad30a7095 Mon Sep 17 00:00:00 2001 From: twq110 Date: Sat, 29 Nov 2025 22:26:28 +0900 Subject: [PATCH 8/8] =?UTF-8?q?[AI]=20SISC2-43=20[REFACTOR]=20DB=20?= =?UTF-8?q?=ED=82=A4=20JSON=EC=97=90=EC=84=9C=20=ED=99=98=EA=B2=BD?= =?UTF-8?q?=EB=B3=80=EC=88=98=ED=99=94?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- AI/configs/config.json | 19 ---- AI/finder/financial_eval.py | 215 ++++++++++++++--------------------- AI/libs/utils/get_db_conn.py | 166 +++++++++++++++------------ AI/tests/quick_db_check.py | 89 ++++++++------- 4 files changed, 227 insertions(+), 262 deletions(-) delete mode 100644 AI/configs/config.json diff --git a/AI/configs/config.json b/AI/configs/config.json deleted file mode 100644 index 36e05aff..00000000 --- a/AI/configs/config.json +++ /dev/null @@ -1,19 +0,0 @@ -{ - "db": { - "host": "ep-misty-lab-adgec0kl-pooler.c-2.us-east-1.aws.neon.tech", - "user": "neondb_owner", - "password": "npg_hWkg04MwGlYs", - "dbname": "neondb", - "port": 5432 - } - - - , - "report_DB": { - "host": "ep-jolly-waterfall-ads25fir-pooler.c-2.us-east-1.aws.neon.tech", - "user": "neondb_owner", - "password": "npg_lo0rC9aOyFkw", - "dbname": "neondb", - "port": 5432 - } -} diff --git a/AI/finder/financial_eval.py b/AI/finder/financial_eval.py index 9a7be3ab..581a81b6 100644 --- a/AI/finder/financial_eval.py +++ b/AI/finder/financial_eval.py @@ -1,139 +1,87 @@ +# -*- coding: utf-8 -*- +""" +한국어 주석: +- 이 파일은 기존 config.json/psycopg2 기반 코드를 + 프로젝트 최신 표준 DB 유틸(get_engine, get_db_conn) 기반으로 완전히 리팩터링한 버전입니다. + +- 모든 DB 접속은 .env 기반 환경변수에서 읽고 + get_engine("db") 로 일관성 있게 연결합니다. +""" + import pandas as pd -import time -import json -import re from datetime import datetime - import warnings warnings.filterwarnings("ignore", category=FutureWarning) -import psycopg2 +# 🚀 신규 DB 엔진 유틸 (환경변수 기반) +from AI.libs.utils.get_db_conn import get_engine + +# ---------------------------------------------------------------------- +# 1️⃣ DB에서 재무제표 불러오기 +# ---------------------------------------------------------------------- +def load_company_fundamentals(db_name="db"): + """ + company_fundamentals 테이블 전체를 불러오는 함수. + JSON 설정 불필요. 환경변수 기반 DB 연결(get_engine) 사용. + """ + query = "SELECT * FROM company_fundamentals;" + engine = get_engine(db_name) -from pathlib import Path -import os + df = pd.read_sql(query, engine) + print(f"[INFO] Loaded fundamentals: {len(df)} rows") + return df -def load_config(path: str = "AI/configs/config.json") -> dict: #컨피그 파일 열기 - p=Path(path) - if not p.exists(): - alt = Path(__file__).parent / path - if alt.exists(): - p=alt - else: - raise FileNotFoundError(f"config file not found: {path}") - with p.open("r", encoding="utf-8") as f: - cfg = json.load(f) - if "db" not in cfg: - raise ValueError("config.json에 'db'섹션이 없습니다.") - # port 정수 보정 - if isinstance(cfg["db"].get("port",5432),str): #타입 검사, 없는경우 5432 사용 - cfg["db"]["port"]=int(cfg["db"]["port"]) - return cfg - -cfg = load_config("AI/configs/config.json") -db_cfg = cfg["db"] - -# DB 연결 & 데이터 로드 - -query = "SELECT * FROM company_fundamentals;" -with psycopg2.connect(**db_cfg) as conn: - company = pd.read_sql(query, conn) - -print(company.head()) - -ticker_list = company['ticker'].unique() # 498 - -#중복 정의라 주석처리했습니다. 확인후 수정 바람. -# 결측치 처리 -#def fill_financials(df: pd.DataFrame, industry_median: pd.DataFrame = None): -# df = df.copy() -# -# # 1. Forward/Backward fill (연속 시계열 기준) -# df = df.ffill().bfill() -# -# # 2. Equity 보정 (자산 - 부채) -# if 'Assets' in df.columns and 'Liabilities' in df.columns: -# df['Equity'] = df['Assets'] - df['Liabilities'] -# -# # 3. EPS 보정 (순이익 / 자본 or 주식수 정보 있을 경우) -# if 'NetIncome' in df.columns and 'EPS' in df.columns: -# df['EPS'] = df['EPS'].fillna(df['NetIncome'] / (df['Equity'].replace(0, pd.NA))) -# -# # 4. 업계 중앙값으로 채우기 (옵션) -# if industry_median is not None: -# df = df.fillna(industry_median) -# -# return df +# ---------------------------------------------------------------------- +# 2️⃣ 재무제표 결측치 보정 함수 (원본과 동일) +# ---------------------------------------------------------------------- def fill_financials(df, industry_medians=None): df = df.copy() - # 1. 시계열 정렬 df['date'] = pd.to_datetime(df['date']) df = df.sort_values(['ticker', 'date']) - # 2. 그룹별 처리 (기업별) def fill_group(g): g = g.sort_values('date') - - # revenue, net_income 보간 - g['revenue'] = g['revenue'].interpolate(method='linear').fillna(method='ffill').fillna(method='bfill') - g['net_income'] = g['net_income'].interpolate(method='linear').fillna(method='ffill').fillna(method='bfill') - - # equity 보정 (assets, liabilities 둘 다 있는 경우만) + + g['revenue'] = g['revenue'].interpolate().ffill().bfill() + g['net_income'] = g['net_income'].interpolate().ffill().bfill() + + # equity 보정 mask = g['equity'].isna() & g['total_assets'].notna() & g['total_liabilities'].notna() g.loc[mask, 'equity'] = g.loc[mask, 'total_assets'] - g.loc[mask, 'total_liabilities'] - + # liabilities 보정 mask = g['total_liabilities'].isna() & g['total_assets'].notna() & g['equity'].notna() g.loc[mask, 'total_liabilities'] = g.loc[mask, 'total_assets'] - g.loc[mask, 'equity'] - + # assets 보정 mask = g['total_assets'].isna() & g['equity'].notna() & g['total_liabilities'].notna() g.loc[mask, 'total_assets'] = g.loc[mask, 'equity'] + g.loc[mask, 'total_liabilities'] - - # 남은 결측치 시계열 보간 + + # 나머지 보간 for col in ['total_assets', 'total_liabilities', 'equity']: - g[col] = g[col].interpolate(method='linear').fillna(method='ffill').fillna(method='bfill') - - # eps 보간 - g['eps'] = g['eps'].interpolate(method='linear').fillna(method='ffill').fillna(method='bfill') - - # pe_ratio는 같은 기업의 중앙값으로 + g[col] = g[col].interpolate().ffill().bfill() + + g['eps'] = g['eps'].interpolate().ffill().bfill() + + # PE Ratio 기업중앙값 g['pe_ratio'] = g['pe_ratio'].fillna(g['pe_ratio'].median()) - + return g - + df = df.groupby('ticker').apply(fill_group).reset_index(drop=True) - # 3. 업계 평균으로 남은 결측치 채우기 (옵션) if industry_medians is not None: df = df.fillna(df['ticker'].map(industry_medians)) - - return df - -c_df = fill_financials(company) + return df -# 연간 재무제표로 변환 +# ---------------------------------------------------------------------- +# 3️⃣ 연간 재무제표 집계 +# ---------------------------------------------------------------------- def aggregate_yearly_financials(df): - """ - 연간 재무제표로 변환 - - revenue, net_income, eps: 연간 평균 (Flow) - - total_assets, total_liabilities, equity: 연말 값 (Stock) - - Parameters - ---------- - df : DataFrame - columns = [ticker, date, revenue, net_income, total_assets, total_liabilities, equity, eps] - - Returns - ------- - DataFrame - 연간 집계 데이터 - """ - - # 날짜 변환 & 연도 추출 df["date"] = pd.to_datetime(df["date"]) df["year"] = df["date"].dt.year @@ -145,30 +93,26 @@ def aggregate_yearly_financials(df): for (ticker, year), g in df.groupby(["ticker", "year"]): flow_data = g[flow_cols].mean() stock_data = g.loc[g["date"].idxmax(), stock_cols] - + yearly_data = {"ticker": ticker, "year": year} yearly_data.update(flow_data.to_dict()) yearly_data.update(stock_data.to_dict()) yearly_list.append(yearly_data) - yearly_df = pd.DataFrame(yearly_list) - - return yearly_df - -year_df = aggregate_yearly_financials(c_df) -year_df.head() + return pd.DataFrame(yearly_list) -# 평가표 생성 +# ---------------------------------------------------------------------- +# 4️⃣ 안정성 평가 점수 계산 +# ---------------------------------------------------------------------- def stability_score(df): results = [] - # 기업별 평가 for ticker, g in df.groupby("ticker"): latest = g.iloc[0] - # 1. Debt Ratio (총부채 / 총자산) + # Debt Ratio debt_ratio = latest["total_liabilities"] / latest["total_assets"] if latest["total_assets"] else None if debt_ratio is not None: if debt_ratio < 0.4: debt_score = 5 @@ -177,9 +121,9 @@ def stability_score(df): elif debt_ratio < 1.0: debt_score = 2 else: debt_score = 1 else: - debt_score = 3 # 결측 시 중립점 + debt_score = 3 - # 2. ROA (순이익 / 총자산) + # ROA roa = latest["net_income"] / latest["total_assets"] if latest["total_assets"] else None if roa is not None: if roa >= 0.08: roa_score = 5 @@ -190,7 +134,7 @@ def stability_score(df): else: roa_score = 3 - # 3. ROE (순이익 / 자본) + # ROE roe = latest["net_income"] / latest["equity"] if latest["equity"] else None if roe is not None: if roe >= 0.12: roe_score = 5 @@ -201,11 +145,16 @@ def stability_score(df): else: roe_score = 3 - # 4. 매출 성장률 (Revenue Growth: 최근 2년) + # 매출 성장률 if len(g) >= 2: - rev_growth = (g.iloc[-1]["revenue"] - g.iloc[-2]["revenue"]) / g.iloc[-2]["revenue"] if g.iloc[-2]["revenue"] else None + prev, curr = g.iloc[-2], g.iloc[-1] + if prev["revenue"] != 0: + rev_growth = (curr["revenue"] - prev["revenue"]) / prev["revenue"] + else: + rev_growth = None else: rev_growth = None + if rev_growth is not None: if rev_growth >= 0.10: rev_score = 5 elif rev_growth >= 0.05: rev_score = 4 @@ -214,13 +163,9 @@ def stability_score(df): else: rev_score = 3 - # 5. EPS - if latest["eps"] is not None: - eps_score = 5 if latest["eps"] > 0 else 1 - else: - eps_score = 3 + # EPS + eps_score = 5 if latest["eps"] > 0 else 1 if latest["eps"] is not None else 3 - # 최종 점수 (평균) total_score = round((debt_score + roa_score + roe_score + rev_score + eps_score) / 5, 2) results.append({ @@ -235,13 +180,25 @@ def stability_score(df): return pd.DataFrame(results) -recent_y_df = year_df[year_df['year'] > datetime.now().year-3].groupby("ticker")[['revenue', 'net_income', 'eps', - 'total_assets', 'total_liabilities', - 'equity']].mean() -eval_df = stability_score(recent_y_df) -eval_df.sort_values(by='stability_score', ascending=False, inplace=True) -# save -eval_df.to_csv(f'data/stability_score_{datetime.now().year}.csv', index=False) +# ---------------------------------------------------------------------- +# 5️⃣ 실행 파이프라인 +# ---------------------------------------------------------------------- +if __name__ == "__main__": + print("[STEP] Load fundamentals") + company = load_company_fundamentals("db") + + print("[STEP] Fill missing values") + c_df = fill_financials(company) + + print("[STEP] Aggregate yearly") + year_df = aggregate_yearly_financials(c_df) + + print("[STEP] Evaluate stability") + recent_y_df = year_df[year_df['year'] >= datetime.now().year - 3].groupby("ticker").mean() + eval_df = stability_score(recent_y_df) + eval_df.sort_values("stability_score", ascending=False, inplace=True) + eval_df.to_csv(f"data/stability_score_{datetime.now().year}.csv", index=False) + print("[DONE] stability score exported") diff --git a/AI/libs/utils/get_db_conn.py b/AI/libs/utils/get_db_conn.py index 6ac43771..e58a3305 100644 --- a/AI/libs/utils/get_db_conn.py +++ b/AI/libs/utils/get_db_conn.py @@ -1,100 +1,109 @@ # AI/libs/utils/get_db_conn.py -# 한국어 주석: JSON 설정에서 DB 접속정보를 읽어 -# 1) psycopg2 Connection (로우 커넥션) -# 2) SQLAlchemy Engine (권장, 커넥션 풀/프리핑) -# 을 생성하는 유틸. 중복 로딩 방지를 위해 캐시 사용. +# 환경변수 기반 DB 연결 유틸 from __future__ import annotations import os import sys -import json -from typing import Dict, Any, Optional +from typing import Dict, Any from pathlib import Path from urllib.parse import quote_plus import psycopg2 from sqlalchemy import create_engine -# --- 프로젝트 루트 경로 설정 --- -project_root = Path(__file__).resolve().parents[3] # .../AI/libs/utils/get_db_conn.py 기준 상위 3단계 -sys.path.append(str(project_root)) -# -------------------------------- - -# 필수 키(포트는 선택) -REQUIRED_KEYS = {"host", "user", "password", "dbname"} - -# JSON 설정 캐시 (프로세스 내 1회 로드) -_CONFIG_CACHE: Optional[Dict[str, Dict[str, Any]]] = None - -def _config_path() -> Path: - """configs/config.json 경로를 안전하게 계산""" - return project_root/"AI"/"configs"/"config.json" +# ---------------------------------------------------------------------- +# 프로젝트 루트 경로 설정 +# ---------------------------------------------------------------------- +project_root = Path(__file__).resolve().parents[3] +sys.path.append(str(project_root)) -def _load_configs() -> Dict[str, Dict[str, Any]]: +# ---------------------------------------------------------------------- +# 내부 헬퍼: prefix 정규화 +# ---------------------------------------------------------------------- +def _normalize_prefix(name: str) -> str: """ - - configs/config.json을 읽어서 {db_name: {host, user, password, dbname, port?, sslmode?}} 형태로 반환 - - 파일은 깃에 올리지 않는 것을 권장(민감정보) + 한국어 주석: + - db_name="db" → "DB_" + - db_name="DB" → "DB_" + - db_name="DB_" → "DB_" + - 이미 REPORT_DB_처럼 끝이 "_" 로 끝나는 경우 → 그대로 사용 + - REPORT_DB → REPORT_DB_ + - market_db → MARKET_DB_ """ - global _CONFIG_CACHE - if _CONFIG_CACHE is not None: - return _CONFIG_CACHE + if not name: + return "DB_" + + up = name.upper() + + # 이미 정확한 환경변수 prefix 구조인 경우 (예: REPORT_DB_) + if up.endswith("_"): + return up - path = _config_path() - if not path.exists(): - raise FileNotFoundError(f"[DB CONFIG] 설정 파일이 없습니다: {path}") + # REPORT_DB → REPORT_DB_ + if "_" in up and not up.endswith("_"): + return up + "_" - try: - with path.open("r", encoding="utf-8") as f: - data: Dict[str, Dict[str, Any]] = json.load(f) - except json.JSONDecodeError as e: - raise ValueError(f"[DB CONFIG] JSON 파싱 오류: {e}") from e + # db → DB_ + return up + "_" - # 간단한 구조 검증(필수키 확인은 get_* 함수에서 db별로 수행) - if not isinstance(data, dict) or not data: - raise ValueError("[DB CONFIG] 최상위 JSON은 비어있지 않은 객체여야 합니다.") - _CONFIG_CACHE = data - return _CONFIG_CACHE +# ---------------------------------------------------------------------- +# 환경변수에서 DB 설정 가져오기 +# ---------------------------------------------------------------------- +REQUIRED_ENV_KEYS = ["HOST", "USER", "PASSWORD", "NAME"] -def _get_db_config(db_name: str) -> Dict[str, Any]: +def _load_db_env(prefix: str) -> Dict[str, Any]: """ - - 특정 db_name에 해당하는 설정 블록을 반환 - - 필수 키(host, user, password, dbname) 존재 검증 + prefix(DB_, REPORT_DB_, MARKET_DB_ 등)에 맞는 환경변수 로딩 + 예: DB_HOST, REPORT_DB_HOST, MARKET_DB_PASSWORD... """ - if not db_name or not isinstance(db_name, str): - raise ValueError("db_name must be a non-empty string") + cfg = {} - configs = _load_configs() - cfg = configs.get(db_name) - if not cfg: - raise KeyError(f"[DB CONFIG] '{db_name}' 설정 블록을 찾을 수 없습니다. (configs/config.json)") + for key in os.environ: + if key.startswith(prefix): + # DB_HOST → host + sub = key.replace(prefix, "").lower() + cfg[sub] = os.environ[key] + + # 필수값 검사 + missing = [] + for key in REQUIRED_ENV_KEYS: + env_name = prefix + key + if env_name not in os.environ: + missing.append(env_name) - missing = REQUIRED_KEYS - set(cfg.keys()) if missing: - raise KeyError(f"[DB CONFIG] '{db_name}'에 필수 키 누락: {sorted(missing)}") + raise EnvironmentError( + f"[DB CONFIG ERROR] 아래 환경변수가 필요합니다:\n" + f" {missing}\n" + f"예시:\n" + f" export {prefix}HOST=...\n" + f" export {prefix}USER=...\n" + f" export {prefix}PASSWORD=...\n" + f" export {prefix}NAME=..." + ) + + # 기본 포트 + cfg.setdefault("port", "5432") return cfg +# ---------------------------------------------------------------------- +# SQLAlchemy URL 생성 +# ---------------------------------------------------------------------- def _build_sqlalchemy_url(cfg: Dict[str, Any]) -> str: - """ - - SQLAlchemy용 PostgreSQL URI를 안전하게 구성 - - 비밀번호/유저의 특수문자를 URL 인코딩(quote_plus)로 보호 - - 예: postgresql+psycopg2://user:pass@host:port/dbname?sslmode=require - """ - user = quote_plus(str(cfg["user"])) - password = quote_plus(str(cfg["password"])) - host = str(cfg["host"]) - port = int(cfg.get("port", 5432)) - dbname = str(cfg["dbname"]) + user = quote_plus(cfg["user"]) + password = quote_plus(cfg["password"]) + host = cfg["host"] + port = cfg.get("port", "5432") + dbname = cfg["name"] base = f"postgresql+psycopg2://{user}:{password}@{host}:{port}/{dbname}" - # 선택 옵션: sslmode - # (Neon/클라우드 Postgres의 경우 require가 흔함) sslmode = cfg.get("sslmode") if sslmode: return f"{base}?sslmode={sslmode}" @@ -102,28 +111,35 @@ def _build_sqlalchemy_url(cfg: Dict[str, Any]) -> str: return base -def get_db_conn(db_name: str): +# ---------------------------------------------------------------------- +# psycopg2 커넥션 +# ---------------------------------------------------------------------- +def get_db_conn(db_name: str = "DB_"): """ - - psycopg2 로우 커넥션 생성(직접 커서 열어 사용할 때) - - pandas 경고가 싫다면 read_sql에는 get_engine() 사용을 권장 + 한국어 주석: + - db_name="db" 또는 "DB" → 자동으로 prefix="DB_" + - db_name="report_db" → "REPORT_DB_" + - 이미 "REPORT_DB_" 라면 그대로 사용 """ - cfg = _get_db_config(db_name) + + prefix = _normalize_prefix(db_name) + cfg = _load_db_env(prefix) + return psycopg2.connect( host=cfg["host"], user=cfg["user"], password=cfg["password"], - dbname=cfg["dbname"], + dbname=cfg["name"], port=int(cfg.get("port", 5432)), - sslmode=cfg.get("sslmode", None), # 필요 시 자동 적용 + sslmode=cfg.get("sslmode"), ) -def get_engine(db_name: str): - """ - - SQLAlchemy Engine 생성(권장) - - 커넥션 풀 + pre_ping으로 죽은 연결 사전 감지 → 운영 안정성↑ - - pandas.read_sql, 대량입출력 등에서 사용 - """ - cfg = _get_db_config(db_name) +# ---------------------------------------------------------------------- +# SQLAlchemy 엔진 +# ---------------------------------------------------------------------- +def get_engine(db_name: str = "DB_"): + prefix = _normalize_prefix(db_name) + cfg = _load_db_env(prefix) url = _build_sqlalchemy_url(cfg) return create_engine(url, pool_pre_ping=True) diff --git a/AI/tests/quick_db_check.py b/AI/tests/quick_db_check.py index f066bd8f..fdcf254c 100644 --- a/AI/tests/quick_db_check.py +++ b/AI/tests/quick_db_check.py @@ -1,48 +1,59 @@ # quick_db_check.py + +""" +DB 연결을 빠르게 테스트하는 스크립트. +- 프로젝트 루트(sisc-web) 자동 계산 +- .env 자동 로드 +- 환경변수 기반 get_db_conn 사용 +""" + import os import sys -import json -from typing import Dict, Union +from dotenv import load_dotenv -import psycopg2 +# ----------------------------- +# 1) 프로젝트 루트 계산 (중요!) +# ----------------------------- +# 현재 위치: sisc-web/AI/tests/quick_db_check.py +project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +sys.path.append(project_root) +# ----------------------------- +# 2) .env 파일 로드 +# ----------------------------- +load_dotenv(os.path.join(project_root, ".env")) + +# ----------------------------- +# 3) DB 유틸 +# ----------------------------- +from AI.libs.utils.get_db_conn import get_db_conn -# --- 프로젝트 루트 경로 설정 --------------------------------------------------- -project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -sys.path.append(project_root) -# --- 설정 파일 로드 ------------------------------------------------------------ -cfg_path = os.path.join(project_root, "configs", "config.json") - -config: Dict = {} -if os.path.isfile(cfg_path): - with open(cfg_path, "r", encoding="utf-8") as f: - config = json.load(f) - print("[INFO] configs/config.json 로드 완료") -else: - print(f"[WARN] 설정 파일이 없습니다: {cfg_path}") - -db_cfg: Union[str, Dict] = (config or {}).get("db", {}) - -# --- DB 연결 테스트 ------------------------------------------------------------ -conn = None -try: - if isinstance(db_cfg, dict): - with psycopg2.connect(**db_cfg) as conn: - with conn.cursor() as cur: - cur.execute("SELECT version();") - print("✅ 연결 성공:", cur.fetchone()[0]) - cur.execute("SELECT current_database(), current_user;") - db, user = cur.fetchone() - print(f"ℹ️ DB/USER: {db} / {user}") - else: - with psycopg2.connect(dsn=str(db_cfg)) as conn: - with conn.cursor() as cur: - ... -except Exception as e: - print("❌ 연결 실패:", repr(e)) -finally: - if conn: +def quick_db_check(db_name: str = "db"): + print(f"[INFO] DB 연결 테스트 시작 (db_name='{db_name}')") + + try: + conn = get_db_conn(db_name) + except Exception as e: + print("❌ DB 연결 실패 (커넥션 생성 오류):", repr(e)) + return + + try: + with conn.cursor() as cur: + cur.execute("SELECT version();") + version = cur.fetchone()[0] + print("✅ 연결 성공:", version) + + cur.execute("SELECT current_database(), current_user;") + db, user = cur.fetchone() + print(f"ℹ️ DB = {db}, USER = {user}") + + except Exception as e: + print("❌ 쿼리 실행 실패:", repr(e)) + finally: conn.close() - print("DB 연결 종료") + print("🔌 DB 연결 종료") + +if __name__ == "__main__": + quick_db_check("db")