Activation Control with Hooks¶
This guide covers directly manipulating activations using hooks for fine-grained model control.
Overview¶
Activation control provides: - Direct manipulation of layer activations - Fine-grained control without SAEs - Custom intervention patterns - Multi-layer coordination
When to Use Activation Control¶
Use activation control when: - You need direct control over activations - SAE-based manipulation is insufficient - You want custom intervention patterns - You're experimenting with new control methods
Basic Activation Control¶
Using Detector Hooks for Inspection¶
First, inspect activations to understand what you're working with:
from mi_crow.hooks import LayerActivationDetector
# Create detector
detector = LayerActivationDetector("transformer.h.0.attn.c_attn")
# Register hook
hook_id = lm.layers.register_hook("transformer.h.0.attn.c_attn", detector)
# Run inference
outputs, encodings = lm.inference.execute_inference(["Hello, world!"])
# Inspect activations
activations = detector.get_captured()
print(f"Activations shape: {activations.shape}")
print(f"Mean: {activations.mean().item()}")
print(f"Std: {activations.std().item()}")
# Cleanup
lm.layers.unregister_hook(hook_id)
Using Controller Hooks for Modification¶
Modify activations directly:
from mi_crow.hooks import FunctionController
# Create controller that scales activations
controller = FunctionController(
layer_signature="transformer.h.0.attn.c_attn",
function=lambda x: x * 1.5 # Scale by 1.5
)
# Register hook
hook_id = lm.layers.register_hook("transformer.h.0.attn.c_attn", controller)
# Run inference with modification
outputs, encodings = lm.inference.execute_inference(["Hello, world!"])
# Cleanup
lm.layers.unregister_hook(hook_id)
Custom Controller Implementation¶
Create custom controllers for specific needs:
Scaling Controller¶
from mi_crow.hooks import Controller
from mi_crow.hooks.hook import HookType
import torch
class ScalingController(Controller):
"""Scales activations by a factor."""
def __init__(self, layer_signature: str | int, scale_factor: float):
super().__init__(hook_type=HookType.FORWARD, layer_signature=layer_signature)
self.scale_factor = scale_factor
def modify_activations(self, module, inputs, output):
if output is None:
return output
return output * self.scale_factor
# Use
controller = ScalingController("transformer.h.0.attn.c_attn", scale_factor=1.5)
hook_id = lm.layers.register_hook("transformer.h.0.attn.c_attn", controller)
Selective Neuron Controller¶
class SelectiveController(Controller):
"""Modifies only specific neurons."""
def __init__(self, layer_signature: str | int, neuron_indices: list[int], scale: float):
super().__init__(hook_type=HookType.FORWARD, layer_signature=layer_signature)
self.neuron_indices = set(neuron_indices)
self.scale = scale
def modify_activations(self, module, inputs, output):
if output is None or not isinstance(output, torch.Tensor):
return output
modified = output.clone()
for idx in self.neuron_indices:
if idx < modified.shape[-1]:
modified[..., idx] *= self.scale
return modified
# Use
controller = SelectiveController(
"transformer.h.0.attn.c_attn",
neuron_indices=[42, 100, 200],
scale=2.0
)
hook_id = lm.layers.register_hook("transformer.h.0.attn.c_attn", controller)
Multi-Layer Interventions¶
Coordinate interventions across multiple layers:
Sequential Modifications¶
# Modify early layer
early_controller = FunctionController("transformer.h.0.attn.c_attn", lambda x: x * 1.2)
hook1 = lm.layers.register_hook("transformer.h.0.attn.c_attn", early_controller)
# Modify late layer
late_controller = FunctionController("transformer.h.10.attn.c_attn", lambda x: x * 0.8)
hook2 = lm.layers.register_hook("transformer.h.10.attn.c_attn", late_controller)
# Both apply during forward pass
outputs, encodings = lm.inference.execute_inference(["Hello, world!"])
# Cleanup
lm.layers.unregister_hook(hook1)
lm.layers.unregister_hook(hook2)
Cross-Layer Communication¶
class CoordinatedController(Controller):
"""Uses information from another layer."""
def __init__(self, layer_signature, reference_detector):
super().__init__(hook_type=HookType.FORWARD, layer_signature=layer_signature)
self.reference_detector = reference_detector
def modify_activations(self, module, inputs, output):
# Get activations from reference layer
ref_activations = self.reference_detector.get_captured()
if ref_activations is not None and output is not None:
# Scale based on reference layer
scale = 1.0 + 0.1 * ref_activations.mean().item()
return output * scale
return output
# Setup
ref_detector = LayerActivationDetector("transformer.h.0.attn.c_attn")
lm.layers.register_hook("transformer.h.0.attn.c_attn", ref_detector)
coordinated = CoordinatedController("transformer.h.5.attn.c_attn", ref_detector)
hook_id = lm.layers.register_hook("transformer.h.5.attn.c_attn", coordinated)
A/B Testing with Hooks¶
Compare behavior with and without interventions:
Baseline¶
# Get baseline
baseline_outputs, _ = lm.inference.execute_inference(
["Your prompt"],
with_controllers=False # Disable all controllers
)
With Intervention¶
# Apply intervention
controller = FunctionController("layer_0", lambda x: x * 1.5)
hook_id = lm.layers.register_hook("layer_0", controller)
# Get modified output
intervention_outputs, _ = lm.inference.execute_inference(
["Your prompt"],
with_controllers=True # Enable controllers
)
# Compare
difference = intervention_outputs.logits - baseline_outputs.logits
# Cleanup
lm.layers.unregister_hook(hook_id)
Advanced Patterns¶
Conditional Control¶
class ConditionalController(Controller):
"""Applies modification conditionally."""
def __init__(self, layer_signature, condition_fn, modification_fn):
super().__init__(hook_type=HookType.FORWARD, layer_signature=layer_signature)
self.condition_fn = condition_fn
self.modification_fn = modification_fn
def modify_activations(self, module, inputs, output):
if output is None:
return output
if self.condition_fn(output):
return self.modification_fn(output)
return output
# Use: only modify if activation magnitude is high
controller = ConditionalController(
"layer_0",
condition_fn=lambda x: x.abs().mean() > 1.0,
modification_fn=lambda x: x * 1.5
)
Pipeline of Transformations¶
class PipelineController(Controller):
"""Applies multiple transformations."""
def __init__(self, layer_signature, transformations):
super().__init__(hook_type=HookType.FORWARD, layer_signature=layer_signature)
self.transformations = transformations
def modify_activations(self, module, inputs, output):
result = output
for transform in self.transformations:
result = transform(result)
return result
# Use
pipeline = PipelineController(
"layer_0",
transformations=[
lambda x: x * 1.2, # Scale
lambda x: torch.clamp(x, -2, 2), # Clamp
lambda x: (x - x.mean()) / (x.std() + 1e-8) # Normalize
]
)
Best Practices¶
- Always cleanup: Unregister hooks when done
- Use context managers: For automatic cleanup
- Test incrementally: Start with simple modifications
- Monitor effects: Compare before/after
- Document interventions: Record what each does
Common Patterns¶
Context Manager Pattern¶
class HookContext:
"""Context manager for hook lifecycle."""
def __init__(self, layers, hook, layer_signature):
self.layers = layers
self.hook = hook
self.layer_signature = layer_signature
self.hook_id = None
def __enter__(self):
self.hook_id = self.layers.register_hook(self.layer_signature, self.hook)
return self.hook
def __exit__(self, exc_type, exc_val, exc_tb):
if self.hook_id:
self.layers.unregister_hook(self.hook_id)
# Usage
with HookContext(lm.layers, controller, "layer_0") as hook:
outputs, encodings = lm.inference.execute_inference(["Hello, world!"])
# Hook automatically unregistered
Try/Finally Pattern¶
hook_id = None
try:
controller = FunctionController("layer_0", lambda x: x * 1.5)
hook_id = lm.layers.register_hook("layer_0", controller)
outputs, encodings = lm.inference.execute_inference(["Hello, world!"])
finally:
if hook_id:
lm.layers.unregister_hook(hook_id)
Next Steps¶
After learning activation control:
- Hooks: Controllers - Detailed controller guide
- Hooks: Advanced - Advanced hook patterns
- Concept Manipulation - SAE-based control
- Examples - See example notebooks
Related Examples¶
examples/08_inference_with_hooks.ipynb- Complete hooks exampleexamples/03_load_concepts.ipynb- Concept manipulation with hooks