Skip to content

Sparse Autoencoder (SAE) API

Sparse Autoencoders, training, concepts, and related modules for mechanistic interpretability.

Core SAE Classes

mi_crow.mechanistic.sae.sae.Sae

Sae(n_latents, n_inputs, hook_id=None, device='cpu', store=None, *args, **kwargs)

Bases: Controller, Detector, ABC

Source code in src/mi_crow/mechanistic/sae/sae.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def __init__(
        self,
        n_latents: int,
        n_inputs: int,
        hook_id: str | None = None,
        device: str = 'cpu',
        store: Store | None = None,
        *args: Any,
        **kwargs: Any
) -> None:
    # Initialize both Controller and Detector
    Controller.__init__(self, hook_type=HookType.FORWARD, hook_id=hook_id)
    Detector.__init__(self, hook_type=HookType.FORWARD, hook_id=hook_id, store=store)

    self._autoencoder_context = AutoencoderContext(
        autoencoder=self,
        n_latents=n_latents,
        n_inputs=n_inputs
    )
    self._autoencoder_context.device = device
    self.sae_engine: OvercompleteSAE = self._initialize_sae_engine()
    self.concepts = AutoencoderConcepts(self._autoencoder_context)

    # Text tracking flag
    self._text_tracking_enabled: bool = False

    # Training component
    self.trainer = SaeTrainer(self)

context property writable

context

Get the AutoencoderContext associated with this SAE.

process_activations abstractmethod

process_activations(module, input, output)

Process activations to save neuron activations in metadata.

This implements the Detector interface. It extracts activations, encodes them to get neuron activations (latents), and saves metadata for each item in the batch individually, including nonzero latent indices and activations.

Parameters:

Name Type Description Default
module Module

The PyTorch module being hooked

required
input HOOK_FUNCTION_INPUT

Tuple of input tensors to the module

required
output HOOK_FUNCTION_OUTPUT

Output tensor(s) from the module

required
Source code in src/mi_crow/mechanistic/sae/sae.py
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
@abc.abstractmethod
def process_activations(
        self,
        module: torch.nn.Module,
        input: HOOK_FUNCTION_INPUT,
        output: HOOK_FUNCTION_OUTPUT
) -> None:
    """
    Process activations to save neuron activations in metadata.

    This implements the Detector interface. It extracts activations, encodes them
    to get neuron activations (latents), and saves metadata for each item in the batch
    individually, including nonzero latent indices and activations.

    Args:
        module: The PyTorch module being hooked
        input: Tuple of input tensors to the module
        output: Output tensor(s) from the module
    """
    raise NotImplementedError("process_activations method not implemented.")

set_context

set_context(context)

Set the LanguageModelContext for this hook and sync to AutoencoderContext.

When the hook is registered, this method is called with the LanguageModelContext. It automatically syncs relevant values to the AutoencoderContext, including device.

Parameters:

Name Type Description Default
context LanguageModelContext

The LanguageModelContext instance from the LanguageModel

required
Source code in src/mi_crow/mechanistic/sae/sae.py
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
def set_context(self, context: "LanguageModelContext") -> None:
    """Set the LanguageModelContext for this hook and sync to AutoencoderContext.

    When the hook is registered, this method is called with the LanguageModelContext.
    It automatically syncs relevant values to the AutoencoderContext, including device.

    Args:
        context: The LanguageModelContext instance from the LanguageModel
    """
    Hook.set_context(self, context)
    self._context = context
    if context is not None:
        self._autoencoder_context.lm = context.language_model
        if context.model_id is not None:
            self._autoencoder_context.model_id = context.model_id
        if context.store is not None and self._autoencoder_context.store is None:
            self._autoencoder_context.store = context.store
        if self.layer_signature is not None:
            self._autoencoder_context.lm_layer_signature = self.layer_signature
        if context.device is not None:
            self._autoencoder_context.device = context.device

mi_crow.mechanistic.sae.modules.topk_sae.TopKSae

TopKSae(n_latents, n_inputs, hook_id=None, device='cpu', store=None, *args, **kwargs)

Bases: Sae

Initialize TopK SAE.

Parameters:

Name Type Description Default
n_latents int

Number of latent dimensions (concepts)

required
n_inputs int

Number of input dimensions

required
hook_id str | None

Optional hook identifier

None
device str

Device to run on ('cpu', 'cuda', 'mps')

'cpu'
store Store | None

Optional store instance

None
Note

The k parameter must be provided in TopKSaeTrainingConfig during training. For loaded models, k is restored from saved metadata. A temporary default k=1 is used for engine initialization and will be overridden with the actual k value from config during training.

Source code in src/mi_crow/mechanistic/sae/modules/topk_sae.py
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def __init__(
        self,
        n_latents: int,
        n_inputs: int,
        hook_id: str | None = None,
        device: str = 'cpu',
        store: Store | None = None,
        *args: Any,
        **kwargs: Any
) -> None:
    """
    Initialize TopK SAE.

    Args:
        n_latents: Number of latent dimensions (concepts)
        n_inputs: Number of input dimensions
        hook_id: Optional hook identifier
        device: Device to run on ('cpu', 'cuda', 'mps')
        store: Optional store instance

    Note:
        The `k` parameter must be provided in TopKSaeTrainingConfig during training.
        For loaded models, `k` is restored from saved metadata.
        A temporary default k=1 is used for engine initialization and will be
        overridden with the actual k value from config during training.
    """
    super().__init__(n_latents, n_inputs, hook_id, device, store, *args, **kwargs)

decode

decode(x)

Decode latents using sae_engine.

Parameters:

Name Type Description Default
x Tensor

Encoded tensor of shape [batch_size, n_latents]

required

Returns:

Type Description
Tensor

Reconstructed tensor of shape [batch_size, n_inputs]

Source code in src/mi_crow/mechanistic/sae/modules/topk_sae.py
109
110
111
112
113
114
115
116
117
118
119
def decode(self, x: torch.Tensor) -> torch.Tensor:
    """
    Decode latents using sae_engine.

    Args:
        x: Encoded tensor of shape [batch_size, n_latents]

    Returns:
        Reconstructed tensor of shape [batch_size, n_inputs]
    """
    return self.sae_engine.decode(x)

encode

encode(x)

Encode input using sae_engine.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape [batch_size, n_inputs]

required

Returns:

Type Description
Tensor

Encoded latents (TopK sparse activations)

Source code in src/mi_crow/mechanistic/sae/modules/topk_sae.py
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
def encode(self, x: torch.Tensor) -> torch.Tensor:
    """
    Encode input using sae_engine.

    Args:
        x: Input tensor of shape [batch_size, n_inputs]

    Returns:
        Encoded latents (TopK sparse activations)
    """
    # Overcomplete TopKSAE encode returns (pre_codes, codes)
    _, codes = self.sae_engine.encode(x)
    return codes

forward

forward(x)

Forward pass using sae_engine.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape [batch_size, n_inputs]

required

Returns:

Type Description
Tensor

Reconstructed tensor of shape [batch_size, n_inputs]

Source code in src/mi_crow/mechanistic/sae/modules/topk_sae.py
121
122
123
124
125
126
127
128
129
130
131
132
133
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """
    Forward pass using sae_engine.

    Args:
        x: Input tensor of shape [batch_size, n_inputs]

    Returns:
        Reconstructed tensor of shape [batch_size, n_inputs]
    """
    # Overcomplete TopKSAE forward returns (pre_codes, codes, x_reconstructed)
    _, _, x_reconstructed = self.sae_engine.forward(x)
    return x_reconstructed

load staticmethod

load(path)

Load TopKSAE from saved file using overcomplete's load method + our metadata.

Parameters:

Name Type Description Default
path Path

Path to saved model file

required

Returns:

Type Description
TopKSae

Loaded TopKSAE instance

Source code in src/mi_crow/mechanistic/sae/modules/topk_sae.py
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
@staticmethod
def load(path: Path) -> "TopKSae":
    """
    Load TopKSAE from saved file using overcomplete's load method + our metadata.

    Args:
        path: Path to saved model file

    Returns:
        Loaded TopKSAE instance
    """
    p = Path(path)

    # Load payload
    if torch.cuda.is_available():
        map_location = 'cuda'
    elif torch.backends.mps.is_available():
        map_location = 'mps'
    else:
        map_location = 'cpu'
    payload = torch.load(p, map_location=map_location)

    # Extract our metadata
    if "mi_crow_metadata" not in payload:
        raise ValueError(f"Invalid TopKSAE save format: missing 'mi_crow_metadata' key in {p}")

    mi_crow_meta = payload["mi_crow_metadata"]
    n_latents = int(mi_crow_meta["n_latents"])
    n_inputs = int(mi_crow_meta["n_inputs"])
    k = int(mi_crow_meta["k"])
    device = mi_crow_meta.get("device", "cpu")
    layer_signature = mi_crow_meta.get("layer_signature")
    model_id = mi_crow_meta.get("model_id")
    concepts_state = mi_crow_meta.get("concepts_state", {})

    # Create TopKSAE instance
    topk_sae = TopKSae(
        n_latents=n_latents,
        n_inputs=n_inputs,
        device=device
    )

    topk_sae.sae_engine = topk_sae._initialize_sae_engine(k=k)

    # Load overcomplete model state dict
    if "sae_state_dict" in payload:
        topk_sae.sae_engine.load_state_dict(payload["sae_state_dict"])
    elif "model" in payload:
        # Backward compatibility with old format
        topk_sae.sae_engine.load_state_dict(payload["model"])
    else:
        # Assume payload is the state dict itself (backward compatibility)
        topk_sae.sae_engine.load_state_dict(payload)

    # Load concepts state
    if concepts_state:
        device = topk_sae.context.device
        if isinstance(device, str):
            device = torch.device(device)
        if "multiplication" in concepts_state:
            topk_sae.concepts.multiplication.data = concepts_state["multiplication"].to(device)
        if "bias" in concepts_state:
            topk_sae.concepts.bias.data = concepts_state["bias"].to(device)

    # Note: Top texts loading was removed as serialization methods were removed
    # Top texts should be exported/imported separately if needed

    # Set context metadata
    topk_sae.context.lm_layer_signature = layer_signature
    topk_sae.context.model_id = model_id

    params_str = f"n_latents={n_latents}, n_inputs={n_inputs}, k={k}"
    logger.info(f"\nLoaded TopKSAE from {p}\n{params_str}")

    return topk_sae

modify_activations

modify_activations(module, inputs, output)

Modify activations using TopKSAE (Controller hook interface).

Extracts tensor from inputs/output, applies SAE forward pass, and optionally applies concept manipulation.

Parameters:

Name Type Description Default
module Module

The PyTorch module being hooked

required
inputs Tensor | None

Tuple of inputs to the module

required
output Tensor | None

Output from the module (None for pre_forward hooks)

required

Returns:

Type Description
Tensor | None

Modified activations with same shape as input

