Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add inference cache (cleaner version) #4

Open
wants to merge 19 commits into
base: weight-sharing
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions dalle_pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def apply_pos_emb(pos_emb, qkv):
# classes

class Attention(nn.Module):
def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropout = 0., stable = False):
def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropout = 0., stable = False,
static_mask = None):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
Expand All @@ -46,25 +47,34 @@ def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropou

self.stable = stable
self.causal = causal
self.register_buffer('static_mask', static_mask, persistent=False)

self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)

def forward(self, x, mask = None, rotary_pos_emb = None):
def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key = None):
b, n, _, h, device = *x.shape, self.heads, x.device
softmax = torch.softmax if not self.stable else stable_softmax
offset = cache.get('offset', 0) if exists(cache) else 0

qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

if exists(rotary_pos_emb):
q, k, v = apply_pos_emb(rotary_pos_emb, (q, k, v))
q, k, v = apply_pos_emb(rotary_pos_emb[..., offset:, :], (q, k, v))

q = q * self.scale

if offset > 0:
k_top, v_top = cache[cache_key]
k = torch.cat([k_top, k], dim=-2)
v = torch.cat([v_top, v], dim=-2)
if exists(cache):
cache[cache_key] = k, v

dots = torch.einsum('b h i d, b h j d -> b h i j', q, k)
mask_value = max_neg_value(dots)

Expand All @@ -73,11 +83,14 @@ def forward(self, x, mask = None, rotary_pos_emb = None):
dots.masked_fill_(~mask, mask_value)
del mask

if self.causal:
if self.causal and offset == 0: # causality is naturally enforced for the cached inference
i, j = dots.shape[-2:]
mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
dots.masked_fill_(mask, mask_value)

if exists(self.static_mask):
dots.masked_fill_(~self.static_mask[offset:offset + n, :offset + n], mask_value)

attn = softmax(dots, dim=-1)

out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
Expand Down
18 changes: 15 additions & 3 deletions dalle_pytorch/dalle_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@ def __init__(
shared_attn_ids = None,
shared_ff_ids = None,
share_input_output_emb = False,
optimize_for_inference = False,
):
super().__init__()
assert isinstance(vae, (DiscreteVAE, OpenAIDiscreteVAE, VQGanVAE)), 'vae must be an instance of DiscreteVAE'
Expand Down Expand Up @@ -391,6 +392,7 @@ def __init__(
rotary_emb = rotary_emb,
shared_attn_ids = shared_attn_ids,
shared_ff_ids = shared_ff_ids,
optimize_for_inference = optimize_for_inference,
)

