kernels.quantize

kernels.quantize

Dequantization utilities for bitsandbytes and FP8 integration.

Functions

Name Description
dequantize Fast NF4 dequantization using bitsandbytes CUDA kernels.
dequantize_fp8 Dequantize FP8 block-quantized weights: W_dequant = W_fp8 * scale_inv.

dequantize

kernels.quantize.dequantize(W, quant_state=None, out=None)

Fast NF4 dequantization using bitsandbytes CUDA kernels.

Performs efficient dequantization of weights from NF4 format using bitsandbytes’ optimized CUDA implementations. Supports both legacy list and new QuantState formats.

Parameters

Name Type Description Default
W torch.Tensor Quantized weight tensor to dequantize required
quant_state QuantState | list | torch.Tensor | None Quantization state containing metadata needed for dequantization. Can be either a QuantState object or legacy list format. If None, returns W unchanged. None
out torch.Tensor | None Optional output tensor for storing dequantized results. Must match expected shape and dtype if provided. None

Returns

Name Type Description
torch.Tensor Dequantized tensor in the specified dtype (fp16 or bf16). Will be transposed if
torch.Tensor input W was transposed.

Raises

Name Type Description
AssertionError If provided output tensor doesn’t match expected shape / dtype.

Note

Uses CUDA streams for better performance when available in newer bitsandbytes versions (>0.43.3).

dequantize_fp8

kernels.quantize.dequantize_fp8(W, scale_inv, dtype=torch.bfloat16)

Dequantize FP8 block-quantized weights: W_dequant = W_fp8 * scale_inv.

Parameters

Name Type Description Default
W torch.Tensor FP8 weight tensor [out_features, in_features] in float8_e4m3fn. required
scale_inv torch.Tensor Per-block inverse scale [ceil(out/block), ceil(in/block)] or per-tensor scalar. required
dtype torch.dtype Output dtype (default bf16). torch.bfloat16

Returns

Name Type Description
torch.Tensor Dequantized tensor in the specified dtype.