Skip to content

Commit

Permalink
Merge pull request #582 from Lumiwealth/full_chain_fixes
Browse files Browse the repository at this point in the history
backtest: Small fixes to get_full_chain_info function when running in…
  • Loading branch information
grzesir authored Oct 16, 2024
2 parents 54170dd + b0058a8 commit b948b08
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions lumibot/data_sources/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime, timedelta
import time

import pandas as pd

Expand Down Expand Up @@ -417,15 +418,18 @@ def get_chain_full_info(self, asset: Asset, expiry: str, chains=None, underlying
A DataFrame containing the full chain information for the option asset. Greeks columns will be named as
'greeks.delta', 'greeks.theta', etc.
"""
start_t = time.perf_counter()
# Base level DataSource assumes that the data source does not support this and the greeks will be calculated
# locally. Subclasses can override this method to provide a more efficient implementation.
expiry_dt = datetime.strptime(expiry, "%Y-%m-%d") if isinstance(expiry, str) else expiry
expiry_str = expiry_dt.strftime("%Y-%m-%d")
if chains is None:
chains = self.get_chains(asset)

rows = []
query_total = 0
for right in chains["Chains"]:
for strike in chains["Chains"][right][expiry]:
for strike in chains["Chains"][right][expiry_str]:
# Skip strikes outside the requested range. Saves querying time.
if strike_min and strike < strike_min or strike_max and strike > strike_max:
continue
Expand All @@ -438,15 +442,17 @@ def get_chain_full_info(self, asset: Asset, expiry: str, chains=None, underlying
strike=strike,
right=right,
)
query_t = time.perf_counter()
option_symbol = create_options_symbol(opt_asset.symbol, expiry_dt, right, strike)
opt_price = self.get_last_price(opt_asset)
greeks = self.calculate_greeks(opt_asset, opt_price, underlying_price, risk_free_rate)
query_total += time.perf_counter() - query_t

# Build the row. Match the Tradier column naming conventions.
row = {
"symbol": option_symbol,
"last": opt_price,
"expiration_date": expiry,
"expiration_date": expiry_dt,
"strike": strike,
"option_type": right,
"underlying": opt_asset.symbol,
Expand All @@ -464,7 +470,10 @@ def get_chain_full_info(self, asset: Asset, expiry: str, chains=None, underlying
row.update({f"greeks.{col}": val for col, val in greeks.items()})
rows.append(row)

return pd.DataFrame(rows).sort_values("strike")
logging.info(f"Chain Full Info Query Total: {query_total:.2f}s. "
f"Total Time: {time.perf_counter() - start_t:.2f}s, "
f"Rows: {len(rows)}")
return pd.DataFrame(rows).sort_values("strike") if rows else pd.DataFrame()

def calculate_greeks(
self,
Expand Down

0 comments on commit b948b08

Please sign in to comment.