Skip to content

Commit

Permalink
[release/9.0] Add explicit casts for generated null in method calls. (#…
Browse files Browse the repository at this point in the history
…34523)

* Add explicit casts for generated null in method calls.

Fixes #34515

* Update src/EFCore.Design/Query/Internal/LinqToCSharpSyntaxTranslator.cs

Co-authored-by: Shay Rojansky <[email protected]>

---------

Co-authored-by: Shay Rojansky <[email protected]>
  • Loading branch information
AndriySvyryd and roji authored Aug 27, 2024
1 parent ab8e014 commit b59aa3a
Show file tree
Hide file tree
Showing 14 changed files with 47 additions and 28 deletions.
18 changes: 16 additions & 2 deletions src/EFCore.Design/Query/Internal/LinqToCSharpSyntaxTranslator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1949,7 +1949,7 @@ protected override Expression VisitMethodCall(MethodCallExpression call)

// Extension syntax
if (call.Method.IsDefined(typeof(ExtensionAttribute), inherit: false)
&& !(arguments[0].Expression is LiteralExpressionSyntax literal && literal.IsKind(SyntaxKind.NullLiteralExpression)))
&& !IsNull(arguments[0].Expression))
{
Result = InvocationExpression(
MemberAccessExpression(
Expand Down Expand Up @@ -2029,6 +2029,14 @@ void ProcessType(Type type)
}
}
}

static bool IsNull(ExpressionSyntax expr) => expr switch
{
LiteralExpressionSyntax literal when literal.IsKind(SyntaxKind.NullLiteralExpression) => true,
CastExpressionSyntax cast => IsNull(cast.Expression),
ParenthesizedExpressionSyntax parenthesized => IsNull(parenthesized.Expression),
_ => false
};
}

/// <inheritdoc />
Expand Down Expand Up @@ -2702,7 +2710,13 @@ private ExpressionSyntax[] TranslateList(IReadOnlyList<Expression> list)

var liftedStatementsPosition = _liftedState.Statements.Count;

var translated = Translate<ExpressionSyntax>(expression);
var translated = expression switch
{
// Add an explicit cast to avoid overload resolution ambiguity
ConstantExpression c
when c.Value is null => (ExpressionSyntax)_g.ConvertExpression(Generate(c.Type), GenerateValue(c.Value)),
_ => Translate<ExpressionSyntax>(expression)
};

