def trace(
model_id: str,
model_class: Type[PreTrainedModel],
sequential_targets: list[str] | str | None = None,
ignore: list[str] | str = DatasetArguments().tracing_ignore,
modality: str = "text",
trust_remote_code: bool = True,
skip_weights: bool = True,
device_map: str | dict = "cpu",
) -> Tuple[PreTrainedModel, list[Subgraph], dict[str, torch.Tensor]]:
"""
Debug traceability by tracing a pre-trained model into subgraphs
:param model_id: stub of the model to load
:param model_class: class constructor of the pre-trained model. Can use either
HF transformers classes or `Traceable` classes defined by LLM Compressor
:param sequential_targets: targets for sequential tracing, defaults to automatic
inference
:param ignore: patterns to ignore during tracing
:param modality: data modality for dummy tracing data, defaults to 'text'
:param trust_remote_code: trust remote model code
Example usage from CLI
llmcompressor.trace \
--model_id Qwen/Qwen2-VL-2B-Instruct \
--model_class Qwen2VLForConditionalGeneration \
--sequential_targets Qwen2VLDecoderLayer \
--ignore "lm_head" "re:visual.*" \
--modality text
"""
# Load model
with skip_weights_download(model_class) if skip_weights else nullcontext():
model = model_class.from_pretrained(
model_id,
device_map=device_map,
torch_dtype="auto",
trust_remote_code=trust_remote_code,
)
processor = AutoProcessor.from_pretrained(
model_id, trust_remote_code=trust_remote_code
)
print("Loaded model")
# Prepare sample data
dataset_args = DatasetArguments(**get_dataset_kwargs(modality, ignore))
dataset = TextGenerationDataset.load_from_registry(
dataset_args.dataset,
dataset_args=dataset_args,
split=dataset_args.splits["calibration"],
processor=processor,
)(add_labels=False)
sample = next(iter(dataset))
sample = collate_sample(sample, device=model.device)
print("Loaded sample data")
# infer sequential targets
if sequential_targets is None:
sequential_targets = get_no_split_params(model)
if isinstance(sequential_targets, str):
sequential_targets = [sequential_targets]
# Attempt trace
print(
"\nAttempting trace\n"
f" model_id={model_id}\n"
f" model_class={model_class.__name__}\n"
f" dataset={dataset_args.dataset}\n"
f" split={dataset.split}\n"
f" inputs={sample.keys()}\n"
f" sequential_targets={sequential_targets}\n"
f" ignore={dataset_args.tracing_ignore}\n"
)
subgraphs = trace_subgraphs(
model, sample, sequential_targets, dataset_args.tracing_ignore
)
print(f"Successfully traced model into {len(subgraphs)} subgraphs!\n")
return model, subgraphs, sample