Skip to content

Commit

Permalink
remove attention abstraction (#5324)
Browse files Browse the repository at this point in the history
  • Loading branch information
contentis authored Oct 22, 2024
1 parent 8ce2a10 commit 5a8a489
Showing 1 changed file with 8 additions and 10 deletions.
18 changes: 8 additions & 10 deletions comfy/ldm/modules/diffusionmodules/mmdit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
import torch
import torch.nn as nn
from .. import attention
from ..attention import optimized_attention
from einops import rearrange, repeat
from .util import timestep_embedding
import comfy.ops
Expand Down Expand Up @@ -266,8 +266,6 @@ def split_qkv(qkv, head_dim):
qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, -1, head_dim).movedim(2, 0)
return qkv[0], qkv[1], qkv[2]

def optimized_attention(qkv, num_heads):
return attention.optimized_attention(qkv[0], qkv[1], qkv[2], num_heads)

class SelfAttention(nn.Module):
ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug")
Expand Down Expand Up @@ -326,9 +324,9 @@ def post_attention(self, x: torch.Tensor) -> torch.Tensor:
return x

def forward(self, x: torch.Tensor) -> torch.Tensor:
qkv = self.pre_attention(x)
q, k, v = self.pre_attention(x)
x = optimized_attention(
qkv, num_heads=self.num_heads
q, k, v, heads=self.num_heads
)
x = self.post_attention(x)
return x
Expand Down Expand Up @@ -531,8 +529,8 @@ def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
assert not self.pre_only
qkv, intermediates = self.pre_attention(x, c)
attn = optimized_attention(
qkv,
num_heads=self.attn.num_heads,
qkv[0], qkv[1], qkv[2],
heads=self.attn.num_heads,
)
return self.post_attention(attn, *intermediates)

Expand All @@ -557,8 +555,8 @@ def _block_mixing(context, x, context_block, x_block, c):
qkv = tuple(o)

attn = optimized_attention(
qkv,
num_heads=x_block.attn.num_heads,
qkv[0], qkv[1], qkv[2],
heads=x_block.attn.num_heads,
)
context_attn, x_attn = (
attn[:, : context_qkv[0].shape[1]],
Expand Down Expand Up @@ -642,7 +640,7 @@ def __init__(self, dim, heads=8, dim_head=64, dtype=None, device=None, operation
def forward(self, x):
qkv = self.qkv(x)
q, k, v = split_qkv(qkv, self.dim_head)
x = optimized_attention((q.reshape(q.shape[0], q.shape[1], -1), k, v), self.heads)
x = optimized_attention(q.reshape(q.shape[0], q.shape[1], -1), k, v, heads=self.heads)
return self.proj(x)

class ContextProcessorBlock(nn.Module):
Expand Down

0 comments on commit 5a8a489

Please sign in to comment.