Skip to content

Commit eff1f25

Browse files
authored
CSHARP-4779: Support Dictionary(IEnumerable<KeyValuePair<TKey, TValue>> collection) constructor in LINQ (#1657)
1 parent 7824d76 commit eff1f25

File tree

5 files changed

+393
-0
lines changed

5 files changed

+393
-0
lines changed

src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Optimizers/AstSimplifier.cs

+99
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
*/
1515

1616
using System;
17+
using System.Collections.Generic;
1718
using System.Linq;
1819
using MongoDB.Bson;
1920
using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions;
@@ -449,6 +450,79 @@ public override AstNode VisitMapExpression(AstMapExpression node)
449450
}
450451
}
451452

453+
// { $map : { input : { $map : { input : <innerInput>, as : "inner", in : { A : <exprA>, B : <exprB>, ... } } }, as: "outer", in : { F : '$$outer.A', G : "$$outer.B", ... } } }
454+
// => { $map : { input : <innerInput>, as: "inner", in : { F : <exprA>, G : <exprB>, ... } } }
455+
if (node.Input is AstMapExpression innerMapExpression &&
456+
node.As is var outerVar &&
457+
node.In is AstComputedDocumentExpression outerComputedDocumentExpression &&
458+
innerMapExpression.Input is var innerInput &&
459+
innerMapExpression.As is var innerVar &&
460+
innerMapExpression.In is AstComputedDocumentExpression innerComputedDocumentExpression &&
461+
outerComputedDocumentExpression.Fields.All(outerField =>
462+
outerField.Value is AstGetFieldExpression outerGetFieldExpression &&
463+
outerGetFieldExpression.Input == outerVar &&
464+
outerGetFieldExpression.FieldName is AstConstantExpression { Value : BsonString { Value : var matchingFieldName } } &&
465+
innerComputedDocumentExpression.Fields.Any(innerField => innerField.Path == matchingFieldName)))
466+
{
467+
var rewrittenOuterFields = new List<AstComputedField>();
468+
foreach (var outerField in outerComputedDocumentExpression.Fields)
469+
{
470+
var outerGetFieldExpression = (AstGetFieldExpression)outerField.Value;
471+
var matchingFieldName = ((AstConstantExpression)outerGetFieldExpression.FieldName).Value.AsString;
472+
var matchingInnerField = innerComputedDocumentExpression.Fields.Single(innerField => innerField.Path == matchingFieldName);
473+
var rewrittenOuterField = AstExpression.ComputedField(outerField.Path, matchingInnerField.Value);
474+
rewrittenOuterFields.Add(rewrittenOuterField);
475+
}
476+
477+
var simplified = AstExpression.Map(
478+
input: innerInput,
479+
@as: innerVar,
480+
@in: AstExpression.ComputedDocument(rewrittenOuterFields));
481+
482+
return Visit(simplified);
483+
}
484+
485+
// { $map : { input : [{ A : <exprA1>, B : <exprB1>, ... }, { A : <exprA2>, B : <exprB2>, ... }, ...], as : "item", in: { F : "$$item.A", G : "$$item.B", ... } } }
486+
// => [{ F : <exprA1>, G : <exprB1>", ... }, { F : <exprA2>, G : <exprB2>, ... }, ...]
487+
if (node.Input is AstComputedArrayExpression inputComputedArray &&
488+
inputComputedArray.Items.Count >= 1 &&
489+
inputComputedArray.Items[0] is AstComputedDocumentExpression firstComputedDocument &&
490+
firstComputedDocument.Fields.Select(inputField => inputField.Path).ToArray() is var inputFieldNames &&
491+
inputComputedArray.Items.Skip(1).All(otherItem =>
492+
otherItem is AstComputedDocumentExpression otherComputedDocument &&
493+
otherComputedDocument.Fields.Select(otherField => otherField.Path).SequenceEqual(inputFieldNames)) &&
494+
node.As is var itemVar &&
495+
node.In is AstComputedDocumentExpression mappedDocument &&
496+
mappedDocument.Fields.All(mappedField =>
497+
mappedField.Value is AstGetFieldExpression mappedGetField &&
498+
mappedGetField.Input == itemVar &&
499+
mappedGetField.FieldName is AstConstantExpression { Value : BsonString { Value : var matchingFieldName } } &&
500+
inputFieldNames.Contains(matchingFieldName)))
501+
{
502+
var rewrittenItems = new List<AstExpression>();
503+
foreach (var inputItem in inputComputedArray.Items)
504+
{
505+
var inputDocument = (AstComputedDocumentExpression)inputItem;
506+
507+
var rewrittenFields = new List<AstComputedField>();
508+
foreach (var mappedField in mappedDocument.Fields)
509+
{
510+
var mappedGetField = (AstGetFieldExpression)mappedField.Value;
511+
var matchingFieldName = ((AstConstantExpression)mappedGetField.FieldName).Value.AsString;
512+
var matchingInputField = inputDocument.Fields.Single(inputField => inputField.Path == matchingFieldName);
513+
var rewrittenField = AstExpression.ComputedField(mappedField.Path, matchingInputField.Value);
514+
rewrittenFields.Add(rewrittenField);
515+
}
516+
517+
var rewrittenItem = AstExpression.ComputedDocument(rewrittenFields);
518+
rewrittenItems.Add(rewrittenItem);
519+
}
520+
521+
var simplified = AstExpression.ComputedArray(rewrittenItems);
522+
523+
return Visit(simplified);
524+
}
525+
452526
return base.VisitMapExpression(node);
453527

454528
static AstExpression UltimateGetFieldInput(AstGetFieldExpression getField)
@@ -567,7 +641,32 @@ arg is AstBinaryExpression argBinaryExpression &&
567641
return AstExpression.Binary(oppositeComparisonOperator, argBinaryExpression.Arg1, argBinaryExpression.Arg2);
568642
}
569643

644+
// { $arrayToObject : [[{ k : 'A', v : <exprA> }, { k : 'B', v : <exprB> }, ...]] } => { A : <exprA>, B : <exprB>, ... }
645+
if (node.Operator == AstUnaryOperator.ArrayToObject &&
646+
arg is AstComputedArrayExpression computedArrayExpression &&
647+
computedArrayExpression.Items.All(
648+
item =>
649+
item is AstComputedDocumentExpression computedDocumentExpression &&
650+
computedDocumentExpression.Fields.Count == 2 &&
651+
computedDocumentExpression.Fields[0].Path == "k" &&
652+
computedDocumentExpression.Fields[1].Path == "v" &&
653+
computedDocumentExpression.Fields[0].Value is AstConstantExpression { Value : { IsString : true } }))
654+
{
655+
var computedFields = computedArrayExpression.Items.Select(KeyValuePairDocumentToComputedField);
656+
return AstExpression.ComputedDocument(computedFields);
657+
}
658+
570659
return node.Update(arg);
660+
661+
static AstComputedField KeyValuePairDocumentToComputedField(AstExpression expression)
662+
{
663+
// caller has verified that expression is of the form: { k : <stringConstant>, v : <valueExpression> }
664+
var keyValuePairDocumentExpression = (AstComputedDocumentExpression)expression;
665+
var keyConstantExpression = (AstConstantExpression)keyValuePairDocumentExpression.Fields[0].Value;
666+
var valueExpression = keyValuePairDocumentExpression.Fields[1].Value;
667+
668+
return AstExpression.ComputedField(keyConstantExpression.Value.AsString, valueExpression);
669+
}
571670
}
572671
}
573672
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/* Copyright 2010-present MongoDB Inc.
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
using System.Collections.Generic;
17+
using System.Reflection;
18+
using MongoDB.Driver.Linq.Linq3Implementation.Misc;
19+
20+
namespace MongoDB.Driver.Linq.Linq3Implementation.Reflection
21+
{
22+
internal static class DictionaryConstructor
23+
{
24+
public static bool IsWithIEnumerableKeyValuePairConstructor(ConstructorInfo constructor)
25+
{
26+
var declaringType = constructor.DeclaringType;
27+
var parameters = constructor.GetParameters();
28+
return
29+
declaringType.IsConstructedGenericType &&
30+
declaringType.GetGenericTypeDefinition() == typeof(Dictionary<,>) &&
31+
parameters.Length == 1 &&
32+
parameters[0].ParameterType.ImplementsIEnumerable(out var enumerableType) &&
33+
enumerableType.IsConstructedGenericType &&
34+
enumerableType.GetGenericTypeDefinition() == typeof(KeyValuePair<,>);
35+
}
36+
}
37+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
/* Copyright 2010-present MongoDB Inc.
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
using System;
17+
using System.Collections.Generic;
18+
using System.Linq.Expressions;
19+
using MongoDB.Bson;
20+
using MongoDB.Bson.Serialization;
21+
using MongoDB.Bson.Serialization.Options;
22+
using MongoDB.Bson.Serialization.Serializers;
23+
using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions;
24+
using MongoDB.Driver.Linq.Linq3Implementation.Misc;
25+
using MongoDB.Driver.Linq.Linq3Implementation.Reflection;
26+
27+
namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators
28+
{
29+
internal static class NewDictionaryExpressionToAggregationExpressionTranslator
30+
{
31+
public static bool CanTranslate(NewExpression expression)
32+
=> DictionaryConstructor.IsWithIEnumerableKeyValuePairConstructor(expression.Constructor);
33+
34+
public static TranslatedExpression Translate(TranslationContext context, NewExpression expression)
35+
{
36+
var arguments = expression.Arguments;
37+
38+
var collectionExpression = arguments[0];
39+
var collectionTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, collectionExpression);
40+
var itemSerializer = ArraySerializerHelper.GetItemSerializer(collectionTranslation.Serializer);
41+
42+
IBsonSerializer keySerializer;
43+
IBsonSerializer valueSerializer;
44+
AstExpression collectionTranslationAst;
45+
46+
if (itemSerializer is IBsonDocumentSerializer itemDocumentSerializer)
47+
{
48+
if (!itemDocumentSerializer.TryGetMemberSerializationInfo("Key", out var keyMemberSerializationInfo))
49+
{
50+
throw new ExpressionNotSupportedException(expression, because: $"serializer class {itemSerializer.GetType()} does not have a Key member");
51+
}
52+
keySerializer = keyMemberSerializationInfo.Serializer;
53+
54+
if (!itemDocumentSerializer.TryGetMemberSerializationInfo("Value", out var valueMemberSerializationInfo))
55+
{
56+
throw new ExpressionNotSupportedException(expression, because: $"serializer class {itemSerializer.GetType()} does not have a Value member");
57+
}
58+
valueSerializer = valueMemberSerializationInfo.Serializer;
59+
60+
if (keyMemberSerializationInfo.ElementName == "k" && valueMemberSerializationInfo.ElementName == "v")
61+
{
62+
collectionTranslationAst = collectionTranslation.Ast;
63+
}
64+
else
65+
{
66+
var pairVar = AstExpression.Var("pair");
67+
var computedDocumentAst = AstExpression.ComputedDocument([
68+
AstExpression.ComputedField("k", AstExpression.GetField(pairVar, keyMemberSerializationInfo.ElementName)),
69+
AstExpression.ComputedField("v", AstExpression.GetField(pairVar, valueMemberSerializationInfo.ElementName))
70+
]);
71+
72+
collectionTranslationAst = AstExpression.Map(collectionTranslation.Ast, pairVar, computedDocumentAst);
73+
}
74+
}
75+
else
76+
{
77+
throw new ExpressionNotSupportedException(expression);
78+
}
79+
80+
if (keySerializer is not IRepresentationConfigurable { Representation: BsonType.String })
81+
{
82+
throw new ExpressionNotSupportedException(expression, because: "key does not serialize as a string");
83+
}
84+
85+
var ast = AstExpression.Unary(AstUnaryOperator.ArrayToObject, collectionTranslationAst);
86+
var resultSerializer = CreateResultSerializer(keySerializer, valueSerializer);
87+
return new TranslatedExpression(expression, ast, resultSerializer);
88+
}
89+
90+
private static IBsonSerializer CreateResultSerializer(IBsonSerializer keySerializer, IBsonSerializer valueSerializer)
91+
{
92+
var dictionaryType = typeof(Dictionary<,>).MakeGenericType(keySerializer.ValueType, valueSerializer.ValueType);
93+
var serializerType = typeof(DictionaryInterfaceImplementerSerializer<,,>).MakeGenericType(dictionaryType, keySerializer.ValueType, valueSerializer.ValueType);
94+
95+
return (IBsonSerializer)Activator.CreateInstance(serializerType, DictionaryRepresentation.Document, keySerializer, valueSerializer);
96+
}
97+
}
98+
}

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewExpressionToAggregationExpressionTranslator.cs

+4
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ public static TranslatedExpression Translate(TranslationContext context, NewExpr
5050
{
5151
return NewKeyValuePairExpressionToAggregationExpressionTranslator.Translate(context, expression);
5252
}
53+
if (NewDictionaryExpressionToAggregationExpressionTranslator.CanTranslate(expression))
54+
{
55+
return NewDictionaryExpressionToAggregationExpressionTranslator.Translate(context, expression);
56+
}
5357
return MemberInitExpressionToAggregationExpressionTranslator.Translate(context, expression, expression, Array.Empty<MemberBinding>());
5458
}
5559
}

0 commit comments

Comments
 (0)