|

Meet Flash-KMeans: An IO-Aware, Exact K-Means That Runs Over 200× Faster Than FAISS on GPUs

k-means has been an offline instrument for many years. You run it as soon as to preprocess information, then transfer on. A crew of researchers from UC Berkeley and UT Austin launched Flash-OkayMeans, a brand new open-source library that targets a special setting. Modern AI pipelines now name k-means inside coaching and inference loops. At that frequency, latency per name issues greater than theoretical FLOPs.

Flash-OkayMeans is an IO-aware implementation of normal Lloyd’s k-means. It doesn’t change the maths, and it doesn’t approximate. It solely restructures how the algorithm strikes information on a GPU. On an NVIDIA H200, the analysis crew reported as much as 17.9× end-to-end speedup over the most effective baseline. Against NVIDIA cuML they report 33×. Against FAISS they report over 200×.

What is Flash-OkayMeans

Flash-OkayMeans is a batched k-means library written in Triton GPU kernels. It ships below Apache 2.0 and installs with pip set up flash-kmeans.

The output is mathematically an identical to straightforward Lloyd’s k-means. The speedup comes from kernel-level dataflow, not from skipping work. That separates it from algorithmic strategies like triangle-inequality pruning or coreset sampling.

A regular Lloyd iteration has two phases. The project stage computes every level’s distance to each centroid, then picks the closest. The replace stage averages the factors in every cluster to kind new centroids. Both phases are easy arithmetic. On GPUs, each are bottlenecked by reminiscence, not compute.

The Two Bottlenecks It Attacks

The first bottleneck is the project stage. Standard code builds a full distance matrix D of form N×Okay in High Bandwidth Memory (HBM). It writes the matrix, then reads it again to run argmin. For N=65536, Okay=1024, d=128, B=32, the gap math takes 2.6ms. Writing and consuming D takes about 23ms. The matrix is the price, not the arithmetic.

Flash-OkayMeans replaces this with FlashAssign. The design borrows from FlashAttention. FlashAssign streams tiles of factors and centroids from HBM into on-chip SRAM. It fuses distance computation with a web based argmin. The full N×Okay matrix is rarely materialized. This cuts the dominant IO complexity from O(NK) to O(Nd + Kd). At the kernel stage, FlashAssign reaches as much as 21.2×. In one case it reduce project from 122.5ms to five.8ms.

The second bottleneck is the centroid replace stage. Standard code makes use of scatter-style atomic provides. Each thread provides its level right into a shared sum buffer keyed by cluster id. Many threads hit the identical ‘sizzling’ cluster directly. That causes atomic rivalry and {hardware} serialization. The analysis crew measured solely 50 GB/s efficient bandwidth right here on an H200.

Flash-OkayMeans replaces this with Sort-Inverse Update. It types the 1D project vector by cluster id utilizing argsort. Identical cluster ids then kind contiguous segments. Each thread block reduces a section on-chip, then points one atomic add per section. The heavy level matrix is rarely bodily permuted. Atomic operations drop from (O((Okay+NBN)d))(O((Okay + frac{N}{B_N})d)) . The kernel reaches as much as 6.3×.

Benchmark

The analysis crew take a look at it on an H200 with CUDA 12.8, FP16 information, and d=128. They sweep N, Okay, and batch dimension B. They examine towards 4 optimized baselines: fast_pytorch_kmeans, fastkmeans, cuML, and FAISS.

Comparison Reported speedup Workload context
End-to-end vs greatest baseline as much as 17.9× N=8M, Okay=1024 (giant N, small Okay)
vs NVIDIA cuML 33× {industry} library
vs FAISS over 200× {industry} library
FlashAssign kernel as much as 21.2× N=1M, Okay=8192 (project)
Sort-Inverse Update kernel as much as 6.3× N=33M, Okay=4096 (replace)
Out-of-core, giant scale as much as 10.5× N=400M, Okay=16384 vs fastkmeans

One failure mode issues for context. Standard PyTorch implementations run out of reminiscence in large-Okay regimes. They can’t materialize the N×Okay matrix. FAISS is the industry-standard library below many manufacturing vector-search methods.

The library additionally runs out-of-core. On one billion factors (Okay=32768, d=128), it finishes an iteration in 41.4s, towards 261.8s for the baseline. It makes use of chunked stream overlap to cover PCIe switch behind compute. A cache-aware compile heuristic additionally cuts tuning overhead by as much as 175×, inside 0.3% of tuned pace.

MTP Interactive Explainer

Marktechpost · Interactive Explainer

Flash-OkayMeans: actual k-means, rebuilt round GPU reminiscence

