Patch#

class Patch#

Base class for operations applying to nodes in a PatchableGraph. Derived classes should be keyword-only dataclasses (i.e. decorated with @dataclass(kw_only=True)) and override op().

requires_clone#

Whether the operation modifies the original output. Set to True and hidden from the constructor; can be overridden in derived classes for read-only operations.

Type:

bool

path#

For nodes that output nested structures, the path within that structure that this operation should apply to. Hidden from the constructor, since setting the path will be handled by PatchableGraph.

Type:

str | None

op(original_output: PatchTarget) PatchTarget#

The operation to perform at this node. Should take in a single argument, which will be populated with the original output at this node, and return a value of the same type.

class AddPatch(*, value: Tensor | int | float | bool, slice: TensorSlice | None = None)#

Patch that adds a value to (optionally, a slice of) its target.

Example

pg = PatchableGraph(model, **example_inputs)
delta = torch.ones((seq_len - 1,))
with pg.patch({"output": AddPatch(value=delta, slice=(slice(1, None), 0))}):
    patched_outputs = pg(**sample_inputs)
slice#

Slice to perform addition on. Applies to full target if None.

Type:

TensorSlice | None

value#

Value to add to target.

Type:

torch.Tensor | int | float | bool

class CustomPatch(*, requires_clone: bool = True, custom_op: Callable[[PatchTarget], PatchTarget])#

Convenience for one-off patch operations without the need to define a new Patch class. Also exposes the normally hidden requires_clone field for operations that do not require cloning.

Example

Replace the output of a layer’s MLP with that of a previous layer:

pg = PatchableGraph(model, **example_inputs)
with pg.patch(
    {
        "layers_0.mlp.output": [layer_0 := ProbePatch()],
        "layers_1.mlp.output": CustomPatch(custom_op=lambda t: layer_0.activation),
    }
):
    print(pg(**sample_inputs))
custom_op#

Operation to perform. Replace output at this node with the return value of custom_op(original_output).

Type:

Callable[[PatchTarget], PatchTarget]

requires_clone#

Whether the operation modifies the original output tensor. Defaults to True. For read-only operations, set to False to avoid creating unnecessary copies.

Type:

bool

class ProbePatch#

Patch that records the last activation of its target.

Example

pg = PatchableGraph(**example_inputs)
probe = ProbePatch()
with pg.patch({"transformer.h_17.mlp.act.mul_3": probe}):
    pg(**sample_inputs)
print(probe.activation)
activation#

Value of the previous activation of its target, or None if not yet recorded.

Type:

torch.Tensor | None

class RecordPatch(*, activations: ~typing.List[~torch.Tensor] = <factory>)#

Patch that records all activations of its target.

Example

Replace a layer’s output with a running mean of the previous layer’s activations:

pg = PatchableGraph(**example_inputs)
record = RecordPatch()
for i in range(10):
    with pg.patch(
        {
            "layers_0.output": layer_0,
            "layers_1.output": CustomPatch(
                custom_op=lambda t: torch.mean(
                    torch.stack(record.activations, dim=2), dim=2
                )
            ),
        }
    ):
        print(pg(**sample_inputs[i]))
activations#

List of activations.

Type:

List[torch.Tensor]

class ReplacePatch(*, slice: TensorSlice | None = None, value: Tensor | int | float | bool)#

Patch that replaces (optionally, a slice of) its target with the given value.

Example

pg = PatchableGraph(**example_inputs)
with pg.patch("linear.input": ReplacePatch(value=42, slice=(slice(None), 0, 0))):
    print(pg(**sample_inputs))
slice#

Slice of the target to replace with value; applies to the whole tensor if None.

Type:

TensorSlice | None

value#

Value with which to replace the target or slice of the target.

Type:

torch.Tensor | int | float | bool

class ZeroPatch(*, slice: TensorSlice | None = None)#

Patch that zeroes out a slice of its target, or the whole tensor if no slice is provided.

Example

pg = PatchableGraph(**example_inputs)
with pg.patch("layers_0.output": ZeroPatch()):
    print(pg(**sample_inputs))
slice#

Slice of the target to apply zeros to; applies to the whole tensor if None.

Type:

TensorSlice | None

Types#

TensorSlice: TensorSliceElement | List[TensorSlice] | Tuple[TensorSlice, ...]#

This is a datatype representing the indexing operation done when you slice a Tensor, as happens in code like

x[:, 5:8, 2] = 3

This is not a graphpatch-specific type (we have merely aliased it for convenience), but interacts with Python internals which may be unfamiliar.

Briefly, you will almost always want to pass a sequence (tuple or list) with as many elements as the dimensionality of your tensor. Within this sequence, elements can be either integers, subsequences, slices, or Tensors. Each element of the sequence will select a subset of the Tensor along the dimension with the corresponding index. An integer will select a single “row” along that dimension. A subsequence will select multiple “rows”. A slice will select a range of “rows”. (slice(None) selects all rows for that dimension, equivalent to writing a “:” within the bracket expression.) A Tensor will perform a complex operation that is out of the scope of this brief note.

For a concrete example, we can accomplish the above operation with the following ReplacePatch:

ReplacePatch(value=3, slice=((slice(None), slice(5, 8), 2)))

See also: Tensor Indexing API.

TensorSliceElement: int | slice | torch.Tensor#

One component of a TensorSlice.

PatchTarget: TypeVar#

Generic type argument which will be specialized for patches expecting different data types. Almost always specialized to Tensor.