Skip to content

Post-Process

Post-quantization process classes for improving quantized model accuracy.

Base Class

PostQuantizationProcess dataclass

PostQuantizationProcess(name: Optional[str] = None)

Abstract base class for post-quantization processes

Post-quantization processes are executed after the main quantization step (e.g., GPTQ, DBF). Each process receives a quantized model on CPU (with quantized inference layers such as GPTQLinear) and may modify it in-place.

Subclasses must implement run() method. name is automatically set to the class name if not provided.

Parameters:

Name Type Description Default
name str or None

Human-readable name used in log messages. If None, automatically set to the class name.

None

Examples:

Typical usage via Runner:

>>> from onecomp import Runner, ModelConfig, GPTQ, BlockWisePTQ
>>> model_config = ModelConfig(model_id="meta-llama/Llama-2-7b-hf")
>>> quantizer = GPTQ(wbits=4, groupsize=128)
>>> runner = Runner(
...     model_config=model_config,
...     quantizer=quantizer,
...     post_processes=[BlockWisePTQ()],
... )
>>> runner.run()

run abstractmethod

run(quantized_model: Module, model_config: ModelConfig) -> None

Execute the post-quantization process.

The model is provided on CPU. Implementations may move it to GPU for computation, but must move it back to CPU before returning so that subsequent processes and Runner methods (e.g. evaluation, saving) can work without device assumptions.

Parameters:

Name Type Description Default
quantized_model Module

The quantized model on CPU. Linear layers that were quantized have already been replaced with quantized inference layers (e.g. GPTQLinear, DoubleBinaryLinear). The process may modify the model in-place.

required
model_config ModelConfig

The model configuration (provides access to tokenizer, model id/path, device, etc.).

required

Block-wise PTQ

BlockWisePTQ dataclass

BlockWisePTQ(name: Optional[str] = None, lr: float = 0.0001, epochs: int = 10, cbq_enable: bool = False, gptq_lr: float = 0.001, gptq_optimize_intweight: bool = False, gptq_intweight_lr: float = 0.0001, grad_clip: float = 1.0, optimize_binary: bool = True, k_smooth: float = 100.0, cbq_epochs: int = 0, cbq_lr: float = 5e-05, calibration_config: CalibrationConfig = None)

Bases: PostQuantizationProcess

Block-wise Post-Training Quantization

After layer-wise PTQ (GPTQ / DBF / OneBit) quantises each linear layer independently, block-wise PTQ minimises intermediate-representation MSE against an FP16 teacher model at the Transformer-block granularity.

Parameters:

Name Type Description Default
lr float

Learning rate for block-wise optimisation (DBF / OneBit / generic). Default is 1e-4.

0.0001
epochs int

Number of optimisation epochs per block. Default is 10.

10
cbq_enable bool

Whether to enable Cross-Block Quantisation (Phase 2) after greedy block-wise distillation. Default is False.

False
gptq_lr float

Learning rate for GPTQ scales/zeros optimisation. Default is 1e-3.

0.001
gptq_optimize_intweight bool

Whether to optimise integer weights via Smooth STE. Default is False.

False
gptq_intweight_lr float

Learning rate for integer weight optimisation. Default is 1e-4.

0.0001
grad_clip float

Gradient clipping norm. Default is 1.0.

1.0
optimize_binary bool

Whether to optimise binary matrices (DBF) / sign matrices (OneBit). Default is True.

True
k_smooth float

SmoothSign STE temperature for binary/sign optimisation. Default is 100.0.

100.0
calibration_config CalibrationConfig or None

Calibration data configuration. When None (default), a :class:CalibrationConfig is created with num_calibration_samples=128. See :class:~onecomp.calibration.CalibrationConfig.

None

Examples:

>>> from onecomp import Runner, ModelConfig, GPTQ, BlockWisePTQ
>>> model_config = ModelConfig(model_id="meta-llama/Llama-2-7b-hf")
>>> quantizer = GPTQ(wbits=4, groupsize=128)
>>> runner = Runner(
...     model_config=model_config,
...     quantizer=quantizer,
...     post_processes=[BlockWisePTQ(lr=1e-4, epochs=10, cbq_enable=True)],
... )
>>> runner.run()

run

run(quantized_model: Module, model_config: ModelConfig) -> None

Execute block-wise PTQ on the quantized model.

Modifies quantized_model in-place. Returns None (per PostQuantizationProcess interface).

Parameters:

Name Type Description Default
quantized_model Module

Quantized model on CPU.

required
model_config ModelConfig

Model configuration.

required

LoRA SFT

PostProcessLoraSFT dataclass

PostProcessLoraSFT(name: Optional[str] = None, dataset_name: str | None = None, dataset_config_name: str | None = None, data_files: str | list[str] | dict[str, str] | None = None, train_split: str = 'train', text_column: str = 'text', max_train_samples: int | None = None, max_length: int = 1024, shuffle_seed: int = 42, lr: float = 0.0001, epochs: int = 4, batch_size: int = 1, gradient_accumulation_steps: int = 16, weight_decay: float = 0.0, warmup_ratio: float = 0.03, logging_steps: int = 10, lora_r: int = 16, lora_alpha: int = 32, lora_dropout: float = 0.05, target_modules: tuple[str, ...] | None = None, output_dir: str | None = None, use_bf16: bool | None = None, sft_loss_weight: float = 1.0, teacher_loss_weight: float = 0.0, teacher_loss_type: str = 'kl', teacher_temperature: float = 1.0, teacher_model_id: str | None = None, teacher_model_path: str | None = None, teacher_dtype: str | None = None, teacher_device: str | None = None, cache_teacher_outputs: bool = False, teacher_cache_dtype: str | None = None, intermediate_block_loss_weight: float = 0.0, cache_intermediate_outputs: bool = False, intermediate_block_indices: tuple[int, ...] | None = None, intermediate_cache_dtype: str | None = None)

Bases: PostQuantizationProcess

LoRA-based SFT post-process on a GPTQ-quantized model.

This post-process improves a GPTQ-quantized causal language model by injecting LoRA adapters into selected GPTQLinear layers and training only those adapters while keeping the quantized base weights frozen. The given quantized_model is modified in-place.

Algorithm overview
  1. Load an SFT training dataset from dataset_name or data_files.
  2. Tokenize the dataset and build causal LM labels.
  3. Find target GPTQLinear modules such as q_proj / k_proj / v_proj / o_proj / gate_proj / up_proj / down_proj.
  4. Replace each target module with LoRAGPTQLinear, which keeps the original GPTQ layer as the frozen base path and adds trainable LoRA low-rank updates.
  5. Optimize the LoRA parameters with SFT loss and, optionally, an additional teacher distillation loss against an FP teacher model.
  6. Move the post-processed model back to CPU at the end so it can be reused by Runner for perplexity / accuracy evaluation.
Training objective
  • sft_loss_weight > 0 enables the standard causal LM loss.
  • teacher_loss_weight > 0 enables teacher-guided distillation.
  • teacher_loss_type selects the distillation loss on logits (currently "kl" or "mse").
  • cache_teacher_outputs=True precomputes teacher logits on CPU to reduce repeated teacher forward passes during multi-epoch training.
Typical usage
  • Use data_files=... for local JSON/JSONL/CSV/TXT/Parquet files.
  • Use dataset_name=... when loading from Hugging Face Datasets.
  • Pass this class to Runner(post_processes=[...]) after GPTQ quantization, or call run() directly on a previously saved quantized model loaded with torch.load(..., weights_only=False).
LoRA implementations
  • PostProcessLoraSFT: Standard LoRA SFT with only the causal LM objective by default.
  • PostProcessLoraTeacherSFT: LoRA SFT with teacher distillation enabled by default. Use this when combining SFT loss and teacher loss.
  • PostProcessLoraTeacherOnlySFT: LoRA training with only teacher distillation.

Examples:

Via Runner: >>> from onecomp import Runner, ModelConfig, GPTQ, PostProcessLoraSFT >>> model_config = ModelConfig(model_id="meta-llama/Llama-2-7b-hf") >>> quantizer = GPTQ(wbits=4, groupsize=128) >>> runner = Runner( ... model_config=model_config, ... quantizer=quantizer, ... post_processes=[ ... PostProcessLoraSFT(data_files="train.jsonl") ... ], ... ) >>> runner.run()

Direct execution on a saved quantized model: >>> import torch >>> from onecomp import ModelConfig, PostProcessLoraSFT >>> model_config = ModelConfig(model_id="meta-llama/Llama-2-7b-hf") >>> quantized_model = torch.load( ... "quantized_model.pt", ... map_location="cpu", ... weights_only=False, ... ) >>> post_process = PostProcessLoraSFT(data_files="train.jsonl") >>> post_process.run(quantized_model, model_config)

