Post-Process¶
Post-quantization process classes for improving quantized model accuracy.
Base Class¶
PostQuantizationProcess
dataclass
¶
Abstract base class for post-quantization processes
Post-quantization processes are executed after the main quantization
step (e.g., GPTQ, DBF). Each process receives a quantized model
on CPU (with quantized inference layers such as GPTQLinear)
and may modify it in-place.
Subclasses must implement run() method.
name is automatically set to the class name if not provided.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str or None
|
Human-readable name used in log messages. If None, automatically set to the class name. |
None
|
Examples:
Typical usage via Runner:
>>> from onecomp import Runner, ModelConfig, GPTQ, BlockWisePTQ
>>> model_config = ModelConfig(model_id="meta-llama/Llama-2-7b-hf")
>>> quantizer = GPTQ(wbits=4, groupsize=128)
>>> runner = Runner(
... model_config=model_config,
... quantizer=quantizer,
... post_processes=[BlockWisePTQ()],
... )
>>> runner.run()
run
abstractmethod
¶
Execute the post-quantization process.
The model is provided on CPU. Implementations may move it to
GPU for computation, but must move it back to CPU before
returning so that subsequent processes and Runner methods
(e.g. evaluation, saving) can work without device assumptions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
quantized_model
|
Module
|
The quantized model on CPU. Linear layers that were
quantized have already been replaced with quantized
inference layers (e.g. |
required |
model_config
|
ModelConfig
|
The model configuration (provides access to tokenizer, model id/path, device, etc.). |
required |
Global PTQ¶
GlobalPTQ
dataclass
¶
GlobalPTQ(name: Optional[str] = None, epochs: int = 5, gptq_lr: float = 1e-05, temperature: float = 1.0, grad_clip: float = 1.0, dbf_lr: float = 5e-05, calibration_config: Optional[CalibrationConfig] = None, warmup_ratio: float = 0.1, min_lr_ratio: float = 0.01, eval_interval: int = 1, use_gradient_checkpointing: bool = True, early_stopping_patience: int = 0, use_mixed_precision: bool = False, grad_accum_steps: int = 1)
Bases: PostQuantizationProcess
Global Post-Training Quantization via KL distillation.
After layer-wise PTQ (GPTQ / DBF) quantises each linear layer independently, global PTQ minimises the KL divergence between an FP16 teacher model and the quantized student model across the entire sequence, fine-tuning continuous quantization parameters (scales and zeros for GPTQ; scaling factors for DBF).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
epochs
|
int
|
Number of distillation epochs. Default is 5. |
5
|
gptq_lr
|
float
|
Learning rate for GPTQ scales / zeros. Default is 1e-5. |
1e-05
|
temperature
|
float
|
Softmax temperature for KL divergence. Default is 1.0. |
1.0
|
grad_clip
|
float
|
Gradient clipping norm. Default is 1.0. |
1.0
|
dbf_lr
|
float
|
Learning rate for DBF scaling parameters. Default is 5e-5. |
5e-05
|
calibration_config
|
CalibrationConfig or None
|
Calibration data configuration. When |
None
|
warmup_ratio
|
float
|
Fraction of total steps used for LR warm-up. Default is 0.1. |
0.1
|
min_lr_ratio
|
float
|
Minimum LR as a fraction of peak LR (cosine decay floor). Default is 0.01. |
0.01
|
eval_interval
|
int
|
Evaluate every N epochs. Default is 1. |
1
|
use_gradient_checkpointing
|
bool
|
Enable gradient checkpointing to reduce GPU memory at the cost of recomputing activations during backpropagation. Default is True. |
True
|
early_stopping_patience
|
int
|
Stop training if eval KL does not improve for this many consecutive evaluations. 0 disables early stopping. Default is 0. |
0
|
use_mixed_precision
|
bool
|
Enable BF16 mixed-precision ( |
False
|
grad_accum_steps
|
int
|
Number of gradient accumulation steps before each optimiser update. Default is 1 (no accumulation). |
1
|
Examples:
>>> from onecomp import Runner, ModelConfig, GPTQ, GlobalPTQ, CalibrationConfig
>>> model_config = ModelConfig(model_id="Qwen/Qwen3-0.6B")
>>> quantizer = GPTQ(wbits=4, groupsize=128)
>>> runner = Runner(
... model_config=model_config,
... quantizer=quantizer,
... post_processes=[GlobalPTQ(epochs=5, gptq_lr=1e-5)],
... )
>>> runner.run()
run ¶
Execute global PTQ on the quantized model.
Modifies quantized_model in-place. The model is returned on
CPU in eval mode per the PostQuantizationProcess contract.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
quantized_model
|
Module
|
Quantized model on CPU (GPTQLinear / DoubleBinaryLinear). |
required |
model_config
|
ModelConfig
|
Model configuration (provides tokenizer, model path, etc.). |
required |
GlobalPTQDistributed
dataclass
¶
GlobalPTQDistributed(name: Optional[str] = None, temperature: float = 1.0, w_distill: float = 1.0, w_ntp: float = 0.0, gptq_lr: float = 1e-05, dbf_lr: float = 5e-05, calibration_config: Optional[CalibrationConfig] = None, epochs: int = 5, per_device_train_batch_size: int = 1, gradient_accumulation_steps: int = 1, warmup_ratio: float = 0.1, max_grad_norm: float = 1.0, lr_scheduler_type: str = 'cosine', use_gradient_checkpointing: bool = True, bf16: bool = True, deepspeed_config: Optional[str] = None, output_dir: Optional[str] = None, logging_steps: int = 1, report_to: Optional[str] = 'none', save_strategy: str = 'no', save_steps: Optional[int] = None, eval_interval: int = 1)
Bases: PostQuantizationProcess
Global PTQ via Trainer-based KL distillation.
Trainer-based implementation of global post-training quantization that supports single-GPU and multi-GPU (DeepSpeed) execution.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
temperature
|
float
|
Softmax temperature for KL divergence. Default is 1.0. |
1.0
|
w_distill
|
float
|
Weight for KL distillation loss. Default is 1.0. |
1.0
|
w_ntp
|
float
|
Weight for next-token prediction loss. Default is 0.0.
Setting |
0.0
|
gptq_lr
|
float
|
Learning rate for GPTQ scales / zeros. Default is 1e-5. |
1e-05
|
dbf_lr
|
float
|
Learning rate for DBF scaling parameters. Default is 5e-5. |
5e-05
|
calibration_config
|
CalibrationConfig or None
|
Calibration data configuration. When |
None
|
epochs
|
int
|
Number of distillation epochs. Default is 5. |
5
|
per_device_train_batch_size
|
int
|
Batch size per device. Default is 1. |
1
|
gradient_accumulation_steps
|
int
|
Gradient accumulation steps. Default is 1. |
1
|
warmup_ratio
|
float
|
Fraction of total steps for LR warm-up. Default is 0.1. |
0.1
|
max_grad_norm
|
float
|
Gradient clipping norm. Default is 1.0. |
1.0
|
lr_scheduler_type
|
str
|
Learning-rate scheduler type (any value accepted by
|
'cosine'
|
use_gradient_checkpointing
|
bool
|
Enable gradient checkpointing. Default is True. |
True
|
bf16
|
bool
|
Enable BF16 mixed-precision. Default is True. |
True
|
deepspeed_config
|
str or None
|
Path to a DeepSpeed JSON config file. |
None
|
output_dir
|
str or None
|
Directory for Trainer outputs (logs, checkpoints).
|
None
|
logging_steps
|
int
|
Log training metrics every N steps. Default is 1. |
1
|
report_to
|
str or list or None
|
Logging integrations (e.g. |
'none'
|
save_strategy
|
str
|
Checkpoint saving strategy passed to
|
'no'
|
save_steps
|
int or None
|
Save checkpoint every N steps (when
|
None
|
eval_interval
|
int
|
Evaluate every N epochs. Default is 1. |
1
|
Examples:
Single GPU:
>>> from onecomp import Runner, ModelConfig, GPTQ, GlobalPTQDistributed
>>> runner = Runner(
... model_config=ModelConfig(model_id="Qwen/Qwen3-0.6B"),
... quantizer=GPTQ(wbits=4, groupsize=128),
... post_processes=[GlobalPTQDistributed(epochs=5, gptq_lr=1e-5)],
... )
>>> runner.run()
With DeepSpeed ZeRO-2:
>>> GlobalPTQDistributed(
... epochs=5,
... deepspeed_config="configs/ds_zero2.json",
... per_device_train_batch_size=2,
... )
Pure QAT (no teacher model):
Custom calibration config:
>>> from onecomp import CalibrationConfig
>>> GlobalPTQDistributed(
... calibration_config=CalibrationConfig(
... calibration_dataset="wikitext2",
... num_calibration_samples=64,
... ),
... )
run ¶
Execute global PTQ via Trainer.
Modifies quantized_model in-place. The model is returned on
CPU in eval mode per the PostQuantizationProcess contract.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
quantized_model
|
Module
|
Quantized model on CPU (GPTQLinear / DoubleBinaryLinear). |
required |
model_config
|
ModelConfig
|
Model configuration (provides tokenizer, model path, etc.). |
required |
Block-wise PTQ¶
BlockWisePTQ
dataclass
¶
BlockWisePTQ(name: Optional[str] = None, lr: float = 0.0001, epochs: int = 10, cbq_enable: bool = False, gptq_lr: float = 0.001, gptq_optimize_intweight: bool = False, gptq_intweight_lr: float = 0.0001, grad_clip: float = 1.0, optimize_binary: bool = True, k_smooth: float = 100.0, cbq_epochs: int = 0, cbq_lr: float = 5e-05, calibration_config: CalibrationConfig = None)
Bases: PostQuantizationProcess
Block-wise Post-Training Quantization
After layer-wise PTQ (GPTQ / DBF / OneBit) quantises each linear layer independently, block-wise PTQ minimises intermediate-representation MSE against an FP16 teacher model at the Transformer-block granularity.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
lr
|
float
|
Learning rate for block-wise optimisation (DBF / OneBit / generic). Default is 1e-4. |
0.0001
|
epochs
|
int
|
Number of optimisation epochs per block. Default is 10. |
10
|
cbq_enable
|
bool
|
Whether to enable Cross-Block Quantisation (Phase 2) after greedy block-wise distillation. Default is False. |
False
|
gptq_lr
|
float
|
Learning rate for GPTQ scales/zeros optimisation. Default is 1e-3. |
0.001
|
gptq_optimize_intweight
|
bool
|
Whether to optimise integer weights via Smooth STE. Default is False. |
False
|
gptq_intweight_lr
|
float
|
Learning rate for integer weight optimisation. Default is 1e-4. |
0.0001
|
grad_clip
|
float
|
Gradient clipping norm. Default is 1.0. |
1.0
|
optimize_binary
|
bool
|
Whether to optimise binary matrices (DBF) / sign matrices (OneBit). Default is True. |
True
|
k_smooth
|
float
|
SmoothSign STE temperature for binary/sign optimisation. Default is 100.0. |
100.0
|
calibration_config
|
CalibrationConfig or None
|
Calibration data configuration. When |
None
|
Examples:
>>> from onecomp import Runner, ModelConfig, GPTQ, BlockWisePTQ
>>> model_config = ModelConfig(model_id="meta-llama/Llama-2-7b-hf")
>>> quantizer = GPTQ(wbits=4, groupsize=128)
>>> runner = Runner(
... model_config=model_config,
... quantizer=quantizer,
... post_processes=[BlockWisePTQ(lr=1e-4, epochs=10, cbq_enable=True)],
... )
>>> runner.run()
run ¶
Execute block-wise PTQ on the quantized model.
Modifies quantized_model in-place. Returns None (per PostQuantizationProcess interface).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
quantized_model
|
Module
|
Quantized model on CPU. |
required |
model_config
|
ModelConfig
|
Model configuration. |
required |
LoRA SFT¶
PostProcessLoraSFT
dataclass
¶
PostProcessLoraSFT(name: Optional[str] = None, dataset_name: str | None = None, dataset_config_name: str | None = None, data_files: str | list[str] | dict[str, str] | None = None, train_split: str = 'train', text_column: str = 'text', max_train_samples: int | None = None, max_length: int = 1024, shuffle_seed: int = 42, lr: float = 0.0001, epochs: int = 4, batch_size: int = 1, gradient_accumulation_steps: int = 16, weight_decay: float = 0.0, warmup_ratio: float = 0.03, logging_steps: int = 10, lora_r: int = 16, lora_alpha: int = 32, lora_dropout: float = 0.05, target_modules: tuple[str, ...] | None = None, output_dir: str | None = None, use_bf16: bool | None = None, sft_loss_weight: float = 1.0, teacher_loss_weight: float = 0.0, teacher_loss_type: str = 'kl', teacher_temperature: float = 1.0, teacher_model_id: str | None = None, teacher_model_path: str | None = None, teacher_dtype: str | None = None, teacher_device: str | None = None, cache_teacher_outputs: bool = False, teacher_cache_dtype: str | None = None, intermediate_block_loss_weight: float = 0.0, cache_intermediate_outputs: bool = False, intermediate_block_indices: tuple[int, ...] | None = None, intermediate_cache_dtype: str | None = None)
Bases: PostQuantizationProcess
LoRA-based SFT post-process on a GPTQ-quantized model.
This post-process improves a GPTQ-quantized causal language model by
injecting LoRA adapters into selected GPTQLinear layers and training
only those adapters while keeping the quantized base weights frozen.
The given quantized_model is modified in-place.
Algorithm overview
- Load an SFT training dataset from
dataset_nameordata_files. - Tokenize the dataset and build causal LM labels.
- Find target
GPTQLinearmodules such asq_proj/k_proj/v_proj/o_proj/gate_proj/up_proj/down_proj. - Replace each target module with
LoRAGPTQLinear, which keeps the original GPTQ layer as the frozen base path and adds trainable LoRA low-rank updates. - Optimize the LoRA parameters with SFT loss and, optionally, an additional teacher distillation loss against an FP teacher model.
- Move the post-processed model back to CPU at the end so it can be
reused by
Runnerfor perplexity / accuracy evaluation.
Training objective
sft_loss_weight > 0enables the standard causal LM loss.teacher_loss_weight > 0enables teacher-guided distillation.teacher_loss_typeselects the distillation loss on logits (currently"kl"or"mse").cache_teacher_outputs=Trueprecomputes teacher logits on CPU to reduce repeated teacher forward passes during multi-epoch training.
Typical usage
- Use
data_files=...for local JSON/JSONL/CSV/TXT/Parquet files. - Use
dataset_name=...when loading from Hugging Face Datasets. - Pass this class to
Runner(post_processes=[...])after GPTQ quantization, or callrun()directly on a previously saved quantized model loaded withtorch.load(..., weights_only=False).
LoRA implementations
PostProcessLoraSFT: Standard LoRA SFT with only the causal LM objective by default.PostProcessLoraTeacherSFT: LoRA SFT with teacher distillation enabled by default. Use this when combining SFT loss and teacher loss.PostProcessLoraTeacherOnlySFT: LoRA training with only teacher distillation.
Examples:
Via Runner: >>> from onecomp import Runner, ModelConfig, GPTQ, PostProcessLoraSFT >>> model_config = ModelConfig(model_id="meta-llama/Llama-2-7b-hf") >>> quantizer = GPTQ(wbits=4, groupsize=128) >>> runner = Runner( ... model_config=model_config, ... quantizer=quantizer, ... post_processes=[ ... PostProcessLoraSFT(data_files="train.jsonl") ... ], ... ) >>> runner.run()
Direct execution on a saved quantized model: >>> import torch >>> from onecomp import ModelConfig, PostProcessLoraSFT >>> model_config = ModelConfig(model_id="meta-llama/Llama-2-7b-hf") >>> quantized_model = torch.load( ... "quantized_model.pt", ... map_location="cpu", ... weights_only=False, ... ) >>> post_process = PostProcessLoraSFT(data_files="train.jsonl") >>> post_process.run(quantized_model, model_config)
With teacher distillation enabled: >>> from onecomp import Runner, ModelConfig, GPTQ, PostProcessLoraTeacherSFT >>> model_config = ModelConfig(model_id="meta-llama/Llama-2-7b-hf") >>> quantizer = GPTQ(wbits=4, groupsize=128) >>> runner = Runner( ... model_config=model_config, ... quantizer=quantizer, ... post_processes=[ ... PostProcessLoraTeacherSFT( ... data_files="train.jsonl", ... teacher_model_id="meta-llama/Llama-2-7b-hf", ... ) ... ], ... ) >>> runner.run()
With teacher-logit caching enabled: >>> from onecomp import Runner, ModelConfig, GPTQ, PostProcessLoraSFT >>> model_config = ModelConfig(model_id="meta-llama/Llama-2-7b-hf") >>> quantizer = GPTQ(wbits=4, groupsize=128) >>> runner = Runner( ... model_config=model_config, ... quantizer=quantizer, ... post_processes=[ ... PostProcessLoraSFT( ... data_files="train.jsonl", ... teacher_loss_weight=1.0, ... cache_teacher_outputs=True, ... ) ... ], ... ) >>> runner.run()
run ¶
Run LoRA SFT on the GPTQ-quantized model in-place.
Convenience Variants¶
PostProcessLoraTeacherSFT and PostProcessLoraTeacherOnlySFT are pre-configured
variants of PostProcessLoraSFT with different default loss weights:
PostProcessLoraTeacherSFT
dataclass
¶
PostProcessLoraTeacherSFT(name: Optional[str] = None, dataset_name: str | None = None, dataset_config_name: str | None = None, data_files: str | list[str] | dict[str, str] | None = None, train_split: str = 'train', text_column: str = 'text', max_train_samples: int | None = None, max_length: int = 1024, shuffle_seed: int = 42, lr: float = 0.0001, epochs: int = 4, batch_size: int = 1, gradient_accumulation_steps: int = 16, weight_decay: float = 0.0, warmup_ratio: float = 0.03, logging_steps: int = 10, lora_r: int = 16, lora_alpha: int = 32, lora_dropout: float = 0.05, target_modules: tuple[str, ...] | None = None, output_dir: str | None = None, use_bf16: bool | None = None, sft_loss_weight: float = 1.0, teacher_loss_weight: float = 1.0, teacher_loss_type: str = 'kl', teacher_temperature: float = 1.0, teacher_model_id: str | None = None, teacher_model_path: str | None = None, teacher_dtype: str | None = None, teacher_device: str | None = None, cache_teacher_outputs: bool = False, teacher_cache_dtype: str | None = None, intermediate_block_loss_weight: float = 0.0, cache_intermediate_outputs: bool = False, intermediate_block_indices: tuple[int, ...] | None = None, intermediate_cache_dtype: str | None = None)
LoRA SFT with teacher-guided distillation enabled by default.
PostProcessLoraTeacherOnlySFT
dataclass
¶
PostProcessLoraTeacherOnlySFT(name: Optional[str] = None, dataset_name: str | None = None, dataset_config_name: str | None = None, data_files: str | list[str] | dict[str, str] | None = None, train_split: str = 'train', text_column: str = 'text', max_train_samples: int | None = None, max_length: int = 1024, shuffle_seed: int = 42, lr: float = 0.0001, epochs: int = 4, batch_size: int = 1, gradient_accumulation_steps: int = 16, weight_decay: float = 0.0, warmup_ratio: float = 0.03, logging_steps: int = 10, lora_r: int = 16, lora_alpha: int = 32, lora_dropout: float = 0.05, target_modules: tuple[str, ...] | None = None, output_dir: str | None = None, use_bf16: bool | None = None, sft_loss_weight: float = 0.0, teacher_loss_weight: float = 1.0, teacher_loss_type: str = 'kl', teacher_temperature: float = 1.0, teacher_model_id: str | None = None, teacher_model_path: str | None = None, teacher_dtype: str | None = None, teacher_device: str | None = None, cache_teacher_outputs: bool = False, teacher_cache_dtype: str | None = None, intermediate_block_loss_weight: float = 0.0, cache_intermediate_outputs: bool = False, intermediate_block_indices: tuple[int, ...] | None = None, intermediate_cache_dtype: str | None = None)
LoRA SFT variant that optimizes only the teacher distillation loss.