"""Base visitor that implements an identity AST transform.

Subclass TransformVisitor to perform non-trivial transformations.
"""

from __future__ import annotations

from collections.abc import Iterable
from typing import Optional, cast

from mypy.nodes import (
    GDEF,
    REVEAL_TYPE,
    Argument,
    AssertStmt,
    AssertTypeExpr,
    AssignmentExpr,
    AssignmentStmt,
    AwaitExpr,
    Block,
    BreakStmt,
    BytesExpr,
    CallExpr,
    CastExpr,
    ClassDef,
    ComparisonExpr,
    ComplexExpr,
    ConditionalExpr,
    ContinueStmt,
    Decorator,
    DelStmt,
    DictExpr,
    DictionaryComprehension,
    EllipsisExpr,
    EnumCallExpr,
    Expression,
    ExpressionStmt,
    FloatExpr,
    ForStmt,
    FuncDef,
    FuncItem,
    GeneratorExpr,
    GlobalDecl,
    IfStmt,
    Import,
    ImportAll,
    ImportFrom,
    IndexExpr,
    IntExpr,
    LambdaExpr,
    ListComprehension,
    ListExpr,
    MatchStmt,
    MemberExpr,
    MypyFile,
    NamedTupleExpr,
    NameExpr,
    NewTypeExpr,
    Node,
    NonlocalDecl,
    OperatorAssignmentStmt,
    OpExpr,
    OverloadedFuncDef,
    OverloadPart,
    ParamSpecExpr,
    PassStmt,
    PromoteExpr,
    RaiseStmt,
    RefExpr,
    ReturnStmt,
    RevealExpr,
    SetComprehension,
    SetExpr,
    SliceExpr,
    StarExpr,
    Statement,
    StrExpr,
    SuperExpr,
    SymbolTable,
    TempNode,
    TryStmt,
    TupleExpr,
    TypeAliasExpr,
    TypeApplication,
    TypedDictExpr,
    TypeVarExpr,
    TypeVarTupleExpr,
    UnaryExpr,
    Var,
    WhileStmt,
    WithStmt,
    YieldExpr,
    YieldFromExpr,
)
from mypy.patterns import (
    AsPattern,
    ClassPattern,
    MappingPattern,
    OrPattern,
    Pattern,
    SequencePattern,
    SingletonPattern,
    StarredPattern,
    ValuePattern,
)
from mypy.traverser import TraverserVisitor
from mypy.types import FunctionLike, ProperType, Type
from mypy.util import replace_object_state
from mypy.visitor import NodeVisitor


