Skip to content

Language Model API

Core language model functionality for loading models, running inference, and managing activations.

Main Classes

mi_crow.language_model.language_model.LanguageModel

LanguageModel(model, tokenizer, store, model_id=None, device=None)

Fence-style language model wrapper.

Provides a unified interface for working with language models, including: - Model initialization and configuration - Inference operations through the inference property - Hook management (detectors and controllers) - Model persistence - Activation tracking

Initialize LanguageModel.

Parameters:

Name Type Description Default
model Module

PyTorch model module

required
tokenizer PreTrainedTokenizerBase

HuggingFace tokenizer

required
store Store

Store instance for persistence

required
model_id str | None

Optional model identifier (auto-extracted if not provided)

None
device str | device | None

Optional device string or torch.device (defaults to 'cpu' if None)

None
Source code in src/mi_crow/language_model/language_model.py
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
def __init__(
        self,
        model: nn.Module,
        tokenizer: PreTrainedTokenizerBase,
        store: Store,
        model_id: str | None = None,
        device: str | torch.device | None = None,
):
    """
    Initialize LanguageModel.

    Args:
        model: PyTorch model module
        tokenizer: HuggingFace tokenizer
        store: Store instance for persistence
        model_id: Optional model identifier (auto-extracted if not provided)
        device: Optional device string or torch.device (defaults to 'cpu' if None)
    """
    self.context = LanguageModelContext(self)
    self.context.model = model
    self.context.tokenizer = tokenizer
    self.context.model_id = initialize_model_id(model, model_id)
    self.context.store = store
    self.context.special_token_ids = _extract_special_token_ids(tokenizer)
    self.context.device = normalize_device(device)
    sync_model_to_context_device(self)

    self.layers = LanguageModelLayers(self.context)
    self.lm_tokenizer = LanguageModelTokenizer(self.context)
    self.activations = LanguageModelActivations(self.context)
    self.inference = InferenceEngine(self)

    self._input_tracker: "InputTracker | None" = None

model property

model

Get the underlying PyTorch model.

model_id property

model_id

Get the model identifier.

store property writable

store

Get the store instance.

tokenizer property

tokenizer

Get the tokenizer.

clear_detectors

clear_detectors()

Clear all accumulated metadata for registered detectors.

This is useful when running multiple independent inference runs (e.g. separate infer_texts / infer_dataset calls) and you want to ensure that detector state does not leak between runs.

Source code in src/mi_crow/language_model/language_model.py
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
def clear_detectors(self) -> None:
    """
    Clear all accumulated metadata for registered detectors.

    This is useful when running multiple independent inference runs
    (e.g. separate `infer_texts` / `infer_dataset` calls) and you want
    to ensure that detector state does not leak between runs.
    """
    detectors = self.layers.get_detectors()
    for detector in detectors:
        detector.metadata.clear()
        detector.tensor_metadata.clear()

        clear_captured = getattr(detector, "clear_captured", None)
        if callable(clear_captured):
            clear_captured()

from_huggingface classmethod

from_huggingface(model_name, store, tokenizer_params=None, model_params=None, device=None)

Load a language model from HuggingFace Hub.

Automatically loads model to GPU if device is "cuda" and CUDA is available. This prevents OOM errors by keeping the model on GPU instead of CPU RAM.

Parameters:

Name Type Description Default
model_name str

HuggingFace model identifier

required
store Store

Store instance for persistence

required
tokenizer_params dict

Optional tokenizer parameters

None
model_params dict

Optional model parameters

None
device str | device | None

Target device ("cuda", "cpu", "mps"). If "cuda" and CUDA is available, model will be loaded directly to GPU using device_map="auto" (via the HuggingFace factory helpers).

None

Returns:

Type Description
'LanguageModel'

LanguageModel instance

Source code in src/mi_crow/language_model/language_model.py
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
@classmethod
def from_huggingface(
        cls,
        model_name: str,
        store: Store,
        tokenizer_params: dict = None,
        model_params: dict = None,
        device: str | torch.device | None = None,
) -> "LanguageModel":
    """
    Load a language model from HuggingFace Hub.

    Automatically loads model to GPU if device is "cuda" and CUDA is available.
    This prevents OOM errors by keeping the model on GPU instead of CPU RAM.

    Args:
        model_name: HuggingFace model identifier
        store: Store instance for persistence
        tokenizer_params: Optional tokenizer parameters
        model_params: Optional model parameters
        device: Target device ("cuda", "cpu", "mps"). If "cuda" and CUDA is available,
            model will be loaded directly to GPU using device_map="auto"
            (via the HuggingFace factory helpers).

    Returns:
        LanguageModel instance
    """
    return create_from_huggingface(cls, model_name, store, tokenizer_params, model_params, device)

from_local classmethod

from_local(saved_path, store, model_id=None, device=None)

Load a language model from a saved file (created by save_model).

Parameters:

Name Type Description Default
saved_path Path | str

Path to the saved model file (.pt file)

required
store Store

Store instance for persistence

required
model_id str | None

Optional model identifier. If not provided, will use the model_id from saved metadata. If provided, will be used to load the model architecture from HuggingFace.

None
device str | device | None

Optional device string or torch.device (defaults to 'cpu' if None)

None

Returns:

Type Description
'LanguageModel'

LanguageModel instance

Raises:

Type Description
FileNotFoundError

If the saved file doesn't exist

ValueError

If the saved file format is invalid or model_id is required but not provided

Source code in src/mi_crow/language_model/language_model.py
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
@classmethod
def from_local(
        cls,
        saved_path: Path | str,
        store: Store,
        model_id: str | None = None,
        device: str | torch.device | None = None,
) -> "LanguageModel":
    """
    Load a language model from a saved file (created by save_model).

    Args:
        saved_path: Path to the saved model file (.pt file)
        store: Store instance for persistence
        model_id: Optional model identifier. If not provided, will use the model_id from saved metadata.
                 If provided, will be used to load the model architecture from HuggingFace.
        device: Optional device string or torch.device (defaults to 'cpu' if None)

    Returns:
        LanguageModel instance

    Raises:
        FileNotFoundError: If the saved file doesn't exist
        ValueError: If the saved file format is invalid or model_id is required but not provided
    """
    return load_model_from_saved_file(cls, saved_path, store, model_id, device)

from_local_torch classmethod

from_local_torch(model_path, tokenizer_path, store, device=None)

Load a language model from local HuggingFace paths.

Parameters:

Name Type Description Default
model_path str

Path to the model directory or file

required
tokenizer_path str

Path to the tokenizer directory or file

required
store Store

Store instance for persistence

required
device str | device | None

Optional device string or torch.device (defaults to 'cpu' if None)

None

Returns:

Type Description
'LanguageModel'

LanguageModel instance

Source code in src/mi_crow/language_model/language_model.py
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
@classmethod
def from_local_torch(
        cls,
        model_path: str,
        tokenizer_path: str,
        store: Store,
        device: str | torch.device | None = None,
) -> "LanguageModel":
    """
    Load a language model from local HuggingFace paths.

    Args:
        model_path: Path to the model directory or file
        tokenizer_path: Path to the tokenizer directory or file
        store: Store instance for persistence
        device: Optional device string or torch.device (defaults to 'cpu' if None)

    Returns:
        LanguageModel instance
    """
    return create_from_local_torch(cls, model_path, tokenizer_path, store, device)

get_all_detector_metadata

get_all_detector_metadata()

Get metadata from all registered detectors.

Returns:

Type Description
tuple[dict[str, dict[str, Any]], dict[str, dict[str, Tensor]]]

Tuple of (detectors_metadata, detectors_tensor_metadata)

Source code in src/mi_crow/language_model/language_model.py
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
def get_all_detector_metadata(self) -> tuple[dict[str, dict[str, Any]], dict[str, dict[str, Tensor]]]:
    """
    Get metadata from all registered detectors.

    Returns:
        Tuple of (detectors_metadata, detectors_tensor_metadata)
    """
    detectors = self.layers.get_detectors()
    detectors_metadata: Dict[str, Dict[str, Any]] = defaultdict(dict)
    detectors_tensor_metadata: Dict[str, Dict[str, torch.Tensor]] = defaultdict(dict)

    for detector in detectors:
        detectors_metadata[detector.layer_signature] = dict(detector.metadata)
        detectors_tensor_metadata[detector.layer_signature] = dict(detector.tensor_metadata)

    return detectors_metadata, detectors_tensor_metadata

get_input_tracker

get_input_tracker()

Get the input tracker instance if it exists.

Returns:

Type Description
'InputTracker | None'

InputTracker instance or None

Source code in src/mi_crow/language_model/language_model.py
156
157
158
159
160
161
162
163
def get_input_tracker(self) -> "InputTracker | None":
    """
    Get the input tracker instance if it exists.

    Returns:
        InputTracker instance or None
    """
    return self._input_tracker

save_detector_metadata

save_detector_metadata(run_name, batch_idx, unified=False, clear_after_save=True)

Save detector metadata to store.

Parameters:

Name Type Description Default
run_name str

Name of the run

required
batch_idx int | None

Batch index. Ignored when unified is True.

required
unified bool

If True, save metadata in a single detectors directory for the whole run instead of per‑batch directories.

False
clear_after_save bool

If True, clear detector metadata after saving to free memory. Defaults to True to prevent OOM errors when processing large batches.

True

Returns:

Type Description
str

Path where metadata was saved

Raises:

Type Description
ValueError

If store is not set

Source code in src/mi_crow/language_model/language_model.py
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
def save_detector_metadata(self, run_name: str, batch_idx: int | None, unified: bool = False, clear_after_save: bool = True) -> str:
    """
    Save detector metadata to store.

    Args:
        run_name: Name of the run
        batch_idx: Batch index. Ignored when ``unified`` is True.
        unified: If True, save metadata in a single detectors directory
            for the whole run instead of per‑batch directories.
        clear_after_save: If True, clear detector metadata after saving to free memory.
            Defaults to True to prevent OOM errors when processing large batches.

    Returns:
        Path where metadata was saved

    Raises:
        ValueError: If store is not set
    """
    if self.store is None:
        raise ValueError("Store must be provided or set on the language model")

    detectors_metadata, detectors_tensor_metadata = self.get_all_detector_metadata()

    if unified:
        result = self.store.put_run_detector_metadata(run_name, detectors_metadata, detectors_tensor_metadata)
    else:
        if batch_idx is None:
            raise ValueError("batch_idx must be provided when unified is False")
        result = self.store.put_detector_metadata(run_name, batch_idx, detectors_metadata, detectors_tensor_metadata)

    if clear_after_save:
        for layer_signature in list(detectors_tensor_metadata.keys()):
            detector_tensors = detectors_tensor_metadata[layer_signature]
            for tensor_key in list(detector_tensors.keys()):
                del detector_tensors[tensor_key]
            del detectors_tensor_metadata[layer_signature]
        detectors_metadata.clear()

        detectors = self.layers.get_detectors()
        for detector in detectors:
            clear_captured = getattr(detector, "clear_captured", None)
            if callable(clear_captured):
                clear_captured()
            for key in list(detector.tensor_metadata.keys()):
                del detector.tensor_metadata[key]
            detector.metadata.clear()

        gc.collect()

    return result

save_model

save_model(path=None)

Save the model and its metadata to the store.

Parameters:

Name Type Description Default
path Path | str | None

Optional path to save the model. If None, defaults to {model_id}/model.pt relative to the store base path.

None

Returns:

Type Description
Path

Path where the model was saved

Raises:

Type Description
ValueError

If store is not set

Source code in src/mi_crow/language_model/language_model.py
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
def save_model(self, path: Path | str | None = None) -> Path:
    """
    Save the model and its metadata to the store.

    Args:
        path: Optional path to save the model. If None, defaults to {model_id}/model.pt
              relative to the store base path.

    Returns:
        Path where the model was saved

    Raises:
        ValueError: If store is not set
    """
    return save_model(self, path)

tokenize

tokenize(texts, **kwargs)

Tokenize texts using the language model tokenizer.

Parameters:

Name Type Description Default
texts Sequence[str]

Sequence of text strings to tokenize

required
**kwargs Any

Additional tokenizer arguments

{}

Returns:

Type Description
Any

Tokenized encodings

Source code in src/mi_crow/language_model/language_model.py
143
144
145
146
147
148
149
150
151
152
153
154
def tokenize(self, texts: Sequence[str], **kwargs: Any) -> Any:
    """
    Tokenize texts using the language model tokenizer.

    Args:
        texts: Sequence of text strings to tokenize
        **kwargs: Additional tokenizer arguments

    Returns:
        Tokenized encodings
    """
    return self.lm_tokenizer.tokenize(texts, **kwargs)

mi_crow.language_model.context.LanguageModelContext dataclass

LanguageModelContext(language_model, model_id=None, tokenizer_params=None, model_params=None, device='cpu', dtype=None, model=None, tokenizer=None, store=None, special_token_ids=None, _hook_registry=dict(), _hook_id_map=dict())

Shared context for LanguageModel and its components.

mi_crow.language_model.layers.LanguageModelLayers

LanguageModelLayers(context)

Manages layer access and hook registration for LanguageModel.

Initialize LanguageModelLayers.

Parameters:

Name Type Description Default
context LanguageModelContext

LanguageModelContext instance

required
Source code in src/mi_crow/language_model/layers.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
def __init__(
        self,
        context: "LanguageModelContext",
):
    """
    Initialize LanguageModelLayers.

    Args:
        context: LanguageModelContext instance
    """
    self.context = context
    self.name_to_layer: Dict[str, nn.Module] = {}
    self.idx_to_layer: Dict[int, nn.Module] = {}
    self._flatten_layer_names()

disable_all_hooks

disable_all_hooks()

Disable all registered hooks.

Source code in src/mi_crow/language_model/layers.py
442
443
444
445
def disable_all_hooks(self) -> None:
    """Disable all registered hooks."""
    for _, _, hook in self.context._hook_id_map.values():
        hook.disable()

disable_hook

disable_hook(hook_id)

Disable a specific hook by ID.

Parameters:

Name Type Description Default
hook_id str

Hook ID to disable

required

Returns:

Type Description
bool

True if hook was found and disabled, False otherwise

Source code in src/mi_crow/language_model/layers.py
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
def disable_hook(self, hook_id: str) -> bool:
    """
    Disable a specific hook by ID.

    Args:
        hook_id: Hook ID to disable

    Returns:
        True if hook was found and disabled, False otherwise
    """
    if hook_id in self.context._hook_id_map:
        _, _, hook = self.context._hook_id_map[hook_id]
        hook.disable()
        return True
    return False

enable_all_hooks

enable_all_hooks()

Enable all registered hooks.

Source code in src/mi_crow/language_model/layers.py
437
438
439
440
def enable_all_hooks(self) -> None:
    """Enable all registered hooks."""
    for _, _, hook in self.context._hook_id_map.values():
        hook.enable()

enable_hook

enable_hook(hook_id)

Enable a specific hook by ID.

Parameters:

Name Type Description Default
hook_id str

Hook ID to enable

required

Returns:

Type Description
bool

True if hook was found and enabled, False otherwise

Source code in src/mi_crow/language_model/layers.py
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
def enable_hook(self, hook_id: str) -> bool:
    """
    Enable a specific hook by ID.

    Args:
        hook_id: Hook ID to enable

    Returns:
        True if hook was found and enabled, False otherwise
    """
    if hook_id in self.context._hook_id_map:
        _, _, hook = self.context._hook_id_map[hook_id]
        hook.enable()
        return True
    return False

get_controllers

get_controllers()

Get all registered Controller hooks.

Returns:

Type Description
List[Controller]

List of Controller instances

Source code in src/mi_crow/language_model/layers.py
447
448
449
450
451
452
453
454
def get_controllers(self) -> List[Controller]:
    """
    Get all registered Controller hooks.

    Returns:
        List of Controller instances
    """
    return [hook for hook in self.get_hooks() if isinstance(hook, Controller)]

get_detectors

get_detectors()

Get all registered Detector hooks.

Returns:

Type Description
List[Detector]

List of Detector instances

Source code in src/mi_crow/language_model/layers.py
456
457
458
459
460
461
462
463
def get_detectors(self) -> List[Detector]:
    """
    Get all registered Detector hooks.

    Returns:
        List of Detector instances
    """
    return [hook for hook in self.get_hooks() if isinstance(hook, Detector)]

get_hooks

get_hooks(layer_signature=None, hook_type=None)

Get registered hooks, optionally filtered by layer and/or type.

Parameters:

Name Type Description Default
layer_signature str | int | None

Optional layer to filter by

None
hook_type HookType | str | None

Optional hook type to filter by (HookType.FORWARD or HookType.PRE_FORWARD)

None

Returns:

Type Description
List[Hook]

List of Hook instances

Source code in src/mi_crow/language_model/layers.py
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
def get_hooks(
        self,
        layer_signature: str | int | None = None,
        hook_type: HookType | str | None = None
) -> List[Hook]:
    """
    Get registered hooks, optionally filtered by layer and/or type.

    Args:
        layer_signature: Optional layer to filter by
        hook_type: Optional hook type to filter by (HookType.FORWARD or HookType.PRE_FORWARD)

    Returns:
        List of Hook instances
    """
    # Normalize hook_type if string
    normalized_hook_type = None
    if hook_type is not None:
        if isinstance(hook_type, str):
            normalized_hook_type = HookType(hook_type)
        else:
            normalized_hook_type = hook_type

    return self._get_hooks_from_registry(layer_signature, normalized_hook_type)

get_layer_names

get_layer_names()

Get all layer names.

Returns:

Type Description
List[str]

List of layer names

Source code in src/mi_crow/language_model/layers.py
 97
 98
 99
100
101
102
103
104
def get_layer_names(self) -> List[str]:
    """
    Get all layer names.

    Returns:
        List of layer names
    """
    return list(self.name_to_layer.keys())

print_layer_names

print_layer_names()

Print layer names with basic info.

Useful for debugging and exploring model structure.

Source code in src/mi_crow/language_model/layers.py
106
107
108
109
110
111
112
113
114
115
116
117
def print_layer_names(self) -> None:
    """
    Print layer names with basic info.

    Useful for debugging and exploring model structure.
    """
    names = self.get_layer_names()
    for name in names:
        layer = self.name_to_layer[name]
        weight_shape = getattr(layer, 'weight', None)
        weight_info = weight_shape.shape if weight_shape is not None else 'No weight'
        print(f"{name}: {weight_info}")

register_forward_hook_for_layer

register_forward_hook_for_layer(layer_signature, hook, hook_args=None)

Register a forward hook directly on a layer.

Parameters:

Name Type Description Default
layer_signature str | int

Layer name or index

required
hook Callable

Hook callable

required
hook_args dict

Optional arguments for register_forward_hook

None

Returns:

Type Description
Any

Hook handle

Source code in src/mi_crow/language_model/layers.py
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
def register_forward_hook_for_layer(
        self,
        layer_signature: str | int,
        hook: Callable,
        hook_args: dict = None
) -> Any:
    """
    Register a forward hook directly on a layer.

    Args:
        layer_signature: Layer name or index
        hook: Hook callable
        hook_args: Optional arguments for register_forward_hook

    Returns:
        Hook handle
    """
    layer = self._resolve_layer(layer_signature)
    return layer.register_forward_hook(hook, **(hook_args or {}))

register_hook

register_hook(layer_signature, hook, hook_type=None)

Register a hook on a layer.

Parameters:

Name Type Description Default
layer_signature str | int

Layer name or index

required
hook Hook

Hook instance to register

required
hook_type HookType | str | None

Type of hook (HookType.FORWARD or HookType.PRE_FORWARD). If None, uses hook.hook_type

None

Returns:

Type Description
str

The hook's ID

Raises:

Type Description
ValueError

If hook ID is not unique or if mixing hook types on same layer

Source code in src/mi_crow/language_model/layers.py
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
def register_hook(
        self,
        layer_signature: str | int,
        hook: Hook,
        hook_type: HookType | str | None = None
) -> str:
    """
    Register a hook on a layer.

    Args:
        layer_signature: Layer name or index
        hook: Hook instance to register
        hook_type: Type of hook (HookType.FORWARD or HookType.PRE_FORWARD). 
                  If None, uses hook.hook_type

    Returns:
        The hook's ID

    Raises:
        ValueError: If hook ID is not unique or if mixing hook types on same layer
    """
    layer = self._resolve_layer(layer_signature)

    if hook_type is None:
        hook_type = self._get_hook_type_from_hook(hook)
    elif isinstance(hook_type, str):
        hook_type = HookType(hook_type)

    self._validate_hook_registration(layer_signature, hook)

    hook.layer_signature = layer_signature
    hook.set_context(self.context)

    if layer_signature not in self.context._hook_registry:
        self.context._hook_registry[layer_signature] = {}

    if hook_type not in self.context._hook_registry[layer_signature]:
        self.context._hook_registry[layer_signature][hook_type] = []

    torch_hook_fn = hook.get_torch_hook()

    if hook_type == HookType.PRE_FORWARD:
        handle = layer.register_forward_pre_hook(torch_hook_fn)
    else:
        handle = layer.register_forward_hook(torch_hook_fn)

    self.context._hook_registry[layer_signature][hook_type].append((hook, handle))
    self.context._hook_id_map[hook.id] = (layer_signature, hook_type, hook)

    return hook.id

register_pre_forward_hook_for_layer

register_pre_forward_hook_for_layer(layer_signature, hook, hook_args=None)

Register a pre-forward hook directly on a layer.

Parameters:

Name Type Description Default
layer_signature str | int

Layer name or index

required
hook Callable

Hook callable

required
hook_args dict

Optional arguments for register_forward_pre_hook

None

Returns:

Type Description
Any

Hook handle

Source code in src/mi_crow/language_model/layers.py
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
def register_pre_forward_hook_for_layer(
        self,
        layer_signature: str | int,
        hook: Callable,
        hook_args: dict = None
) -> Any:
    """
    Register a pre-forward hook directly on a layer.

    Args:
        layer_signature: Layer name or index
        hook: Hook callable
        hook_args: Optional arguments for register_forward_pre_hook

    Returns:
        Hook handle
    """
    layer = self._resolve_layer(layer_signature)
    return layer.register_forward_pre_hook(hook, **(hook_args or {}))

unregister_hook

unregister_hook(hook_or_id)

Unregister a hook by Hook instance or ID.

Parameters:

Name Type Description Default
hook_or_id Hook | str

Hook instance or hook ID string

required

Returns:

Type Description
bool

True if hook was found and removed, False otherwise

Source code in src/mi_crow/language_model/layers.py
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
def unregister_hook(self, hook_or_id: Hook | str) -> bool:
    """
    Unregister a hook by Hook instance or ID.

    Args:
        hook_or_id: Hook instance or hook ID string

    Returns:
        True if hook was found and removed, False otherwise
    """
    # Get hook ID
    if isinstance(hook_or_id, Hook):
        hook_id = hook_or_id.id
    else:
        hook_id = hook_or_id

    # Look up hook
    if hook_id not in self.context._hook_id_map:
        return False

    layer_signature, hook_type, hook = self.context._hook_id_map[hook_id]

    if layer_signature not in self.context._hook_registry:
        del self.context._hook_id_map[hook_id]
        return True

    hook_types = self.context._hook_registry[layer_signature]
    if hook_type not in hook_types:
        del self.context._hook_id_map[hook_id]
        return True

    hooks_list = hook_types[hook_type]
    for i, (h, handle) in enumerate(hooks_list):
        if h.id == hook_id:
            handle.remove()
            hooks_list.pop(i)
            break

    del self.context._hook_id_map[hook_id]
    return True

mi_crow.language_model.tokenizer.LanguageModelTokenizer

LanguageModelTokenizer(context)

Handles tokenization for LanguageModel.

Initialize LanguageModelTokenizer.

Parameters:

Name Type Description Default
context LanguageModelContext

LanguageModelContext instance

required
Source code in src/mi_crow/language_model/tokenizer.py
13
14
15
16
17
18
19
20
21
22
23
def __init__(
        self,
        context: "LanguageModelContext"
):
    """
    Initialize LanguageModelTokenizer.

    Args:
        context: LanguageModelContext instance
    """
    self.context = context

split_to_tokens

split_to_tokens(text, add_special_tokens=False)

Split text into token strings.

Parameters:

Name Type Description Default
text Union[str, Sequence[str]]

Single string or sequence of strings to tokenize

required
add_special_tokens bool

Whether to add special tokens (e.g., BOS, EOS)

False

Returns:

Type Description
Union[List[str], List[List[str]]]

For a single string: list of token strings

Union[List[str], List[List[str]]]

For a sequence of strings: list of lists of token strings

Source code in src/mi_crow/language_model/tokenizer.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def split_to_tokens(
        self,
        text: Union[str, Sequence[str]],
        add_special_tokens: bool = False
) -> Union[List[str], List[List[str]]]:
    """
    Split text into token strings.

    Args:
        text: Single string or sequence of strings to tokenize
        add_special_tokens: Whether to add special tokens (e.g., BOS, EOS)

    Returns:
        For a single string: list of token strings
        For a sequence of strings: list of lists of token strings
    """
    if isinstance(text, str):
        return self._split_single_text_to_tokens(text, add_special_tokens)

    return [self._split_single_text_to_tokens(t, add_special_tokens) for t in text]

tokenize

tokenize(texts, padding=False, pad_token='[PAD]', **kwargs)

Robust batch tokenization that works across tokenizer variants.

Tries methods in order: - callable tokenizer (most HF tokenizers) - batch_encode_plus - encode_plus per item + tokenizer.pad to collate

Parameters:

Name Type Description Default
texts Sequence[str]

Sequence of text strings to tokenize

required
padding bool

Whether to pad sequences

False
pad_token str

Pad token string

'[PAD]'
**kwargs Any

Additional tokenizer arguments

{}

Returns:

Type Description
Any

Tokenized encodings

Raises:

Type Description
ValueError

If tokenizer is not initialized

TypeError

If tokenizer is not usable for batch tokenization

Source code in src/mi_crow/language_model/tokenizer.py
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
def tokenize(
        self,
        texts: Sequence[str],
        padding: bool = False,
        pad_token: str = "[PAD]",
        **kwargs: Any
) -> Any:
    """
    Robust batch tokenization that works across tokenizer variants.

    Tries methods in order:
    - callable tokenizer (most HF tokenizers)
    - batch_encode_plus
    - encode_plus per item + tokenizer.pad to collate

    Args:
        texts: Sequence of text strings to tokenize
        padding: Whether to pad sequences
        pad_token: Pad token string
        **kwargs: Additional tokenizer arguments

    Returns:
        Tokenized encodings

    Raises:
        ValueError: If tokenizer is not initialized
        TypeError: If tokenizer is not usable for batch tokenization
    """
    tokenizer = self.context.tokenizer
    if tokenizer is None:
        raise ValueError("Tokenizer must be initialized before tokenization")

    model = self.context.model

    if padding and pad_token and getattr(tokenizer, "pad_token", None) is None:
        self._setup_pad_token(tokenizer, model)

    kwargs["padding"] = padding

    # Try callable tokenizer first (most common case)
    if callable(tokenizer):
        try:
            return tokenizer(texts, **kwargs)
        except (TypeError, NotImplementedError):
            pass

    # Try batch_encode_plus
    if hasattr(tokenizer, "batch_encode_plus"):
        return tokenizer.batch_encode_plus(texts, **kwargs)

    # Fallback to encode_plus per item
    if hasattr(tokenizer, "encode_plus"):
        encoded = [tokenizer.encode_plus(t, **kwargs) for t in texts]
        if hasattr(tokenizer, "pad"):
            rt = kwargs.get("return_tensors") or "pt"
            return tokenizer.pad(encoded, return_tensors=rt)
        return encoded

    raise TypeError("Tokenizer object on LanguageModel is not usable for batch tokenization")

mi_crow.language_model.activations.LanguageModelActivations

LanguageModelActivations(context)

Handles activation saving and processing for LanguageModel.

Initialize LanguageModelActivations.

Parameters:

Name Type Description Default
context LanguageModelContext

LanguageModelContext instance

required
Source code in src/mi_crow/language_model/activations.py
24
25
26
27
28
29
30
31
def __init__(self, context: "LanguageModelContext"):  # noqa: F821
    """
    Initialize LanguageModelActivations.

    Args:
        context: LanguageModelContext instance
    """
    self.context = context

save_activations

save_activations(texts, layer_signature, run_name=None, batch_size=None, *, dtype=None, max_length=None, autocast=True, autocast_dtype=None, free_cuda_cache_every=0, verbose=False, save_in_batches=True, save_attention_mask=False, stop_after_last_layer=True)

Save activations from a list of texts.

Parameters:

Name Type Description Default
texts Sequence[str]

Sequence of text strings to process

required
layer_signature str | int | list[str | int]

Layer signature (or list of signatures) to capture activations from

required
run_name str | None

Optional run name (generated if None)

None
batch_size int | None

Optional batch size for processing (if None, processes all at once)

None
dtype dtype | None

Optional dtype to convert activations to

None
max_length int | None

Optional max length for tokenization

None
autocast bool

Whether to use autocast

True
autocast_dtype dtype | None

Optional dtype for autocast

None
free_cuda_cache_every int | None

Clear CUDA cache every N batches (0 or None to disable)

0
verbose bool

Whether to log progress

False
save_attention_mask bool

Whether to also save attention masks (automatically attaches ModelInputDetector)

False
stop_after_last_layer bool

Whether to stop model forward pass after the last requested layer to save memory and time. Defaults to True.

True

Returns:

Type Description
str

Run name used for saving

Raises:

Type Description
ValueError

If model or store is not initialized

Source code in src/mi_crow/language_model/activations.py
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
def save_activations(
    self,
    texts: Sequence[str],
    layer_signature: str | int | list[str | int],
    run_name: str | None = None,
    batch_size: int | None = None,
    *,
    dtype: torch.dtype | None = None,
    max_length: int | None = None,
    autocast: bool = True,
    autocast_dtype: torch.dtype | None = None,
    free_cuda_cache_every: int | None = 0,
    verbose: bool = False,
    save_in_batches: bool = True,
    save_attention_mask: bool = False,
    stop_after_last_layer: bool = True,
) -> str:
    """
    Save activations from a list of texts.

    Args:
        texts: Sequence of text strings to process
        layer_signature: Layer signature (or list of signatures) to capture activations from
        run_name: Optional run name (generated if None)
        batch_size: Optional batch size for processing (if None, processes all at once)
        dtype: Optional dtype to convert activations to
        max_length: Optional max length for tokenization
        autocast: Whether to use autocast
        autocast_dtype: Optional dtype for autocast
        free_cuda_cache_every: Clear CUDA cache every N batches (0 or None to disable)
        verbose: Whether to log progress
        save_attention_mask: Whether to also save attention masks (automatically attaches ModelInputDetector)
        stop_after_last_layer: Whether to stop model forward pass after the last requested layer
            to save memory and time. Defaults to True.

    Returns:
        Run name used for saving

    Raises:
        ValueError: If model or store is not initialized
    """
    if not texts:
        raise ValueError("Texts list cannot be empty")

    model, store = self._validate_save_prerequisites()

    device = torch.device(self.context.device)
    device_type = str(device.type)

    if batch_size is None:
        batch_size = len(texts)

    options = {
        "dtype": str(dtype) if dtype is not None else None,
        "max_length": max_length,
        "batch_size": int(batch_size),
        "stop_after_last_layer": stop_after_last_layer,
    }

    run_name, meta, layer_sig_list = self._prepare_save_metadata(layer_signature, None, run_name, options)

    if verbose:
        logger.info(
            f"Starting save_activations: run={run_name}, layers={layer_sig_list}, "
            f"batch_size={batch_size}, device={device_type}"
        )

    self._save_run_metadata(store, run_name, meta, verbose)

    hook_ids, attention_mask_hook_id = self._setup_activation_hooks(
        layer_sig_list, run_name, save_attention_mask, dtype=dtype
    )

    batch_counter = 0
    # Stop after last hooked layer if requested
    stop_after = layer_sig_list[-1] if (layer_sig_list and stop_after_last_layer) else None

    try:
        with torch.inference_mode():
            for i in range(0, len(texts), batch_size):
                batch_texts = texts[i : i + batch_size]
                batch_index = i // batch_size

                self._process_batch(
                    batch_texts,
                    run_name,
                    batch_index,
                    max_length,
                    autocast,
                    autocast_dtype,
                    dtype,
                    verbose,
                    save_in_batches=save_in_batches,
                    stop_after_layer=stop_after,
                )
                batch_counter += 1
                self._manage_cuda_cache(batch_counter, free_cuda_cache_every, device_type, verbose)
    finally:
        self._teardown_activation_hooks(hook_ids, attention_mask_hook_id)
        if verbose:
            logger.info(f"Completed save_activations: run={run_name}, batches_saved={batch_counter}")

    return run_name

save_activations_dataset

save_activations_dataset(dataset, layer_signature, run_name=None, batch_size=32, *, dtype=None, max_length=None, autocast=True, autocast_dtype=None, free_cuda_cache_every=None, verbose=False, save_in_batches=True, save_attention_mask=False, stop_after_last_layer=True)

Save activations from a dataset.

Parameters:

Name Type Description Default
dataset BaseDataset

Dataset to process

required
layer_signature str | int | list[str | int]

Layer signature (or list of signatures) to capture activations from

required
run_name str | None

Optional run name (generated if None)

None
batch_size int

Batch size for processing

32
dtype dtype | None

Optional dtype to convert activations to

None
max_length int | None

Optional max length for tokenization

None
autocast bool

Whether to use autocast

True
autocast_dtype dtype | None

Optional dtype for autocast

None
free_cuda_cache_every int | None

Clear CUDA cache every N batches (None to auto-detect, 0 to disable)

None
verbose bool

Whether to log progress

False
save_attention_mask bool

Whether to also save attention masks (automatically attaches ModelInputDetector)

False
stop_after_last_layer bool

Whether to stop model forward pass after the last requested layer to save memory and time. Defaults to True.

True

Returns:

Type Description
str

Run name used for saving

Raises:

Type Description
ValueError

If model or store is not initialized

Source code in src/mi_crow/language_model/activations.py
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
def save_activations_dataset(
    self,
    dataset: BaseDataset,
    layer_signature: str | int | list[str | int],
    run_name: str | None = None,
    batch_size: int = 32,
    *,
    dtype: torch.dtype | None = None,
    max_length: int | None = None,
    autocast: bool = True,
    autocast_dtype: torch.dtype | None = None,
    free_cuda_cache_every: int | None = None,
    verbose: bool = False,
    save_in_batches: bool = True,
    save_attention_mask: bool = False,
    stop_after_last_layer: bool = True,
) -> str:
    """
    Save activations from a dataset.

    Args:
        dataset: Dataset to process
        layer_signature: Layer signature (or list of signatures) to capture activations from
        run_name: Optional run name (generated if None)
        batch_size: Batch size for processing
        dtype: Optional dtype to convert activations to
        max_length: Optional max length for tokenization
        autocast: Whether to use autocast
        autocast_dtype: Optional dtype for autocast
        free_cuda_cache_every: Clear CUDA cache every N batches (None to auto-detect, 0 to disable)
        verbose: Whether to log progress
        save_attention_mask: Whether to also save attention masks (automatically attaches ModelInputDetector)
        stop_after_last_layer: Whether to stop model forward pass after the last requested layer
            to save memory and time. Defaults to True.

    Returns:
        Run name used for saving

    Raises:
        ValueError: If model or store is not initialized
    """
    model, store = self._validate_save_prerequisites()

    device = torch.device(self.context.device)
    device_type = str(device.type)

    if free_cuda_cache_every is None:
        free_cuda_cache_every = 5 if device_type == "cuda" else 0

    options = {
        "dtype": str(dtype) if dtype is not None else None,
        "max_length": max_length,
        "batch_size": int(batch_size),
        "stop_after_last_layer": stop_after_last_layer,
    }

    run_name, meta, layer_sig_list = self._prepare_save_metadata(layer_signature, dataset, run_name, options)

    if verbose:
        logger.info(
            f"Starting save_activations_dataset: run={run_name}, layers={layer_sig_list}, "
            f"batch_size={batch_size}, device={device_type}"
        )

    self._save_run_metadata(store, run_name, meta, verbose)

    hook_ids, attention_mask_hook_id = self._setup_activation_hooks(
        layer_sig_list, run_name, save_attention_mask, dtype=dtype
    )

    batch_counter = 0
    # Stop after last hooked layer if requested
    stop_after = layer_sig_list[-1] if (layer_sig_list and stop_after_last_layer) else None

    try:
        with torch.inference_mode():
            for batch_index, batch in enumerate(dataset.iter_batches(batch_size)):
                texts = dataset.extract_texts_from_batch(batch)
                self._process_batch(
                    texts,
                    run_name,
                    batch_index,
                    max_length,
                    autocast,
                    autocast_dtype,
                    dtype,
                    verbose,
                    save_in_batches=save_in_batches,
                    stop_after_layer=stop_after,
                )
                batch_counter += 1

                self._manage_cuda_cache(batch_counter, free_cuda_cache_every, device_type, verbose)
    finally:
        self._teardown_activation_hooks(hook_ids, attention_mask_hook_id)
        if verbose:
            logger.info(f"Completed save_activations_dataset: run={run_name}, batches_saved={batch_counter}")

    return run_name

mi_crow.language_model.inference.InferenceEngine

InferenceEngine(language_model)

Handles inference operations for LanguageModel.

Initialize inference engine.

Parameters:

Name Type Description Default
language_model 'LanguageModel'

LanguageModel instance

required
Source code in src/mi_crow/language_model/inference.py
35
36
37
38
39
40
41
42
def __init__(self, language_model: "LanguageModel"):
    """
    Initialize inference engine.

    Args:
        language_model: LanguageModel instance
    """
    self.lm = language_model

execute_inference

execute_inference(texts, tok_kwargs=None, autocast=True, autocast_dtype=None, with_controllers=True, stop_after_layer=None)

Execute inference on texts.

Parameters:

Name Type Description Default
texts Sequence[str]

Sequence of input texts

required
tok_kwargs Dict | None

Optional tokenizer keyword arguments

None
autocast bool

Whether to use automatic mixed precision

True
autocast_dtype dtype | None

Optional dtype for autocast

None
with_controllers bool

Whether to use controllers during inference

True
stop_after_layer str | int | None

Optional layer signature (name or index) after which the forward pass should be stopped early

None

Returns:

Type Description
tuple[Any, Dict[str, Tensor]]

Tuple of (model_output, encodings)

Raises:

Type Description
ValueError

If texts is empty or tokenizer is not initialized

Source code in src/mi_crow/language_model/inference.py
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
def execute_inference(
        self,
        texts: Sequence[str],
        tok_kwargs: Dict | None = None,
        autocast: bool = True,
        autocast_dtype: torch.dtype | None = None,
        with_controllers: bool = True,
        stop_after_layer: str | int | None = None,
) -> tuple[Any, Dict[str, torch.Tensor]]:
    """
    Execute inference on texts.

    Args:
        texts: Sequence of input texts
        tok_kwargs: Optional tokenizer keyword arguments
        autocast: Whether to use automatic mixed precision
        autocast_dtype: Optional dtype for autocast
        with_controllers: Whether to use controllers during inference
        stop_after_layer: Optional layer signature (name or index) after which
            the forward pass should be stopped early

    Returns:
        Tuple of (model_output, encodings)

    Raises:
        ValueError: If texts is empty or tokenizer is not initialized
    """
    if not texts:
        raise ValueError("Texts list cannot be empty")

    if self.lm.tokenizer is None:
        raise ValueError("Tokenizer must be initialized before running inference")

    tok_kwargs = self._prepare_tokenizer_kwargs(tok_kwargs)
    logger.debug(f"[DEBUG] About to tokenize {len(texts)} texts...")
    enc = self.lm.tokenize(texts, **tok_kwargs)
    logger.debug(f"[DEBUG] Tokenization completed, shape: {enc['input_ids'].shape if isinstance(enc, dict) else 'N/A'}")

    device = torch.device(self.lm.context.device)
    device_type = str(device.type)

    sync_model_to_context_device(self.lm)

    enc = move_tensors_to_device(enc, device)

    self.lm.model.eval()

    self._setup_trackers(texts)
    self._setup_model_input_detectors(enc)

    controllers_to_restore = self._prepare_controllers(with_controllers)

    hook_handle = None
    try:
        if stop_after_layer is not None:
            # Register a temporary forward hook that stops the forward pass
            def _early_stop_hook(module: nn.Module, inputs: tuple, output: Any):
                raise _EarlyStopInference(output)

            hook_handle = self.lm.layers.register_forward_hook_for_layer(
                stop_after_layer, _early_stop_hook
            )

        output = self._run_model_forward(enc, autocast, device_type, autocast_dtype)
        return output, enc
    finally:
        if hook_handle is not None:
            try:
                hook_handle.remove()
            except Exception:
                pass
        self._restore_controllers(controllers_to_restore)

extract_logits

extract_logits(output)

Extract logits tensor from model output.

Parameters:

Name Type Description Default
output Any

Model output

required

Returns:

Type Description
Tensor

Logits tensor

Source code in src/mi_crow/language_model/inference.py
231
232
233
234
235
236
237
238
239
240
241
def extract_logits(self, output: Any) -> torch.Tensor:
    """
    Extract logits tensor from model output.

    Args:
        output: Model output

    Returns:
        Logits tensor
    """
    return extract_logits_from_output(output)

infer_dataset

infer_dataset(dataset, run_name=None, batch_size=32, tok_kwargs=None, autocast=True, autocast_dtype=None, with_controllers=True, free_cuda_cache_every=0, clear_detectors_before=False, verbose=False, stop_after_layer=None, save_in_batches=True)

Run inference on whole dataset with metadata saving.

Parameters:

Name Type Description Default
dataset 'BaseDataset'

Dataset to process

required
run_name str | None

Optional run name (generated if None)

None
batch_size int

Batch size for processing

32
tok_kwargs Dict | None

Optional tokenizer keyword arguments

None
autocast bool

Whether to use automatic mixed precision

True
autocast_dtype dtype | None

Optional dtype for autocast

None
with_controllers bool

Whether to use controllers during inference

True
free_cuda_cache_every int | None

Clear CUDA cache every N batches (0 or None to disable)

0
clear_detectors_before bool

If True, clears all detector state before running

False
verbose bool

Whether to log progress

False
stop_after_layer str | int | None

Optional layer signature (name or index) after which the forward pass should be stopped early

None

Returns:

Type Description
str

Run name used for saving

Raises:

Type Description
ValueError

If model or store is not initialized

Source code in src/mi_crow/language_model/inference.py
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
def infer_dataset(
    self,
    dataset: "BaseDataset",
    run_name: str | None = None,
    batch_size: int = 32,
    tok_kwargs: Dict | None = None,
    autocast: bool = True,
    autocast_dtype: torch.dtype | None = None,
    with_controllers: bool = True,
    free_cuda_cache_every: int | None = 0,
    clear_detectors_before: bool = False,
    verbose: bool = False,
    stop_after_layer: str | int | None = None,
    save_in_batches: bool = True,
) -> str:
    """
    Run inference on whole dataset with metadata saving.

    Args:
        dataset: Dataset to process
        run_name: Optional run name (generated if None)
        batch_size: Batch size for processing
        tok_kwargs: Optional tokenizer keyword arguments
        autocast: Whether to use automatic mixed precision
        autocast_dtype: Optional dtype for autocast
        with_controllers: Whether to use controllers during inference
        free_cuda_cache_every: Clear CUDA cache every N batches (0 or None to disable)
        clear_detectors_before: If True, clears all detector state before running
        verbose: Whether to log progress
        stop_after_layer: Optional layer signature (name or index) after which
            the forward pass should be stopped early

    Returns:
        Run name used for saving

    Raises:
        ValueError: If model or store is not initialized
    """
    if clear_detectors_before:
        self.lm.clear_detectors()

    model: nn.Module | None = self.lm.model
    if model is None:
        raise ValueError("Model must be initialized before running")

    store = self.lm.store
    if store is None:
        raise ValueError("Store must be provided or set on the language model")

    device = torch.device(self.lm.context.device)
    device_type = str(device.type)

    options = {
        "max_length": tok_kwargs.get("max_length") if tok_kwargs else None,
        "batch_size": int(batch_size),
    }

    run_name, meta = self._prepare_run_metadata(dataset=dataset, run_name=run_name, options=options)

    if verbose:
        logger.info(
            f"Starting infer_dataset: run={run_name}, "
            f"batch_size={batch_size}, device={device_type}"
        )

    self._save_run_metadata(store, run_name, meta, verbose)

    batch_counter = 0

    with torch.inference_mode():
        for batch_index, batch in enumerate(dataset.iter_batches(batch_size)):
            if not batch:
                continue

            texts = dataset.extract_texts_from_batch(batch)

            self.execute_inference(
                texts,
                tok_kwargs=tok_kwargs,
                autocast=autocast,
                autocast_dtype=autocast_dtype,
                with_controllers=with_controllers,
                stop_after_layer=stop_after_layer,
            )

            self.lm.save_detector_metadata(run_name, batch_index, unified=not save_in_batches)

            batch_counter += 1

            if device_type == "cuda" and free_cuda_cache_every and free_cuda_cache_every > 0:
                if (batch_counter % free_cuda_cache_every) == 0:
                    torch.cuda.empty_cache()
                    if verbose:
                        logger.info("Emptied CUDA cache")

            if verbose:
                logger.info(f"Saved batch {batch_index} for run={run_name}")

    if verbose:
        logger.info(f"Completed infer_dataset: run={run_name}, batches_saved={batch_counter}")

    return run_name

infer_texts

infer_texts(texts, run_name=None, batch_size=None, tok_kwargs=None, autocast=True, autocast_dtype=None, with_controllers=True, clear_detectors_before=False, verbose=False, stop_after_layer=None, save_in_batches=True)

Run inference on list of strings with optional metadata saving.

Parameters:

Name Type Description Default
texts Sequence[str]

Sequence of input texts

required
run_name str | None

Optional run name for saving metadata (if None, no metadata saved)

None
batch_size int | None

Optional batch size for processing (if None, processes all at once)

None
tok_kwargs Dict | None

Optional tokenizer keyword arguments

None
autocast bool

Whether to use automatic mixed precision

True
autocast_dtype dtype | None

Optional dtype for autocast

None
with_controllers bool

Whether to use controllers during inference

True
clear_detectors_before bool

If True, clears all detector state before running

False
verbose bool

Whether to log progress

False
stop_after_layer str | int | None

Optional layer signature (name or index) after which the forward pass should be stopped early

None
save_in_batches bool

If True, save detector metadata in per‑batch directories. If False, aggregate all detector metadata for the run under a single detectors directory.

True

Returns:

Type Description
tuple[Any, Dict[str, Tensor]] | tuple[List[Any], List[Dict[str, Tensor]]]

If batch_size is None or >= len(texts): Tuple of (model_output, encodings)

tuple[Any, Dict[str, Tensor]] | tuple[List[Any], List[Dict[str, Tensor]]]

If batch_size < len(texts): Tuple of (list of outputs, list of encodings)

Raises:

Type Description
ValueError

If texts is empty or tokenizer is not initialized

Source code in src/mi_crow/language_model/inference.py
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
def infer_texts(
    self,
    texts: Sequence[str],
    run_name: str | None = None,
    batch_size: int | None = None,
    tok_kwargs: Dict | None = None,
    autocast: bool = True,
    autocast_dtype: torch.dtype | None = None,
    with_controllers: bool = True,
    clear_detectors_before: bool = False,
    verbose: bool = False,
    stop_after_layer: str | int | None = None,
    save_in_batches: bool = True,
) -> tuple[Any, Dict[str, torch.Tensor]] | tuple[List[Any], List[Dict[str, torch.Tensor]]]:
    """
    Run inference on list of strings with optional metadata saving.

    Args:
        texts: Sequence of input texts
        run_name: Optional run name for saving metadata (if None, no metadata saved)
        batch_size: Optional batch size for processing (if None, processes all at once)
        tok_kwargs: Optional tokenizer keyword arguments
        autocast: Whether to use automatic mixed precision
        autocast_dtype: Optional dtype for autocast
        with_controllers: Whether to use controllers during inference
        clear_detectors_before: If True, clears all detector state before running
        verbose: Whether to log progress
        stop_after_layer: Optional layer signature (name or index) after which
            the forward pass should be stopped early
        save_in_batches: If True, save detector metadata in per‑batch
            directories. If False, aggregate all detector metadata for
            the run under a single detectors directory.

    Returns:
        If batch_size is None or >= len(texts): Tuple of (model_output, encodings)
        If batch_size < len(texts): Tuple of (list of outputs, list of encodings)

    Raises:
        ValueError: If texts is empty or tokenizer is not initialized
    """
    if not texts:
        raise ValueError("Texts list cannot be empty")

    if self.lm.tokenizer is None:
        raise ValueError("Tokenizer must be initialized before running inference")

    if clear_detectors_before:
        self.lm.clear_detectors()

    store = self.lm.store
    if run_name is not None and store is None:
        raise ValueError("Store must be provided to save metadata")

    if batch_size is None or batch_size >= len(texts):
        output, enc = self.execute_inference(
            texts,
            tok_kwargs=tok_kwargs,
            autocast=autocast,
            autocast_dtype=autocast_dtype,
            with_controllers=with_controllers,
            stop_after_layer=stop_after_layer,
        )

        if run_name is not None:
            options = {
                "batch_size": len(texts),
                "max_length": tok_kwargs.get("max_length") if tok_kwargs else None,
            }
            _, meta = self._prepare_run_metadata(dataset=None, run_name=run_name, options=options)
            self._save_run_metadata(store, run_name, meta, verbose)
            self.lm.save_detector_metadata(run_name, 0, unified=not save_in_batches)

        return output, enc

    all_outputs = []
    all_encodings = []
    batch_counter = 0

    if run_name is not None:
        options = {
            "batch_size": batch_size,
            "max_length": tok_kwargs.get("max_length") if tok_kwargs else None,
        }
        _, meta = self._prepare_run_metadata(dataset=None, run_name=run_name, options=options)
        self._save_run_metadata(store, run_name, meta, verbose)

    for i in range(0, len(texts), batch_size):
        batch_texts = texts[i:i + batch_size]
        output, enc = self.execute_inference(
            batch_texts,
            tok_kwargs=tok_kwargs,
            autocast=autocast,
            autocast_dtype=autocast_dtype,
            with_controllers=with_controllers,
            stop_after_layer=stop_after_layer,
        )

        all_outputs.append(output)
        all_encodings.append(enc)

        if run_name is not None:
            self.lm.save_detector_metadata(run_name, batch_counter, unified=not save_in_batches)
            if verbose:
                logger.info(f"Saved batch {batch_counter} for run={run_name}")

        batch_counter += 1

    return all_outputs, all_encodings

Utilities

mi_crow.language_model.initialization

Model initialization and factory methods.

create_from_huggingface

create_from_huggingface(cls, model_name, store, tokenizer_params=None, model_params=None, device=None)

Load a language model from HuggingFace Hub.

Parameters:

Name Type Description Default
cls type['LanguageModel']

LanguageModel class

required
model_name str

HuggingFace model identifier

required
store Store

Store instance for persistence

required
tokenizer_params dict | None

Optional tokenizer parameters

None
model_params dict | None

Optional model parameters

None
device str | device | None

Target device ("cuda", "cpu", "mps"). Model will be moved to this device after loading.

None

Returns: LanguageModel instance

Raises:

Type Description
ValueError

If model_name is invalid

RuntimeError

If model loading fails

Source code in src/mi_crow/language_model/initialization.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def create_from_huggingface(
        cls: type["LanguageModel"],
        model_name: str,
        store: Store,
        tokenizer_params: dict | None = None,
        model_params: dict | None = None,
        device: str | torch.device | None = None,
) -> "LanguageModel":
    """
    Load a language model from HuggingFace Hub.

    Args:
        cls: LanguageModel class
        model_name: HuggingFace model identifier
        store: Store instance for persistence
        tokenizer_params: Optional tokenizer parameters
        model_params: Optional model parameters
        device: Target device ("cuda", "cpu", "mps"). Model will be moved to this device
            after loading.
    Returns:
        LanguageModel instance

    Raises:
        ValueError: If model_name is invalid
        RuntimeError: If model loading fails
    """
    if not model_name or not isinstance(model_name, str) or not model_name.strip():
        raise ValueError(f"model_name must be a non-empty string, got: {model_name!r}")

    if store is None:
        raise ValueError("store cannot be None")

    if tokenizer_params is None:
        tokenizer_params = {}
    if model_params is None:
        model_params = {}

    try:
        tokenizer = AutoTokenizer.from_pretrained(model_name, **tokenizer_params)
        model = AutoModelForCausalLM.from_pretrained(model_name, **model_params)
    except Exception as e:
        raise RuntimeError(
            f"Failed to load model '{model_name}' from HuggingFace. Error: {e}"
        ) from e

    return cls(model, tokenizer, store, device=device)

create_from_local_torch

create_from_local_torch(cls, model_path, tokenizer_path, store, device=None)

Load a language model from local HuggingFace paths.

Parameters:

Name Type Description Default
cls type['LanguageModel']

LanguageModel class

required
model_path str

Path to the model directory or file

required
tokenizer_path str

Path to the tokenizer directory or file

required
store Store

Store instance for persistence

required
device str | device | None

Optional device string or torch.device (defaults to 'cpu' if None)

None

Returns:

Type Description
'LanguageModel'

LanguageModel instance

Raises:

Type Description
FileNotFoundError

If model or tokenizer paths don't exist

RuntimeError

If model loading fails

Source code in src/mi_crow/language_model/initialization.py
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
def create_from_local_torch(
        cls: type["LanguageModel"],
        model_path: str,
        tokenizer_path: str,
        store: Store,
        device: str | torch.device | None = None,
) -> "LanguageModel":
    """
    Load a language model from local HuggingFace paths.

    Args:
        cls: LanguageModel class
        model_path: Path to the model directory or file
        tokenizer_path: Path to the tokenizer directory or file
        store: Store instance for persistence
        device: Optional device string or torch.device (defaults to 'cpu' if None)

    Returns:
        LanguageModel instance

    Raises:
        FileNotFoundError: If model or tokenizer paths don't exist
        RuntimeError: If model loading fails
    """
    if store is None:
        raise ValueError("store cannot be None")

    model_path_obj = Path(model_path)
    tokenizer_path_obj = Path(tokenizer_path)

    if not model_path_obj.exists():
        raise FileNotFoundError(f"Model path does not exist: {model_path}")

    if not tokenizer_path_obj.exists():
        raise FileNotFoundError(f"Tokenizer path does not exist: {tokenizer_path}")

    try:
        tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
        model = AutoModelForCausalLM.from_pretrained(model_path)
    except Exception as e:
        raise RuntimeError(
            f"Failed to load model from local paths. "
            f"model_path={model_path!r}, tokenizer_path={tokenizer_path!r}. Error: {e}"
        ) from e

    return cls(model, tokenizer, store, device=device)

initialize_model_id

initialize_model_id(model, provided_model_id=None)

Initialize model ID for LanguageModel.

Parameters:

Name Type Description Default
model Module

PyTorch model module

required
provided_model_id str | None

Optional model ID provided by user

None

Returns:

Type Description
str

Model ID string

Source code in src/mi_crow/language_model/initialization.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def initialize_model_id(
        model: nn.Module,
        provided_model_id: str | None = None
) -> str:
    """
    Initialize model ID for LanguageModel.

    Args:
        model: PyTorch model module
        provided_model_id: Optional model ID provided by user

    Returns:
        Model ID string
    """
    return extract_model_id(model, provided_model_id)

mi_crow.language_model.persistence

Model persistence (save/load) operations.

load_model_from_saved_file

load_model_from_saved_file(cls, saved_path, store, model_id=None, device=None)

Load a language model from a saved file (created by save_model).

Parameters:

Name Type Description Default
cls type['LanguageModel']

LanguageModel class

required
saved_path Path | str

Path to the saved model file (.pt file)

required
store 'Store'

Store instance for persistence

required
model_id str | None

Optional model identifier. If not provided, will use the model_id from saved metadata. If provided, will be used to load the model architecture from HuggingFace.

None
device str | device | None

Optional device string or torch.device (defaults to 'cpu' if None)

None

Returns:

Type Description
'LanguageModel'

LanguageModel instance

Raises:

Type Description
FileNotFoundError

If the saved file doesn't exist

ValueError

If the saved file format is invalid or model_id is required but not provided

RuntimeError

If model loading fails

Source code in src/mi_crow/language_model/persistence.py
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
def load_model_from_saved_file(
        cls: type["LanguageModel"],
        saved_path: Path | str,
        store: "Store",
        model_id: str | None = None,
        device: str | torch.device | None = None,
) -> "LanguageModel":
    """
    Load a language model from a saved file (created by save_model).

    Args:
        cls: LanguageModel class
        saved_path: Path to the saved model file (.pt file)
        store: Store instance for persistence
        model_id: Optional model identifier. If not provided, will use the model_id from saved metadata.
                 If provided, will be used to load the model architecture from HuggingFace.
        device: Optional device string or torch.device (defaults to 'cpu' if None)

    Returns:
        LanguageModel instance

    Raises:
        FileNotFoundError: If the saved file doesn't exist
        ValueError: If the saved file format is invalid or model_id is required but not provided
        RuntimeError: If model loading fails
    """
    if store is None:
        raise ValueError("store cannot be None")

    saved_path = Path(saved_path)
    if not saved_path.exists():
        raise FileNotFoundError(f"Saved model file not found: {saved_path}")

    # Load the saved payload
    try:
        payload = torch.load(saved_path, map_location='cpu')
    except Exception as e:
        raise RuntimeError(
            f"Failed to load model file {saved_path}. Error: {e}"
        ) from e

    # Validate payload structure
    if "model_state_dict" not in payload:
        raise ValueError(f"Invalid saved model format: missing 'model_state_dict' key in {saved_path}")
    if "metadata" not in payload:
        raise ValueError(f"Invalid saved model format: missing 'metadata' key in {saved_path}")

    model_state_dict = payload["model_state_dict"]
    metadata_dict = payload["metadata"]

    # Get model_id from metadata or use provided one
    saved_model_id = metadata_dict.get("model_id")
    if model_id is None:
        if saved_model_id is None:
            raise ValueError(
                f"model_id not found in saved metadata and not provided. "
                f"Please provide model_id parameter."
            )
        model_id = saved_model_id

    # Load model and tokenizer from HuggingFace using model_id
    # This assumes model_id is a valid HuggingFace model name
    try:
        tokenizer = AutoTokenizer.from_pretrained(model_id)
        model = AutoModelForCausalLM.from_pretrained(model_id)
    except Exception as e:
        raise ValueError(
            f"Failed to load model '{model_id}' from HuggingFace. "
            f"Error: {e}. "
            f"Please ensure model_id is a valid HuggingFace model name."
        ) from e

    # Load the saved state dict into the model
    try:
        model.load_state_dict(model_state_dict)
    except Exception as e:
        raise RuntimeError(
            f"Failed to load state dict into model '{model_id}'. Error: {e}"
        ) from e

    # Create LanguageModel instance
    lm = cls(model, tokenizer, store, model_id=model_id, device=device)

    # Note: Hooks are not automatically restored as they require hook instances
    # The hook metadata is available in metadata_dict["hooks"] if needed

    from mi_crow.utils import get_logger
    logger = get_logger(__name__)
    logger.info(f"Loaded model from {saved_path} (model_id: {model_id})")

    return lm

save_model

save_model(language_model, path=None)

Save the model and its metadata to the store.

Parameters:

Name Type Description Default
language_model 'LanguageModel'

LanguageModel instance to save

required
path Path | str | None

Optional path to save the model. If None, defaults to {model_id}/model.pt relative to the store base path.

None

Returns:

Type Description
Path

Path where the model was saved

Raises:

Type Description
ValueError

If store is not set

OSError

If file operations fail

Source code in src/mi_crow/language_model/persistence.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
def save_model(
        language_model: "LanguageModel",
        path: Path | str | None = None
) -> Path:
    """
    Save the model and its metadata to the store.

    Args:
        language_model: LanguageModel instance to save
        path: Optional path to save the model. If None, defaults to {model_id}/model.pt
              relative to the store base path.

    Returns:
        Path where the model was saved

    Raises:
        ValueError: If store is not set
        OSError: If file operations fail
    """
    if language_model.store is None:
        raise ValueError("Store must be provided or set on the language model")

    # Determine save path
    if path is None:
        save_path = Path(language_model.store.base_path) / language_model.model_id / "model.pt"
    else:
        save_path = Path(path)
        # If path is relative, make it relative to store base path
        if not save_path.is_absolute():
            save_path = Path(language_model.store.base_path) / save_path

    # Ensure parent directory exists
    save_path.parent.mkdir(parents=True, exist_ok=True)

    # Collect hooks information
    hooks_info = collect_hooks_metadata(language_model.context)

    # Save model state dict
    model_state_dict = language_model.model.state_dict()

    # Create metadata
    metadata = ModelMetadata(
        model_id=language_model.model_id,
        hooks=hooks_info,
        model_path=str(save_path)
    )

    # Save everything in a single file
    payload = {
        "model_state_dict": model_state_dict,
        "metadata": asdict(metadata),
    }

    try:
        torch.save(payload, save_path)
    except OSError as e:
        raise OSError(
            f"Failed to save model to {save_path}. Error: {e}"
        ) from e

    from mi_crow.utils import get_logger
    logger = get_logger(__name__)
    logger.info(f"Saved model to {save_path}")

    return save_path