跳到内容

llmcompressor.pipelines.cache

  • IntermediateValue

    数据类,递归定义了要卸载的值以及要加载到的设备

  • IntermediatesCache

    缓存,存储由批量、顺序生成的中间值(激活)

IntermediateValue dataclass

IntermediateValue(
    value: Union[Tensor, IntermediateValue, Any],
    device: Union[device, None],
)

数据类,递归定义了要卸载的值以及要加载到的设备

参数

  • (Union[Tensor, IntermediateValue, Any]) –

    可以是卸载的 Tensor、基本值或可递归值

  • device

    (Union[device, None]) –

    如果值是 Tensor,则为要加载 Tensor 的设备,否则为 None

IntermediatesCache

IntermediatesCache(
    batch_intermediates: Optional[
        List[IntermediateValues]
    ] = None,
    offload_device: Optional[device] = "cpu",
)

缓存,存储由模型批量、顺序执行生成的中间值(激活)。值在存储在缓存中时被卸载到 offload_device,在从缓存中获取时被加载到其原始设备。如果 offload_device 为 None,则不会卸载值。

当前支持数据类实例和元组的嵌套卸载

使用 emptyfrom_dataloader 类方法进行构建

方法

  • append

    将新值追加到缓存。新值将被分配下一个

  • delete

    从缓存中删除值

  • empty

    构建一个空缓存

  • fetch

    获取属于某个批次的值

  • from_dataloader

    使用提供的 DataLoader 初始化缓存

  • size

    返回缓存中已使用的内存,按设备分组,以字节为单位

  • update

    更新/放置属于某个批次的值

源代码在 llmcompressor/pipelines/cache.py
def __init__(
    self,
    batch_intermediates: Optional[List[IntermediateValues]] = None,
    offload_device: Optional[torch.device] = "cpu",
):
    self.batch_intermediates = batch_intermediates or []
    self.offload_device = offload_device

append

append(values: Dict[str, Any])

将新值追加到缓存。新值将被分配下一个可用的批次索引

参数

  • (Dict[str, Any]) –

    用于更新的键到值的字典映射

源代码在 llmcompressor/pipelines/cache.py
def append(self, values: Dict[str, Any]):
    """
    Append new values to the cache. The new values will be assigned the next
    available batch index

    :param values: dictionary mapping keys to values used for update
    """
    batch_index = len(self.batch_intermediates)
    self.batch_intermediates.append({})
    self.update(batch_index, values)

delete

delete(
    batch_index: int,
    consumed_names: Optional[List[str]] = None,
)

从缓存中删除值

参数

  • batch_index

    (int) –

    要删除其值的批次索引

  • consumed_names

    (Optional[List[str]], 默认值: None ) –

    要删除其值的键列表,默认为删除所有键

源代码在 llmcompressor/pipelines/cache.py
def delete(self, batch_index: int, consumed_names: Optional[List[str]] = None):
    """
    Delete values from the cache

    :param batch_index: index of batch whose values will be deleted
    :param consumed_names: list of keys whose values will be deleted, defaults to
        removing all keys
    """
    intermediates = self.batch_intermediates[batch_index]

    if consumed_names is None:
        consumed_names = list(intermediates.keys())

    for name in consumed_names:
        del intermediates[name]

empty classmethod

empty(num_batches: int, offload_device: device)

构建一个空缓存

参数

  • num_batches

    (int) –

    要存储的预期批次数

  • offload_device

    (device) –

    要卸载值的设备

源代码在 llmcompressor/pipelines/cache.py
@classmethod
def empty(cls, num_batches: int, offload_device: torch.device):
    """
    Construct an empty cache

    :param num_batches: the expected number of batches to be stored
    :param offload_device: device to offload values to
    """
    batch_intermediates = [{} for _ in range(num_batches)]
    return cls(batch_intermediates, offload_device)

fetch

fetch(
    batch_index: int,
    input_names: Optional[List[str]] = None,
) -> Dict[str, Any]

