Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
256 changes: 140 additions & 116 deletions AI/libs/core/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,126 +1,150 @@
import os
import sys
from typing import List, Dict
import json
import datetime as dt
from datetime import datetime, timedelta
import pandas as pd
from typing import Dict

# ==============================================
# 내부 모듈 (이미 구현돼 있다고 가정)
# ==============================================
from finder.modules.finder import run_finder_with_scores # 종목+점수 매기기 포함
from transform.modules.transform import run_transform
from xai.modules.xai import run_xai
from AI.libs.utils.data import fetch_ohlcv
from AI.libs.utils.io import _log
# --- 프로젝트 루트 경로 설정 ---
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(project_root)
# ------------------------------

# ==============================================
# Helper: Finder 결과 → JSON 변환
# ==============================================
def make_reasons_json(finder_df: pd.DataFrame, run_date: str) -> Dict:
# --- 모듈 import ---
from finder.main import run_finder
from transform.modules.main import run_transform
from libs.utils.data.fetch_ohlcv import fetch_ohlcv
from xai.run_xai import run_xai
# ---------------------------------

def run_weekly_finder() -> List[str]:
"""
Finder 결과에서 종목 선택 이유를 JSON 구조로 변환
{ "YYYY-MM-DD": { "TICKER1": "이유 요약", "TICKER2": "..." } }
주간 종목 발굴(Finder)을 실행하고 결과(종목 리스트)를 반환합니다.
"""
reasons = {}
daily_reasons = {}
for _, row in finder_df.iterrows():
daily_reasons[row["ticker"]] = row.get("reason", "선정 사유 없음")
reasons[run_date] = daily_reasons
return reasons

# ==============================================
# 주간 Finder (월요일 1회)
# ==============================================
def run_weekly_finder(config: dict, run_date: str) -> pd.DataFrame:
_log(f"[FINDER] 주간 종목 선정 실행 ({run_date})")

finder_df = run_finder_with_scores(config) # 종목+점수+이유 포함 DataFrame

out_dir = os.path.join(config["storage"]["out_dir"], "finder")
os.makedirs(out_dir, exist_ok=True)

# parquet 저장
finder_path = os.path.join(out_dir, f"finder_{run_date}.parquet")
finder_df.to_parquet(finder_path, index=False)

# JSON 이유 저장 (append)
reasons_path = os.path.join(out_dir, "reasons.json")
reasons = make_reasons_json(finder_df, run_date)
if os.path.exists(reasons_path):
with open(reasons_path, "r", encoding="utf-8") as f:
prev = json.load(f)
else:
prev = {}
prev.update(reasons)
with open(reasons_path, "w", encoding="utf-8") as f:
json.dump(prev, f, ensure_ascii=False, indent=2)
print("--- [PIPELINE-STEP 1] Finder 모듈 실행 시작 ---")
top_tickers = run_finder()
# top_tickers = ['AAPL', 'MSFT', 'GOOGL'] # 임시 데이터
print(f"--- [PIPELINE-STEP 1] Finder 모듈 실행 완료 ---")
return top_tickers

return finder_df

# ==============================================
# 일간 Transform + XAI
# ==============================================
def run_daily_tasks(config: dict, run_date: str, finder_df: pd.DataFrame) -> None:
_log(f"[DAILY] Transform + XAI 실행 ({run_date})")

# 데이터 수집
tickers = finder_df["ticker"].tolist()
window_days = int(config.get("data", {}).get("window_days", 252 * 5))
interval = str(config.get("data", {}).get("interval", "1d"))
cache_dir = str(config.get("storage", {}).get("cache_dir", ""))

market_data = fetch_ohlcv(tickers, period_days=window_days, interval=interval, cache_dir=cache_dir)

