Skip to content

Pre-Process (Rotation Preprocessing)

Rotation preprocessing reduces quantization error by learning optimal rotation matrices (SpinQuant/OstQuant) and absorbing them into model weights before quantization.

prepare_rotated_model

prepare_rotated_model

prepare_rotated_model(model_config: ModelConfig, save_directory: str, *, rotation: bool = True, scaling: bool = False, rotation_mode: str = 'random', scaling_mode: str = 'identity', seed: int = 0, enable_training: bool = True, calibration_dataset=None, max_length: int = 2048, num_calibration_samples: int = 128, calibration_strategy: str = 'drop_rand', wbits: int = 4, sym: bool = False, groupsize: int = -1, fp32_had: bool = False, use_sdpa: bool = False, training_args_override: dict | None = None) -> RotatedModelConfig

Train rotation matrices, apply them to model weights, and save.

Parameters:

Name Type Description Default
model_config ModelConfig

Original model configuration (model_id or path).

required
save_directory str

Directory to save the rotated model.

required
rotation bool

Whether to apply rotation matrices (R1, R2).

True
scaling bool

Whether to apply scaling diagonals (S_*).

False
rotation_mode str

"random" or "identity".

'random'
scaling_mode str

"identity" | "random_ones" | "random".

'identity'
seed int

Random seed for rotation matrix initialisation and calibration data preparation. Note that the Trainer uses a separate seed (TrainingArguments.seed, default 42) for data shuffling and training reproducibility.

0
enable_training bool

If True, train the rotation/scaling matrices; otherwise use the randomly initialised matrices directly.

True
calibration_dataset

List of texts for calibration. If None, the C4 dataset is used (same as Runner).

None
max_length int

Sequence length for calibration data (default: 2048).

2048
num_calibration_samples int

Number of calibration samples. Default matches Runner (128).

128
calibration_strategy str

Strategy for preparing calibration inputs ("drop_rand", "concat_chunk", etc.). See :func:~onecomp.utils.calibration.prepare_calibration_dataset.

'drop_rand'
wbits int

Weight quantisation bit-width for the RTN proxy during training. Should match the quantizer's wbits.

4
sym bool

Symmetric quantisation for the RTN proxy.

False
groupsize int

Group size for the RTN proxy (-1 = per-channel). Should match the quantizer's groupsize. When positive, the value must evenly divide the out_features of every nn.Linear layer in the model.

-1
fp32_had bool

Use FP32 for the online Hadamard transform.

False
use_sdpa bool

Use SDPA attention implementation during training.

False
training_args_override dict | None

Override TrainingArguments fields (dict).

None

Returns:

Type Description
RotatedModelConfig

class:~onecomp.rotated_model_config.RotatedModelConfig pointing at

RotatedModelConfig

save_directory.

Examples:

Basic usage:

>>> from onecomp import ModelConfig, prepare_rotated_model, GPTQ
>>> model_config = ModelConfig(model_id="meta-llama/Llama-2-7b-hf")
>>> rotated_config = prepare_rotated_model(
...     model_config=model_config,
...     save_directory="./rotated_model",
... )

Without training (random rotation only):

>>> rotated_config = prepare_rotated_model(
...     model_config=model_config,
...     save_directory="./rotated_model",
...     enable_training=False,
... )

RotatedModelConfig

ModelConfig subclass for loading rotation-preprocessed models. Automatically registers Hadamard forward_pre_hook on down_proj layers.

RotatedModelConfig

RotatedModelConfig(path: str = None, dtype: str = 'float16', device: str = 'auto', fp32_had: bool = None, **kwargs)

Bases: ModelConfig

ModelConfig subclass for rotation-preprocessed models.

Inherits ModelConfig and automatically registers deterministic Hadamard forward_pre_hook on down_proj layers when load_model() is called.

The saved model directory should contain:

  • config.json — HuggingFace model config (includes fp32_had field)
  • model.safetensors — rotation-applied weights
  • tokenizer.json

Parameters:

Name Type Description Default
path str

Path to the saved rotated model (required).

None
dtype str

Data type. Defaults to "float16".

'float16'
device str

Device. Defaults to "auto".

'auto'
fp32_had bool or None

Use FP32 for online Hadamard transform. If None (default), auto-detect from rotation_config.json in the model directory. Falls back to False if not found.

None
Example

from onecomp import Runner, RotatedModelConfig, GPTQ

model_config = RotatedModelConfig(path="./rotated_model") quantizer = GPTQ(wbits=4, groupsize=128) runner = Runner(model_config=model_config, quantizer=quantizer) runner.run()

load_model

load_model(**kwargs)

Load the rotated model and register Hadamard hooks.

Returns:

Type Description

nn.Module: Model with Hadamard pre-hooks registered on down_proj.

has_additional_data

has_additional_data()

Returns True (rotation metadata exists).

Workflow

┌─────────────────────────────────────────────────────────────┐
│  Step 1: Rotation Preprocessing                             │
│                                                             │
│  ModelConfig ──► prepare_rotated_model() ──► RotatedModelConfig
│                  (train rotation matrices,                  │
│                   absorb into weights,                      │
│                   save rotated model)                       │
└──────────────────────────┬──────────────────────────────────┘
┌──────────────────────────▼──────────────────────────────────┐
│  Step 2: Quantization                                       │
│                                                             │
│  RotatedModelConfig ──► Runner(quantizer=GPTQ/RTN/...) ──► run()
│  (auto-registers            ──► save_quantized_model()      │
│   Hadamard hooks)                                           │
└──────────────────────────┬──────────────────────────────────┘
┌──────────────────────────▼──────────────────────────────────┐
│  Step 3: Load                                               │
│                                                             │
│  load_quantized_model()                                     │
│  (auto-detects "rotated: true" in config.json,              │
│   registers Hadamard hooks automatically)                   │
└─────────────────────────────────────────────────────────────┘

Note

The wbits, groupsize, and sym parameters passed to prepare_rotated_model() control the RTN proxy used during rotation training. These values must match the quantizer parameters used in Step 2.