Skip to content

Overloads

overloads

Detect @overload groups for a given symbol in a module.

detect_overload_group(tree, symbol_name)

Return the ordered overload group for symbol_name.

Includes every top-level FunctionDef with that name when at least one is decorated with @overload (or @typing.overload or a resolved alias). Returns [] otherwise.

Source code in packages/axm-anvil/src/axm_anvil/_cst/overloads.py
Python
def detect_overload_group(tree: cst.Module, symbol_name: str) -> list[cst.FunctionDef]:
    """Return the ordered overload group for ``symbol_name``.

    Includes every top-level ``FunctionDef`` with that name when at
    least one is decorated with ``@overload`` (or ``@typing.overload``
    or a resolved alias). Returns ``[]`` otherwise.
    """
    aliases = _collect_overload_aliases(tree)
    funcs = [
        stmt
        for stmt in tree.body
        if isinstance(stmt, cst.FunctionDef) and stmt.name.value == symbol_name
    ]
    if not funcs:
        return []
    has_overload = any(
        _is_overload_decorator(dec, aliases)
        for func in funcs
        for dec in func.decorators
    )
    if not has_overload:
        return []
    return funcs