Skip to content

Commit b1f87eb

Browse files
committed
Add new Connection parameter: max_cacheable_statement_size.
Closes: #115.
1 parent bde06b5 commit b1f87eb

File tree

3 files changed

+81
-21
lines changed

3 files changed

+81
-21
lines changed

asyncpg/connection.py

+47-20
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,13 @@ class Connection(metaclass=ConnectionMeta):
4343
'_addr', '_opts', '_command_timeout', '_listeners',
4444
'_server_version', '_server_caps', '_intro_query',
4545
'_reset_query', '_proxy', '_stmt_exclusive_section',
46-
'_ssl_context')
46+
'_max_cacheable_statement_size', '_ssl_context')
4747

4848
def __init__(self, protocol, transport, loop, addr, opts, *,
4949
statement_cache_size, command_timeout,
50-
max_cached_statement_lifetime, ssl_context):
50+
max_cached_statement_lifetime,
51+
max_cacheable_statement_size,
52+
ssl_context):
5153
self._protocol = protocol
5254
self._transport = transport
5355
self._loop = loop
@@ -61,6 +63,7 @@ def __init__(self, protocol, transport, loop, addr, opts, *,
6163
self._opts = opts
6264
self._ssl_context = ssl_context
6365

66+
self._max_cacheable_statement_size = max_cacheable_statement_size
6467
self._stmt_cache = _StatementCache(
6568
loop=loop,
6669
max_size=statement_cache_size,
@@ -69,22 +72,6 @@ def __init__(self, protocol, transport, loop, addr, opts, *,
6972

7073
self._stmts_to_close = set()
7174

72-
if command_timeout is not None:
73-
try:
74-
if isinstance(command_timeout, bool):
75-
raise ValueError
76-
77-
command_timeout = float(command_timeout)
78-
79-
if command_timeout < 0:
80-
raise ValueError
81-
82-
except ValueError:
83-
raise ValueError(
84-
'invalid command_timeout value: '
85-
'expected non-negative float (got {!r})'.format(
86-
command_timeout)) from None
87-
8875
self._command_timeout = command_timeout
8976

9077
self._listeners = {}
@@ -280,7 +267,16 @@ async def _get_statement(self, query, timeout, *, named: bool=False):
280267
if statement is not None:
281268
return statement
282269

283-
if self._stmt_cache.get_max_size() or named:
270+
# Only use the cache when:
271+
# * `statement_cache_size` is greater than 0;
272+
# * query size is less than `max_cacheable_statement_size`.
273+
use_cache = self._stmt_cache.get_max_size() > 0
274+
if (use_cache and
275+
self._max_cacheable_statement_size and
276+
len(query) > self._max_cacheable_statement_size):
277+
use_cache = False
278+
279+
if use_cache or named:
284280
stmt_name = self._get_unique_id('stmt')
285281
else:
286282
stmt_name = ''
@@ -295,7 +291,8 @@ async def _get_statement(self, query, timeout, *, named: bool=False):
295291
types = await self._types_stmt.fetch(list(ready))
296292
self._protocol.get_settings().register_data_types(types)
297293

298-
self._stmt_cache.put(query, statement)
294+
if use_cache:
295+
self._stmt_cache.put(query, statement)
299296

300297
# If we've just created a new statement object, check if there
301298
# are any statements for GC.
@@ -721,6 +718,7 @@ async def connect(dsn=None, *,
721718
timeout=60,
722719
statement_cache_size=100,
723720
max_cached_statement_lifetime=300,
721+
max_cacheable_statement_size=1024 * 15,
724722
command_timeout=None,
725723
ssl=None,
726724
__connection_class__=Connection,
@@ -772,6 +770,11 @@ async def connect(dsn=None, *,
772770
in the cache. Pass ``0`` to allow statements be cached
773771
indefinitely.
774772
773+
:param int max_cacheable_statement_size:
774+
the maximum size of a statement that can be cached (15KiB by
775+
default). Pass ``0`` to allow all statements to be cached
776+
regardless of their size.
777+
775778
:param float command_timeout:
776779
the default timeout for operations on this connection
777780
(the default is no timeout).
@@ -807,6 +810,29 @@ async def connect(dsn=None, *,
807810
if loop is None:
808811
loop = asyncio.get_event_loop()
809812

813+
local_vars = locals()
814+
for var_name in {'max_cacheable_statement_size',
815+
'max_cached_statement_lifetime',
816+
'statement_cache_size'}:
817+
var_val = local_vars[var_name]
818+
if var_val is None or isinstance(var_val, bool) or var_val < 0:
819+
raise ValueError(
820+
'{} is expected to be greater '
821+
'or equal to 0, got {!r}'.format(var_name, var_val))
822+
823+
if command_timeout is not None:
824+
try:
825+
if isinstance(command_timeout, bool):
826+
raise ValueError
827+
command_timeout = float(command_timeout)
828+
if command_timeout < 0:
829+
raise ValueError
830+
except ValueError:
831+
raise ValueError(
832+
'invalid command_timeout value: '
833+
'expected non-negative float (got {!r})'.format(
834+
command_timeout)) from None
835+
810836
addrs, opts = _parse_connect_params(
811837
dsn=dsn, host=host, port=port, user=user, password=password,
812838
database=database, opts=opts)
@@ -855,6 +881,7 @@ async def connect(dsn=None, *,
855881
pr, tr, loop, addr, opts,
856882
statement_cache_size=statement_cache_size,
857883
max_cached_statement_lifetime=max_cached_statement_lifetime,
884+
max_cacheable_statement_size=max_cacheable_statement_size,
858885
command_timeout=command_timeout, ssl_context=ssl)
859886

860887
pr.set_connection(con)

tests/test_connect.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ async def test_auth_unsupported(self):
193193
pass
194194

195195

196-
class TestConnectParams(unittest.TestCase):
196+
class TestConnectParams(tb.TestCase):
197197

198198
TESTS = [
199199
{
@@ -421,6 +421,18 @@ def test_connect_params(self):
421421
for testcase in self.TESTS:
422422
self.run_testcase(testcase)
423423

424+
async def test_connect_args_validation(self):
425+
for val in {-1, 'a', True, False}:
426+
with self.assertRaisesRegex(ValueError, 'non-negative'):
427+
await asyncpg.connect(command_timeout=val)
428+
429+
for arg in {'max_cacheable_statement_size',
430+
'max_cached_statement_lifetime',
431+
'statement_cache_size'}:
432+
for val in {None, -1, True, False}:
433+
with self.assertRaisesRegex(ValueError, 'greater or equal'):
434+
await asyncpg.connect(**{arg: val})
435+
424436

425437
class TestConnection(tb.ConnectedTestCase):
426438

tests/test_prepare.py

+21
Original file line numberDiff line numberDiff line change
@@ -513,3 +513,24 @@ async def test_prepare_26_max_lifetime_max_size(self):
513513

514514
# Check that nothing crashes after the initial timeout
515515
await asyncio.sleep(1, loop=self.loop)
516+
517+
@tb.with_connection_options(max_cacheable_statement_size=50)
518+
async def test_prepare_27_max_cacheable_statement_size(self):
519+
cache = self.con._stmt_cache
520+
521+
await self.con.prepare('SELECT 1')
522+
self.assertEqual(len(cache), 1)
523+
524+
# Test that long and explicitly created prepared statements
525+
# are not cached.
526+
await self.con.prepare("SELECT \'" + "a" * 50 + "\'")
527+
self.assertEqual(len(cache), 1)
528+
529+
# Test that implicitly created long prepared statements
530+
# are not cached.
531+
await self.con.fetchval("SELECT \'" + "a" * 50 + "\'")
532+
self.assertEqual(len(cache), 1)
533+
534+
# Test that short prepared statements can still be cached.
535+
await self.con.prepare('SELECT 2')
536+
self.assertEqual(len(cache), 2)

0 commit comments

Comments
 (0)