@@ -43,11 +43,13 @@ class Connection(metaclass=ConnectionMeta):
43
43
'_addr' , '_opts' , '_command_timeout' , '_listeners' ,
44
44
'_server_version' , '_server_caps' , '_intro_query' ,
45
45
'_reset_query' , '_proxy' , '_stmt_exclusive_section' ,
46
- '_ssl_context' )
46
+ '_max_cacheable_statement_size' , ' _ssl_context' )
47
47
48
48
def __init__ (self , protocol , transport , loop , addr , opts , * ,
49
49
statement_cache_size , command_timeout ,
50
- max_cached_statement_lifetime , ssl_context ):
50
+ max_cached_statement_lifetime ,
51
+ max_cacheable_statement_size ,
52
+ ssl_context ):
51
53
self ._protocol = protocol
52
54
self ._transport = transport
53
55
self ._loop = loop
@@ -61,6 +63,7 @@ def __init__(self, protocol, transport, loop, addr, opts, *,
61
63
self ._opts = opts
62
64
self ._ssl_context = ssl_context
63
65
66
+ self ._max_cacheable_statement_size = max_cacheable_statement_size
64
67
self ._stmt_cache = _StatementCache (
65
68
loop = loop ,
66
69
max_size = statement_cache_size ,
@@ -69,22 +72,6 @@ def __init__(self, protocol, transport, loop, addr, opts, *,
69
72
70
73
self ._stmts_to_close = set ()
71
74
72
- if command_timeout is not None :
73
- try :
74
- if isinstance (command_timeout , bool ):
75
- raise ValueError
76
-
77
- command_timeout = float (command_timeout )
78
-
79
- if command_timeout < 0 :
80
- raise ValueError
81
-
82
- except ValueError :
83
- raise ValueError (
84
- 'invalid command_timeout value: '
85
- 'expected non-negative float (got {!r})' .format (
86
- command_timeout )) from None
87
-
88
75
self ._command_timeout = command_timeout
89
76
90
77
self ._listeners = {}
@@ -280,7 +267,16 @@ async def _get_statement(self, query, timeout, *, named: bool=False):
280
267
if statement is not None :
281
268
return statement
282
269
283
- if self ._stmt_cache .get_max_size () or named :
270
+ # Only use the cache when:
271
+ # * `statement_cache_size` is greater than 0;
272
+ # * query size is less than `max_cacheable_statement_size`.
273
+ use_cache = self ._stmt_cache .get_max_size () > 0
274
+ if (use_cache and
275
+ self ._max_cacheable_statement_size and
276
+ len (query ) > self ._max_cacheable_statement_size ):
277
+ use_cache = False
278
+
279
+ if use_cache or named :
284
280
stmt_name = self ._get_unique_id ('stmt' )
285
281
else :
286
282
stmt_name = ''
@@ -295,7 +291,8 @@ async def _get_statement(self, query, timeout, *, named: bool=False):
295
291
types = await self ._types_stmt .fetch (list (ready ))
296
292
self ._protocol .get_settings ().register_data_types (types )
297
293
298
- self ._stmt_cache .put (query , statement )
294
+ if use_cache :
295
+ self ._stmt_cache .put (query , statement )
299
296
300
297
# If we've just created a new statement object, check if there
301
298
# are any statements for GC.
@@ -721,6 +718,7 @@ async def connect(dsn=None, *,
721
718
timeout = 60 ,
722
719
statement_cache_size = 100 ,
723
720
max_cached_statement_lifetime = 300 ,
721
+ max_cacheable_statement_size = 1024 * 15 ,
724
722
command_timeout = None ,
725
723
ssl = None ,
726
724
__connection_class__ = Connection ,
@@ -772,6 +770,11 @@ async def connect(dsn=None, *,
772
770
in the cache. Pass ``0`` to allow statements be cached
773
771
indefinitely.
774
772
773
+ :param int max_cacheable_statement_size:
774
+ the maximum size of a statement that can be cached (15KiB by
775
+ default). Pass ``0`` to allow all statements to be cached
776
+ regardless of their size.
777
+
775
778
:param float command_timeout:
776
779
the default timeout for operations on this connection
777
780
(the default is no timeout).
@@ -807,6 +810,29 @@ async def connect(dsn=None, *,
807
810
if loop is None :
808
811
loop = asyncio .get_event_loop ()
809
812
813
+ local_vars = locals ()
814
+ for var_name in {'max_cacheable_statement_size' ,
815
+ 'max_cached_statement_lifetime' ,
816
+ 'statement_cache_size' }:
817
+ var_val = local_vars [var_name ]
818
+ if var_val is None or isinstance (var_val , bool ) or var_val < 0 :
819
+ raise ValueError (
820
+ '{} is expected to be greater '
821
+ 'or equal to 0, got {!r}' .format (var_name , var_val ))
822
+
823
+ if command_timeout is not None :
824
+ try :
825
+ if isinstance (command_timeout , bool ):
826
+ raise ValueError
827
+ command_timeout = float (command_timeout )
828
+ if command_timeout < 0 :
829
+ raise ValueError
830
+ except ValueError :
831
+ raise ValueError (
832
+ 'invalid command_timeout value: '
833
+ 'expected non-negative float (got {!r})' .format (
834
+ command_timeout )) from None
835
+
810
836
addrs , opts = _parse_connect_params (
811
837
dsn = dsn , host = host , port = port , user = user , password = password ,
812
838
database = database , opts = opts )
@@ -855,6 +881,7 @@ async def connect(dsn=None, *,
855
881
pr , tr , loop , addr , opts ,
856
882
statement_cache_size = statement_cache_size ,
857
883
max_cached_statement_lifetime = max_cached_statement_lifetime ,
884
+ max_cacheable_statement_size = max_cacheable_statement_size ,
858
885
command_timeout = command_timeout , ssl_context = ssl )
859
886
860
887
pr .set_connection (con )
0 commit comments