|

How to Build Memory-Efficient Transformers with xFormers Using Packed Sequences, GQA, ALiBi, SwiGLU, and Causal Attention

In this tutorial, we implement xFormers: a sensible toolkit for constructing quick, memory-efficient Transformer fashions on GPUs. We start by validating memory-efficient consideration towards an ordinary consideration implementation, then examine their pace and reminiscence consumption throughout totally different sequence lengths. We then look at causal masking, packed variable-length sequences, grouped-query consideration, and customized ALiBi positional biases. Finally, we mix these strategies right into a trainable GPT-style mannequin that makes use of xFormers consideration, SwiGLU feed-forward layers, and automated mixed-precision coaching.

Setting Up xFormers and Validating Memory-Efficient Attention

import subprocess, sys
def _pip(*a): subprocess.run([sys.executable, "-m", "pip", "install", *a], examine=False)
strive:
   import xformers
besides Exception:
   _pip("-q", "-U", "xformers")
import math, time
import torch, torch.nn as nn, torch.nn.practical as F
import xformers, xformers.ops as xops
from xformers.ops import fmha
ab = fmha.attn_bias
assert torch.cuda.is_available(), (
   "No GPU detected. In Colab: Runtime → Change runtime sort → GPU, then re-run.")
system = "cuda"
torch.manual_seed(0)
print("torch    :", torch.__version__)
print("xformers :", xformers.__version__)
print("GPU      :", torch.cuda.get_device_name(0))
print("n--- xformers.data (which kernels are constructed/out there) ---")
strive:
   subprocess.run([sys.executable, "-m", "xformers.info"], examine=False)
besides Exception as e:
   print("xformers.data unavailable:", e)
def cuda_time(fn, iters=20, warmup=5):
   for _ in vary(warmup): fn()
   torch.cuda.synchronize()
   s, e = (torch.cuda.Event(enable_timing=True) for _ in vary(2))
   s.file()
   for _ in vary(iters): fn()
   e.file(); torch.cuda.synchronize()
   return s.elapsed_time(e) / iters
def peak_mem_mb(fn):
   torch.cuda.empty_cache(); torch.cuda.reset_peak_memory_stats()
   fn(); torch.cuda.synchronize()
   return torch.cuda.max_memory_allocated() / 1e6
def vanilla_attention(q, ok, v, causal=False):
   """Reference consideration that MATERIALIZES the [B,H,M,M] rating matrix.
      Inputs are xformers-layout [B, M, H, K]."""
   q, ok, v = (t.transpose(1, 2).float() for t in (q, ok, v))
   scores = (q @ ok.transpose(-2, -1)) / math.sqrt(q.form[-1])
   if causal:
       M = scores.form[-1]
       m = torch.triu(torch.ones(M, M, system=q.system, dtype=torch.bool), 1)
       scores = scores.masked_fill(m, float("-inf"))
   out = scores.softmax(-1) @ v
   return out.transpose(1, 2)
print("n" + "="*70 + "n1. memory_efficient_attention fundamentals + correctnessn" + "="*70)
B, M, H, Okay = 2, 512, 8, 64
q, ok, v = (torch.randn(B, M, H, Okay, system=system, dtype=torch.float16) for _ in vary(3))
out_xf  = xops.memory_efficient_attention(q, ok, v)
out_ref = vanilla_attention(q, ok, v).half()
print("output form         :", tuple(out_xf.form), "(structure B, M, H, Okay)")
print("max abs diff vs ref  : {:.2e}".format((out_xf - out_ref).abs().max().merchandise()))
print("-> it is EXACT consideration (fp16 rounding solely), simply computed with out")
print("   ever storing the total MxM rating matrix.")

We set up and import xFormers, confirm GPU availability, and examine the eye kernels supported by the atmosphere. We outline helper capabilities for measuring CUDA execution time and peak reminiscence consumption. We then validate memory-efficient consideration towards commonplace consideration to verify that each produce outcomes that intently match one another.

Benchmarking Memory and Speed Against Naive Causal Attention

print("n" + "="*70 + "n2. Memory & pace vs naive consideration (fwd+bwd)n" + "="*70)
print(f"{'seqlen':>8} | {'naive MB':>10} | {'xformers MB':>12} | {'naive ms':>9} | {'xf ms':>7}")
print("-"*60)
for M in [512, 1024, 2048, 4096]:
   q, ok, v = (torch.randn(2, M, 8, 64, system=system, dtype=torch.float16,
                          requires_grad=True) for _ in vary(3))
   def run_xf():
       o = xops.memory_efficient_attention(q, ok, v); o.sum().backward()
   def run_naive():
       o = vanilla_attention(q, ok, v); o.sum().backward()
   strive:
       nm = peak_mem_mb(run_naive); nt = cuda_time(run_naive, 8, 3)
   besides RuntimeError:
       nm, nt = float("nan"), float("nan"); torch.cuda.empty_cache()
   xm = peak_mem_mb(run_xf); xt = cuda_time(run_xf, 8, 3)
   print(f"{M:>8} | {nm:>10.0f} | {xm:>12.0f} | {nt:>9.2f} | {xt:>7.2f}")
print("-> naive reminiscence grows ~4x per doubling of M (it shops BxHxMxM);")
print("   xformers grows ~linearly and stays quick.")
print("n" + "="*70 + "n3. Causal consideration by way of LowerTriangularMaskn" + "="*70)
B, M, H, Okay = 2, 256, 8, 64
q, ok, v = (torch.randn(B, M, H, Okay, system=system, dtype=torch.float16) for _ in vary(3))
out_causal = xops.memory_efficient_attention(q, ok, v, attn_bias=ab.LowerTriangularMask())
ref_causal = vanilla_attention(q, ok, v, causal=True).half()
print("causal max abs diff  : {:.2e}".format((out_causal - ref_causal).abs().max().merchandise()))
print("-> the masks is implicit; no MxM boolean tensor is allotted.")

We benchmark naive consideration and xFormers consideration throughout progressively longer sequences utilizing ahead and backward passes. We examine their execution occasions and peak GPU reminiscence utilization to observe how xFormers avoids quadratic reminiscence progress. We additionally apply an implicit lower-triangular masks and confirm causal consideration towards the reference implementation.

Packing Variable-Length Sequences and Running Grouped-Query Attention

print("n" + "="*70 + "n4. Variable-length packed batch — no padding wasten" + "="*70)
seqlens = [37, 120, 8, 200]
whole = sum(seqlens)
H, Okay = 8, 64
q = torch.randn(1, whole, H, Okay, system=system, dtype=torch.float16)
ok = torch.randn(1, whole, H, Okay, system=system, dtype=torch.float16)
v = torch.randn(1, whole, H, Okay, system=system, dtype=torch.float16)
strive:
   bias = ab.BlockDiagonalMask.from_seqlens(seqlens)
   out_packed = xops.memory_efficient_attention(q, ok, v, attn_bias=bias)
   s0 = seqlens[0]
   ref0 = vanilla_attention(q[:, :s0], ok[:, :s0], v[:, :s0]).half()
   print("packed form         :", tuple(out_packed.form), "(all", whole, "tokens, no pad)")
   print("segment-0 max diff   : {:.2e}".format((out_packed[:, :s0] - ref0).abs().max().merchandise()))
   cbias = ab.BlockDiagonalCausalMasks.from_seqlens(seqlens)
   _ = xops.memory_efficient_attention(q, ok, v, attn_bias=cbias)
   print("-> additionally did a packed CAUSAL move. This is how vLLM-style engines")
   print("   batch requests of various lengths with zero padding overhead.")
   splits = bias.cut up(out_packed)
   print("recovered segments   :", [tuple(t.shape) for t in splits])
besides Exception as e:
   print("BlockDiagonalMask path skipped on this model/backend:", repr(e))
