Skip to content

Post-Process

Post-quantization process classes for fine-tuning quantized models.

Base Class

PostQuantizationProcess dataclass

PostQuantizationProcess(name: str = None)

Abstract base class for post-quantization processes

Post-quantization processes are executed after the main quantization step (e.g., GPTQ, DBF). Each process receives a quantized model on CPU (with quantized inference layers such as GPTQLinear) and may modify it in-place.

Subclasses must implement run() method. name is automatically set to the class name if not provided.

Parameters:

Name Type Description Default
name str or None

Human-readable name used in log messages. If None, automatically set to the class name.

None

Examples:

Typical usage via Runner:

>>> from onecomp import Runner, ModelConfig, GPTQ, BlockWisePTQ
>>> model_config = ModelConfig(model_id="meta-llama/Llama-2-7b-hf")
>>> quantizer = GPTQ(wbits=4, groupsize=128)
>>> runner = Runner(
...     model_config=model_config,
...     quantizer=quantizer,
...     post_processes=[BlockWisePTQ()],
... )
>>> runner.run()

run abstractmethod

run(quantized_model: Module, model_config: ModelConfig) -> None

Execute the post-quantization process.

The model is provided on CPU. Implementations may move it to GPU for computation, but must move it back to CPU before returning so that subsequent processes and Runner methods (e.g. evaluation, saving) can work without device assumptions.

Parameters:

Name Type Description Default
quantized_model Module

The quantized model on CPU. Linear layers that were quantized have already been replaced with quantized inference layers (e.g. GPTQLinear, DoubleBinaryLinear). The process may modify the model in-place.

required
model_config ModelConfig

The model configuration (provides access to tokenizer, model id/path, device, etc.).

required

LoRA SFT

PostProcessLoraSFT dataclass

PostProcessLoraSFT(name: str = None, dataset_name: str | None = None, dataset_config_name: str | None = None, data_files: str | list[str] | dict[str, str] | None = None, train_split: str = 'train', text_column: str = 'text', max_train_samples: int | None = None, max_length: int = 1024, shuffle_seed: int = 42, lr: float = 0.0001, epochs: int = 4, batch_size: int = 1, gradient_accumulation_steps: int = 16, weight_decay: float = 0.0, warmup_ratio: float = 0.03, logging_steps: int = 10, lora_r: int = 16, lora_alpha: int = 32, lora_dropout: float = 0.05, target_modules: tuple[str, ...] | None = None, output_dir: str | None = None, use_bf16: bool | None = None, sft_loss_weight: float = 1.0, teacher_loss_weight: float = 0.0, teacher_loss_type: str = 'kl', teacher_temperature: float = 1.0, teacher_model_id: str | None = None, teacher_model_path: str | None = None, teacher_dtype: str | None = None, teacher_device: str | None = None, cache_teacher_outputs: bool = False, teacher_cache_dtype: str | None = None, intermediate_block_loss_weight: float = 0.0, cache_intermediate_outputs: bool = False, intermediate_block_indices: tuple[int, ...] | None = None, intermediate_cache_dtype: str | None = None)

Bases: PostQuantizationProcess

LoRA-based SFT post-process on a GPTQ-quantized model.

This post-process improves a GPTQ-quantized causal language model by injecting LoRA adapters into selected GPTQLinear layers and training only those adapters while keeping the quantized base weights frozen. The given quantized_model is modified in-place.

Algorithm overview
  1. Load an SFT training dataset from dataset_name or data_files.
  2. Tokenize the dataset and build causal LM labels.
  3. Find target GPTQLinear modules such as q_proj / k_proj / v_proj / o_proj / gate_proj / up_proj / down_proj.
  4. Replace each target module with LoRAGPTQLinear, which keeps the original GPTQ layer as the frozen base path and adds trainable LoRA low-rank updates.
  5. Optimize the LoRA parameters with SFT loss and, optionally, an additional teacher distillation loss against an FP teacher model.
  6. Move the post-processed model back to CPU at the end so it can be reused by Runner for perplexity / accuracy evaluation.
Training objective
  • sft_loss_weight > 0 enables the standard causal LM loss.
  • teacher_loss_weight > 0 enables teacher-guided distillation.
  • teacher_loss_type selects the distillation loss on logits (currently "kl" or "mse").
  • cache_teacher_outputs=True precomputes teacher logits on CPU to reduce repeated teacher forward passes during multi-epoch training.
Typical usage
  • Use data_files=... for local JSON/JSONL/CSV/TXT/Parquet files.
  • Use dataset_name=... when loading from Hugging Face Datasets.
  • Pass this class to Runner(post_processes=[...]) after GPTQ quantization, or call run() directly on a previously saved quantized model loaded with torch.load(..., weights_only=False).
LoRA implementations
  • PostProcessLoraSFT: Standard LoRA SFT with only the causal LM objective by default.
  • PostProcessLoraTeacherSFT: LoRA SFT with teacher distillation enabled by default. Use this when combining SFT loss and teacher loss.
  • PostProcessLoraTeacherOnlySFT: LoRA training with only teacher distillation.

Examples:

Via Runner: >>> from onecomp import Runner, ModelConfig, GPTQ, PostProcessLoraSFT >>> model_config = ModelConfig(model_id="meta-llama/Llama-2-7b-hf") >>> quantizer = GPTQ(wbits=4, groupsize=128) >>> runner = Runner( ... model_config=model_config, ... quantizer=quantizer, ... post_processes=[ ... PostProcessLoraSFT(data_files="train.jsonl") ... ], ... ) >>> runner.run()

Direct execution on a saved quantized model: >>> import torch >>> from onecomp import ModelConfig, PostProcessLoraSFT >>> model_config = ModelConfig(model_id="meta-llama/Llama-2-7b-hf") >>> quantized_model = torch.load( ... "quantized_model.pt", ... map_location="cpu", ... weights_only=False, ... ) >>> post_process = PostProcessLoraSFT(data_files="train.jsonl") >>> post_process.run(quantized_model, model_config)

With teacher distillation enabled: >>> from onecomp import Runner, ModelConfig, GPTQ, PostProcessLoraTeacherSFT >>> model_config = ModelConfig(model_id="meta-llama/Llama-2-7b-hf") >>> quantizer = GPTQ(wbits=4, groupsize=128) >>> runner = Runner( ... model_config=model_config, ... quantizer=quantizer, ... post_processes=[ ... PostProcessLoraTeacherSFT( ... data_files="train.jsonl", ... teacher_model_id="meta-llama/Llama-2-7b-hf", ... ) ... ], ... ) >>> runner.run()

With teacher-logit caching enabled: >>> from onecomp import Runner, ModelConfig, GPTQ, PostProcessLoraSFT >>> model_config = ModelConfig(model_id="meta-llama/Llama-2-7b-hf") >>> quantizer = GPTQ(wbits=4, groupsize=128) >>> runner = Runner( ... model_config=model_config, ... quantizer=quantizer, ... post_processes=[ ... PostProcessLoraSFT( ... data_files="train.jsonl", ... teacher_loss_weight=1.0, ... cache_teacher_outputs=True, ... ) ... ], ... ) >>> runner.run()

run

run(quantized_model: Module, model_config: ModelConfig) -> None

Run LoRA SFT on the GPTQ-quantized model in-place.

Convenience Variants

PostProcessLoraTeacherSFT and PostProcessLoraTeacherOnlySFT are pre-configured variants of PostProcessLoraSFT with different default loss weights:

PostProcessLoraTeacherSFT dataclass

PostProcessLoraTeacherSFT(name: str = None, dataset_name: str | None = None, dataset_config_name: str | None = None, data_files: str | list[str] | dict[str, str] | None = None, train_split: str = 'train', text_column: str = 'text', max_train_samples: int | None = None, max_length: int = 1024, shuffle_seed: int = 42, lr: float = 0.0001, epochs: int = 4, batch_size: int = 1, gradient_accumulation_steps: int = 16, weight_decay: float = 0.0, warmup_ratio: float = 0.03, logging_steps: int = 10, lora_r: int = 16, lora_alpha: int = 32, lora_dropout: float = 0.05, target_modules: tuple[str, ...] | None = None, output_dir: str | None = None, use_bf16: bool | None = None, sft_loss_weight: float = 1.0, teacher_loss_weight: float = 1.0, teacher_loss_type: str = 'kl', teacher_temperature: float = 1.0, teacher_model_id: str | None = None, teacher_model_path: str | None = None, teacher_dtype: str | None = None, teacher_device: str | None = None, cache_teacher_outputs: bool = False, teacher_cache_dtype: str | None = None, intermediate_block_loss_weight: float = 0.0, cache_intermediate_outputs: bool = False, intermediate_block_indices: tuple[int, ...] | None = None, intermediate_cache_dtype: str | None = None)

LoRA SFT with teacher-guided distillation enabled by default.

PostProcessLoraTeacherOnlySFT dataclass

PostProcessLoraTeacherOnlySFT(name: str = None, dataset_name: str | None = None, dataset_config_name: str | None = None, data_files: str | list[str] | dict[str, str] | None = None, train_split: str = 'train', text_column: str = 'text', max_train_samples: int | None = None, max_length: int = 1024, shuffle_seed: int = 42, lr: float = 0.0001, epochs: int = 4, batch_size: int = 1, gradient_accumulation_steps: int = 16, weight_decay: float = 0.0, warmup_ratio: float = 0.03, logging_steps: int = 10, lora_r: int = 16, lora_alpha: int = 32, lora_dropout: float = 0.05, target_modules: tuple[str, ...] | None = None, output_dir: str | None = None, use_bf16: bool | None = None, sft_loss_weight: float = 0.0, teacher_loss_weight: float = 1.0, teacher_loss_type: str = 'kl', teacher_temperature: float = 1.0, teacher_model_id: str | None = None, teacher_model_path: str | None = None, teacher_dtype: str | None = None, teacher_device: str | None = None, cache_teacher_outputs: bool = False, teacher_cache_dtype: str | None = None, intermediate_block_loss_weight: float = 0.0, cache_intermediate_outputs: bool = False, intermediate_block_indices: tuple[int, ...] | None = None, intermediate_cache_dtype: str | None = None)

LoRA SFT variant that optimizes only the teacher distillation loss.