Post-Process (LoRA SFT)¶
OneComp supports post-quantization processing — additional steps applied to a quantized model to improve accuracy or inject domain-specific knowledge. The main implementation is LoRA SFT, which fine-tunes quantized models using Low-Rank Adaptation (LoRA) adapters.
Overview¶
The post-process framework integrates into the Runner pipeline via the post_processes parameter. After quantization completes, Runner builds a quantized model on CPU and executes each process in order. The processed model is stored as runner.quantized_model and is automatically used by subsequent evaluation and save operations.
LoRA SFT: Accuracy Recovery¶
The most common use case is recovering accuracy lost during quantization. Provide a general-purpose dataset (e.g., WikiText-2) to fine-tune the quantized model:
from onecomp import GPTQ, ModelConfig, Runner, PostProcessLoraSFT, setup_logger
setup_logger()
model_config = ModelConfig(
model_id="meta-llama/Llama-2-7b-hf",
device="cuda:0",
)
gptq = GPTQ(wbits=4, groupsize=128)
post_process = PostProcessLoraSFT(
dataset_name="wikitext",
dataset_config_name="wikitext-2-raw-v1",
train_split="train",
text_column="text",
max_train_samples=256,
max_length=512,
epochs=4,
batch_size=2,
gradient_accumulation_steps=8,
lr=1e-4,
lora_r=16,
lora_alpha=32,
)
runner = Runner(
model_config=model_config,
quantizer=gptq,
post_processes=[post_process],
)
runner.run()
# Evaluate: PPL should be lower than without LoRA SFT
original_ppl, _, quantized_ppl = runner.calculate_perplexity(
original_model=True, quantized_model=True,
)
print(f"Original PPL: {original_ppl:.4f}")
print(f"Quantized + LoRA SFT PPL: {quantized_ppl:.4f}")
Tip
A complete working example is available at
example/post_process/example_lora_sft.py.
LoRA SFT: Knowledge Injection¶
LoRA SFT also supports injecting new knowledge into a quantized model using custom training data. Provide a JSONL file where each line has a "text" field:
{"text": "OneCompression (OneComp) is an open-source Python library for LLM quantization developed by Fujitsu."}
{"text": "OneComp supports GPTQ, DBF, RTN, and AutoBit quantization methods."}
Then pass the file path to data_files:
post_process = PostProcessLoraSFT(
data_files="./my_knowledge.jsonl",
max_length=256,
epochs=20,
batch_size=2,
lr=3e-4,
lora_r=16,
lora_alpha=32,
)
runner = Runner(
model_config=model_config,
quantizer=gptq,
post_processes=[post_process],
)
runner.run()
After training, the model can generate responses based on the injected knowledge.
Tip
A complete working example with before/after comparison is available at
example/post_process/example_lora_sft_knowledge.py.
Saving and Loading LoRA Models¶
LoRA-applied models contain custom module types (LoRAGPTQLinear) that are not compatible with the standard safetensors format. Use the dedicated PyTorch .pt save/load methods instead:
Save¶
# After quantization + LoRA SFT
runner.run()
# Save the LoRA-applied model (PyTorch .pt format)
runner.save_quantized_model_pt("./my_model_lora")
Load¶
from onecomp import load_quantized_model_pt
model, tokenizer = load_quantized_model_pt("./my_model_lora")
save_quantized_model vs save_quantized_model_pt
| Method | Format | Use Case |
|---|---|---|
save_quantized_model() |
safetensors (HF-compatible) | Standard quantized models (no post-processing) |
save_quantized_model_pt() |
PyTorch .pt |
Post-processed models (e.g. LoRA adapters) |
Similarly, use load_quantized_model() for safetensors and load_quantized_model_pt() for .pt files.
Data Sources¶
PostProcessLoraSFT supports two ways to provide training data:
Hugging Face Datasets¶
PostProcessLoraSFT(
dataset_name="wikitext",
dataset_config_name="wikitext-2-raw-v1",
train_split="train",
text_column="text",
)
Local Files¶
Supported formats: JSON, JSONL, CSV, TXT, Parquet.
Key Parameters¶
| Parameter | Default | Description |
|---|---|---|
epochs |
4 |
Number of training epochs |
lr |
1e-4 |
Learning rate |
batch_size |
1 |
Training batch size |
gradient_accumulation_steps |
16 |
Gradient accumulation steps |
max_length |
1024 |
Maximum sequence length for tokenization |
max_train_samples |
None |
Cap on number of training samples (unlimited if None) |
lora_r |
16 |
LoRA rank |
lora_alpha |
32 |
LoRA scaling factor (effective scaling = alpha / r) |
lora_dropout |
0.05 |
LoRA dropout rate |
target_modules |
None |
Module name suffixes to wrap with LoRA. Defaults to q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj |
warmup_ratio |
0.03 |
Learning rate warmup ratio |
weight_decay |
0.0 |
Weight decay |
use_bf16 |
None |
Use bfloat16 training. Auto-detected from GPU capability if None |
See the API Reference for the full parameter list.
Advanced: Teacher Distillation¶
Teacher distillation aligns the quantized model's output distribution with a full-precision teacher model. This can improve accuracy beyond what SFT alone achieves:
post_process = PostProcessLoraSFT(
dataset_name="wikitext",
dataset_config_name="wikitext-2-raw-v1",
train_split="train",
text_column="text",
sft_loss_weight=1.0,
teacher_loss_weight=0.5, # Enable teacher distillation
teacher_loss_type="kl", # "kl" or "mse"
teacher_temperature=1.0,
teacher_model_id="meta-llama/Llama-2-7b-hf", # Full-precision teacher
cache_teacher_outputs=True, # Pre-compute teacher logits for speed
)
| Parameter | Default | Description |
|---|---|---|
sft_loss_weight |
1.0 |
Weight for causal LM (SFT) loss |
teacher_loss_weight |
0.0 |
Weight for teacher distillation loss (0 = disabled) |
teacher_loss_type |
"kl" |
"kl" (KL divergence) or "mse" (mean squared error) on logits |
teacher_temperature |
1.0 |
Temperature for softening teacher logits |
teacher_model_id |
None |
Hugging Face model ID for the teacher |
teacher_model_path |
None |
Local path for the teacher model |
cache_teacher_outputs |
False |
Pre-compute and cache teacher outputs on CPU |
Advanced: Intermediate Block Alignment¶
Intermediate block alignment adds a loss term that aligns hidden states at selected transformer blocks between the teacher and student models:
post_process = PostProcessLoraSFT(
dataset_name="wikitext",
dataset_config_name="wikitext-2-raw-v1",
train_split="train",
text_column="text",
teacher_model_id="meta-llama/Llama-2-7b-hf",
intermediate_block_loss_weight=0.1,
intermediate_block_indices=[8, 16, 24],
cache_intermediate_outputs=True,
)
| Parameter | Default | Description |
|---|---|---|
intermediate_block_loss_weight |
0.0 |
Weight for intermediate alignment loss (0 = disabled) |
intermediate_block_indices |
None |
Transformer block indices to align |
cache_intermediate_outputs |
False |
Pre-compute and cache teacher block outputs |
Limitations¶
vLLM Inference
LoRA-applied models saved with save_quantized_model_pt() are not currently supported by the vLLM plugins. vLLM integration for LoRA post-processed models is planned for a future release.
For standard quantized models (without LoRA), use save_quantized_model() and serve via vLLM as described in the vLLM Inference guide.
Supported Quantizers
LoRA SFT currently supports GPTQ-quantized models only. Support for other quantization methods (DBF, RTN) may be added in the future.