Same Lloyd’s math as commonplace k-means — sooner solely due to dataflow. Run clustering dwell, watch the replace bottleneck, and dimension the IO it removes.

17.9×end-to-end vs greatest baseline
33×vs NVIDIA cuML
200×+vs FAISS
1Bfactors, out-of-core

1 · Live clustering
2 · Update rivalry
3 · IO calculator





Iteration0
Centroid shift
Statusidle

This runs actual Lloyd’s k-means in your browser on 2-D factors. The algorithm is an identical to what Flash-OkayMeans accelerates — solely the GPU dataflow differs. Each step = one project + one centroid replace.

Press play. Standard scatter-update serializes when blocks write the identical “sizzling” centroid (purple stalls). Sort-Inverse Update types cluster IDs first, so every block merges contiguous segments with one atomic add — no battle.


Standard atomicsO(N·d)
Sort-Inverse atomicsO((Okay+N/B)·d)
Measured std bandwidth50 GB/s
Kernel speedup6.3×

Standard updates difficulty one atomic add per token. Many threads hit the identical centroid directly, inflicting rivalry. Sorting by cluster ID turns scatters into segment-level reductions in on-chip reminiscence.

Standard — materialize N×Okay matrix, O(NK)
FlashAssign — stream inputs, O(Nd+Kd)

much less HBM site visitors for the project step (theoretical)




Standard k-means writes then reads a full N×Okay distance matrix in HBM. FlashAssign by no means builds it — it reads X and C as soon as and writes assignments as soon as. Bars present relative HBM round-trips, FP16.

© Marktechpost
Speedups: Flash-OkayMeans paper (arXiv:2603.09229), NVIDIA H200. Demo runs in-browser for illustration · github.com/svg-project/flash-kmeans


Use Cases

Faster actual k-means modifications what you may run on-line, not simply offline.

  • Vector search indexing: FAISS builds its search indices with k-means. Faster k-means permits you to re-index as information shifts, as a substitute of rebuilding in a single day.
  • Sparse consideration routing: Routing Transformers and Tactic cluster tokens to route consideration. Millisecond k-means makes this viable contained in the inference loop.
  • KV-cache compression: ClusterKV clusters tokens in semantic area to compress the cache. Cheaper clustering makes per-layer, per-step compression sensible.
  • Low-bit KV quantization: Recent strategies cluster KV entries into codebooks, repeatedly. Faster clustering shrinks that preprocessing price.
  • Diffusion Transformers: Sparse VideoGen2 calls batched k-means throughout ahead passes. It permutes tokens by semantic similarity to take advantage of sparsity.

Using It

The API mirrors faiss and sklearn. The name beneath clusters a batched (B, N, d) tensor.

import torch
from flash_kmeans import batch_kmeans_Euclid

x = torch.randn(32, 75600, 128, machine="cuda", dtype=torch.float16)
cluster_ids, facilities, _ = batch_kmeans_Euclid(
    x, n_clusters=1000, tol=1e-4, verbose=True
)

A scikit-learn-style interface can also be out there.

from flash_kmeans import FlashOkayMeans

km = FlashOkayMeans(d=128, ok=8192, niter=100)
labels = km.fit_predict(large_cpu_tensor)  # machine=None makes use of all seen GPUs

The kernel auto-dispatches by form and dtype. A small-D path handles d≤512. A split-D path handles bigger d with out materializing the gap matrix. Multi-GPU runs set off routinely for large-N information held in CPU reminiscence.

Key Takeaways

  • Flash-OkayMeans is actual, not approximate — similar Lloyd’s math, sped up purely by GPU dataflow.
  • FlashAssign fuses distance + on-line argmin, reducing project IO from O(NK) to O(Nd+Kd) — as much as 21.2×.
  • Sort-Inverse Update types cluster IDs into segments, changing scatter atomics — as much as 6.3×.
  • Reports as much as 17.9× end-to-end, 33× over cuML, and over 200× over FAISS on an H200.
  • Scales out-of-core to 1 billion factors and cuts tuning overhead as much as 175×.


Check out the Paper and RepoAlso, be happy to comply with us on Twitter and don’t neglect to hitch our 150k+ML SubReddit and Subscribe to our Newsletter. Wait! are you on telegram? now you can join us on telegram as well.

Need to associate with us for selling your GitHub Repo OR Hugging Face Page OR Product Release OR Webinar and so on.? Connect with us

The put up Meet Flash-KMeans: An IO-Aware, Exact K-Means That Runs Over 200× Faster Than FAISS on GPUs appeared first on MarkTechPost.

Similar Posts