Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

如何使用state tuning rwkv6-7B? #246

Open
xinyinan9527 opened this issue May 23, 2024 · 4 comments
Open

如何使用state tuning rwkv6-7B? #246

xinyinan9527 opened this issue May 23, 2024 · 4 comments

Comments

@xinyinan9527
Copy link

我按照官网尝试,
应该是只训练time_state,然而报错

RuntimeError: element o of tensors does not require grad and does not have a grad_fn

@JL-er
Copy link

JL-er commented Jul 6, 2024

请问你是直接使用的RWKV-LM项目,还是自己修改的?如果是自己修改的项目,在冻结梯度是deepspeed的checkpoint会报错,你需要使用torch.checkpoint 详细可以参考RWKV-PEFT

@shouldsee
Copy link

@JL-er 谢谢,请问为啥是只tune time_state 64*64的矩阵呀,另外的两组state为啥不一起微调呢?

            state[i*3+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
            state[i*3+1] = state_xueshan_raw[f'blocks.{i}.att.time_state'].transpose(1,2).to(dtype=torch.float, device=dev).requires_grad_(False).contiguous()
            state[i*3+2] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()

@JL-er
Copy link

JL-er commented Aug 1, 2024

@JL-er 谢谢,请问为啥是只tune time_state 64*64的矩阵呀,另外的两组state为啥不一起微调呢?

            state[i*3+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
            state[i*3+1] = state_xueshan_raw[f'blocks.{i}.att.time_state'].transpose(1,2).to(dtype=torch.float, device=dev).requires_grad_(False).contiguous()
            state[i*3+2] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()

这两个参数非常小影响不大,所以只取核心部分的state,为了简单方便

@shouldsee
Copy link

好的了解了,谢谢!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants