# functions to transform a c class into a dataclass

from collections import OrderedDict
from textwrap import dedent
import operator

from . import ExprNodes
from . import Nodes
from . import PyrexTypes
from . import Builtin
from . import Naming
from .Errors import error, warning
from .Code import UtilityCode, TempitaUtilityCode, PyxCodeWriter
from .Visitor import VisitorTransform
from .StringEncoding import EncodedString
from .TreeFragment import TreeFragment
from .ParseTreeTransforms import NormalizeTree, SkipDeclarations
from .Options import copy_inherited_directives

_dataclass_loader_utilitycode = None

def make_dataclasses_module_callnode(pos):
    global _dataclass_loader_utilitycode
    if not _dataclass_loader_utilitycode:
        python_utility_code = UtilityCode.load_cached("Dataclasses_fallback", "Dataclasses.py")
        python_utility_code = EncodedString(python_utility_code.impl)
        _dataclass_loader_utilitycode = TempitaUtilityCode.load(
            "SpecificModuleLoader", "Dataclasses.c",
            context={'cname': "dataclasses", 'py_code': python_utility_code.as_c_string_literal()})
    return ExprNodes.PythonCapiCallNode(
        pos, "__Pyx_Load_dataclasses_Module",
        PyrexTypes.CFuncType(PyrexTypes.py_object_type, []),
        utility_code=_dataclass_loader_utilitycode,
        args=[],
    )

def make_dataclass_call_helper(pos, callable, kwds):
    utility_code = UtilityCode.load_cached("DataclassesCallHelper", "Dataclasses.c")
    func_type = PyrexTypes.CFuncType(
        PyrexTypes.py_object_type, [
            PyrexTypes.CFuncTypeArg("callable", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("kwds", PyrexTypes.py_object_type, None)
        ],
    )
    return ExprNodes.PythonCapiCallNode(
        pos,
        function_name="__Pyx_DataclassesCallHelper",
        func_type=func_type,
        utility_code=utility_code,
        args=[callable, kwds],
    )


class RemoveAssignmentsToNames(VisitorTransform, SkipDeclarations):
    """
    Cython (and Python) normally treats

    class A:
         x = 1

    as generating a class attribute. However for dataclasses the `= 1` should be interpreted as
    a default value to initialize an instance attribute with.
    This transform therefore removes the `x=1` assignment so that the class attribute isn't
    generated, while recording what it has removed so that it can be used in the initialization.
    """
    def __init__(self, names):
        super().__init__()
        self.names = names
        self.removed_assignments = {}

    def visit_CClassNode(self, node):
        self.visitchildren(node)
        return node

    def visit_PyClassNode(self, node):
        return node  # go no further

    def visit_FuncDefNode(self, node):
        return node  # go no further

    def visit_SingleAssignmentNode(self, node):
        if node.lhs.is_name and node.lhs.name in self.names:
            if node.lhs.name in self.removed_assignments:
                warning(node.pos, ("Multiple assignments for '%s' in dataclass; "
                                   "using most recent") % node.lhs.name, 1)
            self.removed_assignments[node.lhs.name] = node.rhs
            return []
        return node

    # I believe cascaded assignment is always a syntax error with annotations
    # so there's no need to define visit_CascadedAssignmentNode

    def visit_Node(self, node):
        self.visitchildren(node)
        return node


class TemplateCode:
    """
    Adds the ability to keep track of placeholder argument names to PyxCodeWriter.

    Also adds extra_stats which are nodes bundled at the end when this
    is converted to a tree.
    """
    _placeholder_count = 0

    def __init__(self, writer=None, placeholders=None, extra_stats=None):
        self.writer = PyxCodeWriter() if writer is None else writer
        self.placeholders = {} if placeholders is None else placeholders
        self.extra_stats = [] if extra_stats is None else extra_stats

    def add_code_line(self, code_line):
        self.writer.putln(code_line)

    def add_code_chunk(self, code_chunk):
        self.writer.put_chunk(code_chunk)

    def reset(self):
        # don't attempt to reset placeholders - it really doesn't matter if
        # we have unused placeholders
        self.writer.reset()

    def empty(self):
        return self.writer.empty()

    def indent(self):
        self.writer.indent()

    def dedent(self):
        self.writer.dedent()

    def indenter(self, block_opener_line):
        return self.writer.indenter(block_opener_line)

    def new_placeholder(self, field_names, value):
        name = self._new_placeholder_name(field_names)
        self.placeholders[name] = value
        return name

    def add_extra_statements(self, statements):
        if self.extra_stats is None:
            assert False, "Can only use add_extra_statements on top-level writer"
        self.extra_stats.extend(statements)

    def _new_placeholder_name(self, field_names):
        while True:
            name = f"DATACLASS_PLACEHOLDER_{self._placeholder_count:d}"
            if (name not in self.placeholders
                    and name not in field_names):
                # make sure name isn't already used and doesn't
                # conflict with a variable name (which is unlikely but possible)
                break
            self._placeholder_count += 1
        return name

    def generate_tree(self, level='c_class'):
        stat_list_node = TreeFragment(
            self.writer.getvalue(),
            level=level,
            pipeline=[NormalizeTree(None)],
        ).substitute(self.placeholders)

        stat_list_node.stats += self.extra_stats
        return stat_list_node

    def insertion_point(self):
        new_writer = self.writer.insertion_point()
        return TemplateCode(
            writer=new_writer,
            placeholders=self.placeholders,
            extra_stats=self.extra_stats
        )


class _MISSING_TYPE:
    pass
MISSING = _MISSING_TYPE()


class Field:
    """
    Field is based on the dataclasses.field class from the standard library module.
    It is used internally during the generation of Cython dataclasses to keep track
    of the settings for individual attributes.

    Attributes of this class are stored as nodes so they can be used in code construction
    more readily (i.e. we store BoolNode rather than bool)
    """
    default = MISSING
    default_factory = MISSING
    private = False

    literal_keys = ("repr", "hash", "init", "compare", "metadata")

    # default values are defined by the CPython dataclasses.field
    def __init__(self, pos, default=MISSING, default_factory=MISSING,
                 repr=None, hash=None, init=None,
                 compare=None, metadata=None,
                 is_initvar=False, is_classvar=False,
                 **additional_kwds):
        if default is not MISSING:
            self.default = default
        if default_factory is not MISSING:
            self.default_factory = default_factory
        self.repr = repr or ExprNodes.BoolNode(pos, value=True)
        self.hash = hash or ExprNodes.NoneNode(pos)
        self.init = init or ExprNodes.BoolNode(pos, value=True)
        self.compare = compare or ExprNodes.BoolNode(pos, value=True)
        self.metadata = metadata or ExprNodes.NoneNode(pos)
        self.is_initvar = is_initvar
        self.is_classvar = is_classvar

        for k, v in additional_kwds.items():
            # There should not be any additional keywords!
            error(v.pos, "cython.dataclasses.field() got an unexpected keyword argument '%s'" % k)

        for field_name in self.literal_keys:
            field_value = getattr(self, field_name)
            if not field_value.is_literal:
                error(field_value.pos,
                      "cython.dataclasses.field parameter '%s' must be a literal value" % field_name)

    def iterate_record_node_arguments(self):
        for key in (self.literal_keys + ('default', 'default_factory')):
            value = getattr(self, key)
            if value is not MISSING:
                yield key, value


def process_class_get_fields(node):
    var_entries = node.scope.var_entries
    # order of definition is used in the dataclass
    var_entries = sorted(var_entries, key=operator.attrgetter('pos'))
    var_names = [entry.name for entry in var_entries]

    # don't treat `x = 1` as an assignment of a class attribute within the dataclass
    transform = RemoveAssignmentsToNames(var_names)
    transform(node)
    default_value_assignments = transform.removed_assignments

    base_type = node.base_type
    fields = OrderedDict()
    while base_type:
        if base_type.is_external or not base_type.scope.implemented:
            warning(node.pos, "Cannot reliably handle Cython dataclasses with base types "
                "in external modules since it is not possible to tell what fields they have", 2)
        if base_type.dataclass_fields:
            fields = base_type.dataclass_fields.copy()
            break
        base_type = base_type.base_type

    for entry in var_entries:
        name = entry.name
        is_initvar = entry.declared_with_pytyping_modifier("dataclasses.InitVar")
        # TODO - classvars aren't included in "var_entries" so are missed here
        # and thus this code is never triggered
        is_classvar = entry.declared_with_pytyping_modifier("typing.ClassVar")
        if name in default_value_assignments:
            assignment = default_value_assignments[name]
            if (isinstance(assignment, ExprNodes.CallNode) and (
                    assignment.function.as_cython_attribute() == "dataclasses.field" or
                    Builtin.exprnode_to_known_standard_library_name(
                        assignment.function, node.scope) == "dataclasses.field")):
                # I believe most of this is well-enforced when it's treated as a directive
                # but it doesn't hurt to make sure
                valid_general_call = (isinstance(assignment, ExprNodes.GeneralCallNode)
                        and isinstance(assignment.positional_args, ExprNodes.TupleNode)
                        and not assignment.positional_args.args
                        and (assignment.keyword_args is None or isinstance(assignment.keyword_args, ExprNodes.DictNode)))
                valid_simple_call = (isinstance(assignment, ExprNodes.SimpleCallNode) and not assignment.args)
                if not (valid_general_call or valid_simple_call):
                    error(assignment.pos, "Call to 'cython.dataclasses.field' must only consist "
                          "of compile-time keyword arguments")
                    continue
                keyword_args = assignment.keyword_args.as_python_dict() if valid_general_call and assignment.keyword_args else {}
                if 'default' in keyword_args and 'default_factory' in keyword_args:
                    error(assignment.pos, "cannot specify both default and default_factory")
                    continue
                field = Field(node.pos, **keyword_args)
            else:
                if assignment.type in [Builtin.list_type, Builtin.dict_type, Builtin.set_type]:
                    # The standard library module generates a TypeError at runtime
                    # in this situation.
                    # Error message is copied from CPython
                    error(assignment.pos, "mutable default <class '{}'> for field {} is not allowed: "
                          "use default_factory".format(assignment.type.name, name))

                field = Field(node.pos, default=assignment)
        else:
            field = Field(node.pos)
        field.is_initvar = is_initvar
        field.is_classvar = is_classvar
        if entry.visibility == "private":
            field.private = True
        fields[name] = field
    node.entry.type.dataclass_fields = fields
    return fields


def handle_cclass_dataclass(node, dataclass_args, analyse_decs_transform):
    # default argument values from https://docs.python.org/3/library/dataclasses.html
    kwargs = dict(init=True, repr=True, eq=True,
                  order=False, unsafe_hash=False,
                  frozen=False, kw_only=False, match_args=True)
    if dataclass_args is not None:
        if dataclass_args[0]:
            error(node.pos, "cython.dataclasses.dataclass takes no positional arguments")
        for k, v in dataclass_args[1].items():
            if k not in kwargs:
                error(node.pos,
                      "cython.dataclasses.dataclass() got an unexpected keyword argument '%s'" % k)
            if not isinstance(v, ExprNodes.BoolNode):
                error(node.pos,
                      "Arguments passed to cython.dataclasses.dataclass must be True or False")
            kwargs[k] = v.value

    kw_only = kwargs['kw_only']

    fields = process_class_get_fields(node)

    dataclass_module = make_dataclasses_module_callnode(node.pos)

    # create __dataclass_params__ attribute. I try to use the exact
    # `_DataclassParams` class defined in the standard library module if at all possible
    # for maximum duck-typing compatibility.
    dataclass_params_func = ExprNodes.AttributeNode(node.pos, obj=dataclass_module,
                                                    attribute=EncodedString("_DataclassParams"))
    dataclass_params_keywords = ExprNodes.DictNode.from_pairs(
        node.pos,
        [ (ExprNodes.IdentifierStringNode(node.pos, value=EncodedString(k)),
           ExprNodes.BoolNode(node.pos, value=v))
          for k, v in kwargs.items() ] +
        [ (ExprNodes.IdentifierStringNode(node.pos, value=EncodedString(k)),
           ExprNodes.BoolNode(node.pos, value=v))
          for k, v in [('kw_only', kw_only),
                       ('slots', False), ('weakref_slot', False)]
        ])
    dataclass_params = make_dataclass_call_helper(
        node.pos, dataclass_params_func, dataclass_params_keywords)
    dataclass_params_assignment = Nodes.SingleAssignmentNode(
        node.pos,
        lhs = ExprNodes.NameNode(node.pos, name=EncodedString("__dataclass_params__")),
        rhs = dataclass_params)

    dataclass_fields_stats = _set_up_dataclass_fields(node, fields, dataclass_module)

    stats = Nodes.StatListNode(node.pos,
                               stats=[dataclass_params_assignment] + dataclass_fields_stats)

    code = TemplateCode()
    generate_init_code(code, kwargs['init'], node, fields, kw_only)
    generate_match_args(code, kwargs['match_args'], node, fields, kw_only)
    generate_repr_code(code, kwargs['repr'], node, fields)
    generate_eq_code(code, kwargs['eq'], node, fields)
    generate_order_code(code, kwargs['order'], node, fields)
    generate_hash_code(code, kwargs['unsafe_hash'], kwargs['eq'], kwargs['frozen'], node, fields)

    stats.stats += code.generate_tree().stats

    # turn off annotation typing, so all arguments to __init__ are accepted as
    # generic objects and thus can accept _HAS_DEFAULT_FACTORY.
    # Type conversion comes later
    comp_directives = Nodes.CompilerDirectivesNode(node.pos,
        directives=copy_inherited_directives(node.scope.directives, annotation_typing=False),
        body=stats)

    comp_directives.analyse_declarations(node.scope)
    # probably already in this scope, but it doesn't hurt to make sure
    analyse_decs_transform.enter_scope(node, node.scope)
    analyse_decs_transform.visit(comp_directives)
    analyse_decs_transform.exit_scope()

    node.body.stats.append(comp_directives)


def generate_init_code(code, init, node, fields, kw_only):
    """
    Notes on CPython generated "__init__":
    * Implemented in `_init_fn`.
    * The use of the `dataclasses._HAS_DEFAULT_FACTORY` sentinel value as
      the default argument for fields that need constructing with a factory
      function is copied from the CPython implementation. (`None` isn't
      suitable because it could also be a value for the user to pass.)
      There's no real reason why it needs importing from the dataclasses module
      though - it could equally be a value generated by Cython when the module loads.
    * seen_default and the associated error message are copied directly from Python
    * Call to user-defined __post_init__ function (if it exists) is copied from
      CPython.

    Cython behaviour deviates a little here (to be decided if this is right...)
    Because the class variable from the assignment does not exist Cython fields will
    return None (or whatever their type default is) if not initialized while Python
    dataclasses will fall back to looking up the class variable.
    """
    if not init or node.scope.lookup_here("__init__"):
        return

    # selfname behaviour copied from the cpython module
    selfname = "__dataclass_self__" if "self" in fields else "self"
    args = [selfname]

    if kw_only:
        args.append("*")

    function_start_point = code.insertion_point()
    code = code.insertion_point()
    code.indent()

    # create a temp to get _HAS_DEFAULT_FACTORY
    dataclass_module = make_dataclasses_module_callnode(node.pos)
    has_default_factory = ExprNodes.AttributeNode(
        node.pos,
        obj=dataclass_module,
        attribute=EncodedString("_HAS_DEFAULT_FACTORY")
    )

    default_factory_placeholder = code.new_placeholder(fields, has_default_factory)

    seen_default = False
    for name, field in fields.items():
        entry = node.scope.lookup(name)
        if entry.annotation:
            annotation = f": {entry.annotation.string.value}"
        else:
            annotation = ""
        assignment = ''
        if field.default is not MISSING or field.default_factory is not MISSING:
            if field.init.value:
                seen_default = True
            if field.default_factory is not MISSING:
                ph_name = default_factory_placeholder
            else:
                ph_name = code.new_placeholder(fields, field.default)  # 'default' should be a node
            assignment = f" = {ph_name}"
        elif seen_default and not kw_only and field.init.value:
            error(entry.pos, ("non-default argument '%s' follows default argument "
                              "in dataclass __init__") % name)
            code.reset()
            return

        if field.init.value:
            args.append(f"{name}{annotation}{assignment}")

        if field.is_initvar:
            continue
        elif field.default_factory is MISSING:
            if field.init.value:
                code.add_code_line(f"{selfname}.{name} = {name}")
            elif assignment:
                # not an argument to the function, but is still initialized
                code.add_code_line(f"{selfname}.{name}{assignment}")
        else:
            ph_name = code.new_placeholder(fields, field.default_factory)
            if field.init.value:
                # close to:
                # def __init__(self, name=_PLACEHOLDER_VALUE):
                #     self.name = name_default_factory() if name is _PLACEHOLDER_VALUE else name
                code.add_code_line(
                    f"{selfname}.{name} = {ph_name}() if {name} is {default_factory_placeholder} else {name}"
                )
            else:
                # still need to use the default factory to initialize
                code.add_code_line(f"{selfname}.{name} = {ph_name}()")

    if node.scope.lookup("__post_init__"):
        post_init_vars = ", ".join(name for name, field in fields.items()
                                   if field.is_initvar)
        code.add_code_line(f"{selfname}.__post_init__({post_init_vars})")

    if code.empty():
        code.add_code_line("pass")

    args = ", ".join(args)
    function_start_point.add_code_line(f"def __init__({args}):")


def generate_match_args(code, match_args, node, fields, global_kw_only):
    """
    Generates a tuple containing what would be the positional args to __init__

    Note that this is generated even if the user overrides init
    """
    if not match_args or node.scope.lookup_here("__match_args__"):
        return
    positional_arg_names = []
    for field_name, field in fields.items():
        # TODO hasattr and global_kw_only can be removed once full kw_only support is added
        field_is_kw_only = global_kw_only or (
            hasattr(field, 'kw_only') and field.kw_only.value
        )
        if not field_is_kw_only:
            positional_arg_names.append(field_name)
    code.add_code_line("__match_args__ = %s" % str(tuple(positional_arg_names)))


def generate_repr_code(code, repr, node, fields):
    """
    The core of the CPython implementation is just:
    ['return self.__class__.__qualname__ + f"(' +
                     ', '.join([f"{f.name}={{self.{f.name}!r}}"
                                for f in fields]) +
                     ')"'],

    The only notable difference here is self.__class__.__qualname__ -> type(self).__name__
    which is because Cython currently supports Python 2.

    However, it also has some guards for recursive repr invocations. In the standard
    library implementation they're done with a wrapper decorator that captures a set
    (with the set keyed by id and thread). Here we create a set as a thread local
    variable and key only by id.
    """
    if not repr or node.scope.lookup("__repr__"):
        return

    # The recursive guard is likely a little costly, so skip it if possible.
    # is_gc_simple defines where it can contain recursive objects
    needs_recursive_guard = False
    for name in fields.keys():
        entry = node.scope.lookup(name)
        type_ = entry.type
        if type_.is_memoryviewslice:
            type_ = type_.dtype
        if not type_.is_pyobject:
            continue  # no GC
        if not type_.is_gc_simple:
            needs_recursive_guard = True
            break

    if needs_recursive_guard:
        code.add_code_chunk("""
            __pyx_recursive_repr_guard = __import__('threading').local()
            __pyx_recursive_repr_guard.running = set()
        """)

    with code.indenter("def __repr__(self):"):
        if needs_recursive_guard:
            code.add_code_chunk("""
                key = id(self)
                guard_set = self.__pyx_recursive_repr_guard.running
                if key in guard_set: return '...'
                guard_set.add(key)
                try:
            """)
            code.indent()

        strs = ["%s={self.%s!r}" % (name, name)
                for name, field in fields.items()
                if field.repr.value and not field.is_initvar]
        format_string = ", ".join(strs)

        code.add_code_chunk(f'''
            name = getattr(type(self), "__qualname__", None) or type(self).__name__
            return f'{{name}}({format_string})'
        ''')
        if needs_recursive_guard:
            code.dedent()
            with code.indenter("finally:"):
                code.add_code_line("guard_set.remove(key)")


def generate_cmp_code(code, op, funcname, node, fields):
    if node.scope.lookup_here(funcname):
        return

    names = [name for name, field in fields.items() if (field.compare.value and not field.is_initvar)]

    with code.indenter(f"def {funcname}(self, other):"):
        code.add_code_chunk(f"""
            if other.__class__ is not self.__class__: return NotImplemented

            cdef {node.class_name} other_cast
            other_cast = <{node.class_name}>other
        """)

        # The Python implementation of dataclasses.py does a tuple comparison
        # (roughly):
        #  return self._attributes_to_tuple() {op} other._attributes_to_tuple()
        #
        # For the Cython implementation a tuple comparison isn't an option because
        # not all attributes can be converted to Python objects and stored in a tuple
        #
        # TODO - better diagnostics of whether the types support comparison before
        #    generating the code. Plus, do we want to convert C structs to dicts and
        #    compare them that way (I think not, but it might be in demand)?
        checks = []
        op_without_equals = op.replace('=', '')

        for name in names:
            if op != '==':
                # tuple comparison rules - early elements take precedence
                code.add_code_line(f"if self.{name} {op_without_equals} other_cast.{name}: return True")
            code.add_code_line(f"if self.{name} != other_cast.{name}: return False")
        code.add_code_line(f"return {'True' if '=' in op else 'False'}")  # "() == ()" is True


def generate_eq_code(code, eq, node, fields):
    if not eq:
        return
    generate_cmp_code(code, "==", "__eq__", node, fields)


def generate_order_code(code, order, node, fields):
    if not order:
        return

    for op, name in [("<", "__lt__"),
                     ("<=", "__le__"),
                     (">", "__gt__"),
                     (">=", "__ge__")]:
        generate_cmp_code(code, op, name, node, fields)


def generate_hash_code(code, unsafe_hash, eq, frozen, node, fields):
    """
    Copied from CPython implementation - the intention is to follow this as far as
    is possible:
    #    +------------------- unsafe_hash= parameter
    #    |       +----------- eq= parameter
    #    |       |       +--- frozen= parameter
    #    |       |       |
    #    v       v       v    |        |        |
    #                         |   no   |  yes   |  <--- class has explicitly defined __hash__
    # +=======+=======+=======+========+========+
    # | False | False | False |        |        | No __eq__, use the base class __hash__
    # +-------+-------+-------+--------+--------+
    # | False | False | True  |        |        | No __eq__, use the base class __hash__
    # +-------+-------+-------+--------+--------+
    # | False | True  | False | None   |        | <-- the default, not hashable
    # +-------+-------+-------+--------+--------+
    # | False | True  | True  | add    |        | Frozen, so hashable, allows override
    # +-------+-------+-------+--------+--------+
    # | True  | False | False | add    | raise  | Has no __eq__, but hashable
    # +-------+-------+-------+--------+--------+
    # | True  | False | True  | add    | raise  | Has no __eq__, but hashable
    # +-------+-------+-------+--------+--------+
    # | True  | True  | False | add    | raise  | Not frozen, but hashable
    # +-------+-------+-------+--------+--------+
    # | True  | True  | True  | add    | raise  | Frozen, so hashable
    # +=======+=======+=======+========+========+
    # For boxes that are blank, __hash__ is untouched and therefore
    # inherited from the base class.  If the base is object, then
    # id-based hashing is used.

    The Python implementation creates a tuple of all the fields, then hashes them.
    This implementation creates a tuple of all the hashes of all the fields and hashes that.
    The reason for this slight difference is to avoid to-Python conversions for anything
    that Cython knows how to hash directly (It doesn't look like this currently applies to
    anything though...).
    """

    hash_entry = node.scope.lookup_here("__hash__")
    if hash_entry:
        # TODO ideally assignment of __hash__ to None shouldn't trigger this
        # but difficult to get the right information here
        if unsafe_hash:
            # error message taken from CPython dataclasses module
            error(node.pos, "Cannot overwrite attribute __hash__ in class %s" % node.class_name)
        return

    if not unsafe_hash:
        if not eq:
            return
        if not frozen:
            code.add_extra_statements([
                Nodes.SingleAssignmentNode(
                    node.pos,
                    lhs=ExprNodes.NameNode(node.pos, name=EncodedString("__hash__")),
                    rhs=ExprNodes.NoneNode(node.pos),
                )
            ])
            return

    names = [
        name for name, field in fields.items()
        if not field.is_initvar and (
            field.compare.value if field.hash.value is None else field.hash.value)
    ]

    # make a tuple of the hashes
    hash_tuple_items = ", ".join("self.%s" % name for name in names)
    if hash_tuple_items:
        hash_tuple_items += ","  # ensure that one arg form is a tuple

    # if we're here we want to generate a hash
    with code.indenter("def __hash__(self):"):
        code.add_code_line(f"return hash(({hash_tuple_items}))")


def get_field_type(pos, entry):
    """
    sets the .type attribute for a field

    Returns the annotation if possible (since this is what the dataclasses
    module does). If not (for example, attributes defined with cdef) then
    it creates a string fallback.
    """
    if entry.annotation:
        # Right now it doesn't look like cdef classes generate an
        # __annotations__ dict, therefore it's safe to just return
        # entry.annotation
        # (TODO: remove .string if we ditch PEP563)
        return entry.annotation.string
        # If they do in future then we may need to look up into that
        # to duplicating the node. The code below should do this:
        #class_name_node = ExprNodes.NameNode(pos, name=entry.scope.name)
        #annotations = ExprNodes.AttributeNode(
        #    pos, obj=class_name_node,
        #    attribute=EncodedString("__annotations__")
        #)
        #return ExprNodes.IndexNode(
        #    pos, base=annotations,
        #    index=ExprNodes.UnicodeNode(pos, value=entry.name)
        #)
    else:
        # it's slightly unclear what the best option is here - we could
        # try to return PyType_Type. This case should only happen with
        # attributes defined with cdef so Cython is free to make it's own
        # decision
        s = EncodedString(entry.type.declaration_code("", for_display=1))
        return ExprNodes.UnicodeNode(pos, value=s)


class FieldRecordNode(ExprNodes.ExprNode):
    """
    __dataclass_fields__ contains a bunch of field objects recording how each field
    of the dataclass was initialized (mainly corresponding to the arguments passed to
    the "field" function). This node is used for the attributes of these field objects.

    If possible, coerces `arg` to a Python object.
    Otherwise, generates a sensible backup string.
    """
    subexprs = ['arg']

    def __init__(self, pos, arg):
        super().__init__(pos, arg=arg)

    def analyse_types(self, env):
        self.arg.analyse_types(env)
        self.type = self.arg.type
        return self

    def coerce_to_pyobject(self, env):
        if self.arg.type.can_coerce_to_pyobject(env):
            return self.arg.coerce_to_pyobject(env)
        else:
            # A string representation of the code that gave the field seems like a reasonable
            # fallback. This'll mostly happen for "default" and "default_factory" where the
            # type may be a C-type that can't be converted to Python.
            return self._make_string()

    def _make_string(self):
        from .AutoDocTransforms import AnnotationWriter
        writer = AnnotationWriter(description="Dataclass field")
        string = writer.write(self.arg)
        return ExprNodes.UnicodeNode(self.pos, value=EncodedString(string))

    def generate_evaluation_code(self, code):
        return self.arg.generate_evaluation_code(code)


def _set_up_dataclass_fields(node, fields, dataclass_module):
    # For defaults and default_factories containing things like lambda,
    # they're already declared in the class scope, and it creates a big
    # problem if multiple copies are floating around in both the __init__
    # function, and in the __dataclass_fields__ structure.
    # Therefore, create module-level constants holding these values and
    # pass those around instead
    #
    # If possible we use the `Field` class defined in the standard library
    # module so that the information stored here is as close to a regular
    # dataclass as is possible.
    variables_assignment_stats = []
    for name, field in fields.items():
        if field.private:
            continue  # doesn't appear in the public interface
        for attrname in [ "default", "default_factory" ]:
            field_default = getattr(field, attrname)
            if field_default is MISSING or field_default.is_literal or field_default.is_name:
                # some simple cases where we don't need to set up
                # the variable as a module-level constant
                continue
            global_scope = node.scope.global_scope()
            module_field_name = global_scope.mangle(
                global_scope.mangle(Naming.dataclass_field_default_cname, node.class_name),
                name)
            # create an entry in the global scope for this variable to live
            field_node = ExprNodes.NameNode(field_default.pos, name=EncodedString(module_field_name))
            field_node.entry = global_scope.declare_var(
                field_node.name, type=field_default.type or PyrexTypes.unspecified_type,
                pos=field_default.pos, cname=field_node.name, is_cdef=True,
                # TODO: do we need to set 'pytyping_modifiers' here?
            )
            # replace the field so that future users just receive the namenode
            setattr(field, attrname, field_node)

            variables_assignment_stats.append(
                Nodes.SingleAssignmentNode(field_default.pos, lhs=field_node, rhs=field_default))

    placeholders = {}
    field_func = ExprNodes.AttributeNode(node.pos, obj=dataclass_module,
                                         attribute=EncodedString("field"))
    dc_fields = ExprNodes.DictNode(node.pos, key_value_pairs=[])
    dc_fields_namevalue_assignments = []

    for name, field in fields.items():
        if field.private:
            continue  # doesn't appear in the public interface
        type_placeholder_name = "PLACEHOLDER_%s" % name
        placeholders[type_placeholder_name] = get_field_type(
            node.pos, node.scope.entries[name]
        )

        # defining these make the fields introspect more like a Python dataclass
        field_type_placeholder_name = "PLACEHOLDER_FIELD_TYPE_%s" % name
        if field.is_initvar:
            placeholders[field_type_placeholder_name] = ExprNodes.AttributeNode(
                node.pos, obj=dataclass_module,
                attribute=EncodedString("_FIELD_INITVAR")
            )
        elif field.is_classvar:
            # TODO - currently this isn't triggered
            placeholders[field_type_placeholder_name] = ExprNodes.AttributeNode(
                node.pos, obj=dataclass_module,
                attribute=EncodedString("_FIELD_CLASSVAR")
            )
        else:
            placeholders[field_type_placeholder_name] = ExprNodes.AttributeNode(
                node.pos, obj=dataclass_module,
                attribute=EncodedString("_FIELD")
            )

        dc_field_keywords = ExprNodes.DictNode.from_pairs(
            node.pos,
            [(ExprNodes.IdentifierStringNode(node.pos, value=EncodedString(k)),
               FieldRecordNode(node.pos, arg=v))
              for k, v in field.iterate_record_node_arguments()]

        )
        dc_field_call = make_dataclass_call_helper(
            node.pos, field_func, dc_field_keywords
        )
        dc_fields.key_value_pairs.append(
            ExprNodes.DictItemNode(
                node.pos,
                key=ExprNodes.IdentifierStringNode(node.pos, value=EncodedString(name)),
                value=dc_field_call))
        dc_fields_namevalue_assignments.append(
            dedent(f"""\
                __dataclass_fields__[{name!r}].name = {name!r}
                __dataclass_fields__[{name!r}].type = {type_placeholder_name}
                __dataclass_fields__[{name!r}]._field_type = {field_type_placeholder_name}
            """))

    dataclass_fields_assignment = \
        Nodes.SingleAssignmentNode(node.pos,
                        lhs = ExprNodes.NameNode(node.pos,
                                        name=EncodedString("__dataclass_fields__")),
                        rhs = dc_fields)

    dc_fields_namevalue_assignments = "\n".join(dc_fields_namevalue_assignments)
    dc_fields_namevalue_assignments = TreeFragment(dc_fields_namevalue_assignments,
                                                   level="c_class",
                                                   pipeline=[NormalizeTree(None)])
    dc_fields_namevalue_assignments = dc_fields_namevalue_assignments.substitute(placeholders)

    return (variables_assignment_stats
            + [dataclass_fields_assignment]
            + dc_fields_namevalue_assignments.stats)
