class UnifiedBucketingStrategy():
def get_unified_cfgs(self, bs, max_model_len, block_size, max_blocks, max_num_batched_tokens):
# [min, max, turning_point]
query_cfg = [1, max_num_batched_tokens, bs]
max_shared_ctx = min(math.ceil(max_model_len // block_size), max_blocks)
shared_ctx_cfg = [0, max_shared_ctx, bs]
max_unique_ctx = max_blocks
unique_ctx_cfg = [0, max_unique_ctx, bs]
return query_cfg, shared_ctx_cfg, unique_ctx_cfg
def get_range(self, cfg):
range_for_cfg = warmup_unified_range(cfg)
return sorted(range_for_cfg)