if (_liftedState.Statements.Count > liftedStatementsPosition)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -891,7 +891,7 @@ public static void CreateAnnotations(RuntimeEntityType runtimeEntityType)
ISnapshot (InternalEntityEntry source) =>
{
var entity5 = ((CompiledModelTestBase.PrincipalBase)(source.Entity));
return ((ISnapshot)(new Snapshot<long?, Guid, object, object>((source.GetCurrentValue<long?>(id) == null ? null : ((ValueComparer<long?>)(((IProperty)id).GetKeyValueComparer())).Snapshot(source.GetCurrentValue<long?>(id))), ((ValueComparer<Guid>)(((IProperty)alternateId).GetKeyValueComparer())).Snapshot(source.GetCurrentValue<Guid>(alternateId)), PrincipalBaseUnsafeAccessors._ownedField(entity5), null)));
return ((ISnapshot)(new Snapshot<long?, Guid, object, object>((source.GetCurrentValue<long?>(id) == null ? null : ((ValueComparer<long?>)(((IProperty)id).GetKeyValueComparer())).Snapshot(source.GetCurrentValue<long?>(id))), ((ValueComparer<Guid>)(((IProperty)alternateId).GetKeyValueComparer())).Snapshot(source.GetCurrentValue<Guid>(alternateId)), PrincipalBaseUnsafeAccessors._ownedField(entity5), (object)(null))));
});
runtimeEntityType.Counts = new PropertyCounts(
propertyCount: 15,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ public static void CreateAnnotations(RuntimeEntityType runtimeEntityType)
ISnapshot (InternalEntityEntry source) =>
{
var entity5 = ((CompiledModelTestBase.PrincipalDerived<CompiledModelTestBase.DependentBase<byte?>>)(source.Entity));
return ((ISnapshot)(new Snapshot<long?, Guid, object, object, object, object, object>((source.GetCurrentValue<long?>(id) == null ? null : ((ValueComparer<long?>)(((IProperty)id).GetKeyValueComparer())).Snapshot(source.GetCurrentValue<long?>(id))), ((ValueComparer<Guid>)(((IProperty)alternateId).GetKeyValueComparer())).Snapshot(source.GetCurrentValue<Guid>(alternateId)), PrincipalBaseUnsafeAccessors._ownedField(entity5), null, PrincipalDerivedUnsafeAccessors<CompiledModelTestBase.DependentBase<byte?>>.Dependent(entity5), SnapshotFactoryFactory.SnapshotCollection(PrincipalDerivedUnsafeAccessors<CompiledModelTestBase.DependentBase<byte?>>.ManyOwned(entity5)), null)));
return ((ISnapshot)(new Snapshot<long?, Guid, object, object, object, object, object>((source.GetCurrentValue<long?>(id) == null ? null : ((ValueComparer<long?>)(((IProperty)id).GetKeyValueComparer())).Snapshot(source.GetCurrentValue<long?>(id))), ((ValueComparer<Guid>)(((IProperty)alternateId).GetKeyValueComparer())).Snapshot(source.GetCurrentValue<Guid>(alternateId)), PrincipalBaseUnsafeAccessors._ownedField(entity5), (object)(null), PrincipalDerivedUnsafeAccessors<CompiledModelTestBase.DependentBase<byte?>>.Dependent(entity5), SnapshotFactoryFactory.SnapshotCollection(PrincipalDerivedUnsafeAccessors<CompiledModelTestBase.DependentBase<byte?>>.ManyOwned(entity5)), (object)(null))));
});
runtimeEntityType.Counts = new PropertyCounts(
propertyCount: 15,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ public void Method_call_extension_with_null_this()
Call(
LinqExpressionToRoslynTranslatorExtensions.SomeExtensionMethod,
Constant(null, typeof(LinqExpressionToRoslynTranslatorExtensionType))),
"LinqExpressionToRoslynTranslatorExtensions.SomeExtension(null)");
"LinqExpressionToRoslynTranslatorExtensions.SomeExtension((LinqExpressionToRoslynTranslatorExtensionType)(null))");

[Fact]
public void Method_call_generic()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using System;
using System.Collections.Generic;
using System.Reflection;
using System.Text.Json;
using Microsoft.EntityFrameworkCore.ChangeTracking;
using Microsoft.EntityFrameworkCore.ChangeTracking.Internal;
using Microsoft.EntityFrameworkCore.InMemory.Storage.Internal;
Expand Down Expand Up @@ -92,18 +93,18 @@ public static RuntimeEntityType Create(RuntimeModel model, RuntimeEntityType bas
bool (int v1, int v2) => v1 == v2,
int (int v) => v,
int (int v) => v),
providerValueComparer: new ValueComparer<int>(
bool (int v1, int v2) => v1 == v2,
int (int v) => v,
int (int v) => v),
converter: new ValueConverter<int, int>(
int (int i) => i,
int (int i) => i),
jsonValueReaderWriter: new JsonConvertedValueReaderWriter<int, int>(
JsonInt32ReaderWriter.Instance,
new ValueConverter<int, int>(
int (int i) => i,
int (int i) => i)));
providerValueComparer: new ValueComparer<string>(
bool (string v1, string v2) => v1 == v2,
int (string v) => ((object)v).GetHashCode(),
string (string v) => v),
converter: new ValueConverter<int, string>(
string (int i) => JsonSerializer.Serialize(i, (JsonSerializerOptions)(null)),
int (string i) => JsonSerializer.Deserialize<int>(i, (JsonSerializerOptions)(null))),
jsonValueReaderWriter: new JsonConvertedValueReaderWriter<int, string>(
JsonStringReaderWriter.Instance,
new ValueConverter<int, string>(
string (int i) => JsonSerializer.Serialize(i, (JsonSerializerOptions)(null)),
int (string i) => JsonSerializer.Deserialize<int>(i, (JsonSerializerOptions)(null)))));
id.SetCurrentValueComparer(new EntryCurrentValueComparer<int>(id));

var key = runtimeEntityType.AddKey(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#nullable enable

using System.Runtime.CompilerServices;
using System.Text.Json;
using Microsoft.EntityFrameworkCore.Design.Internal;
using Microsoft.EntityFrameworkCore.InMemory.Storage.Internal;
using Microsoft.EntityFrameworkCore.Internal;
Expand Down Expand Up @@ -342,16 +343,19 @@ public virtual Task Custom_value_converter()
modelBuilder => modelBuilder.Entity(
"MyEntity", e =>
{
e.Property<int>("Id").HasConversion(i => i, i => i);
e.Property<int>("Id").HasConversion(
i => JsonSerializer.Serialize(i, (JsonSerializerOptions?)default),
i => JsonSerializer.Deserialize<int>(i, (JsonSerializerOptions?)null));
e.HasKey("Id");
}),
model =>
{
var entityType = model.GetEntityTypes().Single();

var converter = entityType.FindProperty("Id")!.GetTypeMapping().Converter!;
Assert.Equal(1, converter.ConvertToProvider(1));
});
Assert.Equal("1", converter.ConvertToProvider(1));
},
options: new CompiledModelCodeGenerationOptions { UseNullableReferenceTypes = true, ForNativeAot = true });

[ConditionalFact]
public virtual Task Custom_value_comparer()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1014,7 +1014,7 @@ public static void CreateAnnotations(RuntimeEntityType runtimeEntityType)
ISnapshot (InternalEntityEntry source) =>
{
var entity7 = ((CompiledModelTestBase.PrincipalBase)(source.Entity));
return ((ISnapshot)(new Snapshot<long?, Guid, object, object>((source.GetCurrentValue<long?>(id) == null ? null : ((ValueComparer<long?>)(((IProperty)id).GetKeyValueComparer())).Snapshot(source.GetCurrentValue<long?>(id))), ((ValueComparer<Guid>)(((IProperty)alternateId).GetKeyValueComparer())).Snapshot(source.GetCurrentValue<Guid>(alternateId)), PrincipalBaseUnsafeAccessors._ownedField(entity7), null)));
return ((ISnapshot)(new Snapshot<long?, Guid, object, object>((source.GetCurrentValue<long?>(id) == null ? null : ((ValueComparer<long?>)(((IProperty)id).GetKeyValueComparer())).Snapshot(source.GetCurrentValue<long?>(id))), ((ValueComparer<Guid>)(((IProperty)alternateId).GetKeyValueComparer())).Snapshot(source.GetCurrentValue<Guid>(alternateId)), PrincipalBaseUnsafeAccessors._ownedField(entity7), (object)(null))));
});
runtimeEntityType.Counts = new PropertyCounts(
propertyCount: 14,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ public static void CreateAnnotations(RuntimeEntityType runtimeEntityType)
ISnapshot (InternalEntityEntry source) =>
{
var entity7 = ((CompiledModelTestBase.PrincipalDerived<CompiledModelTestBase.DependentBase<byte?>>)(source.Entity));
return ((ISnapshot)(new Snapshot<long?, Guid, object, object, object, object, object>((source.GetCurrentValue<long?>(id) == null ? null : ((ValueComparer<long?>)(((IProperty)id).GetKeyValueComparer())).Snapshot(source.GetCurrentValue<long?>(id))), ((ValueComparer<Guid>)(((IProperty)alternateId).GetKeyValueComparer())).Snapshot(source.GetCurrentValue<Guid>(alternateId)), PrincipalBaseUnsafeAccessors._ownedField(entity7), null, PrincipalDerivedUnsafeAccessors<CompiledModelTestBase.DependentBase<byte?>>.Dependent(entity7), SnapshotFactoryFactory.SnapshotCollection(PrincipalDerivedUnsafeAccessors<CompiledModelTestBase.DependentBase<byte?>>.ManyOwned(entity7)), null)));
return ((ISnapshot)(new Snapshot<long?, Guid, object, object, object, object, object>((source.GetCurrentValue<long?>(id) == null ? null : ((ValueComparer<long?>)(((IProperty)id).GetKeyValueComparer())).Snapshot(source.GetCurrentValue<long?>(id))), ((ValueComparer<Guid>)(((IProperty)alternateId).GetKeyValueComparer())).Snapshot(source.GetCurrentValue<Guid>(alternateId)), PrincipalBaseUnsafeAccessors._ownedField(entity7), (object)(null), PrincipalDerivedUnsafeAccessors<CompiledModelTestBase.DependentBase<byte?>>.Dependent(entity7), SnapshotFactoryFactory.SnapshotCollection(PrincipalDerivedUnsafeAccessors<CompiledModelTestBase.DependentBase<byte?>>.ManyOwned(entity7)), (object)(null))));
});
runtimeEntityType.Counts = new PropertyCounts(
propertyCount: 14,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1044,7 +1044,7 @@ public static void CreateAnnotations(RuntimeEntityType runtimeEntityType)
ISnapshot (InternalEntityEntry source) =>
{
var entity7 = ((CompiledModelTestBase.PrincipalBase)(source.Entity));
return ((ISnapshot)(new Snapshot<long?, Guid, object, object>((source.GetCurrentValue<long?>(id) == null ? null : ((ValueComparer<long?>)(((IProperty)id).GetKeyValueComparer())).Snapshot(source.GetCurrentValue<long?>(id))), ((ValueComparer<Guid>)(((IProperty)alternateId).GetKeyValueComparer())).Snapshot(source.GetCurrentValue<Guid>(alternateId)), PrincipalBaseUnsafeAccessors._ownedField(entity7), null)));
return ((ISnapshot)(new Snapshot<long?, Guid, object, object>((source.GetCurrentValue<long?>(id) == null ? null : ((ValueComparer<long?>)(((IProperty)id).GetKeyValueComparer())).Snapshot(source.GetCurrentValue<long?>(id))), ((ValueComparer<Guid>)(((IProperty)alternateId).GetKeyValueComparer())).Snapshot(source.GetCurrentValue<Guid>(alternateId)), PrincipalBaseUnsafeAccessors._ownedField(entity7), (object)(null))));
});
runtimeEntityType.Counts = new PropertyCounts(
propertyCount: 15,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ public static void CreateAnnotations(RuntimeEntityType runtimeEntityType)
ISnapshot (InternalEntityEntry source) =>
{
var entity7 = ((CompiledModelTestBase.PrincipalDerived<CompiledModelTestBase.DependentBase<byte?>>)(source.Entity));
return ((ISnapshot)(new Snapshot<long?, Guid, object, object, object, object, object>((source.GetCurrentValue<long?>(id) == null ? null : ((ValueComparer<long?>)(((IProperty)id).GetKeyValueComparer())).Snapshot(source.GetCurrentValue<long?>(id))), ((ValueComparer<Guid>)(((IProperty)alternateId).GetKeyValueComparer())).Snapshot(source.GetCurrentValue<Guid>(alternateId)), PrincipalBaseUnsafeAccessors._ownedField(entity7), null, PrincipalDerivedUnsafeAccessors<CompiledModelTestBase.DependentBase<byte?>>.Dependent(entity7), SnapshotFactoryFactory.SnapshotCollection(PrincipalDerivedUnsafeAccessors<CompiledModelTestBase.DependentBase<byte?>>.ManyOwned(entity7)), null)));
return ((ISnapshot)(new Snapshot<long?, Guid, object, object, object, object, object>((source.GetCurrentValue<long?>(id) == null ? null : ((ValueComparer<long?>)(((IProperty)id).GetKeyValueComparer())).Snapshot(source.GetCurrentValue<long?>(id))), ((ValueComparer<Guid>)(((IProperty)alternateId).GetKeyValueComparer())).Snapshot(source.GetCurrentValue<Guid>(alternateId)), PrincipalBaseUnsafeAccessors._ownedField(entity7), (object)(null), PrincipalDerivedUnsafeAccessors<CompiledModelTestBase.DependentBase<byte?>>.Dependent(entity7), SnapshotFactoryFactory.SnapshotCollection(PrincipalDerivedUnsafeAccessors<CompiledModelTestBase.DependentBase<byte?>>.ManyOwned(entity7)), (object)(null))));
});
runtimeEntityType.Counts = new PropertyCounts(
propertyCount: 15,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -936,7 +936,7 @@ public static void CreateAnnotations(RuntimeEntityType runtimeEntityType)
ISnapshot (InternalEntityEntry source) =>
{
var entity7 = ((CompiledModelTestBase.PrincipalBase)(source.Entity));
return ((ISnapshot)(new Snapshot<long?, Guid, object, object>((source.GetCurrentValue<long?>(id) == null ? null : ((ValueComparer<long?>)(((IProperty)id).GetKeyValueComparer())).Snapshot(source.GetCurrentValue<long?>(id))), ((ValueComparer<Guid>)(((IProperty)alternateId).GetKeyValueComparer())).Snapshot(source.GetCurrentValue<Guid>(alternateId)), PrincipalBaseUnsafeAccessors._ownedField(entity7), null)));
return ((ISnapshot)(new Snapshot<long?, Guid, object, object>((source.GetCurrentValue<long?>(id) == null ? null : ((ValueComparer<long?>)(((IProperty)id).GetKeyValueComparer())).Snapshot(source.GetCurrentValue<long?>(id))), ((ValueComparer<Guid>)(((IProperty)alternateId).GetKeyValueComparer())).Snapshot(source.GetCurrentValue<Guid>(alternateId)), PrincipalBaseUnsafeAccessors._ownedField(entity7), (object)(null))));
});
runtimeEntityType.Counts = new PropertyCounts(
propertyCount: 15,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ public static void CreateAnnotations(RuntimeEntityType runtimeEntityType)
ISnapshot (InternalEntityEntry source) =>
{
var entity7 = ((CompiledModelTestBase.PrincipalDerived<CompiledModelTestBase.DependentBase<byte?>>)(source.Entity));
return ((ISnapshot)(new Snapshot<long?, Guid, object, object, object, object, object>((source.GetCurrentValue<long?>(id) == null ? null : ((ValueComparer<long?>)(((IProperty)id).GetKeyValueComparer())).Snapshot(source.GetCurrentValue<long?>(id))), ((ValueComparer<Guid>)(((IProperty)alternateId).GetKeyValueComparer())).Snapshot(source.GetCurrentValue<Guid>(alternateId)), PrincipalBaseUnsafeAccessors._ownedField(entity7), null, PrincipalDerivedUnsafeAccessors<CompiledModelTestBase.DependentBase<byte?>>.Dependent(entity7), SnapshotFactoryFactory.SnapshotCollection(PrincipalDerivedUnsafeAccessors<CompiledModelTestBase.DependentBase<byte?>>.ManyOwned(entity7)), null)));
return ((ISnapshot)(new Snapshot<long?, Guid, object, object, object, object, object>((source.GetCurrentValue<long?>(id) == null ? null : ((ValueComparer<long?>)(((IProperty)id).GetKeyValueComparer())).Snapshot(source.GetCurrentValue<long?>(id))), ((ValueComparer<Guid>)(((IProperty)alternateId).GetKeyValueComparer())).Snapshot(source.GetCurrentValue<Guid>(alternateId)), PrincipalBaseUnsafeAccessors._ownedField(entity7), (object)(null), PrincipalDerivedUnsafeAccessors<CompiledModelTestBase.DependentBase<byte?>>.Dependent(entity7), SnapshotFactoryFactory.SnapshotCollection(PrincipalDerivedUnsafeAccessors<CompiledModelTestBase.DependentBase<byte?>>.ManyOwned(entity7)), (object)(null))));
});
runtimeEntityType.Counts = new PropertyCounts(
propertyCount: 15,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -949,7 +949,7 @@ public static void CreateAnnotations(RuntimeEntityType runtimeEntityType)
ISnapshot (InternalEntityEntry source) =>
{
var entity7 = ((CompiledModelTestBase.PrincipalBase)(source.Entity));
return ((ISnapshot)(new Snapshot<long?, Guid, object, object>((source.GetCurrentValue<long?>(id) == null ? null : ((ValueComparer<long?>)(((IProperty)id).GetKeyValueComparer())).Snapshot(source.GetCurrentValue<long?>(id))), ((ValueComparer<Guid>)(((IProperty)alternateId).GetKeyValueComparer())).Snapshot(source.GetCurrentValue<Guid>(alternateId)), PrincipalBaseUnsafeAccessors._ownedField(entity7), null)));
return ((ISnapshot)(new Snapshot<long?, Guid, object, object>((source.GetCurrentValue<long?>(id) == null ? null : ((ValueComparer<long?>)(((IProperty)id).GetKeyValueComparer())).Snapshot(source.GetCurrentValue<long?>(id))), ((ValueComparer<Guid>)(((IProperty)alternateId).GetKeyValueComparer())).Snapshot(source.GetCurrentValue<Guid>(alternateId)), PrincipalBaseUnsafeAccessors._ownedField(entity7), (object)(null))));
});
runtimeEntityType.Counts = new PropertyCounts(
propertyCount: 16,
Expand Down
Loading

0 comments on commit b59aa3a

Please sign in to comment.