Skip to content

Refactor: unify task detection into one core; derive modality from model class, not config-field heuristics #877

@timenick

Description

@timenick

Summary

detect_task (used by inspect/eval) and _detect_task_and_class_from_config
(used by config/build) are two implementations of the same task decision over
the same MODEL_CLASS_MAPPING data. They have drifted in three places, and the
modality-disambiguation step (D2) reconstructs modality from config field names —
a heuristic that is provably weaker than information the pipeline already holds.

Proposal: extract a single task-detection core with two thin entry points, and
derive modality from the resolved model class's main_input_name.

Internal refactor only — public detect_task / resolve_task_and_model_class
signatures stay unchanged.

Motivation: the D2 false positive (concrete bug)

D2 upgrades feature-extraction → image-feature-extraction when the config has a
top-level image_size/patch_size (OR semantics). patch_size is not exclusive
to vision
— spectrogram transformers patchify their mel-spectrogram. Verified via
the real detection path:

declared architecture inspect task today correct
Wav2Vec2Model feature-extraction → text dataset → fails audio
WhisperModel feature-extraction → text dataset → fails audio
ASTModel image-feature-extraction → image dataset → fails audio

feature-extraction routes to a text dataset (eval: STS-B; build: TextDataset) and
image-feature-extraction to an image dataset, so audio backbones fail in both
eval and build regardless of which branch D2 picks.

Root cause

Modality is a property of the model class (known with certainty), but the pipeline
collapses class → task through TasksManager (modality-blind by design), discarding
modality, then D2 reconstructs it from config fields. The class that carries the
authoritative signal is already resolved at that exact point
(_detect_task_from_config_resolve_model_class_from_config), so D2 pays a
heuristic's fragility to avoid a cost already incurred.

main_input_name (an HF framework convention) is the authoritative, offline,
architecture-agnostic modality signal:

main_input_name modality feature-extraction upgrades to
input_ids text feature-extraction
pixel_values image image-feature-extraction
input_values / input_features audio audio-feature-extraction

It also handles the CLIP text/vision split correctly
(CLIPTextModelWithProjectioninput_ids, CLIPVisionModelWithProjectionpixel_values),
which the config-field table cannot.

Three inconsistencies the merge eliminates

  1. Model-type override: detect_task's distinct_tasks short-circuit vs
    _detect_task_and_class_from_config's (model_type, None) sentinel
    reverse-lookup — two mechanisms over the same dict (fix(task): make detect_task architecture-aware for multi-task model types #841 had to sync this pair).
  2. Model-id override: get_default_task_for_model_id (e.g. prajjwal1/bert-tiny)
    is applied on the build path only; detect_task skips it, so inspect can
    disagree with build today.
  3. Modality signal: config fields (the AST bug) vs the resolved class.

Proposed architecture

_resolve_task_override(model_type, model_id) -> task | None
        single place encoding model_type / model_id canonical task overrides
        (replaces the short-circuit AND the sentinel AND folds in the model-id override)

_detect(config) -> (task, model_class | None, source)
        the one task-detection core: override → wrapped-library → resolve class
        → infer task → fill-mask→seq2seq upgrade

detect_task(config)              = _detect → modality-upgrade → drop class
resolve_task_and_model_class C1  = _detect → ensure class → modality-upgrade

Build-specific class resolution (get_model_class_for_task, specialization, arch
fallback) stays in the build entry layer. The short-circuit's "answer without
importing optimum" optimization is preserved when the override hits.

Decisions (proposed)

  • (a) Remove the D2 config-field table outright. Every path yielding
    feature-extraction either comes from the override mapping (already
    modality-aware) or from TasksManager (class resolved → main_input_name
    available). No path holds feature-extraction without a class, so the
    heuristic is dead weight. Keep _resolve_task_modality as the single modality
    entry point, re-implemented on main_input_name.
  • (b) Audio = detection-correct + register the name; defer real eval support.
    AST/Whisper → audio-feature-extraction; add it to KNOWN_TASKS and
    TASK_ABBREV. Build quant then falls back to RandomDataset (works), eval errors
    cleanly as unsupported. A real audio evaluator/dataset is follow-on work.

Out of scope

  • Universal "ONNX-input-intersection → RandomDataset" calibration fallback
    (defense-in-depth; separate datasets-layer change).
  • A real audio-feature-extraction evaluator + default dataset.

Testing

Parametrized pytest over representative configs (text / image / audio / CLIP-dual /
SAM / bert-tiny), asserting: detect_task == task from resolve_task_and_model_class
for the same config; AST → audio-feature-extraction; bert-tiny model-id override
fires on both paths; vision backbones still → image-feature-extraction; no
regression on the #841 table (bart-mnli, sam, clip).

References

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions