Module dataclasses_json.core

View Source
import copy

import json

import warnings

from collections import defaultdict, namedtuple

# noinspection PyProtectedMember

from dataclasses import (MISSING,

                         _is_dataclass_instance,

                         fields,

                         is_dataclass  # type: ignore

                         )

from datetime import datetime, timezone

from decimal import Decimal

from enum import Enum

from typing import Any, Collection, Mapping, Union, get_type_hints, Tuple

from uuid import UUID

from typing_inspect import is_union_type  # type: ignore

from dataclasses_json import cfg

from dataclasses_json.utils import (_get_type_cons, _get_type_origin,

                                    _handle_undefined_parameters_safe,

                                    _is_collection, _is_mapping, _is_new_type,

                                    _is_optional, _isinstance_safe,

                                    _issubclass_safe)

Json = Union[dict, list, str, int, float, bool, None]

confs = ['encoder', 'decoder', 'mm_field', 'letter_case', 'exclude']

FieldOverride = namedtuple('FieldOverride', confs)

class _ExtendedEncoder(json.JSONEncoder):

    def default(self, o) -> Json:

        result: Json

        if _isinstance_safe(o, Collection):

            if _isinstance_safe(o, Mapping):

                result = dict(o)

            else:

                result = list(o)

        elif _isinstance_safe(o, datetime):

            result = o.timestamp()

        elif _isinstance_safe(o, UUID):

            result = str(o)

        elif _isinstance_safe(o, Enum):

            result = o.value

        elif _isinstance_safe(o, Decimal):

            result = str(o)

        else:

            result = json.JSONEncoder.default(self, o)

        return result

def _user_overrides_or_exts(cls):

    global_metadata = defaultdict(dict)

    encoders = cfg.global_config.encoders

    decoders = cfg.global_config.decoders

    mm_fields = cfg.global_config.mm_fields

    for field in fields(cls):

        if field.type in encoders:

            global_metadata[field.name]['encoder'] = encoders[field.type]

        if field.type in decoders:

            global_metadata[field.name]['decoder'] = decoders[field.type]

        if field.type in mm_fields:

            global_metadata[field.name]['mm_fields'] = mm_fields[field.type]

    try:

        cls_config = (cls.dataclass_json_config

                      if cls.dataclass_json_config is not None else {})

    except AttributeError:

        cls_config = {}

    overrides = {}

    for field in fields(cls):

        field_config = {}

        # first apply global overrides or extensions

        field_metadata = global_metadata[field.name]

        if 'encoder' in field_metadata:

            field_config['encoder'] = field_metadata['encoder']

        if 'decoder' in field_metadata:

            field_config['decoder'] = field_metadata['decoder']

        if 'mm_field' in field_metadata:

            field_config['mm_field'] = field_metadata['mm_field']

        # then apply class-level overrides or extensions

        field_config.update(cls_config)

        # last apply field-level overrides or extensions

        field_config.update(field.metadata.get('dataclasses_json', {}))

        overrides[field.name] = FieldOverride(*map(field_config.get, confs))

    return overrides

def _encode_json_type(value, default=_ExtendedEncoder().default):

    if isinstance(value, Json.__args__):  # type: ignore

        return value

    return default(value)

def _encode_overrides(kvs, overrides, encode_json=False):

    override_kvs = {}

    for k, v in kvs.items():

        if k in overrides:

            exclude = overrides[k].exclude

            # If the exclude predicate returns true, the key should be

            #  excluded from encoding, so skip the rest of the loop

            if exclude and exclude(v):

                continue

            letter_case = overrides[k].letter_case

            original_key = k

            k = letter_case(k) if letter_case is not None else k

            encoder = overrides[original_key].encoder

            v = encoder(v) if encoder is not None else v

        if encode_json:

            v = _encode_json_type(v)

        override_kvs[k] = v

    return override_kvs

def _decode_letter_case_overrides(field_names, overrides):

    """Override letter case of field names for encode/decode"""

    names = {}

    for field_name in field_names:

        field_override = overrides.get(field_name)

        if field_override is not None:

            letter_case = field_override.letter_case

            if letter_case is not None:

                names[letter_case(field_name)] = field_name

    return names

