1
1
# pyright: reportMissingTypeArgument=true, reportMissingParameterType=true
2
- import copy
3
2
import dataclasses
4
3
import enum
5
- import functools
6
4
import inspect
7
5
import itertools
8
6
import json
45
43
_RE_SNAKE_CASE_2 = re .compile (r"[A-Z]" )
46
44
47
45
48
- @functools .lru_cache (1024 )
46
+ __not_valid = object ()
47
+
48
+ __to_snake_case_cache : Dict [str , str ] = {}
49
+
50
+
49
51
def to_snake_case (s : str ) -> str :
50
- s = _RE_SNAKE_CASE_1 .sub ("_" , s )
51
- if not s :
52
- return s
53
- return s [0 ].lower () + _RE_SNAKE_CASE_2 .sub (lambda matched : "_" + matched .group (0 ).lower (), s [1 :])
52
+ result = __to_snake_case_cache .get (s , __not_valid )
53
+ if result is __not_valid :
54
+ s = _RE_SNAKE_CASE_1 .sub ("_" , s )
55
+ if not s :
56
+ result = s
57
+ else :
58
+ result = s [0 ].lower () + _RE_SNAKE_CASE_2 .sub (lambda matched : "_" + matched .group (0 ).lower (), s [1 :])
59
+ __to_snake_case_cache [s ] = result
60
+ return cast (str , result )
54
61
55
62
56
63
_RE_CAMEL_CASE_1 = re .compile (r"^[\-_\.]" )
57
64
_RE_CAMEL_CASE_2 = re .compile (r"[\-_\.\s]([a-z])" )
58
65
66
+ __to_snake_camel_cache : Dict [str , str ] = {}
67
+
59
68
60
- @functools .lru_cache (1024 )
61
69
def to_camel_case (s : str ) -> str :
62
- s = _RE_CAMEL_CASE_1 .sub ("" , str (s ))
63
- if not s :
64
- return s
65
- return str (s [0 ]).lower () + _RE_CAMEL_CASE_2 .sub (
66
- lambda matched : str (matched .group (1 )).upper (),
67
- s [1 :],
68
- )
70
+ result = __to_snake_camel_cache .get (s , __not_valid )
71
+ if result is __not_valid :
72
+ s = _RE_CAMEL_CASE_1 .sub ("" , s )
73
+ if not s :
74
+ result = s
75
+ else :
76
+ result = str (s [0 ]).lower () + _RE_CAMEL_CASE_2 .sub (
77
+ lambda matched : str (matched .group (1 )).upper (),
78
+ s [1 :],
79
+ )
80
+ __to_snake_camel_cache [s ] = result
81
+ return cast (str , result )
69
82
70
83
71
84
class CamelSnakeMixin :
@@ -110,21 +123,13 @@ def _decode_case(cls, s: str) -> str:
110
123
return s
111
124
112
125
113
- __default_config : Optional [DefaultConfig ] = None
114
-
115
-
116
- def __get_default_config () -> DefaultConfig :
117
- global __default_config
118
-
119
- if __default_config is None :
120
- __default_config = DefaultConfig ()
121
- return __default_config
126
+ __default_config = DefaultConfig ()
122
127
123
128
124
129
def __get_config (obj : Any , entry_protocol : Type [_T ]) -> _T :
125
130
if isinstance (obj , entry_protocol ):
126
131
return obj
127
- return cast (_T , __get_default_config () )
132
+ return cast (_T , __default_config )
128
133
129
134
130
135
def encode_case (obj : Any , field : dataclasses .Field ) -> str : # type: ignore
@@ -357,23 +362,32 @@ def from_json(
357
362
358
363
359
364
def as_dict (
360
- value : Any , * , remove_defaults : bool = False , dict_factory : Callable [[Any ], Dict [str , Any ]] = dict
365
+ value : Any ,
366
+ * ,
367
+ remove_defaults : bool = False ,
368
+ dict_factory : Callable [[Any ], Dict [str , Any ]] = dict ,
369
+ encode : bool = True ,
361
370
) -> Dict [str , Any ]:
362
371
if not dataclasses .is_dataclass (value ):
363
372
raise TypeError ("as_dict() should be called on dataclass instances" )
364
373
365
- return cast (Dict [str , Any ], _as_dict_inner (value , remove_defaults , dict_factory ))
374
+ return cast (Dict [str , Any ], _as_dict_inner (value , remove_defaults , dict_factory , encode ))
366
375
367
376
368
- def _as_dict_inner (value : Any , remove_defaults : bool , dict_factory : Callable [[Any ], Dict [str , Any ]]) -> Any :
377
+ def _as_dict_inner (
378
+ value : Any ,
379
+ remove_defaults : bool ,
380
+ dict_factory : Callable [[Any ], Dict [str , Any ]],
381
+ encode : bool = True ,
382
+ ) -> Any :
369
383
if dataclasses .is_dataclass (value ):
370
384
result = []
371
385
for f in dataclasses .fields (value ):
372
386
v = _as_dict_inner (getattr (value , f .name ), remove_defaults , dict_factory )
373
387
374
388
if remove_defaults and v == f .default :
375
389
continue
376
- result .append ((f .name , v ))
390
+ result .append ((encode_case ( value , f ) if encode else f .name , v ))
377
391
return dict_factory (result )
378
392
379
393
if isinstance (value , tuple ) and hasattr (value , "_fields" ):
@@ -388,7 +402,7 @@ def _as_dict_inner(value: Any, remove_defaults: bool, dict_factory: Callable[[An
388
402
for k , v in value .items ()
389
403
)
390
404
391
- return copy . deepcopy ( value )
405
+ return value
392
406
393
407
394
408
class TypeValidationError (Exception ):
0 commit comments