diff --git a/.gitignore b/.gitignore index 110ceb61..362bb45b 100644 --- a/.gitignore +++ b/.gitignore @@ -27,4 +27,7 @@ __pycache__/ /venv/ /env /.vs +/.venv/ + +# ===== Backend ===== backend/src/main/java/org/sejongisc/backend/stock/TestController.java diff --git a/AI/configs/config.json b/AI/configs/config.json index c66eaa90..36e05aff 100644 --- a/AI/configs/config.json +++ b/AI/configs/config.json @@ -1,4 +1,4 @@ -{ +{ "db": { "host": "ep-misty-lab-adgec0kl-pooler.c-2.us-east-1.aws.neon.tech", "user": "neondb_owner", diff --git a/AI/libs/core/pipeline.py b/AI/libs/core/pipeline.py index caca8fc9..0d0d29d2 100644 --- a/AI/libs/core/pipeline.py +++ b/AI/libs/core/pipeline.py @@ -1,8 +1,7 @@ -import os +import os import sys -from typing import List, Dict -import json -from datetime import datetime, timedelta +from typing import List, Dict, Optional +from datetime import datetime, timedelta, timezone import pandas as pd # --- 프로젝트 루트 경로 설정 --- @@ -16,160 +15,189 @@ from libs.utils.fetch_ohlcv import fetch_ohlcv from xai.run_xai import run_xai from libs.utils.get_db_conn import get_db_conn +from libs.utils.save_reports_to_db import save_reports_to_db # --------------------------------- +# DB 이름 상수(실제 등록된 키와 반드시 일치해야 함) +MARKET_DB_NAME = "db" # 시세/원천 데이터 DB +REPORT_DB_NAME = "report_DB" # 리포트 저장 DB + +REQUIRED_LOG_COLS = { + "ticker", "date", "action", "price", + "feature1", "feature2", "feature3", + "prob1", "prob2", "prob3" +} + def run_weekly_finder() -> List[str]: """ 주간 종목 발굴(Finder)을 실행하고 결과(종목 리스트)를 반환합니다. """ print("--- [PIPELINE-STEP 1] Finder 모듈 실행 시작 ---") - #top_tickers = run_finder() - top_tickers = ['AAPL', 'MSFT', 'GOOGL'] # 임시 데이터 - print(f"--- [PIPELINE-STEP 1] Finder 모듈 실행 완료 ---") + # top_tickers = run_finder() + top_tickers = ["AAPL", "MSFT", "GOOGL"] # 임시 데이터 + print("--- [PIPELINE-STEP 1] Finder 모듈 실행 완료 ---") return top_tickers -def run_signal_transformer(tickers: List[str], config: Dict) -> pd.DataFrame: +def _utcnow() -> datetime: + return datetime.now(timezone.utc) + +def run_signal_transformer(tickers: List[str], db_name: str) -> pd.DataFrame: """ 종목 리스트를 받아 Transformer 모듈을 실행하고, 신호(결정 로그)를 반환합니다. """ - try: - with open(os.path.join(project_root, 'configs', 'config.json'), 'r') as f: - config = json.load(f) - except FileNotFoundError: - print("Config file not found") - except json.JSONDecodeError: - print("Invalid JSON format in config file") - db_config = (config or {}).get("db", {}) # ★ db 섹션만 추출 print("--- [PIPELINE-STEP 2] Transformer 모듈 실행 시작 ---") - - # --- 실제 Transformer 모듈 호출 --- - end_date = datetime.now() + + if not tickers: + print("[WARN] 빈 종목 리스트가 입력되어 Transformer를 건너뜁니다.") + return pd.DataFrame() + + #end_date = _utcnow() # 한국 시간 기준 당일 종가까지 사용, 서버 사용시 주석 해제 + end_date = datetime.strptime("2024-10-30", "%Y-%m-%d") #임시 고정 날짜 start_date = end_date - timedelta(days=600) - all_ohlcv_df = [] + + all_ohlcv_df: List[pd.DataFrame] = [] for ticker in tickers: - ohlcv_df = fetch_ohlcv( - ticker=ticker, - start=start_date.strftime('%Y-%m-%d'), - end=end_date.strftime('%Y-%m-%d'), - config=db_config - ) - ohlcv_df['ticker'] = ticker - all_ohlcv_df.append(ohlcv_df) + try: + ohlcv_df = fetch_ohlcv( + ticker=ticker, + start=start_date.strftime("%Y-%m-%d"), + end=end_date.strftime("%Y-%m-%d"), + db_name=db_name + ) + if ohlcv_df is None or ohlcv_df.empty: + print(f"[WARN] OHLCV 미수집: {ticker}") + continue + ohlcv_df = ohlcv_df.copy() + ohlcv_df["ticker"] = ticker + all_ohlcv_df.append(ohlcv_df) + except Exception as e: + print(f"[ERROR] OHLCV 수집 실패({ticker}): {e}") + if not all_ohlcv_df: - print("OHLCV 데이터를 가져오지 못했습니다.") + print("[ERROR] 어떤 티커에서도 OHLCV 데이터를 가져오지 못했습니다.") return pd.DataFrame() + raw_data = pd.concat(all_ohlcv_df, ignore_index=True) - finder_df = pd.DataFrame(tickers, columns=['ticker']) - transformer_result = run_transformer( + + finder_df = pd.DataFrame(tickers, columns=["ticker"]) + transformer_result: Dict = run_transformer( finder_df=finder_df, seq_len=60, pred_h=1, - raw_data=raw_data, - config=config - ) - logs_df = transformer_result.get("logs", pd.DataFrame()) - - # --- 임시 결정 로그 데이터 (주석 처리) --- - # data = { - # 'ticker': ['AAPL', 'GOOGL', 'MSFT'], - # 'date': ['2025-09-17', '2025-09-17', '2025-09-17'], - # 'action': ['SELL', 'BUY', 'SELL'], - # 'price': [238.99, 249.52, 510.01], - # 'weight': [0.16, 0.14, 0.15], - # 'feature1': ['RSI', 'Stochastic', 'MACD'], - # 'feature2': ['MACD', 'MA_5', 'ATR'], - # 'feature3': ['Bollinger_Bands_lower', 'RSI', 'MA_200'], - # 'prob1': [0.5, 0.4, 0.6], - # 'prob2': [0.3, 0.25, 0.2], - # 'prob3': [0.1, 0.15, 0.1] - # } - # logs_df = pd.DataFrame(data) - - print(f"--- [PIPELINE-STEP 2] Transformer 모듈 실행 완료 ---") + raw_data=raw_data + ) or {} + + logs_df: pd.DataFrame = transformer_result.get("logs", pd.DataFrame()) + if logs_df is None or logs_df.empty: + print("[WARN] Transformer 결과 로그가 비어 있습니다.") + return pd.DataFrame() + + # 필수 컬럼 검증 + missing_cols = REQUIRED_LOG_COLS - set(logs_df.columns) + if missing_cols: + print(f"[ERROR] 결정 로그에 필수 컬럼 누락: {sorted(missing_cols)}") + return pd.DataFrame() + + print("--- [PIPELINE-STEP 2] Transformer 모듈 실행 완료 ---") return logs_df -def run_xai_report(decision_log: pd.DataFrame) -> List[str]: +# --- 안전 변환 유틸 --- +def _to_iso_date(v) -> str: + import pandas as pd + from datetime import datetime + try: + if isinstance(v, (pd.Timestamp, datetime)): + return v.strftime("%Y-%m-%d") + return str(v) + except Exception: + return str(v) + +def _to_float(v, fallback=0.0) -> float: + try: + return float(v) + except Exception: + return float(fallback) + +# --- XAI 리포트: 5-튜플(rows)로 반환 --- +from typing import List, Tuple + +def run_xai_report(decision_log: pd.DataFrame) -> List[Tuple[str, str, float, str, str]]: """ - 결정 로그를 바탕으로 실제 XAI 리포트를 생성합니다. + save_reports_to_db()가 기대하는 형식: + rows = List[ (ticker, signal, price, date_str, report_text) ] """ print("--- [PIPELINE-STEP 3] XAI 리포트 생성 시작 ---") api_key = os.environ.get("GROQ_API_KEY") if not api_key: - raise ValueError("XAI 리포트 생성을 위해 GROQ_API_KEY 환경 변수를 설정해주세요.") - reports = [] - if decision_log.empty: - return reports + print("[STOP] GROQ_API_KEY 미설정: XAI 리포트 단계를 건너뜁니다.") + return [] + + if decision_log is None or decision_log.empty: + print("[WARN] 비어있는 결정 로그가 입력되어 XAI 리포트를 생성하지 않습니다.") + return [] + + rows: List[Tuple[str, str, float, str, str]] = [] + for _, row in decision_log.iterrows(): - decision = { - "ticker": row['ticker'], - "date": row['date'], - "signal": row['action'], - "price": row['price'], + ticker = str(row.get("ticker", "UNKNOWN")) + date_s = _to_iso_date(row.get("date", "")) + signal = str(row.get("action", "")) + price = _to_float(row.get("price", 0.0)) + + # evidence 등은 DB에 안 넣는 설계로 보이므로 내부 호출에만 사용 + decision_payload = { + "ticker": ticker, + "date": date_s, + "signal": signal, + "price": price, "evidence": [ - {"feature_name": row['feature1'], "contribution": row['prob1']}, - {"feature_name": row['feature2'], "contribution": row['prob2']}, - {"feature_name": row['feature3'], "contribution": row['prob3']}, - ] + {"feature_name": str(row.get("feature1", "")), "contribution": _to_float(row.get("prob1", 0.0))}, + {"feature_name": str(row.get("feature2", "")), "contribution": _to_float(row.get("prob2", 0.0))}, + {"feature_name": str(row.get("feature3", "")), "contribution": _to_float(row.get("prob3", 0.0))}, + ], } + try: - report = run_xai(decision, api_key) - reports.append(report) - print(f"--- {row['ticker']} XAI 리포트 생성 완료 ---") + report_text = run_xai(decision_payload, api_key) + report_text = str(report_text) # 혹시 모를 비문자 타입 대비 + print(f"--- {ticker} XAI 리포트 생성 완료 ---") except Exception as e: - error_message = f"--- {row['ticker']} XAI 리포트 생성 중 오류 발생: {e} ---" - print(error_message) - reports.append(error_message) - print(f"--- [PIPELINE-STEP 3] XAI 리포트 생성 완료 ---") - return reports + report_text = f"[ERROR] XAI 리포트 생성 실패: {e}" + print(f"--- {ticker} XAI 리포트 생성 중 오류: {e} ---") -def save_reports_to_db(reports: List[str], config: Dict): - """ - 생성된 XAI 리포트를 데이터베이스에 저장합니다. - """ - db_config = config.get("report_DB", {}) - conn = get_db_conn(db_config) - cursor = conn.cursor() - insert_query = """ - INSERT INTO xai_reports (report_text, created_at) - VALUES (%s, %s); - """ - for report in reports: - cursor.execute(insert_query, (report, datetime.now())) - conn.commit() - cursor.close() - conn.close() - print(f"--- {len(reports)}개의 XAI 리포트가 데이터베이스에 저장되었습니다. ---") - -# --- 전체 파이프라인 실행 --- -def run_pipeline(): + rows.append((ticker, signal, price, date_s, report_text)) + + print("--- [PIPELINE-STEP 3] XAI 리포트 생성 완료 ---") + return rows + + + + +def run_pipeline() -> Optional[List[str]]: """ 전체 파이프라인(Finder -> Transformer -> XAI)을 실행합니다. """ - #--- 설정 파일 로드 --- - config : Dict = {} - try: - with open(os.path.join(project_root, 'configs', 'config.json'), 'r') as f: - config = json.load(f) - except FileNotFoundError: - print("[WARN] configs/config.json 파일을 찾을 수 없어 DB 연결이 필요 없는 기능만 작동합니다.") - - #--- 파이프라인 단계별 실행 --- - top_tickers = run_weekly_finder() - if not top_tickers: - print("Finder에서 종목을 찾지 못해 파이프라인을 중단합니다.") + # 1) Finder + tickers = run_weekly_finder() + if not tickers: + print("[STOP] Finder에서 종목을 찾지 못해 파이프라인을 중단합니다.") return None - decision_log = run_signal_transformer(top_tickers, config) - if decision_log.empty: - print("Transformer에서 신호를 생성하지 못해 파이프라인을 중단합니다.") + + # 2) Transformer + logs_df = run_signal_transformer(tickers, MARKET_DB_NAME) + if logs_df is None or logs_df.empty: + print("[STOP] Transformer에서 신호를 생성하지 못해 파이프라인을 중단합니다.") return None - xai_reports = run_xai_report(decision_log) - - save_reports_to_db(xai_reports, config) - return xai_reports + # 3) XAI + reports = run_xai_report(logs_df) + # 4) 저장 + save_reports_to_db(reports, REPORT_DB_NAME) + + return reports -# --- 테스트를 위한 실행 코드 --- +# --- 테스트 실행 --- if __name__ == "__main__": print(">>> 파이프라인 (Finder -> Transformer -> XAI) 테스트를 시작합니다.") final_reports = run_pipeline() diff --git a/AI/libs/db/.gitkeep b/AI/libs/db/.gitkeep deleted file mode 100644 index e69de29b..00000000 diff --git a/AI/libs/utils/fetch_ohlcv.py b/AI/libs/utils/fetch_ohlcv.py index 9570c771..8fffd8dd 100644 --- a/AI/libs/utils/fetch_ohlcv.py +++ b/AI/libs/utils/fetch_ohlcv.py @@ -1,41 +1,69 @@ -import pandas as pd +# libs/utils/fetch_ohlcv.py +from __future__ import annotations +from typing import Optional +import pandas as pd +from sqlalchemy import text -# DB 접속 커넥션 생성 -from .get_db_conn import get_db_conn +# DB용 유틸: SQLAlchemy Engine 생성 함수 사용 (get_engine) +from .get_db_conn import get_engine -# OHLCV 데이터 불러오기 def fetch_ohlcv( ticker: str, start: str, end: str, interval: str = "1d", - config: dict = None, # type: ignore + db_name: str = "db", ) -> pd.DataFrame: """ - 특정 티커, 날짜 범위의 OHLCV 데이터를 DB에서 불러오기 + 특정 티커, 날짜 범위의 OHLCV 데이터를 DB에서 불러오기 (SQLAlchemy 엔진 사용) Args: - ticker (str): 종목 코드 - start (str): 시작일자 'YYYY-MM-DD' - end (str): 종료일자 'YYYY-MM-DD' - interval (str): 데이터 간격 ('1d' 등) - config (dict): DB 접속 정보 포함한 설정 + ticker (str): 종목 코드 (예: "AAPL") + start (str): 시작일자 'YYYY-MM-DD' (inclusive) + end (str): 종료일자 'YYYY-MM-DD' (inclusive) + interval (str): 데이터 간격 ('1d' 등) - 현재 테이블이 일봉만 제공하면 무시됨 + db_name (str): get_engine()가 참조할 설정 블록 이름 (예: "db", "report_DB") Returns: - DataFrame: 컬럼 = [ticker, date, open, high, low, close, volume, adjusted_close] + pd.DataFrame: 컬럼 = [ticker, date, open, high, low, close, adjusted_close, volume] + (date 컬럼은 pandas datetime으로 변환됨) """ - conn = get_db_conn(config) - query = """ + # 1) SQLAlchemy engine 얻기 (configs/config.json 기준) + engine = get_engine(db_name) + + # 2) 쿼리: named parameter(:ticker 등) 사용 -> 안전하고 가독성 좋음 + # - interval 분기가 필요하면 테이블/파티션 구조에 따라 쿼리를 분기하도록 확장 가능 + query = text(""" SELECT ticker, date, open, high, low, close, adjusted_close, volume FROM public.price_data - WHERE ticker = %s - AND date BETWEEN %s AND %s + WHERE ticker = :ticker + AND date BETWEEN :start AND :end ORDER BY date; - """ + """) + + # 3) DB에서 읽기 (with 문으로 커넥션 자동 정리) + with engine.connect() as conn: + df = pd.read_sql( + query, + con=conn, # 꼭 키워드 인자로 con=conn + params={"ticker": ticker, "start": start, "end": end}, # 튜플 X, 딕셔너리 O + ) - # 파라미터 바인딩 (%s) 사용 → SQL injection 방지 - df = pd.read_sql(query, conn, params=(ticker, start, end)) + # 4) 후처리: 컬럼 정렬 및 date 타입 통일 + if df is None or df.empty: + # 빈 DataFrame이면 일관된 컬럼 스키마로 반환 + return pd.DataFrame(columns=["ticker", "date", "open", "high", "low", "close", "adjusted_close", "volume"]) + + # date 컬럼을 datetime으로 변경 (UTC로 맞추고 싶으면 pd.to_datetime(..., utc=True) 사용) + if "date" in df.columns: + df["date"] = pd.to_datetime(df["date"]) + + # 선택: 컬럼 순서 고정 (일관성 유지) + desired_cols = ["ticker", "date", "open", "high", "low", "close", "adjusted_close", "volume"] + # 존재하는 컬럼만 가져오기 + cols_present = [c for c in desired_cols if c in df.columns] + df = df.loc[:, cols_present] - conn.close() return df + diff --git a/AI/libs/utils/get_db_conn.py b/AI/libs/utils/get_db_conn.py index ad930d63..6ac43771 100644 --- a/AI/libs/utils/get_db_conn.py +++ b/AI/libs/utils/get_db_conn.py @@ -1,12 +1,129 @@ +# AI/libs/utils/get_db_conn.py +# 한국어 주석: JSON 설정에서 DB 접속정보를 읽어 +# 1) psycopg2 Connection (로우 커넥션) +# 2) SQLAlchemy Engine (권장, 커넥션 풀/프리핑) +# 을 생성하는 유틸. 중복 로딩 방지를 위해 캐시 사용. + +from __future__ import annotations +import os +import sys +import json +from typing import Dict, Any, Optional +from pathlib import Path +from urllib.parse import quote_plus + import psycopg2 +from sqlalchemy import create_engine + +# --- 프로젝트 루트 경로 설정 --- +project_root = Path(__file__).resolve().parents[3] # .../AI/libs/utils/get_db_conn.py 기준 상위 3단계 +sys.path.append(str(project_root)) +# -------------------------------- + +# 필수 키(포트는 선택) +REQUIRED_KEYS = {"host", "user", "password", "dbname"} + +# JSON 설정 캐시 (프로세스 내 1회 로드) +_CONFIG_CACHE: Optional[Dict[str, Dict[str, Any]]] = None + + +def _config_path() -> Path: + """configs/config.json 경로를 안전하게 계산""" + return project_root/"AI"/"configs"/"config.json" + + +def _load_configs() -> Dict[str, Dict[str, Any]]: + """ + - configs/config.json을 읽어서 {db_name: {host, user, password, dbname, port?, sslmode?}} 형태로 반환 + - 파일은 깃에 올리지 않는 것을 권장(민감정보) + """ + global _CONFIG_CACHE + if _CONFIG_CACHE is not None: + return _CONFIG_CACHE + + path = _config_path() + if not path.exists(): + raise FileNotFoundError(f"[DB CONFIG] 설정 파일이 없습니다: {path}") + + try: + with path.open("r", encoding="utf-8") as f: + data: Dict[str, Dict[str, Any]] = json.load(f) + except json.JSONDecodeError as e: + raise ValueError(f"[DB CONFIG] JSON 파싱 오류: {e}") from e + + # 간단한 구조 검증(필수키 확인은 get_* 함수에서 db별로 수행) + if not isinstance(data, dict) or not data: + raise ValueError("[DB CONFIG] 최상위 JSON은 비어있지 않은 객체여야 합니다.") + + _CONFIG_CACHE = data + return _CONFIG_CACHE -def get_db_conn(config: dict): - """config에서 DB 접속 정보 가져와 psycopg2 Connection 생성""" - conn = psycopg2.connect( - host=config["host"], - user=config["user"], - password=config["password"], - dbname=config["dbname"], - port=config.get("port", 5432), + +def _get_db_config(db_name: str) -> Dict[str, Any]: + """ + - 특정 db_name에 해당하는 설정 블록을 반환 + - 필수 키(host, user, password, dbname) 존재 검증 + """ + if not db_name or not isinstance(db_name, str): + raise ValueError("db_name must be a non-empty string") + + configs = _load_configs() + cfg = configs.get(db_name) + if not cfg: + raise KeyError(f"[DB CONFIG] '{db_name}' 설정 블록을 찾을 수 없습니다. (configs/config.json)") + + missing = REQUIRED_KEYS - set(cfg.keys()) + if missing: + raise KeyError(f"[DB CONFIG] '{db_name}'에 필수 키 누락: {sorted(missing)}") + + return cfg + + +def _build_sqlalchemy_url(cfg: Dict[str, Any]) -> str: + """ + - SQLAlchemy용 PostgreSQL URI를 안전하게 구성 + - 비밀번호/유저의 특수문자를 URL 인코딩(quote_plus)로 보호 + - 예: postgresql+psycopg2://user:pass@host:port/dbname?sslmode=require + """ + user = quote_plus(str(cfg["user"])) + password = quote_plus(str(cfg["password"])) + host = str(cfg["host"]) + port = int(cfg.get("port", 5432)) + dbname = str(cfg["dbname"]) + + base = f"postgresql+psycopg2://{user}:{password}@{host}:{port}/{dbname}" + + # 선택 옵션: sslmode + # (Neon/클라우드 Postgres의 경우 require가 흔함) + sslmode = cfg.get("sslmode") + if sslmode: + return f"{base}?sslmode={sslmode}" + + return base + + +def get_db_conn(db_name: str): + """ + - psycopg2 로우 커넥션 생성(직접 커서 열어 사용할 때) + - pandas 경고가 싫다면 read_sql에는 get_engine() 사용을 권장 + """ + cfg = _get_db_config(db_name) + return psycopg2.connect( + host=cfg["host"], + user=cfg["user"], + password=cfg["password"], + dbname=cfg["dbname"], + port=int(cfg.get("port", 5432)), + sslmode=cfg.get("sslmode", None), # 필요 시 자동 적용 ) - return conn \ No newline at end of file + + +def get_engine(db_name: str): + """ + - SQLAlchemy Engine 생성(권장) + - 커넥션 풀 + pre_ping으로 죽은 연결 사전 감지 → 운영 안정성↑ + - pandas.read_sql, 대량입출력 등에서 사용 + """ + cfg = _get_db_config(db_name) + url = _build_sqlalchemy_url(cfg) + return create_engine(url, pool_pre_ping=True) diff --git a/AI/libs/utils/save_reports_to_db.py b/AI/libs/utils/save_reports_to_db.py new file mode 100644 index 00000000..6b833fdf --- /dev/null +++ b/AI/libs/utils/save_reports_to_db.py @@ -0,0 +1,99 @@ +# libs/core/save_reports_to_db.py +from __future__ import annotations +from typing import Iterable, Tuple, List +from datetime import datetime, timezone +import sys +from sqlalchemy import create_engine, text +import os + +# --- 프로젝트 루트 경로 설정 --- +project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +sys.path.append(project_root) +# ------------------------------ + +from libs.utils.get_db_conn import get_engine + +ReportRow = Tuple[str, str, float, str, str] + +def utcnow() -> datetime: + return datetime.now(timezone.utc) + + +def ensure_table_schema(engine) -> None: + """ + 한국어 주석: + - 정보스키마 조회 후 필요한 컬럼만 추가. + """ + with engine.begin() as conn: + cols = conn.execute(text(""" + SELECT column_name FROM information_schema.columns + WHERE table_schema='public' AND table_name='xai_reports'; + """)).fetchall() + existing = {r[0] for r in cols} + need = {"ticker", "signal", "price", "date", "report", "created_at"} + missing = need - existing + if missing: + parts = [] + if "ticker" in missing: parts.append("ADD COLUMN IF NOT EXISTS ticker varchar(20) NOT NULL") + if "signal" in missing: parts.append("ADD COLUMN IF NOT EXISTS signal varchar(10) NOT NULL") + if "price" in missing: parts.append("ADD COLUMN IF NOT EXISTS price numeric(10,2) NOT NULL") + if "date" in missing: parts.append("ADD COLUMN IF NOT EXISTS date date NOT NULL") + if "report" in missing: parts.append("ADD COLUMN IF NOT EXISTS report text") + if "created_at" in missing: + parts.append("ADD COLUMN IF NOT EXISTS created_at timestamptz NOT NULL DEFAULT now()") + conn.execute(text(f"ALTER TABLE public.xai_reports {', '.join(parts)};")) + +def build_insert_params(rows: Iterable[ReportRow], created_at: datetime) -> List[dict]: + """ + 한국어 주석: + - SQLAlchemy의 named parameter 형태(dict)로 변환. + """ + out: List[dict] = [] + for (ticker, signal, price, date_s, report_text) in rows: + if not ticker or not signal or not date_s: + continue + out.append({ + "ticker": ticker, + "signal": signal, + "price": float(price), + "date": date_s, # 'YYYY-MM-DD' + "report": str(report_text), + "created_at": created_at, + }) + return out + +def save_reports_to_db(rows: List[ReportRow], db_name: str) -> int: + """ + 한국어 주석: + - SQLAlchemy로 안전하게 INSERT. + - pandas 경고 제거, 커넥션 관리 자동화, 프리핑으로 죽은 커넥션 방지. + """ + if not rows: + print("[INFO] 저장할 리포트가 없습니다.") + return 0 + + engine = get_engine(db_name) + ensure_table_schema(engine) + + created_at = utcnow() + params = build_insert_params(rows, created_at) + if not params: + print("[WARN] 유효한 저장 파라미터가 없어 INSERT를 건너뜁니다.") + return 0 + + insert_sql = text(""" + INSERT INTO public.xai_reports (ticker, signal, price, date, report, created_at) + VALUES (:ticker, :signal, :price, :date, :report, :created_at) + """) + + inserted = 0 + # 대량이면 청크 분할 권장 + CHUNK = 1000 + with engine.begin() as conn: + for i in range(0, len(params), CHUNK): + batch = params[i:i+CHUNK] + conn.execute(insert_sql, batch) + inserted += len(batch) + + print(f"--- {inserted}개의 XAI 리포트가 데이터베이스에 저장되었습니다. ---") + return inserted diff --git a/AI/transformer/main.py b/AI/transformer/main.py index f998b282..bc14ba25 100644 --- a/AI/transformer/main.py +++ b/AI/transformer/main.py @@ -5,9 +5,6 @@ from pathlib import Path -# (선택) 프로젝트 공용 로거가 있다면 교체: from AI.libs.utils.io import _log -_log = print - # ★ 실제 추론 로직은 modules/inference.run_inference 에 구현되어 있음 from .modules.inference import run_inference @@ -19,7 +16,7 @@ def run_transformer( pred_h: int, raw_data: pd.DataFrame, run_date: Optional[str] = None, - config: Optional[dict] = None, + weights_path: Optional[str] = None, interval: str = "1d", ) -> Dict[str, pd.DataFrame]: """ @@ -60,14 +57,19 @@ def run_transformer( """ # 1) weights_path 경로지정 - base_dir = Path("/transformer/weights") - candidate = base_dir / "inital.weights.h5" + PROJECT_ROOT = Path(__file__).resolve().parents[1] + + weights_dir = PROJECT_ROOT / "transformer" / "weights" + candidate = weights_dir / "initial.weights.h5" - weights_path = str(candidate) if candidate.exists() else None + weights_path = str(candidate) + if candidate.exists(): + + print(f"[TRANSFORMER] weights_path 설정됨: {weights_path}") if not weights_path: - _log("[TRANSFORMER][WARN] weights_path 미설정 → 가중치 없이 랜덤 초기화로 추론될 수 있음(품질 저하).") - _log(" config 예시: {'transformer': {'weights_path': 'weights/inital.weights.h5'}}") + print("[TRANSFORMER][WARN] weights_path 미설정 → 가중치 없이 랜덤 초기화로 추론될 수 있음(품질 저하).") + print(" config 예시: {'transformer': {'weights_path': 'weights/initial.weights.h5'}}") # 2) 실제 추론 실행(모듈 위임) diff --git a/AI/transformer/modules/__init__.py b/AI/transformer/modules/__init__.py index 75684179..66d610c0 100644 --- a/AI/transformer/modules/__init__.py +++ b/AI/transformer/modules/__init__.py @@ -1,3 +1,3 @@ -# AI/finder/__init__.py +# AI/transformer/modules/__init__.py from .models import build_transformer_classifier __all__ = ["build_transformer_classifier"] diff --git a/AI/transformer/modules/inference.py b/AI/transformer/modules/inference.py index 7c7700c0..8d145f9b 100644 --- a/AI/transformer/modules/inference.py +++ b/AI/transformer/modules/inference.py @@ -1,13 +1,12 @@ -# transformer/modules/inference.py +# transformer/modules/inference.py from __future__ import annotations from typing import Dict, List, Optional, Tuple import numpy as np import pandas as pd from sklearn.preprocessing import MinMaxScaler -from tensorflow.keras import Model +from tensorflow.keras import Model + -# from AI.libs.utils.io import _log -_log = print # TODO: 추후 io._log 로 교체 from transformer.modules.models import build_transformer_classifier from transformer.modules.features import FEATURES, build_features @@ -37,11 +36,11 @@ def _load_or_build_model(seq_len: int, n_features: int, weights_path: Optional[s if weights_path: try: model.load_weights(weights_path) - _log(f"[INFER] Transformer weights loaded: {weights_path}") + print(f"[INFER] 가중치 로드 완료 : {weights_path}") except Exception as e: - _log(f"[INFER][WARN] 가중치 로드 실패 → 랜덤 초기화: {e}") + print(f"[INFER][WARN] 가중치 로드 실패 → 랜덤 초기화: {e}") else: - _log("[INFER][WARN] weights_path 미지정 → 랜덤 초기화로 진행") + print("[INFER][WARN] weights_path 미지정 → 랜덤 초기화로 진행") return model # ===== 공개 엔트리포인트 (추론) ===== @@ -80,7 +79,7 @@ def run_inference( """ tickers = finder_df["ticker"].astype(str).tolist() if raw_data is None or raw_data.empty: - _log("[INFER] raw_data empty -> empty logs") + print("[INFER] raw_data empty -> empty logs") return {"logs": pd.DataFrame(columns=[ "ticker","date","action","price","weight", "feature1","feature2","feature3","prob1","prob2","prob3" @@ -94,7 +93,7 @@ def run_inference( df = df.rename(columns={c: c.lower() for c in df.columns}) df = df[df["ticker"].astype(str).isin(tickers)] if df.empty: - _log("[INFER] 대상 종목 데이터 없음") + print("[INFER] 대상 종목 데이터 없음") return {"logs": pd.DataFrame(columns=[ "ticker","date","action","price","weight", "feature1","feature2","feature3","prob1","prob2","prob3" @@ -131,12 +130,12 @@ def run_inference( feats = build_features(ohlcv) if feats.empty: - _log(f"[INFER] {t} features empty -> skip") + print(f"[INFER] {t} features empty -> skip") continue X_seq = _make_sequence(feats, model_feats, seq_len) if X_seq is None: - _log(f"[INFER] {t} 부족한 길이(seq_len={seq_len}) -> skip") + print(f"[INFER] {t} 부족한 길이(seq_len={seq_len}) -> skip") continue X_scaled, _ = _scale_per_ticker(X_seq) @@ -149,7 +148,7 @@ def run_inference( buy_p, hold_p, sell_p = float(probs[0]), float(probs[1]), float(probs[2]) action = ["BUY","HOLD","SELL"][int(np.argmax(probs))] except Exception as e: - _log(f"[INFER][WARN] 예측 실패({t}) → 룰기반 fallback: {e}") + print(f"[INFER][WARN] 예측 실패({t}) → 룰기반 fallback: {e}") recent = feats.iloc[-1] rsi = float(recent["RSI"]) macd = float(recent["MACD"]) @@ -188,7 +187,7 @@ def run_inference( "prob3": float(sell_p), }) except Exception as e: - _log(f"[INFER][ERROR] {t}: {e}") + print(f"[INFER][ERROR] {t}: {e}") continue logs_df = pd.DataFrame(rows, columns=[ diff --git a/AI/transformer/modules/models.py b/AI/transformer/modules/models.py index b844cb51..d361e5e8 100644 --- a/AI/transformer/modules/models.py +++ b/AI/transformer/modules/models.py @@ -1,9 +1,9 @@ import numpy as np import tensorflow as tf -from tensorflow.keras import layers, Model +from tensorflow.keras import layers, Model # type: ignore[reportMissingImports] # 위치 인코딩 -def positional_encoding(maxlen: int, d_model: int) -> tf.Tensor: +def positional_encoding(maxlen: int, d_model: int): angles = np.arange(maxlen)[:, None] / np.power( 10000, (2 * (np.arange(d_model)[None, :] // 2)) / d_model ) diff --git a/AI/transformer/scaler/scaler.pkl b/AI/transformer/scaler/scaler.pkl new file mode 100644 index 00000000..10b3e9b8 Binary files /dev/null and b/AI/transformer/scaler/scaler.pkl differ diff --git a/AI/transformer/training/train_transformer.py b/AI/transformer/training/train_transformer.py index 4bda0db2..3f20c887 100644 --- a/AI/transformer/training/train_transformer.py +++ b/AI/transformer/training/train_transformer.py @@ -1,11 +1,23 @@ -# transformer/training/train_transformer.py +# transformer/training/train_transformer.py from __future__ import annotations -from typing import Dict, List, Optional + +# ───────────────────────────────────────────────────────────────────────────── +# 코드 설명 +# 1) DB(PostgreSQL, psycopg2)에서 public.price_data로부터 "모든" 티커와 일봉 데이터 추출 +# 2) 수집된 일봉 데이터를 기반으로 사용자 정의 피처를 생성하고 +# 3) 미래 수익률 라벨링(BUY/HOLD/SELL) 후 Transformer 분류 모델을 학습 +# 4) 최적 가중치(.h5)와 스케일러(.pkl)를 저장 +# +# 주의: +# - 시간대 처리는 일관성을 위해 기본적으로 UTC 'date' 컬럼을 우선 사용한다. +# - run_date 컷은 Asia/Seoul 기준 날짜를 UTC로 변환 후 비교한다. +# - 대량 티커 수집 시 API 호출 간 sleep을 넣어 서버 과부하/차단을 완화한다. +# ───────────────────────────────────────────────────────────────────────────── + +from typing import Dict, List, Optional, Any import os -import time import pickle import sys -import requests # ← yfinance SSL 이슈 회피용: REST 직접 호출 import numpy as np import pandas as pd from sklearn.preprocessing import MinMaxScaler @@ -13,20 +25,90 @@ import tensorflow as tf from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau +# --- 프로젝트/레포 경로 설정 --------------------------------------------------- +_this_file = os.path.abspath(__file__) +project_root = os.path.dirname(os.path.dirname(_this_file)) # .../transformer +repo_root = os.path.dirname(project_root) # .../ +libs_root = os.path.join(repo_root, "libs") # .../libs -# --- 프로젝트 루트 경로 설정 --------------------------------------------------- -project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -sys.path.append(project_root) -# --------------------------------------------------------------------------- +# sys.path에 중복 없이 추가 +for p in (project_root, repo_root, libs_root): + if p not in sys.path: + sys.path.append(p) +# ------------------------------------------------------------------------------ from modules.features import FEATURES, build_features from modules.models import build_transformer_classifier +from libs.utils.get_db_conn import get_db_conn # from AI.libs.utils.io import _log _log = print # TODO: io._log 로 교체 +# DB 이름(프로젝트 환경에 맞춰 설정) +db_name = "db" + +# 클래스 이름 매핑(라벨→사람이 읽을 수 있는 문자열) CLASS_NAMES = ["BUY", "HOLD", "SELL"] + +# ============================================================================= +# DB에서 "모든" 티커 목록 가져오기 (제한 없음) +# - PostgreSQL(psycopg2) 연결을 get_db_conn(db_name)으로 얻는다고 가정 +# - 스키마: public.price_data (PK: (ticker, date)) +# ============================================================================= +def load_all_tickers_from_db(verbose: bool = True) -> List[str]: + """ + public.price_data에서 사용 가능한 모든 티커를 DISTINCT로 추출하여 반환한다. + + 반환 + ---- + List[str] + 대문자 티커 문자열 리스트(중복 제거/공백 제거/알파벳 정렬) + + 구현 상세 + -------- + - psycopg2 커넥션을 사용한다고 가정하고, 사용 후 conn.close()로 커넥션을 닫는다. + - DataFrame 정리 단계에서 NULL/빈문자열, 공백 등을 제거한다. + """ + conn = get_db_conn(db_name) # psycopg2 커넥션 또는 SQLAlchemy 엔진 + try: + sql = """ + SELECT DISTINCT ticker + FROM public.price_data + WHERE ticker IS NOT NULL AND ticker <> '' + """ + df = pd.read_sql(sql, conn) + finally: + # psycopg2 커넥션일 경우 명시적으로 닫아 리소스 누수 방지 + try: + conn.close() + except Exception: + pass + + if df.empty or "ticker" not in df.columns: + raise RuntimeError("[load_all_tickers_from_db] price_data에서 티커를 찾지 못했습니다.") + + # 문자열 정리: 공백 제거 → 대문자화 → 결측/중복 제거 → 정렬 + tickers = ( + df["ticker"] + .astype(str) + .str.strip() + .str.upper() + .dropna() + .drop_duplicates() + .tolist() + ) + tickers = sorted([t for t in tickers if t]) # 안전한 최종 정리 + + if not tickers: + raise RuntimeError("[load_all_tickers_from_db] 정리 후 유효한 티커가 없습니다. DB 데이터를 확인하세요.") + + if verbose: + _log(f"[DB] 모든 티커 로드 완료: {len(tickers)}개. 예시: {tickers[:10]}") + + return tickers + + # ============================================================================= # 1) 라벨링 정책 (분류) # ============================================================================= @@ -51,7 +133,7 @@ def _label_by_future_return(close: pd.Series, pred_h: int, hold_thr: float = 0.0 future = close.shift(-pred_h) r = (future / close) - 1.0 - buy = (r > hold_thr).astype(int) + buy = (r > hold_thr).astype(int) sell = (r < -hold_thr).astype(int) hold = ((r.abs() <= hold_thr) & r.notna()).astype(int) @@ -59,6 +141,7 @@ def _label_by_future_return(close: pd.Series, pred_h: int, hold_thr: float = 0.0 label = np.select([buy.eq(1), hold.eq(1), sell.eq(1)], [0, 1, 2], default=np.nan) return pd.Series(label, index=close.index, dtype="float") + # ============================================================================= # 2) 시퀀스/스케일링 유틸 # ============================================================================= @@ -79,6 +162,7 @@ def _build_sequences(feats: pd.DataFrame, use_cols: List[str], seq_len: int) -> return np.empty((0, seq_len, len(use_cols)), dtype="float32") return np.stack(X_list, axis=0) + def _align_labels(feats: pd.DataFrame, labels: pd.Series, seq_len: int) -> np.ndarray: """ 시퀀스 끝 시점에 대한 라벨을 맞추기 위해, 시퀀스 시작 오프셋(seq_len-1)만큼 라벨을 잘라서 정렬. @@ -86,6 +170,7 @@ def _align_labels(feats: pd.DataFrame, labels: pd.Series, seq_len: int) -> np.nd """ return labels.iloc[seq_len - 1 :].values + def _fit_scaler_on_train(X: np.ndarray) -> MinMaxScaler: """ 학습 데이터 전체 분포에 맞춰 스케일러 적합. @@ -98,6 +183,7 @@ def _fit_scaler_on_train(X: np.ndarray) -> MinMaxScaler: scaler.fit(X2) return scaler + def _apply_scaler(X: np.ndarray, scaler: MinMaxScaler) -> np.ndarray: """학습/검증/테스트에 동일 스케일 적용.""" n, s, f = X.shape @@ -105,6 +191,7 @@ def _apply_scaler(X: np.ndarray, scaler: MinMaxScaler) -> np.ndarray: X2 = scaler.transform(X2) return X2.reshape(n, s, f).astype("float32") + # ============================================================================= # 3) 학습 메인 파이프라인 # ============================================================================= @@ -122,7 +209,7 @@ def train_transformer_classifier( hold_thr: float = 0.003, batch_size: int = 64, epochs: int = 50, -) -> Dict[str, any]: +) -> Dict[str, Any]: """ Transformer 분류기 학습 파이프라인. - 입력: 원천 OHLCV(raw_data; 여러 티커 혼합 가능) @@ -138,7 +225,7 @@ def train_transformer_classifier( pred_h : int 미래 라벨링 지평(일수/캔들수) model_out_path : str - 최종 가중치 저장 경로(.h5 권장): 예) 'artifacts/transformer_cls.h5' + 최종 가중치 저장 경로(.h5 권장): scaler_out_path : str, optional 스케일러 저장 경로(.pkl). 추론 시 동일 스케일 사용을 원할 때 권장. tickers : list, optional @@ -155,12 +242,26 @@ def train_transformer_classifier( raise ValueError("raw_data가 비어있습니다.") df = raw_data.copy() - # 컬럼 소문자화(혼용 방지) + + # 컬럼 소문자화(혼용 방지). 단, tz 정보 유지를 위해 시점 컬럼 해석은 아래에서 별도로 처리 df = df.rename(columns={c: c.lower() for c in df.columns}) - ts_col = "ts_local" if "ts_local" in df.columns else ("date" if "date" in df.columns else None) + + # ⚠️ 시간대 컬럼 선택 정책: + # - 항상 'date'(UTC)를 우선 사용 → run_date 컷 등에서 일관성 확보 + # - 'date'가 없고 'ts_local'(Asia/Seoul)만 있으면 tz-aware로 변환/보존 + ts_col = "date" if "date" in df.columns else ("ts_local" if "ts_local" in df.columns else None) if ts_col is None: - raise ValueError("raw_data에 'ts_local' 또는 'date' 컬럼이 필요합니다.") - df[ts_col] = pd.to_datetime(df[ts_col], utc=True, errors="coerce") + raise ValueError("raw_data에 'date' 또는 'ts_local' 컬럼이 필요합니다.") + + if ts_col == "date": + # 'date'는 UTC 기준 타임스탬프로 파싱 + df[ts_col] = pd.to_datetime(df[ts_col], utc=True, errors="coerce") + else: + # 'ts_local'만 있는 경우: Asia/Seoul로 인식(naive면 현지 부여) + df[ts_col] = pd.to_datetime(df[ts_col], errors="coerce") + if df[ts_col].dt.tz is None: + df[ts_col] = df[ts_col].dt.tz_localize("Asia/Seoul") + if df[ts_col].isna().any(): raise ValueError("타임스탬프 파싱 중 NaT가 발생했습니다. 원본 데이터를 확인하세요.") @@ -168,15 +269,18 @@ def train_transformer_classifier( df["ticker"] = df["ticker"].astype(str) df = df[df["ticker"].isin([str(t) for t in tickers])] - # run_date 컷 (Asia/Seoul 기준) + # run_date 컷 (Asia/Seoul 기준 날짜 → UTC로 변환하여 비교) if run_date is not None: # Asia/Seoul 자정까지 포함되도록 끝점 계산 end_dt = pd.to_datetime(run_date).tz_localize("Asia/Seoul", nonexistent="shift_forward").normalize() end_dt = end_dt + pd.Timedelta(days=1) - pd.Timedelta(microseconds=1) - # df는 UTC → 동일 기준으로 비교 + # df는 UTC(date) 또는 Asia/Seoul(ts_local)일 수 있으므로 UTC로 변환해 비교 end_cut_utc = end_dt.tz_convert("UTC") - df = df[df[ts_col] <= end_cut_utc] + # 비교 시 df[ts_col]도 UTC 기준으로 맞춰 사용 + compare_ts = df[ts_col].dt.tz_convert("UTC") + df = df[compare_ts <= end_cut_utc] + # 정렬 df = df.sort_values(["ticker", ts_col]).reset_index(drop=True) # ---------- 피처 + 라벨 ---------- @@ -185,10 +289,11 @@ def train_transformer_classifier( X_all, y_all = [], [] for t, g in df.groupby("ticker", sort=False): + # 모델 피처 함수가 'date' 인덱스를 기대한다고 가정 g = g.rename(columns={ts_col: "date"}).set_index("date") ohlcv = g[["open", "high", "low", "close", "volume"]].copy() - # 사용자 정의 피처 빌드 + # 사용자 정의 피처 빌드 (FEATURES와 build_features는 프로젝트 모듈 제공) feats = build_features(ohlcv) # 반드시 'CLOSE_RAW' 포함한다고 가정 if len(feats) < (seq_len + pred_h + 1): # 시퀀스/라벨링 최소 길이 부족 시 스킵 @@ -233,7 +338,7 @@ def train_transformer_classifier( scaler = _fit_scaler_on_train(X) X = _apply_scaler(X, scaler) - # 클래스 불균형이 심할 수 있으니 stratify 분할 + # 클래스 불균형이 심할 수 있으니 stratify 분할 권장 X_train, X_val, y_train, y_val = train_test_split( X, y, test_size=test_size, random_state=random_state, stratify=y ) @@ -285,91 +390,99 @@ def train_transformer_classifier( "scaler_path": scaler_out_path } + # ============================================================================= -# 4) 야후 파이낸스 REST 폴백: OHLCV 수집 (requests) -# - yfinance SSL/차단 이슈를 피해 직접 엔드포인트 호출 +# DB에서 OHLCV 수집 (public.price_data) +# - 스키마: (ticker TEXT, date DATE, open/high/low/close NUMERIC, volume BIGINT, adjusted_close NUMERIC) +# - 대량 티커를 대비해 IN 절을 청크로 나눠 반복 조회 +# - 반환 컬럼: ['ticker','date(UTC tz-aware)','open','high','low','close','volume','ts_local(Asia/Seoul)'] # ============================================================================= -def _yahoo_interval_str(interval: str) -> str: - """ - 야후 차트 API interval 명세 검증/정규화. - - 허용: '1d','1h','1wk','1mo' 등 - """ - allowed = {"1m","2m","5m","15m","30m","60m","90m","1h","1d","5d","1wk","1mo","3mo"} - if interval not in allowed: - raise ValueError(f"지원하지 않는 interval: {interval} (허용: {sorted(allowed)})") - return interval - -def _fetch_yahoo_ohlcv( - ticker: str, - start: pd.Timestamp, - end: pd.Timestamp, - interval: str = "1d", - retries: int = 3, - sleep_sec: float = 1.0, +def _fetch_db_ohlcv_for_tickers( + tickers: List[str], + start_date: str, + end_date: str, + use_adjusted_close: bool = True, + chunk_size: int = 200, ) -> pd.DataFrame: """ - 야후 파이낸스 차트 API(v8)에서 OHLCV를 수집하여 DataFrame 반환. - - 요청 URL: https://query2.finance.yahoo.com/v8/finance/chart/{ticker} - - 파라미터: period1(UNIX), period2(UNIX), interval - - 반환 컬럼: ['ticker','date','open','high','low','close','volume','ts_local(Asia/Seoul)'] + DB에서 지정한 티커 리스트와 날짜 구간에 해당하는 OHLCV를 읽어 하나의 DataFrame으로 반환. - 주의 - ---- - * period1/period2는 초 단위 UNIX 타임스탬프. - * 반환 timeZone은 종목 거래소 기준이므로, ts_local은 Asia/Seoul로 별도 변환해서 제공. - * 프리마켓/서머타임 등 미세한 체결 시간 차이에 따른 분봉은 케이스별 확인 필요. + Parameters + ---------- + tickers : List[str] + 조회할 티커 목록(대문자/소문자 무관, 내부에서 그대로 비교) + start_date : str + 'YYYY-MM-DD' (price_data.date는 DATE 타입 기준) + end_date : str + 'YYYY-MM-DD' (포함 조건) + use_adjusted_close : bool + True면 adjusted_close가 있는 행은 그 값을 close로 대체(배당/분할 반영) + chunk_size : int + 너무 많은 티커로 IN 절이 길어지는 것을 방지하기 위한 청크 크기 + + Returns + ------- + DataFrame + ['ticker','date','open','high','low','close','volume','ts_local'] 정렬 완료 """ - interval = _yahoo_interval_str(interval) - base = "https://query2.finance.yahoo.com/v8/finance/chart/{}".format(ticker) - params = { - "period1": int(pd.Timestamp(start).tz_convert("UTC").timestamp()), - "period2": int(pd.Timestamp(end).tz_convert("UTC").timestamp()), - "interval": interval, - "events": "div,splits" - } - headers = { - "User-Agent": "Mozilla/5.0", - "Accept": "application/json, text/plain, */*", - "Connection": "keep-alive", - } - - last_err = None - for _ in range(retries): + conn = get_db_conn(db_name) + try: + frames = [] + # 티커 청크 분할 + for i in range(0, len(tickers), chunk_size): + chunk = tickers[i:i+chunk_size] + # IN 절 플레이스홀더 생성: (%s, %s, ..., %s) + placeholders = ",".join(["%s"] * len(chunk)) + sql = f""" + SELECT + ticker, + date, + open, + high, + low, + close, + volume, + adjusted_close + FROM public.price_data + WHERE date >= %s + AND date <= %s + AND ticker IN ({placeholders}) + ORDER BY ticker, date + """ + params = [start_date, end_date] + chunk + df = pd.read_sql(sql, conn, params=params) + if not df.empty: + frames.append(df) + if not frames: + return pd.DataFrame(columns=["ticker","date","open","high","low","close","volume","ts_local"]) + + out = pd.concat(frames, ignore_index=True) + + # ---- 데이터 정리: 숫자형 변환 (NUMERIC -> float) ---- + num_cols = ["open","high","low","close","volume","adjusted_close"] + for c in num_cols: + if c in out.columns: + out[c] = pd.to_numeric(out[c], errors="coerce") + + # ---- 조정 종가 적용 옵션 ---- + if use_adjusted_close and "adjusted_close" in out.columns: + # adjusted_close가 존재할 때만 대체(결측은 원래 close 유지) + out["close"] = np.where(out["adjusted_close"].notna(), out["adjusted_close"], out["close"]) + + # ---- 타임존 컬럼 구성 ---- + # DB의 date는 "캘린더 날짜"이므로 UTC 자정으로 타임스탬프화 + out["date"] = pd.to_datetime(out["date"], format="%Y-%m-%d", errors="coerce").dt.tz_localize("UTC") + out["ts_local"] = out["date"].dt.tz_convert("Asia/Seoul") + + # ---- 최종 컬럼/정렬 ---- + out = out[["ticker","date","open","high","low","close","volume","ts_local"]].sort_values(["ticker","date"]) + out = out.dropna(subset=["open","high","low","close"]) # 필수값 결측 제거 + return out.reset_index(drop=True) + finally: try: - resp = requests.get(base, params=params, headers=headers, timeout=10) - resp.raise_for_status() - data = resp.json() - result = data.get("chart", {}).get("result") - if not result: - raise ValueError(f"Yahoo 응답에 result가 없습니다: {data}") - result = result[0] - - ts_list = result["timestamp"] # 초 단위 UNIX - ind = pd.to_datetime(ts_list, unit="s", utc=True) - - q = result["indicators"]["quote"][0] - df = pd.DataFrame({ - "open": q.get("open"), - "high": q.get("high"), - "low": q.get("low"), - "close": q.get("close"), - "volume": q.get("volume"), - }, index=ind) - - # 기본 정리 - df = df.dropna(subset=["open","high","low","close"]) # 완전결측 제거 - df["ticker"] = str(ticker) - - # 로컬(Asia/Seoul) 타임스탬프 컬럼 별도 생성 - df["ts_local"] = df.index.tz_convert("Asia/Seoul") - - # date(UTC), ts_local 둘 다 보유 (학습코드는 ts_local/ date 둘 중 하나만 있으면 동작) - df = df.reset_index().rename(columns={"index": "date"}) - return df[["ticker","date","open","high","low","close","volume","ts_local"]] - except Exception as e: - last_err = e - time.sleep(sleep_sec) - raise RuntimeError(f"야후 차트 API 호출 실패: {last_err}") + conn.close() + except Exception: + pass # ============================================================================= # 5) 단독 실행(초기 가중치 생성)용 CLI 엔트리포인트 @@ -377,21 +490,28 @@ def _fetch_yahoo_ohlcv( def run_training(config: dict): """config 딕셔너리 기반 Transformer 학습 실행""" - # ---- 1) 데이터 수집 ---- - start = pd.Timestamp(config["start"], tz="Asia/Seoul").tz_convert("UTC") - end = pd.Timestamp(config["end"], tz="Asia/Seoul").tz_convert("UTC") + pd.Timedelta(days=1) - - frames = [] - for t in config["tickers"]: - _log(f"[FETCH] {t} {config['interval']} {config['start']}→{config['end']}") - df_t = _fetch_yahoo_ohlcv( - ticker=t, - start=start, - end=end, - interval=config["interval"] - ) - frames.append(df_t) - raw = pd.concat(frames, ignore_index=True) + # ---- (A) 사용할 티커 소스 결정 ---- + use_db = (config.get("tickers_source", "db") == "db") or (not config.get("tickers")) + if use_db: + tickers = load_all_tickers_from_db(verbose=True) # ← DB에서 "모든" 티커 + else: + tickers = [str(t).upper() for t in config["tickers"]] + _log(f"[CFG] 수동 입력 티커 {len(tickers)}개 사용: {tickers[:8]}...") + + # ---- 1) 데이터 수집: DB에서 가격 읽기 ---- + # - config["start"], ["end"]는 'YYYY-MM-DD' 문자열로 받았다고 가정 + # - price_data.date (DATE)와 동일 포맷이므로 그대로 전달 + use_adjusted = bool(config.get("use_adjusted_close", True)) + raw = _fetch_db_ohlcv_for_tickers( + tickers=tickers, + start_date=config["start"], + end_date=config["end"], + use_adjusted_close=use_adjusted, + chunk_size=int(config.get("db_chunk_size", 200)), + ) + + if raw.empty: + raise RuntimeError("[run_training] DB에서 아무 데이터도 읽히지 않았습니다. 기간/티커/DB 내용을 확인하세요.") # ---- 2) 학습 ---- os.makedirs(os.path.dirname(config["model_out"]), exist_ok=True) @@ -403,7 +523,7 @@ def run_training(config: dict): pred_h=config["pred_h"], model_out_path=config["model_out"], scaler_out_path=config["scaler_out"], - tickers=config["tickers"], + tickers=tickers, # 실제 학습대상 기록 run_date=config.get("run_date"), test_size=config["test_size"], hold_thr=config["hold_thr"], @@ -420,21 +540,33 @@ def run_training(config: dict): if __name__ == "__main__": - # ⚙️ 여기에 원하는 설정만 바꾸면 됨 + config = { - "tickers": ["AAPL", "MSFT"], # 학습 대상 종목 - "start": "2018-01-01", # 시작일 - "end": "2025-10-31", # 종료일 - "interval": "1d", # 일봉 - "seq_len": 64, # 시퀀스 길이 - "pred_h": 5, # 예측 지평(미래 5일) - "hold_thr": 0.003, # HOLD 임계치 - "test_size": 0.2, # 검증셋 비율 - "epochs": 3, # 에폭 수 - "batch_size": 128, # 배치 크기 - "model_out": "transformer/weights/initial.weights.h5", # 가중치 저장 - "scaler_out": "transformer/scaler/scaler.pkl", # 스케일러 저장 - "run_date": None, # 특정 날짜까지만 사용할 경우 지정 + # --- 데이터/티커 소스 --- + "tickers_source": "db", # 티커: DB에서 전체 로드 + "use_adjusted_close": True, # adjusted_close가 있으면 close로 사용 + "db_chunk_size": 200, # IN 청크 크기(파라미터/성능 균형) + + # --- 기간/빈도 --- + "start": "2018-01-01", # DB DATE 기준 (YYYY-MM-DD) + "end": "2024-10-31", + + # --- 시퀀스/라벨 --- + "seq_len": 128, + "pred_h": 7, + "hold_thr": 0.004, + + # --- 학습/평가 --- + "test_size": 0.2, + "epochs": 50, + "batch_size": 512, + + # --- 출력 경로 --- + "model_out": "AI/transformer/weights/initial.weights.h5", + "scaler_out": "AI/transformer/scaler/scaler.pkl", + + # --- 기타 --- + "run_date": None, } run_training(config) diff --git a/AI/transformer/weights/initial.weights.h5 b/AI/transformer/weights/initial.weights.h5 index 52a04ebd..8f3c1302 100644 Binary files a/AI/transformer/weights/initial.weights.h5 and b/AI/transformer/weights/initial.weights.h5 differ diff --git a/transformer/scaler/scaler.pkl b/transformer/scaler/scaler.pkl deleted file mode 100644 index b71b9ba4..00000000 Binary files a/transformer/scaler/scaler.pkl and /dev/null differ