def _decode_dataclass(cls, kvs, infer_missing):

    if _isinstance_safe(kvs, cls):

        return kvs

    overrides = _user_overrides_or_exts(cls)

    kvs = {} if kvs is None and infer_missing else kvs

    field_names = [field.name for field in fields(cls)]

    decode_names = _decode_letter_case_overrides(field_names, overrides)

    kvs = {decode_names.get(k, k): v for k, v in kvs.items()}

    missing_fields = {field for field in fields(cls) if field.name not in kvs}

    for field in missing_fields:

        if field.default is not MISSING:

            kvs[field.name] = field.default

        elif field.default_factory is not MISSING:

            kvs[field.name] = field.default_factory()

        elif infer_missing:

            kvs[field.name] = None

    # Perform undefined parameter action

    kvs = _handle_undefined_parameters_safe(cls, kvs, usage="from")

    init_kwargs = {}

    types = get_type_hints(cls)

    for field in fields(cls):

        # The field should be skipped from being added

        # to init_kwargs as it's not intended as a constructor argument.

        if not field.init:

            continue

        field_value = kvs[field.name]

        field_type = types[field.name]

        if field_value is None and not _is_optional(field_type):

            warning = (f"value of non-optional type {field.name} detected "

                       f"when decoding {cls.__name__}")

            if infer_missing:

                warnings.warn(

                    f"Missing {warning} and was defaulted to None by "

                    f"infer_missing=True. "

                    f"Set infer_missing=False (the default) to prevent this "

                    f"behavior.", RuntimeWarning)

            else:

                warnings.warn(f"`NoneType` object {warning}.", RuntimeWarning)

            init_kwargs[field.name] = field_value

            continue

        while True:

            if not _is_new_type(field_type):

                break

            field_type = field_type.__supertype__

        if (field.name in overrides

                and overrides[field.name].decoder is not None):

            # FIXME hack

            if field_type is type(field_value):

                init_kwargs[field.name] = field_value

            else:

                init_kwargs[field.name] = overrides[field.name].decoder(

                    field_value)

        elif is_dataclass(field_type):

            # FIXME this is a band-aid to deal with the value already being

            # serialized when handling nested marshmallow schema

            # proper fix is to investigate the marshmallow schema generation

            # code

            if is_dataclass(field_value):

                value = field_value

            else:

                value = _decode_dataclass(field_type, field_value,

                                          infer_missing)

            init_kwargs[field.name] = value

        elif _is_supported_generic(field_type) and field_type != str:

            init_kwargs[field.name] = _decode_generic(field_type,

                                                      field_value,

                                                      infer_missing)

        else:

            init_kwargs[field.name] = _support_extended_types(field_type,

                                                              field_value)

    return cls(**init_kwargs)

def _support_extended_types(field_type, field_value):

    if _issubclass_safe(field_type, datetime):

        # FIXME this is a hack to deal with mm already decoding

        # the issue is we want to leverage mm fields' missing argument

        # but need this for the object creation hook

        if isinstance(field_value, datetime):

            res = field_value

        else:

            tz = datetime.now(timezone.utc).astimezone().tzinfo

            res = datetime.fromtimestamp(field_value, tz=tz)

    elif _issubclass_safe(field_type, Decimal):

        res = (field_value

               if isinstance(field_value, Decimal)

               else Decimal(field_value))

    elif _issubclass_safe(field_type, UUID):

        res = (field_value

               if isinstance(field_value, UUID)

               else UUID(field_value))

    else:

        res = field_value

    return res

def _is_supported_generic(type_):

    not_str = not _issubclass_safe(type_, str)

    is_enum = _issubclass_safe(type_, Enum)

    return (not_str and _is_collection(type_)) or _is_optional(

        type_) or is_union_type(type_) or is_enum

def _decode_generic(type_, value, infer_missing):

    if value is None:

        res = value

    elif _issubclass_safe(type_, Enum):

        # Convert to an Enum using the type as a constructor.

        # Assumes a direct match is found.

        res = type_(value)

    # FIXME this is a hack to fix a deeper underlying issue. A refactor is due.

    elif _is_collection(type_):

        if _is_mapping(type_):

            k_type, v_type = getattr(type_, "__args__", (Any, Any))

            # a mapping type has `.keys()` and `.values()`

            # (see collections.abc)

            ks = _decode_dict_keys(k_type, value.keys(), infer_missing)

            vs = _decode_items(v_type, value.values(), infer_missing)

            xs = zip(ks, vs)

        else:

            xs = _decode_items(type_.__args__[0], value, infer_missing)

        # get the constructor if using corresponding generic type in `typing`

        # otherwise fallback on constructing using type_ itself

        try:

            res = _get_type_cons(type_)(xs)

        except (TypeError, AttributeError):

            res = type_(xs)

    else:  # Optional or Union

        if not hasattr(type_, "__args__"):

            # Any, just accept

            res = value

        elif _is_optional(type_) and len(type_.__args__) == 2:  # Optional

            type_arg = type_.__args__[0]

            if is_dataclass(type_arg) or is_dataclass(value):

                res = _decode_dataclass(type_arg, value, infer_missing)

            elif _is_supported_generic(type_arg):

                res = _decode_generic(type_arg, value, infer_missing)

            else:

                res = _support_extended_types(type_arg, value)

        else:  # Union (already decoded or unsupported 'from_json' used)

            res = value

    return res

def _decode_dict_keys(key_type, xs, infer_missing):

    """

    Because JSON object keys must be strs, we need the extra step of decoding

    them back into the user's chosen python type

    """

    decode_function = key_type

    # handle NoneType keys... it's weird to type a Dict as NoneType keys

    # but it's valid...

    if key_type is None or key_type == Any:

        decode_function = key_type = (lambda x: x)

    # handle a nested python dict that has tuples for keys. E.g. for

    # Dict[Tuple[int], int], key_type will be typing.Tuple[int], but

    # decode_function should be tuple, so map() doesn't break.

    #

    # Note: _get_type_origin() will return typing.Tuple for python

    # 3.6 and tuple for 3.7 and higher.

    elif _get_type_origin(key_type) in {tuple, Tuple}:

        decode_function = tuple

        key_type = key_type

    return map(decode_function, _decode_items(key_type, xs, infer_missing))

def _decode_items(type_arg, xs, infer_missing):

    """

    This is a tricky situation where we need to check both the annotated

    type info (which is usually a type from `typing`) and check the

    value's type directly using `type()`.

    If the type_arg is a generic we can use the annotated type, but if the

    type_arg is a typevar we need to extract the reified type information

    hence the check of `is_dataclass(vs)`

    """

    if is_dataclass(type_arg) or is_dataclass(xs):

        items = (_decode_dataclass(type_arg, x, infer_missing)

                 for x in xs)

    elif _is_supported_generic(type_arg):

        items = (_decode_generic(type_arg, x, infer_missing) for x in xs)

    else:

        items = xs

    return items

def _asdict(obj, encode_json=False):

    """

    A re-implementation of `asdict` (based on the original in the `dataclasses`

    source) to support arbitrary Collection and Mapping types.

    """

    if _is_dataclass_instance(obj):

        result = []

        overrides = _user_overrides_or_exts(obj)

        for field in fields(obj):

            if overrides[field.name].encoder:

                value = getattr(obj, field.name)

            else:

                value = _asdict(

                    getattr(obj, field.name),

                    encode_json=encode_json

                )

            result.append((field.name, value))

        result = _handle_undefined_parameters_safe(cls=obj, kvs=dict(result),

                                                   usage="to")

        return _encode_overrides(dict(result), _user_overrides_or_exts(obj),

                                 encode_json=encode_json)

    elif isinstance(obj, Mapping):

        return dict((_asdict(k, encode_json=encode_json),

                     _asdict(v, encode_json=encode_json)) for k, v in

                    obj.items())

    elif isinstance(obj, Collection) and not isinstance(obj, str) \

            and not isinstance(obj, bytes):

        return list(_asdict(v, encode_json=encode_json) for v in obj)

    else:

        return copy.deepcopy(obj)

Variables

Json
confs

Classes

FieldOverride

class FieldOverride(
    /,
    *args,
    **kwargs
)

FieldOverride(encoder, decoder, mm_field, letter_case, exclude)

Ancestors (in MRO)

  • builtins.tuple

Class variables

decoder
encoder
exclude
letter_case
mm_field

Methods

count
def count(
    self,
    value,
    /
)

Return number of occurrences of value.

index
def index(
    self,
    value,
    start=0,
    stop=9223372036854775807,
    /
)

Return first index of value.

Raises ValueError if the value is not present.