PatchableGraph#

class PatchableGraph(module: Module, *extraction_args: Any, _graphpatch_postprocessing_function: Callable[[GraphModule, Module], None] | None = None, **extraction_kwargs: Any)#

PatchableGraph is a wrapper around torch.nn.Module allowing activation patching at any computational node.

Internally, PatchableGraph builds a torch.fx.GraphModule for the module and each of its submodules using torch.compile(). This exposes the computational structure of the module while still being equivalent to the original–you can perform any operation you would with the original module using the PatchableGraph.

Note that the original module hierarchy is retained. For example, if you had a module foo containing a submodule bar, you would get back a GraphModule equivalent to foo which has a sub-GraphModule bar, equivalent to the original bar.

To perform activation patching, use the patch context manager. This method takes a mapping from NodePaths to lists of Patch to apply at the corresponding node. Note that the activation patches will only be applied inside the context block; using the PatchableGraph outside such a block is equivalent to running the original module.

Example

>>> from graphpatch import PatchableGraph, ZeroPatch
>>> my_llm, my_tokenizer = MyLLM(), MyTokenizer()
>>> my_inputs = MyTokenizer("Hello, ")
>>> patchable_graph = PatchableGraph(my_llm, **my_inputs)
# Patch the input to the third layer's MLP
>>> with patchable_graph.patch({"layers_2.mlp.x": [ZeroPatch()]):
>>>    patched_output = patchable_graph(**my_inputs)
Parameters:
  • module – The Module to wrap.

  • extraction_args – Arguments (example inputs) to be passed to the module during torch.compile().

  • _graphpatch_postprocessing_function – Optional function to call which will modify the generated torch.fx.GraphModule. This function can modify the underlying torch.fx.Graph in-place. The original module is passed for reference in case, for example, the needed modifications depend on its configuration.

  • extraction_kwargs – Keyword arguments to be passed to the module during torch.compile().

property graph#

Convenience property for working in REPL and notebook environments. Exposes the full NodePath hierarchy of this PatchableGraph via recursive attribute access. Children of the current node can be tab-completed at each step. Has a custom __repr__() to display the subgraph rooted at the current path. Dynamically generated attributes:

<node_name>

One attribute per child node, having the name of that child.

_code#

For submodules, the compiled GraphModule code. The partial stacktrace of the original model for other nodes.

_shape#

The shape of the Tensor observed at this node during compilation, if the value was a Tensor.

Example:

In [1]: pg.graph
Out[1]: 
<root>: Graph(3)
├─x: Tensor(3, 2)
├─linear: Graph(5)
│ ├─input: Tensor(3, 2)
│ ├─weight: Tensor(3, 2)
│ ├─bias: Tensor(3)
│ ├─linear: Tensor(3, 3)
│ └─output: Tensor(3, 3)
└─output: Tensor(3, 3)

In [2]: pg.graph.linear._code
Out[2]: 
Calling context:
File "/Users/evanlloyd/graphpatch/tests/fixtures/minimal_module.py", line 16, in forward
    return self.linear(x)
Compiled code:
def forward(self, input : torch.Tensor):
    input_1 = input
    weight = self.weight
    bias = self.bias
    linear = torch._C._nn.linear(input_1, weight, bias);  input_1 = weight = bias = None
    return linear

In [3]: pg.graph.output._shape
Out[3]: torch.Size([3, 3])

See Working with graphpatch for more discussion and examples.

patch(patch_map: Dict[str | NodePath, List[Patch[Tensor]] | Patch[Tensor]]) Iterator[None]#

Context manager that will cause the given activation patches to be applied when running inference on the wrapped module.

Parameters:

patch_map – A mapping from NodePath to a Patch or list of Patches to apply to each respective node during inference.

Yields:

A context in which the given activation patch(es) will be applied when calling self.forward().

Raises:
save(*args: Any, **kwargs: Any) None#

Wrapper around torch.save() because some PatchableGraph internals need to be handled specially before pickling. You will get an exception asking you to use this method if you call torch.save() directly on a PatchableGraph instance. All the normal caveats around pickling apply; you should not torch.load() anything you downloaded from the Internet.

Future versions of graphpatch will likely remove this method in favor of a more secure serialization scheme.