class HabanaHighLevelProfiler:
profiling_trace_events: queue.Queue = queue.Queue()
event_tid = {'counter': 1, 'external': 2, 'internal': 3}
event_cache: List[Any] = []
def __init__(self, vllm_instance_id=None):
self.enabled = get_config().high_level_profiler_enabled and int(os.getenv('RANK', '0')) == 0
self.pid = os.getpid()
if self.enabled:
self.vllm_instance_id = vllm_instance_id if vllm_instance_id is not None \
else f"vllm-instance-{self.pid}-{str(uuid.uuid4().hex)}"
msg = f'Profiler enabled for: {self.vllm_instance_id}'
logger().info(msg)
self.filename = f'server_events_{self.vllm_instance_id}.json'
# initialize the trace file (JSON Array Format)
with open(self.filename, 'w') as outfile:
outfile.write('[')
file_writer = FileWriter(self.filename, self.profiling_trace_events)
file_writer.start()
if os.getenv('VLLM_PROFILER_ENABLED') == 'full':
self.enabled = True # don't save separate high-level traces
self.gc_track_recompiles = get_config().track_graph_compilation
self.num_graph_compilations = 0
def _dump_with_sep(self, entry):
entry = json.dumps(entry) + ','
self.profiling_trace_events.put(entry)
def get_timestamp_us(self):
return time.time() * 1000000.0
def record_counter(self, ts, counter):
if self.enabled:
self._dump_with_sep({
'pid': self.pid,
'tid': self.event_tid['counter'],
'ph': 'C',
'name': 'utils',
'ts': ts,
'args': counter
})
def start(self, type, name, args=None):
if self.enabled:
ts = self.get_timestamp_us()
if args is not None and 'counter' in args:
self.record_counter(ts, args['counter'])
del args['counter']
event = {
'pid': self.pid,
'tid': self.event_tid[type],
'ph': 'X',
'name': name,
'ts': ts,
'dur': None,
'args': args
}
self.event_cache.append(event)
def end(self):
if self.enabled:
ts = self.get_timestamp_us()
if not self.event_cache:
logger().warning('Profiler: end() call does not have matching start() call. '
'Disabling profiler.')
self.enabled = False
return
event = self.event_cache.pop()
event['dur'] = ts - event['ts']
self._dump_with_sep(event)
def full_trace_handler(self, dir_name, use_gzip=False):
def handler_fn(prof) -> None:
if not os.path.isdir(dir_name):
try:
os.makedirs(dir_name, exist_ok=True)
except Exception as e:
raise RuntimeError("Can't create directory: " + dir_name) from e
file_name = f"vllm.{time.time_ns()}.pt.trace.json"
file_path = os.path.join(dir_name, file_name)
prof.export_chrome_trace(file_path)
with open(file_path) as f:
pytorch_trace = json.load(f)
os.remove(file_path)
base = pytorch_trace['baseTimeNanoseconds'] / 1000
events = self.profiling_trace_events
while True:
try:
event_str = events.get_nowait()
event = json.loads(event_str[:-1])
event['ts'] = event['ts'] - base
pytorch_trace['traceEvents'].append(event)
except queue.Empty:
break
pytorch_trace['traceEvents'].append({
"args": {
"name": "vLLM"
},
"name": "process_name",
"ph": "M",
"pid": 1,
"tid": 0,
"ts": 0.0
})
if use_gzip:
file_path = file_path + ".gz"
with gzip.open(file_path, 'wt', encoding="ascii") as zipfile:
json.dump(pytorch_trace, zipfile)
else:
with open(file_path, "w") as outfile:
outfile.write(json.dumps(pytorch_trace))
logger().info("Saved full profiling to %s", file_path)
return handler_fn
@contextmanager
def record_event(self, type, name, args=None):
if self.enabled:
self.start(type, name, args)
with self.track_graph_compile(type, args) \
if self.gc_track_recompiles \
else contextlib.nullcontext():
yield
self.end()
else:
yield
def record_block(self, type, name, ts, dur, args=None):
if self.enabled:
event = {
'pid': self.pid,
'tid': self.event_tid[type],
'ph': 'X',
'name': name,
'ts': ts,
'dur': dur,
'args': args
}
self._dump_with_sep(event)
@contextmanager
def track_graph_compile(self, type, args=None):
start = self.get_timestamp_us()
import habana_frameworks.torch as htorch
from habana_frameworks.torch.hpu.metrics import metric_localcontext
with metric_localcontext("graph_compilation") as gc:
yield
htorch.hpu.synchronize()
if gc.stats()[0][1] != 0:
compile_start_time = start
for recipe in gc.stats()[3][1]:
recipe_name = recipe[0]
compile_time = recipe[1]
self.num_graph_compilations += 1
self.record_counter(compile_start_time, {'cumulative_graph_compilations': self.num_graph_compilations})
self.record_block(type, 'GRAPH COMPILE: ' + recipe_name, compile_start_time, compile_time, args)
compile_start_time += compile_time