Skip to content

Hooks API

Hook system for intercepting and managing model activations during inference.

Core Hook Classes

mi_crow.hooks.hook.Hook

Hook(layer_signature=None, hook_type=HookType.FORWARD, hook_id=None)

Bases: ABC

Abstract base class for hooks that can be registered on language model layers.

Hooks provide a way to intercept and process activations during model inference. They expose PyTorch-compatible callables via get_torch_hook() while providing additional functionality like enable/disable and unique identification.

Initialize a hook.

Parameters:

Name Type Description Default
layer_signature str | int | None

Layer name or index to attach hook to

None
hook_type HookType | str

Type of hook - HookType.FORWARD or HookType.PRE_FORWARD

FORWARD
hook_id str | None

Unique identifier (auto-generated if not provided)

None

Raises:

Type Description
ValueError

If hook_type string is invalid

Source code in src/mi_crow/hooks/hook.py
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def __init__(
        self,
        layer_signature: str | int | None = None,
        hook_type: HookType | str = HookType.FORWARD,
        hook_id: str | None = None
):
    """
    Initialize a hook.

    Args:
        layer_signature: Layer name or index to attach hook to
        hook_type: Type of hook - HookType.FORWARD or HookType.PRE_FORWARD
        hook_id: Unique identifier (auto-generated if not provided)

    Raises:
        ValueError: If hook_type string is invalid
    """
    self.layer_signature = layer_signature
    self.hook_type = self._normalize_hook_type(hook_type)
    self.id = hook_id if hook_id is not None else str(uuid.uuid4())
    self._enabled = True
    self._torch_hook_handle = None
    self._context: Optional["LanguageModelContext"] = None

context property

context

Get the LanguageModelContext associated with this hook.

enabled property

enabled

Whether this hook is currently enabled.

disable

disable()

Disable this hook.

Source code in src/mi_crow/hooks/hook.py
150
151
152
def disable(self) -> None:
    """Disable this hook."""
    self._enabled = False

enable

enable()

Enable this hook.

Source code in src/mi_crow/hooks/hook.py
146
147
148
def enable(self) -> None:
    """Enable this hook."""
    self._enabled = True

get_torch_hook

get_torch_hook()

Return a PyTorch-compatible hook function.

The returned callable will check the enabled flag before executing and call the abstract _hook_fn method.

Returns:

Type Description
Callable

A callable compatible with PyTorch's register_forward_hook or

Callable

register_forward_pre_hook APIs.

Source code in src/mi_crow/hooks/hook.py
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
def get_torch_hook(self) -> Callable:
    """
    Return a PyTorch-compatible hook function.

    The returned callable will check the enabled flag before executing
    and call the abstract _hook_fn method.

    Returns:
        A callable compatible with PyTorch's register_forward_hook or
        register_forward_pre_hook APIs.
    """
    if self.hook_type == HookType.PRE_FORWARD:
        return self._create_pre_forward_wrapper()
    else:
        return self._create_forward_wrapper()

set_context

set_context(context)

Set the LanguageModelContext for this hook.

Parameters:

Name Type Description Default
context 'LanguageModelContext'

The LanguageModelContext instance

required
Source code in src/mi_crow/hooks/hook.py
159
160
161
162
163
164
165
def set_context(self, context: "LanguageModelContext") -> None:
    """Set the LanguageModelContext for this hook.

    Args:
        context: The LanguageModelContext instance
    """
    self._context = context

mi_crow.hooks.detector.Detector

Detector(hook_type=HookType.FORWARD, hook_id=None, store=None, layer_signature=None)

Bases: Hook

Abstract base class for detector hooks that collect metadata during inference.

Detectors can accumulate data across batches and optionally save it to a Store. They are designed to observe and record information without modifying activations.

Initialize a detector hook.

Parameters:

Name Type Description Default
hook_type HookType | str

Type of hook (HookType.FORWARD or HookType.PRE_FORWARD)

FORWARD
hook_id str | None

Unique identifier

None
store Store | None

Optional Store for saving metadata

None
layer_signature str | int | None

Layer to attach to (optional, for compatibility)

