MultiplyInvokedModule

class MultiplyInvokedModule

Wrapper around a module that was invoked multiple times by its parent when graphpatch converted it into a GraphModule. This allows you to patch distinct invocations independently.

Example

class Foo(Module):
    def __init__(self):
        super().__init__()
        self.bar = Linear(3, 3)

    def forward(self, x, y):
        return self.bar(x) + self.bar(y)
In [1]: pg = PatchableGraph(Foo(), **inputs)
In [2]: print(pg._graph_module)
Out [2]:
    CompiledGraphModule(
        (bar): MultiplyInvokedModule(
            (0-1): 2 x CompiledGraphModule()
        )
    )
In [3]: pg.graph
Out[1]: 
<root>: CompiledGraphModule
├─x: Tensor(3, 3)
├─y: Tensor(3, 3)
├─bar_0: CompiledGraphModule
│ ├─input: Tensor(3, 3)
│ ├─weight: Tensor(3, 3)
│ ├─bias: Tensor(3)
│ ├─linear: Tensor(3, 3)
│ └─output: Tensor(3, 3)
├─bar_1: CompiledGraphModule
│ ├─input: Tensor(3, 3)
│ ├─weight: Tensor(3, 3)
│ ├─bias: Tensor(3)
│ ├─linear: Tensor(3, 3)
│ └─output: Tensor(3, 3)
├─add: Tensor(3, 3)
└─output: Tensor(3, 3)

You can patch the two calls to the submodule “bar” independently:

>>> with pg.patch({"bar_0": ZeroPatch(), "bar_1": AddPatch(value=1)}):
    ...

See also Multiple invocations of a submodule are treated independently.