将给定模块的 forward 方法替换为重新编译的版本,其中所有不可追踪的代码模式都被移除,并替换为 torch.fx 函数包装器。
有关不可追踪代码模式及其解释的列表,请参阅 https://github.com/vllm-project/llm-compressor/pull/1411
参数
-
module
(Module) – -
ignore
(List[str]) –
源代码位于 llmcompressor/pipelines/sequential/ast_helpers.py
| @contextlib.contextmanager
def autowrap_forward(module: torch.nn.Module, ignore: List[str]):
"""
Replace the `forward` method of the given module with a recompiled version where
all untraceble code patterns are removed and replaced with torch.fx function
wrappers.
For a list of untraceable code patterns and their explainations, see
https://github.com/vllm-project/llm-compressor/pull/1411
:param module: module whose forward method should be replaced
:param ignore: explicit list of function names to wrap
"""
# get source code of module forward
source = inspect.getsource(module.forward)
source = textwrap.dedent(source)
tree = ast.parse(source)
# construct namespace for our new code
defining_module = sys.modules[module.__class__.__module__]
namespace = defining_module.__dict__.copy()
namespace.update({"torch.fx.wrap": torch.fx.wrap})
namespace.update({"self": module})
# autowrap untraceable code
auto_wrapper = AutoWrapper(namespace, ignore)
tree = auto_wrapper.auto_wrap(tree)
source = ast.unparse(tree)
# compile new forward function from autowrapped code
filename = f"<Autowrapped {module.__class__.__name__} {id(module)}>"
code = compile(source, filename=filename, mode="exec")
exec(code, namespace) # ensure ns of functions is the same ns as torch.fx.wrap
# enable better tracebacks if autowrapped code fails
linecache.cache[filename] = (
len(source),
None,
[line + "\n" for line in source.splitlines()],
filename,
)
# patch forward with autowrapped forward
new_forward = namespace["forward"].__get__(module)
with patch_attr(module, "forward", new_forward):
yield
|