Skip to content

Pre-Process (Rotation Preprocessing)

Rotation preprocessing reduces quantization error by learning optimal rotation matrices (SpinQuant/OstQuant) and absorbing them into model weights before quantization.

prepare_rotated_model

prepare_rotated_model

prepare_rotated_model(model_config: ModelConfig, save_directory: str, *, rotation: bool = True, scaling: bool = False, rotation_mode: str = 'random_hadamard', scaling_mode: str = 'identity', seed: int = 0, enable_training: bool = True, calibration_config: CalibrationConfig | None = None, wbits: int = 4, sym: bool = False, groupsize: int = -1, mse: bool = False, norm: float = 2.4, grid: int = 100, fp32_had: bool = False, use_sdpa: bool = False, training_args_override: dict | None = None) -> RotatedModelConfig

Optionally train rotation/scaling matrices, apply them to model weights, and save.

Parameters:

Name Type Description Default
model_config ModelConfig

Original model configuration (model_id or path).

required
save_directory str

Directory to save the rotated model.

required
rotation bool

Whether to apply rotation matrices (R1, R2).

True
scaling bool

Whether to apply scaling diagonals (S_*).

False
rotation_mode str

"random_hadamard" | "hadamard" | "random" | "identity".

'random_hadamard'
scaling_mode str

"identity" | "random_ones" | "random".

'identity'
seed int

Random seed for rotation matrix initialisation and calibration data preparation. Note that the Trainer uses a separate seed (TrainingArguments.seed, default 42) for data shuffling and training reproducibility.

0
enable_training bool

If True, train the rotation/scaling matrices; otherwise use the randomly initialised matrices directly.

True
calibration_config CalibrationConfig | None

Calibration data configuration. When None (default), a :class:CalibrationConfig with default values is created automatically (calibration_dataset="c4", max_length=2048, num_calibration_samples=512, strategy="drop_rand"). See :class:~onecomp.calibration.CalibrationConfig.

None
wbits int

Weight quantisation bit-width for the RTN proxy during training (default: 4). Should match the quantizer's wbits.

4
sym bool

Symmetric quantisation for the RTN proxy.

False
groupsize int

Group size for the RTN proxy (-1 = per-channel, default: -1). Should match the quantizer's groupsize. When positive, the value must evenly divide the out_features of every nn.Linear layer in the model.

-1
mse bool

Enable MSE grid search for optimal clipping in the RTN proxy during training.

False
norm float

Lp norm exponent for the MSE grid search (default: 2.4).

2.4
grid int

Number of candidate shrink levels for the MSE grid search (default: 100).

100
fp32_had bool

Use FP32 for the online Hadamard transform.

False
use_sdpa bool

Use SDPA attention implementation during training.

False
training_args_override dict | None

Override TrainingArguments fields (dict).

None

Returns:

Type Description
RotatedModelConfig

class:~onecomp.rotated_model_config.RotatedModelConfig pointing at

RotatedModelConfig

save_directory.

Examples:

Basic usage:

>>> from onecomp import ModelConfig, prepare_rotated_model, GPTQ
>>> model_config = ModelConfig(model_id="meta-llama/Llama-2-7b-hf")
>>> rotated_config = prepare_rotated_model(
...     model_config=model_config,
...     save_directory="./rotated_model",
... )

Without training (random rotation only):

>>> rotated_config = prepare_rotated_model(
...     model_config=model_config,
...     save_directory="./rotated_model",
...     enable_training=False,
... )

RotatedModelConfig

ModelConfig subclass for loading rotation-preprocessed models. Automatically registers Hadamard forward_pre_hook on down_proj layers.

RotatedModelConfig

RotatedModelConfig(path: str = None, dtype: str = 'float16', device: str = 'auto', fp32_had: bool = None, **kwargs)

Bases: ModelConfig

ModelConfig subclass for rotation-preprocessed models.

Inherits ModelConfig and automatically registers deterministic Hadamard forward_pre_hook on down_proj layers when load_model() is called.

The saved model directory should contain:

  • config.json — HuggingFace model config (includes fp32_had field)
  • model.safetensors — rotation-applied weights
  • tokenizer.json

Parameters:

Name Type Description Default
path str

Path to the saved rotated model (required).

None
dtype str

Data type. Defaults to "float16".

'float16'
device str

Device. Defaults to "auto".

'auto'
fp32_had bool or None

Use FP32 for online Hadamard transform. If None (default), auto-detect from rotation_config.json in the model directory. Falls back to False if not found.

None
Example

from onecomp import Runner, RotatedModelConfig, GPTQ

model_config = RotatedModelConfig(path="./rotated_model") quantizer = GPTQ(wbits=4, groupsize=128) runner = Runner(model_config=model_config, quantizer=quantizer) runner.run()

load_model

load_model(**kwargs)

Load the rotated model and register Hadamard hooks.

Returns:

Type Description

nn.Module: Model with Hadamard pre-hooks registered on down_proj.

has_additional_data

has_additional_data()

Returns True (rotation metadata exists).

Workflow

┌─────────────────────────────────────────────────────────────┐
│  Step 1: Rotation Preprocessing                             │
│                                                             │
│  ModelConfig ──► prepare_rotated_model() ──► RotatedModelConfig
│                  (train rotation matrices,                  │
│                   absorb into weights,                      │
│                   save rotated model)                       │
└──────────────────────────┬──────────────────────────────────┘
┌──────────────────────────▼──────────────────────────────────┐
│  Step 2: Quantization                                       │
│                                                             │
│  RotatedModelConfig ──► Runner(quantizer=GPTQ/RTN/...) ──► run()
│  (auto-registers            ──► save_quantized_model()      │
│   Hadamard hooks)                                           │
└──────────────────────────┬──────────────────────────────────┘
┌──────────────────────────▼──────────────────────────────────┐
│  Step 3: Load                                               │
│                                                             │
│  load_quantized_model()                                     │
│  (auto-detects "rotated: true" in config.json,              │
│   registers Hadamard hooks automatically)                   │
└─────────────────────────────────────────────────────────────┘

Note

The wbits, groupsize, and sym parameters passed to prepare_rotated_model() control the RTN proxy used during rotation training. These values must match the quantizer parameters used in Step 2.