Skip to content

Index

_cst

Private CST primitives shared by move/rename/split tooling.

This sub-package is intentionally internal: symbols here are subject to change without notice. External consumers should use the public axm_anvil API instead.

Block dataclass

A top-level symbol definition extracted from a module.

Carries the CST statement that defines the symbol, the leading formatting lines that immediately precede it (used to preserve # --- Section --- comments), and the set of external names referenced by its body.

Source code in packages/axm-anvil/src/axm_anvil/_cst/blocks.py
Python
@dataclass
class Block:
    """A top-level symbol definition extracted from a module.

    Carries the CST statement that defines the symbol, the leading
    formatting lines that immediately precede it (used to preserve
    ``# --- Section ---`` comments), and the set of external names
    referenced by its body.
    """

    name: str
    node: cst.BaseStatement
    leading_lines: list[cst.EmptyLine] = field(default_factory=list)
    referenced_names: set[str] = field(default_factory=set)

ReferenceCollector

Bases: CSTVisitor

Collect referenced names within a CST node.

Visits all Name occurrences and records only the root of any Attribute chain (foo.bar.baz -> "foo").

Source code in packages/axm-anvil/src/axm_anvil/_cst/visitors.py
Python
class ReferenceCollector(cst.CSTVisitor):
    """Collect referenced names within a CST node.

    Visits all ``Name`` occurrences and records only the root of any
    ``Attribute`` chain (``foo.bar.baz`` -> ``"foo"``).
    """

    def __init__(self) -> None:
        super().__init__()
        self.names: set[str] = set()

    def visit_Name(self, node: cst.Name) -> None:  # noqa: N802
        """Record a bare ``Name`` reference."""
        self.names.add(node.value)

    def visit_Attribute(self, node: cst.Attribute) -> bool:  # noqa: N802
        """Record only the root of an ``Attribute`` chain; skip nested visit."""
        root: cst.BaseExpression = node
        while isinstance(root, cst.Attribute):
            root = root.value
        if isinstance(root, cst.Name):
            self.names.add(root.value)
        else:
            root.visit(self)
        return False
visit_Attribute(node)

Record only the root of an Attribute chain; skip nested visit.

Source code in packages/axm-anvil/src/axm_anvil/_cst/visitors.py
Python
def visit_Attribute(self, node: cst.Attribute) -> bool:  # noqa: N802
    """Record only the root of an ``Attribute`` chain; skip nested visit."""
    root: cst.BaseExpression = node
    while isinstance(root, cst.Attribute):
        root = root.value
    if isinstance(root, cst.Name):
        self.names.add(root.value)
    else:
        root.visit(self)
    return False
visit_Name(node)

Record a bare Name reference.

Source code in packages/axm-anvil/src/axm_anvil/_cst/visitors.py
Python
def visit_Name(self, node: cst.Name) -> None:  # noqa: N802
    """Record a bare ``Name`` reference."""
    self.names.add(node.value)

RemoveSymbols

Bases: _DepthTracker

Remove targeted top-level ClassDef, FunctionDef, or constant assignments (Assign / AnnAssign) from a module.

Surrounding formatting (comments, blank lines, indentation of other top-level symbols) is preserved thanks to libcst's lossless tree. Non-assignment SimpleStatementLine nodes (imports, docstrings, bare expressions) are left untouched.

Source code in packages/axm-anvil/src/axm_anvil/_cst/transformers.py
Python
class RemoveSymbols(_DepthTracker):
    """Remove targeted top-level ``ClassDef``, ``FunctionDef``, or constant
    assignments (``Assign`` / ``AnnAssign``) from a module.

    Surrounding formatting (comments, blank lines, indentation of other
    top-level symbols) is preserved thanks to libcst's lossless tree.
    Non-assignment ``SimpleStatementLine`` nodes (imports, docstrings,
    bare expressions) are left untouched.
    """

    def __init__(self, names_to_remove: set[str]) -> None:
        super().__init__()
        self._targets = names_to_remove
        self._depth = 0

    def leave_ClassDef(  # noqa: N802
        self, original_node: cst.ClassDef, updated_node: cst.ClassDef
    ) -> cst.ClassDef | cst.RemovalSentinel:
        """Drop the class when its top-level name matches a removal target."""
        if self._depth == 0 and updated_node.name.value in self._targets:
            return cst.RemoveFromParent()
        return updated_node

    def leave_FunctionDef(  # noqa: N802
        self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
    ) -> cst.FunctionDef | cst.RemovalSentinel:
        """Drop the function when its top-level name matches a removal target."""
        if self._depth == 0 and updated_node.name.value in self._targets:
            return cst.RemoveFromParent()
        return updated_node

    def _should_remove_assign(self, node: cst.Assign) -> bool:
        return (
            len(node.targets) == 1
            and isinstance(node.targets[0].target, cst.Name)
            and node.targets[0].target.value in self._targets
        )

    def _should_remove_ann_assign(self, node: cst.AnnAssign) -> bool:
        return isinstance(node.target, cst.Name) and node.target.value in self._targets

    def _should_remove_stmt(self, inner: cst.BaseSmallStatement) -> bool:
        if isinstance(inner, cst.Assign):
            return self._should_remove_assign(inner)
        if isinstance(inner, cst.AnnAssign):
            return self._should_remove_ann_assign(inner)
        return False

    def leave_SimpleStatementLine(  # noqa: N802
        self,
        original_node: cst.SimpleStatementLine,
        updated_node: cst.SimpleStatementLine,
    ) -> cst.SimpleStatementLine | cst.RemovalSentinel:
        """Drop top-level statement lines whose assignments target a removed name."""
        if self._depth != 0:
            return updated_node
        if any(self._should_remove_stmt(inner) for inner in updated_node.body):
            return cst.RemoveFromParent()
        return updated_node
leave_ClassDef(original_node, updated_node)

Drop the class when its top-level name matches a removal target.

Source code in packages/axm-anvil/src/axm_anvil/_cst/transformers.py
Python
def leave_ClassDef(  # noqa: N802
    self, original_node: cst.ClassDef, updated_node: cst.ClassDef
) -> cst.ClassDef | cst.RemovalSentinel:
    """Drop the class when its top-level name matches a removal target."""
    if self._depth == 0 and updated_node.name.value in self._targets:
        return cst.RemoveFromParent()
    return updated_node
leave_FunctionDef(original_node, updated_node)

Drop the function when its top-level name matches a removal target.

Source code in packages/axm-anvil/src/axm_anvil/_cst/transformers.py
Python
def leave_FunctionDef(  # noqa: N802
    self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
) -> cst.FunctionDef | cst.RemovalSentinel:
    """Drop the function when its top-level name matches a removal target."""
    if self._depth == 0 and updated_node.name.value in self._targets:
        return cst.RemoveFromParent()
    return updated_node
leave_SimpleStatementLine(original_node, updated_node)

Drop top-level statement lines whose assignments target a removed name.

Source code in packages/axm-anvil/src/axm_anvil/_cst/transformers.py
Python
def leave_SimpleStatementLine(  # noqa: N802
    self,
    original_node: cst.SimpleStatementLine,
    updated_node: cst.SimpleStatementLine,
) -> cst.SimpleStatementLine | cst.RemovalSentinel:
    """Drop top-level statement lines whose assignments target a removed name."""
    if self._depth != 0:
        return updated_node
    if any(self._should_remove_stmt(inner) for inner in updated_node.body):
        return cst.RemoveFromParent()
    return updated_node

detect_overload_group(tree, symbol_name)

Return the ordered overload group for symbol_name.

Includes every top-level FunctionDef with that name when at least one is decorated with @overload (or @typing.overload or a resolved alias). Returns [] otherwise.

Source code in packages/axm-anvil/src/axm_anvil/_cst/overloads.py
Python
def detect_overload_group(tree: cst.Module, symbol_name: str) -> list[cst.FunctionDef]:
    """Return the ordered overload group for ``symbol_name``.

    Includes every top-level ``FunctionDef`` with that name when at
    least one is decorated with ``@overload`` (or ``@typing.overload``
    or a resolved alias). Returns ``[]`` otherwise.
    """
    aliases = _collect_overload_aliases(tree)
    funcs = [
        stmt
        for stmt in tree.body
        if isinstance(stmt, cst.FunctionDef) and stmt.name.value == symbol_name
    ]
    if not funcs:
        return []
    has_overload = any(
        _is_overload_decorator(dec, aliases)
        for func in funcs
        for dec in func.decorators
    )
    if not has_overload:
        return []
    return funcs

dotted_name(node)

Convert a Name / Attribute chain to its dotted string form.

Returns an empty string for any other node type.

Source code in packages/axm-anvil/src/axm_anvil/_cst/visitors.py
Python
def dotted_name(node: cst.CSTNode) -> str:
    """Convert a ``Name`` / ``Attribute`` chain to its dotted string form.

    Returns an empty string for any other node type.
    """
    if isinstance(node, cst.Name):
        return node.value
    if isinstance(node, cst.Attribute):
        prefix = dotted_name(node.value)
        if not prefix:
            return ""
        return f"{prefix}.{node.attr.value}"
    return ""

extract_blocks(tree, symbol_names)

Extract Block records for each requested top-level symbol.

Supports ClassDef, FunctionDef, Assign, and AnnAssign at module scope. Missing symbols are silently omitted.

Source code in packages/axm-anvil/src/axm_anvil/_cst/blocks.py
Python
def extract_blocks(tree: cst.Module, symbol_names: Sequence[str]) -> list[Block]:
    """Extract ``Block`` records for each requested top-level symbol.

    Supports ``ClassDef``, ``FunctionDef``, ``Assign``, and ``AnnAssign``
    at module scope. Missing symbols are silently omitted.
    """
    wanted = set(symbol_names)
    blocks: list[Block] = []
    for index, stmt in enumerate(tree.body):
        leading = _leading_lines_for(tree, stmt, index)
        if isinstance(stmt, cst.ClassDef | cst.FunctionDef):
            if stmt.name.value in wanted:
                blocks.append(_make_block(stmt.name.value, stmt, leading))
            continue
        if isinstance(stmt, cst.SimpleStatementLine):
            assign_name = _assigned_name_in(stmt, wanted)
            if assign_name is not None:
                blocks.append(_make_block(assign_name, stmt, leading))
    return blocks