-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathebm_agent.py
74 lines (59 loc) · 2.64 KB
/
ebm_agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from langchain.agents import ZeroShotAgent
from langchain.agents import AgentExecutor
from langchain.chains.llm import LLMChain
from langchain.memory import ConversationBufferMemory
import markdown
from IPython.display import Image, display
from llm2ebm import feature_importances_to_text
from tool import get_tools
from prompt import suffix_no_df,suffix_with_df,get_prefix
#用md语法表示的图的字符串生成图
def md2img(text):
# 使用Markdown库将Markdown文本转换为HTML
html_output = markdown.markdown(text)
# 解析HTML中的图片标签,并显示图片
def process_image_tags(html):
from bs4 import BeautifulSoup
soup = BeautifulSoup(html, 'html.parser')
# 找到所有的图片标签
img_tags = soup.find_all('img')
# 遍历图片标签,显示图片
for img in img_tags:
url = img['src']
alt_text = img.get('alt', '')
# 使用IPython.display模块的Image类显示图片
display(Image(url=url, alt=alt_text))
# 调用函数解析图片标签并显示图片
process_image_tags(html_output)
def get_agent(llm,ebm,df = None,dataset_description = None,y_axis_description = None):
#获取需要的ebm的属性
feature_importances = feature_importances_to_text(ebm)
global_explanation = global_explanation = ebm.explain_global().data
#获取prompt的prefix部分
prefix = get_prefix(ebm,feature_importances,dataset_description,y_axis_description)
#获取工具
tools=get_tools(ebm)
python_tool = tools[0]
#获取prompt的suffix部分
if df is not None:
python_tool.python_repl.locals={"df": df,"ft_graph":global_explanation}
input_variables=["input", "chat_history", "agent_scratchpad","df_head"]
suffix = suffix_with_df
else:
python_tool.python_repl.locals={"ft_graph":global_explanation}
input_variables=["input", "chat_history", "agent_scratchpad"]
suffix = suffix_no_df
prompt = ZeroShotAgent.create_prompt(
tools,
prefix=prefix,
suffix=suffix,
input_variables=input_variables,
)
memory = ConversationBufferMemory(memory_key="chat_history")
if 'df_head' in input_variables:
prompt = prompt.partial(df_head=str(df.head().to_markdown()))
llm_chain = LLMChain(llm=llm, prompt=prompt)
agent = ZeroShotAgent(llm_chain=llm_chain,tools=tools, verbose=True)
return AgentExecutor.from_agent_and_tools(
agent=agent, tools=tools, verbose=True, memory=memory,handle_parsing_errors=True
)