class LinearBucketingStrategy:
def get_prompt_cfgs(self, max_num_prefill_seqs, block_size, max_num_batched_tokens, max_model_len):
use_merged_prefill = get_config().merged_prefill
prefix_caching = get_config().prefix_caching
prompt_bs_bucket_cfg = read_bucket_settings('prompt', 'bs', min=1, step=1, max=max_num_prefill_seqs)
prompt_query_bucket_cfg = read_bucket_settings('prompt',
'query',
min=block_size,
step=block_size,
max=max_num_batched_tokens)
max_ctx = math.ceil((max_model_len - prompt_query_bucket_cfg[0]) // block_size)
prompt_ctx_bucket_cfg = read_bucket_settings('prompt', 'ctx', min=0, step=1, max=max_ctx)
if use_merged_prefill:
prev_prompt_bs_bucket_cfg = tuple(prompt_bs_bucket_cfg)
prev_prompt_query_bucket_cfg = tuple(prompt_query_bucket_cfg)
prev_prompt_ctx_bucket_cfg = tuple(prompt_ctx_bucket_cfg)
prompt_bs_bucket_cfg = (1, 1, 1)
query_min, query_step, _ = prev_prompt_query_bucket_cfg
prompt_query_bucket_cfg = (query_min, query_step * 4, max_num_batched_tokens)
prompt_ctx_bucket_cfg = read_bucket_settings('prompt',
'ctx',
min=0,
step=4,
max=max_ctx * max_num_prefill_seqs)
msg = ('Merged prefill is enabled!\n'
'Overriding prompt bucketing settings!\n'
f'prompt bs cfg: {prev_prompt_bs_bucket_cfg} -> {prompt_bs_bucket_cfg}\n'
f'prompt query cfg: {prev_prompt_query_bucket_cfg} -> {prompt_query_bucket_cfg}\n'
f'prompt ctx cfg: {prev_prompt_ctx_bucket_cfg} -> {prompt_ctx_bucket_cfg}\n')
logger().info(msg)
msg = ("Prompt bucket config (min, step, max_warmup) "
f"bs:{prompt_bs_bucket_cfg}, "
f"query:{prompt_query_bucket_cfg}, "
f"blocks:{prompt_ctx_bucket_cfg}")
logger().info(msg)
return prompt_bs_bucket_cfg, prompt_query_bucket_cfg, prompt_ctx_bucket_cfg
def get_decode_cfgs(self, max_num_seqs, block_size, max_num_batched_tokens, max_model_len, max_blocks):
prefix_caching = get_config().prefix_caching
contiguous_pa = get_config().use_contiguous_pa
decode_bs_bucket_cfg = read_bucket_settings('decode', 'bs', min=1, step=32, max=max_num_seqs)
decode_query_bucket_cfg = [1, 1, 1]
max_decode_blocks = max(math.ceil(max_model_len * max_num_seqs // block_size), block_size)
if contiguous_pa:
max_decode_blocks = max_blocks
decode_block_bucket_cfg = read_bucket_settings('decode', 'block', min=1, step=block_size, max=max_decode_blocks)
if decode_block_bucket_cfg[2] > max_blocks:
logger().info(
f'VLLM_DECODE_BLOCK_BUCKET_MAX={decode_block_bucket_cfg[2]} is higher than max_blocks={max_blocks}. Your configuration VLLM_DECODE_BLOCK_BUCKET_MAX={decode_block_bucket_cfg[2]} will be overwritten to VLLM_DECODE_BLOCK_BUCKET_MAX={max_blocks}'
)
decode_block_bucket_cfg[2] = max_blocks
if decode_block_bucket_cfg[0] > max_blocks:
decode_block_bucket_min = max(1, max_blocks - decode_block_bucket_cfg[1])
logger().info(
f'VLLM_DECODE_BLOCK_BUCKET_MIN={decode_block_bucket_cfg[0]} is higher than max_blocks={max_blocks}. Your configuration VLLM_DECODE_BLOCK_BUCKET_MIN={decode_block_bucket_cfg[0]} will be overwritten to VLLM_DECODE_BLOCK_BUCKET_MIN={decode_block_bucket_min}'
)
decode_block_bucket_cfg[0] = decode_block_bucket_min
msg = ("Decode bucket config (min, step, max_warmup) "
f"bs:{decode_bs_bucket_cfg}, "
f"blocks:{decode_block_bucket_cfg}")
logger().info(msg)
return decode_bs_bucket_cfg, decode_query_bucket_cfg, decode_block_bucket_cfg
def get_range(self, cfg):
range_for_cfg = warmup_range(cfg)
return sorted(range_for_cfg)