Skip to content

Commit f709b33

Browse files
committed
add a limit to return value of ValueReader.ReadMultiLength to avoid memory overflow caused by malicious byte stream
1 parent 9138171 commit f709b33

File tree

2 files changed

+64
-50
lines changed

2 files changed

+64
-50
lines changed

readers.go

+4-27
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ import (
2424
"reflect"
2525
)
2626

27+
type Lenner interface {
28+
Len() int
29+
}
30+
2731
type typeReaderFunc func(length int, vr ValueReader, value reflect.Value, nesting int) error
2832

2933
func unsupported(length int, vr ValueReader, value reflect.Value, nesting int) error {
@@ -617,32 +621,6 @@ func valueReader(r ValueReader, value reflect.Value) error {
617621
}
618622

619623
func valueReader0(vr ValueReader, value reflect.Value, nesting int) error {
620-
// typ := value.Type()
621-
// if typ.Implements(TypeOfDecoder) {
622-
// if typ.Kind() == reflect.Ptr && value.IsNil() {
623-
// nvalue := reflect.New(typ.Elem())
624-
// value.Set(nvalue)
625-
// }
626-
// decoder, _ := value.Interface().(Decoder)
627-
// return decoder.Deserialization(vr)
628-
// }
629-
// if typ.Kind() == reflect.Ptr {
630-
// etyp := typ.Elem()
631-
// // if value.IsNil() {
632-
// // nvalue := reflect.New(etyp)
633-
// // value.Set(nvalue)
634-
// // }
635-
// if etyp.Implements(TypeOfDecoder) {
636-
// elem := value.Elem()
637-
// if elem.Kind() == reflect.Ptr && elem.IsNil() {
638-
// evalue := reflect.New(etyp.Elem())
639-
// elem.Set(evalue)
640-
// }
641-
// decoder, _ := elem.Interface().(Decoder)
642-
// return decoder.Deserialization(vr)
643-
// }
644-
// }
645-
646624
// decode itself if the value implements encoding.Decoder interface
647625
isDecoder, err := checkTypeOfDecoder(vr, value)
648626
if isDecoder || err != nil {
@@ -675,7 +653,6 @@ func valueReader1(th TypeHeader, length int, vr ValueReader, value reflect.Value
675653

676654
// big.Rat or *big.Rat
677655
if typ.AssignableTo(typeOfBigRat) || typ.AssignableTo(reflect.PtrTo(typeOfBigRat)) {
678-
// if typ.AssignableTo(typeOfBigRatPtr) {
679656
return bigRatReader0(th, int(length), vr, value, nesting)
680657
}
681658

valuereader.go

+60-23
Original file line numberDiff line numberDiff line change
@@ -33,17 +33,18 @@ type ValueReader interface {
3333
}
3434

3535
type noBufValueReader struct {
36-
reader io.Reader
37-
eof bool
38-
readCount int
39-
header [1]byte
36+
reader io.Reader
37+
eof bool
38+
readCount int
39+
header [1]byte
40+
readerSize int
4041
}
4142

4243
func EndOfFile(err error) bool {
4344
return err == io.EOF || err == io.ErrUnexpectedEOF
4445
}
4546

46-
func (r *noBufValueReader) filerErr(err error) error {
47+
func (r *noBufValueReader) filterErr(err error) error {
4748
if EndOfFile(err) {
4849
r.eof = true
4950
return io.EOF
@@ -62,13 +63,17 @@ func (r *noBufValueReader) HasMore() bool {
6263
return true
6364
}
6465

66+
func (r *noBufValueReader) left() int {
67+
return r.readerSize - r.readCount
68+
}
69+
6570
func (r *noBufValueReader) ReadHeader() (TypeHeader, uint32, error) {
6671
if !r.HasMore() {
6772
return 0, 0, io.EOF
6873
}
6974
b, err := r.ReadByte()
7075
if err != nil {
71-
return 0, 0, r.filerErr(err)
76+
return 0, 0, r.filterErr(err)
7277
}
7378
return ParseRTLHeader(b)
7479
}
@@ -80,7 +85,7 @@ func (r *noBufValueReader) ReadByte() (byte, error) {
8085
n, err := io.ReadFull(r.reader, r.header[:])
8186
r.readCount += n
8287
if err != nil {
83-
return 0, r.filerErr(err)
88+
return 0, r.filterErr(err)
8489
}
8590
if n <= 0 {
8691
r.eof = true
@@ -96,29 +101,38 @@ func (r *noBufValueReader) Read(p []byte) (int, error) {
96101

97102
n, err := io.ReadFull(r.reader, p)
98103
r.readCount += n
99-
return n, r.filerErr(err)
104+
return n, r.filterErr(err)
100105
}
101106

102107
func (r *noBufValueReader) ReadBytes(length int, bytes []byte) ([]byte, error) {
103108
return ReadBytesFromReader(r, length, bytes)
104109
}
105110

106111
func (r *noBufValueReader) ReadMultiLength(length int) (uint64, error) {
107-
return ReadMultiLengthFromReader(r, length)
112+
ret, err := ReadMultiLengthFromReader(r, length)
113+
if err != nil {
114+
return 0, err
115+
}
116+
left := r.left()
117+
if left <= 0 || ret > uint64(left) {
118+
return 0, fmt.Errorf("%d bytes multi-length(%d) is larger than left(%d)", length, ret, left)
119+
}
120+
return ret, nil
108121
}
109122

110123
func (r *noBufValueReader) ReadMultiLengthBytes(length int, bytes []byte) ([]byte, error) {
111124
return ReadMultiLengthBytesFromReader(r, length, bytes)
112125
}
113126

114127
type bufValueReader struct {
115-
reader io.Reader // basic reader
116-
eof bool // if the reader EOF
117-
lastError error // error of last reading(if exist, except io.EOF)
118-
buffer []byte // buffered bytes
119-
available uint32 // length of available bytes in buffer
120-
offset uint32 // offset for buffer of reading
121-
readCount int // counting the read bytes
128+
reader io.Reader // basic reader
129+
eof bool // if the reader EOF
130+
lastError error // error of last reading(if exist, except io.EOF)
131+
buffer []byte // buffered bytes
132+
available uint32 // length of available bytes in buffer
133+
offset uint32 // offset for buffer of reading
134+
readCount int // counting the read bytes
135+
readerSize int // summary size of the reader, if it's a stream, set to MaxSliceSize
122136
}
123137

124138
func (r *bufValueReader) ResetCount() {
@@ -136,6 +150,10 @@ func (r *bufValueReader) HasMore() bool {
136150
return r.next()
137151
}
138152

153+
func (r *bufValueReader) left() int {
154+
return r.readerSize - r.readCount
155+
}
156+
139157
// next read more bytes to buffer when buffer is empty,
140158
// and return whether has more bytes in buffer
141159
func (r *bufValueReader) next() bool {
@@ -258,7 +276,15 @@ func (r *bufValueReader) ReadBytes(length int, bytes []byte) ([]byte, error) {
258276

259277
// ReadMultiLength read length of multi bytes header value's length
260278
func (r *bufValueReader) ReadMultiLength(length int) (uint64, error) {
261-
return ReadMultiLengthFromReader(r, length)
279+
ret, err := ReadMultiLengthFromReader(r, length)
280+
if err != nil {
281+
return 0, err
282+
}
283+
left := r.left()
284+
if left <= 0 || ret > uint64(left) {
285+
return 0, fmt.Errorf("%d bytes multi-length(%d) is larger than left(%d)", length, ret, left)
286+
}
287+
return ret, nil
262288
}
263289

264290
func (r *bufValueReader) ReadMultiLengthBytes(length int, bytes []byte) ([]byte, error) {
@@ -335,15 +361,26 @@ func ReadMultiLengthBytesFromReader(vr ValueReader, length int, bytes []byte) ([
335361
}
336362

337363
func NewValueReader(r io.Reader, bufferSize int) ValueReader {
364+
len := int(MaxSliceSize)
365+
lenner, ok := r.(Lenner)
366+
if ok {
367+
len = lenner.Len()
368+
}
338369
if bufferSize > 0 {
339370
return &bufValueReader{
340-
reader: r,
341-
eof: false,
342-
buffer: make([]byte, bufferSize),
343-
available: 0,
344-
offset: 0,
371+
reader: r,
372+
eof: false,
373+
buffer: make([]byte, bufferSize),
374+
available: 0,
375+
offset: 0,
376+
readerSize: len,
345377
}
346378
} else {
347-
return &noBufValueReader{reader: r, eof: false, readCount: 0}
379+
return &noBufValueReader{
380+
reader: r,
381+
eof: false,
382+
readCount: 0,
383+
readerSize: len,
384+
}
348385
}
349386
}

0 commit comments

Comments
 (0)