Post-Process¶
Post-quantization process classes for fine-tuning quantized models.
Base Class¶
PostQuantizationProcess
dataclass
¶
Abstract base class for post-quantization processes
Post-quantization processes are executed after the main quantization
step (e.g., GPTQ, DBF). Each process receives a quantized model
on CPU (with quantized inference layers such as GPTQLinear)
and may modify it in-place.
Subclasses must implement run() method.
name is automatically set to the class name if not provided.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str or None
|
Human-readable name used in log messages. If None, automatically set to the class name. |
None
|
Examples:
Typical usage via Runner:
>>> from onecomp import Runner, ModelConfig, GPTQ, BlockWisePTQ
>>> model_config = ModelConfig(model_id="meta-llama/Llama-2-7b-hf")
>>> quantizer = GPTQ(wbits=4, groupsize=128)
>>> runner = Runner(
... model_config=model_config,
... quantizer=quantizer,
... post_processes=[BlockWisePTQ()],
... )
>>> runner.run()
run
abstractmethod
¶
Execute the post-quantization process.
The model is provided on CPU. Implementations may move it to
GPU for computation, but must move it back to CPU before
returning so that subsequent processes and Runner methods
(e.g. evaluation, saving) can work without device assumptions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
quantized_model
|
Module
|
The quantized model on CPU. Linear layers that were
quantized have already been replaced with quantized
inference layers (e.g. |
required |
model_config
|
ModelConfig
|
The model configuration (provides access to tokenizer, model id/path, device, etc.). |
required |
LoRA SFT¶
PostProcessLoraSFT
dataclass
¶
PostProcessLoraSFT(name: str = None, dataset_name: str | None = None, dataset_config_name: str | None = None, data_files: str | list[str] | dict[str, str] | None = None, train_split: str = 'train', text_column: str = 'text', max_train_samples: int | None = None, max_length: int = 1024, shuffle_seed: int = 42, lr: float = 0.0001, epochs: int = 4, batch_size: int = 1, gradient_accumulation_steps: int = 16, weight_decay: float = 0.0, warmup_ratio: float = 0.03, logging_steps: int = 10, lora_r: int = 16, lora_alpha: int = 32, lora_dropout: float = 0.05, target_modules: tuple[str, ...] | None = None, output_dir: str | None = None, use_bf16: bool | None = None, sft_loss_weight: float = 1.0, teacher_loss_weight: float = 0.0, teacher_loss_type: str = 'kl', teacher_temperature: float = 1.0, teacher_model_id: str | None = None, teacher_model_path: str | None = None, teacher_dtype: str | None = None, teacher_device: str | None = None, cache_teacher_outputs: bool = False, teacher_cache_dtype: str | None = None, intermediate_block_loss_weight: float = 0.0, cache_intermediate_outputs: bool = False, intermediate_block_indices: tuple[int, ...] | None = None, intermediate_cache_dtype: str | None = None)
Bases: PostQuantizationProcess
LoRA-based SFT post-process on a GPTQ-quantized model.
This post-process improves a GPTQ-quantized causal language model by
injecting LoRA adapters into selected GPTQLinear layers and training
only those adapters while keeping the quantized base weights frozen.
The given quantized_model is modified in-place.
Algorithm overview
- Load an SFT training dataset from
dataset_nameordata_files. - Tokenize the dataset and build causal LM labels.
- Find target
GPTQLinearmodules such asq_proj/k_proj/v_proj/o_proj/gate_proj/up_proj/down_proj. - Replace each target module with
LoRAGPTQLinear, which keeps the original GPTQ layer as the frozen base path and adds trainable LoRA low-rank updates. - Optimize the LoRA parameters with SFT loss and, optionally, an additional teacher distillation loss against an FP teacher model.
- Move the post-processed model back to CPU at the end so it can be
reused by
Runnerfor perplexity / accuracy evaluation.
Training objective
sft_loss_weight > 0enables the standard causal LM loss.teacher_loss_weight > 0enables teacher-guided distillation.teacher_loss_typeselects the distillation loss on logits (currently"kl"or"mse").cache_teacher_outputs=Trueprecomputes teacher logits on CPU to reduce repeated teacher forward passes during multi-epoch training.
Typical usage
- Use
data_files=...for local JSON/JSONL/CSV/TXT/Parquet files. - Use
dataset_name=...when loading from Hugging Face Datasets. - Pass this class to
Runner(post_processes=[...])after GPTQ quantization, or callrun()directly on a previously saved quantized model loaded withtorch.load(..., weights_only=False).
LoRA implementations
PostProcessLoraSFT: Standard LoRA SFT with only the causal LM objective by default.PostProcessLoraTeacherSFT: LoRA SFT with teacher distillation enabled by default. Use this when combining SFT loss and teacher loss.PostProcessLoraTeacherOnlySFT: LoRA training with only teacher distillation.
Examples:
Via Runner: >>> from onecomp import Runner, ModelConfig, GPTQ, PostProcessLoraSFT >>> model_config = ModelConfig(model_id="meta-llama/Llama-2-7b-hf") >>> quantizer = GPTQ(wbits=4, groupsize=128) >>> runner = Runner( ... model_config=model_config, ... quantizer=quantizer, ... post_processes=[ ... PostProcessLoraSFT(data_files="train.jsonl") ... ], ... ) >>> runner.run()
Direct execution on a saved quantized model: >>> import torch >>> from onecomp import ModelConfig, PostProcessLoraSFT >>> model_config = ModelConfig(model_id="meta-llama/Llama-2-7b-hf") >>> quantized_model = torch.load( ... "quantized_model.pt", ... map_location="cpu", ... weights_only=False, ... ) >>> post_process = PostProcessLoraSFT(data_files="train.jsonl") >>> post_process.run(quantized_model, model_config)
With teacher distillation enabled: >>> from onecomp import Runner, ModelConfig, GPTQ, PostProcessLoraTeacherSFT >>> model_config = ModelConfig(model_id="meta-llama/Llama-2-7b-hf") >>> quantizer = GPTQ(wbits=4, groupsize=128) >>> runner = Runner( ... model_config=model_config, ... quantizer=quantizer, ... post_processes=[ ... PostProcessLoraTeacherSFT( ... data_files="train.jsonl", ... teacher_model_id="meta-llama/Llama-2-7b-hf", ... ) ... ], ... ) >>> runner.run()
With teacher-logit caching enabled: >>> from onecomp import Runner, ModelConfig, GPTQ, PostProcessLoraSFT >>> model_config = ModelConfig(model_id="meta-llama/Llama-2-7b-hf") >>> quantizer = GPTQ(wbits=4, groupsize=128) >>> runner = Runner( ... model_config=model_config, ... quantizer=quantizer, ... post_processes=[ ... PostProcessLoraSFT( ... data_files="train.jsonl", ... teacher_loss_weight=1.0, ... cache_teacher_outputs=True, ... ) ... ], ... ) >>> runner.run()
run ¶
Run LoRA SFT on the GPTQ-quantized model in-place.
Convenience Variants¶
PostProcessLoraTeacherSFT and PostProcessLoraTeacherOnlySFT are pre-configured
variants of PostProcessLoraSFT with different default loss weights:
PostProcessLoraTeacherSFT
dataclass
¶
PostProcessLoraTeacherSFT(name: str = None, dataset_name: str | None = None, dataset_config_name: str | None = None, data_files: str | list[str] | dict[str, str] | None = None, train_split: str = 'train', text_column: str = 'text', max_train_samples: int | None = None, max_length: int = 1024, shuffle_seed: int = 42, lr: float = 0.0001, epochs: int = 4, batch_size: int = 1, gradient_accumulation_steps: int = 16, weight_decay: float = 0.0, warmup_ratio: float = 0.03, logging_steps: int = 10, lora_r: int = 16, lora_alpha: int = 32, lora_dropout: float = 0.05, target_modules: tuple[str, ...] | None = None, output_dir: str | None = None, use_bf16: bool | None = None, sft_loss_weight: float = 1.0, teacher_loss_weight: float = 1.0, teacher_loss_type: str = 'kl', teacher_temperature: float = 1.0, teacher_model_id: str | None = None, teacher_model_path: str | None = None, teacher_dtype: str | None = None, teacher_device: str | None = None, cache_teacher_outputs: bool = False, teacher_cache_dtype: str | None = None, intermediate_block_loss_weight: float = 0.0, cache_intermediate_outputs: bool = False, intermediate_block_indices: tuple[int, ...] | None = None, intermediate_cache_dtype: str | None = None)
LoRA SFT with teacher-guided distillation enabled by default.
PostProcessLoraTeacherOnlySFT
dataclass
¶
PostProcessLoraTeacherOnlySFT(name: str = None, dataset_name: str | None = None, dataset_config_name: str | None = None, data_files: str | list[str] | dict[str, str] | None = None, train_split: str = 'train', text_column: str = 'text', max_train_samples: int | None = None, max_length: int = 1024, shuffle_seed: int = 42, lr: float = 0.0001, epochs: int = 4, batch_size: int = 1, gradient_accumulation_steps: int = 16, weight_decay: float = 0.0, warmup_ratio: float = 0.03, logging_steps: int = 10, lora_r: int = 16, lora_alpha: int = 32, lora_dropout: float = 0.05, target_modules: tuple[str, ...] | None = None, output_dir: str | None = None, use_bf16: bool | None = None, sft_loss_weight: float = 0.0, teacher_loss_weight: float = 1.0, teacher_loss_type: str = 'kl', teacher_temperature: float = 1.0, teacher_model_id: str | None = None, teacher_model_path: str | None = None, teacher_dtype: str | None = None, teacher_device: str | None = None, cache_teacher_outputs: bool = False, teacher_cache_dtype: str | None = None, intermediate_block_loss_weight: float = 0.0, cache_intermediate_outputs: bool = False, intermediate_block_indices: tuple[int, ...] | None = None, intermediate_cache_dtype: str | None = None)
LoRA SFT variant that optimizes only the teacher distillation loss.