获取属于某个批次的值

参数

  • batch_index

    (int) –

    正在获取其值的批次索引

  • input_names

    (Optional[List[str]], 默认值: None ) –

    要获取其值的键列表

返回

  • Dict[str, Any]

    键到已加载值的字典映射

源代码在 llmcompressor/pipelines/cache.py
def fetch(
    self, batch_index: int, input_names: Optional[List[str]] = None
) -> Dict[str, Any]:
    """
    Fetch values belonging to a batch

    :param batch_index: index of batch whose values are being fetched
    :param input_names: list of keys whose values are being fetched
    :return: dictionary mapping keys to onloaded values
    """
    intermediates = self.batch_intermediates[batch_index]

    return {
        key: self._onload_value(subgraph_input)
        for key, subgraph_input in intermediates.items()
        if input_names is None or key in input_names
    }

from_dataloader classmethod

from_dataloader(
    dataloader: DataLoader,
    model_device: device = torch.device("cpu"),
    offload_device: Optional[device] = torch.device("cpu"),
)

使用提供的 DataLoader 初始化缓存

参数

  • dataloader

    (DataLoader) –

    生成要缓存的值的 DataLoader

  • model_device

    (device, default: device('cpu') ) –

    获取时要将值加载到的设备

  • offload_device

    (Optional[device], default: device('cpu') ) –

    要卸载值的设备

源代码在 llmcompressor/pipelines/cache.py
@classmethod
def from_dataloader(
    cls,
    dataloader: torch.utils.data.DataLoader,
    model_device: torch.device = torch.device("cpu"),
    offload_device: Optional[torch.device] = torch.device("cpu"),
):
    """
    Initialize a cache with data from the provided dataloader

    :param dataloader: dataloader which generates values to be cached
    :param model_device: device which values will be onloaded to when fetched
    :param offload_device: device to offload values to
    """
    batch_intermediates = [
        {
            key: cls._offload_value(value, offload_device, model_device)
            for key, value in batch.items()
        }
        for batch in tqdm(dataloader, desc="Preparing cache")
    ]

    return cls(batch_intermediates, offload_device)

size

size() -> Dict[torch.device, int]

返回缓存中已使用的内存,按设备分组,以字节为单位

返回

  • Dict[device, int]

    映射到缓存中字节数的 Torch 设备字典

源代码在 llmcompressor/pipelines/cache.py
def size(self) -> Dict[torch.device, int]:
    """
    Returns the memory used by cached values, keyed by device, in bytes

    :return: dictionary mapping torch device to number of bytes in cache
    """
    sizes = defaultdict(lambda: 0)

    def _size_helper(intermediate: IntermediateValue) -> int:
        value = intermediate.value

        match value:
            case torch.Tensor():
                sizes[value.device] += value.nbytes
            case list() | tuple():
                for v in value:
                    _size_helper(v)
            case dict():
                for v in value.values():
                    _size_helper(v)
            case _ if is_dataclass(value):
                for field in fields(value):
                    _size_helper(getattr(value, field.name))
            case _:
                # this handles primitive values that don't match any other cases
                sizes[torch.device("cpu")] += sys.getsizeof(value, 0)

    for intermediates in self.batch_intermediates:
        for value in intermediates.values():
            _size_helper(value)

    return dict(sizes)

update

update(batch_index: int, values: Dict[str, Any])

更新/放置属于某个批次的值

参数

  • batch_index

    (int) –

    要更新其值的批次索引

  • (Dict[str, Any]) –

    用于更新的键到值的字典映射

源代码在 llmcompressor/pipelines/cache.py
def update(self, batch_index: int, values: Dict[str, Any]):
    """
    Update/put values belonging to a batch

    :param batch_index: index of batch whose values will be updated
    :param values: dictionary mapping keys to values used for update
    """
    device = self.offload_device
    intermediates = {k: self._offload_value(v, device) for k, v in values.items()}
    self.batch_intermediates[batch_index].update(intermediates)