# Transform (학습 + 로그 생성)
tr = run_transform(
finder_df,
seq_len=config["transform"]["seq_len"],
pred_h=config["transform"]["pred_h"],
def run_signal_transform(tickers: List[str], config: Dict) -> pd.DataFrame:
"""
종목 리스트를 받아 Transform 모듈을 실행하고, 신호(결정 로그)를 반환합니다.
"""
print("--- [PIPELINE-STEP 2] Transform 모듈 실행 시작 ---")

# --- 실제 Transform 모듈 호출 ---
end_date = datetime.now()
start_date = end_date - timedelta(days=600)
all_ohlcv_df = []
for ticker in tickers:
ohlcv_df = fetch_ohlcv(
ticker=ticker,
start=start_date.strftime('%Y-%m-%d'),
end=end_date.strftime('%Y-%m-%d'),
config=config
)
ohlcv_df['ticker'] = ticker
all_ohlcv_df.append(ohlcv_df)
if not all_ohlcv_df:
print("OHLCV 데이터를 가져오지 못했습니다.")
return pd.DataFrame()
raw_data = pd.concat(all_ohlcv_df, ignore_index=True)
finder_df = pd.DataFrame(tickers, columns=['ticker'])
transform_result = run_transform(
finder_df=finder_df,
seq_len=60,
pred_h=1,
raw_data=raw_data,
config=config
)
logs_df: pd.DataFrame = tr["logs"] # (종목,날짜,매매여부,가격,비중,피쳐...,확률...)

# Transform 로그 저장 (Parquet)
out_dir = os.path.join(config["storage"]["out_dir"], "transform")
os.makedirs(out_dir, exist_ok=True)
log_path = os.path.join(out_dir, f"logs_{run_date}.parquet")
logs_df.to_parquet(log_path, index=False)

# XAI 리포트 생성 + 저장 (JSON per ticker)
xai_out_dir = os.path.join(config["storage"]["out_dir"], "xai", run_date)
os.makedirs(xai_out_dir, exist_ok=True)

xai_reports = run_xai(logs_df)
for ticker, report in xai_reports.items():
with open(os.path.join(xai_out_dir, f"{ticker}.json"), "w", encoding="utf-8") as f:
json.dump(report, f, ensure_ascii=False, indent=2)

_log(f"[DAILY] Transform 로그 + XAI 저장 완료 ({run_date})")

# ==============================================
# 메인 파이프라인
# ==============================================
def run_pipeline(config: dict) -> bool:
run_date = dt.datetime.now(dt.timezone(dt.timedelta(hours=9))).strftime("%Y-%m-%d")

logs_df = transform_result.get("logs", pd.DataFrame())
Comment on lines +30 to +61
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

DB 설정 누락 시 즉시 크래시(핵심 경로 가드 필요).

config=None인 상태에서 fetch_ohlcv 호출 시 TypeError: 'NoneType' is not subscriptable로 파이프라인이 중단됩니다. 설정 부재 시 조기 반환하도록 가드하세요. 이는 데모/테스트 실행 안정성에 중요합니다.

-def run_signal_transform(tickers: List[str], config: Dict) -> pd.DataFrame:
+def run_signal_transform(tickers: List[str], config: Dict | None) -> pd.DataFrame:
@@
-    print("--- [PIPELINE-STEP 2] Transform 모듈 실행 시작 ---")
+    print("--- [PIPELINE-STEP 2] Transform 모듈 실행 시작 ---")
+    if not config or "db" not in config:
+        print("[WARN] DB 설정이 없어 Transform 단계를 건너뜁니다.")
+        return pd.DataFrame()
@@
-    for ticker in tickers:
-        ohlcv_df = fetch_ohlcv(
-            ticker=ticker,
-            start=start_date.strftime('%Y-%m-%d'),
-            end=end_date.strftime('%Y-%m-%d'),
-            config=config
-        )
-        ohlcv_df['ticker'] = ticker
-        all_ohlcv_df.append(ohlcv_df)
+    for ticker in tickers:
+        try:
+            ohlcv_df = fetch_ohlcv(
+                ticker=ticker,
+                start=start_date.strftime('%Y-%m-%d'),
+                end=end_date.strftime('%Y-%m-%d'),
+                config=config,
+            )
+            if ohlcv_df.empty:
+                continue
+            ohlcv_df['ticker'] = ticker
+            all_ohlcv_df.append(ohlcv_df)
+        except Exception as e:
+            print(f"[WARN] {ticker} OHLCV 조회 실패: {e}")
+            continue
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def run_signal_transform(tickers: List[str], config: Dict) -> pd.DataFrame:
"""
종목 리스트를 받아 Transform 모듈을 실행하고, 신호(결정 로그) 반환합니다.
"""
print("--- [PIPELINE-STEP 2] Transform 모듈 실행 시작 ---")
# --- 실제 Transform 모듈 호출 ---
end_date = datetime.now()
start_date = end_date - timedelta(days=600)
all_ohlcv_df = []
for ticker in tickers:
ohlcv_df = fetch_ohlcv(
ticker=ticker,
start=start_date.strftime('%Y-%m-%d'),
end=end_date.strftime('%Y-%m-%d'),
config=config
)
ohlcv_df['ticker'] = ticker
all_ohlcv_df.append(ohlcv_df)
if not all_ohlcv_df:
print("OHLCV 데이터를 가져오지 못했습니다.")
return pd.DataFrame()
raw_data = pd.concat(all_ohlcv_df, ignore_index=True)
finder_df = pd.DataFrame(tickers, columns=['ticker'])
transform_result = run_transform(
finder_df=finder_df,
seq_len=60,
pred_h=1,
raw_data=raw_data,
config=config
)
logs_df: pd.DataFrame = tr["logs"] # (종목,날짜,매매여부,가격,비중,피쳐...,확률...)
# Transform 로그 저장 (Parquet)
out_dir = os.path.join(config["storage"]["out_dir"], "transform")
os.makedirs(out_dir, exist_ok=True)
log_path = os.path.join(out_dir, f"logs_{run_date}.parquet")
logs_df.to_parquet(log_path, index=False)
# XAI 리포트 생성 + 저장 (JSON per ticker)
xai_out_dir = os.path.join(config["storage"]["out_dir"], "xai", run_date)
os.makedirs(xai_out_dir, exist_ok=True)
xai_reports = run_xai(logs_df)
for ticker, report in xai_reports.items():
with open(os.path.join(xai_out_dir, f"{ticker}.json"), "w", encoding="utf-8") as f:
json.dump(report, f, ensure_ascii=False, indent=2)
_log(f"[DAILY] Transform 로그 + XAI 저장 완료 ({run_date})")
# ==============================================
# 메인 파이프라인
# ==============================================
def run_pipeline(config: dict) -> bool:
run_date = dt.datetime.now(dt.timezone(dt.timedelta(hours=9))).strftime("%Y-%m-%d")
logs_df = transform_result.get("logs", pd.DataFrame())
def run_signal_transform(tickers: List[str], config: Dict | None) -> pd.DataFrame:
"""
종목 리스트를 받아 Transform 모듈을 실행하고, 신호(결정 로그) 반환합니다.
"""
print("--- [PIPELINE-STEP 2] Transform 모듈 실행 시작 ---")
if not config or "db" not in config:
print("[WARN] DB 설정이 없어 Transform 단계를 건너뜁니다.")
return pd.DataFrame()
# --- 실제 Transform 모듈 호출 ---
end_date = datetime.now()
start_date = end_date - timedelta(days=600)
all_ohlcv_df = []
for ticker in tickers:
try:
ohlcv_df = fetch_ohlcv(
ticker=ticker,
start=start_date.strftime('%Y-%m-%d'),
end=end_date.strftime('%Y-%m-%d'),
config=config,
)
if ohlcv_df.empty:
continue
ohlcv_df['ticker'] = ticker
all_ohlcv_df.append(ohlcv_df)
except Exception as e:
print(f"[WARN] {ticker} OHLCV 조회 실패: {e}")
continue
if not all_ohlcv_df:
print("OHLCV 데이터를 가져오지 못했습니다.")
return pd.DataFrame()
raw_data = pd.concat(all_ohlcv_df, ignore_index=True)
finder_df = pd.DataFrame(tickers, columns=['ticker'])
transform_result = run_transform(
finder_df=finder_df,
seq_len=60,
pred_h=1,
raw_data=raw_data,
config=config
)
logs_df = transform_result.get("logs", pd.DataFrame())
🤖 Prompt for AI Agents
In AI/libs/core/pipeline.py around lines 30 to 61, the function
run_signal_transform assumes config is present and passes it to fetch_ohlcv,
causing a TypeError when config is None; add an early guard at the top of the
function that checks for a missing or invalid config (e.g., if config is None or
required keys like DB/settings are absent), log/print a clear message about the
missing configuration, and immediately return an empty pd.DataFrame to avoid
calling fetch_ohlcv with None; this prevents crashes for demo/test runs and
ensures the pipeline fails gracefully.


# --- 임시 결정 로그 데이터 (주석 처리) ---
# data = {
# 'ticker': ['AAPL', 'GOOGL', 'MSFT'],
# 'date': ['2025-09-17', '2025-09-17', '2025-09-17'],
# 'action': ['SELL', 'BUY', 'SELL'],
# 'price': [238.99, 249.52, 510.01],
# 'weight': [0.16, 0.14, 0.15],
# 'feature1': ['RSI', 'Stochastic', 'MACD'],
# 'feature2': ['MACD', 'MA_5', 'ATR'],
# 'feature3': ['Bollinger_Bands_lower', 'RSI', 'MA_200'],
# 'prob1': [0.5, 0.4, 0.6],
# 'prob2': [0.3, 0.25, 0.2],
# 'prob3': [0.1, 0.15, 0.1]
# }
# logs_df = pd.DataFrame(data)

print(f"--- [PIPELINE-STEP 2] Transform 모듈 실행 완료 ---")
return logs_df

def run_xai_report(decision_log: pd.DataFrame) -> List[str]:
"""
결정 로그를 바탕으로 실제 XAI 리포트를 생성합니다.
"""
print("--- [PIPELINE-STEP 3] XAI 리포트 생성 시작 ---")
api_key = os.environ.get("GROQ_API_KEY")
if not api_key:
raise ValueError("XAI 리포트 생성을 위해 GROQ_API_KEY 환경 변수를 설정해주세요.")
reports = []
if decision_log.empty:
return reports
for _, row in decision_log.iterrows():
decision = {
"ticker": row['ticker'],
"date": row['date'],
"signal": row['action'],
"price": row['price'],
"evidence": [
{"feature_name": row['feature1'], "contribution": row['prob1']},
{"feature_name": row['feature2'], "contribution": row['prob2']},
{"feature_name": row['feature3'], "contribution": row['prob3']},
]
}
try:
report = run_xai(decision, api_key)
reports.append(report)
print(f"--- {row['ticker']} XAI 리포트 생성 완료 ---")
except Exception as e:
error_message = f"--- {row['ticker']} XAI 리포트 생성 중 오류 발생: {e} ---"
print(error_message)
reports.append(error_message)
print(f"--- [PIPELINE-STEP 3] XAI 리포트 생성 완료 ---")
return reports

# --- 전체 파이프라인 실행 ---
def run_pipeline():
"""
전체 파이프라인(Finder -> Transform -> XAI)을 실행합니다.
"""
config = None
try:
_log(f"=== 배치 시작: {run_date} ===")

# 1) 주간 Finder (월요일만 새로 실행)
finder_out_dir = os.path.join(config["storage"]["out_dir"], "finder")
if dt.datetime.now().weekday() == 0: # 월요일
finder_df = run_weekly_finder(config, run_date)
else:
last_file = sorted(
[f for f in os.listdir(finder_out_dir) if f.startswith("finder_")]
)[-1]
finder_df = pd.read_parquet(os.path.join(finder_out_dir, last_file))

# 2) 일간 Transform + XAI
run_daily_tasks(config, run_date, finder_df)

_log("=== 배치 성공 ===")
return True

except Exception as e:
_log(f"[ERROR] 배치 실패: {e}")
return False
with open(os.path.join(project_root, 'configs', 'config.json'), 'r') as f:
config = json.load(f)
except FileNotFoundError:
print("[WARN] configs/config.json 파일을 찾을 수 없어 DB 연결이 필요 없는 기능만 작동합니다.")
top_tickers = run_weekly_finder()
if not top_tickers:
print("Finder에서 종목을 찾지 못해 파이프라인을 중단합니다.")
return None
decision_log = run_signal_transform(top_tickers, config)
if decision_log.empty:
print("Transform에서 신호를 생성하지 못해 파이프라인을 중단합니다.")
return None
xai_reports = run_xai_report(decision_log)
return xai_reports

# --- 테스트를 위한 실행 코드 ---
if __name__ == "__main__":
print(">>> 파이프라인 (Finder -> Transform -> XAI) 테스트를 시작합니다.")
final_reports = run_pipeline()
print("\n>>> 최종 반환 결과 (XAI Reports):")
if final_reports:
for report in final_reports:
print(report)
else:
print("생성된 리포트가 없습니다.")
print("\n---")
print("테스트가 정상적으로 완료되었다면, 위 '최종 반환 결과'에 각 종목에 대한 XAI 리포트가 출력되어야 합니다.")
print("---")
5 changes: 2 additions & 3 deletions AI/libs/utils/data/fetch_ohlcv.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,14 @@ def fetch_ohlcv(

query = """
SELECT date, open, high, low, close, volume
FROM stock_prices
FROM public.price_data
WHERE ticker = %s
AND interval = %s
AND date BETWEEN %s AND %s
ORDER BY date;
"""

# 파라미터 바인딩 (%s) 사용 → SQL injection 방지
df = pd.read_sql(query, conn, params=(ticker, interval, start, end))
df = pd.read_sql(query, conn, params=(ticker, start, end))

conn.close()
return df
Loading