class HPUBucketingManager():
_instance = None
prompt_buckets: List[Tuple[int, int, int]] = []
decode_buckets: List[Tuple[int, int, int]] = []
unified_buckets: List[Tuple[int, int, int]] = []
initialized = False
def __new__(cls, *args, **kwargs):
if not cls._instance:
cls._instance = super(HPUBucketingManager, cls).__new__(cls)
return cls._instance
def initialize(self, max_num_seqs, max_num_prefill_seqs, block_size, max_num_batched_tokens, max_model_len):
self.max_num_seqs = max_num_seqs
self.max_num_prefill_seqs = max_num_prefill_seqs
self.block_size = block_size
self.max_num_batched_tokens = max_num_batched_tokens
self.num_hpu_blocks = None
self.max_model_len = max_model_len
self.initialized = True
self.fallback_bs_base_step = 2
self.fallback_seq_base_step = 32
self.fallback_blocks_base_step = 32
self.use_sliding_window = get_config().PT_HPU_SDPA_QKV_SLICE_MODE_FWD
if self.use_sliding_window:
self.slice_size = get_config().PT_HPU_SDPA_BC_FACTOR if \
get_config().PT_HPU_SDPA_BC_FACTOR is not None else 1024
self.slice_thld = get_config().VLLM_FUSEDSDPA_SLIDE_THLD if \
get_config().VLLM_FUSEDSDPA_SLIDE_THLD is not None else 8192
msg = (
f"use_sliding_window {self.use_sliding_window}, slice_size {self.slice_size}, threshold {self.slice_thld}"
)
logger().info(msg)
### GENERATE BUCKETS FUNCTIONS ###
def read_from_file(self, is_prompt):
file_name = get_config().VLLM_BUCKETING_FROM_FILE
from vllm_gaudi.extension.bucketing.file_strategy import (FileBucketingStrategy)
strategy = FileBucketingStrategy()
return strategy.get_buckets(file_name, is_prompt)
def get_bucketing_strategy(self):
strategy = None
# TODO - we can use different strategies for decode and prompt
use_exponential_bucketing = True if \
get_config().VLLM_EXPONENTIAL_BUCKETING == None else \
get_config().VLLM_EXPONENTIAL_BUCKETING
if use_exponential_bucketing:
from vllm_gaudi.extension.bucketing.exponential import (ExponentialBucketingStrategy)
strategy = ExponentialBucketingStrategy()
else:
from vllm_gaudi.extension.bucketing.linear import LinearBucketingStrategy
strategy = LinearBucketingStrategy()
return strategy
def generate_unified_buckets(self):
if self.initialized:
if get_config().VLLM_BUCKETING_FROM_FILE:
assert "Unified attention doesn't support bucketing from file"
from vllm_gaudi.extension.bucketing.unified import (UnifiedBucketingStrategy)
strategy = UnifiedBucketingStrategy()
query_cfg, shared_ctx_cfg, unique_ctx_cfg = strategy.get_unified_cfgs(
bs=self.max_num_seqs,
max_model_len=self.max_model_len,
block_size=self.block_size,
max_blocks=self.num_hpu_blocks,
max_num_batched_tokens=self.max_num_batched_tokens)
query_range = strategy.get_range(query_cfg)
shared_ctx_range = strategy.get_range(shared_ctx_cfg)
unique_ctx_range = strategy.get_range(unique_ctx_cfg)
self.unified_buckets = generate_unified_buckets(query_range, shared_ctx_range, unique_ctx_range,
self.max_num_seqs, self.block_size, self.max_model_len)
msg = (f"Generated {len(self.unified_buckets)} "
f"unified buckets [query, shared_blocks, unique_blocks]: "
f"{list(self.unified_buckets)}")
logger().info(msg)
else:
logger().info("Bucketing is off - skipping prompt buckets generation")
self.unified_buckets = []
return
def generate_prompt_buckets(self):
if self.initialized:
buckets_from_file = None
bs_range = []
query_range = []
ctx_range = []
if get_config().VLLM_BUCKETING_FROM_FILE:
buckets_from_file = self.read_from_file(is_prompt=True)
else:
strategy = self.get_bucketing_strategy()
bs_cfg, query_cfg, ctx_cfg = strategy.get_prompt_cfgs(
max_num_prefill_seqs=self.max_num_prefill_seqs,
block_size=self.block_size,
max_num_batched_tokens=self.max_num_batched_tokens,
max_model_len=self.max_model_len)
bs_range = strategy.get_range(bs_cfg)
query_range = strategy.get_range(query_cfg)
ctx_range = strategy.get_range(ctx_cfg)
self.prompt_buckets = generate_buckets(bs_range, query_range, ctx_range, True, self.max_model_len,
self.max_num_seqs, self.max_num_prefill_seqs,
self.max_num_batched_tokens, self.block_size, self.num_hpu_blocks,
buckets_from_file)
self.log_generate_info(True)
if self.use_sliding_window:
self.prompt_buckets = [
t for t in self.prompt_buckets
if t[2] != 0 or (t[2] == 0 and (t[1] < self.slice_thld or
(t[1] >= self.slice_thld and t[1] % self.slice_size == 0)))
]
self.log_generate_info(True)
else:
logger().info("Bucketing is off - skipping prompt buckets generation")
self.prompt_buckets = []
return
def generate_decode_buckets(self):
if self.initialized:
buckets_from_file = None
bs_range = []
query_range = []
ctx_range = []
if get_config().VLLM_BUCKETING_FROM_FILE:
buckets_from_file = self.read_from_file(is_prompt=False)
else:
strategy = self.get_bucketing_strategy()
bs_cfg, query_cfg, ctx_cfg = strategy.get_decode_cfgs(
max_num_seqs=self.max_num_seqs,
block_size=self.block_size,
max_num_batched_tokens=self.max_num_batched_tokens,
max_model_len=self.max_model_len,
max_blocks=self.num_hpu_blocks)
bs_range = strategy.get_range(bs_cfg)
query_range = strategy.get_range(query_cfg)
ctx_range = strategy.get_range(ctx_cfg)
if get_config().use_contiguous_pa and ctx_range[-1] < self.num_hpu_blocks:
ctx_range.append(self.num_hpu_blocks)
self.decode_buckets = generate_buckets(bs_range, query_range, ctx_range, False, self.max_model_len,
self.max_num_seqs, self.max_num_prefill_seqs,
self.max_num_batched_tokens, self.block_size, self.num_hpu_blocks,
buckets_from_file)
self.log_generate_info(False)
else:
logger().info("Bucketing is off - skipping decode buckets generation")
self.decode_buckets = []
return
def log_generate_info(self, is_prompt=False):
phase = 'prompt' if is_prompt else 'decode'
buckets = self.prompt_buckets if is_prompt else self.decode_buckets
msg = (f"Generated {len(buckets)} "
f"{phase} buckets [bs, query, num_blocks]: "
f"{list(buckets)}")
logger().info(msg)
### RETRIEVE BUCKETS FUNCTIONS ###
def generate_fallback_bucket(self, batch_size, seq_len, ctx):
assert self.max_num_batched_tokens is not None
new_batch_size = calc_fallback_value(batch_size, self.fallback_bs_base_step)
if self.use_sliding_window and seq_len >= self.slice_thld:
new_seq_len = math.ceil(seq_len / self.slice_size) * self.slice_size
else:
new_seq_len = min(calc_fallback_value(seq_len, self.fallback_seq_base_step), self.max_num_batched_tokens)
if self.num_hpu_blocks is None:
new_ctx = 0
else:
new_ctx = min(calc_fallback_value(ctx, self.fallback_blocks_base_step), self.num_hpu_blocks)
return (new_batch_size, new_seq_len, new_ctx)
def find_prompt_bucket(self, batch_size, seq_len, ctx=0):
if self.initialized:
found_bucket = find_equal_or_closest_greater_config(self.prompt_buckets, (batch_size, seq_len, ctx))
if found_bucket is None:
new_bucket = self.generate_fallback_bucket(batch_size, seq_len, ctx)
logger().warning(f"Prompt bucket for {batch_size, seq_len, ctx}"
f" was not prepared. Adding new bucket: {new_bucket}")
self.prompt_buckets.append(new_bucket)
self.prompt_buckets.sort()
return new_bucket
return found_bucket
return (batch_size, seq_len, ctx)
def find_decode_bucket(self, batch_size, num_blocks):
if self.initialized:
found_bucket = find_equal_or_closest_greater_config(self.decode_buckets, (batch_size, 1, num_blocks))
if found_bucket is None:
new_bucket = self.generate_fallback_bucket(batch_size, 1, num_blocks)
logger().warning(f"Decode bucket for {batch_size, 1, num_blocks}"
f" was not prepared. Adding new bucket: {new_bucket}")
self.decode_buckets.append(new_bucket)
self.decode_buckets.sort()
return new_bucket
return found_bucket
return (batch_size, 1, num_blocks)
def find_unified_bucket(self, query, shared_ctx, unique_ctx, is_causal):
if self.initialized:
# TODO: handle is_causal
found_bucket = find_equal_or_closest_greater_config(self.unified_buckets,
(query, shared_ctx, unique_ctx, is_causal))
if found_bucket is None:
logger().warning(f"No bucket found for: {(query, shared_ctx, unique_ctx)}")
return (query, shared_ctx, unique_ctx)
return found_bucket
return (query, shared_ctx, unique_ctx)
def get_max_prompt_shape(self):
return max(b[1] for b in self.prompt_buckets) \
if len(self.prompt_buckets) > 0 else self.max_model_len
@classmethod
def get_instance(cls):
"""
Retrieve the singleton instance of the class.
"""
return cls._instance