class TransformVisitor(NodeVisitor[Node]):
    """Transform a semantically analyzed AST (or subtree) to an identical copy.

    Use the node() method to transform an AST node.

    Subclass to perform a non-identity transform.

    Notes:

     * This can only be used to transform functions or classes, not top-level
       statements, and/or modules as a whole.
     * Do not duplicate TypeInfo nodes. This would generally not be desirable.
     * Only update some name binding cross-references, but only those that
       refer to Var, Decorator or FuncDef nodes, not those targeting ClassDef or
       TypeInfo nodes.
     * Types are not transformed, but you can override type() to also perform
       type transformation.

    TODO nested classes and functions have not been tested well enough
    """

    def __init__(self) -> None:
        # To simplify testing, set this flag to True if you want to transform
        # all statements in a file (this is prohibited in normal mode).
        self.test_only = False
        # There may be multiple references to a Var node. Keep track of
        # Var translations using a dictionary.
        self.var_map: dict[Var, Var] = {}
        # These are uninitialized placeholder nodes used temporarily for nested
        # functions while we are transforming a top-level function. This maps an
        # untransformed node to a placeholder (which will later become the
        # transformed node).
        self.func_placeholder_map: dict[FuncDef, FuncDef] = {}

    def visit_mypy_file(self, node: MypyFile) -> MypyFile:
        assert self.test_only, "This visitor should not be used for whole files."
        # NOTE: The 'names' and 'imports' instance variables will be empty!
        ignored_lines = {line: codes.copy() for line, codes in node.ignored_lines.items()}
        new = MypyFile(self.statements(node.defs), [], node.is_bom, ignored_lines=ignored_lines)
        new._fullname = node._fullname
        new.path = node.path
        new.names = SymbolTable()
        return new

    def visit_import(self, node: Import) -> Import:
        return Import(node.ids.copy())

    def visit_import_from(self, node: ImportFrom) -> ImportFrom:
        return ImportFrom(node.id, node.relative, node.names.copy())

    def visit_import_all(self, node: ImportAll) -> ImportAll:
        return ImportAll(node.id, node.relative)

    def copy_argument(self, argument: Argument) -> Argument:
        arg = Argument(
            self.visit_var(argument.variable),
            argument.type_annotation,
            argument.initializer,
            argument.kind,
        )

        # Refresh lines of the inner things
        arg.set_line(argument)

        return arg

    def visit_func_def(self, node: FuncDef) -> FuncDef:
        # Note that a FuncDef must be transformed to a FuncDef.

        # These contortions are needed to handle the case of recursive
        # references inside the function being transformed.
        # Set up placeholder nodes for references within this function
        # to other functions defined inside it.
        # Don't create an entry for this function itself though,
        # since we want self-references to point to the original
        # function if this is the top-level node we are transforming.
        init = FuncMapInitializer(self)
        for stmt in node.body.body:
            stmt.accept(init)

        new = FuncDef(
            node.name,
            [self.copy_argument(arg) for arg in node.arguments],
            self.block(node.body),
            cast(Optional[FunctionLike], self.optional_type(node.type)),
        )

        self.copy_function_attributes(new, node)

        new._fullname = node._fullname
        new.is_decorated = node.is_decorated
        new.is_conditional = node.is_conditional
        new.abstract_status = node.abstract_status
        new.is_static = node.is_static
        new.is_class = node.is_class
        new.is_property = node.is_property
        new.is_final = node.is_final
        new.original_def = node.original_def

        if node in self.func_placeholder_map:
            # There is a placeholder definition for this function. Replace
            # the attributes of the placeholder with those form the transformed
            # function. We know that the classes will be identical (otherwise
            # this wouldn't work).
            result = self.func_placeholder_map[node]
            replace_object_state(result, new)
            return result
        else:
            return new

    def visit_lambda_expr(self, node: LambdaExpr) -> LambdaExpr:
        new = LambdaExpr(
            [self.copy_argument(arg) for arg in node.arguments],
            self.block(node.body),
            cast(Optional[FunctionLike], self.optional_type(node.type)),
        )
        self.copy_function_attributes(new, node)
        return new

    def copy_function_attributes(self, new: FuncItem, original: FuncItem) -> None:
        new.info = original.info
        new.min_args = original.min_args
        new.max_pos = original.max_pos
        new.is_overload = original.is_overload
        new.is_generator = original.is_generator
        new.is_coroutine = original.is_coroutine
        new.is_async_generator = original.is_async_generator
        new.is_awaitable_coroutine = original.is_awaitable_coroutine
        new.line = original.line

    def visit_overloaded_func_def(self, node: OverloadedFuncDef) -> OverloadedFuncDef:
        items = [cast(OverloadPart, item.accept(self)) for item in node.items]
        for newitem, olditem in zip(items, node.items):
            newitem.line = olditem.line
        new = OverloadedFuncDef(items)
        new._fullname = node._fullname
        new_type = self.optional_type(node.type)
        assert isinstance(new_type, ProperType)
        new.type = new_type
        new.info = node.info
        new.is_static = node.is_static
        new.is_class = node.is_class
        new.is_property = node.is_property
        new.is_final = node.is_final
        if node.impl:
            new.impl = cast(OverloadPart, node.impl.accept(self))
        return new

    def visit_class_def(self, node: ClassDef) -> ClassDef:
        new = ClassDef(
            node.name,
            self.block(node.defs),
            node.type_vars,
            self.expressions(node.base_type_exprs),
            self.optional_expr(node.metaclass),
        )
        new.fullname = node.fullname
        new.info = node.info
        new.decorators = [self.expr(decorator) for decorator in node.decorators]
        return new

    def visit_global_decl(self, node: GlobalDecl) -> GlobalDecl:
        return GlobalDecl(node.names.copy())

    def visit_nonlocal_decl(self, node: NonlocalDecl) -> NonlocalDecl:
        return NonlocalDecl(node.names.copy())

    def visit_block(self, node: Block) -> Block:
        return Block(self.statements(node.body), is_unreachable=node.is_unreachable)

    def visit_decorator(self, node: Decorator) -> Decorator:
        # Note that a Decorator must be transformed to a Decorator.
        func = self.visit_func_def(node.func)
        func.line = node.func.line
        new = Decorator(func, self.expressions(node.decorators), self.visit_var(node.var))
        new.is_overload = node.is_overload
        return new

    def visit_var(self, node: Var) -> Var:
        # Note that a Var must be transformed to a Var.
        if node in self.var_map:
            return self.var_map[node]
        new = Var(node.name, self.optional_type(node.type))
        new.line = node.line
        new._fullname = node._fullname
        new.info = node.info
        new.is_self = node.is_self
        new.is_ready = node.is_ready
        new.is_initialized_in_class = node.is_initialized_in_class
        new.is_staticmethod = node.is_staticmethod
        new.is_classmethod = node.is_classmethod
        new.is_property = node.is_property
        new.is_final = node.is_final
        new.final_value = node.final_value
        new.final_unset_in_class = node.final_unset_in_class
        new.final_set_in_init = node.final_set_in_init
        new.set_line(node)
        self.var_map[node] = new
        return new

    def visit_expression_stmt(self, node: ExpressionStmt) -> ExpressionStmt:
        return ExpressionStmt(self.expr(node.expr))

    def visit_assignment_stmt(self, node: AssignmentStmt) -> AssignmentStmt:
        return self.duplicate_assignment(node)

    def duplicate_assignment(self, node: AssignmentStmt) -> AssignmentStmt:
        new = AssignmentStmt(
            self.expressions(node.lvalues),
            self.expr(node.rvalue),
            self.optional_type(node.unanalyzed_type),
        )
        new.line = node.line
        new.is_final_def = node.is_final_def
        new.type = self.optional_type(node.type)
        return new

    def visit_operator_assignment_stmt(
        self, node: OperatorAssignmentStmt
    ) -> OperatorAssignmentStmt:
        return OperatorAssignmentStmt(node.op, self.expr(node.lvalue), self.expr(node.rvalue))

    def visit_while_stmt(self, node: WhileStmt) -> WhileStmt:
        return WhileStmt(
            self.expr(node.expr), self.block(node.body), self.optional_block(node.else_body)
        )

    def visit_for_stmt(self, node: ForStmt) -> ForStmt:
        new = ForStmt(
            self.expr(node.index),
            self.expr(node.expr),
            self.block(node.body),
            self.optional_block(node.else_body),
            self.optional_type(node.unanalyzed_index_type),
        )
        new.is_async = node.is_async
        new.index_type = self.optional_type(node.index_type)
        return new

    def visit_return_stmt(self, node: ReturnStmt) -> ReturnStmt:
        return ReturnStmt(self.optional_expr(node.expr))

    def visit_assert_stmt(self, node: AssertStmt) -> AssertStmt:
        return AssertStmt(self.expr(node.expr), self.optional_expr(node.msg))

    def visit_del_stmt(self, node: DelStmt) -> DelStmt:
        return DelStmt(self.expr(node.expr))

    def visit_if_stmt(self, node: IfStmt) -> IfStmt:
        return IfStmt(
            self.expressions(node.expr),
            self.blocks(node.body),
            self.optional_block(node.else_body),
        )

    def visit_break_stmt(self, node: BreakStmt) -> BreakStmt:
        return BreakStmt()

    def visit_continue_stmt(self, node: ContinueStmt) -> ContinueStmt:
        return ContinueStmt()

    def visit_pass_stmt(self, node: PassStmt) -> PassStmt:
        return PassStmt()

    def visit_raise_stmt(self, node: RaiseStmt) -> RaiseStmt:
        return RaiseStmt(self.optional_expr(node.expr), self.optional_expr(node.from_expr))

    def visit_try_stmt(self, node: TryStmt) -> TryStmt:
        new = TryStmt(
            self.block(node.body),
            self.optional_names(node.vars),
            self.optional_expressions(node.types),
            self.blocks(node.handlers),
            self.optional_block(node.else_body),
            self.optional_block(node.finally_body),
        )
        new.is_star = node.is_star
        return new

    def visit_with_stmt(self, node: WithStmt) -> WithStmt:
        new = WithStmt(
            self.expressions(node.expr),
            self.optional_expressions(node.target),
            self.block(node.body),
            self.optional_type(node.unanalyzed_type),
        )
        new.is_async = node.is_async
        new.analyzed_types = [self.type(typ) for typ in node.analyzed_types]
        return new

    def visit_as_pattern(self, p: AsPattern) -> AsPattern:
        return AsPattern(
            pattern=self.pattern(p.pattern) if p.pattern is not None else None,
            name=self.duplicate_name(p.name) if p.name is not None else None,
        )

    def visit_or_pattern(self, p: OrPattern) -> OrPattern:
        return OrPattern([self.pattern(pat) for pat in p.patterns])

    def visit_value_pattern(self, p: ValuePattern) -> ValuePattern:
        return ValuePattern(self.expr(p.expr))

    def visit_singleton_pattern(self, p: SingletonPattern) -> SingletonPattern:
        return SingletonPattern(p.value)

    def visit_sequence_pattern(self, p: SequencePattern) -> SequencePattern:
        return SequencePattern([self.pattern(pat) for pat in p.patterns])

    def visit_starred_pattern(self, p: StarredPattern) -> StarredPattern:
        return StarredPattern(self.duplicate_name(p.capture) if p.capture is not None else None)

    def visit_mapping_pattern(self, p: MappingPattern) -> MappingPattern:
        return MappingPattern(
            keys=[self.expr(expr) for expr in p.keys],
            values=[self.pattern(pat) for pat in p.values],
            rest=self.duplicate_name(p.rest) if p.rest is not None else None,
        )

    def visit_class_pattern(self, p: ClassPattern) -> ClassPattern:
        class_ref = p.class_ref.accept(self)
        assert isinstance(class_ref, RefExpr)
        return ClassPattern(
            class_ref=class_ref,
            positionals=[self.pattern(pat) for pat in p.positionals],
            keyword_keys=list(p.keyword_keys),
            keyword_values=[self.pattern(pat) for pat in p.keyword_values],
        )

    def visit_match_stmt(self, o: MatchStmt) -> MatchStmt:
        return MatchStmt(
            subject=self.expr(o.subject),
            patterns=[self.pattern(p) for p in o.patterns],
            guards=self.optional_expressions(o.guards),
            bodies=self.blocks(o.bodies),
        )

    def visit_star_expr(self, node: StarExpr) -> StarExpr:
        return StarExpr(node.expr)

    def visit_int_expr(self, node: IntExpr) -> IntExpr:
        return IntExpr(node.value)

    def visit_str_expr(self, node: StrExpr) -> StrExpr:
        return StrExpr(node.value)

    def visit_bytes_expr(self, node: BytesExpr) -> BytesExpr:
        return BytesExpr(node.value)

    def visit_float_expr(self, node: FloatExpr) -> FloatExpr:
        return FloatExpr(node.value)

    def visit_complex_expr(self, node: ComplexExpr) -> ComplexExpr:
        return ComplexExpr(node.value)

    def visit_ellipsis(self, node: EllipsisExpr) -> EllipsisExpr:
        return EllipsisExpr()

    def visit_name_expr(self, node: NameExpr) -> NameExpr:
        return self.duplicate_name(node)

    def duplicate_name(self, node: NameExpr) -> NameExpr:
        # This method is used when the transform result must be a NameExpr.
        # visit_name_expr() is used when there is no such restriction.
        new = NameExpr(node.name)
        self.copy_ref(new, node)
        new.is_special_form = node.is_special_form
        return new

    def visit_member_expr(self, node: MemberExpr) -> MemberExpr:
        member = MemberExpr(self.expr(node.expr), node.name)
        if node.def_var:
            # This refers to an attribute and we don't transform attributes by default,
            # just normal variables.
            member.def_var = node.def_var
        self.copy_ref(member, node)
        return member

    def copy_ref(self, new: RefExpr, original: RefExpr) -> None:
        new.kind = original.kind
        new.fullname = original.fullname
        target = original.node
        if isinstance(target, Var):
            # Do not transform references to global variables. See
            # testGenericFunctionAliasExpand for an example where this is important.
            if original.kind != GDEF:
                target = self.visit_var(target)
        elif isinstance(target, Decorator):
            target = self.visit_var(target.var)
        elif isinstance(target, FuncDef):
            # Use a placeholder node for the function if it exists.
            target = self.func_placeholder_map.get(target, target)
        new.node = target
        new.is_new_def = original.is_new_def
        new.is_inferred_def = original.is_inferred_def

    def visit_yield_from_expr(self, node: YieldFromExpr) -> YieldFromExpr:
        return YieldFromExpr(self.expr(node.expr))

    def visit_yield_expr(self, node: YieldExpr) -> YieldExpr:
        return YieldExpr(self.optional_expr(node.expr))

    def visit_await_expr(self, node: AwaitExpr) -> AwaitExpr:
        return AwaitExpr(self.expr(node.expr))

    def visit_call_expr(self, node: CallExpr) -> CallExpr:
        return CallExpr(
            self.expr(node.callee),
            self.expressions(node.args),
            node.arg_kinds.copy(),
            node.arg_names.copy(),
            self.optional_expr(node.analyzed),
        )

    def visit_op_expr(self, node: OpExpr) -> OpExpr:
        new = OpExpr(
            node.op,
            self.expr(node.left),
            self.expr(node.right),
            cast(Optional[TypeAliasExpr], self.optional_expr(node.analyzed)),
        )
        new.method_type = self.optional_type(node.method_type)
        return new

    def visit_comparison_expr(self, node: ComparisonExpr) -> ComparisonExpr:
        new = ComparisonExpr(node.operators, self.expressions(node.operands))
        new.method_types = [self.optional_type(t) for t in node.method_types]
        return new

    def visit_cast_expr(self, node: CastExpr) -> CastExpr:
        return CastExpr(self.expr(node.expr), self.type(node.type))

    def visit_assert_type_expr(self, node: AssertTypeExpr) -> AssertTypeExpr:
        return AssertTypeExpr(self.expr(node.expr), self.type(node.type))

    def visit_reveal_expr(self, node: RevealExpr) -> RevealExpr:
        if node.kind == REVEAL_TYPE:
            assert node.expr is not None
            return RevealExpr(kind=REVEAL_TYPE, expr=self.expr(node.expr))
        else:
            # Reveal locals expressions don't have any sub expressions
            return node

    def visit_super_expr(self, node: SuperExpr) -> SuperExpr:
        call = self.expr(node.call)
        assert isinstance(call, CallExpr)
        new = SuperExpr(node.name, call)
        new.info = node.info
        return new

    def visit_assignment_expr(self, node: AssignmentExpr) -> AssignmentExpr:
        return AssignmentExpr(self.duplicate_name(node.target), self.expr(node.value))

    def visit_unary_expr(self, node: UnaryExpr) -> UnaryExpr:
        new = UnaryExpr(node.op, self.expr(node.expr))
        new.method_type = self.optional_type(node.method_type)
        return new

    def visit_list_expr(self, node: ListExpr) -> ListExpr:
        return ListExpr(self.expressions(node.items))

    def visit_dict_expr(self, node: DictExpr) -> DictExpr:
        return DictExpr(
            [(self.expr(key) if key else None, self.expr(value)) for key, value in node.items]
        )

    def visit_tuple_expr(self, node: TupleExpr) -> TupleExpr:
        return TupleExpr(self.expressions(node.items))

    def visit_set_expr(self, node: SetExpr) -> SetExpr:
        return SetExpr(self.expressions(node.items))

    def visit_index_expr(self, node: IndexExpr) -> IndexExpr:
        new = IndexExpr(self.expr(node.base), self.expr(node.index))
        if node.method_type:
            new.method_type = self.type(node.method_type)
        if node.analyzed:
            if isinstance(node.analyzed, TypeApplication):
                new.analyzed = self.visit_type_application(node.analyzed)
            else:
                new.analyzed = self.visit_type_alias_expr(node.analyzed)
            new.analyzed.set_line(node.analyzed)
        return new

    def visit_type_application(self, node: TypeApplication) -> TypeApplication:
        return TypeApplication(self.expr(node.expr), self.types(node.types))

    def visit_list_comprehension(self, node: ListComprehension) -> ListComprehension:
        generator = self.duplicate_generator(node.generator)
        generator.set_line(node.generator)
        return ListComprehension(generator)

    def visit_set_comprehension(self, node: SetComprehension) -> SetComprehension:
        generator = self.duplicate_generator(node.generator)
        generator.set_line(node.generator)
        return SetComprehension(generator)

    def visit_dictionary_comprehension(
        self, node: DictionaryComprehension
    ) -> DictionaryComprehension:
        return DictionaryComprehension(
            self.expr(node.key),
            self.expr(node.value),
            [self.expr(index) for index in node.indices],
            [self.expr(s) for s in node.sequences],
            [[self.expr(cond) for cond in conditions] for conditions in node.condlists],
            node.is_async,
        )

    def visit_generator_expr(self, node: GeneratorExpr) -> GeneratorExpr:
        return self.duplicate_generator(node)

    def duplicate_generator(self, node: GeneratorExpr) -> GeneratorExpr:
        return GeneratorExpr(
            self.expr(node.left_expr),
            [self.expr(index) for index in node.indices],
            [self.expr(s) for s in node.sequences],
            [[self.expr(cond) for cond in conditions] for conditions in node.condlists],
            node.is_async,
        )

    def visit_slice_expr(self, node: SliceExpr) -> SliceExpr:
        return SliceExpr(
            self.optional_expr(node.begin_index),
            self.optional_expr(node.end_index),
            self.optional_expr(node.stride),
        )

    def visit_conditional_expr(self, node: ConditionalExpr) -> ConditionalExpr:
        return ConditionalExpr(
            self.expr(node.cond), self.expr(node.if_expr), self.expr(node.else_expr)
        )

    def visit_type_var_expr(self, node: TypeVarExpr) -> TypeVarExpr:
        return TypeVarExpr(
            node.name,
            node.fullname,
            self.types(node.values),
            self.type(node.upper_bound),
            self.type(node.default),
            variance=node.variance,
        )

    def visit_paramspec_expr(self, node: ParamSpecExpr) -> ParamSpecExpr:
        return ParamSpecExpr(
            node.name,
            node.fullname,
            self.type(node.upper_bound),
            self.type(node.default),
            variance=node.variance,
        )

    def visit_type_var_tuple_expr(self, node: TypeVarTupleExpr) -> TypeVarTupleExpr:
        return TypeVarTupleExpr(
            node.name,
            node.fullname,
            self.type(node.upper_bound),
            node.tuple_fallback,
            self.type(node.default),
            variance=node.variance,
        )

    def visit_type_alias_expr(self, node: TypeAliasExpr) -> TypeAliasExpr:
        return TypeAliasExpr(node.node)

    def visit_newtype_expr(self, node: NewTypeExpr) -> NewTypeExpr:
        res = NewTypeExpr(node.name, node.old_type, line=node.line, column=node.column)
        res.info = node.info
        return res

    def visit_namedtuple_expr(self, node: NamedTupleExpr) -> NamedTupleExpr:
        return NamedTupleExpr(node.info)

    def visit_enum_call_expr(self, node: EnumCallExpr) -> EnumCallExpr:
        return EnumCallExpr(node.info, node.items, node.values)

    def visit_typeddict_expr(self, node: TypedDictExpr) -> Node:
        return TypedDictExpr(node.info)

    def visit__promote_expr(self, node: PromoteExpr) -> PromoteExpr:
        return PromoteExpr(node.type)

    def visit_temp_node(self, node: TempNode) -> TempNode:
        return TempNode(self.type(node.type))

    def node(self, node: Node) -> Node:
        new = node.accept(self)
        new.set_line(node)
        return new

    def mypyfile(self, node: MypyFile) -> MypyFile:
        new = node.accept(self)
        assert isinstance(new, MypyFile)
        new.set_line(node)
        return new

    def expr(self, expr: Expression) -> Expression:
        new = expr.accept(self)
        assert isinstance(new, Expression)
        new.set_line(expr)
        return new

    def stmt(self, stmt: Statement) -> Statement:
        new = stmt.accept(self)
        assert isinstance(new, Statement)
        new.set_line(stmt)
        return new

    def pattern(self, pattern: Pattern) -> Pattern:
        new = pattern.accept(self)
        assert isinstance(new, Pattern)
        new.set_line(pattern)
        return new

    # Helpers
    #
    # All the node helpers also propagate line numbers.

    def optional_expr(self, expr: Expression | None) -> Expression | None:
        if expr:
            return self.expr(expr)
        else:
            return None

    def block(self, block: Block) -> Block:
        new = self.visit_block(block)
        new.line = block.line
        return new

    def optional_block(self, block: Block | None) -> Block | None:
        if block:
            return self.block(block)
        else:
            return None

    def statements(self, statements: list[Statement]) -> list[Statement]:
        return [self.stmt(stmt) for stmt in statements]

    def expressions(self, expressions: list[Expression]) -> list[Expression]:
        return [self.expr(expr) for expr in expressions]

    def optional_expressions(
        self, expressions: Iterable[Expression | None]
    ) -> list[Expression | None]:
        return [self.optional_expr(expr) for expr in expressions]

    def blocks(self, blocks: list[Block]) -> list[Block]:
        return [self.block(block) for block in blocks]

    def names(self, names: list[NameExpr]) -> list[NameExpr]:
        return [self.duplicate_name(name) for name in names]

    def optional_names(self, names: Iterable[NameExpr | None]) -> list[NameExpr | None]:
        result: list[NameExpr | None] = []
        for name in names:
            if name:
                result.append(self.duplicate_name(name))
            else:
                result.append(None)
        return result

    def type(self, type: Type) -> Type:
        # Override this method to transform types.
        return type

    def optional_type(self, type: Type | None) -> Type | None:
        if type:
            return self.type(type)
        else:
            return None

    def types(self, types: list[Type]) -> list[Type]:
        return [self.type(type) for type in types]


class FuncMapInitializer(TraverserVisitor):
    """This traverser creates mappings from nested FuncDefs to placeholder FuncDefs.

    The placeholders will later be replaced with transformed nodes.
    """

    def __init__(self, transformer: TransformVisitor) -> None:
        self.transformer = transformer

    def visit_func_def(self, node: FuncDef) -> None:
        if node not in self.transformer.func_placeholder_map:
            # Haven't seen this FuncDef before, so create a placeholder node.
            self.transformer.func_placeholder_map[node] = FuncDef(
                node.name, node.arguments, node.body, None
            )
        super().visit_func_def(node)
