vLLM quantization workflow
Jan 18, 2025
This post is to discuss how quantization works in vLLM.
initialize llm_engine
INFO 12-23 02:54:03 llm_engine.py:234] Initializing an LLM engine (v0.6.6.dev31+gb880ffb8) with config: model='TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ', speculative_config=None, tokenizer='TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=32768, download_dir=None, load_format=auto, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=gptq_marlin, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='xgrammar'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=False, chunked_prefill_enabled=False, use_async_output_proc=True, disable_mm_preprocessor_cache=False, mm_processor_kwargs=None, pooler_config=None, compilation_config={"splitting_ops":["vllm.unified_attention","vllm.unified_attention_with_output"],"candidate_compile_sizes":[],"compile_sizes":[],"capture_sizes":[256,248,240,232,224,216,208,200,192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1],"max_capture_size":256}, use_cached_outputs=False,
INFO 12-23 02:54:11 [selector.py:120](http://selector.py:120/)] Using Flash Attention backend.
INFO 12-23 02:54:12 model_runner.py:1094] Starting to load model TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ...
create model_config
vllm/config.py: init(): store relevant infomation of the model in model_config.
create model: Initialize model + load weight
model_runner.py: self.model = get_model(vllm_config=self.vllm_config)
1.initialize model: Initialize a model with the given configurations → init() of each model
vllml/model_executor/model_loader/loader.py
with target_device:
model = _initialize_model(vllm_config=vllm_config)
vLLM will create layers(weights) according to quant_method when they initialize the model
self.quant_method.create_weights(
layer=self,
input_size_per_partition=self.input_size,
output_partition_sizes=self.output_partition_sizes,
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
weight_loader=(
self.weight_loader_v2 if self.quant_method.__class__.__name__
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
2.load weight: loader.py: call load_model()
- load weights
loaded_weights = model.load_weights( self._get_all_weights(model_config, model))
- Call process_weights_after_loading() of quantization to repack the weights for kernel
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if isinstance(quant_method, QuantizeMethodBase):
# When quant methods need to process weights after loading
# (for repacking, quantizing, etc), they expect parameters
# to be on the global target device. This scope is for the
# case where cpu offloading is used, where we will move the
# parameters onto device for processing and back off after.
with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(module)
Inference
If there is a quantization method, it will call its apply() in the forward path. Below is an example for gptq_marlin method.
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
workspace = MarlinWorkspace(self.output_size_per_partition,
GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MAX_PARALLEL)
scales = layer.marlin_scales
zeros = layer.marlin_zeros
orig_type = x.dtype
if orig_type != torch.float16:
x = x.to(torch.float16)
scales = scales.to(torch.float16)
zeros = zeros.to(torch.float16)
marlin_out = ops.gptq_marlin_gemm(
x,
layer.marlin_qweight,
scales,
zeros,
layer.g_idx,
layer.g_idx_sort_indices,
workspace.scratch,
scalar_types.uint4,
x.shape[0],
self.output_size_per_partition,
self.input_size_per_partition,
True, # is_k_full
True, # has_zp
True, # use 32-bit reduce
True, # use float zp
)
if orig_type != torch.float16:
marlin_out = marlin_out.to(orig_type)
if bias is not None:
marlin_out.add_(bias)
return marlin_out