diff --git a/scripts/ipadapter/plugable_ipadapter.py b/scripts/ipadapter/plugable_ipadapter.py index 7e1e433d0..7b3ad8124 100644 --- a/scripts/ipadapter/plugable_ipadapter.py +++ b/scripts/ipadapter/plugable_ipadapter.py @@ -3,6 +3,35 @@ from .ipadapter_model import ImageEmbed, IPAdapterModel +def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor: + # Fallback implementation for PyTorch v1 compatibility (less efficient) + # Slightly modified from: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html + L, S = query.size(-2), key.size(-2) + scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale + attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) + if is_causal: + assert attn_mask is None + temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(query.dtype) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) + else: + attn_bias += attn_mask + attn_weight = query @ key.transpose(-2, -1) * scale_factor + attn_weight += attn_bias + attn_weight = torch.softmax(attn_weight, dim=-1) + attn_weight = torch.dropout(attn_weight, dropout_p, train=True) + return attn_weight @ value + +try: + scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention +except AttributeError: + pass + + def get_block(model, flag): return { "input": model.input_blocks, @@ -30,7 +59,7 @@ def attn_forward_hacked(self, x, context=None, **kwargs): (q, k, v), ) - out = torch.nn.functional.scaled_dot_product_attention( + out = scaled_dot_product_attention( q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False ) out = out.transpose(1, 2).reshape(batch_size, -1, h * head_dim) @@ -227,7 +256,7 @@ def forward(attn_blk, x, q): ip_k = ip_k.to(dtype=q.dtype) ip_v = ip_v.to(dtype=q.dtype) - ip_out = torch.nn.functional.scaled_dot_product_attention( + ip_out = scaled_dot_product_attention( q, ip_k, ip_v, attn_mask=None, dropout_p=0.0, is_causal=False ) ip_out = ip_out.transpose(1, 2).reshape(batch_size, -1, h * head_dim)