Notes on compilation

PatchableGraph constructs a hierarchy of graphs matching the original submodule hierarchy. Submodules will always be instances of either CompiledGraphModule, OpaqueGraphModule, or container types such as ModuleDict and ModuleList. For example, if you had a module foo containing a submodule bar, you might get back a GraphModule equivalent to foo which has a sub-GraphModule bar, equivalent to the original bar:

PatchableGraph(
    (_graph_module): CompiledGraphModule(
        (bar): OpaqueGraphModule()
    )
)

Here, _graph_module is the compiled version of the original root module (foo) and foo.bar is an opaque wrapper around the original bar.

graphpatch makes a best effort to compile every submodule. When this succeeds, the corresponding submodule will be of class CompiledGraphModule. This includes some workarounds for cases where a native torch.compile() would fail, such as for quantized linear modules from the bitsandbytes library, which accelerate uses for model quantization. When this fails, by default graphpatch will fall back to constructing a wrapper around the original model code that allows for patching inputs, outputs, parameters, and buffers. The corresponding submodule will be of class OpaqueGraphModule.

As torch.compile() is still new and somewhat rough to use in practice, I have made the default behavior to handle this fallback silently; this can be configured by passing ExtractionOptions to PatchableGraph. For example, you can treat compilation failure as an error with the option error_on_compilation_failure, which will give you access to the original exception. PyTorch does not offer much guidance on diagnosing these errors, but you might start with PyTorch 2.0 Troubleshooting.

Custom extraction functions

As an advanced option, you can pass custom functions for handling the conversion of modules into graphs with the extraction option custom_extraction_functions, which is a dict mapping from subtypes of Module to functions taking a Module and outputting a Graph. This may be easier in some cases than diagnosing issues with torch.compile(), and gives fine-grained control over the generated graph. Example:

class MyUncompilableModule(Module):
    def forward(self, foo: Tensor, bar: Tensor):
        return uncompilable_operation(foo, bar)

def extract_my_module(module: Module) -> Graph:
    graph = Graph()

    # Note that placeholders must exactly match the names of the arguments to forward.
    foo = graph.placeholder("foo")
    bar = graph.placeholder("bar")
    operation = graph.call_function(uncompilable_operation, (foo, bar))
    # graphpatch will respect the names of any nodes in the graph, which can make subsequent
    # patching operations easier to parse.
    operation.name = "my_custom_name"

    # Note that the output must be wrapped in a single-element tuple.
    graph.output((operation,))
    return graph

pg = PatchableGraph(module_instance,
    ExtractionOptions(custom_extraction_functions={MyUncompilableModule: extract_my_module})),
    example_foo,
    example_bar,
)

When using this option, make sure that your graph has placeholders with targets exactly matching the names of the inputs to your module’s forward() function. This is needed because graphpatch runs a sanity check on these inputs to correct them sometimes getting mangled by the normal compilation process. Your graph’s output must also be wrapped in a single-element tuple as in the above example to match the behavior of torch.compile().

For another example, graphpatch internally uses this mechanism to handle the bitsandbytes class Linear8bitLt to allow patching of the weights as if they were an ordinary tensor, with the following extraction function that simply manually constructs the desired Graph:

def compile_8_bit_linear(module):
    graph = Graph()
    x = graph.placeholder("x", torch.Tensor)
    cb = graph.get_attr("CB")
    scb = graph.get_attr("SCB")
    bias = graph.get_attr("bias")
    threshold = graph.get_attr("threshold")
    mul = graph.call_function(operator.mul, (cb, scb))
    weight = graph.call_function(operator.truediv, (mul, 127))
    weight.name = "weight"
    output = graph.call_function(matmul_8bit, (x, weight, bias, threshold))
    graph.output((output,))
    return graph

Multiple invocations of a submodule are treated independently

While this may be a rare edge case in practice, graphpatch handles cases where a submodule is called multiple times by treating each instance as an independent graph that can be patched separately. For a (somewhat contrived) example from a model used in our test cases:

class TupleOutputModule(Module):
    _shape = (2, 3)

    def __init__(self):
        super().__init__()
        self.linear = Linear(*TupleOutputModule._shape)

    def forward(self, x):
        return (self.linear(x), self.linear(x + 1))

>>> pg = PatchableGraph(TupleOutputModule(), torch.ones(3, 2))
PatchableGraph(
    (_graph_module): CompiledGraphModule(
        (linear): MultiplyInvokedModule(
            (0-1): 2 x CompiledGraphModule()
        )
    )
)
>>> pg.graph
<root>: CompiledGraphModule
├─x: Tensor(3, 2)
├─linear_0: CompiledGraphModule
│ ├─input: Tensor(3, 2)
│ ├─weight: Tensor(3, 2)
│ ├─bias: Tensor(3)
│ ├─linear: Tensor(3, 3)
│ └─output: Tensor(3, 3)
├─add: Tensor(3, 2)
├─linear_1: CompiledGraphModule
│ ├─input: Tensor(3, 2)
│ ├─weight: Tensor(3, 2)
│ ├─bias: Tensor(3)
│ ├─linear: Tensor(3, 3)
│ └─output: Tensor(3, 3)
└─output: tuple(2)
├─sub_0: Tensor(3, 3)
└─sub_1: Tensor(3, 3)

Note that linear is now an instance of MultiplyInvokedModule, which is a subclass of ModuleList, and there are two nodes corresponding to it in the graph, linear_0 and linear_1. The two invocations can be patched independently:

with pg.patch(
    {
        "linear_0.output": [AddPatch(value=torch.ones((1,)))],
        "linear_1.output": [ZeroPatch()],
    }
):
    ...