MultiplyInvokedModule¶
- class MultiplyInvokedModule¶
Wrapper around a module that was invoked multiple times by its parent when
graphpatchconverted it into a GraphModule. This allows you to patch distinct invocations independently.Example
class Foo(Module): def __init__(self): super().__init__() self.bar = Linear(3, 3) def forward(self, x, y): return self.bar(x) + self.bar(y)
In [1]: pg = PatchableGraph(Foo(), **inputs) In [2]: print(pg._graph_module) Out [2]: CompiledGraphModule( (bar): MultiplyInvokedModule( (0-1): 2 x CompiledGraphModule() ) ) In [3]: pg.graph Out[1]: <root>: CompiledGraphModule ├─x: Tensor(3, 3) ├─y: Tensor(3, 3) ├─bar_0: CompiledGraphModule │ ├─input: Tensor(3, 3) │ ├─weight: Tensor(3, 3) │ ├─bias: Tensor(3) │ ├─linear: Tensor(3, 3) │ └─output: Tensor(3, 3) ├─bar_1: CompiledGraphModule │ ├─input: Tensor(3, 3) │ ├─weight: Tensor(3, 3) │ ├─bias: Tensor(3) │ ├─linear: Tensor(3, 3) │ └─output: Tensor(3, 3) ├─add: Tensor(3, 3) └─output: Tensor(3, 3)
You can patch the two calls to the submodule “bar” independently:
>>> with pg.patch({"bar_0": ZeroPatch(), "bar_1": AddPatch(value=1)}): ...
See also Multiple invocations of a submodule are treated independently.