Skip to content

Cst rewrite

cst_rewrite

libcst write helpers + the import-index cache.

Everything that mutates Python source code lives here. Split into five concerns:

  • class flatten (_flatten_class_to_top_level, _flatten_class_child)
  • rename (_rename_name_in_module, _rename_top_level_in_source)
  • delete (_delete_function_from_source, _delete_source_if_empty_tests)
  • statement reorder (_reorder_module_statements + ast helpers)
  • Path(__file__).parents[N] depth patch (_patch_file_dunder_depth)
  • import management — read-side analysis lives in tests_ast; here we insert, dedupe, backfill, and synthesise missing imports, plus the project-wide import index that keeps backfill O(1) per lookup.

Hybrid I/O: we read with ast (cheap, well-tested for analysis) and write with libcst (source-fidelity: comments, triple-quoted strings, blank lines, quote style all preserved — what ast.unparse silently loses).

backfill_import(module, mapping)

Insert from {mod} import {name} for each (name → mod) in mapping.

The new imports go at the canonical position — after every from __future__ import … (or at the module top when none exists) and before the first non-import statement. Names already imported in module are skipped, so this is idempotent on already-correct sources.

Source code in packages/axm-audit/src/axm_audit/core/fix/cst_rewrite.py
Python
def backfill_import(module: cst.Module, mapping: dict[str, str]) -> cst.Module:
    """Insert ``from {mod} import {name}`` for each (name → mod) in *mapping*.

    The new imports go at the canonical position — after every
    ``from __future__ import …`` (or at the module top when none exists)
    and before the first non-import statement. Names already imported in
    *module* are skipped, so this is idempotent on already-correct sources.
    """
    if not mapping:
        return module

    new_imports = _build_import_lines(mapping, _existing_import_names(module))
    if not new_imports:
        return module

    body = list(module.body)
    insert_at = _future_import_insert_index(body)
    return module.with_changes(body=body[:insert_at] + new_imports + body[insert_at:])

backfill_missing_imports(source, target, project_path=None)

Copy imports from source into target for names target uses but doesn't define.

Falls back to scanning all test files under project_path if the immediate source doesn't have the import — covers cases where the original import was lost by an earlier move.

Hybrid: analyse with ast (cheap, well-tested), write with libcst so triple-quoted strings, blank lines, and comments in the target file are preserved byte-for-byte.

Source code in packages/axm-audit/src/axm_audit/core/fix/cst_rewrite.py
Python
def backfill_missing_imports(
    source: Path, target: Path, project_path: Path | None = None
) -> list[str]:
    """Copy imports from *source* into *target* for names target uses but doesn't define.

    Falls back to scanning all test files under ``project_path`` if the
    immediate source doesn't have the import — covers cases where the
    original import was lost by an earlier move.

    Hybrid: analyse with ast (cheap, well-tested), write with libcst so
    triple-quoted strings, blank lines, and comments in the target file
    are preserved byte-for-byte.
    """
    recoverable = _resolve_recoverable_imports(source, target, project_path)
    if recoverable is None:
        return []
    unresolved = recoverable.pop(_UNRESOLVED_KEY, None)
    msgs: list[str] = []
    if recoverable:
        top_level_ast, type_checking_ast, msgs = _partition_imports(
            recoverable, source.name
        )
        _write_backfilled_imports(target, top_level_ast, type_checking_ast)
    if isinstance(unresolved, set):
        msgs.extend(
            f"unresolved import for `{name}` in {target.name} "
            "(no donor found; left for manual fix)"
            for name in sorted(unresolved)
        )
    return msgs

dedupe_imports(module)

Public wrapper around :func:_dedupe_imports_cst.

Source code in packages/axm-audit/src/axm_audit/core/fix/cst_rewrite.py
Python
def dedupe_imports(module: cst.Module) -> cst.Module:
    """Public wrapper around :func:`_dedupe_imports_cst`."""
    return _dedupe_imports_cst(module)

delete_function(module, func_name)

Drop top-level function func_name from module.

Neighbouring statements (and their attached blank-line spacing) are preserved by libcst's leading-lines semantics. In-memory counterpart of :func:_delete_function_from_source.

Source code in packages/axm-audit/src/axm_audit/core/fix/cst_rewrite.py
Python
def delete_function(module: cst.Module, func_name: str) -> cst.Module:
    """Drop top-level function *func_name* from *module*.

    Neighbouring statements (and their attached blank-line spacing) are
    preserved by libcst's leading-lines semantics. In-memory counterpart
    of :func:`_delete_function_from_source`.
    """
    new_body = [
        stmt
        for stmt in module.body
        if not (isinstance(stmt, cst.FunctionDef) and stmt.name.value == func_name)
    ]
    return module.with_changes(body=new_body)

flatten_class(module, class_name)

Flatten the class_name class into top-level functions.

In-memory variant of :func:_flatten_class_to_top_level. Class-level pytest marks are propagated onto each promoted method; method-level decorators are preserved verbatim; the class docstring is dropped.

Source code in packages/axm-audit/src/axm_audit/core/fix/cst_rewrite.py
Python
def flatten_class(module: cst.Module, class_name: str) -> cst.Module:
    """Flatten the *class_name* class into top-level functions.

    In-memory variant of :func:`_flatten_class_to_top_level`. Class-level
    pytest marks are propagated onto each promoted method; method-level
    decorators are preserved verbatim; the class docstring is dropped.
    """
    new_body: list[cst.BaseStatement] = []
    for stmt in module.body:
        if not (isinstance(stmt, cst.ClassDef) and stmt.name.value == class_name):
            new_body.append(stmt)
            continue
        class_decos = tuple(d for d in stmt.decorators if _is_pytest_mark_decorator(d))
        for child in stmt.body.body:
            promoted = _flatten_class_child(child, class_decos)
            if promoted is not None:
                new_body.append(promoted)
    return module.with_changes(body=new_body)