None
Source code in src/mi_crow/hooks/detector.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def __init__(
        self,
        hook_type: HookType | str = HookType.FORWARD,
        hook_id: str | None = None,
        store: Store | None = None,
        layer_signature: str | int | None = None
):
    """
    Initialize a detector hook.

    Args:
        hook_type: Type of hook (HookType.FORWARD or HookType.PRE_FORWARD)
        hook_id: Unique identifier
        store: Optional Store for saving metadata
        layer_signature: Layer to attach to (optional, for compatibility)
    """
    super().__init__(layer_signature=layer_signature, hook_type=hook_type, hook_id=hook_id)
    self.store = store
    self.metadata: Dict[str, Any] = {}
    self.tensor_metadata: Dict[str, torch.Tensor] = {}

process_activations abstractmethod

process_activations(module, input, output)

Process activations from the hooked layer.

This is where detector-specific logic goes (e.g., tracking top activations, computing statistics, etc.).

Parameters:

Name Type Description Default
module Module

The PyTorch module being hooked

required
input HOOK_FUNCTION_INPUT

Tuple of input tensors to the module

required
output HOOK_FUNCTION_OUTPUT

Output tensor(s) from the module

required

Raises:

Type Description
Exception

Subclasses may raise exceptions for invalid inputs or processing errors

Source code in src/mi_crow/hooks/detector.py
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
@abc.abstractmethod
def process_activations(
        self,
        module: torch.nn.Module,
        input: HOOK_FUNCTION_INPUT,
        output: HOOK_FUNCTION_OUTPUT
) -> None:
    """
    Process activations from the hooked layer.

    This is where detector-specific logic goes (e.g., tracking top activations,
    computing statistics, etc.).

    Args:
        module: The PyTorch module being hooked
        input: Tuple of input tensors to the module
        output: Output tensor(s) from the module

    Raises:
        Exception: Subclasses may raise exceptions for invalid inputs or processing errors
    """
    raise NotImplementedError("process_activations must be implemented by subclasses")

mi_crow.hooks.controller.Controller

Controller(hook_type=HookType.FORWARD, hook_id=None, layer_signature=None)

Bases: Hook

Abstract base class for controller hooks that modify activations during inference.

Controllers can modify inputs (pre_forward) or outputs (forward) of layers. They are designed to actively change the behavior of the model during inference.

Initialize a controller hook.

Parameters:

Name Type Description Default
hook_type HookType | str

Type of hook (HookType.FORWARD or HookType.PRE_FORWARD)

FORWARD
hook_id str | None

Unique identifier

None
layer_signature str | int | None

Layer to attach to (optional, for compatibility)

None
Source code in src/mi_crow/hooks/controller.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def __init__(
        self,
        hook_type: HookType | str = HookType.FORWARD,
        hook_id: str | None = None,
        layer_signature: str | int | None = None
):
    """
    Initialize a controller hook.

    Args:
        hook_type: Type of hook (HookType.FORWARD or HookType.PRE_FORWARD)
        hook_id: Unique identifier
        layer_signature: Layer to attach to (optional, for compatibility)
    """
    super().__init__(layer_signature=layer_signature, hook_type=hook_type, hook_id=hook_id)

modify_activations abstractmethod

modify_activations(module, inputs, output)

Modify activations from the hooked layer.

For pre_forward hooks: receives input tensor, should return modified input tensor. For forward hooks: receives input and output tensors, should return modified output tensor.

Parameters:

Name Type Description Default
module Module

The PyTorch module being hooked

required
inputs Tensor | None

Input tensor (None for forward hooks if not available)

required
output Tensor | None

Output tensor (None for pre_forward hooks)

required

Returns:

Type Description
Tensor | None

Modified input tensor (for pre_forward) or modified output tensor (for forward).

Tensor | None

Return None to keep original tensor unchanged.

Raises:

Type Description
Exception

Subclasses may raise exceptions for invalid inputs or modification errors

Source code in src/mi_crow/hooks/controller.py
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
@abc.abstractmethod
def modify_activations(
        self,
        module: nn.Module,
        inputs: torch.Tensor | None,
        output: torch.Tensor | None
) -> torch.Tensor | None:
    """
    Modify activations from the hooked layer.

    For pre_forward hooks: receives input tensor, should return modified input tensor.
    For forward hooks: receives input and output tensors, should return modified output tensor.

    Args:
        module: The PyTorch module being hooked
        inputs: Input tensor (None for forward hooks if not available)
        output: Output tensor (None for pre_forward hooks)

    Returns:
        Modified input tensor (for pre_forward) or modified output tensor (for forward).
        Return None to keep original tensor unchanged.

    Raises:
        Exception: Subclasses may raise exceptions for invalid inputs or modification errors
    """
    raise NotImplementedError("modify_activations must be implemented by subclasses")

Implementations

mi_crow.hooks.implementations.layer_activation_detector.LayerActivationDetector

LayerActivationDetector(layer_signature, hook_id=None, target_dtype=None)

Bases: Detector

Detector hook that captures and saves activations during inference.

This detector extracts activations from layer outputs and stores them for later use (e.g., saving to disk, further analysis).

Initialize the activation saver detector.

Parameters:

Name Type Description Default
layer_signature str | int

Layer to capture activations from

required
hook_id str | None

Unique identifier for this hook

None
target_dtype dtype | None

Optional dtype to convert activations to before storing

None

Raises:

Type Description
ValueError

If layer_signature is None

Source code in src/mi_crow/hooks/implementations/layer_activation_detector.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def __init__(self, layer_signature: str | int, hook_id: str | None = None, target_dtype: torch.dtype | None = None):
    """
    Initialize the activation saver detector.

    Args:
        layer_signature: Layer to capture activations from
        hook_id: Unique identifier for this hook
        target_dtype: Optional dtype to convert activations to before storing

    Raises:
        ValueError: If layer_signature is None
    """
    if layer_signature is None:
        raise ValueError("layer_signature cannot be None for LayerActivationDetector")

    super().__init__(hook_type=HookType.FORWARD, hook_id=hook_id, store=None, layer_signature=layer_signature)
    self.target_dtype = target_dtype

clear_captured

clear_captured()

Clear captured activations for current batch.

Source code in src/mi_crow/hooks/implementations/layer_activation_detector.py
89
90
91
92
def clear_captured(self) -> None:
    """Clear captured activations for current batch."""
    self.tensor_metadata.pop("activations", None)
    self.metadata.pop("activations_shape", None)

get_captured

get_captured()

Get the captured activations from the current batch.

Returns:

Type Description
Tensor | None

The captured activation tensor from the current batch or None if no activations captured yet

Source code in src/mi_crow/hooks/implementations/layer_activation_detector.py
80
81
82
83
84
85
86
87
def get_captured(self) -> torch.Tensor | None:
    """
    Get the captured activations from the current batch.

    Returns:
        The captured activation tensor from the current batch or None if no activations captured yet
    """
    return self.tensor_metadata.get("activations")

process_activations

process_activations(module, input, output)

Extract and store activations from output.

Handles various output types: - Plain tensors - Tuples/lists of tensors (takes first tensor) - Objects with last_hidden_state attribute (e.g., HuggingFace outputs)

Parameters:

Name Type Description Default
module Module

The PyTorch module being hooked

required
input HOOK_FUNCTION_INPUT

Tuple of input tensors to the module

required
output HOOK_FUNCTION_OUTPUT

Output tensor(s) from the module

required

Raises:

Type Description
RuntimeError

If tensor extraction or storage fails

Source code in src/mi_crow/hooks/implementations/layer_activation_detector.py
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
def process_activations(
    self, module: torch.nn.Module, input: HOOK_FUNCTION_INPUT, output: HOOK_FUNCTION_OUTPUT
) -> None:
    """
    Extract and store activations from output.

    Handles various output types:
    - Plain tensors
    - Tuples/lists of tensors (takes first tensor)
    - Objects with last_hidden_state attribute (e.g., HuggingFace outputs)

    Args:
        module: The PyTorch module being hooked
        input: Tuple of input tensors to the module
        output: Output tensor(s) from the module

    Raises:
        RuntimeError: If tensor extraction or storage fails
    """
    try:
        tensor = extract_tensor_from_output(output)

        if tensor is not None:
            if tensor.is_cuda:
                tensor_cpu = tensor.detach().to("cpu", non_blocking=True)
            else:
                tensor_cpu = tensor.detach()

            if self.target_dtype is not None:
                tensor_cpu = tensor_cpu.to(self.target_dtype)

            self.tensor_metadata["activations"] = tensor_cpu
            self.metadata["activations_shape"] = tuple(tensor_cpu.shape)
    except Exception as e:
        layer_sig = str(self.layer_signature) if self.layer_signature is not None else "unknown"
        raise RuntimeError(
            f"Error extracting activations in LayerActivationDetector {self.id} (layer={layer_sig}): {e}"
        ) from e

mi_crow.hooks.implementations.model_input_detector.ModelInputDetector

ModelInputDetector(layer_signature=None, hook_id=None, save_input_ids=True, save_attention_mask=False, special_token_ids=None)

Bases: Detector

Detector hook that captures and saves tokenized inputs from model forward pass.

This detector is designed to be attached to the root model module and captures: - Tokenized inputs (input_ids) from the model's forward pass - Attention masks (optional) that exclude both padding and special tokens

Uses PRE_FORWARD hook to capture inputs before they are processed. Useful for saving tokenized inputs for analysis or training.

Initialize the model input detector.

Parameters:

Name Type Description Default
layer_signature str | int | None

Layer to capture from (typically the root model, can be None)

None
hook_id str | None

Unique identifier for this hook

None
save_input_ids bool

Whether to save input_ids tensor

True
save_attention_mask bool

Whether to save attention_mask tensor (excludes padding and special tokens)

False
special_token_ids Optional[List[int] | Set[int]]

Optional list/set of special token IDs. If None, will extract from LanguageModel context.

None
Source code in src/mi_crow/hooks/implementations/model_input_detector.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def __init__(
    self,
    layer_signature: str | int | None = None,
    hook_id: str | None = None,
    save_input_ids: bool = True,
    save_attention_mask: bool = False,
    special_token_ids: Optional[List[int] | Set[int]] = None,
):
    """
    Initialize the model input detector.

    Args:
        layer_signature: Layer to capture from (typically the root model, can be None)
        hook_id: Unique identifier for this hook
        save_input_ids: Whether to save input_ids tensor
        save_attention_mask: Whether to save attention_mask tensor (excludes padding and special tokens)
        special_token_ids: Optional list/set of special token IDs. If None, will extract from LanguageModel context.
    """
    super().__init__(hook_type=HookType.PRE_FORWARD, hook_id=hook_id, store=None, layer_signature=layer_signature)
    self.save_input_ids = save_input_ids
    self.save_attention_mask = save_attention_mask
    self.special_token_ids = set(special_token_ids) if special_token_ids is not None else None

clear_captured

clear_captured()

Clear all captured inputs for current batch.

Source code in src/mi_crow/hooks/implementations/model_input_detector.py
245
246
247
248
249
250
def clear_captured(self) -> None:
    """Clear all captured inputs for current batch."""
    keys_to_remove = ["input_ids", "attention_mask"]
    for key in keys_to_remove:
        self.tensor_metadata.pop(key, None)
        self.metadata.pop(f"{key}_shape", None)

get_captured_attention_mask

get_captured_attention_mask()

Get the captured attention_mask from the current batch (excludes padding and special tokens).

Source code in src/mi_crow/hooks/implementations/model_input_detector.py
241
242
243
def get_captured_attention_mask(self) -> torch.Tensor | None:
    """Get the captured attention_mask from the current batch (excludes padding and special tokens)."""
    return self.tensor_metadata.get("attention_mask")

get_captured_input_ids

get_captured_input_ids()

Get the captured input_ids from the current batch.

Source code in src/mi_crow/hooks/implementations/model_input_detector.py
237
238
239
def get_captured_input_ids(self) -> torch.Tensor | None:
    """Get the captured input_ids from the current batch."""
    return self.tensor_metadata.get("input_ids")

process_activations

process_activations(module, input, output)

Extract and store tokenized inputs.

Note: For HuggingFace models called with **kwargs, the input tuple may be empty. In such cases, use set_inputs_from_encodings() to manually set inputs from the encodings dictionary returned by lm.inference.execute_inference().

Parameters:

Name Type Description Default
module Module

The PyTorch module being hooked (typically the root model)

required
input HOOK_FUNCTION_INPUT

Tuple of input tensors/dicts to the module

required
output HOOK_FUNCTION_OUTPUT

Output from the module (None for PRE_FORWARD hooks)

required

Raises:

Type Description
RuntimeError

If tensor extraction or storage fails

Source code in src/mi_crow/hooks/implementations/model_input_detector.py
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
def process_activations(
    self, module: torch.nn.Module, input: HOOK_FUNCTION_INPUT, output: HOOK_FUNCTION_OUTPUT
) -> None:
    """
    Extract and store tokenized inputs.

    Note: For HuggingFace models called with **kwargs, the input tuple may be empty.
    In such cases, use set_inputs_from_encodings() to manually set inputs from
    the encodings dictionary returned by lm.inference.execute_inference().

    Args:
        module: The PyTorch module being hooked (typically the root model)
        input: Tuple of input tensors/dicts to the module
        output: Output from the module (None for PRE_FORWARD hooks)

    Raises:
        RuntimeError: If tensor extraction or storage fails
    """
    try:
        if self.save_input_ids:
            input_ids = self._extract_input_ids(input)
            if input_ids is not None:
                if input_ids.is_cuda:
                    input_ids_cpu = input_ids.detach().to("cpu", non_blocking=True)
                else:
                    input_ids_cpu = input_ids.detach()
                self.tensor_metadata["input_ids"] = input_ids_cpu
                self.metadata["input_ids_shape"] = tuple(input_ids_cpu.shape)

        if self.save_attention_mask:
            input_ids = self._extract_input_ids(input)
            if input_ids is not None:
                original_attention_mask = self._extract_attention_mask(input)
                combined_mask = self._create_combined_attention_mask(input_ids, original_attention_mask, module)
                if combined_mask.is_cuda:
                    combined_mask_cpu = combined_mask.detach().to("cpu", non_blocking=True)
                else:
                    combined_mask_cpu = combined_mask.detach()
                self.tensor_metadata["attention_mask"] = combined_mask_cpu
                self.metadata["attention_mask_shape"] = tuple(combined_mask_cpu.shape)

    except Exception as e:
        layer_sig = str(self.layer_signature) if self.layer_signature is not None else "unknown"
        raise RuntimeError(
            f"Error extracting inputs in ModelInputDetector {self.id} (layer={layer_sig}): {e}"
        ) from e

set_inputs_from_encodings

set_inputs_from_encodings(encodings, module=None)

Manually set inputs from encodings dictionary.

This is useful when the model is called with keyword arguments, as PyTorch's pre_forward hook doesn't receive kwargs.

Parameters:

Name Type Description Default
encodings Dict[str, Tensor]

Dictionary of encoded inputs (e.g., from lm.inference.execute_inference() or lm.tokenize())

required
module Optional[Module]

Optional module for extracting special token IDs. If None, will use DummyModule.

None

Raises:

Type Description
RuntimeError

If tensor extraction or storage fails

Source code in src/mi_crow/hooks/implementations/model_input_detector.py
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
def set_inputs_from_encodings(
    self, encodings: Dict[str, torch.Tensor], module: Optional[torch.nn.Module] = None
) -> None:
    """
    Manually set inputs from encodings dictionary.

    This is useful when the model is called with keyword arguments,
    as PyTorch's pre_forward hook doesn't receive kwargs.

    Args:
        encodings: Dictionary of encoded inputs (e.g., from lm.inference.execute_inference() or lm.tokenize())
        module: Optional module for extracting special token IDs. If None, will use DummyModule.

    Raises:
        RuntimeError: If tensor extraction or storage fails
    """
    try:
        if self.save_input_ids and "input_ids" in encodings:
            input_ids = encodings["input_ids"]
            self.tensor_metadata["input_ids"] = input_ids.detach().to("cpu")
            self.metadata["input_ids_shape"] = tuple(input_ids.shape)

        if self.save_attention_mask and "input_ids" in encodings:
            input_ids = encodings["input_ids"]
            if module is None:

                class DummyModule:
                    pass

                module = DummyModule()

            original_attention_mask = encodings.get("attention_mask")
            combined_mask = self._create_combined_attention_mask(input_ids, original_attention_mask, module)
            self.tensor_metadata["attention_mask"] = combined_mask.detach().to("cpu")
            self.metadata["attention_mask_shape"] = tuple(combined_mask.shape)
    except Exception as e:
        raise RuntimeError(f"Error setting inputs from encodings in ModelInputDetector {self.id}: {e}") from e

mi_crow.hooks.implementations.model_output_detector.ModelOutputDetector

ModelOutputDetector(layer_signature=None, hook_id=None, save_output_logits=True, save_output_hidden_state=False)

Bases: Detector

Detector hook that captures and saves model outputs.

This detector is designed to be attached to the root model module and captures: - Model outputs (logits) from the model's forward pass - Hidden states (optional) from the model's forward pass

Uses FORWARD hook to capture outputs after they are computed. Useful for saving model outputs for analysis or training.

Initialize the model output detector.

Parameters:

Name Type Description Default
layer_signature str | int | None

Layer to capture from (typically the root model, can be None)

None
hook_id str | None

Unique identifier for this hook

None
save_output_logits bool

Whether to save output logits (if available)

True
save_output_hidden_state bool

Whether to save last_hidden_state (if available)

False
Source code in src/mi_crow/hooks/implementations/model_output_detector.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def __init__(
    self,
    layer_signature: str | int | None = None,
    hook_id: str | None = None,
    save_output_logits: bool = True,
    save_output_hidden_state: bool = False,
):
    """
    Initialize the model output detector.

    Args:
        layer_signature: Layer to capture from (typically the root model, can be None)
        hook_id: Unique identifier for this hook
        save_output_logits: Whether to save output logits (if available)
        save_output_hidden_state: Whether to save last_hidden_state (if available)
    """
    super().__init__(hook_type=HookType.FORWARD, hook_id=hook_id, store=None, layer_signature=layer_signature)
    self.save_output_logits = save_output_logits
    self.save_output_hidden_state = save_output_hidden_state

clear_captured

clear_captured()

Clear all captured outputs for current batch.

Source code in src/mi_crow/hooks/implementations/model_output_detector.py
128
129
130
131
132
133
def clear_captured(self) -> None:
    """Clear all captured outputs for current batch."""
    keys_to_remove = ["output_logits", "output_hidden_state"]
    for key in keys_to_remove:
        self.tensor_metadata.pop(key, None)
        self.metadata.pop(f"{key}_shape", None)

get_captured_output_hidden_state

get_captured_output_hidden_state()

Get the captured output hidden state from the current batch.

Source code in src/mi_crow/hooks/implementations/model_output_detector.py
124
125
126
def get_captured_output_hidden_state(self) -> torch.Tensor | None:
    """Get the captured output hidden state from the current batch."""
    return self.tensor_metadata.get("output_hidden_state")

get_captured_output_logits

get_captured_output_logits()

Get the captured output logits from the current batch.

Source code in src/mi_crow/hooks/implementations/model_output_detector.py
120
121
122
def get_captured_output_logits(self) -> torch.Tensor | None:
    """Get the captured output logits from the current batch."""
    return self.tensor_metadata.get("output_logits")

process_activations

process_activations(module, input, output)

Extract and store model outputs.

Parameters:

Name Type Description Default
module Module

The PyTorch module being hooked (typically the root model)

required
input HOOK_FUNCTION_INPUT

Tuple of input tensors/dicts to the module

required
output HOOK_FUNCTION_OUTPUT

Output from the module

required

Raises:

Type Description
RuntimeError

If tensor extraction or storage fails

Source code in src/mi_crow/hooks/implementations/model_output_detector.py
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
def process_activations(
    self, module: torch.nn.Module, input: HOOK_FUNCTION_INPUT, output: HOOK_FUNCTION_OUTPUT
) -> None:
    """
    Extract and store model outputs.

    Args:
        module: The PyTorch module being hooked (typically the root model)
        input: Tuple of input tensors/dicts to the module
        output: Output from the module

    Raises:
        RuntimeError: If tensor extraction or storage fails
    """
    try:
        # Extract and save outputs
        logits, hidden_state = self._extract_output_tensor(output)

        if self.save_output_logits and logits is not None:
            if logits.is_cuda:
                logits_cpu = logits.detach().to("cpu", non_blocking=True)
            else:
                logits_cpu = logits.detach()
            self.tensor_metadata["output_logits"] = logits_cpu
            self.metadata["output_logits_shape"] = tuple(logits_cpu.shape)

        if self.save_output_hidden_state and hidden_state is not None:
            if hidden_state.is_cuda:
                hidden_state_cpu = hidden_state.detach().to("cpu", non_blocking=True)
            else:
                hidden_state_cpu = hidden_state.detach()
            self.tensor_metadata["output_hidden_state"] = hidden_state_cpu
            self.metadata["output_hidden_state_shape"] = tuple(hidden_state_cpu.shape)

    except Exception as e:
        layer_sig = str(self.layer_signature) if self.layer_signature is not None else "unknown"
        raise RuntimeError(
            f"Error extracting outputs in ModelOutputDetector {self.id} (layer={layer_sig}): {e}"
        ) from e

mi_crow.hooks.implementations.function_controller.FunctionController

FunctionController(layer_signature, function, hook_type=HookType.FORWARD, hook_id=None)

Bases: Controller

A controller that applies a user-provided function to tensors during inference.

This controller allows users to pass any function and apply it to activations. The function will be applied to: - Single tensors directly - All tensors in tuples/lists (default behavior)

Example

Scale activations by 2

controller = FunctionController( ... layer_signature="layer_0", ... function=lambda x: x * 2.0 ... )

Initialize a function controller.

Parameters:

Name Type Description Default
layer_signature str | int

Layer to attach to

required
function Callable[[Tensor], Tensor]

Function to apply to tensors. Must take a torch.Tensor and return a torch.Tensor

required
hook_type HookType | str

Type of hook (HookType.FORWARD or HookType.PRE_FORWARD)

FORWARD
hook_id str | None

Unique identifier

None

Raises:

Type Description
ValueError

If function is None or not callable

Source code in src/mi_crow/hooks/implementations/function_controller.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def __init__(
    self,
    layer_signature: str | int,
    function: Callable[[torch.Tensor], torch.Tensor],
    hook_type: HookType | str = HookType.FORWARD,
    hook_id: str | None = None,
):
    """
    Initialize a function controller.

    Args:
        layer_signature: Layer to attach to
        function: Function to apply to tensors. Must take a torch.Tensor and return a torch.Tensor
        hook_type: Type of hook (HookType.FORWARD or HookType.PRE_FORWARD)
        hook_id: Unique identifier

    Raises:
        ValueError: If function is None or not callable
    """
    if function is None:
        raise ValueError("function cannot be None")

    if not callable(function):
        raise ValueError(f"function must be callable, got: {type(function)}")

    super().__init__(hook_type=hook_type, hook_id=hook_id, layer_signature=layer_signature)
    self.function = function

modify_activations

modify_activations(module, inputs, output)

Apply the user-provided function to activations.

Parameters:

Name Type Description Default
module 'nn.Module'

The PyTorch module being hooked

required
inputs Tensor | None

Input tensor (None for forward hooks)

required
output Tensor | None

Output tensor (None for pre_forward hooks)

required

Returns:

Type Description
Tensor | None

Modified tensor with function applied, or None if target tensor is None

Raises:

Type Description
RuntimeError

If function raises an exception when applied to tensor

Source code in src/mi_crow/hooks/implementations/function_controller.py
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
def modify_activations(
    self,
    module: "nn.Module",
    inputs: torch.Tensor | None,
    output: torch.Tensor | None
) -> torch.Tensor | None:
    """
    Apply the user-provided function to activations.

    Args:
        module: The PyTorch module being hooked
        inputs: Input tensor (None for forward hooks)
        output: Output tensor (None for pre_forward hooks)

    Returns:
        Modified tensor with function applied, or None if target tensor is None

    Raises:
        RuntimeError: If function raises an exception when applied to tensor
    """
    target = output if self.hook_type == HookType.FORWARD else inputs

    if target is None or not isinstance(target, torch.Tensor):
        return target

    try:
        result = self.function(target)
        if not isinstance(result, torch.Tensor):
            raise TypeError(
                f"Function must return a torch.Tensor, got: {type(result)}"
            )
        return result
    except Exception as e:
        raise RuntimeError(
            f"Error applying function in FunctionController {self.id}: {e}"
        ) from e

Utilities

mi_crow.hooks.utils

Utility functions for hook implementations.

apply_modification_to_output

apply_modification_to_output(output, modified_tensor, target_device=None)

Apply a modified tensor to an output object in-place.

Handles various output formats: - Plain tensors: modifies the tensor directly (in-place) - Tuples/lists of tensors: replaces first tensor - Objects with last_hidden_state attribute: sets last_hidden_state

If target_device is provided, output tensors are moved to target_device first, ensuring consistency with the desired device (e.g., context.device). Otherwise, modified_tensor is moved to match output's current device.

Parameters:

Name Type Description Default
output HOOK_FUNCTION_OUTPUT

Output object to modify

required
modified_tensor Tensor

Modified tensor to apply

required
target_device device | None

Optional target device. If provided, output tensors are moved to this device before applying modification. If None, uses output's current device.

None
Source code in src/mi_crow/hooks/utils.py
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
def apply_modification_to_output(
    output: HOOK_FUNCTION_OUTPUT,
    modified_tensor: torch.Tensor,
    target_device: torch.device | None = None
) -> None:
    """
    Apply a modified tensor to an output object in-place.

    Handles various output formats:
    - Plain tensors: modifies the tensor directly (in-place)
    - Tuples/lists of tensors: replaces first tensor
    - Objects with last_hidden_state attribute: sets last_hidden_state

    If target_device is provided, output tensors are moved to target_device first,
    ensuring consistency with the desired device (e.g., context.device).
    Otherwise, modified_tensor is moved to match output's current device.

    Args:
        output: Output object to modify
        modified_tensor: Modified tensor to apply
        target_device: Optional target device. If provided, output tensors are moved
            to this device before applying modification. If None, uses output's current device.
    """
    if output is None:
        return

    if isinstance(output, torch.Tensor):
        if target_device is not None:
            if output.device != target_device:
                output = output.to(target_device)
            if modified_tensor.device != target_device:
                modified_tensor = modified_tensor.to(target_device)
        else:
            if modified_tensor.device != output.device:
                modified_tensor = modified_tensor.to(output.device)
        output.data.copy_(modified_tensor.data)
        return

    if isinstance(output, (tuple, list)):
        for i, item in enumerate(output):
            if isinstance(item, torch.Tensor):
                if target_device is not None:
                    if item.device != target_device:
                        item = item.to(target_device)
                        if isinstance(output, list):
                            output[i] = item
                    if modified_tensor.device != target_device or modified_tensor.dtype != item.dtype:
                        modified_tensor = modified_tensor.to(device=target_device, dtype=item.dtype)
                else:
                    if modified_tensor.device != item.device or modified_tensor.dtype != item.dtype:
                        modified_tensor = modified_tensor.to(device=item.device, dtype=item.dtype)
                if isinstance(output, tuple):
                    item.data.copy_(modified_tensor.data)
                else:
                    output[i] = modified_tensor
                break
        return

    if hasattr(output, "last_hidden_state"):
        original_tensor = output.last_hidden_state
        if isinstance(original_tensor, torch.Tensor):
            if target_device is not None:
                if original_tensor.device != target_device:
                    output.last_hidden_state = original_tensor.to(target_device)
                    original_tensor = output.last_hidden_state
                if modified_tensor.device != target_device:
                    modified_tensor = modified_tensor.to(target_device)
            else:
                if modified_tensor.device != original_tensor.device:
                    modified_tensor = modified_tensor.to(original_tensor.device)
        output.last_hidden_state = modified_tensor
        return

extract_tensor_from_input

extract_tensor_from_input(input)

Extract the first tensor from input sequence.

Handles various input formats: - Direct tensor in first position - Tuple/list of tensors in first position - Empty or None inputs

Parameters:

Name Type Description Default
input HOOK_FUNCTION_INPUT

Input sequence (tuple/list of tensors)

required

Returns:

Type Description
Tensor | None

First tensor found, or None if no tensor found

Source code in src/mi_crow/hooks/utils.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def extract_tensor_from_input(input: HOOK_FUNCTION_INPUT) -> torch.Tensor | None:
    """
    Extract the first tensor from input sequence.

    Handles various input formats:
    - Direct tensor in first position
    - Tuple/list of tensors in first position
    - Empty or None inputs

    Args:
        input: Input sequence (tuple/list of tensors)

    Returns:
        First tensor found, or None if no tensor found
    """
    if not input or len(input) == 0:
        return None

    first_item = input[0]
    if isinstance(first_item, torch.Tensor):
        return first_item

    if isinstance(first_item, (tuple, list)):
        for item in first_item:
            if isinstance(item, torch.Tensor):
                return item

    return None

extract_tensor_from_output

extract_tensor_from_output(output)

Extract tensor from output (handles various output types).

Handles various output formats: - Plain tensors - Tuples/lists of tensors (takes first tensor) - Objects with last_hidden_state attribute (e.g., HuggingFace outputs) - None outputs

Parameters:

Name Type Description Default
output HOOK_FUNCTION_OUTPUT

Output from module (tensor, tuple, or object with attributes)

required

Returns:

Type Description
Tensor | None

First tensor found, or None if no tensor found

Source code in src/mi_crow/hooks/utils.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def extract_tensor_from_output(output: HOOK_FUNCTION_OUTPUT) -> torch.Tensor | None:
    """
    Extract tensor from output (handles various output types).

    Handles various output formats:
    - Plain tensors
    - Tuples/lists of tensors (takes first tensor)
    - Objects with last_hidden_state attribute (e.g., HuggingFace outputs)
    - None outputs

    Args:
        output: Output from module (tensor, tuple, or object with attributes)

    Returns:
        First tensor found, or None if no tensor found
    """
    if output is None:
        return None

    if isinstance(output, torch.Tensor):
        return output

    if isinstance(output, (tuple, list)):
        for item in output:
            if isinstance(item, torch.Tensor):
                return item

    # Try common HuggingFace output objects
    if hasattr(output, "last_hidden_state"):
        maybe = getattr(output, "last_hidden_state")
        if isinstance(maybe, torch.Tensor):
            return maybe

    return None