diff --git a/README.md b/README.md index 66c5d8f..fdb3f04 100644 --- a/README.md +++ b/README.md @@ -63,15 +63,28 @@ when using without login, following warning will be shown `you are using nologin To download the data use `tv.get_hist` method. -It accepts following arguments and returns pandas dataframe +It accepts following arguments and returns pandas dataframe if `dataFrame` is set to True (default) to get pandas DataDrame, if False it will return data in list format + ```python -(symbol: str, exchange: str = 'NSE', interval: Interval = Interval.in_daily, n_bars: int = 10, fut_contract: int | None = None, extended_session: bool = False) -> DataFrame) +(symbol: str|List[str], exchange: str = 'NSE', interval: Interval = Interval.in_daily, n_bars: int = 10, dataFrame: bool = True, fut_contract: int | None = None, extended_session: bool = False) -> pd.DataFrame|Dict[str, List[List]|pd.DataFrame]|List[List]) ``` +Note: If symbol (str) given it will return DataFrame or List of historical data of the symbol. + If List of symbols is passed to `tv.get_hist` it will return python Dictionary in {'symbol': Data, ......} format. + For multiple symbols, it fetches data asynchronously to get faster results. + for example- ```python +symbols = ['SBIN', 'EICHERMOT', 'INFY', 'BHARTIARTL', 'NESTLEIND', 'ASIANPAINT', 'ITC'] + +# returns {symbol1: pd DataFrame, symbol2: pd DataFrame, .....} +results = tv.get_hist(symbols, "NSE", n_bars=500) + +# returns {symbol1: [[Timestamp, open, high, low, close, volume], .....], symbol2: [[Timestamp, open, high, low, close, volume], .....], .....} +results = tv.get_hist(symbols, "NSE", n_bars=500, dataFrame=False) + # index nifty_index_data = tv.get_hist(symbol='NIFTY',exchange='NSE',interval=Interval.in_1_hour,n_bars=1000) @@ -85,6 +98,12 @@ crudeoil_data = tv.get_hist(symbol='CRUDEOIL',exchange='MCX',interval=Interval.i extended_price_data = tv.get_hist(symbol="EICHERMOT",exchange="NSE",interval=Interval.in_1_hour,n_bars=500, extended_session=False) ``` +To use in Ipython notebooks, add these lines at first +```python +import nest_asyncio # To run asyncio in a notebook environment +nest_asyncio.apply() # Enable asyncio in a notebook environment +``` + --- ## Search Symbol diff --git a/requirements.txt b/requirements.txt index a48b07b..185cf39 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ setuptools~=49.2.0 pandas~=1.0.5 -websocket-client~=0.57.0 +websockets~=14.1 requests \ No newline at end of file diff --git a/setup.py b/setup.py index 8c8eac4..f9a245c 100644 --- a/setup.py +++ b/setup.py @@ -24,7 +24,7 @@ install_requires=[ "setuptools", "pandas", - "websocket-client", + "websockets", "requests" ], ) diff --git a/tvDatafeed/main.py b/tvDatafeed/main.py index de77ebc..457ea75 100644 --- a/tvDatafeed/main.py +++ b/tvDatafeed/main.py @@ -1,3 +1,4 @@ +from typing import Dict, List import datetime import enum import json @@ -6,13 +7,14 @@ import re import string import pandas as pd -from websocket import create_connection import requests import json +import logging +import asyncio +from websockets import connect # Replaced `websocket` with `websockets` logger = logging.getLogger(__name__) - class Interval(enum.Enum): in_1_minute = "1" in_3_minute = "3" @@ -81,12 +83,6 @@ def __auth(self, username, password): return token - def __create_connection(self): - logging.debug("creating websocket connection") - self.ws = create_connection( - "wss://data.tradingview.com/socket.io/websocket", headers=self.__ws_headers, timeout=self.__ws_timeout - ) - @staticmethod def __filter_raw_message(text): try: @@ -131,41 +127,45 @@ def __send_message(self, func, args): self.ws.send(m) @staticmethod - def __create_df(raw_data, symbol): - try: - out = re.search('"s":\[(.+?)\}\]', raw_data).group(1) - x = out.split(',{"') - data = list() - volume_data = True + def __parse_data(raw_data, is_return_dataframe:bool) -> List[List]: + out = re.search('"s":\[(.+?)\}\]', raw_data).group(1) + x = out.split(',{"') + data = list() + volume_data = True - for xi in x: - xi = re.split("\[|:|,|\]", xi) - ts = datetime.datetime.fromtimestamp(float(xi[4])) + for xi in x: + xi = re.split("\[|:|,|\]", xi) + ts = datetime.datetime.fromtimestamp(float(xi[4])) if is_return_dataframe else int(xi[4].split('.')[0]) - row = [ts] + row = [ts] - for i in range(5, 10): + for i in range(5, 10): - # skip converting volume data if does not exists - if not volume_data and i == 9: - row.append(0.0) - continue - try: - row.append(float(xi[i])) + # skip converting volume data if does not exists + if not volume_data and i == 9: + row.append(0.0) + continue + try: + row.append(float(xi[i])) - except ValueError: - volume_data = False - row.append(0.0) - logger.debug('no volume data') + except ValueError: + volume_data = False + row.append(0.0) + logger.debug('no volume data') - data.append(row) + data.append(row) - data = pd.DataFrame( - data, columns=["datetime", "open", + return data + + @staticmethod + def __create_df(parsed_data, symbol) -> pd.DataFrame: + try: + df = pd.DataFrame( + parsed_data, columns=["datetime", "open", "high", "low", "close", "volume"] ).set_index("datetime") - data.insert(0, "symbol", value=symbol) - return data + df.insert(0, "symbol", value=symbol) + return df except AttributeError: logger.error("no data, please check the exchange and symbol") @@ -185,136 +185,117 @@ def __format_symbol(symbol, exchange, contract: int = None): return symbol - def get_hist( - self, - symbol: str, - exchange: str = "NSE", - interval: Interval = Interval.in_daily, - n_bars: int = 10, - fut_contract: int = None, - extended_session: bool = False, - ) -> pd.DataFrame: - """get historical data + async def __fetch_symbol_data(self, symbol: str, exchange: str, interval: Interval, n_bars: int, fut_contract: int, extended_session: bool, dataFrame: bool) -> pd.DataFrame|List[List]: + """Helper function to asynchronously fetch symbol data.""" + try: + symbol = self.__format_symbol(symbol, exchange, fut_contract) + interval = interval.value + + async with connect( + "wss://data.tradingview.com/socket.io/websocket", + origin="https://data.tradingview.com" + ) as websocket: + # Authentication and session setup + await websocket.send(self.__create_message("set_auth_token", [self.token])) + await websocket.send(self.__create_message("chart_create_session", [self.chart_session, ""])) + await websocket.send(self.__create_message("quote_create_session", [self.session])) + await websocket.send(self.__create_message( + "quote_set_fields", + [ + self.session, + "ch", "chp", "current_session", "description", + "local_description", "language", "exchange", + "fractional", "is_tradable", "lp", "lp_time", + "minmov", "minmove2", "original_name", "pricescale", + "pro_name", "short_name", "type", "update_mode", "volume", + "currency_code", "rchp", "rtc", + ] + )) + await websocket.send(self.__create_message("quote_add_symbols", [self.session, symbol, {"flags": ["force_permission"]}])) + await websocket.send(self.__create_message("quote_fast_symbols", [self.session, symbol])) + + # Symbol resolution and series creation + await websocket.send( + self.__create_message( + "resolve_symbol", + [ + self.chart_session, + "symbol_1", + f'={{"symbol":"{symbol}","adjustment":"splits","session":"{"regular" if not extended_session else "extended"}"}}', + ], + ) + ) + await websocket.send(self.__create_message("create_series", [self.chart_session, "s1", "s1", "symbol_1", interval, n_bars])) + await websocket.send(self.__create_message("switch_timezone", [self.chart_session, "exchange"])) + + raw_data = "" + + # Fetch and parse raw data asynchronously + while True: + try: + result = await websocket.recv() + raw_data += result + "\n" + except Exception as e: + logger.error(e) + break + + if "series_completed" in result: + break + + # Return formatted data + if dataFrame: + parsed_data = self.__parse_data(raw_data, dataFrame) + return self.__create_df(parsed_data, symbol) + else: + return self.__parse_data(raw_data, dataFrame) + except Exception as e: + logger.error(f"Error fetching data for {symbol}: {e}") + return None + + async def get_hist_async(self, symbols: list[str], exchange: str = "NSE", interval: Interval = Interval.in_daily, n_bars: int = 10, dataFrame: bool = True, fut_contract: int = None, extended_session: bool = False) -> Dict[str, List[List]|pd.DataFrame]: + """Fetch historical data for multiple symbols asynchronously.""" + tasks = [ + self.__fetch_symbol_data(symbol, exchange, interval, n_bars, fut_contract, extended_session, dataFrame) + for symbol in symbols + ] + results = await asyncio.gather(*tasks) + + return {sym: data for sym, data in zip(symbols, results)} + + def get_hist(self, symbols: list[str]|str, exchange: str = "NSE", interval: Interval = Interval.in_daily, n_bars: int = 10, dataFrame: bool = True, fut_contract: int = None, extended_session: bool = False) -> pd.DataFrame|Dict[str, List[List]|pd.DataFrame]|List[List]: + """Fetch historical data for a single or multiple symbols. Args: - symbol (str): symbol name - exchange (str, optional): exchange, not required if symbol is in format EXCHANGE:SYMBOL. Defaults to None. - interval (str, optional): chart interval. Defaults to 'D'. - n_bars (int, optional): no of bars to download, max 5000. Defaults to 10. - fut_contract (int, optional): None for cash, 1 for continuous current contract in front, 2 for continuous next contract in front . Defaults to None. - extended_session (bool, optional): regular session if False, extended session if True, Defaults to False. + symbols (list[str] | str): Single symbol or list of symbols. + exchange (str, optional): Exchange. Defaults to "NSE". + interval (Interval, optional): Interval. Defaults to Interval.in_daily. + n_bars (int, optional): Number of bars. Defaults to 10. + dataFrame (bool, optional): Return as DataFrame. Defaults to True. + fut_contract (int, optional): Future contract. Defaults to None. + extended_session (bool, optional): Extended session. Defaults to False. Returns: - pd.Dataframe: dataframe with sohlcv as columns + pd.DataFrame | Dict[str, List[List] | pd.DataFrame] | List[List]: Historical data. """ - symbol = self.__format_symbol( - symbol=symbol, exchange=exchange, contract=fut_contract - ) - - interval = interval.value - - self.__create_connection() - - self.__send_message("set_auth_token", [self.token]) - self.__send_message("chart_create_session", [self.chart_session, ""]) - self.__send_message("quote_create_session", [self.session]) - self.__send_message( - "quote_set_fields", - [ - self.session, - "ch", - "chp", - "current_session", - "description", - "local_description", - "language", - "exchange", - "fractional", - "is_tradable", - "lp", - "lp_time", - "minmov", - "minmove2", - "original_name", - "pricescale", - "pro_name", - "short_name", - "type", - "update_mode", - "volume", - "currency_code", - "rchp", - "rtc", - ], - ) + if isinstance(symbols, str): + return asyncio.run(self.__fetch_symbol_data(symbols, exchange, interval, n_bars, fut_contract, extended_session, dataFrame)) - self.__send_message( - "quote_add_symbols", [self.session, symbol, - {"flags": ["force_permission"]}] - ) - self.__send_message("quote_fast_symbols", [self.session, symbol]) - - self.__send_message( - "resolve_symbol", - [ - self.chart_session, - "symbol_1", - '={"symbol":"' - + symbol - + '","adjustment":"splits","session":' - + ('"regular"' if not extended_session else '"extended"') - + "}", - ], - ) - self.__send_message( - "create_series", - [self.chart_session, "s1", "s1", "symbol_1", interval, n_bars], - ) - self.__send_message("switch_timezone", [ - self.chart_session, "exchange"]) - - raw_data = "" - - logger.debug(f"getting data for {symbol}...") - while True: - try: - result = self.ws.recv() - raw_data = raw_data + result + "\n" - except Exception as e: - logger.error(e) - break - - if "series_completed" in result: - break - - return self.__create_df(raw_data, symbol) - - def search_symbol(self, text: str, exchange: str = ''): - url = self.__search_url.format(text, exchange) - - symbols_list = [] - try: - resp = requests.get(url) - - symbols_list = json.loads(resp.text.replace( - '', '').replace('', '')) - except Exception as e: - logger.error(e) - - return symbols_list + return asyncio.run(self.get_hist_async(symbols, exchange, interval, n_bars, dataFrame, fut_contract, extended_session)) if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) tv = TvDatafeed() - print(tv.get_hist("CRUDEOIL", "MCX", fut_contract=1)) + + symbols = ['SBIN', 'EICHERMOT', 'INFY', 'BHARTIARTL', 'NESTLEIND', 'ASIANPAINT', 'ITC'] + print(tv.get_hist(symbols, "NSE", n_bars=500)) print(tv.get_hist("NIFTY", "NSE", fut_contract=1)) - print( - tv.get_hist( + print(tv.get_hist( "EICHERMOT", "NSE", interval=Interval.in_1_hour, n_bars=500, extended_session=False, + dataFrame=False ) )