def copy_blocks(key_caches, value_caches, block_mapping):
if block_mapping.numel() == 0:
return
block_mapping = block_mapping.transpose(0, 1)
src = block_mapping[0]
dst = block_mapping[1]
for key_cache, value_cache in zip(key_caches, value_caches):
key_cache.index_copy_(0, dst, key_cache.index_select(0, src))
value_cache.index_copy_(0, dst, value_cache.index_select(0, src))
if key_caches[0].device.type == 'hpu':
htorch.core.mark_step()