Skip to content

Commit

Permalink
Merge pull request #7 from tejtw/v1.1.1rc2
Browse files Browse the repository at this point in the history
MAINT: update for pandas version 2
  • Loading branch information
Han860207 authored Feb 23, 2024
2 parents cf93e1f + 576f4ec commit 3bd23f7
Showing 1 changed file with 8 additions and 31 deletions.
39 changes: 8 additions & 31 deletions TejToolAPI/TejToolAPI.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ def get_history_data(ticker:list, columns:list = [], fin_type:list = ['A','Q','T
history_data = history_data.drop(columns=[i for i in history_data.columns if i in para.drop_keys+['fin_date', 'mon_sales_date', 'share_date']])

# Transfer to pandas dataframe
# history_data = history_data.compute(meta = history_data)
history_data = history_data.compute(meta = Meta_Types.all_meta)

# Drop repeat rows from the table.
Expand All @@ -92,10 +91,9 @@ def get_history_data(ticker:list, columns:list = [], fin_type:list = ['A','Q','T
history_data = history_data.compute(meta = Meta_Types.all_meta)

# Truncate resuly by user-setted start.
history_data = history_data.loc[history_data.mdate >= org_start,:]
history_data = history_data.loc[history_data.mdate >= pd.Timestamp(org_start),:]

# Transfer columns to abbreviation text.
# print(history_data.columns)
lang_map = transfer_language_columns(history_data.columns, isChinese=transfer_to_chinese)
history_data = history_data.rename(columns= lang_map)
history_data = history_data.reset_index(drop=True)
Expand All @@ -109,9 +107,6 @@ def process_fin_data(all_tables, variable, tickers, start, end):
days = days.rename(columns = {'zdate':'all_dates'})
all_tables[variable] = all_tables[variable].rename(columns = {'mdate':'fin_date'})
all_tables[variable] = dd.merge(days, all_tables[variable], left_on=['all_dates'], right_on=['annd'], how='left')

# Drop mdate column
# all_tables[variable] = all_tables[variable].drop(columns = 'mdate')

# Delete the redundant dataframe to release memory space
del days
Expand All @@ -129,9 +124,7 @@ def get_col_name(col, isChinese):
transfer_lang = 'CHN_COLUMN_NAMES' if isChinese else 'ENG_COLUMN_NAMES'
try:
col_name = search_columns([col])[transfer_lang].dropna().drop_duplicates(keep='last').item()

except:
# print(col)
col_name = search_columns([col])[transfer_lang].dropna().tail(1).item()

return col_name if col_name else col
Expand Down Expand Up @@ -180,12 +173,8 @@ def triggers(ticker:list, columns:list = [], fin_type:list = ['A','Q','TTM'], i
columns = get_internal_code(columns)

# Kick out `coid` and `mdate` from
for i in ['coid', 'mdate','key3', 'no','sem','fin_type', 'curr', 'fin_ind']:
try:
columns.remove(i)
except:
pass
# columns = [i for i in columns if i !='coid' or i!='mdate']

columns = [col for col in columns if col not in ['coid', 'mdate','key3', 'no','sem','fin_type', 'curr', 'fin_ind'] ]

# Qualify the table triggered by the given `columns`
trigger_tables = search_table(columns)
Expand All @@ -196,7 +185,6 @@ def triggers(ticker:list, columns:list = [], fin_type:list = ['A','Q','TTM'], i
coid_calendar = get_stock_calendar(ticker, start = start, end = end, npartitions = npartitions)

# Get trading calendar of all given tickers
# if 'stk_price' not in trigger_tables['TABLE_NAMES'].unique():
trading_calendar = get_trading_calendar(ticker, start = start, end = end, npartitions = npartitions)

# If include_self_acc equals to 'N', then delete the fin_self_acc in the trigger_tables list
Expand Down Expand Up @@ -236,13 +224,15 @@ def consecutive_merge(local_var, loop_array):
# Merge tables by dask merge.

temp = local_var[loop_array[i]]
if temp['mdate'].dtype != data['mdate'].dtype :
if ('fin_date' in temp.columns) and ('mdate' not in temp.columns) : # modified 20240223 by Han
temp['mdate'] = temp['fin_date'].copy()
if (temp['mdate'].dtype != data['mdate'].dtype) :
data['mdate'] = data['mdate'].astype(temp['mdate'].dtype)
data = dd.merge(data, temp, on = ['coid', 'mdate'] , how = 'left', suffixes = ('','_surfeit'))

data = dd.merge(data, temp, on = ['coid', 'mdate'] , how = 'left', suffixes = ('','_surfeit'))
# Drop surfeit columns.
data = data.loc[:,~data.columns.str.contains('_surfeit')]

data['mdate'] = data['mdate'].astype('datetime64[ns]')
return data

def keep_repo_date(data):
Expand Down Expand Up @@ -276,18 +266,6 @@ def get_index_trading_date(tickers):

return data

# def get_data(tickers):
# # trading calendar
# data = get_index_trading_date(tickers)

# if len(data)<1:
# return pd.DataFrame({'coid': pd.Series(dtype='object'), 'mdate': pd.Series(dtype='datetime64[ns]')})

# return data


# Define the meta of the dataframe
# meta = pd.DataFrame({'coid': pd.Series(dtype='object'), 'mdate': pd.Series(dtype='datetime64[ns]')})

# Calculate the number of tickers in each partition.
ticker_partitions = dask_api.get_partition_group(tickers = tickers, npartitions= npartitions)
Expand Down Expand Up @@ -323,7 +301,6 @@ def get_data(tickers):


# Define the meta of the dataframe
# meta = pd.DataFrame({'coid': pd.Series(dtype='object'), 'mdate': pd.Series(dtype='datetime64[ns]')})

# Calculate the number of tickers in each partition.
ticker_partitions = dask_api.get_partition_group(tickers = tickers, npartitions= npartitions)
Expand Down

0 comments on commit 3bd23f7

Please sign in to comment.