Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

能给一份S^2 Attension推理的代码吗? #120

Open
hxs91 opened this issue Nov 3, 2023 · 4 comments
Open

能给一份S^2 Attension推理的代码吗? #120

hxs91 opened this issue Nov 3, 2023 · 4 comments

Comments

@hxs91
Copy link

hxs91 commented Nov 3, 2023

RT

@yukang2017
Copy link
Member

你好,

您直接在inference的时候把forward的函数替换成这个就可以。

def forward_flashattn(

昨天有PR已经把这里改好了。
#114

Regards,
Yukang Chen

@hxs91
Copy link
Author

hxs91 commented Nov 7, 2023

@yukang2017 看起来里面的代码跟我想的不太一样哈。我如果没理解错的话,这里的代码是离线推理的?我比较期望的是online版的S^2 Attention,但现在看了一下online版应该只有用flash attention实现的full attention。

online版的确有一些问题现在还不明确,比如,在每生成一个token的时候,它前面被cache起来的kv size没法被4整除,这时要做一些其他的事,比如padding或者truncate?

@coranholmes
Copy link

coranholmes commented Nov 8, 2023

你好,

您直接在inference的时候把forward的函数替换成这个就可以。

def forward_flashattn(

昨天有PR已经把这里改好了。 #114

Regards, Yukang Chen

我看这里说推理时候不需要s2 attention,但是你的代码里又有forward_flashattn_inference,所以推理时究竟需要s2 attention吗?用vllm之类的推理框架时是只要用默认的flash_attn就行吗?

@yukang2017
Copy link
Member

yukang2017 commented Nov 19, 2023

@coranholmes 你好,forward_flashattn_inference 里面就是标准的attention 推理,不是S^2 attention. 用默认的就可以的。

@hxs91 你好,我确实还没有实现过S^2 attention + KV cache推理的代码,现在的forward_flashattn版本其实已经不需要padding或者考虑整除的问题,您可以尝试一下。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants