Skip to content

Commit 08298b4

Browse files
Fix Contains for Distinct/Union with custom comparer (#112815)
* Fix Contains for Distinct/Union with custom comparer The recently-added optimization needs to factor in a non-default comparer. * Address feedback --------- Co-authored-by: Eirik Tsarpalis <eirik.tsarpalis@gmail.com>
1 parent 0580a0c commit 08298b4

File tree

3 files changed

+40
-8
lines changed

3 files changed

+40
-8
lines changed

src/libraries/System.Linq/src/System/Linq/Distinct.SpeedOpt.cs

+7-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,13 @@ private sealed partial class DistinctIterator<TSource>
1717

1818
public override TSource? TryGetFirst(out bool found) => _source.TryGetFirst(out found);
1919

20-
public override bool Contains(TSource value) => _source.Contains(value);
20+
public override bool Contains(TSource value) =>
21+
// If we're using the default comparer, then source.Distinct().Contains(value) is no different from
22+
// source.Contains(value), as the Distinct() won't remove anything that could have caused
23+
// Contains to return true. If, however, there is a custom comparer, Distinct might remove
24+
// the elements that would have matched, and thus we can't skip it.
25+
_comparer is null ? _source.Contains(value) :
26+
base.Contains(value);
2127
}
2228
}
2329
}

src/libraries/System.Linq/src/System/Linq/Union.SpeedOpt.cs

+15-5
Original file line numberDiff line numberDiff line change
@@ -48,16 +48,26 @@ private HashSet<TSource> FillSet()
4848

4949
public override bool Contains(TSource value)
5050
{
51-
IEnumerable<TSource>? source;
52-
for (int i = 0; (source = GetEnumerable(i)) is not null; i++)
51+
// If there's no comparer, then source1.Union(source2).Contains(value) is no different from
52+
// source1.Contains(value) || source2.Contains(value), as Union's set semantics won't remove
53+
// anything from either that could have matched. However, if there is a comparer, it's possible
54+
// the Union could end up removing items that would have matched, and thus we can't skip it.
55+
if (_comparer is null || _comparer == EqualityComparer<TSource>.Default)
5356
{
54-
if (source.Contains(value))
57+
IEnumerable<TSource>? source;
58+
for (int i = 0; (source = GetEnumerable(i)) is not null; i++)
5559
{
56-
return true;
60+
if (source.Contains(value))
61+
{
62+
return true;
63+
}
5764
}
65+
66+
return false;
5867
}
5968

60-
return default;
69+
70+
return base.Contains(value);
6171
}
6272
}
6373
}

src/libraries/System.Linq/tests/ContainsTests.cs

+18-2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33

44
using System.Collections.Generic;
5+
using System.Diagnostics;
56
using Xunit;
67

78
namespace System.Linq.Tests
@@ -189,8 +190,9 @@ public void FollowingVariousOperators()
189190
Assert.True(transform(source.Concat(source)).Distinct().Contains(2));
190191
Assert.False(transform(source.Concat(source)).Distinct().Contains(4));
191192
Assert.True(transform(source.Concat(source)).Distinct().Contains(1));
192-
Assert.True(transform(source.Concat(source)).Distinct(EqualityComparer<int>.Create((x, y) => true)).Contains(2));
193-
Assert.False(transform(source.Concat(source)).Distinct(EqualityComparer<int>.Create((x, y) => true)).Contains(0));
193+
Assert.True(transform(source.Concat(source)).Distinct(EqualityComparer<int>.Create((x, y) => true, x => 0)).Contains(1));
194+
Assert.False(transform(source.Concat(source)).Distinct(EqualityComparer<int>.Create((x, y) => true, x => 0)).Contains(2));
195+
Assert.False(transform(source.Concat(source)).Distinct(EqualityComparer<int>.Create((x, y) => true, x => 0)).Contains(0));
194196

195197
// OrderBy
196198
Assert.True(transformedSource.OrderBy(x => x).Contains(2));
@@ -246,13 +248,21 @@ public void FollowingVariousOperators()
246248

247249
// Union
248250
Assert.True(transformedSource.Union(transform([4])).Contains(4));
251+
Assert.True(transformedSource.Union(transform([4]), EqualityComparer<int>.Create((x, y) => true, x => 0)).Contains(1));
252+
Assert.False(transformedSource.Union(transform([4]), EqualityComparer<int>.Create((x, y) => true, x => 0)).Contains(4));
249253
Assert.False(transformedSource.Union(transform([3])).Contains(4));
250254
}
251255

252256
// DefaultIfEmpty
253257
Assert.True(Enumerable.Empty<int>().DefaultIfEmpty(1).Contains(1));
254258
Assert.False(Enumerable.Empty<int>().DefaultIfEmpty(1).Contains(0));
255259

260+
// Distinct
261+
Assert.True(new string[] { "a", "A" }.Distinct().Contains("a"));
262+
Assert.True(new string[] { "a", "A" }.Distinct().Contains("A"));
263+
Assert.True(new string[] { "a", "A" }.Distinct(StringComparer.OrdinalIgnoreCase).Contains("a"));
264+
Assert.False(new string[] { "a", "A" }.Distinct(StringComparer.OrdinalIgnoreCase).Contains("A"));
265+
256266
// Repeat
257267
Assert.True(Enumerable.Repeat(1, 5).Contains(1));
258268
Assert.False(Enumerable.Repeat(1, 5).Contains(2));
@@ -268,6 +278,12 @@ public void FollowingVariousOperators()
268278
Assert.False(new object[] { 1, "2", 3 }.OfType<int>().Contains(2));
269279
Assert.True(new object[] { 1, "2", 3 }.OfType<string>().Contains("2"));
270280
Assert.False(new object[] { 1, "2", 3 }.OfType<string>().Contains("4"));
281+
282+
// Union
283+
Assert.True(new string[] { "a" }.Union(new string[] { "A" }).Contains("a"));
284+
Assert.True(new string[] { "a" }.Union(new string[] { "A" }).Contains("A"));
285+
Assert.True(new string[] { "a" }.Union(new string[] { "A" }, StringComparer.OrdinalIgnoreCase).Contains("a"));
286+
Assert.False(new string[] { "a" }.Union(new string[] { "A" }, StringComparer.OrdinalIgnoreCase).Contains("A"));
271287
}
272288
}
273289
}

0 commit comments

Comments
 (0)