invalidate_import_index(project_path)

Drop the cached import index for project_path.

Source code in packages/axm-audit/src/axm_audit/core/fix/cst_rewrite.py
Python
def invalidate_import_index(project_path: Path) -> None:
    """Drop the cached import index for *project_path*."""
    _PROJECT_IMPORT_INDEX_CACHE.pop(project_path, None)

patch_file_depth(module, depth_delta=0)

Rewrite Path(__file__).parents[N] literals by depth_delta.

In-memory variant of :func:_patch_file_dunder_depth that targets the subscript form only — the chained .parent.parent form is left for the file-level helper. Identity transform when depth_delta is 0 or the pattern is absent.

Source code in packages/axm-audit/src/axm_audit/core/fix/cst_rewrite.py
Python
def patch_file_depth(module: cst.Module, depth_delta: int = 0) -> cst.Module:
    """Rewrite ``Path(__file__).parents[N]`` literals by *depth_delta*.

    In-memory variant of :func:`_patch_file_dunder_depth` that targets the
    subscript form only — the chained ``.parent.parent`` form is left for
    the file-level helper. Identity transform when *depth_delta* is 0 or
    the pattern is absent.
    """
    if depth_delta == 0:
        return module

    class _DunderPatcher(cst.CSTTransformer):
        def leave_Subscript(
            self,
            original_node: cst.Subscript,
            updated_node: cst.Subscript,
        ) -> cst.BaseExpression:
            value = updated_node.value
            if not isinstance(value, cst.Attribute):
                return updated_node
            if value.attr.value != "parents":
                return updated_node
            if not _is_file_dunder_chain(value.value):
                return updated_node
            slices = updated_node.slice
            if len(slices) != 1:
                return updated_node
            elt = slices[0].slice
            if not isinstance(elt, cst.Index):
                return updated_node
            n_node = elt.value
            if not isinstance(n_node, cst.Integer):
                return updated_node
            new_n = int(n_node.value) + depth_delta
            if new_n <= 0:
                return updated_node
            return updated_node.with_changes(
                slice=[
                    cst.SubscriptElement(
                        slice=cst.Index(value=cst.Integer(value=str(new_n)))
                    )
                ]
            )

    result = module.visit(_DunderPatcher())
    assert isinstance(result, cst.Module)
    return result

rename_function(module, old_name, new_name)

Rename top-level function old_name to new_name across module.

Updates the FunctionDef itself, any Name reference, and any string-literal argument (e.g. pytest.mark.parametrize("old", …)) that matches old_name. In-memory counterpart of :func:_rename_name_in_module.

Source code in packages/axm-audit/src/axm_audit/core/fix/cst_rewrite.py
Python
def rename_function(module: cst.Module, old_name: str, new_name: str) -> cst.Module:
    """Rename top-level function *old_name* to *new_name* across *module*.

    Updates the ``FunctionDef`` itself, any ``Name`` reference, and any
    string-literal argument (e.g. ``pytest.mark.parametrize("old", …)``)
    that matches *old_name*. In-memory counterpart of
    :func:`_rename_name_in_module`.
    """
    mapping = {old_name: new_name}

    class _Renamer(cst.CSTTransformer):
        def leave_Name(
            self, original_node: cst.Name, updated_node: cst.Name
        ) -> cst.BaseExpression:
            if updated_node.value in mapping:
                return updated_node.with_changes(value=mapping[updated_node.value])
            return updated_node

        def leave_FunctionDef(
            self,
            original_node: cst.FunctionDef,
            updated_node: cst.FunctionDef,
        ) -> cst.BaseStatement:
            if updated_node.name.value in mapping:
                return updated_node.with_changes(
                    name=cst.Name(value=mapping[updated_node.name.value])
                )
            return updated_node

        def leave_SimpleString(
            self,
            original_node: cst.SimpleString,
            updated_node: cst.SimpleString,
        ) -> cst.BaseExpression:
            raw = updated_node.value
            if len(raw) < 2 or raw[0] not in {'"', "'"}:
                return updated_node
            inner = raw[1:-1]
            if inner in mapping:
                return updated_node.with_changes(
                    value=f"{raw[0]}{mapping[inner]}{raw[0]}"
                )
            return updated_node

    result = module.visit(_Renamer())
    assert isinstance(result, cst.Module)
    return result

resolve_import_for_symbol(project_path, symbol)

Return the import statement that brings symbol into scope, or None.

Builds (and caches in _PROJECT_IMPORT_INDEX_CACHE) a project-wide index of top-level FunctionDef / AsyncFunctionDef / ClassDef definitions across every .py file under project_path. Drop the cache via :func:invalidate_import_index after mutating the file tree so the next call rebuilds.

Source code in packages/axm-audit/src/axm_audit/core/fix/cst_rewrite.py
Python
def resolve_import_for_symbol(
    project_path: Path, symbol: str
) -> tuple[ast.stmt, ast.stmt | None] | None:
    """Return the import statement that brings *symbol* into scope, or ``None``.

    Builds (and caches in ``_PROJECT_IMPORT_INDEX_CACHE``) a project-wide
    index of top-level FunctionDef / AsyncFunctionDef / ClassDef
    definitions across every ``.py`` file under *project_path*. Drop the
    cache via :func:`invalidate_import_index` after mutating the file
    tree so the next call rebuilds.
    """
    if project_path not in _PROJECT_IMPORT_INDEX_CACHE:
        _PROJECT_IMPORT_INDEX_CACHE[project_path] = _build_project_symbol_index(
            project_path
        )
    return _PROJECT_IMPORT_INDEX_CACHE[project_path].get(symbol)