With teacher distillation enabled: >>> from onecomp import Runner, ModelConfig, GPTQ, PostProcessLoraTeacherSFT >>> model_config = ModelConfig(model_id="meta-llama/Llama-2-7b-hf") >>> quantizer = GPTQ(wbits=4, groupsize=128) >>> runner = Runner( ... model_config=model_config, ... quantizer=quantizer, ... post_processes=[ ... PostProcessLoraTeacherSFT( ... data_files="train.jsonl", ... teacher_model_id="meta-llama/Llama-2-7b-hf", ... ) ... ], ... ) >>> runner.run()

With teacher-logit caching enabled: >>> from onecomp import Runner, ModelConfig, GPTQ, PostProcessLoraSFT >>> model_config = ModelConfig(model_id="meta-llama/Llama-2-7b-hf") >>> quantizer = GPTQ(wbits=4, groupsize=128) >>> runner = Runner( ... model_config=model_config, ... quantizer=quantizer, ... post_processes=[ ... PostProcessLoraSFT( ... data_files="train.jsonl", ... teacher_loss_weight=1.0, ... cache_teacher_outputs=True, ... ) ... ], ... ) >>> runner.run()

run

run(quantized_model: Module, model_config: ModelConfig) -> None

Run LoRA SFT on the GPTQ-quantized model in-place.

Convenience Variants

PostProcessLoraTeacherSFT and PostProcessLoraTeacherOnlySFT are pre-configured variants of PostProcessLoraSFT with different default loss weights:

PostProcessLoraTeacherSFT dataclass

PostProcessLoraTeacherSFT(name: Optional[str] = None, dataset_name: str | None = None, dataset_config_name: str | None = None, data_files: str | list[str] | dict[str, str] | None = None, train_split: str = 'train', text_column: str = 'text', max_train_samples: int | None = None, max_length: int = 1024, shuffle_seed: int = 42, lr: float = 0.0001, epochs: int = 4, batch_size: int = 1, gradient_accumulation_steps: int = 16, weight_decay: float = 0.0, warmup_ratio: float = 0.03, logging_steps: int = 10, lora_r: int = 16, lora_alpha: int = 32, lora_dropout: float = 0.05, target_modules: tuple[str, ...] | None = None, output_dir: str | None = None, use_bf16: bool | None = None, sft_loss_weight: float = 1.0, teacher_loss_weight: float = 1.0, teacher_loss_type: str = 'kl', teacher_temperature: float = 1.0, teacher_model_id: str | None = None, teacher_model_path: str | None = None, teacher_dtype: str | None = None, teacher_device: str | None = None, cache_teacher_outputs: bool = False, teacher_cache_dtype: str | None = None, intermediate_block_loss_weight: float = 0.0, cache_intermediate_outputs: bool = False, intermediate_block_indices: tuple[int, ...] | None = None, intermediate_cache_dtype: str | None = None)

LoRA SFT with teacher-guided distillation enabled by default.

PostProcessLoraTeacherOnlySFT dataclass

PostProcessLoraTeacherOnlySFT(name: Optional[str] = None, dataset_name: str | None = None, dataset_config_name: str | None = None, data_files: str | list[str] | dict[str, str] | None = None, train_split: str = 'train', text_column: str = 'text', max_train_samples: int | None = None, max_length: int = 1024, shuffle_seed: int = 42, lr: float = 0.0001, epochs: int = 4, batch_size: int = 1, gradient_accumulation_steps: int = 16, weight_decay: float = 0.0, warmup_ratio: float = 0.03, logging_steps: int = 10, lora_r: int = 16, lora_alpha: int = 32, lora_dropout: float = 0.05, target_modules: tuple[str, ...] | None = None, output_dir: str | None = None, use_bf16: bool | None = None, sft_loss_weight: float = 0.0, teacher_loss_weight: float = 1.0, teacher_loss_type: str = 'kl', teacher_temperature: float = 1.0, teacher_model_id: str | None = None, teacher_model_path: str | None = None, teacher_dtype: str | None = None, teacher_device: str | None = None, cache_teacher_outputs: bool = False, teacher_cache_dtype: str | None = None, intermediate_block_loss_weight: float = 0.0, cache_intermediate_outputs: bool = False, intermediate_block_indices: tuple[int, ...] | None = None, intermediate_cache_dtype: str | None = None)

LoRA SFT variant that optimizes only the teacher distillation loss.