print("n" + "="*70 + "n5. Grouped-query consideration (5-D BMGHK structure)n" + "="*70)
B, M, Okay = 2, 256, 64
n_q_heads, n_kv_heads = 8, 2
G, Hq = n_kv_heads, n_q_heads // n_kv_heads
strive:
   qg = torch.randn(B, M, G, Hq, Okay, system=system, dtype=torch.float16)
   kg = torch.randn(B, M, G, 1,  Okay, system=system, dtype=torch.float16)
   vg = torch.randn(B, M, G, 1,  Okay, system=system, dtype=torch.float16)
   out_gqa = xops.memory_efficient_attention(qg, kg, vg)
   print("GQA output form     :", tuple(out_gqa.form), "= [B, M, G, Hq, K]")
   print(f"-> {n_q_heads} question heads, solely {n_kv_heads} KV heads: smaller KV-cache,")
   print("   which is strictly what Llama-/Mistral-class fashions use at inference.")
besides Exception as e:
   print("GQA 5-D path skipped on this model/backend:", repr(e))

We concatenate variable-length sequences and use BlockDiagonalMask to forestall consideration from crossing sequence boundaries with out padding. We get better the person outputs and additionally carry out packed causal consideration for decoder-style workloads. We then exhibit grouped-query consideration, the place a number of question heads share fewer key-value heads to scale back KV-cache necessities.

Adding a Custom ALiBi Additive Positional Bias

print("n" + "="*70 + "n6. Custom ALiBi additive biasn" + "="*70)
B, M, H, Okay = 1, 128, 8, 64
q, ok, v = (torch.randn(B, M, H, Okay, system=system, dtype=torch.float16) for _ in vary(3))
strive:
   slopes = (2.0 ** (-8.0 / H)) ** torch.arange(1, H + 1, system=system)
   pos = torch.arange(M, system=system)
   rel = (pos[None, :] - pos[:, None]).clamp(max=0).float()
   alibi = slopes[:, None, None] * rel[None]
   alibi = alibi[None].broaden(B, H, M, M).to(torch.float16).contiguous()
   causal = torch.triu(torch.ones(M, M, system=system, dtype=torch.bool), 1)
   alibi = alibi.masked_fill(causal[None, None], float("-inf"))
   out_alibi = xops.memory_efficient_attention(q, ok, v, attn_bias=alibi)
   print("ALiBi output form   :", tuple(out_alibi.form))
   print("-> any per-(head,question,key) additive bias works the identical means.")
besides Exception as e:
   print("Custom-bias path skipped (some backends limit bias shapes):", repr(e))

We assemble a customized ALiBi tensor that applies a distinct linear positional penalty to every consideration head. We mix this additive bias with a causal masks in order that tokens attend solely to legitimate earlier positions. We move the ensuing bias instantly to xFormers consideration and confirm the form of its output.

Training a GPT Block with xFormers Attention and SwiGLU

print("n" + "="*70 + "n7. Train a small GPT block (xformers attn + SwiGLU)n" + "="*70)
def make_swiglu(d, hidden):
   """Fused xformers SwiGLU if out there, else a clear handbook fallback."""
   strive:
       m = xops.SwiGLU(in_features=d, hidden_features=hidden, out_features=d, bias=True)
       return m, "fused xops.SwiGLU"
   besides Exception:
       class SwiGLU(nn.Module):
           def __init__(s):
               tremendous().__init__()
               s.w12 = nn.Linear(d, 2 * hidden); s.w3 = nn.Linear(hidden, d)
           def ahead(s, x):
               a, b = s.w12(x).chunk(2, -1)
               return s.w3(F.silu(a) * b)
       return SwiGLU(), "handbook SwiGLU fallback"
