From dffbf53ed74a622edfe57cb59052b6bdd77ca3e6 Mon Sep 17 00:00:00 2001 From: connorsanders Date: Sun, 17 Dec 2023 01:43:29 -0600 Subject: [PATCH] Hardened data.py. --- test/test_yahoofinancials.py | 9 +++++---- yahoofinancials/data.py | 17 ++++++++++++++--- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/test/test_yahoofinancials.py b/test/test_yahoofinancials.py index 18dc3c3..e72b557 100644 --- a/test/test_yahoofinancials.py +++ b/test/test_yahoofinancials.py @@ -43,6 +43,7 @@ def setUp(self): self.test_yf_currencies = yf(currencies) self.test_yf_concurrent = yf(stocks, concurrent=True) self.test_yf_stock_flat = yf('C', flat_format=True) + self.test_yf_stock_analytic = yf('WFC') # Fundamentals Test def test_yf_fundamentals(self): @@ -120,15 +121,15 @@ def test_yf_fundamentals_flat(self): def test_yf_analytic_methods(self): # Get Insights - out = self.test_yf_stock_single.get_insights() - if out.get("C").get("instrumentInfo").get("technicalEvents").get("sector") == "Financial Services": + out = self.test_yf_stock_analytic.get_insights() + if out.get("WFC").get("instrumentInfo").get("technicalEvents").get("sector") == "Financial Services": self.assertEqual(True, True) else: self.assertEqual(False, True) # Get Recommendations - out = self.test_yf_stock_single.get_recommendations() - if isinstance(out.get("C"), list): + out = self.test_yf_stock_analytic.get_recommendations() + if isinstance(out.get("WFC"), list): self.assertEqual(True, True) else: self.assertEqual(False, True) diff --git a/yahoofinancials/data.py b/yahoofinancials/data.py index 384ed02..1975b14 100644 --- a/yahoofinancials/data.py +++ b/yahoofinancials/data.py @@ -560,6 +560,17 @@ def _create_dict_ent(self, up_ticker, statement_type, tech_type, report_name, hi dict_ent = {up_ticker: re_data} return dict_ent + def _retry_create_dict_ent(self, up_ticker, statement_type, tech_type, report_name, hist_obj): + i = 0 + while i < 250: + try: + out = self._create_dict_ent(up_ticker, statement_type, tech_type, report_name, hist_obj) + return out + except: + time.sleep(random.randint(2, 10)) + i += 1 + continue + # Private method to return the stmt_id for the reformat_process def _get_stmt_id(self, statement_type, raw_data): stmt_id = '' @@ -611,17 +622,17 @@ def get_time_code(self, time_interval): # Public Method to get stock data def get_stock_data(self, statement_type='income', tech_type='', report_name='', hist_obj={}): data = {} - if statement_type == 'income' and tech_type == '' and report_name == '': # temp, so this method doesn't return nulls + if statement_type == 'income' and tech_type == '' and report_name == '': # temp, so this method doesn't return nulls statement_type = 'profile' tech_type = 'assetProfile' report_name = 'assetProfile' if isinstance(self.ticker, str): - dict_ent = self._create_dict_ent(self.ticker, statement_type, tech_type, report_name, hist_obj) + dict_ent = self._retry_create_dict_ent(self.ticker, statement_type, tech_type, report_name, hist_obj) data.update(dict_ent) else: if self.concurrent: with Pool(self._get_worker_count()) as pool: - dict_ents = pool.map(partial(self._create_dict_ent, + dict_ents = pool.map(partial(self._retry_create_dict_ent, statement_type=statement_type, tech_type=tech_type, report_name=report_name,