Patch#
- class Patch#
Base class for operations applying to nodes in a
PatchableGraph
. Derived classes should be keyword-onlydataclasses
(i.e. decorated with@dataclass(kw_only=True)
) and overrideop()
.- requires_clone#
Whether the operation modifies the original output. Set to True and hidden from the constructor; can be overridden in derived classes for read-only operations.
- Type:
- path#
For nodes that output nested structures, the path within that structure that this operation should apply to. Hidden from the constructor, since setting the path will be handled by
PatchableGraph
.- Type:
str | None
- op(original_output: PatchTarget) PatchTarget #
The operation to perform at this node. Should take in a single argument, which will be populated with the original output at this node, and return a value of the same type.
- class AddPatch(*, value: Tensor | int | float | bool, slice: TensorSlice | None = None)#
Patch that adds a value to (optionally, a slice of) its target.
Example
pg = PatchableGraph(model, **example_inputs) delta = torch.ones((seq_len - 1,)) with pg.patch({"output": AddPatch(value=delta, slice=(slice(1, None), 0))}): patched_outputs = pg(**sample_inputs)
- slice#
Slice to perform addition on. Applies to full target if None.
- Type:
TensorSlice | None
- value#
Value to add to target.
- Type:
torch.Tensor | int | float | bool
- class CustomPatch(*, requires_clone: bool = True, custom_op: Callable[[PatchTarget], PatchTarget])#
Convenience for one-off patch operations without the need to define a new Patch class. Also exposes the normally hidden
requires_clone
field for operations that do not require cloning.Example
Replace the output of a layer’s MLP with that of a previous layer:
pg = PatchableGraph(model, **example_inputs) with pg.patch( { "layers_0.mlp.output": [layer_0 := ProbePatch()], "layers_1.mlp.output": CustomPatch(custom_op=lambda t: layer_0.activation), } ): print(pg(**sample_inputs))
- custom_op#
Operation to perform. Replace output at this node with the return value of
custom_op(original_output)
.- Type:
Callable[[PatchTarget], PatchTarget]
- class ProbePatch#
Patch that records the last activation of its target.
Example
pg = PatchableGraph(**example_inputs) probe = ProbePatch() with pg.patch({"transformer.h_17.mlp.act.mul_3": probe}): pg(**sample_inputs) print(probe.activation)
- activation#
Value of the previous activation of its target, or None if not yet recorded.
- Type:
torch.Tensor | None
- class RecordPatch(*, activations: ~typing.List[~torch.Tensor] = <factory>)#
Patch that records all activations of its target.
Example
Replace a layer’s output with a running mean of the previous layer’s activations:
pg = PatchableGraph(**example_inputs) record = RecordPatch() for i in range(10): with pg.patch( { "layers_0.output": layer_0, "layers_1.output": CustomPatch( custom_op=lambda t: torch.mean( torch.stack(record.activations, dim=2), dim=2 ) ), } ): print(pg(**sample_inputs[i]))
- activations#
List of activations.
- Type:
List[torch.Tensor]
- class ReplacePatch(*, slice: TensorSlice | None = None, value: Tensor | int | float | bool)#
Patch that replaces (optionally, a slice of) its target with the given value.
Example
pg = PatchableGraph(**example_inputs) with pg.patch("linear.input": ReplacePatch(value=42, slice=(slice(None), 0, 0))): print(pg(**sample_inputs))
- slice#
Slice of the target to replace with
value
; applies to the whole tensor if None.- Type:
TensorSlice | None
- value#
Value with which to replace the target or slice of the target.
- Type:
torch.Tensor | int | float | bool
- class ZeroPatch(*, slice: TensorSlice | None = None)#
Patch that zeroes out a slice of its target, or the whole tensor if no slice is provided.
Example
pg = PatchableGraph(**example_inputs) with pg.patch("layers_0.output": ZeroPatch()): print(pg(**sample_inputs))
- slice#
Slice of the target to apply zeros to; applies to the whole tensor if None.
- Type:
TensorSlice | None
Types#
- TensorSlice: TensorSliceElement | List[TensorSlice] | Tuple[TensorSlice, ...]#
This is a datatype representing the indexing operation done when you slice a
Tensor
, as happens in code likex[:, 5:8, 2] = 3
This is not a
graphpatch
-specific type (we have merely aliased it for convenience), but interacts withPython internals
which may be unfamiliar.Briefly, you will almost always want to pass a sequence (tuple or list) with as many elements as the dimensionality of your tensor. Within this sequence, elements can be either integers, subsequences,
slices
, or Tensors. Each element of the sequence will select a subset of the Tensor along the dimension with the corresponding index. An integer will select a single “row” along that dimension. A subsequence will select multiple “rows”. A slice will select a range of “rows”. (slice(None)
selects all rows for that dimension, equivalent to writing a “:” within the bracket expression.) A Tensor will perform a complex operation that is out of the scope of this brief note.For a concrete example, we can accomplish the above operation with the following
ReplacePatch
:ReplacePatch(value=3, slice=((slice(None), slice(5, 8), 2)))
See also: Tensor Indexing API.
- TensorSliceElement: int | slice | torch.Tensor#
One component of a
TensorSlice
.