跳到内容

llmcompressor.utils.dev

函数

dispatch_for_generation

dispatch_for_generation(
    model: PreTrainedModel,
) -> PreTrainedModel

调度模型的自回归生成。这意味着模块被平均调度到可用设备上,并尽可能保持加载。会移除先前可能存在的任何 HF 钩子。

参数

  • model

    (PreTrainedModel) –

    要分发的模型

返回

  • PreTrainedModel

    将被调度的模型

源文件位于 llmcompressor/utils/dev.py
def dispatch_for_generation(model: PreTrainedModel) -> PreTrainedModel:
    """
    Dispatch a model autoregressive generation. This means that modules are dispatched
    evenly across avaiable devices and kept onloaded if possible. Removes any HF hooks
    that may have existed previously.

    :param model: model to dispatch
    :return: model which is dispatched
    """
    remove_dispatch(model)

    no_split_module_classes = model._get_no_split_modules("auto")
    max_memory = get_balanced_memory(
        model,
        dtype=model.dtype,
        no_split_module_classes=no_split_module_classes,
    )
    device_map = infer_auto_device_map(
        model,
        dtype=model.dtype,
        max_memory=max_memory,
        no_split_module_classes=no_split_module_classes,
    )

    return dispatch_model(model, device_map=device_map)

patch_transformers_logger_level

patch_transformers_logger_level(level: int = logging.ERROR)

修改 transformers 日志器级别的上下文

这可以与 skip_weights_download 一起使用,以抑制与检查点中缺失参数相关的警告。

参数

  • 级别

    (int, default: ERROR ) –

    transformers 日志器的新日志级别。级别低于此级别的日志将不会被记录。

源文件位于 llmcompressor/utils/dev.py
@contextlib.contextmanager
def patch_transformers_logger_level(level: int = logging.ERROR):
    """
    Context under which the transformers logger's level is modified

    This can be used with `skip_weights_download` to squelch warnings related to
    missing parameters in the checkpoint

    :param level: new logging level for transformers logger. Logs whose level is below
        this level will not be logged
    """
    transformers_logger = logging.getLogger("transformers.modeling_utils")
    restore_log_level = transformers_logger.getEffectiveLevel()

    transformers_logger.setLevel(level=level)
    yield
    transformers_logger.setLevel(level=restore_log_level)

skip_weights_download

skip_weights_download(
    model_class: Type[
        PreTrainedModel
    ] = AutoModelForCausalLM,
)

上下文管理器,在此期间模型在不下载模型权重文件的情况下进行初始化。这与 init_empty_weights 不同,因为它将权重以随机值分配到指定的设备上,而不是在 meta 设备上。

参数

  • model_class

    (Type[PreTrainedModel], default: AutoModelForCausalLM ) –

    需要修改的类,默认为 AutoModelForCausalLM

源文件位于 llmcompressor/utils/dev.py
@contextlib.contextmanager
def skip_weights_download(model_class: Type[PreTrainedModel] = AutoModelForCausalLM):
    """
    Context manager under which models are initialized without having to download
    the model weight files. This differs from `init_empty_weights` in that weights are
    allocated on to assigned devices with random values, as opposed to being on the meta
    device

    :param model_class: class to patch, defaults to `AutoModelForCausalLM`
    """
    original_fn = model_class.from_pretrained
    weights_files = [
        "*.bin",
        "*.safetensors",
        "*.pth",
        SAFE_WEIGHTS_INDEX_NAME,
        WEIGHTS_INDEX_NAME,
        "*.msgpack",
        "*.pt",
    ]

    @classmethod
    def patched(cls, *args, **kwargs):
        nonlocal tmp_dir

        # intercept model stub
        model_stub = args[0] if args else kwargs.pop("pretrained_model_name_or_path")

        # download files into tmp dir
        os.makedirs(tmp_dir, exist_ok=True)
        snapshot_download(
            repo_id=model_stub, local_dir=tmp_dir, ignore_patterns=weights_files
        )

        # make an empty weights file to avoid errors
        weights_file_path = os.path.join(tmp_dir, "model.safetensors")
        save_file({}, weights_file_path, metadata={"format": "pt"})

        # load from tmp dir
        model = original_fn(tmp_dir, **kwargs)

        # replace model_path
        model.name_or_path = model_stub
        model.config._name_or_path = model_stub

        return model

    with tempfile.TemporaryDirectory() as tmp_dir, patch_attr(
        model_class, "from_pretrained", patched
    ), skip_weights_initialize(), patch_transformers_logger_level():
        yield

skip_weights_initialize

skip_weights_initialize(use_zeros: bool = False)

transformers.model_utils.no_init_weights 非常相似,但它还修改了 torch.Tensor 的初始化函数,以处理不在 meta 设备上初始化的张量。

源文件位于 llmcompressor/utils/dev.py
@contextlib.contextmanager
def skip_weights_initialize(use_zeros: bool = False):
    """
    Very similar to `transformers.model_utils.no_init_weights`, except that torch.Tensor
    initialization functions are also patched to account for tensors which are
    initialized not on the meta device
    """

    def skip(tensor: torch.Tensor, *args, **kwargs) -> torch.Tensor:
        if use_zeros:
            return tensor.fill_(0)
        return tensor

    with contextlib.ExitStack() as stack:
        for name in TORCH_INIT_FUNCTIONS.keys():
            stack.enter_context(patch_attr(torch.nn.init, name, skip))
            stack.enter_context(patch_attr(torch.Tensor, name, skip))
        yield