Skip to content

Commit 192b6fe

Browse files
committed
Allow empty Record objects.
1 parent 27ce0aa commit 192b6fe

File tree

5 files changed

+46
-19
lines changed

5 files changed

+46
-19
lines changed

asyncpg/protocol/prepared_stmt.pyx

+8-2
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,11 @@ cdef class PreparedStatementState:
141141
Codec codec
142142
list codecs
143143

144-
if self.cols_num == 0 or self.cols_desc is not None:
144+
if self.cols_desc is not None:
145+
return
146+
147+
if self.cols_num == 0:
148+
self.cols_desc = record.ApgRecordDesc_New({}, ())
145149
return
146150

147151
cols_mapping = collections.OrderedDict()
@@ -219,7 +223,9 @@ cdef class PreparedStatementState:
219223
fnum, self.cols_num))
220224

221225
if rows_codecs is None or len(rows_codecs) < fnum:
222-
raise RuntimeError('invalid rows_codecs')
226+
if fnum > 0:
227+
# It's OK to have no rows_codecs for empty records
228+
raise RuntimeError('invalid rows_codecs')
223229

224230
dec_row = record.ApgRecord_New(self.cols_desc, fnum)
225231
for i in range(fnum):

asyncpg/protocol/protocol.pyx

+6-1
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,12 @@ def _create_record(object mapping, tuple elems):
479479
object rec
480480
int32_t i
481481

482-
desc = record.ApgRecordDesc_New(mapping, tuple(mapping) if mapping else ())
482+
if mapping is None:
483+
desc = record.ApgRecordDesc_New({}, ())
484+
else:
485+
desc = record.ApgRecordDesc_New(
486+
mapping, tuple(mapping) if mapping else ())
487+
483488
rec = record.ApgRecord_New(desc, len(elems))
484489
for i in range(len(elems)):
485490
elem = elems[i]

asyncpg/protocol/record/recordobj.c

+4-2
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ ApgRecord_New(PyObject *desc, Py_ssize_t size)
2222
ApgRecordObject *o;
2323
Py_ssize_t i;
2424

25-
if (size < 1 || desc == NULL || !ApgRecordDesc_CheckExact(desc)) {
25+
if (size < 0 || desc == NULL || !ApgRecordDesc_CheckExact(desc)) {
2626
PyErr_BadInternalCall();
2727
return NULL;
2828
}
@@ -346,7 +346,9 @@ record_repr(ApgRecordObject *v)
346346
_PyUnicodeWriter writer;
347347

348348
n = Py_SIZE(v);
349-
assert(n > 0);
349+
if (n == 0) {
350+
return PyUnicode_FromString("<Record>");
351+
}
350352

351353
keys_iter = PyObject_GetIter(v->desc->keys);
352354
if (keys_iter == NULL) {

tests/test_prepare.py

+5
Original file line numberDiff line numberDiff line change
@@ -395,3 +395,8 @@ async def test_prepare_21_errors(self):
395395
await stmt.fetchval(0)
396396

397397
self.assertEqual(await stmt.fetchval(5), 2)
398+
399+
async def test_prepare_22_empty(self):
400+
result = await self.con.fetchrow('SELECT')
401+
self.assertEqual(result, ())
402+
self.assertEqual(repr(result), '<Record>')

tests/test_record.py

+23-14
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import gc
1111
import pickle
1212
import sys
13-
import unittest
1413

1514
from asyncpg import _testbase as tb
1615
from asyncpg.protocol.protocol import _create_record as Record
@@ -37,10 +36,6 @@ def checkref(self, *objs):
3736
self.fail('refcounts differ for {!r}: {:+}'.format(
3837
objs[i], after - before))
3938

40-
def test_record_zero_length(self):
41-
with self.assertRaises(SystemError):
42-
Record({}, ())
43-
4439
def test_record_gc(self):
4540
elem = object()
4641
mapping = {}
@@ -153,10 +148,7 @@ def test_record_keys(self):
153148
r = Record(R_AB, (42, 43))
154149
vv = r.keys()
155150
self.assertEqual(tuple(vv), ('a', 'b'))
156-
157-
# test invalid record
158-
with self.assertRaisesRegex(TypeError, 'not iterable'):
159-
Record(None, (42, 43)).keys()
151+
self.assertEqual(list(Record(None, (42, 43)).keys()), [])
160152

161153
def test_record_items(self):
162154
r = Record(R_AB, (42, 43))
@@ -188,9 +180,12 @@ def test_record_items(self):
188180
self.assertEqual(list(r.items()), [('a', 42)])
189181
r = Record(R_AB, (42,))
190182
self.assertEqual(list(r.items()), [('a', 42)])
191-
r = Record(None, (42, 43))
192-
with self.assertRaises(TypeError):
193-
list(r.items())
183+
184+
# Try to iterate over exhausted items() iterator
185+
r = Record(R_A, (42, 43))
186+
it = r.items()
187+
list(it)
188+
list(it)
194189

195190
def test_record_hash(self):
196191
AB = collections.namedtuple('AB', ('a', 'b'))
@@ -220,8 +215,7 @@ def test_record_contains(self):
220215
self.assertNotIn('z', r)
221216

222217
r = Record(None, (42, 43))
223-
with self.assertRaises(TypeError):
224-
self.assertIn('a', r)
218+
self.assertNotIn('a', r)
225219

226220
with self.assertRaises(TypeError):
227221
type(r).__contains__(None, 'a')
@@ -281,6 +275,21 @@ def test_record_not_pickleable(self):
281275
with self.assertRaises(Exception):
282276
pickle.dumps(r)
283277

278+
def test_record_empty(self):
279+
r = Record(None, ())
280+
self.assertEqual(r, ())
281+
self.assertLess(r, (1,))
282+
self.assertEqual(len(r), 0)
283+
self.assertFalse(r)
284+
self.assertNotIn('a', r)
285+
self.assertEqual(repr(r), '<Record>')
286+
self.assertEqual(str(r), '<Record>')
287+
with self.assertRaisesRegex(KeyError, 'aaa'):
288+
r['aaa']
289+
self.assertEqual(dict(r.items()), {})
290+
self.assertEqual(list(r.keys()), [])
291+
self.assertEqual(list(r.values()), [])
292+
284293
async def test_record_duplicate_colnames(self):
285294
"""Test that Record handles duplicate column names."""
286295
r = await self.con.fetchrow('SELECT 1 as a, 2 as a')

0 commit comments

Comments
 (0)