ExtractionOptions¶
- class ExtractionOptions¶
Options to control the behavior of
graphpatchduring graph extraction. This is a keyword-only dataclass; to construct one, pass any number of options from the below.- classes_to_skip_compiling¶
Set of Module classes to leave uncompiled. These modules will only be patchable at their inputs, outputs, parameters, and buffers. May be useful for working around compilation issues. Default:
set().- Type:
Set[Type[Module]]
- copy_transformers_generation_config¶
If the wrapped Module is a huggingface transformers implementation, should graphpatch attempt to copy its generation config so generation convenience functions like
generate()can be used? Default:True.- Type:
- custom_extraction_functions¶
Optional map from Module classes to callables generating
Graphto be used in place of graphpatch’s normal extraction mechanism when encountering that class. Advanced feature; should not be necessary for ordinary use. See Custom extraction functions. Default:dict().
- error_on_compilation_failure¶
Treat failure to compile a submodule as an error, rather than falling back to module-level patching via
OpaqueGraphModule. Default:False.- Type:
- postprocessing_function¶
Optional function to call which will modify the generated
GraphModule. This function can modify the underlyingGraphin-place. The original module is passed for reference in case, for example, the needed modifications depend on its configuration. Advanced feature; should not be necessary for ordinary use. Default:None.- Type:
Callable[[GraphModule, Module], None] | None
- skip_compilation¶
Skip compilation on all modules. Only module inputs and outputs will be patchable. May be useful for faster iteration times if patching intermediate values isn’t needed. Default:
False.- Type:
- warn_on_compilation_failure¶
Issue a warning when compilation fails, but then fall back to module-level patching for the failed module(s). Default:
False.- Type:
Example
options = ExtractionOptions( classes_to_skip_compiling={MyUncompilableModule}, error_on_compilation_failure=True, ) pg = PatchableGraph(my_model, options, **example_inputs)