kernels.quantize
kernels.quantize
Dequantization utilities for bitsandbytes and FP8 integration.
Functions
| Name | Description |
|---|---|
| dequantize | Fast NF4 dequantization using bitsandbytes CUDA kernels. |
| dequantize_fp8 | Dequantize FP8 block-quantized weights: W_dequant = W_fp8 * scale_inv. |
dequantize
kernels.quantize.dequantize(W, quant_state=None, out=None)Fast NF4 dequantization using bitsandbytes CUDA kernels.
Performs efficient dequantization of weights from NF4 format using bitsandbytes’
optimized CUDA implementations. Supports both legacy list and new QuantState
formats.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| W | torch.Tensor | Quantized weight tensor to dequantize | required |
| quant_state | QuantState | list | torch.Tensor | None | Quantization state containing metadata needed for dequantization. Can be either a QuantState object or legacy list format. If None, returns W unchanged. |
None |
| out | torch.Tensor | None | Optional output tensor for storing dequantized results. Must match expected shape and dtype if provided. | None |
Returns
| Name | Type | Description |
|---|---|---|
| torch.Tensor | Dequantized tensor in the specified dtype (fp16 or bf16). Will be transposed if | |
| torch.Tensor | input W was transposed. |
Raises
| Name | Type | Description |
|---|---|---|
| AssertionError | If provided output tensor doesn’t match expected shape / dtype. |
Note
Uses CUDA streams for better performance when available in newer bitsandbytes
versions (>0.43.3).
dequantize_fp8
kernels.quantize.dequantize_fp8(W, scale_inv, dtype=torch.bfloat16)Dequantize FP8 block-quantized weights: W_dequant = W_fp8 * scale_inv.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| W | torch.Tensor | FP8 weight tensor [out_features, in_features] in float8_e4m3fn. | required |
| scale_inv | torch.Tensor | Per-block inverse scale [ceil(out/block), ceil(in/block)] or per-tensor scalar. | required |
| dtype | torch.dtype | Output dtype (default bf16). | torch.bfloat16 |
Returns
| Name | Type | Description |
|---|---|---|
| torch.Tensor | Dequantized tensor in the specified dtype. |