PatchableGraph#
- class PatchableGraph(module: Module, *extraction_args: Any, _graphpatch_postprocessing_function: Callable[[GraphModule, Module], None] | None = None, **extraction_kwargs: Any)#
PatchableGraph is a wrapper around
torch.nn.Module
allowing activation patching at any computational node.Internally, PatchableGraph builds a
torch.fx.GraphModule
for the module and each of its submodules usingtorch.compile()
. This exposes the computational structure of the module while still being equivalent to the original–you can perform any operation you would with the original module using the PatchableGraph.Note that the original module hierarchy is retained. For example, if you had a module
foo
containing a submodulebar
, you would get back a GraphModule equivalent tofoo
which has a sub-GraphModulebar
, equivalent to the originalbar
.To perform activation patching, use the
patch
context manager. This method takes a mapping from NodePaths to lists of Patch to apply at the corresponding node. Note that the activation patches will only be applied inside the context block; using the PatchableGraph outside such a block is equivalent to running the original module.Example
>>> from graphpatch import PatchableGraph, ZeroPatch >>> my_llm, my_tokenizer = MyLLM(), MyTokenizer() >>> my_inputs = MyTokenizer("Hello, ") >>> patchable_graph = PatchableGraph(my_llm, **my_inputs) # Patch the input to the third layer's MLP >>> with patchable_graph.patch({"layers_2.mlp.x": [ZeroPatch()]): >>> patched_output = patchable_graph(**my_inputs)
- Parameters:
module – The
Module
to wrap.extraction_args – Arguments (example inputs) to be passed to the module during
torch.compile()
._graphpatch_postprocessing_function – Optional function to call which will modify the generated
torch.fx.GraphModule
. This function can modify the underlyingtorch.fx.Graph
in-place. The original module is passed for reference in case, for example, the needed modifications depend on its configuration.extraction_kwargs – Keyword arguments to be passed to the module during
torch.compile()
.
- property graph#
Convenience property for working in REPL and notebook environments. Exposes the full NodePath hierarchy of this PatchableGraph via recursive attribute access. Children of the current node can be tab-completed at each step. Has a custom
__repr__()
to display the subgraph rooted at the current path. Dynamically generated attributes:- <node_name>
One attribute per child node, having the name of that child.
- _code#
For submodules, the compiled GraphModule code. The partial stacktrace of the original model for other nodes.
- _shape#
The shape of the Tensor observed at this node during compilation, if the value was a Tensor.
Example:
In [1]: pg.graph Out[1]: <root>: Graph(3) ├─x: Tensor(3, 2) ├─linear: Graph(5) │ ├─input: Tensor(3, 2) │ ├─weight: Tensor(3, 2) │ ├─bias: Tensor(3) │ ├─linear: Tensor(3, 3) │ └─output: Tensor(3, 3) └─output: Tensor(3, 3) In [2]: pg.graph.linear._code Out[2]: Calling context: File "/Users/evanlloyd/graphpatch/tests/fixtures/minimal_module.py", line 16, in forward return self.linear(x) Compiled code: def forward(self, input : torch.Tensor): input_1 = input weight = self.weight bias = self.bias linear = torch._C._nn.linear(input_1, weight, bias); input_1 = weight = bias = None return linear In [3]: pg.graph.output._shape Out[3]: torch.Size([3, 3])
See Working with graphpatch for more discussion and examples.
- patch(patch_map: Dict[str | NodePath, List[Patch[Tensor]] | Patch[Tensor]]) Iterator[None] #
Context manager that will cause the given activation patches to be applied when running inference on the wrapped module.
- Parameters:
patch_map – A mapping from NodePath to a Patch or list of Patches to apply to each respective node during inference.
- Yields:
A context in which the given activation patch(es) will be applied when calling
self.forward()
.- Raises:
KeyError – If any NodePath in
patch_map
does not exist in the graph.ValueError – If
patch_map
has any invalid types.
- save(*args: Any, **kwargs: Any) None #
Wrapper around
torch.save()
because some PatchableGraph internals need to be handled specially before pickling. You will get an exception asking you to use this method if you calltorch.save()
directly on a PatchableGraph instance. All the normal caveats around pickling apply; you should nottorch.load()
anything you downloaded from the Internet.Future versions of graphpatch will likely remove this method in favor of a more secure serialization scheme.