Source code in src/mi_crow/mechanistic/sae/modules/topk_sae.py
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
def modify_activations(
        self,
        module: "torch.nn.Module",
        inputs: torch.Tensor | None,
        output: torch.Tensor | None
) -> torch.Tensor | None:
    """
    Modify activations using TopKSAE (Controller hook interface).

    Extracts tensor from inputs/output, applies SAE forward pass,
    and optionally applies concept manipulation.

    Args:
        module: The PyTorch module being hooked
        inputs: Tuple of inputs to the module
        output: Output from the module (None for pre_forward hooks)

    Returns:
        Modified activations with same shape as input
    """
    # Extract tensor from output/inputs, handling objects with last_hidden_state
    if self.hook_type == HookType.FORWARD:
        if isinstance(output, torch.Tensor):
            tensor = output
        elif hasattr(output, "last_hidden_state") and isinstance(output.last_hidden_state, torch.Tensor):
            tensor = output.last_hidden_state
        elif isinstance(output, (tuple, list)):
            # Try to find first tensor in tuple/list
            tensor = next((item for item in output if isinstance(item, torch.Tensor)), None)
        else:
            tensor = None
    else:
        tensor = inputs[0] if len(inputs) > 0 and isinstance(inputs[0], torch.Tensor) else None

    if tensor is None or not isinstance(tensor, torch.Tensor):
        return output if self.hook_type == HookType.FORWARD else inputs

    original_shape = tensor.shape

    # Flatten to 2D for SAE processing: (batch, seq_len, hidden) -> (batch * seq_len, hidden)
    # or keep as 2D if already 2D: (batch, hidden)
    if len(original_shape) > 2:
        batch_size, seq_len = original_shape[:2]
        tensor_flat = tensor.reshape(-1, original_shape[-1])
    else:
        batch_size = original_shape[0]
        seq_len = 1
        tensor_flat = tensor

    # Get full activations (pre_codes) and sparse codes
    # Overcomplete TopKSAE encode returns (pre_codes, codes)
    pre_codes, codes = self.sae_engine.encode(tensor_flat)

    # Save SAE activations (pre_codes) as 3D tensor: (batch, seq, n_latents)
    latents_cpu = pre_codes.detach().cpu()
    latents_3d = latents_cpu.reshape(batch_size, seq_len, -1)

    # Save to tensor_metadata
    self.tensor_metadata['neurons'] = latents_3d
    self.tensor_metadata['activations'] = latents_3d

    # Process each item in the batch individually for metadata
    batch_items = []
    n_items = latents_cpu.shape[0]
    for item_idx in range(n_items):
        item_latents = latents_cpu[item_idx]  # [n_latents]

        # Find nonzero indices for this item
        nonzero_mask = item_latents != 0
        nonzero_indices = torch.nonzero(nonzero_mask, as_tuple=False).flatten().tolist()

        # Create map of nonzero indices to activations
        activations_map = {
            int(idx): float(item_latents[idx].item())
            for idx in nonzero_indices
        }

        # Create item metadata
        item_metadata = {
            "nonzero_indices": nonzero_indices,
            "activations": activations_map
        }
        batch_items.append(item_metadata)

    # Save batch items metadata
    self.metadata['batch_items'] = batch_items

    # Use sparse codes for reconstruction
    latents = codes

    # Update top texts if text tracking is enabled
    if self._text_tracking_enabled and self.context.lm is not None:
        input_tracker = self.context.lm.get_input_tracker()
        if input_tracker is not None:
            texts = input_tracker.get_current_texts()
            if texts:
                # Use pre_codes (full activations) for text tracking
                self.concepts.update_top_texts_from_latents(
                    latents_cpu,
                    texts,
                    original_shape
                )

    # Apply concept manipulation if parameters are set
    # Check if multiplication or bias differ from defaults (ones)
    if not torch.allclose(self.concepts.multiplication, torch.ones_like(self.concepts.multiplication)) or \
            not torch.allclose(self.concepts.bias, torch.ones_like(self.concepts.bias)):
        # Apply manipulation: latents = latents * multiplication + bias
        latents = latents * self.concepts.multiplication + self.concepts.bias

    # Decode to get reconstruction
    reconstructed = self.decode(latents)

    # Reshape back to original shape
    if len(original_shape) > 2:
        reconstructed = reconstructed.reshape(original_shape)

    # Return in appropriate format
    if self.hook_type == HookType.FORWARD:
        if isinstance(output, torch.Tensor):
            return reconstructed
        elif isinstance(output, (tuple, list)):
            # Replace first tensor in tuple/list
            result = list(output)
            for i, item in enumerate(result):
                if isinstance(item, torch.Tensor):
                    result[i] = reconstructed
                    break
            return tuple(result) if isinstance(output, tuple) else result
        else:
            # For objects with attributes, try to set last_hidden_state
            if hasattr(output, "last_hidden_state"):
                output.last_hidden_state = reconstructed
            return output
    else:  # PRE_FORWARD
        # Return modified inputs tuple
        result = list(inputs)
        if len(result) > 0:
            result[0] = reconstructed
        return tuple(result)

process_activations

process_activations(module, input, output)

Process activations (Detector interface).

Metadata saving is handled in modify_activations to avoid duplicate work. This method is kept for interface compatibility but does nothing since modify_activations already saves the metadata when called.

Parameters:

Name Type Description Default
module Module

The PyTorch module being hooked

required
input HOOK_FUNCTION_INPUT

Tuple of input tensors to the module

required
output HOOK_FUNCTION_OUTPUT

Output tensor(s) from the module

required
Source code in src/mi_crow/mechanistic/sae/modules/topk_sae.py
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
def process_activations(
        self,
        module: torch.nn.Module,
        input: HOOK_FUNCTION_INPUT,
        output: HOOK_FUNCTION_OUTPUT
) -> None:
    """
    Process activations (Detector interface).

    Metadata saving is handled in modify_activations to avoid duplicate work.
    This method is kept for interface compatibility but does nothing since
    modify_activations already saves the metadata when called.

    Args:
        module: The PyTorch module being hooked
        input: Tuple of input tensors to the module
        output: Output tensor(s) from the module
    """
    # Metadata saving is done in modify_activations to avoid duplicate encoding
    pass

save

save(name, path=None, k=None)

Save model using overcomplete's state dict + our metadata.

Parameters:

Name Type Description Default
name str

Model name

required
path str | Path | None

Directory path to save to (defaults to current directory)

None
k int | None

Top-K value to save (if None, attempts to get from engine or raises error)

None
Source code in src/mi_crow/mechanistic/sae/modules/topk_sae.py
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
def save(self, name: str, path: str | Path | None = None, k: int | None = None) -> None:
    """
    Save model using overcomplete's state dict + our metadata.

    Args:
        name: Model name
        path: Directory path to save to (defaults to current directory)
        k: Top-K value to save (if None, attempts to get from engine or raises error)
    """
    if path is None:
        path = Path.cwd()
    save_dir = Path(path)
    save_dir.mkdir(parents=True, exist_ok=True)
    save_path = save_dir / f"{name}.pt"

    # Save overcomplete model state dict
    sae_state_dict = self.sae_engine.state_dict()

    # Get k value - prefer parameter, then try to get from engine
    if k is None:
        if hasattr(self.sae_engine, 'top_k'):
            k = self.sae_engine.top_k
        else:
            raise ValueError(
                "k parameter must be provided to save() method. "
                "The engine does not expose top_k attribute."
            )

    mi_crow_metadata = {
        "concepts_state": {
            'multiplication': self.concepts.multiplication.data,
            'bias': self.concepts.bias.data,
        },
        "n_latents": self.context.n_latents,
        "n_inputs": self.context.n_inputs,
        "k": k,
        "device": self.context.device,
        "layer_signature": self.context.lm_layer_signature,
        "model_id": self.context.model_id,
    }

    payload = {
        "sae_state_dict": sae_state_dict,
        "mi_crow_metadata": mi_crow_metadata,
    }

    torch.save(payload, save_path)
    logger.info(f"Saved TopKSAE to {save_path}")

train

train(store, run_id, layer_signature, config=None, training_run_id=None)

Train TopKSAE using activations from a Store.

This method delegates to the SaeTrainer composite class. The SAE engine will be reinitialized with the k value from config.

Parameters:

Name Type Description Default
store Store

Store instance containing activations

required
run_id str

Run ID to train on

required
config TopKSaeTrainingConfig | None

Training configuration (must include k parameter)

None
training_run_id str | None

Optional training run ID

None

Returns:

Type Description
dict[str, Any]

Dictionary with keys: - "history": Training history dictionary - "training_run_id": Training run ID where outputs were saved

Raises:

Type Description
ValueError

If config is None or config.k is not set

Source code in src/mi_crow/mechanistic/sae/modules/topk_sae.py
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
def train(
        self,
        store: Store,
        run_id: str,
        layer_signature: str | int,
        config: TopKSaeTrainingConfig | None = None,
        training_run_id: str | None = None
) -> dict[str, Any]:
    """
    Train TopKSAE using activations from a Store.

    This method delegates to the SaeTrainer composite class.
    The SAE engine will be reinitialized with the k value from config.

    Args:
        store: Store instance containing activations
        run_id: Run ID to train on
        config: Training configuration (must include k parameter)
        training_run_id: Optional training run ID

    Returns:
        Dictionary with keys:
            - "history": Training history dictionary
            - "training_run_id": Training run ID where outputs were saved

    Raises:
        ValueError: If config is None or config.k is not set
    """
    if config is None:
        config = TopKSaeTrainingConfig()

    # Ensure k is set in config
    if not hasattr(config, 'k') or config.k is None:
        raise ValueError(
            "TopKSaeTrainingConfig must have k parameter set. "
            "Example: TopKSaeTrainingConfig(k=10, epochs=100, ...)"
        )

    # Reinitialize engine with k from config
    logger.info(f"Initializing SAE engine with k={config.k}")
    self.sae_engine = self._initialize_sae_engine(k=config.k)
    if hasattr(config, 'device') and config.device:
        device = torch.device(config.device)
        self.sae_engine.to(device)
        self.context.device = str(device)

    return self.trainer.train(store, run_id, layer_signature, config, training_run_id)

mi_crow.mechanistic.sae.autoencoder_context.AutoencoderContext dataclass

AutoencoderContext(autoencoder, n_latents, n_inputs, lm=None, lm_layer_signature=None, model_id=None, device='cpu', experiment_name=None, run_id=None, text_tracking_enabled=False, text_tracking_k=5, text_tracking_negative=False, store=None, tied=False, bias_init=0.0, init_method='kaiming')

Shared context for Autoencoder and its nested components.

Training

mi_crow.mechanistic.sae.sae_trainer.SaeTrainer

SaeTrainer(sae)

Composite trainer class for SAE models using overcomplete's training functions.

This trainer handles training of any SAE that has a sae_engine attribute compatible with overcomplete's train_sae functions.

Initialize SaeTrainer.

Parameters:

Name Type Description Default
sae Sae

The SAE instance to train

required
Source code in src/mi_crow/mechanistic/sae/sae_trainer.py
68
69
70
71
72
73
74
75
76
def __init__(self, sae: "Sae") -> None:
    """
    Initialize SaeTrainer.

    Args:
        sae: The SAE instance to train
    """
    self.sae = sae
    self.logger = get_logger(__name__)

mi_crow.mechanistic.sae.sae_trainer.SaeTrainingConfig dataclass

SaeTrainingConfig(epochs=1, batch_size=1024, lr=0.001, l1_lambda=0.0, device='cpu', dtype=None, max_batches_per_epoch=None, verbose=False, use_amp=True, amp_dtype=None, grad_accum_steps=1, clip_grad=1.0, monitoring=1, scheduler=None, max_nan_fallbacks=5, use_wandb=False, wandb_project=None, wandb_entity=None, wandb_name=None, wandb_tags=None, wandb_config=None, wandb_mode='online', wandb_slow_metrics_frequency=50, wandb_api_key=None, memory_efficient=False, snapshot_every_n_epochs=None, snapshot_base_path=None)

Configuration for SAE training (compatible with overcomplete.train_sae).

mi_crow.mechanistic.sae.modules.topk_sae.TopKSaeTrainingConfig dataclass

TopKSaeTrainingConfig(epochs=1, batch_size=1024, lr=0.001, l1_lambda=0.0, device='cpu', dtype=None, max_batches_per_epoch=None, verbose=False, use_amp=True, amp_dtype=None, grad_accum_steps=1, clip_grad=1.0, monitoring=1, scheduler=None, max_nan_fallbacks=5, use_wandb=False, wandb_project=None, wandb_entity=None, wandb_name=None, wandb_tags=None, wandb_config=None, wandb_mode='online', wandb_slow_metrics_frequency=50, wandb_api_key=None, memory_efficient=False, snapshot_every_n_epochs=None, snapshot_base_path=None, k=10)

Bases: SaeTrainingConfig

Training configuration for TopK SAE models.

This class extends SaeTrainingConfig to provide a type-safe configuration interface specifically for TopK SAE models. It adds the k parameter which specifies how many top activations to keep during encoding.

Parameters:

Name Type Description Default
k int

Number of top activations to keep (required for TopK SAE training)

10
Note

All other parameters are inherited from SaeTrainingConfig.

Attributes:

Name Type Description
k int

Number of top activations to keep during TopK encoding

Example

config = TopKSaeTrainingConfig( ... k=10, ... epochs=100, ... batch_size=1024, ... lr=1e-3, ... l1_lambda=1e-4 ... )

Concepts

mi_crow.mechanistic.sae.concepts.autoencoder_concepts.AutoencoderConcepts

AutoencoderConcepts(context)
Source code in src/mi_crow/mechanistic/sae/concepts/autoencoder_concepts.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def __init__(
        self,
        context: AutoencoderContext
):
    self.context = context
    self._n_size = context.n_latents
    self.dictionary: ConceptDictionary | None = None

    self.multiplication = nn.Parameter(torch.ones(self._n_size))
    self.bias = nn.Parameter(torch.ones(self._n_size))

    self._text_heaps_positive: list[TextHeap] | None = None
    self._text_heaps_negative: list[TextHeap] | None = None
    self._text_tracking_k: int = 5
    self._text_tracking_negative: bool = False

enable_text_tracking

enable_text_tracking()

Enable text tracking using context parameters.

Source code in src/mi_crow/mechanistic/sae/concepts/autoencoder_concepts.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def enable_text_tracking(self):
    """Enable text tracking using context parameters."""
    if self.context.lm is None:
        raise ValueError("LanguageModel must be set in context to enable tracking")

    # Store tracking parameters
    self._text_tracking_k = self.context.text_tracking_k
    self._text_tracking_negative = self.context.text_tracking_negative

    # Ensure InputTracker singleton exists on LanguageModel and enable it
    input_tracker = self.context.lm._ensure_input_tracker()
    input_tracker.enable()

    # Enable text tracking on the SAE instance
    if hasattr(self.context.autoencoder, '_text_tracking_enabled'):
        self.context.autoencoder._text_tracking_enabled = True

export_bottom_texts_to_json

export_bottom_texts_to_json(filepath)

Export bottom texts (negative activations) to JSON file.

Source code in src/mi_crow/mechanistic/sae/concepts/autoencoder_concepts.py
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
def export_bottom_texts_to_json(self, filepath: Path | str) -> Path:
    """Export bottom texts (negative activations) to JSON file."""
    if not self._text_tracking_negative or self._text_heaps_negative is None:
        raise ValueError("No bottom texts available. Enable negative text tracking and run inference first.")

    filepath = Path(filepath)
    filepath.parent.mkdir(parents=True, exist_ok=True)

    all_texts = self.get_all_bottom_texts()
    export_data = {}

    for neuron_idx, neuron_texts in enumerate(all_texts):
        export_data[neuron_idx] = [
            {
                "text": nt.text,
                "score": nt.score,
                "token_str": nt.token_str,
                "token_idx": nt.token_idx
            }
            for nt in neuron_texts
        ]

    with filepath.open("w", encoding="utf-8") as f:
        json.dump(export_data, f, ensure_ascii=False, indent=2)

    return filepath

export_top_texts_to_json

export_top_texts_to_json(filepath)

Export top texts (positive activations) to JSON file.

Source code in src/mi_crow/mechanistic/sae/concepts/autoencoder_concepts.py
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
def export_top_texts_to_json(self, filepath: Path | str) -> Path:
    """Export top texts (positive activations) to JSON file."""
    if self._text_heaps_positive is None:
        raise ValueError("No top texts available. Enable text tracking and run inference first.")

    filepath = Path(filepath)
    filepath.parent.mkdir(parents=True, exist_ok=True)

    all_texts = self.get_all_top_texts()
    export_data = {}

    for neuron_idx, neuron_texts in enumerate(all_texts):
        export_data[neuron_idx] = [
            {
                "text": nt.text,
                "score": nt.score,
                "token_str": nt.token_str,
                "token_idx": nt.token_idx
            }
            for nt in neuron_texts
        ]

    with filepath.open("w", encoding="utf-8") as f:
        json.dump(export_data, f, ensure_ascii=False, indent=2)

    return filepath

generate_concepts_with_llm

generate_concepts_with_llm(llm_provider=None)

Generate concepts using LLM based on current top texts

Source code in src/mi_crow/mechanistic/sae/concepts/autoencoder_concepts.py
82
83
84
85
86
87
88
89
90
91
92
93
94
95
def generate_concepts_with_llm(self, llm_provider: str | None = None):
    """Generate concepts using LLM based on current top texts"""
    if self._text_heaps_positive is None:
        raise ValueError("No top texts available. Enable text tracking and run inference first.")

    from mi_crow.mechanistic.sae.concepts.concept_dictionary import ConceptDictionary
    neuron_texts = self.get_all_top_texts()

    self.dictionary = ConceptDictionary.from_llm(
        neuron_texts=neuron_texts,
        n_size=self._n_size,
        store=self.dictionary.store if self.dictionary else None,
        llm_provider=llm_provider
    )

get_all_bottom_texts

get_all_bottom_texts()

Get bottom texts for all neurons (negative activations).

Source code in src/mi_crow/mechanistic/sae/concepts/autoencoder_concepts.py
312
313
314
315
316
def get_all_bottom_texts(self) -> list[list[NeuronText]]:
    """Get bottom texts for all neurons (negative activations)."""
    if not self._text_tracking_negative or self._text_heaps_negative is None:
        return []
    return [self.get_bottom_texts_for_neuron(i) for i in range(len(self._text_heaps_negative))]

get_all_top_texts

get_all_top_texts()

Get top texts for all neurons (positive activations).

Source code in src/mi_crow/mechanistic/sae/concepts/autoencoder_concepts.py
306
307
308
309
310
def get_all_top_texts(self) -> list[list[NeuronText]]:
    """Get top texts for all neurons (positive activations)."""
    if self._text_heaps_positive is None:
        return []
    return [self.get_top_texts_for_neuron(i) for i in range(len(self._text_heaps_positive))]

get_bottom_texts_for_neuron

get_bottom_texts_for_neuron(neuron_idx, top_m=None)

Get bottom texts for a specific neuron (negative activations).

Source code in src/mi_crow/mechanistic/sae/concepts/autoencoder_concepts.py
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
def get_bottom_texts_for_neuron(self, neuron_idx: int, top_m: int | None = None) -> list[NeuronText]:
    """Get bottom texts for a specific neuron (negative activations)."""
    if not self._text_tracking_negative:
        return []
    if self._text_heaps_negative is None or neuron_idx < 0 or neuron_idx >= len(self._text_heaps_negative):
        return []
    heap = self._text_heaps_negative[neuron_idx]
    items = heap.get_items()
    items_sorted = sorted(items, key=lambda s_t: s_t[0], reverse=False)
    if top_m is not None:
        items_sorted = items_sorted[: top_m]

    neuron_texts = []
    for score, text, token_idx in items_sorted:
        token_str = self._decode_token(text, token_idx)
        neuron_texts.append(NeuronText(score=score, text=text, token_idx=token_idx, token_str=token_str))
    return neuron_texts

get_top_texts_for_neuron

get_top_texts_for_neuron(neuron_idx, top_m=None)

Get top texts for a specific neuron (positive activations).

Source code in src/mi_crow/mechanistic/sae/concepts/autoencoder_concepts.py
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
def get_top_texts_for_neuron(self, neuron_idx: int, top_m: int | None = None) -> list[NeuronText]:
    """Get top texts for a specific neuron (positive activations)."""
    if self._text_heaps_positive is None or neuron_idx < 0 or neuron_idx >= len(self._text_heaps_positive):
        return []
    heap = self._text_heaps_positive[neuron_idx]
    items = heap.get_items()
    items_sorted = sorted(items, key=lambda s_t: s_t[0], reverse=True)
    if top_m is not None:
        items_sorted = items_sorted[: top_m]

    neuron_texts = []
    for score, text, token_idx in items_sorted:
        token_str = self._decode_token(text, token_idx)
        neuron_texts.append(NeuronText(score=score, text=text, token_idx=token_idx, token_str=token_str))
    return neuron_texts

reset_top_texts

reset_top_texts()

Reset all tracked top texts.

Source code in src/mi_crow/mechanistic/sae/concepts/autoencoder_concepts.py
318
319
320
321
def reset_top_texts(self) -> None:
    """Reset all tracked top texts."""
    self._text_heaps_positive = None
    self._text_heaps_negative = None

update_top_texts_from_latents

update_top_texts_from_latents(latents, texts, original_shape=None)

Update top texts heaps from latents and texts.

Optimized version that: - Only processes active neurons (non-zero activations) - Vectorizes argmax/argmin operations - Eliminates per-neuron tensor slicing

Parameters:

Name Type Description Default
latents Tensor

Latent activations tensor, shape [B*T, n_latents] or [B, n_latents] (already flattened)

required
texts Sequence[str]

List of texts corresponding to the batch

required
original_shape tuple[int, ...] | None

Original shape before flattening, e.g., (B, T, D) or (B, D)

None
Source code in src/mi_crow/mechanistic/sae/concepts/autoencoder_concepts.py
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
def update_top_texts_from_latents(
        self,
        latents: torch.Tensor,
        texts: Sequence[str],
        original_shape: tuple[int, ...] | None = None
) -> None:
    """
    Update top texts heaps from latents and texts.

    Optimized version that:
    - Only processes active neurons (non-zero activations)
    - Vectorizes argmax/argmin operations
    - Eliminates per-neuron tensor slicing

    Args:
        latents: Latent activations tensor, shape [B*T, n_latents] or [B, n_latents] (already flattened)
        texts: List of texts corresponding to the batch
        original_shape: Original shape before flattening, e.g., (B, T, D) or (B, D)
    """
    if not texts:
        return

    n_neurons = latents.shape[-1]
    self._ensure_heaps(n_neurons)

    # Calculate batch and token dimensions
    original_B = len(texts)
    BT = latents.shape[0]  # Total positions (B*T if 3D original, or B if 2D original)

    # Determine if original was 3D or 2D
    if original_shape is not None and len(original_shape) == 3:
        # Original was [B, T, D], latents are [B*T, n_latents]
        B, T, _ = original_shape
        # Verify batch size matches
        if B != original_B:
            logger.warning(f"Batch size mismatch: original_shape has B={B}, but {original_B} texts provided")
            # Use the actual number of texts as batch size
            B = original_B
            T = BT // B if B > 0 else 1
    else:
        # Original was [B, D], latents are [B, n_latents]
        B = original_B
        T = 1

    # OPTIMIZATION 1: Find active neurons (have any non-zero activation across batch)
    # Shape: [n_neurons] - boolean mask
    active_neurons_mask = (latents.abs().sum(dim=0) > 0)
    active_neuron_indices = torch.nonzero(active_neurons_mask, as_tuple=False).flatten().tolist()

    if not active_neuron_indices:
        return  # No active neurons, skip

    # OPTIMIZATION 2: Vectorize argmax/argmin for all neurons at once
    if original_shape is not None and len(original_shape) == 3:
        # Reshape to [B, T, n_neurons]
        latents_3d = latents.view(B, T, n_neurons)
        # For each text, find max/min across tokens for each neuron
        # Shape: [B, n_neurons] - max activation per text per neuron
        max_activations, max_token_indices_3d = latents_3d.max(dim=1)  # [B, n_neurons]
        min_activations, min_token_indices_3d = latents_3d.min(dim=1)  # [B, n_neurons]
        # max_token_indices_3d is already the token index (0 to T-1)
        max_token_indices = max_token_indices_3d
        min_token_indices = min_token_indices_3d
    else:
        # Shape: [B, n_neurons]
        latents_2d = latents.view(B, n_neurons)
        max_activations = latents_2d  # [B, n_neurons]
        max_token_indices = torch.zeros(B, n_neurons, dtype=torch.long, device=latents.device)
        min_activations = latents_2d
        min_token_indices = torch.zeros(B, n_neurons, dtype=torch.long, device=latents.device)

    # Convert to numpy for faster CPU access (already on CPU from l1_sae.py)
    max_activations_np = max_activations.cpu().numpy()
    min_activations_np = min_activations.cpu().numpy()
    max_token_indices_np = max_token_indices.cpu().numpy()
    min_token_indices_np = min_token_indices.cpu().numpy()

    # OPTIMIZATION 3: Only process active neurons
    for j in active_neuron_indices:
        heap_positive = self._text_heaps_positive[j]
        heap_negative = self._text_heaps_negative[j] if self._text_tracking_negative else None

        # OPTIMIZATION 4: Batch process all texts for this neuron
        for batch_idx in range(original_B):
            if batch_idx >= len(texts):
                continue

            text = texts[batch_idx]

            # Use pre-computed max/min (no tensor slicing needed!)
            max_score_positive = float(max_activations_np[batch_idx, j])
            token_idx_positive = int(max_token_indices_np[batch_idx, j])

            if max_score_positive > 0.0:
                heap_positive.update(text, max_score_positive, token_idx_positive)

            if self._text_tracking_negative and heap_negative is not None:
                min_score_negative = float(min_activations_np[batch_idx, j])
                if min_score_negative != 0.0:
                    token_idx_negative = int(min_token_indices_np[batch_idx, j])
                    heap_negative.update(text, min_score_negative, token_idx_negative, adjusted_score=-min_score_negative)

mi_crow.mechanistic.sae.concepts.concept_dictionary.ConceptDictionary

ConceptDictionary(n_size, store=None)
Source code in src/mi_crow/mechanistic/sae/concepts/concept_dictionary.py
22
23
24
25
26
27
28
29
30
def __init__(
        self,
        n_size: int,
        store: Store | None = None
) -> None:
    self.n_size = n_size
    self.concepts_map: Dict[int, Concept] = {}
    self.store = store
    self._directory: Path | None = None

mi_crow.mechanistic.sae.concepts.concept_models

mi_crow.mechanistic.sae.concepts.input_tracker.InputTracker

InputTracker(language_model)

Simple listener that saves input texts before tokenization.

This is a singleton per LanguageModel instance. It's used as a listener during inference to capture texts before they are tokenized. SAE hooks can then access these texts to track top activating texts for their neurons.

Initialize InputTracker.

Parameters:

Name Type Description Default
language_model LanguageModel

Language model instance

required
Source code in src/mi_crow/mechanistic/sae/concepts/input_tracker.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def __init__(
        self,
        language_model: "LanguageModel",
) -> None:
    """
    Initialize InputTracker.

    Args:
        language_model: Language model instance
    """
    self.language_model = language_model

    # Flag to control whether to save inputs
    self._enabled: bool = False

    # Runtime state - only stores texts
    self._current_texts: list[str] = []

enabled property

enabled

Whether input tracking is enabled.

disable

disable()

Disable input tracking.

Source code in src/mi_crow/mechanistic/sae/concepts/input_tracker.py
47
48
49
def disable(self) -> None:
    """Disable input tracking."""
    self._enabled = False

enable

enable()

Enable input tracking.

Source code in src/mi_crow/mechanistic/sae/concepts/input_tracker.py
43
44
45
def enable(self) -> None:
    """Enable input tracking."""
    self._enabled = True

get_current_texts

get_current_texts()

Get the current batch of texts.

Source code in src/mi_crow/mechanistic/sae/concepts/input_tracker.py
65
66
67
def get_current_texts(self) -> list[str]:
    """Get the current batch of texts."""
    return self._current_texts.copy()

reset

reset()

Reset stored texts.

Source code in src/mi_crow/mechanistic/sae/concepts/input_tracker.py
51
52
53
def reset(self) -> None:
    """Reset stored texts."""
    self._current_texts.clear()

set_current_texts

set_current_texts(texts)

Set the current batch of texts being processed.

This is called by LanguageModel._inference() before tokenization if tracking is enabled.

Source code in src/mi_crow/mechanistic/sae/concepts/input_tracker.py
55
56
57
58
59
60
61
62
63
def set_current_texts(self, texts: Sequence[str]) -> None:
    """
    Set the current batch of texts being processed.

    This is called by LanguageModel._inference() before tokenization
    if tracking is enabled.
    """
    if self._enabled:
        self._current_texts = list(texts)

Training Utilities

mi_crow.mechanistic.sae.training.wandb_logger.WandbLogger

WandbLogger(config, run_id)

Handles wandb logging for SAE training.

Encapsulates all wandb-related operations including initialization, metric logging, and summary updates.

Initialize WandbLogger.

Parameters:

Name Type Description Default
config SaeTrainingConfig

Training configuration

required
run_id str

Training run identifier

required
Source code in src/mi_crow/mechanistic/sae/training/wandb_logger.py
19
20
21
22
23
24
25
26
27
28
29
30
def __init__(self, config: SaeTrainingConfig, run_id: str):
    """
    Initialize WandbLogger.

    Args:
        config: Training configuration
        run_id: Training run identifier
    """
    self.config = config
    self.run_id = run_id
    self.wandb_run: Optional[Any] = None
    self._initialized = False

initialize

initialize()

Initialize wandb run if enabled in config.

Returns:

Type Description
bool

True if wandb was successfully initialized, False otherwise

Source code in src/mi_crow/mechanistic/sae/training/wandb_logger.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def initialize(self) -> bool:
    """
    Initialize wandb run if enabled in config.

    Returns:
        True if wandb was successfully initialized, False otherwise
    """
    if not self.config.use_wandb:
        return False

    try:
        import wandb
    except ImportError:
        logger.warning("[WandbLogger] wandb not installed, skipping wandb logging")
        logger.warning("[WandbLogger] Install with: pip install wandb")
        return False

    try:
        wandb_project = self.config.wandb_project or "sae-training"
        wandb_name = self.config.wandb_name or self.run_id
        wandb_mode = self.config.wandb_mode.lower() if self.config.wandb_mode else "online"

        self.wandb_run = wandb.init(
            project=wandb_project,
            entity=self.config.wandb_entity,
            name=wandb_name,
            mode=wandb_mode,
            config=self._build_wandb_config(),
            tags=self.config.wandb_tags or [],
        )
        self._initialized = True
        return True
    except Exception as e:
        logger.warning(f"[WandbLogger] Unexpected error initializing wandb: {e}")
        logger.warning("[WandbLogger] Continuing training without wandb logging")
        return False

log_metrics

log_metrics(history, verbose=False)

Log training metrics to wandb.

Parameters:

Name Type Description Default
history dict[str, list[float | None]]

Dictionary with training history (loss, r2, l1, l0, etc.)

required
verbose bool

Whether to log verbose information

False
Source code in src/mi_crow/mechanistic/sae/training/wandb_logger.py
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
def log_metrics(
        self,
        history: dict[str, list[float | None]],
        verbose: bool = False
) -> None:
    """
    Log training metrics to wandb.

    Args:
        history: Dictionary with training history (loss, r2, l1, l0, etc.)
        verbose: Whether to log verbose information
    """
    if not self._initialized or self.wandb_run is None:
        return

    try:
        num_epochs = len(history.get("loss", []))
        slow_metrics_freq = self.config.wandb_slow_metrics_frequency

        # Helper to get last known value for slow metrics
        def get_last_known_value(values: list[float | None], idx: int) -> float:
            """Get the last non-None value up to idx, or 0.0 if none found."""
            for i in range(idx, -1, -1):
                if i < len(values) and values[i] is not None:
                    return values[i]
            return 0.0

        # Log metrics for each epoch
        for epoch in range(1, num_epochs + 1):
            epoch_idx = epoch - 1
            should_log_slow = (epoch % slow_metrics_freq == 0) or (epoch == num_epochs)

            metrics = self._build_epoch_metrics(history, epoch_idx, should_log_slow, get_last_known_value)
            self.wandb_run.log(metrics)

        # Log final summary metrics
        self._log_summary_metrics(history, get_last_known_value)

        if verbose:
            self._log_wandb_url()

    except Exception as e:
        logger.warning(f"[WandbLogger] Failed to log metrics to wandb: {e}")