We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Mindcv 实现的GPSA layer代码 中get_attention 函数没有对attn 进行normalization:
get_attention
attn
def get_attention(self, x: Tensor) -> Tensor: B, N, C = x.shape q = ops.reshape(self.q(x), (B, N, self.num_heads, C // self.num_heads)) q = ops.transpose(q, (0, 2, 1, 3)) k = ops.reshape(self.k(x), (B, N, self.num_heads, C // self.num_heads)) k = ops.transpose(k, (0, 2, 3, 1)) pos_score = self.pos_proj(self.rel_indices) pos_score = ops.transpose(pos_score, (0, 3, 1, 2)) pos_score = self.softmax(pos_score) patch_score = self.batch_matmul(q, k) patch_score = ops.mul(patch_score, self.scale) patch_score = self.softmax(patch_score) gating = ops.reshape(self.gating_param, (1, -1, 1, 1)) gating = ops.Sigmoid()(gating) attn = (1.0 - gating) * patch_score + gating * pos_score attn = self.attn_drop(attn) return attn
再看一个hugging face实现的torch 版本:
def get_attention(self, x): B, N, C = x.shape qk = self.qk(x).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k = qk[0], qk[1] pos_score = self.rel_indices.expand(B, -1, -1, -1) pos_score = self.pos_proj(pos_score).permute(0, 3, 1, 2) patch_score = (q @ k.transpose(-2, -1)) * self.scale patch_score = patch_score.softmax(dim=-1) pos_score = pos_score.softmax(dim=-1) gating = self.gating_param.view(1, -1, 1, 1) attn = (1. - torch.sigmoid(gating)) * patch_score + torch.sigmoid(gating) * pos_score attn /= attn.sum(dim=-1).unsqueeze(-1) # attention normalized by its sum attn = self.attn_drop(attn) return attn
虽然并不清楚这个normalization对performance的影响是大还是小,但是我认为最好跟原论文保持一致。
The text was updated successfully, but these errors were encountered:
SamitHuang
No branches or pull requests
Mindcv 实现的GPSA layer代码 中
get_attention
函数没有对attn
进行normalization:再看一个hugging face实现的torch 版本:
虽然并不清楚这个normalization对performance的影响是大还是小,但是我认为最好跟原论文保持一致。
The text was updated successfully, but these errors were encountered: