Hooks API¶
Hook system for intercepting and managing model activations during inference.
Core Hook Classes¶
mi_crow.hooks.hook.Hook ¶
Hook(layer_signature=None, hook_type=HookType.FORWARD, hook_id=None)
Bases: ABC
Abstract base class for hooks that can be registered on language model layers.
Hooks provide a way to intercept and process activations during model inference. They expose PyTorch-compatible callables via get_torch_hook() while providing additional functionality like enable/disable and unique identification.
Initialize a hook.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
layer_signature
|
str | int | None
|
Layer name or index to attach hook to |
None
|
hook_type
|
HookType | str
|
Type of hook - HookType.FORWARD or HookType.PRE_FORWARD |
FORWARD
|
hook_id
|
str | None
|
Unique identifier (auto-generated if not provided) |
None
|
Raises:
| Type | Description |
|---|---|
ValueError
|
If hook_type string is invalid |
Source code in src/mi_crow/hooks/hook.py
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 | |
disable ¶
disable()
Disable this hook.
Source code in src/mi_crow/hooks/hook.py
150 151 152 | |
enable ¶
enable()
Enable this hook.
Source code in src/mi_crow/hooks/hook.py
146 147 148 | |
get_torch_hook ¶
get_torch_hook()
Return a PyTorch-compatible hook function.
The returned callable will check the enabled flag before executing and call the abstract _hook_fn method.
Returns:
| Type | Description |
|---|---|
Callable
|
A callable compatible with PyTorch's register_forward_hook or |
Callable
|
register_forward_pre_hook APIs. |
Source code in src/mi_crow/hooks/hook.py
180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 | |
set_context ¶
set_context(context)
Set the LanguageModelContext for this hook.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
context
|
'LanguageModelContext'
|
The LanguageModelContext instance |
required |
Source code in src/mi_crow/hooks/hook.py
159 160 161 162 163 164 165 | |
mi_crow.hooks.detector.Detector ¶
Detector(hook_type=HookType.FORWARD, hook_id=None, store=None, layer_signature=None)
Bases: Hook
Abstract base class for detector hooks that collect metadata during inference.
Detectors can accumulate data across batches and optionally save it to a Store. They are designed to observe and record information without modifying activations.
Initialize a detector hook.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hook_type
|
HookType | str
|
Type of hook (HookType.FORWARD or HookType.PRE_FORWARD) |
FORWARD
|
hook_id
|
str | None
|
Unique identifier |
None
|
store
|
Store | None
|
Optional Store for saving metadata |
None
|
layer_signature
|
str | int | None
|
Layer to attach to (optional, for compatibility) |
None
|
Source code in src/mi_crow/hooks/detector.py
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 | |
process_activations
abstractmethod
¶
process_activations(module, input, output)
Process activations from the hooked layer.
This is where detector-specific logic goes (e.g., tracking top activations, computing statistics, etc.).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
module
|
Module
|
The PyTorch module being hooked |
required |
input
|
HOOK_FUNCTION_INPUT
|
Tuple of input tensors to the module |
required |
output
|
HOOK_FUNCTION_OUTPUT
|
Output tensor(s) from the module |
required |
Raises:
| Type | Description |
|---|---|
Exception
|
Subclasses may raise exceptions for invalid inputs or processing errors |
Source code in src/mi_crow/hooks/detector.py
74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 | |
mi_crow.hooks.controller.Controller ¶
Controller(hook_type=HookType.FORWARD, hook_id=None, layer_signature=None)
Bases: Hook
Abstract base class for controller hooks that modify activations during inference.
Controllers can modify inputs (pre_forward) or outputs (forward) of layers. They are designed to actively change the behavior of the model during inference.
Initialize a controller hook.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hook_type
|
HookType | str
|
Type of hook (HookType.FORWARD or HookType.PRE_FORWARD) |
FORWARD
|
hook_id
|
str | None
|
Unique identifier |
None
|
layer_signature
|
str | int | None
|
Layer to attach to (optional, for compatibility) |
None
|
Source code in src/mi_crow/hooks/controller.py
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 | |
modify_activations
abstractmethod
¶
modify_activations(module, inputs, output)
Modify activations from the hooked layer.
For pre_forward hooks: receives input tensor, should return modified input tensor. For forward hooks: receives input and output tensors, should return modified output tensor.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
module
|
Module
|
The PyTorch module being hooked |
required |
inputs
|
Tensor | None
|
Input tensor (None for forward hooks if not available) |
required |
output
|
Tensor | None
|
Output tensor (None for pre_forward hooks) |
required |
Returns:
| Type | Description |
|---|---|
Tensor | None
|
Modified input tensor (for pre_forward) or modified output tensor (for forward). |
Tensor | None
|
Return None to keep original tensor unchanged. |
Raises:
| Type | Description |
|---|---|
Exception
|
Subclasses may raise exceptions for invalid inputs or modification errors |
Source code in src/mi_crow/hooks/controller.py
151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 | |
Implementations¶
mi_crow.hooks.implementations.layer_activation_detector.LayerActivationDetector ¶
LayerActivationDetector(layer_signature, hook_id=None, target_dtype=None)
Bases: Detector
Detector hook that captures and saves activations during inference.
This detector extracts activations from layer outputs and stores them for later use (e.g., saving to disk, further analysis).
Initialize the activation saver detector.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
layer_signature
|
str | int
|
Layer to capture activations from |
required |
hook_id
|
str | None
|
Unique identifier for this hook |
None
|
target_dtype
|
dtype | None
|
Optional dtype to convert activations to before storing |
None
|
Raises:
| Type | Description |
|---|---|
ValueError
|
If layer_signature is None |
Source code in src/mi_crow/hooks/implementations/layer_activation_detector.py
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 | |
clear_captured ¶
clear_captured()
Clear captured activations for current batch.
Source code in src/mi_crow/hooks/implementations/layer_activation_detector.py
89 90 91 92 | |
get_captured ¶
get_captured()
Get the captured activations from the current batch.
Returns:
| Type | Description |
|---|---|
Tensor | None
|
The captured activation tensor from the current batch or None if no activations captured yet |
Source code in src/mi_crow/hooks/implementations/layer_activation_detector.py
80 81 82 83 84 85 86 87 | |
process_activations ¶
process_activations(module, input, output)
Extract and store activations from output.
Handles various output types: - Plain tensors - Tuples/lists of tensors (takes first tensor) - Objects with last_hidden_state attribute (e.g., HuggingFace outputs)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
module
|
Module
|
The PyTorch module being hooked |
required |
input
|
HOOK_FUNCTION_INPUT
|
Tuple of input tensors to the module |
required |
output
|
HOOK_FUNCTION_OUTPUT
|
Output tensor(s) from the module |
required |
Raises:
| Type | Description |
|---|---|
RuntimeError
|
If tensor extraction or storage fails |
Source code in src/mi_crow/hooks/implementations/layer_activation_detector.py
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 | |
mi_crow.hooks.implementations.model_input_detector.ModelInputDetector ¶
ModelInputDetector(layer_signature=None, hook_id=None, save_input_ids=True, save_attention_mask=False, special_token_ids=None)
Bases: Detector
Detector hook that captures and saves tokenized inputs from model forward pass.
This detector is designed to be attached to the root model module and captures: - Tokenized inputs (input_ids) from the model's forward pass - Attention masks (optional) that exclude both padding and special tokens
Uses PRE_FORWARD hook to capture inputs before they are processed. Useful for saving tokenized inputs for analysis or training.
Initialize the model input detector.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
layer_signature
|
str | int | None
|
Layer to capture from (typically the root model, can be None) |
None
|
hook_id
|
str | None
|
Unique identifier for this hook |
None
|
save_input_ids
|
bool
|
Whether to save input_ids tensor |
True
|
save_attention_mask
|
bool
|
Whether to save attention_mask tensor (excludes padding and special tokens) |
False
|
special_token_ids
|
Optional[List[int] | Set[int]]
|
Optional list/set of special token IDs. If None, will extract from LanguageModel context. |
None
|
Source code in src/mi_crow/hooks/implementations/model_input_detector.py
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 | |
clear_captured ¶
clear_captured()
Clear all captured inputs for current batch.
Source code in src/mi_crow/hooks/implementations/model_input_detector.py
245 246 247 248 249 250 | |
get_captured_attention_mask ¶
get_captured_attention_mask()
Get the captured attention_mask from the current batch (excludes padding and special tokens).
Source code in src/mi_crow/hooks/implementations/model_input_detector.py
241 242 243 | |
get_captured_input_ids ¶
get_captured_input_ids()
Get the captured input_ids from the current batch.
Source code in src/mi_crow/hooks/implementations/model_input_detector.py
237 238 239 | |
process_activations ¶
process_activations(module, input, output)
Extract and store tokenized inputs.
Note: For HuggingFace models called with **kwargs, the input tuple may be empty. In such cases, use set_inputs_from_encodings() to manually set inputs from the encodings dictionary returned by lm.inference.execute_inference().
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
module
|
Module
|
The PyTorch module being hooked (typically the root model) |
required |
input
|
HOOK_FUNCTION_INPUT
|
Tuple of input tensors/dicts to the module |
required |
output
|
HOOK_FUNCTION_OUTPUT
|
Output from the module (None for PRE_FORWARD hooks) |
required |
Raises:
| Type | Description |
|---|---|
RuntimeError
|
If tensor extraction or storage fails |
Source code in src/mi_crow/hooks/implementations/model_input_detector.py
190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 | |
set_inputs_from_encodings ¶
set_inputs_from_encodings(encodings, module=None)
Manually set inputs from encodings dictionary.
This is useful when the model is called with keyword arguments, as PyTorch's pre_forward hook doesn't receive kwargs.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
encodings
|
Dict[str, Tensor]
|
Dictionary of encoded inputs (e.g., from lm.inference.execute_inference() or lm.tokenize()) |
required |
module
|
Optional[Module]
|
Optional module for extracting special token IDs. If None, will use DummyModule. |
None
|
Raises:
| Type | Description |
|---|---|
RuntimeError
|
If tensor extraction or storage fails |
Source code in src/mi_crow/hooks/implementations/model_input_detector.py
152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 | |
mi_crow.hooks.implementations.model_output_detector.ModelOutputDetector ¶
ModelOutputDetector(layer_signature=None, hook_id=None, save_output_logits=True, save_output_hidden_state=False)
Bases: Detector
Detector hook that captures and saves model outputs.
This detector is designed to be attached to the root model module and captures: - Model outputs (logits) from the model's forward pass - Hidden states (optional) from the model's forward pass
Uses FORWARD hook to capture outputs after they are computed. Useful for saving model outputs for analysis or training.
Initialize the model output detector.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
layer_signature
|
str | int | None
|
Layer to capture from (typically the root model, can be None) |
None
|
hook_id
|
str | None
|
Unique identifier for this hook |
None
|
save_output_logits
|
bool
|
Whether to save output logits (if available) |
True
|
save_output_hidden_state
|
bool
|
Whether to save last_hidden_state (if available) |
False
|
Source code in src/mi_crow/hooks/implementations/model_output_detector.py
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 | |
clear_captured ¶
clear_captured()
Clear all captured outputs for current batch.
Source code in src/mi_crow/hooks/implementations/model_output_detector.py
128 129 130 131 132 133 | |
get_captured_output_hidden_state ¶
get_captured_output_hidden_state()
Get the captured output hidden state from the current batch.
Source code in src/mi_crow/hooks/implementations/model_output_detector.py
124 125 126 | |
get_captured_output_logits ¶
get_captured_output_logits()
Get the captured output logits from the current batch.
Source code in src/mi_crow/hooks/implementations/model_output_detector.py
120 121 122 | |
process_activations ¶
process_activations(module, input, output)
Extract and store model outputs.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
module
|
Module
|
The PyTorch module being hooked (typically the root model) |
required |
input
|
HOOK_FUNCTION_INPUT
|
Tuple of input tensors/dicts to the module |
required |
output
|
HOOK_FUNCTION_OUTPUT
|
Output from the module |
required |
Raises:
| Type | Description |
|---|---|
RuntimeError
|
If tensor extraction or storage fails |
Source code in src/mi_crow/hooks/implementations/model_output_detector.py
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 | |
mi_crow.hooks.implementations.function_controller.FunctionController ¶
FunctionController(layer_signature, function, hook_type=HookType.FORWARD, hook_id=None)
Bases: Controller
A controller that applies a user-provided function to tensors during inference.
This controller allows users to pass any function and apply it to activations. The function will be applied to: - Single tensors directly - All tensors in tuples/lists (default behavior)
Example
Scale activations by 2¶
controller = FunctionController( ... layer_signature="layer_0", ... function=lambda x: x * 2.0 ... )
Initialize a function controller.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
layer_signature
|
str | int
|
Layer to attach to |
required |
function
|
Callable[[Tensor], Tensor]
|
Function to apply to tensors. Must take a torch.Tensor and return a torch.Tensor |
required |
hook_type
|
HookType | str
|
Type of hook (HookType.FORWARD or HookType.PRE_FORWARD) |
FORWARD
|
hook_id
|
str | None
|
Unique identifier |
None
|
Raises:
| Type | Description |
|---|---|
ValueError
|
If function is None or not callable |
Source code in src/mi_crow/hooks/implementations/function_controller.py
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 | |
modify_activations ¶
modify_activations(module, inputs, output)
Apply the user-provided function to activations.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
module
|
'nn.Module'
|
The PyTorch module being hooked |
required |
inputs
|
Tensor | None
|
Input tensor (None for forward hooks) |
required |
output
|
Tensor | None
|
Output tensor (None for pre_forward hooks) |
required |
Returns:
| Type | Description |
|---|---|
Tensor | None
|
Modified tensor with function applied, or None if target tensor is None |
Raises:
| Type | Description |
|---|---|
RuntimeError
|
If function raises an exception when applied to tensor |
Source code in src/mi_crow/hooks/implementations/function_controller.py
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 | |
Utilities¶
mi_crow.hooks.utils ¶
Utility functions for hook implementations.
apply_modification_to_output ¶
apply_modification_to_output(output, modified_tensor, target_device=None)
Apply a modified tensor to an output object in-place.
Handles various output formats: - Plain tensors: modifies the tensor directly (in-place) - Tuples/lists of tensors: replaces first tensor - Objects with last_hidden_state attribute: sets last_hidden_state
If target_device is provided, output tensors are moved to target_device first, ensuring consistency with the desired device (e.g., context.device). Otherwise, modified_tensor is moved to match output's current device.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
output
|
HOOK_FUNCTION_OUTPUT
|
Output object to modify |
required |
modified_tensor
|
Tensor
|
Modified tensor to apply |
required |
target_device
|
device | None
|
Optional target device. If provided, output tensors are moved to this device before applying modification. If None, uses output's current device. |
None
|
Source code in src/mi_crow/hooks/utils.py
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 | |
extract_tensor_from_input ¶
extract_tensor_from_input(input)
Extract the first tensor from input sequence.
Handles various input formats: - Direct tensor in first position - Tuple/list of tensors in first position - Empty or None inputs
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input
|
HOOK_FUNCTION_INPUT
|
Input sequence (tuple/list of tensors) |
required |
Returns:
| Type | Description |
|---|---|
Tensor | None
|
First tensor found, or None if no tensor found |
Source code in src/mi_crow/hooks/utils.py
12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 | |
extract_tensor_from_output ¶
extract_tensor_from_output(output)
Extract tensor from output (handles various output types).
Handles various output formats: - Plain tensors - Tuples/lists of tensors (takes first tensor) - Objects with last_hidden_state attribute (e.g., HuggingFace outputs) - None outputs
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
output
|
HOOK_FUNCTION_OUTPUT
|
Output from module (tensor, tuple, or object with attributes) |
required |
Returns:
| Type | Description |
|---|---|
Tensor | None
|
First tensor found, or None if no tensor found |
Source code in src/mi_crow/hooks/utils.py
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 | |