self.stable = stable
Expand Down Expand Up @@ -484,7 +486,8 @@ def generate_images(
filter_thres = 0.5,
temperature = 1.,
img = None,
num_init_img_tokens = None
num_init_img_tokens = None,
use_cache = False,
):
vae, text_seq_len, image_seq_len, num_text_tokens = self.vae, self.text_seq_len, self.image_seq_len, self.num_text_tokens
total_len = text_seq_len + image_seq_len
Expand All @@ -503,12 +506,13 @@ def generate_images(
indices = indices[:, :num_img_tokens]
out = torch.cat((out, indices), dim = -1)

cache = {} if use_cache else None
for cur_len in range(out.shape[1], total_len):
is_image = cur_len >= text_seq_len

text, image = out[:, :text_seq_len], out[:, text_seq_len:]

logits = self(text, image, mask = mask)[:, -1, :]
logits = self(text, image, mask = mask, cache = cache)[:, -1, :]

filtered_logits = top_k(logits, thres = filter_thres)
probs = F.softmax(filtered_logits / temperature, dim = -1)
Expand Down Expand Up @@ -536,6 +540,7 @@ def forward(
text,
image = None,
mask = None,
cache = None,
return_loss = False
):
assert text.shape[-1] == self.text_seq_len, f'the length {text.shape[-1]} of the text tokens you passed in does not have the correct length ({self.text_seq_len})'
Expand Down Expand Up @@ -584,7 +589,9 @@ def forward(
alpha = 0.1
tokens = tokens * alpha + tokens.detach() * (1 - alpha)

out = self.transformer(tokens)
if exists(cache) and cache.get('offset'):
tokens = tokens[:, -1:]
out = self.transformer(tokens, cache=cache)

if self.stable:
out = self.norm_by_max(out)
Expand All @@ -594,9 +601,14 @@ def forward(
# mask logits to make sure text predicts text (except last token), and image predicts image

logits_mask = self.logits_mask[:, :seq_len]
if exists(cache) and cache.get('offset'):
logits_mask = logits_mask[:, -1:]
max_neg_value = -torch.finfo(logits.dtype).max
logits.masked_fill_(logits_mask, max_neg_value)

if exists(cache):
cache['offset'] = cache.get('offset', 0) + logits.shape[1]

if not return_loss:
return logits

Expand Down
128 changes: 117 additions & 11 deletions dalle_pytorch/transformer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import deque
from collections.abc import Iterable
from functools import partial
from itertools import islice, cycle
Expand Down Expand Up @@ -35,6 +36,41 @@ def forward(self, x):
maxes = x.amax(dim = self.dim, keepdim = True)
return x / maxes

class NonCached(nn.Module):
"""
A wrapper for layers that don't support the inference cache themselves.
Reconstructs the full sequence before the layer and
cuts the suffix of the outputs after the layer.
"""

def __init__(self, fn):
super().__init__()
self.fn = fn

def forward(self, x, *, cache = None, cache_key = None, **kwargs):
n = x.shape[-2]
if exists(cache):
if cache_key in cache:
x = torch.cat([cache[cache_key], x], dim=-2)
cache[cache_key] = x

out = self.fn(x, **kwargs)

return out[:, -n:]

class CachedAs(nn.Module):
"""
A wrapper that defines a key for the inference cache.
"""

def __init__(self, cache_key, fn):
super().__init__()
self.cache_key = cache_key
self.fn = fn

def forward(self, x, *, cache=None, **kwargs):
return self.fn(x, cache=cache, cache_key=self.cache_key, **kwargs)

# https://arxiv.org/abs/2103.17239
class LayerScale(nn.Module):
def __init__(self, dim, depth, fn):
Expand Down Expand Up @@ -83,7 +119,7 @@ def __init__(self, dim, dropout = 0., mult = 4.):
nn.Linear(dim * mult, dim)
)

def forward(self, x):
def forward(self, x, cache=None, cache_key=None):
return self.net(x)

# token shift classes
Expand All @@ -94,12 +130,30 @@ def __init__(self, fn, image_size, seq_len):
self.fn = fn
self.image_size = image_size
self.seq_len = seq_len
self.img_seq_len = image_size ** 2
self.text_len = seq_len - self.img_seq_len + 1

def forward(self, x, cache=None, cache_key=None, **kwargs):
seq_len, image_size, text_len = self.seq_len, self.image_size, self.text_len

if exists(cache) and cache_key in cache:
offset = cache['offset']
assert offset >= text_len, "cached inference for text is not supported"
q = cache[cache_key]
assert isinstance(q, deque) and len(q) == image_size

x_top, x_left, *x_pass = x[:, -1].chunk(4, dim=-1)

q.append((x_top, x_left))
x_top = q.popleft()[0]
x_left = q[-2][1]
if (offset - text_len) % image_size == 0:
x_left = torch.zeros_like(x_left)

x = torch.cat((x_top, x_left, *x_pass), dim=-1)
return self.fn(x[:, None], cache=cache, **kwargs)

def forward(self, x, **kwargs):
n = x.shape[1]
seq_len, image_size = self.seq_len, self.image_size
img_seq_len = image_size ** 2
text_len = seq_len - img_seq_len + 1
padding = seq_len - n + 1

# get text and image tokens
Expand All @@ -124,8 +178,22 @@ def forward(self, x, **kwargs):
# merge text and image sequence back together

x_img = rearrange(x_img, 'b h w d -> b (h w) d')
x = torch.cat((x_text, x_img[:, :-padding]), dim = 1)
return self.fn(x, **kwargs)
x_img = x_img[:, :-padding]
x = torch.cat((x_text, x_img), dim = 1)

if exists(cache):
dummy_top, dummy_left, *_ = x[:, -1].chunk(4, dim=-1)
dummy_top, dummy_left = torch.zeros_like(dummy_top), torch.zeros_like(dummy_left)

q = deque()
x_img = x_img[:, -image_size:]
for _ in range(image_size - x_img.shape[1]):
q.append((dummy_top, dummy_left))
for i in range(x_img.shape[1]):
q.append(x_img[:, i].chunk(4, dim=-1)[:2])
cache[cache_key] = q

return self.fn(x, cache=cache, **kwargs)

# main transformer class

Expand All @@ -152,11 +220,15 @@ def __init__(
rotary_emb = True,
shared_attn_ids = None,
shared_ff_ids = None,
optimize_for_inference = False, # use cache-friendly masked attention instead of sparse one
):
super().__init__()
layers = nn.ModuleList([])
sparse_layer = cast_tuple(sparse_attn, depth)

self.seq_len = seq_len
self.image_fmap_size = image_fmap_size

attn_types = default(attn_types, ('full',))
attn_types = cast_tuple(attn_types)
attn_type_layer = islice(cycle(attn_types), depth)
Expand All @@ -173,9 +245,15 @@ def __init__(
elif attn_type == 'sparse':
attn_class = SparseAttention
elif attn_type == 'axial_row':
attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 0, image_size = image_fmap_size, stable = stable)
if optimize_for_inference:
attn_class = partial(Attention, stable = stable, static_mask = self._get_attention_mask(attn_type))
else:
attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 0, image_size = image_fmap_size, stable = stable)
elif attn_type == 'axial_col':
attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 1, image_size = image_fmap_size, stable = stable)
if optimize_for_inference:
attn_class = partial(Attention, stable = stable, static_mask = self._get_attention_mask(attn_type))
else:
attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 1, image_size = image_fmap_size, stable = stable)
elif attn_type == 'conv_like':
attn_class = partial(SparseConvCausalAttention, seq_len = seq_len, image_size = image_fmap_size, stable = stable)
elif attn_type == 'mlp':
Expand All @@ -199,8 +277,15 @@ def __init__(
ff = FeedForward(dim, mult = ff_mult, dropout = ff_dropout)
shared_ff_layers[ff_id] = ff

if isinstance(attn, Attention):
attn = CachedAs(f'attn_{ind}', attn)
else:
# at the moment, other attention classes don't support cache
attn = NonCached(attn)

if shift_tokens:
attn, ff = map(lambda t: PreShiftToken(t, image_size = image_fmap_size, seq_len = seq_len), (attn, ff))
attn = CachedAs(f'preshift_attn_{ind}', PreShiftToken(attn, image_size = image_fmap_size, seq_len = seq_len))
ff = CachedAs(f'preshift_ff_{ind}', PreShiftToken(ff, image_size = image_fmap_size, seq_len = seq_len))

layers.append(nn.ModuleList([
LayerScale(dim, ind + 1, PreNorm(dim, attn, sandwich = sandwich_norm)),
Expand All @@ -209,7 +294,9 @@ def __init__(

execute_type = ReversibleSequence if reversible else SequentialSequence
route_attn = ((True, False),) * depth
attn_route_map = {'mask': route_attn, 'rotary_pos_emb': route_attn}
route_all = ((True, True),) * depth
attn_route_map = {'mask': route_attn, 'rotary_pos_emb': route_attn,
'cache': route_all}

self.layers = execute_type(layers, args_route = attn_route_map)

Expand Down Expand Up @@ -245,3 +332,22 @@ def __init__(

def forward(self, x, **kwargs):
return self.layers(x, rotary_pos_emb = self.pos_emb, **kwargs)

def _get_attention_mask(self, attn_type):
img_seq_len = self.image_fmap_size ** 2
text_len = self.seq_len + 1 - img_seq_len

static_mask = torch.zeros(self.seq_len, self.seq_len, dtype=torch.bool)
static_mask[:, :text_len] = True
if attn_type == 'axial_row':
for row in range(self.image_fmap_size):
begin = text_len + row * self.image_fmap_size
end = text_len + (row + 1) * self.image_fmap_size
static_mask[begin:end, begin:end] = True
elif attn_type == 'axial_col':
for col in range(self.image_fmap_size):
begin = text_len + col
static_mask[begin::self.image_fmap_size, begin::self.image_fmap_size] = True
else:
raise ValueError(f'attention type "{attn_type}" can\'t be simulated with a static mask')
return static_mask