| Blog | Documentation | Slack| Discussion Forum |
FlashInfer is a library and kernel generator for Large Language Models that provides high-performance implementation of LLM GPU kernels such as FlashAttention, SparseAttention, PageAttention, Sampling, and more. FlashInfer focuses on LLM serving and inference, and delivers state-of-the-art performance across diverse scenarios.
Check our v0.2 release blog for new features!
The core features of FlashInfer include:
plan/run stage of attention computation where we schedule the computation of variable-length inputs in plan stage to alleviate load-imbalance issue.FlashInfer support PyTorch, TVM and C++ (header-only) APIs, and can be easily integrated into existing projects.
Using our PyTorch API is the easiest way to get started:
We provide prebuilt wheels for Linux. You can install FlashInfer with the following command:
# For CUDA 12.4 & torch 2.4
pip install flashinfer -i https://flashinfer.ai/whl/cu124/torch2.4
# For other CUDA & torch versions, please check https://docs.flashinfer.ai/installation.htmlWe also offer nightly-built wheels to try the latest features from the main branch:
pip install flashinfer -i https://flashinfer.ai/whl/nightly/cu124/torch2.4Alternatively, you can build FlashInfer from source:
git clone https://github.com/flashinfer-ai/flashinfer.git --recursive
cd flashinfer
pip install -e . -vBy default, FlashInfer uses Just-In-Time (JIT) compilation for its kernels. To pre-compile essential kernels, set the environment variable FLASHINFER_ENABLE_AOT=1 before running the installation command:
FLASHINFER_ENABLE_AOT=1 pip install -e . -vFor more details, refer to the Install from Source documentation.
Below is a minimal example of using FlashInfer's single-request decode/append/prefill attention kernels:
import torch
import flashinfer
kv_len = 2048
num_kv_heads = 32
head_dim = 128
k = torch.randn(kv_len, num_kv_heads, head_dim).half().to(0)
v = torch.randn(kv_len, num_kv_heads, head_dim).half().to(0)
# decode attention
num_qo_heads = 32
q = torch.randn(num_qo_heads, head_dim).half().to(0)
o = flashinfer.single_decode_with_kv_cache(q, k, v) # decode attention without RoPE on-the-fly
o_rope_on_the_fly = flashinfer.single_decode_with_kv_cache(q, k, v, pos_encoding_mode="ROPE_LLAMA") # decode with LLaMA style RoPE on-the-fly
# append attention
append_qo_len = 128
q = torch.randn(append_qo_len, num_qo_heads, head_dim).half().to(0) # append attention, the last 128 tokens in the KV-Cache are the new tokens
o = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=True) # append attention without RoPE on-the-fly, apply causal mask
o_rope_on_the_fly = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=True, pos_encoding_mode="ROPE_LLAMA") # append attention with LLaMA style RoPE on-the-fly, apply causal mask
# prefill attention
qo_len = 2048
q = torch.randn(qo_len, num_qo_heads, head_dim).half().to(0) # prefill attention
o = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=False) # prefill attention without RoPE on-the-fly, do not apply causal maskCheck out documentation for usage of batch decode/append/prefill kernels and shared-prefix cascading kernels.
We profile FlashInfer kernel performance with nvbench and you can compile and run the benchmarks with the following commands:
mkdir build
cp cmake/config.cmake build # you can modify the config.cmake to enable/disable benchmarks and change CUDA architectures
cd build
cmake ..
make -j12You can run ./bench_{single/batch}_{prefill/decode} to benchmark the performance (e.g. ./bench_single_prefill for single-request prefill attention). ./bench_{single/batch}_{prefill/decode} --help will show you the available options.
FlashInfer also provides C++ API and TVM bindings, please refer to documentation for more details.
We are thrilled to share that FlashInfer is being adopted by many cutting-edge projects, including but not limited to:
FlashInfer is inspired by FlashAttention 1&2, vLLM, stream-K, cutlass and AITemplate projects.