class Block(nn.Module):
   def __init__(self, d, n_heads, mlp_mult=4):
       tremendous().__init__()
       self.h, self.ok = n_heads, d // n_heads
       self.n1, self.n2 = nn.LayerNorm(d), nn.LayerNorm(d)
       self.qkv, self.proj = nn.Linear(d, 3 * d), nn.Linear(d, d)
       self.ff, self.ff_kind = make_swiglu(d, mlp_mult * d)
   def ahead(self, x):
       B, M, d = x.form
       qkv = self.qkv(self.n1(x)).reshape(B, M, 3, self.h, self.ok)
       q, kk, vv = qkv.unbind(2)
       a = xops.memory_efficient_attention(q, kk, vv, attn_bias=ab.LowerTriangularMask())
       x = x + self.proj(a.reshape(B, M, d))
       return x + self.ff(self.n2(x))
class TinyGPT(nn.Module):
   def __init__(self, vocab, d=128, n_layers=3, n_heads=8, maxlen=64):
       tremendous().__init__()
       self.tok = nn.Embedding(vocab, d); self.pos = nn.Embedding(maxlen, d)
       self.blocks = nn.ModuleRecord(Block(d, n_heads) for _ in vary(n_layers))
       self.nf, self.head = nn.LayerNorm(d), nn.Linear(d, vocab)
   def ahead(self, idx):
       B, M = idx.form
       x = self.tok(idx) + self.pos(torch.arange(M, system=idx.system))[None]
       for b in self.blocks: x = b(x)
       return self.head(self.nf(x))
VOCAB, SEQ = 64, 64
def make_batch(B):
   begin = torch.randint(0, VOCAB, (B, 1), system=system)
   return (begin + torch.arange(SEQ, system=system)[None]) % VOCAB
mannequin = TinyGPT(VOCAB).to(system)
print("FFN sort             :", mannequin.blocks[0].ff_kind)
choose = torch.optim.AdamW(mannequin.parameters(), lr=3e-3)
scaler = torch.amp.GradScaler("cuda")
for step in vary(400):
   seq = make_batch(64); inp, tgt = seq[:, :-1], seq[:, 1:]
   with torch.autocast("cuda", dtype=torch.float16):
       logits = mannequin(inp)
       loss = F.cross_entropy(logits.reshape(-1, VOCAB), tgt.reshape(-1))
   choose.zero_grad(); scaler.scale(loss).backward(); scaler.step(choose); scaler.replace()
   if step % 80 == 0 or step == 399:
       acc = (logits.argmax(-1) == tgt).float().imply().merchandise()
       print(f"step {step:4d} | loss {loss.merchandise():.4f} | next-token acc {acc*100:5.1f}%")
print("-> a full causal transformer operating on memory-efficient consideration,")
print("   educated end-to-end with AMP. Swap in actual knowledge/tokenizer to scale up.")
print("nDone. Sections 1-3 are core; 4-6 are the superior bits price preserving.")

We construct a compact GPT-style Transformer utilizing causal xFormers consideration, residual connections, normalization, and SwiGLU feed-forward layers. We prepare the mannequin with automated blended precision on an artificial next-token prediction process that counts upward modulo the vocabulary measurement. We monitor its loss and accuracy to verify that the whole memory-efficient Transformer learns efficiently end-to-end.

Conclusion

In conclusion, we developed a sensible understanding of how xFormers improves Transformer effectivity with out altering the basic consideration calculation. We noticed how memory-efficient kernels scale back the price of lengthy sequences, whereas causal masks, packed sequences, grouped-query consideration, and additive biases assist real looking coaching and inference workflows. We concluded by integrating these capabilities right into a compact GPT mannequin and coaching it end-to-end, giving us a robust basis for making use of xFormers to bigger language fashions and extra demanding datasets.


Check out the Full Codes with NotebookAlso, be happy to comply with us on Twitter and don’t neglect to be part of our 150k+ML SubReddit and Subscribe to our Newsletter. Wait! are you on telegram? now you can join us on telegram as well.

Need to accomplice with us for selling your GitHub Repo OR Hugging Face Page OR Product Release OR Webinar and so forth.? Connect with us

The submit How to Build Memory-Efficient Transformers with xFormers Using Packed Sequences, GQA, ALiBi, SwiGLU, and Causal Attention appeared first on MarkTechPost.

Similar Posts