diff --git a/eng/MSBuild/LegacySupport.props b/eng/MSBuild/LegacySupport.props
index 2cfe7b73964..842951ab867 100644
--- a/eng/MSBuild/LegacySupport.props
+++ b/eng/MSBuild/LegacySupport.props
@@ -43,6 +43,10 @@
+
+
+
+
diff --git a/eng/packages/TestOnly.props b/eng/packages/TestOnly.props
index 2bde3b34e05..4c78b8dcbe8 100644
--- a/eng/packages/TestOnly.props
+++ b/eng/packages/TestOnly.props
@@ -7,6 +7,7 @@
+
@@ -20,6 +21,7 @@
+
diff --git a/eng/spellchecking_exclusions.dic b/eng/spellchecking_exclusions.dic
index 2fc9b74699b..72596816516 100644
Binary files a/eng/spellchecking_exclusions.dic and b/eng/spellchecking_exclusions.dic differ
diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/AdditionalPropertiesDictionary.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/AdditionalPropertiesDictionary.cs
index 616ad284198..4a681d4679a 100644
--- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/AdditionalPropertiesDictionary.cs
+++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/AdditionalPropertiesDictionary.cs
@@ -4,13 +4,21 @@
using System;
using System.Collections;
using System.Collections.Generic;
+using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Globalization;
using System.Linq;
+using Microsoft.Shared.Diagnostics;
+
+#pragma warning disable S1144 // Unused private types or members should be removed
+#pragma warning disable S2365 // Properties should not make collection or array copies
+#pragma warning disable S3604 // Member initializer values should not be redundant
namespace Microsoft.Extensions.AI;
/// Provides a dictionary used as the AdditionalProperties dictionary on Microsoft.Extensions.AI objects.
+[DebuggerTypeProxy(typeof(DebugView))]
+[DebuggerDisplay("Count = {Count}")]
public sealed class AdditionalPropertiesDictionary : IDictionary, IReadOnlyDictionary
{
/// The underlying dictionary.
@@ -77,6 +85,25 @@ public object? this[string key]
///
public void Add(string key, object? value) => _dictionary.Add(key, value);
+ /// Attempts to add the specified key and value to the dictionary.
+ /// The key of the element to add.
+ /// The value of the element to add.
+ /// if the key/value pair was added to the dictionary successfully; otherwise, .
+ public bool TryAdd(string key, object? value)
+ {
+#if NET
+ return _dictionary.TryAdd(key, value);
+#else
+ if (!_dictionary.ContainsKey(key))
+ {
+ _dictionary.Add(key, value);
+ return true;
+ }
+
+ return false;
+#endif
+ }
+
///
void ICollection>.Add(KeyValuePair item) => ((ICollection>)_dictionary).Add(item);
@@ -93,11 +120,17 @@ public object? this[string key]
void ICollection>.CopyTo(KeyValuePair[] array, int arrayIndex) =>
((ICollection>)_dictionary).CopyTo(array, arrayIndex);
+ ///
+ /// Returns an enumerator that iterates through the .
+ ///
+ /// An that enumerates the contents of the .
+ public Enumerator GetEnumerator() => new(_dictionary.GetEnumerator());
+
///
- public IEnumerator> GetEnumerator() => _dictionary.GetEnumerator();
+ IEnumerator> IEnumerable>.GetEnumerator() => GetEnumerator();
///
- IEnumerator IEnumerable.GetEnumerator() => _dictionary.GetEnumerator();
+ IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
///
public bool Remove(string key) => _dictionary.Remove(key);
@@ -156,4 +189,59 @@ public bool TryGetValue(string key, [NotNullWhen(true)] out T? value)
value = default;
return false;
}
+
+ /// Enumerates the elements of an .
+ public struct Enumerator : IEnumerator>
+ {
+ /// The wrapped dictionary enumerator.
+ private Dictionary.Enumerator _dictionaryEnumerator;
+
+ /// Initializes a new instance of the struct with the dictionary enumerator to wrap.
+ /// The dictionary enumerator to wrap.
+ internal Enumerator(Dictionary.Enumerator dictionaryEnumerator)
+ {
+ _dictionaryEnumerator = dictionaryEnumerator;
+ }
+
+ ///
+ public KeyValuePair Current => _dictionaryEnumerator.Current;
+
+ ///
+ object IEnumerator.Current => Current;
+
+ ///
+ public void Dispose() => _dictionaryEnumerator.Dispose();
+
+ ///
+ public bool MoveNext() => _dictionaryEnumerator.MoveNext();
+
+ ///
+ public void Reset() => Reset(ref _dictionaryEnumerator);
+
+ /// Calls on an enumerator.
+ private static void Reset(ref TEnumerator enumerator)
+ where TEnumerator : struct, IEnumerator
+ {
+ enumerator.Reset();
+ }
+ }
+
+ /// Provides a debugger view for the collection.
+ private sealed class DebugView(AdditionalPropertiesDictionary properties)
+ {
+ private readonly AdditionalPropertiesDictionary _properties = Throw.IfNull(properties);
+
+ [DebuggerBrowsable(DebuggerBrowsableState.RootHidden)]
+ public AdditionalProperty[] Items => (from p in _properties select new AdditionalProperty(p.Key, p.Value)).ToArray();
+
+ [DebuggerDisplay("{Value}", Name = "[{Key}]")]
+ public readonly struct AdditionalProperty(string key, object? value)
+ {
+ [DebuggerBrowsable(DebuggerBrowsableState.Collapsed)]
+ public string Key { get; } = key;
+
+ [DebuggerBrowsable(DebuggerBrowsableState.Collapsed)]
+ public object? Value { get; } = value;
+ }
+ }
}
diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/CHANGELOG.md b/src/Libraries/Microsoft.Extensions.AI.Abstractions/CHANGELOG.md
new file mode 100644
index 00000000000..6b347a8c09d
--- /dev/null
+++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/CHANGELOG.md
@@ -0,0 +1,19 @@
+# Release History
+
+## 9.0.0-preview.9.24525.1
+
+- Lowered the required version of System.Text.Json to 8.0.5 when targeting net8.0 or older.
+- Annotated `FunctionCallContent.Exception` and `FunctionResultContent.Exception` as `[JsonIgnore]`, such that they're ignored when serializing instances with `JsonSerializer`. The corresponding constructors accepting an `Exception` were removed.
+- Annotated `ChatCompletion.Message` as `[JsonIgnore]`, such that it's ignored when serializing instances with `JsonSerializer`.
+- Added the `FunctionCallContent.CreateFromParsedArguments` method.
+- Added the `AdditionalPropertiesDictionary.TryGetValue` method.
+- Added the `StreamingChatCompletionUpdate.ModelId` property and removed the `AIContent.ModelId` property.
+- Renamed the `GenerateAsync` extension method on `IEmbeddingGenerator<,>` to `GenerateEmbeddingsAsync` and updated it to return `Embedding` rather than `GeneratedEmbeddings`.
+- Added `GenerateAndZipAsync` and `GenerateEmbeddingVectorAsync` extension methods for `IEmbeddingGenerator<,>`.
+- Added the `EmbeddingGeneratorOptions.Dimensions` property.
+- Added the `ChatOptions.TopK` property.
+- Normalized `null` inputs in `TextContent` to be empty strings.
+
+## 9.0.0-preview.9.24507.7
+
+Initial Preview
diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatOptions.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatOptions.cs
index 4edbed900b4..0a4f6f58296 100644
--- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatOptions.cs
+++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatOptions.cs
@@ -27,6 +27,9 @@ public class ChatOptions
/// Gets or sets the presence penalty for generating chat responses.
public float? PresencePenalty { get; set; }
+ /// Gets or sets a seed value used by a service to control the reproducability of results.
+ public long? Seed { get; set; }
+
///
/// Gets or sets the response format for the chat request.
///
@@ -74,6 +77,7 @@ public virtual ChatOptions Clone()
TopK = TopK,
FrequencyPenalty = FrequencyPenalty,
PresencePenalty = PresencePenalty,
+ Seed = Seed,
ResponseFormat = ResponseFormat,
ModelId = ModelId,
ToolMode = ToolMode,
diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Microsoft.Extensions.AI.Abstractions.csproj b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Microsoft.Extensions.AI.Abstractions.csproj
index 8f00d6b9271..b96b4dca920 100644
--- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Microsoft.Extensions.AI.Abstractions.csproj
+++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Microsoft.Extensions.AI.Abstractions.csproj
@@ -17,9 +17,11 @@
$(TargetFrameworks);netstandard2.0
$(NoWarn);CA2227;CA1034;SA1316;S3253
true
+ true
+ true
true
true
true
diff --git a/src/Libraries/Microsoft.Extensions.AI/Utilities/AIJsonSchemaCreateOptions.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonSchemaCreateOptions.cs
similarity index 100%
rename from src/Libraries/Microsoft.Extensions.AI/Utilities/AIJsonSchemaCreateOptions.cs
rename to src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonSchemaCreateOptions.cs
diff --git a/src/Libraries/Microsoft.Extensions.AI/Utilities/AIJsonUtilities.Defaults.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Defaults.cs
similarity index 98%
rename from src/Libraries/Microsoft.Extensions.AI/Utilities/AIJsonUtilities.Defaults.cs
rename to src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Defaults.cs
index 94340160cb1..de2c2a695b6 100644
--- a/src/Libraries/Microsoft.Extensions.AI/Utilities/AIJsonUtilities.Defaults.cs
+++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Defaults.cs
@@ -23,11 +23,11 @@ private static JsonSerializerOptions CreateDefaultOptions()
{
// If reflection-based serialization is enabled by default, use it, as it's the most permissive in terms of what it can serialize,
// and we want to be flexible in terms of what can be put into the various collections in the object model.
- // Otherwise, use the source-generated options to enable Native AOT.
+ // Otherwise, use the source-generated options to enable trimming and Native AOT.
if (JsonSerializer.IsReflectionEnabledByDefault)
{
- // Keep in sync with the JsonSourceGenerationOptions on JsonContext below.
+ // Keep in sync with the JsonSourceGenerationOptions attribute on JsonContext below.
JsonSerializerOptions options = new(JsonSerializerDefaults.Web)
{
TypeInfoResolver = new DefaultJsonTypeInfoResolver(),
diff --git a/src/Libraries/Microsoft.Extensions.AI/Utilities/AIJsonUtilities.Schema.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.cs
similarity index 81%
rename from src/Libraries/Microsoft.Extensions.AI/Utilities/AIJsonUtilities.Schema.cs
rename to src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.cs
index 46fe45342f2..b555148df8b 100644
--- a/src/Libraries/Microsoft.Extensions.AI/Utilities/AIJsonUtilities.Schema.cs
+++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.cs
@@ -5,17 +5,22 @@
using System.Collections.Concurrent;
using System.ComponentModel;
using System.Diagnostics;
+#if !NET9_0_OR_GREATER
+using System.Diagnostics.CodeAnalysis;
+#endif
using System.Linq;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Text.Json;
using System.Text.Json.Nodes;
using System.Text.Json.Schema;
+using System.Text.Json.Serialization;
using Microsoft.Shared.Diagnostics;
#pragma warning disable S1121 // Assignments should not be made from within sub-expressions
#pragma warning disable S107 // Methods should not have too many parameters
#pragma warning disable S1075 // URIs should not be hardcoded
+#pragma warning disable SA1118 // Parameter should not span multiple lines
using FunctionParameterKey = (
System.Type? Type,
@@ -138,8 +143,6 @@ public static JsonElement CreateJsonSchema(
JsonSerializerOptions? serializerOptions = null,
AIJsonSchemaCreateOptions? inferenceOptions = null)
{
- _ = Throw.IfNull(serializerOptions);
-
serializerOptions ??= DefaultOptions;
inferenceOptions ??= AIJsonSchemaCreateOptions.Default;
@@ -176,6 +179,11 @@ private static JsonElement GetJsonSchemaCached(JsonSerializerOptions options, Fu
#endif
}
+#if !NET9_0_OR_GREATER
+ [UnconditionalSuppressMessage("Trimming", "IL2026:Members annotated with 'RequiresUnreferencedCodeAttribute' require dynamic access",
+ Justification = "Pre STJ-9 schema extraction can fail with a runtime exception if certain reflection metadata have been trimmed. " +
+ "The exception message will guide users to turn off 'IlcTrimMetadata' which resolves all issues.")]
+#endif
private static JsonElement GetJsonSchemaCore(JsonSerializerOptions options, FunctionParameterKey key)
{
_ = Throw.IfNull(options);
@@ -238,16 +246,9 @@ JsonNode TransformSchemaNode(JsonSchemaExporterContext ctx, JsonNode schema)
const string DefaultPropertyName = "default";
const string RefPropertyName = "$ref";
- // Find the first DescriptionAttribute, starting first from the property, then the parameter, and finally the type itself.
- Type descAttrType = typeof(DescriptionAttribute);
- var descriptionAttribute =
- GetAttrs(descAttrType, ctx.PropertyInfo?.AttributeProvider)?.FirstOrDefault() ??
- GetAttrs(descAttrType, ctx.PropertyInfo?.AssociatedParameter?.AttributeProvider)?.FirstOrDefault() ??
- GetAttrs(descAttrType, ctx.TypeInfo.Type)?.FirstOrDefault();
-
- if (descriptionAttribute is DescriptionAttribute attr)
+ if (ctx.ResolveAttribute() is { } attr)
{
- ConvertSchemaToObject(ref schema).Insert(0, DescriptionPropertyName, (JsonNode)attr.Description);
+ ConvertSchemaToObject(ref schema).InsertAtStart(DescriptionPropertyName, (JsonNode)attr.Description);
}
if (schema is JsonObject objSchema)
@@ -270,7 +271,7 @@ JsonNode TransformSchemaNode(JsonSchemaExporterContext ctx, JsonNode schema)
// Include the type keyword in enum types
if (key.IncludeTypeInEnumSchemas && ctx.TypeInfo.Type.IsEnum && objSchema.ContainsKey(EnumPropertyName) && !objSchema.ContainsKey(TypePropertyName))
{
- objSchema.Insert(0, TypePropertyName, "string");
+ objSchema.InsertAtStart(TypePropertyName, "string");
}
// Disallow additional properties in object schemas
@@ -278,24 +279,24 @@ JsonNode TransformSchemaNode(JsonSchemaExporterContext ctx, JsonNode schema)
{
objSchema.Add(AdditionalPropertiesPropertyName, (JsonNode)false);
}
- }
-
- if (ctx.Path.IsEmpty)
- {
- // We are at the root-level schema node, update/append parameter-specific metadata
// Some consumers of the JSON schema, including Ollama as of v0.3.13, don't understand
// schemas with "type": [...], and only understand "type" being a single value.
// STJ represents .NET integer types as ["string", "integer"], which will then lead to an error.
- if (TypeIsArrayContainingInteger(schema))
+ if (TypeIsIntegerWithStringNumberHandling(ctx, objSchema))
{
// We don't want to emit any array for "type". In this case we know it contains "integer"
// so reduce the type to that alone, assuming it's the most specific type.
- // This makes schemas for Int32 (etc) work with Ollama
+ // This makes schemas for Int32 (etc) work with Ollama.
JsonObject obj = ConvertSchemaToObject(ref schema);
obj[TypePropertyName] = "integer";
_ = obj.Remove(PatternPropertyName);
}
+ }
+
+ if (ctx.Path.IsEmpty)
+ {
+ // We are at the root-level schema node, update/append parameter-specific metadata
if (!string.IsNullOrWhiteSpace(key.Description))
{
@@ -305,7 +306,7 @@ JsonNode TransformSchemaNode(JsonSchemaExporterContext ctx, JsonNode schema)
if (index < 0)
{
// If there's no description property, insert it at the beginning of the doc.
- obj.Insert(0, DescriptionPropertyName, (JsonNode)key.Description!);
+ obj.InsertAtStart(DescriptionPropertyName, (JsonNode)key.Description!);
}
else
{
@@ -323,15 +324,12 @@ JsonNode TransformSchemaNode(JsonSchemaExporterContext ctx, JsonNode schema)
if (key.IncludeSchemaUri)
{
// The $schema property must be the first keyword in the object
- ConvertSchemaToObject(ref schema).Insert(0, SchemaPropertyName, (JsonNode)SchemaKeywordUri);
+ ConvertSchemaToObject(ref schema).InsertAtStart(SchemaPropertyName, (JsonNode)SchemaKeywordUri);
}
}
return schema;
- static object[]? GetAttrs(Type attrType, ICustomAttributeProvider? provider) =>
- provider?.GetCustomAttributes(attrType, inherit: false);
-
static JsonObject ConvertSchemaToObject(ref JsonNode schema)
{
JsonObject obj;
@@ -354,22 +352,82 @@ static JsonObject ConvertSchemaToObject(ref JsonNode schema)
}
}
- private static bool TypeIsArrayContainingInteger(JsonNode schema)
+ private static bool TypeIsIntegerWithStringNumberHandling(JsonSchemaExporterContext ctx, JsonObject schema)
{
- if (schema["type"] is JsonArray typeArray)
+ if (ctx.TypeInfo.NumberHandling is not JsonNumberHandling.Strict && schema["type"] is JsonArray typeArray)
{
- foreach (var entry in typeArray)
+ int count = 0;
+ foreach (JsonNode? entry in typeArray)
{
- if (entry?.GetValueKind() == JsonValueKind.String && entry.GetValue() == "integer")
+ if (entry?.GetValueKind() is JsonValueKind.String &&
+ entry.GetValue() is "integer" or "string")
{
- return true;
+ count++;
}
}
+
+ return count == typeArray.Count;
}
return false;
}
+ private static void InsertAtStart(this JsonObject jsonObject, string key, JsonNode value)
+ {
+#if NET9_0_OR_GREATER
+ jsonObject.Insert(0, key, value);
+#else
+ jsonObject.Remove(key);
+ var copiedEntries = jsonObject.ToArray();
+ jsonObject.Clear();
+
+ jsonObject.Add(key, value);
+ foreach (var entry in copiedEntries)
+ {
+ jsonObject[entry.Key] = entry.Value;
+ }
+#endif
+ }
+
+#if !NET9_0_OR_GREATER
+ private static int IndexOf(this JsonObject jsonObject, string key)
+ {
+ int i = 0;
+ foreach (var entry in jsonObject)
+ {
+ if (string.Equals(entry.Key, key, StringComparison.Ordinal))
+ {
+ return i;
+ }
+
+ i++;
+ }
+
+ return -1;
+ }
+#endif
+
+ private static TAttribute? ResolveAttribute(this JsonSchemaExporterContext ctx)
+ where TAttribute : Attribute
+ {
+ // Resolve attributes from locations in the following order:
+ // 1. Property-level attributes
+ // 2. Parameter-level attributes and
+ // 3. Type-level attributes.
+ return
+#if NET9_0_OR_GREATER
+ GetAttrs(ctx.PropertyInfo?.AttributeProvider) ??
+ GetAttrs(ctx.PropertyInfo?.AssociatedParameter?.AttributeProvider) ??
+#else
+ GetAttrs(ctx.PropertyAttributeProvider) ??
+ GetAttrs(ctx.ParameterInfo) ??
+#endif
+ GetAttrs(ctx.TypeInfo.Type);
+
+ static TAttribute? GetAttrs(ICustomAttributeProvider? provider) =>
+ (TAttribute?)provider?.GetCustomAttributes(typeof(TAttribute), inherit: false).FirstOrDefault();
+ }
+
private static JsonElement ParseJsonElement(ReadOnlySpan utf8Json)
{
Utf8JsonReader reader = new(utf8Json);
diff --git a/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceChatClient.cs b/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceChatClient.cs
index ecc41140b27..ba76f5c3c90 100644
--- a/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceChatClient.cs
+++ b/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceChatClient.cs
@@ -285,6 +285,7 @@ private ChatCompletionsOptions ToAzureAIOptions(IList chatContents,
result.NucleusSamplingFactor = options.TopP;
result.PresencePenalty = options.PresencePenalty;
result.Temperature = options.Temperature;
+ result.Seed = options.Seed;
if (options.StopSequences is { Count: > 0 } stopSequences)
{
@@ -306,11 +307,6 @@ private ChatCompletionsOptions ToAzureAIOptions(IList chatContents,
{
switch (prop.Key)
{
- // These properties are strongly-typed on the ChatCompletionsOptions class but not on the ChatOptions class.
- case nameof(result.Seed) when prop.Value is long seed:
- result.Seed = seed;
- break;
-
// Propagate everything else to the ChatCompletionOptions' AdditionalProperties.
default:
if (prop.Value is not null)
diff --git a/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceEmbeddingGenerator.cs
index 84198e6b2cc..866e55ad87a 100644
--- a/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceEmbeddingGenerator.cs
+++ b/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceEmbeddingGenerator.cs
@@ -156,7 +156,7 @@ private EmbeddingsOptions ToAzureAIOptions(IEnumerable inputs, Embedding
{
EmbeddingsOptions result = new(inputs)
{
- Dimensions = _dimensions,
+ Dimensions = options?.Dimensions ?? _dimensions,
Model = options?.ModelId ?? Metadata.ModelId,
EncodingFormat = format,
};
diff --git a/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/CHANGELOG.md b/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/CHANGELOG.md
new file mode 100644
index 00000000000..7929cc7e8b2
--- /dev/null
+++ b/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/CHANGELOG.md
@@ -0,0 +1,12 @@
+# Release History
+
+## 9.0.0-preview.9.24525.1
+
+- Lowered the required version of System.Text.Json to 8.0.5 when targeting net8.0 or older.
+- Updated to use Azure.AI.Inference 1.0.0-beta.2.
+- Added `AzureAIInferenceEmbeddingGenerator` and corresponding `AsEmbeddingGenerator` extension method.
+- Improved handling of assistant messages that include both text and function call content.
+
+## 9.0.0-preview.9.24507.7
+
+Initial Preview
diff --git a/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/JsonContext.cs b/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/JsonContext.cs
index 5576cbf134a..1e1dabffab7 100644
--- a/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/JsonContext.cs
+++ b/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/JsonContext.cs
@@ -48,11 +48,11 @@ private static JsonSerializerOptions CreateDefaultToolJsonOptions()
{
// If reflection-based serialization is enabled by default, use it, as it's the most permissive in terms of what it can serialize,
// and we want to be flexible in terms of what can be put into the various collections in the object model.
- // Otherwise, use the source-generated options to enable Native AOT.
+ // Otherwise, use the source-generated options to enable trimming and Native AOT.
if (JsonSerializer.IsReflectionEnabledByDefault)
{
- // Keep in sync with the JsonSourceGenerationOptions on JsonContext below.
+ // Keep in sync with the JsonSourceGenerationOptions attribute on JsonContext above.
JsonSerializerOptions options = new(JsonSerializerDefaults.Web)
{
TypeInfoResolver = new DefaultJsonTypeInfoResolver(),
diff --git a/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/Microsoft.Extensions.AI.AzureAIInference.csproj b/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/Microsoft.Extensions.AI.AzureAIInference.csproj
index 3a66e7837f2..0e3f60b8db3 100644
--- a/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/Microsoft.Extensions.AI.AzureAIInference.csproj
+++ b/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/Microsoft.Extensions.AI.AzureAIInference.csproj
@@ -17,6 +17,7 @@
$(TargetFrameworks);netstandard2.0
$(NoWarn);CA1063;CA2227;SA1316;S1067;S1121;S3358
true
+ true
@@ -29,6 +30,7 @@
+
diff --git a/src/Libraries/Microsoft.Extensions.AI.Ollama/CHANGELOG.md b/src/Libraries/Microsoft.Extensions.AI.Ollama/CHANGELOG.md
new file mode 100644
index 00000000000..ffb35814039
--- /dev/null
+++ b/src/Libraries/Microsoft.Extensions.AI.Ollama/CHANGELOG.md
@@ -0,0 +1,10 @@
+# Release History
+
+## 9.0.0-preview.9.24525.1
+
+- Lowered the required version of System.Text.Json to 8.0.5 when targeting net8.0 or older.
+- Added additional constructors to `OllamaChatClient` and `OllamaEmbeddingGenerator` that accept `string` endpoints, in addition to the existing ones accepting `Uri` endpoints.
+
+## 9.0.0-preview.9.24507.7
+
+Initial Preview
diff --git a/src/Libraries/Microsoft.Extensions.AI.Ollama/Microsoft.Extensions.AI.Ollama.csproj b/src/Libraries/Microsoft.Extensions.AI.Ollama/Microsoft.Extensions.AI.Ollama.csproj
index ad3064c8a66..018184d6bf0 100644
--- a/src/Libraries/Microsoft.Extensions.AI.Ollama/Microsoft.Extensions.AI.Ollama.csproj
+++ b/src/Libraries/Microsoft.Extensions.AI.Ollama/Microsoft.Extensions.AI.Ollama.csproj
@@ -17,6 +17,7 @@
$(TargetFrameworks);netstandard2.0
$(NoWarn);CA2227;SA1316;S1121;EA0002
true
+ true
diff --git a/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatClient.cs b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatClient.cs
index 72ddb13b2ac..18ff5d50b7c 100644
--- a/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatClient.cs
+++ b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatClient.cs
@@ -273,7 +273,6 @@ private OllamaChatRequest ToOllamaChatRequest(IList chatMessages, C
TransferMetadataValue(nameof(OllamaRequestOptions.penalize_newline), (options, value) => options.penalize_newline = value);
TransferMetadataValue(nameof(OllamaRequestOptions.repeat_last_n), (options, value) => options.repeat_last_n = value);
TransferMetadataValue(nameof(OllamaRequestOptions.repeat_penalty), (options, value) => options.repeat_penalty = value);
- TransferMetadataValue(nameof(OllamaRequestOptions.seed), (options, value) => options.seed = value);
TransferMetadataValue(nameof(OllamaRequestOptions.tfs_z), (options, value) => options.tfs_z = value);
TransferMetadataValue(nameof(OllamaRequestOptions.typical_p), (options, value) => options.typical_p = value);
TransferMetadataValue(nameof(OllamaRequestOptions.use_mmap), (options, value) => options.use_mmap = value);
@@ -314,6 +313,11 @@ private OllamaChatRequest ToOllamaChatRequest(IList chatMessages, C
{
(request.Options ??= new()).top_k = topK;
}
+
+ if (options.Seed is long seed)
+ {
+ (request.Options ??= new()).seed = seed;
+ }
}
return request;
diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/CHANGELOG.md b/src/Libraries/Microsoft.Extensions.AI.OpenAI/CHANGELOG.md
new file mode 100644
index 00000000000..179da41a0b0
--- /dev/null
+++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/CHANGELOG.md
@@ -0,0 +1,12 @@
+# Release History
+
+## 9.0.0-preview.9.24525.1
+
+- Lowered the required version of System.Text.Json to 8.0.5 when targeting net8.0 or older.
+- Improved handling of system messages that include multiple content items.
+- Improved handling of assistant messages that include both text and function call content.
+- Fixed handling of streaming updates containing empty payloads.
+
+## 9.0.0-preview.9.24507.7
+
+Initial Preview
diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/Microsoft.Extensions.AI.OpenAI.csproj b/src/Libraries/Microsoft.Extensions.AI.OpenAI/Microsoft.Extensions.AI.OpenAI.csproj
index 76930738579..f2e2e9c0f52 100644
--- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/Microsoft.Extensions.AI.OpenAI.csproj
+++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/Microsoft.Extensions.AI.OpenAI.csproj
@@ -17,6 +17,7 @@
$(TargetFrameworks);netstandard2.0
$(NoWarn);CA1063;CA1508;CA2227;SA1316;S1121;S3358;EA0002
true
+ true
@@ -26,6 +27,7 @@
+
diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs
index 935bb88f812..985060256f7 100644
--- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs
+++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs
@@ -3,11 +3,13 @@
using System;
using System.Collections.Generic;
+using System.Diagnostics.CodeAnalysis;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization;
+using System.Text.Json.Serialization.Metadata;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Shared.Diagnostics;
@@ -265,8 +267,7 @@ public async IAsyncEnumerable CompleteStreamingAs
existing.CallId ??= toolCallUpdate.ToolCallId;
existing.Name ??= toolCallUpdate.FunctionName;
- if (toolCallUpdate.FunctionArgumentsUpdate is { } update &&
- !update.ToMemory().IsEmpty) // workaround for https://github.com/dotnet/runtime/issues/68262 in 6.0.0 package
+ if (toolCallUpdate.FunctionArgumentsUpdate is { } update && !update.ToMemory().IsEmpty)
{
_ = (existing.Arguments ??= new()).Append(update.ToString());
}
@@ -391,6 +392,9 @@ private static ChatCompletionOptions ToOpenAIOptions(ChatOptions? options)
result.TopP = options.TopP;
result.PresencePenalty = options.PresencePenalty;
result.Temperature = options.Temperature;
+#pragma warning disable OPENAI001 // Type is for evaluation purposes only and is subject to change or removal in future updates.
+ result.Seed = options.Seed;
+#pragma warning restore OPENAI001
if (options.StopSequences is { Count: > 0 } stopSequences)
{
@@ -425,13 +429,6 @@ private static ChatCompletionOptions ToOpenAIOptions(ChatOptions? options)
result.AllowParallelToolCalls = allowParallelToolCalls;
}
-#pragma warning disable OPENAI001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed.
- if (additionalProperties.TryGetValue(nameof(result.Seed), out long seed))
- {
- result.Seed = seed;
- }
-#pragma warning restore OPENAI001
-
if (additionalProperties.TryGetValue(nameof(result.TopLogProbabilityCount), out int topLogProbabilityCountInt))
{
result.TopLogProbabilityCount = topLogProbabilityCountInt;
@@ -587,10 +584,9 @@ private sealed class OpenAIChatToolJson
string? result = resultContent.Result as string;
if (result is null && resultContent.Result is not null)
{
- JsonSerializerOptions options = ToolCallJsonSerializerOptions ?? JsonContext.Default.Options;
try
{
- result = JsonSerializer.Serialize(resultContent.Result, options.GetTypeInfo(typeof(object)));
+ result = JsonSerializer.Serialize(resultContent.Result, JsonContext.GetTypeInfo(typeof(object), ToolCallJsonSerializerOptions));
}
catch (NotSupportedException)
{
@@ -617,7 +613,9 @@ private sealed class OpenAIChatToolJson
ChatToolCall.CreateFunctionToolCall(
callRequest.CallId,
callRequest.Name,
- BinaryData.FromObjectAsJson(callRequest.Arguments, ToolCallJsonSerializerOptions)));
+ new(JsonSerializer.SerializeToUtf8Bytes(
+ callRequest.Arguments,
+ JsonContext.GetTypeInfo(typeof(IDictionary), ToolCallJsonSerializerOptions)))));
}
}
@@ -670,8 +668,53 @@ private static FunctionCallContent ParseCallContentFromBinaryData(BinaryData ut8
argumentParser: static json => JsonSerializer.Deserialize(json, JsonContext.Default.IDictionaryStringObject)!);
/// Source-generated JSON type information.
+ [JsonSourceGenerationOptions(JsonSerializerDefaults.Web,
+ UseStringEnumConverter = true,
+ DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull,
+ WriteIndented = true)]
[JsonSerializable(typeof(OpenAIChatToolJson))]
[JsonSerializable(typeof(IDictionary))]
[JsonSerializable(typeof(JsonElement))]
- private sealed partial class JsonContext : JsonSerializerContext;
+ private sealed partial class JsonContext : JsonSerializerContext
+ {
+ /// Gets the singleton used as the default in JSON serialization operations.
+ private static readonly JsonSerializerOptions _defaultToolJsonOptions = CreateDefaultToolJsonOptions();
+
+ /// Gets JSON type information for the specified type.
+ ///
+ /// This first tries to get the type information from ,
+ /// falling back to if it can't.
+ ///
+ public static JsonTypeInfo GetTypeInfo(Type type, JsonSerializerOptions? firstOptions) =>
+ firstOptions?.TryGetTypeInfo(type, out JsonTypeInfo? info) is true ?
+ info :
+ _defaultToolJsonOptions.GetTypeInfo(type);
+
+ /// Creates the default to use for serialization-related operations.
+ [UnconditionalSuppressMessage("AotAnalysis", "IL3050", Justification = "DefaultJsonTypeInfoResolver is only used when reflection-based serialization is enabled")]
+ [UnconditionalSuppressMessage("ReflectionAnalysis", "IL2026", Justification = "DefaultJsonTypeInfoResolver is only used when reflection-based serialization is enabled")]
+ private static JsonSerializerOptions CreateDefaultToolJsonOptions()
+ {
+ // If reflection-based serialization is enabled by default, use it, as it's the most permissive in terms of what it can serialize,
+ // and we want to be flexible in terms of what can be put into the various collections in the object model.
+ // Otherwise, use the source-generated options to enable trimming and Native AOT.
+
+ if (JsonSerializer.IsReflectionEnabledByDefault)
+ {
+ // Keep in sync with the JsonSourceGenerationOptions attribute on JsonContext above.
+ JsonSerializerOptions options = new(JsonSerializerDefaults.Web)
+ {
+ TypeInfoResolver = new DefaultJsonTypeInfoResolver(),
+ Converters = { new JsonStringEnumConverter() },
+ DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull,
+ WriteIndented = true,
+ };
+
+ options.MakeReadOnly();
+ return options;
+ }
+
+ return Default.Options;
+ }
+ }
}
diff --git a/src/Libraries/Microsoft.Extensions.AI/CHANGELOG.md b/src/Libraries/Microsoft.Extensions.AI/CHANGELOG.md
new file mode 100644
index 00000000000..e2dae2e6e37
--- /dev/null
+++ b/src/Libraries/Microsoft.Extensions.AI/CHANGELOG.md
@@ -0,0 +1,17 @@
+# Release History
+
+## 9.0.0-preview.9.24525.1
+
+- Added new `AIJsonUtilities` and `AIJsonSchemaCreateOptions` classes.
+- Made `AIFunctionFactory.Create` safe for use with Native AOT.
+- Simplified the set of `AIFunctionFactory.Create` overloads.
+- Changed the default for `FunctionInvokingChatClient.ConcurrentInvocation` from `true` to `false`.
+- Improved the readability of JSON generated as part of logging.
+- Fixed handling of generated JSON schema names when using arrays or generic types.
+- Improved `CachingChatClient`'s coalescing of streaming updates, including reduced memory allocation and enhanced metadata propagation.
+- Updated `OpenTelemetryChatClient` and `OpenTelemetryEmbeddingGenerator` to conform to the latest 1.28.0 draft specification of the Semantic Conventions for Generative AI systems.
+- Improved `CompleteAsync`'s structured output support to handle primitive types, enums, and arrays.
+
+## 9.0.0-preview.9.24507.7
+
+Initial Preview
diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ConfigureOptionsChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ConfigureOptionsChatClient.cs
index 895bf8873df..990c92d3ad9 100644
--- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ConfigureOptionsChatClient.cs
+++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ConfigureOptionsChatClient.cs
@@ -17,7 +17,7 @@ namespace Microsoft.Extensions.AI;
///
/// The configuration callback is invoked with the caller-supplied instance. To override the caller-supplied options
/// with a new instance, the callback may simply return that new instance, for example _ => new ChatOptions() { MaxTokens = 1000 }. To provide
-/// a new instance only if the caller-supplied instance is `null`, the callback may conditionally return a new instance, for example
+/// a new instance only if the caller-supplied instance is , the callback may conditionally return a new instance, for example
/// options => options ?? new ChatOptions() { MaxTokens = 1000 }. Any changes to the caller-provided options instance will persist on the
/// original instance, so the callback must take care to only do so when such mutations are acceptable, such as by cloning the original instance
/// and mutating the clone, for example:
@@ -31,6 +31,9 @@ namespace Microsoft.Extensions.AI;
///
///
///
+/// The callback may return , in which case a options will be passed to the next client in the pipeline.
+///
+///
/// The provided implementation of is thread-safe for concurrent use so long as the employed configuration
/// callback is also thread-safe for concurrent requests. If callers employ a shared options instance, care should be taken in the
/// configuration callback, as multiple calls to it may end up running in parallel with the same options instance.
@@ -39,7 +42,7 @@ namespace Microsoft.Extensions.AI;
public sealed class ConfigureOptionsChatClient : DelegatingChatClient
{
/// The callback delegate used to configure options.
- private readonly Func _configureOptions;
+ private readonly Func _configureOptions;
/// Initializes a new instance of the class with the specified callback.
/// The inner client.
@@ -47,7 +50,7 @@ public sealed class ConfigureOptionsChatClient : DelegatingChatClient
/// The delegate to invoke to configure the instance. It is passed the caller-supplied
/// instance and should return the configured instance to use.
///
- public ConfigureOptionsChatClient(IChatClient innerClient, Func configureOptions)
+ public ConfigureOptionsChatClient(IChatClient innerClient, Func configureOptions)
: base(innerClient)
{
_configureOptions = Throw.IfNull(configureOptions);
diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ConfigureOptionsChatClientBuilderExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ConfigureOptionsChatClientBuilderExtensions.cs
index 12b903c0dac..2d98fbd9003 100644
--- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ConfigureOptionsChatClientBuilderExtensions.cs
+++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ConfigureOptionsChatClientBuilderExtensions.cs
@@ -21,9 +21,10 @@ public static class ConfigureOptionsChatClientBuilderExtensions
///
/// The .
///
+ ///
/// The configuration callback is invoked with the caller-supplied instance. To override the caller-supplied options
/// with a new instance, the callback may simply return that new instance, for example _ => new ChatOptions() { MaxTokens = 1000 }. To provide
- /// a new instance only if the caller-supplied instance is `null`, the callback may conditionally return a new instance, for example
+ /// a new instance only if the caller-supplied instance is , the callback may conditionally return a new instance, for example
/// options => options ?? new ChatOptions() { MaxTokens = 1000 }. Any changes to the caller-provided options instance will persist on the
/// original instance, so the callback must take care to only do so when such mutations are acceptable, such as by cloning the original instance
/// and mutating the clone, for example:
@@ -35,9 +36,13 @@ public static class ConfigureOptionsChatClientBuilderExtensions
/// return newOptions;
/// }
///
+ ///
+ ///
+ /// The callback may return , in which case a options will be passed to the next client in the pipeline.
+ ///
///
public static ChatClientBuilder UseChatOptions(
- this ChatClientBuilder builder, Func configureOptions)
+ this ChatClientBuilder builder, Func configureOptions)
{
_ = Throw.IfNull(builder);
_ = Throw.IfNull(configureOptions);
diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/OpenTelemetryChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/OpenTelemetryChatClient.cs
index 905e756e246..a6dfe53adf5 100644
--- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/OpenTelemetryChatClient.cs
+++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/OpenTelemetryChatClient.cs
@@ -322,7 +322,7 @@ private static ChatCompletion ComposeStreamingUpdatesIntoChatCompletion(
_ = activity.AddTag(OpenTelemetryConsts.GenAI.Request.PerProvider(_system, "response_format"), responseFormat);
}
- if (options.AdditionalProperties?.TryGetValue("seed", out long seed) is true)
+ if (options.Seed is long seed)
{
_ = activity.AddTag(OpenTelemetryConsts.GenAI.Request.PerProvider(_system, "seed"), seed);
}
diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/ConfigureOptionsEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/ConfigureOptionsEmbeddingGenerator.cs
new file mode 100644
index 00000000000..9068ac41caa
--- /dev/null
+++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/ConfigureOptionsEmbeddingGenerator.cs
@@ -0,0 +1,75 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+using System.Collections.Generic;
+using System.Runtime.CompilerServices;
+using System.Threading;
+using System.Threading.Tasks;
+using Microsoft.Shared.Diagnostics;
+
+#pragma warning disable SA1629 // Documentation text should end with a period
+
+namespace Microsoft.Extensions.AI;
+
+/// A delegating embedding generator that updates or replaces the used by the remainder of the pipeline.
+/// Specifies the type of the input passed to the generator.
+/// Specifies the type of the embedding instance produced by the generator.
+///
+///
+/// The configuration callback is invoked with the caller-supplied instance. To override the caller-supplied options
+/// with a new instance, the callback may simply return that new instance, for example _ => new EmbeddingGenerationOptions() { Dimensions = 100 }. To provide
+/// a new instance only if the caller-supplied instance is , the callback may conditionally return a new instance, for example
+/// options => options ?? new EmbeddingGenerationOptions() { Dimensions = 100 }. Any changes to the caller-provided options instance will persist on the
+/// original instance, so the callback must take care to only do so when such mutations are acceptable, such as by cloning the original instance
+/// and mutating the clone, for example:
+///
+/// options =>
+/// {
+/// var newOptions = options?.Clone() ?? new();
+/// newOptions.Dimensions = 100;
+/// return newOptions;
+/// }
+///
+///
+///
+/// The callback may return , in which case a options will be passed to the next generator in the pipeline.
+///
+///
+/// The provided implementation of is thread-safe for concurrent use so long as the employed configuration
+/// callback is also thread-safe for concurrent requests. If callers employ a shared options instance, care should be taken in the
+/// configuration callback, as multiple calls to it may end up running in parallel with the same options instance.
+///
+///
+public sealed class ConfigureOptionsEmbeddingGenerator : DelegatingEmbeddingGenerator
+ where TEmbedding : Embedding
+{
+ /// The callback delegate used to configure options.
+ private readonly Func _configureOptions;
+
+ ///
+ /// Initializes a new instance of the class with the
+ /// specified callback.
+ ///
+ /// The inner generator.
+ ///
+ /// The delegate to invoke to configure the instance. It is passed the caller-supplied
+ /// instance and should return the configured instance to use.
+ ///
+ public ConfigureOptionsEmbeddingGenerator(
+ IEmbeddingGenerator innerGenerator,
+ Func configureOptions)
+ : base(innerGenerator)
+ {
+ _configureOptions = Throw.IfNull(configureOptions);
+ }
+
+ ///
+ public override async Task> GenerateAsync(
+ IEnumerable values,
+ EmbeddingGenerationOptions? options = null,
+ CancellationToken cancellationToken = default)
+ {
+ return await base.GenerateAsync(values, _configureOptions(options), cancellationToken).ConfigureAwait(false);
+ }
+}
diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/ConfigureOptionsEmbeddingGeneratorBuilderExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/ConfigureOptionsEmbeddingGeneratorBuilderExtensions.cs
new file mode 100644
index 00000000000..011f4c058e9
--- /dev/null
+++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/ConfigureOptionsEmbeddingGeneratorBuilderExtensions.cs
@@ -0,0 +1,56 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+using Microsoft.Shared.Diagnostics;
+
+#pragma warning disable SA1629 // Documentation text should end with a period
+
+namespace Microsoft.Extensions.AI;
+
+/// Provides extensions for configuring instances.
+public static class ConfigureOptionsEmbeddingGeneratorBuilderExtensions
+{
+ ///
+ /// Adds a callback that updates or replaces . This can be used to set default options.
+ ///
+ /// Specifies the type of the input passed to the generator.
+ /// Specifies the type of the embedding instance produced by the generator.
+ /// The .
+ ///
+ /// The delegate to invoke to configure the instance. It is passed the caller-supplied
+ /// instance and should return the configured instance to use.
+ ///
+ /// The .
+ ///
+ ///
+ /// The configuration callback is invoked with the caller-supplied instance. To override the caller-supplied options
+ /// with a new instance, the callback may simply return that new instance, for example _ => new EmbeddingGenerationOptions() { Dimensions = 100 }. To provide
+ /// a new instance only if the caller-supplied instance is , the callback may conditionally return a new instance, for example
+ /// options => options ?? new EmbeddingGenerationOptions() { Dimensions = 100 }. Any changes to the caller-provided options instance will persist on the
+ /// original instance, so the callback must take care to only do so when such mutations are acceptable, such as by cloning the original instance
+ /// and mutating the clone, for example:
+ ///
+ /// options =>
+ /// {
+ /// var newOptions = options?.Clone() ?? new();
+ /// newOptions.Dimensions = 100;
+ /// return newOptions;
+ /// }
+ ///
+ ///
+ ///
+ /// The callback may return , in which case a options will be passed to the next generator in the pipeline.
+ ///
+ ///
+ public static EmbeddingGeneratorBuilder UseEmbeddingGenerationOptions(
+ this EmbeddingGeneratorBuilder builder,
+ Func configureOptions)
+ where TEmbedding : Embedding
+ {
+ _ = Throw.IfNull(builder);
+ _ = Throw.IfNull(configureOptions);
+
+ return builder.Use(innerGenerator => new ConfigureOptionsEmbeddingGenerator(innerGenerator, configureOptions));
+ }
+}
diff --git a/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.csproj b/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.csproj
index 2dfd7347ea8..e4ebd6198a7 100644
--- a/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.csproj
+++ b/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.csproj
@@ -19,6 +19,7 @@
$(TargetFrameworks);netstandard2.0
$(NoWarn);CA2227;CA1034;SA1316;S1067;S1121;S1994;S3253
true
+ true
diff --git a/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.CacheItem.cs b/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.CacheItem.cs
index 5585b9b2a29..05edc65dc06 100644
--- a/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.CacheItem.cs
+++ b/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.CacheItem.cs
@@ -5,6 +5,7 @@
using System.Diagnostics;
using System.Threading;
using Microsoft.Extensions.Caching.Memory;
+using Microsoft.Extensions.Logging;
namespace Microsoft.Extensions.Caching.Hybrid.Internal;
@@ -22,7 +23,7 @@ internal abstract class CacheItem
// zero.
// This counter also drives cache lifetime, with the cache itself incrementing the count by one. In the
// case of mutable data, cache eviction may reduce this to zero (in cooperation with any concurrent readers,
- // who incr/decr around their fetch), allowing safe buffer recycling.
+ // who increment/decrement around their fetch), allowing safe buffer recycling.
internal int RefCount => Volatile.Read(ref _refCount);
@@ -89,13 +90,18 @@ internal abstract class CacheItem : CacheItem
{
public abstract bool TryGetSize(out long size);
- // attempt to get a value that was *not* previously reserved
- public abstract bool TryGetValue(out T value);
+ // Attempt to get a value that was *not* previously reserved.
+ // Note on ILogger usage: we don't want to propagate and store this everywhere.
+ // It is used for reporting deserialization problems - pass it as needed.
+ // (CacheItem gets into the IMemoryCache - let's minimize the onward reachable set
+ // of that cache, by only handing it leaf nodes of a "tree", not a "graph" with
+ // backwards access - we can also limit object size at the same time)
+ public abstract bool TryGetValue(ILogger log, out T value);
// get a value that *was* reserved, countermanding our reservation in the process
- public T GetReservedValue()
+ public T GetReservedValue(ILogger log)
{
- if (!TryGetValue(out var value))
+ if (!TryGetValue(log, out var value))
{
Throw();
}
diff --git a/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.ImmutableCacheItem.cs b/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.ImmutableCacheItem.cs
index 9ae8468ba29..2e803d87ad6 100644
--- a/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.ImmutableCacheItem.cs
+++ b/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.ImmutableCacheItem.cs
@@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
using System.Threading;
+using Microsoft.Extensions.Logging;
namespace Microsoft.Extensions.Caching.Hybrid.Internal;
@@ -38,7 +39,7 @@ public void SetValue(T value, long size)
Size = size;
}
- public override bool TryGetValue(out T value)
+ public override bool TryGetValue(ILogger log, out T value)
{
value = _value;
return true; // always available
diff --git a/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.L2.cs b/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.L2.cs
index 1e694448737..230a657bdc3 100644
--- a/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.L2.cs
+++ b/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.L2.cs
@@ -16,12 +16,16 @@ internal partial class DefaultHybridCache
{
[SuppressMessage("Performance", "CA1849:Call async methods when in an async method", Justification = "Manual sync check")]
[SuppressMessage("Usage", "VSTHRD003:Avoid awaiting foreign Tasks", Justification = "Manual sync check")]
+ [SuppressMessage("Design", "CA1031:Do not catch general exception types", Justification = "Explicit async exception handling")]
+ [SuppressMessage("Reliability", "CA2000:Dispose objects before losing scope", Justification = "Deliberate recycle only on success")]
internal ValueTask GetFromL2Async(string key, CancellationToken token)
{
switch (GetFeatures(CacheFeatures.BackendCache | CacheFeatures.BackendBuffers))
{
case CacheFeatures.BackendCache: // legacy byte[]-based
+
var pendingLegacy = _backendCache!.GetAsync(key, token);
+
#if NETCOREAPP2_0_OR_GREATER || NETSTANDARD2_1_OR_GREATER
if (!pendingLegacy.IsCompletedSuccessfully)
#else
@@ -36,6 +40,7 @@ internal ValueTask GetFromL2Async(string key, CancellationToken tok
case CacheFeatures.BackendCache | CacheFeatures.BackendBuffers: // IBufferWriter-based
RecyclableArrayBufferWriter writer = RecyclableArrayBufferWriter.Create(MaximumPayloadBytes);
var cache = Unsafe.As(_backendCache!); // type-checked already
+
var pendingBuffers = cache.TryGetAsync(key, writer, token);
if (!pendingBuffers.IsCompletedSuccessfully)
{
@@ -49,7 +54,7 @@ internal ValueTask GetFromL2Async(string key, CancellationToken tok
return new(result);
}
- return default;
+ return default; // treat as a "miss"
static async Task AwaitedLegacyAsync(Task pending, DefaultHybridCache @this)
{
@@ -115,6 +120,11 @@ internal void SetL1(string key, CacheItem value, HybridCacheEntryOptions?
// commit
cacheEntry.Dispose();
+
+ if (HybridCacheEventSource.Log.IsEnabled())
+ {
+ HybridCacheEventSource.Log.LocalCacheWrite();
+ }
}
}
diff --git a/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.MutableCacheItem.cs b/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.MutableCacheItem.cs
index 2d02c23b6d8..db95e8c4590 100644
--- a/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.MutableCacheItem.cs
+++ b/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.MutableCacheItem.cs
@@ -1,14 +1,18 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
+using System;
+using Microsoft.Extensions.Logging;
+
namespace Microsoft.Extensions.Caching.Hybrid.Internal;
internal partial class DefaultHybridCache
{
private sealed partial class MutableCacheItem : CacheItem // used to hold types that require defensive copies
{
- private IHybridCacheSerializer _serializer = null!; // deferred until SetValue
+ private IHybridCacheSerializer? _serializer;
private BufferChunk _buffer;
+ private T? _fallbackValue; // only used in the case of serialization failures
public override bool NeedsEvictionCallback => _buffer.ReturnToPool;
@@ -21,16 +25,27 @@ public void SetValue(ref BufferChunk buffer, IHybridCacheSerializer serialize
buffer = default; // we're taking over the lifetime; the caller no longer has it!
}
- public override bool TryGetValue(out T value)
+ public void SetFallbackValue(T fallbackValue)
+ {
+ _fallbackValue = fallbackValue;
+ }
+
+ public override bool TryGetValue(ILogger log, out T value)
{
// only if we haven't already burned
if (TryReserve())
{
try
{
- value = _serializer.Deserialize(_buffer.AsSequence());
+ var serializer = _serializer;
+ value = serializer is null ? _fallbackValue! : serializer.Deserialize(_buffer.AsSequence());
return true;
}
+ catch (Exception ex)
+ {
+ log.DeserializationFailure(ex);
+ throw;
+ }
finally
{
_ = Release();
diff --git a/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.Serialization.cs b/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.Serialization.cs
index 523a95e279a..d12b2cce592 100644
--- a/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.Serialization.cs
+++ b/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.Serialization.cs
@@ -3,7 +3,7 @@
using System;
using System.Collections.Concurrent;
-using System.Reflection;
+using System.Diagnostics.CodeAnalysis;
using System.Runtime.CompilerServices;
using Microsoft.Extensions.DependencyInjection;
@@ -51,4 +51,54 @@ static IHybridCacheSerializer ResolveAndAddSerializer(DefaultHybridCache @thi
return serializer;
}
}
+
+ [SuppressMessage("Design", "CA1031:Do not catch general exception types", Justification = "Intentional for logged failure mode")]
+ private bool TrySerialize(T value, out BufferChunk buffer, out IHybridCacheSerializer? serializer)
+ {
+ // note: also returns the serializer we resolved, because most-any time we want to serialize, we'll also want
+ // to make sure we use that same instance later (without needing to re-resolve and/or store the entire HC machinery)
+
+ RecyclableArrayBufferWriter? writer = null;
+ buffer = default;
+ try
+ {
+ writer = RecyclableArrayBufferWriter.Create(MaximumPayloadBytes); // note this lifetime spans the SetL2Async
+ serializer = GetSerializer();
+
+ serializer.Serialize(value, writer);
+
+ buffer = new(writer.DetachCommitted(out var length), length, returnToPool: true); // remove buffer ownership from the writer
+ writer.Dispose(); // we're done with the writer
+ return true;
+ }
+ catch (Exception ex)
+ {
+ bool knownCause = false;
+
+ // ^^^ if we know what happened, we can record directly via cause-specific events
+ // and treat as a handled failure (i.e. return false) - otherwise, we'll bubble
+ // the fault up a few layers *in addition to* logging in a failure event
+
+ if (writer is not null)
+ {
+ if (writer.QuotaExceeded)
+ {
+ _logger.MaximumPayloadBytesExceeded(ex, MaximumPayloadBytes);
+ knownCause = true;
+ }
+
+ writer.Dispose();
+ }
+
+ if (!knownCause)
+ {
+ _logger.SerializationFailure(ex);
+ throw;
+ }
+
+ buffer = default;
+ serializer = null;
+ return false;
+ }
+ }
}
diff --git a/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.StampedeState.cs b/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.StampedeState.cs
index eba71774395..e2439357f26 100644
--- a/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.StampedeState.cs
+++ b/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.StampedeState.cs
@@ -74,8 +74,6 @@ protected StampedeState(DefaultHybridCache cache, in StampedeKey key, CacheItem
public abstract void Execute();
- protected int MaximumPayloadBytes => _cache.MaximumPayloadBytes;
-
public override string ToString() => Key.ToString();
public abstract void SetCanceled();
diff --git a/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.StampedeStateT.cs b/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.StampedeStateT.cs
index 4e45acae930..4be5b351485 100644
--- a/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.StampedeStateT.cs
+++ b/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.StampedeStateT.cs
@@ -6,6 +6,7 @@
using System.Diagnostics.CodeAnalysis;
using System.Threading;
using System.Threading.Tasks;
+using Microsoft.Extensions.Logging;
using static Microsoft.Extensions.Caching.Hybrid.Internal.DefaultHybridCache;
namespace Microsoft.Extensions.Caching.Hybrid.Internal;
@@ -14,7 +15,8 @@ internal partial class DefaultHybridCache
{
internal sealed class StampedeState : StampedeState
{
- private const HybridCacheEntryFlags FlagsDisableL1AndL2 = HybridCacheEntryFlags.DisableLocalCacheWrite | HybridCacheEntryFlags.DisableDistributedCacheWrite;
+ // note on terminology: L1 and L2 are, for brevity, used interchangeably with "local" and "distributed" cache, i.e. `IMemoryCache` and `IDistributedCache`
+ private const HybridCacheEntryFlags FlagsDisableL1AndL2Write = HybridCacheEntryFlags.DisableLocalCacheWrite | HybridCacheEntryFlags.DisableDistributedCacheWrite;
private readonly TaskCompletionSource>? _result;
private TState? _state;
@@ -76,13 +78,13 @@ public Task ExecuteDirectAsync(in TState state, Func _result?.TrySetCanceled(SharedToken);
[SuppressMessage("Usage", "VSTHRD003:Avoid awaiting foreign Tasks", Justification = "Custom task management")]
- public ValueTask JoinAsync(CancellationToken token)
+ public ValueTask JoinAsync(ILogger log, CancellationToken token)
{
// If the underlying has already completed, and/or our local token can't cancel: we
// can simply wrap the shared task; otherwise, we need our own cancellation state.
- return token.CanBeCanceled && !Task.IsCompleted ? WithCancellationAsync(this, token) : UnwrapReservedAsync();
+ return token.CanBeCanceled && !Task.IsCompleted ? WithCancellationAsync(log, this, token) : UnwrapReservedAsync(log);
- static async ValueTask WithCancellationAsync(StampedeState stampede, CancellationToken token)
+ static async ValueTask WithCancellationAsync(ILogger log, StampedeState stampede, CancellationToken token)
{
var cancelStub = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
using var reg = token.Register(static obj =>
@@ -112,7 +114,7 @@ static async ValueTask WithCancellationAsync(StampedeState stamped
}
// outside the catch, so we know we only decrement one way or the other
- return result.GetReservedValue();
+ return result.GetReservedValue(log);
}
}
@@ -133,7 +135,7 @@ static Task> InvalidAsync() => System.Threading.Tasks.Task.FromExce
[SuppressMessage("Performance", "CA1849:Call async methods when in an async method", Justification = "Checked manual unwrap")]
[SuppressMessage("Usage", "VSTHRD003:Avoid awaiting foreign Tasks", Justification = "Checked manual unwrap")]
[SuppressMessage("Major Code Smell", "S1121:Assignments should not be made from within sub-expressions", Justification = "Unusual, but legit here")]
- internal ValueTask UnwrapReservedAsync()
+ internal ValueTask UnwrapReservedAsync(ILogger log)
{
var task = Task;
#if NETCOREAPP2_0_OR_GREATER || NETSTANDARD2_1_OR_GREATER
@@ -142,16 +144,16 @@ internal ValueTask UnwrapReservedAsync()
if (task.Status == TaskStatus.RanToCompletion)
#endif
{
- return new(task.Result.GetReservedValue());
+ return new(task.Result.GetReservedValue(log));
}
// if the type is immutable, callers can share the final step too (this may leave dangling
// reservation counters, but that's OK)
- var result = ImmutableTypeCache.IsImmutable ? (_sharedUnwrap ??= AwaitedAsync(Task)) : AwaitedAsync(Task);
+ var result = ImmutableTypeCache.IsImmutable ? (_sharedUnwrap ??= AwaitedAsync(log, Task)) : AwaitedAsync(log, Task);
return new(result);
- static async Task AwaitedAsync(Task> task)
- => (await task.ConfigureAwait(false)).GetReservedValue();
+ static async Task AwaitedAsync(ILogger log, Task> task)
+ => (await task.ConfigureAwait(false)).GetReservedValue(log);
}
[DoesNotReturn]
@@ -161,12 +163,43 @@ static async Task AwaitedAsync(Task> task)
[SuppressMessage("Design", "CA1031:Do not catch general exception types", Justification = "Exception is passed through to faulted task result")]
private async Task BackgroundFetchAsync()
{
+ bool eventSourceEnabled = HybridCacheEventSource.Log.IsEnabled();
try
{
// read from L2 if appropriate
if ((Key.Flags & HybridCacheEntryFlags.DisableDistributedCacheRead) == 0)
{
- var result = await Cache.GetFromL2Async(Key.Key, SharedToken).ConfigureAwait(false);
+ BufferChunk result;
+ try
+ {
+ if (eventSourceEnabled)
+ {
+ HybridCacheEventSource.Log.DistributedCacheGet();
+ }
+
+ result = await Cache.GetFromL2Async(Key.Key, SharedToken).ConfigureAwait(false);
+ if (eventSourceEnabled)
+ {
+ if (result.Array is not null)
+ {
+ HybridCacheEventSource.Log.DistributedCacheHit();
+ }
+ else
+ {
+ HybridCacheEventSource.Log.DistributedCacheMiss();
+ }
+ }
+ }
+ catch (Exception ex)
+ {
+ if (eventSourceEnabled)
+ {
+ HybridCacheEventSource.Log.DistributedCacheFailed();
+ }
+
+ Cache._logger.CacheUnderlyingDataQueryFailure(ex);
+ result = default; // treat as "miss"
+ }
if (result.Array is not null)
{
@@ -179,7 +212,30 @@ private async Task BackgroundFetchAsync()
if ((Key.Flags & HybridCacheEntryFlags.DisableUnderlyingData) == 0)
{
// invoke the callback supplied by the caller
- T newValue = await _underlying!(_state!, SharedToken).ConfigureAwait(false);
+ T newValue;
+ try
+ {
+ if (eventSourceEnabled)
+ {
+ HybridCacheEventSource.Log.UnderlyingDataQueryStart();
+ }
+
+ newValue = await _underlying!(_state!, SharedToken).ConfigureAwait(false);
+
+ if (eventSourceEnabled)
+ {
+ HybridCacheEventSource.Log.UnderlyingDataQueryComplete();
+ }
+ }
+ catch
+ {
+ if (eventSourceEnabled)
+ {
+ HybridCacheEventSource.Log.UnderlyingDataQueryFailed();
+ }
+
+ throw;
+ }
// If we're writing this value *anywhere*, we're going to need to serialize; this is obvious
// in the case of L2, but we also need it for L1, because MemoryCache might be enforcing
@@ -187,11 +243,11 @@ private async Task BackgroundFetchAsync()
// Likewise, if we're writing to a MutableCacheItem, we'll be serializing *anyway* for the payload.
//
// Rephrasing that: the only scenario in which we *do not* need to serialize is if:
- // - it is an ImmutableCacheItem
- // - we're writing neither to L1 nor L2
+ // - it is an ImmutableCacheItem (so we don't need bytes for the CacheItem, L1)
+ // - we're not writing to L2
CacheItem cacheItem = CacheItem;
- bool skipSerialize = cacheItem is ImmutableCacheItem && (Key.Flags & FlagsDisableL1AndL2) == FlagsDisableL1AndL2;
+ bool skipSerialize = cacheItem is ImmutableCacheItem && (Key.Flags & FlagsDisableL1AndL2Write) == FlagsDisableL1AndL2Write;
if (skipSerialize)
{
@@ -202,33 +258,55 @@ private async Task BackgroundFetchAsync()
// ^^^ The first thing we need to do is make sure we're not getting into a thread race over buffer disposal.
// In particular, if this cache item is somehow so short-lived that the buffers would be released *before* we're
// done writing them to L2, which happens *after* we've provided the value to consumers.
- RecyclableArrayBufferWriter writer = RecyclableArrayBufferWriter.Create(MaximumPayloadBytes); // note this lifetime spans the SetL2Async
- IHybridCacheSerializer serializer = Cache.GetSerializer();
- serializer.Serialize(newValue, writer);
- BufferChunk buffer = new(writer.DetachCommitted(out var length), length, returnToPool: true); // remove buffer ownership from the writer
- writer.Dispose(); // we're done with the writer
-
- // protect "buffer" (this is why we "reserved") for writing to L2 if needed; SetResultPreSerialized
- // *may* (depending on context) claim this buffer, in which case "bufferToRelease" gets reset, and
- // the final RecycleIfAppropriate() is a no-op; however, the buffer is valid in either event,
- // (with TryReserve above guaranteeing that we aren't in a race condition).
- BufferChunk bufferToRelease = buffer;
-
- // and since "bufferToRelease" is the thing that will be returned at some point, we can make it explicit
- // that we do not need or want "buffer" to do any recycling (they're the same memory)
- buffer = buffer.DoNotReturnToPool();
-
- // set the underlying result for this operation (includes L1 write if appropriate)
- SetResultPreSerialized(newValue, ref bufferToRelease, serializer);
-
- // Note that at this point we've already released most or all of the waiting callers. Everything
- // from this point onwards happens in the background, from the perspective of the calling code.
-
- // Write to L2 if appropriate.
- if ((Key.Flags & HybridCacheEntryFlags.DisableDistributedCacheWrite) == 0)
+
+ BufferChunk bufferToRelease = default;
+ if (Cache.TrySerialize(newValue, out var buffer, out var serializer))
{
- // We already have the payload serialized, so this is trivial to do.
- await Cache.SetL2Async(Key.Key, in buffer, _options, SharedToken).ConfigureAwait(false);
+ // note we also capture the resolved serializer ^^^ - we'll need it again later
+
+ // protect "buffer" (this is why we "reserved") for writing to L2 if needed; SetResultPreSerialized
+ // *may* (depending on context) claim this buffer, in which case "bufferToRelease" gets reset, and
+ // the final RecycleIfAppropriate() is a no-op; however, the buffer is valid in either event,
+ // (with TryReserve above guaranteeing that we aren't in a race condition).
+ bufferToRelease = buffer;
+
+ // and since "bufferToRelease" is the thing that will be returned at some point, we can make it explicit
+ // that we do not need or want "buffer" to do any recycling (they're the same memory)
+ buffer = buffer.DoNotReturnToPool();
+
+ // set the underlying result for this operation (includes L1 write if appropriate)
+ SetResultPreSerialized(newValue, ref bufferToRelease, serializer);
+
+ // Note that at this point we've already released most or all of the waiting callers. Everything
+ // from this point onwards happens in the background, from the perspective of the calling code.
+
+ // Write to L2 if appropriate.
+ if ((Key.Flags & HybridCacheEntryFlags.DisableDistributedCacheWrite) == 0)
+ {
+ // We already have the payload serialized, so this is trivial to do.
+ try
+ {
+ await Cache.SetL2Async(Key.Key, in buffer, _options, SharedToken).ConfigureAwait(false);
+
+ if (eventSourceEnabled)
+ {
+ HybridCacheEventSource.Log.DistributedCacheWrite();
+ }
+ }
+ catch (Exception ex)
+ {
+ // log the L2 write failure, but that doesn't need to interrupt the app flow (so:
+ // don't rethrow); L1 will still reduce impact, and L1 without L2 is better than
+ // hard failure every time
+ Cache._logger.CacheBackendWriteFailure(ex);
+ }
+ }
+ }
+ else
+ {
+ // unable to serialize (or quota exceeded); try to at least store the onwards value; this is
+ // especially useful for immutable data types
+ SetResultPreSerialized(newValue, ref bufferToRelease, serializer);
}
// Release our hook on the CacheItem (only really important for "mutable").
@@ -309,7 +387,7 @@ private void SetResultAndRecycleIfAppropriate(ref BufferChunk value)
private void SetImmutableResultWithoutSerialize(T value)
{
- Debug.Assert((Key.Flags & FlagsDisableL1AndL2) == FlagsDisableL1AndL2, "Only expected if L1+L2 disabled");
+ Debug.Assert((Key.Flags & FlagsDisableL1AndL2Write) == FlagsDisableL1AndL2Write, "Only expected if L1+L2 disabled");
// set a result from a value we calculated directly
CacheItem cacheItem;
@@ -328,7 +406,7 @@ private void SetImmutableResultWithoutSerialize(T value)
SetResult(cacheItem);
}
- private void SetResultPreSerialized(T value, ref BufferChunk buffer, IHybridCacheSerializer serializer)
+ private void SetResultPreSerialized(T value, ref BufferChunk buffer, IHybridCacheSerializer? serializer)
{
// set a result from a value we calculated directly that
// has ALREADY BEEN SERIALIZED (we can optionally consume this buffer)
@@ -343,8 +421,17 @@ private void SetResultPreSerialized(T value, ref BufferChunk buffer, IHybridCach
// (but leave the buffer alone)
break;
case MutableCacheItem mutable:
- mutable.SetValue(ref buffer, serializer);
- mutable.DebugOnlyTrackBuffer(Cache);
+ if (serializer is null)
+ {
+ // serialization is failing; set fallback value
+ mutable.SetFallbackValue(value);
+ }
+ else
+ {
+ mutable.SetValue(ref buffer, serializer);
+ mutable.DebugOnlyTrackBuffer(Cache);
+ }
+
cacheItem = mutable;
break;
default:
diff --git a/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.cs b/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.cs
index c789e7c6652..71dbf71fd54 100644
--- a/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.cs
+++ b/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.cs
@@ -22,6 +22,9 @@ namespace Microsoft.Extensions.Caching.Hybrid.Internal;
///
internal sealed partial class DefaultHybridCache : HybridCache
{
+ // reserve non-printable characters from keys, to prevent potential L2 abuse
+ private static readonly char[] _keyReservedCharacters = Enumerable.Range(0, 32).Select(i => (char)i).ToArray();
+
[System.Diagnostics.CodeAnalysis.SuppressMessage("Style", "IDE0032:Use auto property", Justification = "Keep usage explicit")]
private readonly IDistributedCache? _backendCache;
[System.Diagnostics.CodeAnalysis.SuppressMessage("Style", "IDE0032:Use auto property", Justification = "Keep usage explicit")]
@@ -37,6 +40,7 @@ internal sealed partial class DefaultHybridCache : HybridCache
private readonly HybridCacheEntryFlags _defaultFlags; // note this already includes hardFlags
private readonly TimeSpan _defaultExpiration;
private readonly TimeSpan _defaultLocalCacheExpiration;
+ private readonly int _maximumKeyLength;
private readonly DistributedCacheEntryOptions _defaultDistributedCacheExpiration;
@@ -90,6 +94,7 @@ public DefaultHybridCache(IOptions options, IServiceProvider
_serializerFactories = factories;
MaximumPayloadBytes = checked((int)_options.MaximumPayloadBytes); // for now hard-limit to 2GiB
+ _maximumKeyLength = _options.MaximumKeyLength;
var defaultEntryOptions = _options.DefaultEntryOptions;
@@ -119,11 +124,33 @@ public override ValueTask GetOrCreateAsync(string key, TState stat
}
var flags = GetEffectiveFlags(options);
- if ((flags & HybridCacheEntryFlags.DisableLocalCacheRead) == 0 && _localCache.TryGetValue(key, out var untyped)
- && untyped is CacheItem typed && typed.TryGetValue(out var value))
+ if (!ValidateKey(key))
{
- // short-circuit
- return new(value);
+ // we can't use cache, but we can still provide the data
+ return RunWithoutCacheAsync(flags, state, underlyingDataCallback, cancellationToken);
+ }
+
+ bool eventSourceEnabled = HybridCacheEventSource.Log.IsEnabled();
+ if ((flags & HybridCacheEntryFlags.DisableLocalCacheRead) == 0)
+ {
+ if (_localCache.TryGetValue(key, out var untyped)
+ && untyped is CacheItem typed && typed.TryGetValue(_logger, out var value))
+ {
+ // short-circuit
+ if (eventSourceEnabled)
+ {
+ HybridCacheEventSource.Log.LocalCacheHit();
+ }
+
+ return new(value);
+ }
+ else
+ {
+ if (eventSourceEnabled)
+ {
+ HybridCacheEventSource.Log.LocalCacheMiss();
+ }
+ }
}
if (GetOrCreateStampedeState(key, flags, out var stampede, canBeCanceled))
@@ -139,11 +166,19 @@ public override ValueTask GetOrCreateAsync(string key, TState stat
{
// we're going to run to completion; no need to get complicated
_ = stampede.ExecuteDirectAsync(in state, underlyingDataCallback, options); // this larger task includes L2 write etc
- return stampede.UnwrapReservedAsync();
+ return stampede.UnwrapReservedAsync(_logger);
+ }
+ }
+ else
+ {
+ // pre-existing query
+ if (eventSourceEnabled)
+ {
+ HybridCacheEventSource.Log.StampedeJoin();
}
}
- return stampede.JoinAsync(cancellationToken);
+ return stampede.JoinAsync(_logger, cancellationToken);
}
public override ValueTask RemoveAsync(string key, CancellationToken token = default)
@@ -164,7 +199,39 @@ public override ValueTask SetAsync(string key, T value, HybridCacheEntryOptio
return new(state.ExecuteDirectAsync(value, static (state, _) => new(state), options)); // note this spans L2 write etc
}
+ private static ValueTask RunWithoutCacheAsync(HybridCacheEntryFlags flags, TState state,
+ Func> underlyingDataCallback,
+ CancellationToken cancellationToken)
+ {
+ return (flags & HybridCacheEntryFlags.DisableUnderlyingData) == 0
+ ? underlyingDataCallback(state, cancellationToken) : default;
+ }
+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private HybridCacheEntryFlags GetEffectiveFlags(HybridCacheEntryOptions? options)
- => (options?.Flags | _hardFlags) ?? _defaultFlags;
+ => (options?.Flags | _hardFlags) ?? _defaultFlags;
+
+ private bool ValidateKey(string key)
+ {
+ if (string.IsNullOrWhiteSpace(key))
+ {
+ _logger.KeyEmptyOrWhitespace();
+ return false;
+ }
+
+ if (key.Length > _maximumKeyLength)
+ {
+ _logger.MaximumKeyLengthExceeded(_maximumKeyLength, key.Length);
+ return false;
+ }
+
+ if (key.IndexOfAny(_keyReservedCharacters) >= 0)
+ {
+ _logger.KeyInvalidContent();
+ return false;
+ }
+
+ // nothing to complain about
+ return true;
+ }
}
diff --git a/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/HybridCacheEventSource.cs b/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/HybridCacheEventSource.cs
new file mode 100644
index 00000000000..92a5d729e57
--- /dev/null
+++ b/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/HybridCacheEventSource.cs
@@ -0,0 +1,203 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Diagnostics;
+using System.Diagnostics.Tracing;
+using System.Runtime.CompilerServices;
+using System.Threading;
+
+namespace Microsoft.Extensions.Caching.Hybrid.Internal;
+
+[EventSource(Name = "Microsoft-Extensions-HybridCache")]
+internal sealed class HybridCacheEventSource : EventSource
+{
+ public static readonly HybridCacheEventSource Log = new();
+
+ internal const int EventIdLocalCacheHit = 1;
+ internal const int EventIdLocalCacheMiss = 2;
+ internal const int EventIdDistributedCacheGet = 3;
+ internal const int EventIdDistributedCacheHit = 4;
+ internal const int EventIdDistributedCacheMiss = 5;
+ internal const int EventIdDistributedCacheFailed = 6;
+ internal const int EventIdUnderlyingDataQueryStart = 7;
+ internal const int EventIdUnderlyingDataQueryComplete = 8;
+ internal const int EventIdUnderlyingDataQueryFailed = 9;
+ internal const int EventIdLocalCacheWrite = 10;
+ internal const int EventIdDistributedCacheWrite = 11;
+ internal const int EventIdStampedeJoin = 12;
+
+ // fast local counters
+ private long _totalLocalCacheHit;
+ private long _totalLocalCacheMiss;
+ private long _totalDistributedCacheHit;
+ private long _totalDistributedCacheMiss;
+ private long _totalUnderlyingDataQuery;
+ private long _currentUnderlyingDataQuery;
+ private long _currentDistributedFetch;
+ private long _totalLocalCacheWrite;
+ private long _totalDistributedCacheWrite;
+ private long _totalStampedeJoin;
+
+#if !(NETSTANDARD2_0 || NET462)
+ // full Counter infrastructure
+ private DiagnosticCounter[]? _counters;
+#endif
+
+ [NonEvent]
+ public void ResetCounters()
+ {
+ Debug.WriteLine($"{nameof(HybridCacheEventSource)} counters reset!");
+
+ Volatile.Write(ref _totalLocalCacheHit, 0);
+ Volatile.Write(ref _totalLocalCacheMiss, 0);
+ Volatile.Write(ref _totalDistributedCacheHit, 0);
+ Volatile.Write(ref _totalDistributedCacheMiss, 0);
+ Volatile.Write(ref _totalUnderlyingDataQuery, 0);
+ Volatile.Write(ref _currentUnderlyingDataQuery, 0);
+ Volatile.Write(ref _currentDistributedFetch, 0);
+ Volatile.Write(ref _totalLocalCacheWrite, 0);
+ Volatile.Write(ref _totalDistributedCacheWrite, 0);
+ Volatile.Write(ref _totalStampedeJoin, 0);
+ }
+
+ [Event(EventIdLocalCacheHit, Level = EventLevel.Verbose)]
+ public void LocalCacheHit()
+ {
+ DebugAssertEnabled();
+ _ = Interlocked.Increment(ref _totalLocalCacheHit);
+ WriteEvent(EventIdLocalCacheHit);
+ }
+
+ [Event(EventIdLocalCacheMiss, Level = EventLevel.Verbose)]
+ public void LocalCacheMiss()
+ {
+ DebugAssertEnabled();
+ _ = Interlocked.Increment(ref _totalLocalCacheMiss);
+ WriteEvent(EventIdLocalCacheMiss);
+ }
+
+ [Event(EventIdDistributedCacheGet, Level = EventLevel.Verbose)]
+ public void DistributedCacheGet()
+ {
+ // should be followed by DistributedCacheHit, DistributedCacheMiss or DistributedCacheFailed
+ DebugAssertEnabled();
+ _ = Interlocked.Increment(ref _currentDistributedFetch);
+ WriteEvent(EventIdDistributedCacheGet);
+ }
+
+ [Event(EventIdDistributedCacheHit, Level = EventLevel.Verbose)]
+ public void DistributedCacheHit()
+ {
+ DebugAssertEnabled();
+
+ // note: not concerned about off-by-one here, i.e. don't panic
+ // about these two being atomic ref each-other - just the overall shape
+ _ = Interlocked.Increment(ref _totalDistributedCacheHit);
+ _ = Interlocked.Decrement(ref _currentDistributedFetch);
+ WriteEvent(EventIdDistributedCacheHit);
+ }
+
+ [Event(EventIdDistributedCacheMiss, Level = EventLevel.Verbose)]
+ public void DistributedCacheMiss()
+ {
+ DebugAssertEnabled();
+
+ // note: not concerned about off-by-one here, i.e. don't panic
+ // about these two being atomic ref each-other - just the overall shape
+ _ = Interlocked.Increment(ref _totalDistributedCacheMiss);
+ _ = Interlocked.Decrement(ref _currentDistributedFetch);
+ WriteEvent(EventIdDistributedCacheMiss);
+ }
+
+ [Event(EventIdDistributedCacheFailed, Level = EventLevel.Error)]
+ public void DistributedCacheFailed()
+ {
+ DebugAssertEnabled();
+ _ = Interlocked.Decrement(ref _currentDistributedFetch);
+ WriteEvent(EventIdDistributedCacheFailed);
+ }
+
+ [Event(EventIdUnderlyingDataQueryStart, Level = EventLevel.Verbose)]
+ public void UnderlyingDataQueryStart()
+ {
+ // should be followed by UnderlyingDataQueryComplete or UnderlyingDataQueryFailed
+ DebugAssertEnabled();
+ _ = Interlocked.Increment(ref _totalUnderlyingDataQuery);
+ _ = Interlocked.Increment(ref _currentUnderlyingDataQuery);
+ WriteEvent(EventIdUnderlyingDataQueryStart);
+ }
+
+ [Event(EventIdUnderlyingDataQueryComplete, Level = EventLevel.Verbose)]
+ public void UnderlyingDataQueryComplete()
+ {
+ DebugAssertEnabled();
+ _ = Interlocked.Decrement(ref _currentUnderlyingDataQuery);
+ WriteEvent(EventIdUnderlyingDataQueryComplete);
+ }
+
+ [Event(EventIdUnderlyingDataQueryFailed, Level = EventLevel.Error)]
+ public void UnderlyingDataQueryFailed()
+ {
+ DebugAssertEnabled();
+ _ = Interlocked.Decrement(ref _currentUnderlyingDataQuery);
+ WriteEvent(EventIdUnderlyingDataQueryFailed);
+ }
+
+ [Event(EventIdLocalCacheWrite, Level = EventLevel.Verbose)]
+ public void LocalCacheWrite()
+ {
+ DebugAssertEnabled();
+ _ = Interlocked.Increment(ref _totalLocalCacheWrite);
+ WriteEvent(EventIdLocalCacheWrite);
+ }
+
+ [Event(EventIdDistributedCacheWrite, Level = EventLevel.Verbose)]
+ public void DistributedCacheWrite()
+ {
+ DebugAssertEnabled();
+ _ = Interlocked.Increment(ref _totalDistributedCacheWrite);
+ WriteEvent(EventIdDistributedCacheWrite);
+ }
+
+ [Event(EventIdStampedeJoin, Level = EventLevel.Verbose)]
+ internal void StampedeJoin()
+ {
+ DebugAssertEnabled();
+ _ = Interlocked.Increment(ref _totalStampedeJoin);
+ WriteEvent(EventIdStampedeJoin);
+ }
+
+#if !(NETSTANDARD2_0 || NET462)
+ [System.Diagnostics.CodeAnalysis.SuppressMessage("Reliability", "CA2000:Dispose objects before losing scope", Justification = "Lifetime exceeds obvious scope; handed to event source")]
+ [NonEvent]
+ protected override void OnEventCommand(EventCommandEventArgs command)
+ {
+ if (command.Command == EventCommand.Enable)
+ {
+ // lazily create counters on first Enable
+ _counters ??= [
+ new PollingCounter("total-local-cache-hits", this, () => Volatile.Read(ref _totalLocalCacheHit)) { DisplayName = "Total Local Cache Hits" },
+ new PollingCounter("total-local-cache-misses", this, () => Volatile.Read(ref _totalLocalCacheMiss)) { DisplayName = "Total Local Cache Misses" },
+ new PollingCounter("total-distributed-cache-hits", this, () => Volatile.Read(ref _totalDistributedCacheHit)) { DisplayName = "Total Distributed Cache Hits" },
+ new PollingCounter("total-distributed-cache-misses", this, () => Volatile.Read(ref _totalDistributedCacheMiss)) { DisplayName = "Total Distributed Cache Misses" },
+ new PollingCounter("total-data-query", this, () => Volatile.Read(ref _totalUnderlyingDataQuery)) { DisplayName = "Total Data Queries" },
+ new PollingCounter("current-data-query", this, () => Volatile.Read(ref _currentUnderlyingDataQuery)) { DisplayName = "Current Data Queries" },
+ new PollingCounter("current-distributed-cache-fetches", this, () => Volatile.Read(ref _currentDistributedFetch)) { DisplayName = "Current Distributed Cache Fetches" },
+ new PollingCounter("total-local-cache-writes", this, () => Volatile.Read(ref _totalLocalCacheWrite)) { DisplayName = "Total Local Cache Writes" },
+ new PollingCounter("total-distributed-cache-writes", this, () => Volatile.Read(ref _totalDistributedCacheWrite)) { DisplayName = "Total Distributed Cache Writes" },
+ new PollingCounter("total-stampede-joins", this, () => Volatile.Read(ref _totalStampedeJoin)) { DisplayName = "Total Stampede Joins" },
+ ];
+ }
+
+ base.OnEventCommand(command);
+ }
+#endif
+
+ [NonEvent]
+ [Conditional("DEBUG")]
+ private void DebugAssertEnabled([CallerMemberName] string caller = "")
+ {
+ Debug.Assert(IsEnabled(), $"Missing check to {nameof(HybridCacheEventSource)}.{nameof(Log)}.{nameof(IsEnabled)} from {caller}");
+ Debug.WriteLine($"{nameof(HybridCacheEventSource)}: {caller}"); // also log all event calls, for visibility
+ }
+}
diff --git a/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/InbuiltTypeSerializer.cs b/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/InbuiltTypeSerializer.cs
index 3ef26341433..4800428a88f 100644
--- a/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/InbuiltTypeSerializer.cs
+++ b/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/InbuiltTypeSerializer.cs
@@ -17,6 +17,18 @@ internal sealed class InbuiltTypeSerializer : IHybridCacheSerializer, IH
public static InbuiltTypeSerializer Instance { get; } = new();
string IHybridCacheSerializer.Deserialize(ReadOnlySequence source)
+ => DeserializeString(source);
+
+ void IHybridCacheSerializer.Serialize(string value, IBufferWriter target)
+ => SerializeString(value, target);
+
+ byte[] IHybridCacheSerializer.Deserialize(ReadOnlySequence source)
+ => source.ToArray();
+
+ void IHybridCacheSerializer.Serialize(byte[] value, IBufferWriter target)
+ => target.Write(value);
+
+ internal static string DeserializeString(ReadOnlySequence source)
{
#if NET5_0_OR_GREATER
return Encoding.UTF8.GetString(source);
@@ -36,7 +48,7 @@ string IHybridCacheSerializer.Deserialize(ReadOnlySequence source)
#endif
}
- void IHybridCacheSerializer.Serialize(string value, IBufferWriter target)
+ internal static void SerializeString(string value, IBufferWriter target)
{
#if NET5_0_OR_GREATER
Encoding.UTF8.GetBytes(value, target);
@@ -49,10 +61,4 @@ void IHybridCacheSerializer.Serialize(string value, IBufferWriter
ArrayPool.Shared.Return(oversized);
#endif
}
-
- byte[] IHybridCacheSerializer.Deserialize(ReadOnlySequence source)
- => source.ToArray();
-
- void IHybridCacheSerializer.Serialize(byte[] value, IBufferWriter target)
- => target.Write(value);
}
diff --git a/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/Log.cs b/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/Log.cs
new file mode 100644
index 00000000000..785107c32ec
--- /dev/null
+++ b/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/Log.cs
@@ -0,0 +1,49 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+using Microsoft.Extensions.Logging;
+
+namespace Microsoft.Extensions.Caching.Hybrid.Internal;
+
+internal static partial class Log
+{
+ internal const int IdMaximumPayloadBytesExceeded = 1;
+ internal const int IdSerializationFailure = 2;
+ internal const int IdDeserializationFailure = 3;
+ internal const int IdKeyEmptyOrWhitespace = 4;
+ internal const int IdMaximumKeyLengthExceeded = 5;
+ internal const int IdCacheBackendReadFailure = 6;
+ internal const int IdCacheBackendWriteFailure = 7;
+ internal const int IdKeyInvalidContent = 8;
+
+ [LoggerMessage(LogLevel.Error, "Cache MaximumPayloadBytes ({Bytes}) exceeded.", EventName = "MaximumPayloadBytesExceeded", EventId = IdMaximumPayloadBytesExceeded, SkipEnabledCheck = false)]
+ internal static partial void MaximumPayloadBytesExceeded(this ILogger logger, Exception e, int bytes);
+
+ // note that serialization is critical enough that we perform hard failures in addition to logging; serialization
+ // failures are unlikely to be transient (i.e. connectivity); we would rather this shows up in QA, rather than
+ // being invisible and people *thinking* they're using cache, when actually they are not
+
+ [LoggerMessage(LogLevel.Error, "Cache serialization failure.", EventName = "SerializationFailure", EventId = IdSerializationFailure, SkipEnabledCheck = false)]
+ internal static partial void SerializationFailure(this ILogger logger, Exception e);
+
+ // (see same notes per SerializationFailure)
+ [LoggerMessage(LogLevel.Error, "Cache deserialization failure.", EventName = "DeserializationFailure", EventId = IdDeserializationFailure, SkipEnabledCheck = false)]
+ internal static partial void DeserializationFailure(this ILogger logger, Exception e);
+
+ [LoggerMessage(LogLevel.Error, "Cache key empty or whitespace.", EventName = "KeyEmptyOrWhitespace", EventId = IdKeyEmptyOrWhitespace, SkipEnabledCheck = false)]
+ internal static partial void KeyEmptyOrWhitespace(this ILogger logger);
+
+ [LoggerMessage(LogLevel.Error, "Cache key maximum length exceeded (maximum: {MaxLength}, actual: {KeyLength}).", EventName = "MaximumKeyLengthExceeded",
+ EventId = IdMaximumKeyLengthExceeded, SkipEnabledCheck = false)]
+ internal static partial void MaximumKeyLengthExceeded(this ILogger logger, int maxLength, int keyLength);
+
+ [LoggerMessage(LogLevel.Error, "Cache backend read failure.", EventName = "CacheBackendReadFailure", EventId = IdCacheBackendReadFailure, SkipEnabledCheck = false)]
+ internal static partial void CacheUnderlyingDataQueryFailure(this ILogger logger, Exception ex);
+
+ [LoggerMessage(LogLevel.Error, "Cache backend write failure.", EventName = "CacheBackendWriteFailure", EventId = IdCacheBackendWriteFailure, SkipEnabledCheck = false)]
+ internal static partial void CacheBackendWriteFailure(this ILogger logger, Exception ex);
+
+ [LoggerMessage(LogLevel.Error, "Cache key contains invalid content.", EventName = "KeyInvalidContent", EventId = IdKeyInvalidContent, SkipEnabledCheck = false)]
+ internal static partial void KeyInvalidContent(this ILogger logger); // for PII etc reasons, we won't include the actual key
+}
diff --git a/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/RecyclableArrayBufferWriter.cs b/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/RecyclableArrayBufferWriter.cs
index 2f2da2c7019..985d55c9f0e 100644
--- a/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/RecyclableArrayBufferWriter.cs
+++ b/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/RecyclableArrayBufferWriter.cs
@@ -46,20 +46,20 @@ internal sealed class RecyclableArrayBufferWriter : IBufferWriter, IDispos
public int CommittedBytes => _index;
public int FreeCapacity => _buffer.Length - _index;
+ public bool QuotaExceeded { get; private set; }
+
private static RecyclableArrayBufferWriter? _spare;
+
public static RecyclableArrayBufferWriter Create(int maxLength)
{
var obj = Interlocked.Exchange(ref _spare, null) ?? new();
- Debug.Assert(obj._index == 0, "index should be zero initially");
- obj._maxLength = maxLength;
+ obj.Initialize(maxLength);
return obj;
}
private RecyclableArrayBufferWriter()
{
_buffer = [];
- _index = 0;
- _maxLength = int.MaxValue;
}
public void Dispose()
@@ -91,6 +91,7 @@ public void Advance(int count)
if (_index + count > _maxLength)
{
+ QuotaExceeded = true;
ThrowQuota();
}
@@ -199,4 +200,12 @@ private void CheckAndResizeBuffer(int sizeHint)
static void ThrowOutOfMemoryException() => throw new InvalidOperationException("Unable to grow buffer as requested");
}
+
+ private void Initialize(int maxLength)
+ {
+ // think .ctor, but with pooled object re-use
+ _index = 0;
+ _maxLength = maxLength;
+ QuotaExceeded = false;
+ }
}
diff --git a/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Microsoft.Extensions.Caching.Hybrid.csproj b/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Microsoft.Extensions.Caching.Hybrid.csproj
index ec8946d2f9d..d3029266b57 100644
--- a/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Microsoft.Extensions.Caching.Hybrid.csproj
+++ b/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Microsoft.Extensions.Caching.Hybrid.csproj
@@ -4,7 +4,7 @@
Multi-level caching implementation building on and extending IDistributedCache
$(NetCoreTargetFrameworks)$(ConditionalNet462);netstandard2.0;netstandard2.1
true
- cache;distributedcache;hybrid
+ cache;distributedcache;hybridcache
true
true
true
@@ -21,6 +21,11 @@
true
+ true
+ true
+
+
+ false
diff --git a/src/Libraries/Microsoft.Extensions.Compliance.Abstractions/Microsoft.Extensions.Compliance.Abstractions.csproj b/src/Libraries/Microsoft.Extensions.Compliance.Abstractions/Microsoft.Extensions.Compliance.Abstractions.csproj
index 5a6c93e1dc7..c83b7284da5 100644
--- a/src/Libraries/Microsoft.Extensions.Compliance.Abstractions/Microsoft.Extensions.Compliance.Abstractions.csproj
+++ b/src/Libraries/Microsoft.Extensions.Compliance.Abstractions/Microsoft.Extensions.Compliance.Abstractions.csproj
@@ -1,6 +1,7 @@
Microsoft.Extensions.Compliance
+ $(NetCoreTargetFrameworks);netstandard2.0;
Abstractions to help ensure compliant data management.
Fundamentals
diff --git a/src/Shared/JsonSchemaExporter/JsonSchemaExporter.JsonSchema.cs b/src/Shared/JsonSchemaExporter/JsonSchemaExporter.JsonSchema.cs
new file mode 100644
index 00000000000..0f1044fc6eb
--- /dev/null
+++ b/src/Shared/JsonSchemaExporter/JsonSchemaExporter.JsonSchema.cs
@@ -0,0 +1,545 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+#if !NET9_0_OR_GREATER
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Text.Json.Nodes;
+
+namespace System.Text.Json.Schema;
+
+#pragma warning disable SA1204 // Static elements should appear before instance elements
+#pragma warning disable S1144 // Unused private types or members should be removed
+
+internal static partial class JsonSchemaExporter
+{
+ // Simple JSON schema representation taken from System.Text.Json
+ // https://github.com/dotnet/runtime/blob/50d6cad649aad2bfa4069268eddd16fd51ec5cf3/src/libraries/System.Text.Json/src/System/Text/Json/Schema/JsonSchema.cs
+ private sealed class JsonSchema
+ {
+ public static JsonSchema False { get; } = new(false);
+ public static JsonSchema True { get; } = new(true);
+
+ public JsonSchema()
+ {
+ }
+
+ private JsonSchema(bool trueOrFalse)
+ {
+ _trueOrFalse = trueOrFalse;
+ }
+
+ public bool IsTrue => _trueOrFalse is true;
+ public bool IsFalse => _trueOrFalse is false;
+ private readonly bool? _trueOrFalse;
+
+ public string? Schema
+ {
+ get => _schema;
+ set
+ {
+ VerifyMutable();
+ _schema = value;
+ }
+ }
+
+ private string? _schema;
+
+ public string? Title
+ {
+ get => _title;
+ set
+ {
+ VerifyMutable();
+ _title = value;
+ }
+ }
+
+ private string? _title;
+
+ public string? Description
+ {
+ get => _description;
+ set
+ {
+ VerifyMutable();
+ _description = value;
+ }
+ }
+
+ private string? _description;
+
+ public string? Ref
+ {
+ get => _ref;
+ set
+ {
+ VerifyMutable();
+ _ref = value;
+ }
+ }
+
+ private string? _ref;
+
+ public string? Comment
+ {
+ get => _comment;
+ set
+ {
+ VerifyMutable();
+ _comment = value;
+ }
+ }
+
+ private string? _comment;
+
+ public JsonSchemaType Type
+ {
+ get => _type;
+ set
+ {
+ VerifyMutable();
+ _type = value;
+ }
+ }
+
+ private JsonSchemaType _type = JsonSchemaType.Any;
+
+ public string? Format
+ {
+ get => _format;
+ set
+ {
+ VerifyMutable();
+ _format = value;
+ }
+ }
+
+ private string? _format;
+
+ public string? Pattern
+ {
+ get => _pattern;
+ set
+ {
+ VerifyMutable();
+ _pattern = value;
+ }
+ }
+
+ private string? _pattern;
+
+ public JsonNode? Constant
+ {
+ get => _constant;
+ set
+ {
+ VerifyMutable();
+ _constant = value;
+ }
+ }
+
+ private JsonNode? _constant;
+
+ public List>? Properties
+ {
+ get => _properties;
+ set
+ {
+ VerifyMutable();
+ _properties = value;
+ }
+ }
+
+ private List>? _properties;
+
+ public List? Required
+ {
+ get => _required;
+ set
+ {
+ VerifyMutable();
+ _required = value;
+ }
+ }
+
+ private List? _required;
+
+ public JsonSchema? Items
+ {
+ get => _items;
+ set
+ {
+ VerifyMutable();
+ _items = value;
+ }
+ }
+
+ private JsonSchema? _items;
+
+ public JsonSchema? AdditionalProperties
+ {
+ get => _additionalProperties;
+ set
+ {
+ VerifyMutable();
+ _additionalProperties = value;
+ }
+ }
+
+ private JsonSchema? _additionalProperties;
+
+ public JsonArray? Enum
+ {
+ get => _enum;
+ set
+ {
+ VerifyMutable();
+ _enum = value;
+ }
+ }
+
+ private JsonArray? _enum;
+
+ public JsonSchema? Not
+ {
+ get => _not;
+ set
+ {
+ VerifyMutable();
+ _not = value;
+ }
+ }
+
+ private JsonSchema? _not;
+
+ public List? AnyOf
+ {
+ get => _anyOf;
+ set
+ {
+ VerifyMutable();
+ _anyOf = value;
+ }
+ }
+
+ private List? _anyOf;
+
+ public bool HasDefaultValue
+ {
+ get => _hasDefaultValue;
+ set
+ {
+ VerifyMutable();
+ _hasDefaultValue = value;
+ }
+ }
+
+ private bool _hasDefaultValue;
+
+ public JsonNode? DefaultValue
+ {
+ get => _defaultValue;
+ set
+ {
+ VerifyMutable();
+ _defaultValue = value;
+ }
+ }
+
+ private JsonNode? _defaultValue;
+
+ public int? MinLength
+ {
+ get => _minLength;
+ set
+ {
+ VerifyMutable();
+ _minLength = value;
+ }
+ }
+
+ private int? _minLength;
+
+ public int? MaxLength
+ {
+ get => _maxLength;
+ set
+ {
+ VerifyMutable();
+ _maxLength = value;
+ }
+ }
+
+ private int? _maxLength;
+
+ public JsonSchemaExporterContext? GenerationContext { get; set; }
+
+ public int KeywordCount
+ {
+ get
+ {
+ if (_trueOrFalse != null)
+ {
+ return 0;
+ }
+
+ int count = 0;
+ Count(Schema != null);
+ Count(Ref != null);
+ Count(Comment != null);
+ Count(Title != null);
+ Count(Description != null);
+ Count(Type != JsonSchemaType.Any);
+ Count(Format != null);
+ Count(Pattern != null);
+ Count(Constant != null);
+ Count(Properties != null);
+ Count(Required != null);
+ Count(Items != null);
+ Count(AdditionalProperties != null);
+ Count(Enum != null);
+ Count(Not != null);
+ Count(AnyOf != null);
+ Count(HasDefaultValue);
+ Count(MinLength != null);
+ Count(MaxLength != null);
+
+ return count;
+
+ void Count(bool isKeywordSpecified) => count += isKeywordSpecified ? 1 : 0;
+ }
+ }
+
+ public void MakeNullable()
+ {
+ if (_trueOrFalse != null)
+ {
+ return;
+ }
+
+ if (Type != JsonSchemaType.Any)
+ {
+ Type |= JsonSchemaType.Null;
+ }
+ }
+
+ public JsonNode ToJsonNode(JsonSchemaExporterOptions options)
+ {
+ if (_trueOrFalse is { } boolSchema)
+ {
+ return CompleteSchema((JsonNode)boolSchema);
+ }
+
+ var objSchema = new JsonObject();
+
+ if (Schema != null)
+ {
+ objSchema.Add(JsonSchemaConstants.SchemaPropertyName, Schema);
+ }
+
+ if (Title != null)
+ {
+ objSchema.Add(JsonSchemaConstants.TitlePropertyName, Title);
+ }
+
+ if (Description != null)
+ {
+ objSchema.Add(JsonSchemaConstants.DescriptionPropertyName, Description);
+ }
+
+ if (Ref != null)
+ {
+ objSchema.Add(JsonSchemaConstants.RefPropertyName, Ref);
+ }
+
+ if (Comment != null)
+ {
+ objSchema.Add(JsonSchemaConstants.CommentPropertyName, Comment);
+ }
+
+ if (MapSchemaType(Type) is JsonNode type)
+ {
+ objSchema.Add(JsonSchemaConstants.TypePropertyName, type);
+ }
+
+ if (Format != null)
+ {
+ objSchema.Add(JsonSchemaConstants.FormatPropertyName, Format);
+ }
+
+ if (Pattern != null)
+ {
+ objSchema.Add(JsonSchemaConstants.PatternPropertyName, Pattern);
+ }
+
+ if (Constant != null)
+ {
+ objSchema.Add(JsonSchemaConstants.ConstPropertyName, Constant);
+ }
+
+ if (Properties != null)
+ {
+ var properties = new JsonObject();
+ foreach (KeyValuePair property in Properties)
+ {
+ properties.Add(property.Key, property.Value.ToJsonNode(options));
+ }
+
+ objSchema.Add(JsonSchemaConstants.PropertiesPropertyName, properties);
+ }
+
+ if (Required != null)
+ {
+ var requiredArray = new JsonArray();
+ foreach (string requiredProperty in Required)
+ {
+ requiredArray.Add((JsonNode)requiredProperty);
+ }
+
+ objSchema.Add(JsonSchemaConstants.RequiredPropertyName, requiredArray);
+ }
+
+ if (Items != null)
+ {
+ objSchema.Add(JsonSchemaConstants.ItemsPropertyName, Items.ToJsonNode(options));
+ }
+
+ if (AdditionalProperties != null)
+ {
+ objSchema.Add(JsonSchemaConstants.AdditionalPropertiesPropertyName, AdditionalProperties.ToJsonNode(options));
+ }
+
+ if (Enum != null)
+ {
+ objSchema.Add(JsonSchemaConstants.EnumPropertyName, Enum);
+ }
+
+ if (Not != null)
+ {
+ objSchema.Add(JsonSchemaConstants.NotPropertyName, Not.ToJsonNode(options));
+ }
+
+ if (AnyOf != null)
+ {
+ JsonArray anyOfArray = new();
+ foreach (JsonSchema schema in AnyOf)
+ {
+ anyOfArray.Add(schema.ToJsonNode(options));
+ }
+
+ objSchema.Add(JsonSchemaConstants.AnyOfPropertyName, anyOfArray);
+ }
+
+ if (HasDefaultValue)
+ {
+ objSchema.Add(JsonSchemaConstants.DefaultPropertyName, DefaultValue);
+ }
+
+ if (MinLength is int minLength)
+ {
+ objSchema.Add(JsonSchemaConstants.MinLengthPropertyName, (JsonNode)minLength);
+ }
+
+ if (MaxLength is int maxLength)
+ {
+ objSchema.Add(JsonSchemaConstants.MaxLengthPropertyName, (JsonNode)maxLength);
+ }
+
+ return CompleteSchema(objSchema);
+
+ JsonNode CompleteSchema(JsonNode schema)
+ {
+ if (GenerationContext is { } context)
+ {
+ Debug.Assert(options.TransformSchemaNode != null, "context should only be populated if a callback is present.");
+
+ // Apply any user-defined transformations to the schema.
+ return options.TransformSchemaNode!(context, schema);
+ }
+
+ return schema;
+ }
+ }
+
+ public static void EnsureMutable(ref JsonSchema schema)
+ {
+ switch (schema._trueOrFalse)
+ {
+ case false:
+ schema = new JsonSchema { Not = JsonSchema.True };
+ break;
+ case true:
+ schema = new JsonSchema();
+ break;
+ }
+ }
+
+ private static readonly JsonSchemaType[] _schemaValues = new JsonSchemaType[]
+ {
+ // NB the order of these values influences order of types in the rendered schema
+ JsonSchemaType.String,
+ JsonSchemaType.Integer,
+ JsonSchemaType.Number,
+ JsonSchemaType.Boolean,
+ JsonSchemaType.Array,
+ JsonSchemaType.Object,
+ JsonSchemaType.Null,
+ };
+
+ private void VerifyMutable()
+ {
+ Debug.Assert(_trueOrFalse is null, "Schema is not mutable");
+ }
+
+ private static JsonNode? MapSchemaType(JsonSchemaType schemaType)
+ {
+ if (schemaType is JsonSchemaType.Any)
+ {
+ return null;
+ }
+
+ if (ToIdentifier(schemaType) is string identifier)
+ {
+ return identifier;
+ }
+
+ var array = new JsonArray();
+ foreach (JsonSchemaType type in _schemaValues)
+ {
+ if ((schemaType & type) != 0)
+ {
+ array.Add((JsonNode)ToIdentifier(type)!);
+ }
+ }
+
+ return array;
+
+ static string? ToIdentifier(JsonSchemaType schemaType) => schemaType switch
+ {
+ JsonSchemaType.Null => "null",
+ JsonSchemaType.Boolean => "boolean",
+ JsonSchemaType.Integer => "integer",
+ JsonSchemaType.Number => "number",
+ JsonSchemaType.String => "string",
+ JsonSchemaType.Array => "array",
+ JsonSchemaType.Object => "object",
+ _ => null,
+ };
+ }
+ }
+
+ [Flags]
+ private enum JsonSchemaType
+ {
+ Any = 0, // No type declared on the schema
+ Null = 1,
+ Boolean = 2,
+ Integer = 4,
+ Number = 8,
+ String = 16,
+ Array = 32,
+ Object = 64,
+ }
+}
+#endif
diff --git a/src/Shared/JsonSchemaExporter/JsonSchemaExporter.ReflectionHelpers.cs b/src/Shared/JsonSchemaExporter/JsonSchemaExporter.ReflectionHelpers.cs
new file mode 100644
index 00000000000..481e5f75753
--- /dev/null
+++ b/src/Shared/JsonSchemaExporter/JsonSchemaExporter.ReflectionHelpers.cs
@@ -0,0 +1,427 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+#if !NET9_0_OR_GREATER
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Diagnostics.CodeAnalysis;
+#if !NET
+using System.Linq;
+#endif
+using System.Reflection;
+using System.Text.Json.Serialization;
+using System.Text.Json.Serialization.Metadata;
+using Microsoft.Shared.Diagnostics;
+
+#pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields
+
+namespace System.Text.Json.Schema;
+
+internal static partial class JsonSchemaExporter
+{
+ private static class ReflectionHelpers
+ {
+ private const BindingFlags AllInstance = BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic;
+ private static PropertyInfo? _jsonTypeInfo_ElementType;
+ private static PropertyInfo? _jsonPropertyInfo_MemberName;
+ private static FieldInfo? _nullableConverter_ElementConverter_Generic;
+ private static FieldInfo? _enumConverter_Options_Generic;
+ private static FieldInfo? _enumConverter_NamingPolicy_Generic;
+
+ public static bool IsBuiltInConverter(JsonConverter converter) =>
+ converter.GetType().Assembly == typeof(JsonConverter).Assembly;
+
+ public static bool CanBeNull(Type type) => !type.IsValueType || Nullable.GetUnderlyingType(type) is not null;
+
+ public static Type GetElementType(JsonTypeInfo typeInfo)
+ {
+ Debug.Assert(typeInfo.Kind is JsonTypeInfoKind.Enumerable or JsonTypeInfoKind.Dictionary, "TypeInfo must be of collection type");
+
+ // Uses reflection to access the element type encapsulated by a JsonTypeInfo.
+ if (_jsonTypeInfo_ElementType is null)
+ {
+ PropertyInfo? elementTypeProperty = typeof(JsonTypeInfo).GetProperty("ElementType", AllInstance);
+ _jsonTypeInfo_ElementType = Throw.IfNull(elementTypeProperty);
+ }
+
+ return (Type)_jsonTypeInfo_ElementType.GetValue(typeInfo)!;
+ }
+
+ public static string? GetMemberName(JsonPropertyInfo propertyInfo)
+ {
+ // Uses reflection to the member name encapsulated by a JsonPropertyInfo.
+ if (_jsonPropertyInfo_MemberName is null)
+ {
+ PropertyInfo? memberName = typeof(JsonPropertyInfo).GetProperty("MemberName", AllInstance);
+ _jsonPropertyInfo_MemberName = Throw.IfNull(memberName);
+ }
+
+ return (string?)_jsonPropertyInfo_MemberName.GetValue(propertyInfo);
+ }
+
+ public static JsonConverter GetElementConverter(JsonConverter nullableConverter)
+ {
+ // Uses reflection to access the element converter encapsulated by a nullable converter.
+ if (_nullableConverter_ElementConverter_Generic is null)
+ {
+ FieldInfo? genericFieldInfo = Type
+ .GetType("System.Text.Json.Serialization.Converters.NullableConverter`1, System.Text.Json")!
+ .GetField("_elementConverter", AllInstance);
+
+ _nullableConverter_ElementConverter_Generic = Throw.IfNull(genericFieldInfo);
+ }
+
+ Type converterType = nullableConverter.GetType();
+ var thisFieldInfo = (FieldInfo)converterType.GetMemberWithSameMetadataDefinitionAs(_nullableConverter_ElementConverter_Generic);
+ return (JsonConverter)thisFieldInfo.GetValue(nullableConverter)!;
+ }
+
+ public static void GetEnumConverterConfig(JsonConverter enumConverter, out JsonNamingPolicy? namingPolicy, out bool allowString)
+ {
+ // Uses reflection to access configuration encapsulated by an enum converter.
+ if (_enumConverter_Options_Generic is null)
+ {
+ FieldInfo? genericFieldInfo = Type
+ .GetType("System.Text.Json.Serialization.Converters.EnumConverter`1, System.Text.Json")!
+ .GetField("_converterOptions", AllInstance);
+
+ _enumConverter_Options_Generic = Throw.IfNull(genericFieldInfo);
+ }
+
+ if (_enumConverter_NamingPolicy_Generic is null)
+ {
+ FieldInfo? genericFieldInfo = Type
+ .GetType("System.Text.Json.Serialization.Converters.EnumConverter`1, System.Text.Json")!
+ .GetField("_namingPolicy", AllInstance);
+
+ _enumConverter_NamingPolicy_Generic = Throw.IfNull(genericFieldInfo);
+ }
+
+ const int EnumConverterOptionsAllowStrings = 1;
+ Type converterType = enumConverter.GetType();
+ var converterOptionsField = (FieldInfo)converterType.GetMemberWithSameMetadataDefinitionAs(_enumConverter_Options_Generic);
+ var namingPolicyField = (FieldInfo)converterType.GetMemberWithSameMetadataDefinitionAs(_enumConverter_NamingPolicy_Generic);
+
+ namingPolicy = (JsonNamingPolicy?)namingPolicyField.GetValue(enumConverter);
+ int converterOptions = (int)converterOptionsField.GetValue(enumConverter)!;
+ allowString = (converterOptions & EnumConverterOptionsAllowStrings) != 0;
+ }
+
+ // The .NET 8 source generator doesn't populate attribute providers for properties
+ // cf. https://github.com/dotnet/runtime/issues/100095
+ // Work around the issue by running a query for the relevant MemberInfo using the internal MemberName property
+ // https://github.com/dotnet/runtime/blob/de774ff9ee1a2c06663ab35be34b755cd8d29731/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/JsonPropertyInfo.cs#L206
+ public static ICustomAttributeProvider? ResolveAttributeProvider(
+ [DynamicallyAccessedMembers(
+ DynamicallyAccessedMemberTypes.PublicProperties | DynamicallyAccessedMemberTypes.NonPublicProperties |
+ DynamicallyAccessedMemberTypes.PublicFields | DynamicallyAccessedMemberTypes.NonPublicFields)]
+ Type? declaringType,
+ JsonPropertyInfo? propertyInfo)
+ {
+ if (declaringType is null || propertyInfo is null)
+ {
+ return null;
+ }
+
+ if (propertyInfo.AttributeProvider is { } provider)
+ {
+ return provider;
+ }
+
+ string? memberName = ReflectionHelpers.GetMemberName(propertyInfo);
+ if (memberName is not null)
+ {
+ return (MemberInfo?)declaringType.GetProperty(memberName, AllInstance) ??
+ declaringType.GetField(memberName, AllInstance);
+ }
+
+ return null;
+ }
+
+ // Resolves the parameters of the deserialization constructor for a type, if they exist.
+ public static Func? ResolveJsonConstructorParameterMapper(
+ [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors | DynamicallyAccessedMemberTypes.NonPublicConstructors)]
+ Type type,
+ JsonTypeInfo typeInfo)
+ {
+ Debug.Assert(type == typeInfo.Type, "The declaring type must match the typeInfo type.");
+ Debug.Assert(typeInfo.Kind is JsonTypeInfoKind.Object, "Should only be passed object JSON kinds.");
+
+ if (typeInfo.Properties.Count > 0 &&
+ typeInfo.CreateObject is null && // Ensure that a default constructor isn't being used
+ TryGetDeserializationConstructor(type, useDefaultCtorInAnnotatedStructs: true, out ConstructorInfo? ctor))
+ {
+ ParameterInfo[]? parameters = ctor?.GetParameters();
+ if (parameters?.Length > 0)
+ {
+ Dictionary dict = new(parameters.Length);
+ foreach (ParameterInfo parameter in parameters)
+ {
+ if (parameter.Name is not null)
+ {
+ // We don't care about null parameter names or conflicts since they
+ // would have already been rejected by JsonTypeInfo exporterOptions.
+ dict[new(parameter.Name, parameter.ParameterType)] = parameter;
+ }
+ }
+
+ return prop => dict.TryGetValue(new(prop.Name, prop.PropertyType), out ParameterInfo? parameter) ? parameter : null;
+ }
+ }
+
+ return null;
+ }
+
+ // Resolves the nullable reference type annotations for a property or field,
+ // additionally addressing a few known bugs of the NullabilityInfo pre .NET 9.
+ public static NullabilityInfo GetMemberNullability(NullabilityInfoContext context, MemberInfo memberInfo)
+ {
+ Debug.Assert(memberInfo is PropertyInfo or FieldInfo, "Member must be property or field.");
+ return memberInfo is PropertyInfo prop
+ ? context.Create(prop)
+ : context.Create((FieldInfo)memberInfo);
+ }
+
+ public static NullabilityState GetParameterNullability(NullabilityInfoContext context, ParameterInfo parameterInfo)
+ {
+#if NET8_0
+ // Workaround for https://github.com/dotnet/runtime/issues/92487
+ // The fix has been incorporated into .NET 9 (and the polyfilled implementations in netfx).
+ // Should be removed once .NET 8 support is dropped.
+ if (GetGenericParameterDefinition(parameterInfo) is { ParameterType: { IsGenericParameter: true } typeParam })
+ {
+ // Step 1. Look for nullable annotations on the type parameter.
+ if (GetNullableFlags(typeParam) is byte[] flags)
+ {
+ return TranslateByte(flags[0]);
+ }
+
+ // Step 2. Look for nullable annotations on the generic method declaration.
+ if (typeParam.DeclaringMethod != null && GetNullableContextFlag(typeParam.DeclaringMethod) is byte flag)
+ {
+ return TranslateByte(flag);
+ }
+
+ // Step 3. Look for nullable annotations on the generic method declaration.
+ if (GetNullableContextFlag(typeParam.DeclaringType!) is byte flag2)
+ {
+ return TranslateByte(flag2);
+ }
+
+ // Default to nullable.
+ return NullabilityState.Nullable;
+
+ static byte[]? GetNullableFlags(MemberInfo member)
+ {
+ foreach (CustomAttributeData attr in member.GetCustomAttributesData())
+ {
+ Type attrType = attr.AttributeType;
+ if (attrType.Name == "NullableAttribute" && attrType.Namespace == "System.Runtime.CompilerServices")
+ {
+ foreach (CustomAttributeTypedArgument ctorArg in attr.ConstructorArguments)
+ {
+ switch (ctorArg.Value)
+ {
+ case byte flag:
+ return [flag];
+ case byte[] flags:
+ return flags;
+ }
+ }
+ }
+ }
+
+ return null;
+ }
+
+ static byte? GetNullableContextFlag(MemberInfo member)
+ {
+ foreach (CustomAttributeData attr in member.GetCustomAttributesData())
+ {
+ Type attrType = attr.AttributeType;
+ if (attrType.Name == "NullableContextAttribute" && attrType.Namespace == "System.Runtime.CompilerServices")
+ {
+ foreach (CustomAttributeTypedArgument ctorArg in attr.ConstructorArguments)
+ {
+ if (ctorArg.Value is byte flag)
+ {
+ return flag;
+ }
+ }
+ }
+ }
+
+ return null;
+ }
+
+#pragma warning disable S109 // Magic numbers should not be used
+ static NullabilityState TranslateByte(byte b) => b switch
+ {
+ 1 => NullabilityState.NotNull,
+ 2 => NullabilityState.Nullable,
+ _ => NullabilityState.Unknown
+ };
+#pragma warning restore S109 // Magic numbers should not be used
+ }
+
+ static ParameterInfo GetGenericParameterDefinition(ParameterInfo parameter)
+ {
+ if (parameter.Member is { DeclaringType.IsConstructedGenericType: true }
+ or MethodInfo { IsGenericMethod: true, IsGenericMethodDefinition: false })
+ {
+ var genericMethod = (MethodBase)GetGenericMemberDefinition(parameter.Member);
+ return genericMethod.GetParameters()[parameter.Position];
+ }
+
+ return parameter;
+ }
+
+ static MemberInfo GetGenericMemberDefinition(MemberInfo member)
+ {
+ if (member is Type type)
+ {
+ return type.IsConstructedGenericType ? type.GetGenericTypeDefinition() : type;
+ }
+
+ if (member.DeclaringType?.IsConstructedGenericType is true)
+ {
+ return member.DeclaringType.GetGenericTypeDefinition().GetMemberWithSameMetadataDefinitionAs(member);
+ }
+
+ if (member is MethodInfo { IsGenericMethod: true, IsGenericMethodDefinition: false } method)
+ {
+ return method.GetGenericMethodDefinition();
+ }
+
+ return member;
+ }
+#endif
+ return context.Create(parameterInfo).WriteState;
+ }
+
+ // Taken from https://github.com/dotnet/runtime/blob/903bc019427ca07080530751151ea636168ad334/src/libraries/System.Text.Json/Common/ReflectionExtensions.cs#L288-L317
+ public static object? GetNormalizedDefaultValue(ParameterInfo parameterInfo)
+ {
+ Type parameterType = parameterInfo.ParameterType;
+ object? defaultValue = parameterInfo.DefaultValue;
+
+ if (defaultValue is null)
+ {
+ return null;
+ }
+
+ // DBNull.Value is sometimes used as the default value (returned by reflection) of nullable params in place of null.
+ if (defaultValue == DBNull.Value && parameterType != typeof(DBNull))
+ {
+ return null;
+ }
+
+ // Default values of enums or nullable enums are represented using the underlying type and need to be cast explicitly
+ // cf. https://github.com/dotnet/runtime/issues/68647
+ if (parameterType.IsEnum)
+ {
+ return Enum.ToObject(parameterType, defaultValue);
+ }
+
+ if (Nullable.GetUnderlyingType(parameterType) is Type underlyingType && underlyingType.IsEnum)
+ {
+ return Enum.ToObject(underlyingType, defaultValue);
+ }
+
+ return defaultValue;
+ }
+
+ // Resolves the deserialization constructor for a type using logic copied from
+ // https://github.com/dotnet/runtime/blob/e12e2fa6cbdd1f4b0c8ad1b1e2d960a480c21703/src/libraries/System.Text.Json/Common/ReflectionExtensions.cs#L227-L286
+ private static bool TryGetDeserializationConstructor(
+ [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors | DynamicallyAccessedMemberTypes.NonPublicConstructors)]
+ Type type,
+ bool useDefaultCtorInAnnotatedStructs,
+ out ConstructorInfo? deserializationCtor)
+ {
+ ConstructorInfo? ctorWithAttribute = null;
+ ConstructorInfo? publicParameterlessCtor = null;
+ ConstructorInfo? lonePublicCtor = null;
+
+ ConstructorInfo[] constructors = type.GetConstructors(BindingFlags.Public | BindingFlags.Instance);
+
+ if (constructors.Length == 1)
+ {
+ lonePublicCtor = constructors[0];
+ }
+
+ foreach (ConstructorInfo constructor in constructors)
+ {
+ if (HasJsonConstructorAttribute(constructor))
+ {
+ if (ctorWithAttribute != null)
+ {
+ deserializationCtor = null;
+ return false;
+ }
+
+ ctorWithAttribute = constructor;
+ }
+ else if (constructor.GetParameters().Length == 0)
+ {
+ publicParameterlessCtor = constructor;
+ }
+ }
+
+ // Search for non-public ctors with [JsonConstructor].
+ foreach (ConstructorInfo constructor in type.GetConstructors(BindingFlags.NonPublic | BindingFlags.Instance))
+ {
+ if (HasJsonConstructorAttribute(constructor))
+ {
+ if (ctorWithAttribute != null)
+ {
+ deserializationCtor = null;
+ return false;
+ }
+
+ ctorWithAttribute = constructor;
+ }
+ }
+
+ // Structs will use default constructor if attribute isn't used.
+ if (useDefaultCtorInAnnotatedStructs && type.IsValueType && ctorWithAttribute == null)
+ {
+ deserializationCtor = null;
+ return true;
+ }
+
+ deserializationCtor = ctorWithAttribute ?? publicParameterlessCtor ?? lonePublicCtor;
+ return true;
+
+ static bool HasJsonConstructorAttribute(ConstructorInfo constructorInfo) =>
+ constructorInfo.GetCustomAttribute() != null;
+ }
+
+ // Parameter to property matching semantics as declared in
+ // https://github.com/dotnet/runtime/blob/12d96ccfaed98e23c345188ee08f8cfe211c03e7/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/JsonTypeInfo.cs#L1007-L1030
+ private readonly struct ParameterLookupKey : IEquatable
+ {
+ public ParameterLookupKey(string name, Type type)
+ {
+ Name = name;
+ Type = type;
+ }
+
+ public string Name { get; }
+ public Type Type { get; }
+
+ public override int GetHashCode() => StringComparer.OrdinalIgnoreCase.GetHashCode(Name);
+ public bool Equals(ParameterLookupKey other) => Type == other.Type && string.Equals(Name, other.Name, StringComparison.OrdinalIgnoreCase);
+ public override bool Equals(object? obj) => obj is ParameterLookupKey key && Equals(key);
+ }
+ }
+
+#if !NET
+ private static MemberInfo GetMemberWithSameMetadataDefinitionAs(this Type specializedType, MemberInfo member)
+ {
+ const BindingFlags All = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance;
+ return specializedType.GetMember(member.Name, member.MemberType, All).First(m => m.MetadataToken == member.MetadataToken);
+ }
+#endif
+}
+#endif
diff --git a/src/Shared/JsonSchemaExporter/JsonSchemaExporter.cs b/src/Shared/JsonSchemaExporter/JsonSchemaExporter.cs
new file mode 100644
index 00000000000..5c6ce6d9ab7
--- /dev/null
+++ b/src/Shared/JsonSchemaExporter/JsonSchemaExporter.cs
@@ -0,0 +1,801 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+#if !NET9_0_OR_GREATER
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Diagnostics.CodeAnalysis;
+using System.Globalization;
+using System.Linq;
+using System.Reflection;
+#if NET
+using System.Runtime.InteropServices;
+#endif
+using System.Text.Json.Nodes;
+using System.Text.Json.Serialization;
+using System.Text.Json.Serialization.Metadata;
+using Microsoft.Shared.Diagnostics;
+
+#pragma warning disable LA0002 // Use 'Microsoft.Shared.Text.NumericExtensions.ToInvariantString' for improved performance
+#pragma warning disable S107 // Methods should not have too many parameters
+#pragma warning disable S1121 // Assignments should not be made from within sub-expressions
+
+namespace System.Text.Json.Schema;
+
+///
+/// Maps .NET types to JSON schema objects using contract metadata from instances.
+///
+#if !SHARED_PROJECT
+[System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
+#endif
+internal static partial class JsonSchemaExporter
+{
+ // Polyfill implementation of JsonSchemaExporter for System.Text.Json version 8.0.0.
+ // Uses private reflection to access metadata not available with the older APIs of STJ.
+
+ private const string RequiresUnreferencedCodeMessage =
+ "Uses private reflection on System.Text.Json components to access converter metadata. " +
+ "If running Native AOT ensure that the 'IlcTrimMetadata' property has been disabled.";
+
+ ///
+ /// Generates a JSON schema corresponding to the contract metadata of the specified type.
+ ///
+ /// The options instance from which to resolve the contract metadata.
+ /// The root type for which to generate the JSON schema.
+ /// The exporterOptions object controlling the schema generation.
+ /// A new instance defining the JSON schema for .
+ /// One of the specified parameters is .
+ /// The parameter contains unsupported exporterOptions.
+ [RequiresUnreferencedCode(RequiresUnreferencedCodeMessage)]
+ public static JsonNode GetJsonSchemaAsNode(this JsonSerializerOptions options, Type type, JsonSchemaExporterOptions? exporterOptions = null)
+ {
+ _ = Throw.IfNull(options);
+ _ = Throw.IfNull(type);
+ ValidateOptions(options);
+
+ exporterOptions ??= JsonSchemaExporterOptions.Default;
+ JsonTypeInfo typeInfo = options.GetTypeInfo(type);
+ return MapRootTypeJsonSchema(typeInfo, exporterOptions);
+ }
+
+ ///
+ /// Generates a JSON schema corresponding to the specified contract metadata.
+ ///
+ /// The contract metadata for which to generate the schema.
+ /// The exporterOptions object controlling the schema generation.
+ /// A new instance defining the JSON schema for .
+ /// One of the specified parameters is .
+ /// The parameter contains unsupported exporterOptions.
+ [RequiresUnreferencedCode(RequiresUnreferencedCodeMessage)]
+ public static JsonNode GetJsonSchemaAsNode(this JsonTypeInfo typeInfo, JsonSchemaExporterOptions? exporterOptions = null)
+ {
+ _ = Throw.IfNull(typeInfo);
+ ValidateOptions(typeInfo.Options);
+
+ exporterOptions ??= JsonSchemaExporterOptions.Default;
+ return MapRootTypeJsonSchema(typeInfo, exporterOptions);
+ }
+
+ [RequiresUnreferencedCode(RequiresUnreferencedCodeMessage)]
+ private static JsonNode MapRootTypeJsonSchema(JsonTypeInfo typeInfo, JsonSchemaExporterOptions exporterOptions)
+ {
+ GenerationState state = new(exporterOptions, typeInfo.Options);
+ JsonSchema schema = MapJsonSchemaCore(ref state, typeInfo);
+ return schema.ToJsonNode(exporterOptions);
+ }
+
+ [RequiresUnreferencedCode(RequiresUnreferencedCodeMessage)]
+ private static JsonSchema MapJsonSchemaCore(
+ ref GenerationState state,
+ JsonTypeInfo typeInfo,
+ Type? parentType = null,
+ JsonPropertyInfo? propertyInfo = null,
+ ICustomAttributeProvider? propertyAttributeProvider = null,
+ ParameterInfo? parameterInfo = null,
+ bool isNonNullableType = false,
+ JsonConverter? customConverter = null,
+ JsonNumberHandling? customNumberHandling = null,
+ JsonTypeInfo? parentPolymorphicTypeInfo = null,
+ bool parentPolymorphicTypeContainsTypesWithoutDiscriminator = false,
+ bool parentPolymorphicTypeIsNonNullable = false,
+ KeyValuePair? typeDiscriminator = null,
+ bool cacheResult = true)
+ {
+ Debug.Assert(typeInfo.IsReadOnly, "The specified contract must have been made read-only.");
+
+ JsonSchemaExporterContext exporterContext = state.CreateContext(typeInfo, parentPolymorphicTypeInfo, parentType, propertyInfo, parameterInfo, propertyAttributeProvider);
+
+ if (cacheResult && typeInfo.Kind is not JsonTypeInfoKind.None &&
+ state.TryGetExistingJsonPointer(exporterContext, out string? existingJsonPointer))
+ {
+ // The schema context has already been generated in the schema document, return a reference to it.
+ return CompleteSchema(ref state, new JsonSchema { Ref = existingJsonPointer });
+ }
+
+ JsonSchema schema;
+ JsonConverter effectiveConverter = customConverter ?? typeInfo.Converter;
+ JsonNumberHandling effectiveNumberHandling = customNumberHandling ?? typeInfo.NumberHandling ?? typeInfo.Options.NumberHandling;
+
+ if (!ReflectionHelpers.IsBuiltInConverter(effectiveConverter))
+ {
+ // Return a `true` schema for types with user-defined converters.
+ return CompleteSchema(ref state, JsonSchema.True);
+ }
+
+ if (parentPolymorphicTypeInfo is null && typeInfo.PolymorphismOptions is { DerivedTypes.Count: > 0 } polyOptions)
+ {
+ // This is the base type of a polymorphic type hierarchy. The schema for this type
+ // will include an "anyOf" property with the schemas for all derived types.
+
+ string typeDiscriminatorKey = polyOptions.TypeDiscriminatorPropertyName;
+ List derivedTypes = polyOptions.DerivedTypes.ToList();
+
+ if (!typeInfo.Type.IsAbstract && !derivedTypes.Any(derived => derived.DerivedType == typeInfo.Type))
+ {
+ // For non-abstract base types that haven't been explicitly configured,
+ // add a trivial schema to the derived types since we should support it.
+ derivedTypes.Add(new JsonDerivedType(typeInfo.Type));
+ }
+
+ bool containsTypesWithoutDiscriminator = derivedTypes.Exists(static derivedTypes => derivedTypes.TypeDiscriminator is null);
+ JsonSchemaType schemaType = JsonSchemaType.Any;
+ List? anyOf = new(derivedTypes.Count);
+
+ state.PushSchemaNode(JsonSchemaConstants.AnyOfPropertyName);
+
+ foreach (JsonDerivedType derivedType in derivedTypes)
+ {
+ Debug.Assert(derivedType.TypeDiscriminator is null or int or string, "Type discriminator does not have the expected type.");
+
+ KeyValuePair? derivedTypeDiscriminator = null;
+ if (derivedType.TypeDiscriminator is { } discriminatorValue)
+ {
+ JsonNode discriminatorNode = discriminatorValue switch
+ {
+ string stringId => (JsonNode)stringId,
+ _ => (JsonNode)(int)discriminatorValue,
+ };
+
+ JsonSchema discriminatorSchema = new() { Constant = discriminatorNode };
+ derivedTypeDiscriminator = new(typeDiscriminatorKey, discriminatorSchema);
+ }
+
+ JsonTypeInfo derivedTypeInfo = typeInfo.Options.GetTypeInfo(derivedType.DerivedType);
+
+ state.PushSchemaNode(anyOf.Count.ToString(CultureInfo.InvariantCulture));
+ JsonSchema derivedSchema = MapJsonSchemaCore(
+ ref state,
+ derivedTypeInfo,
+ parentPolymorphicTypeInfo: typeInfo,
+ typeDiscriminator: derivedTypeDiscriminator,
+ parentPolymorphicTypeContainsTypesWithoutDiscriminator: containsTypesWithoutDiscriminator,
+ parentPolymorphicTypeIsNonNullable: isNonNullableType,
+ cacheResult: false);
+
+ state.PopSchemaNode();
+
+ // Determine if all derived schemas have the same type.
+ if (anyOf.Count == 0)
+ {
+ schemaType = derivedSchema.Type;
+ }
+ else if (schemaType != derivedSchema.Type)
+ {
+ schemaType = JsonSchemaType.Any;
+ }
+
+ anyOf.Add(derivedSchema);
+ }
+
+ state.PopSchemaNode();
+
+ if (schemaType is not JsonSchemaType.Any)
+ {
+ // If all derived types have the same schema type, we can simplify the schema
+ // by moving the type keyword to the base schema and removing it from the derived schemas.
+ foreach (JsonSchema derivedSchema in anyOf)
+ {
+ derivedSchema.Type = JsonSchemaType.Any;
+
+ if (derivedSchema.KeywordCount == 0)
+ {
+ // if removing the type results in an empty schema,
+ // remove the anyOf array entirely since it's always true.
+ anyOf = null;
+ break;
+ }
+ }
+ }
+
+ schema = new()
+ {
+ Type = schemaType,
+ AnyOf = anyOf,
+
+ // If all derived types have a discriminator, we can require it in the base schema.
+ Required = containsTypesWithoutDiscriminator ? null : new() { typeDiscriminatorKey },
+ };
+
+ return CompleteSchema(ref state, schema);
+ }
+
+ if (Nullable.GetUnderlyingType(typeInfo.Type) is Type nullableElementType)
+ {
+ JsonTypeInfo elementTypeInfo = typeInfo.Options.GetTypeInfo(nullableElementType);
+ customConverter = ExtractCustomNullableConverter(customConverter);
+ schema = MapJsonSchemaCore(ref state, elementTypeInfo, customConverter: customConverter, cacheResult: false);
+
+ if (schema.Enum != null)
+ {
+ Debug.Assert(elementTypeInfo.Type.IsEnum, "The enum keyword should only be populated by schemas for enum types.");
+ schema.Enum.Add(null); // Append null to the enum array.
+ }
+
+ return CompleteSchema(ref state, schema);
+ }
+
+ switch (typeInfo.Kind)
+ {
+ case JsonTypeInfoKind.Object:
+ List>? properties = null;
+ List? required = null;
+ JsonSchema? additionalProperties = null;
+
+ JsonUnmappedMemberHandling effectiveUnmappedMemberHandling = typeInfo.UnmappedMemberHandling ?? typeInfo.Options.UnmappedMemberHandling;
+ if (effectiveUnmappedMemberHandling is JsonUnmappedMemberHandling.Disallow)
+ {
+ // Disallow unspecified properties.
+ additionalProperties = JsonSchema.False;
+ }
+
+ if (typeDiscriminator is { } typeDiscriminatorPair)
+ {
+ (properties = new()).Add(typeDiscriminatorPair);
+ if (parentPolymorphicTypeContainsTypesWithoutDiscriminator)
+ {
+ // Require the discriminator here since it's not common to all derived types.
+ (required = new()).Add(typeDiscriminatorPair.Key);
+ }
+ }
+
+ Func? parameterInfoMapper =
+ ReflectionHelpers.ResolveJsonConstructorParameterMapper(typeInfo.Type, typeInfo);
+
+ state.PushSchemaNode(JsonSchemaConstants.PropertiesPropertyName);
+ foreach (JsonPropertyInfo property in typeInfo.Properties)
+ {
+ if (property is { Get: null, Set: null } or { IsExtensionData: true })
+ {
+ continue; // Skip JsonIgnored properties and extension data
+ }
+
+ JsonNumberHandling? propertyNumberHandling = property.NumberHandling ?? effectiveNumberHandling;
+ JsonTypeInfo propertyTypeInfo = typeInfo.Options.GetTypeInfo(property.PropertyType);
+
+ // Resolve the attribute provider for the property.
+ ICustomAttributeProvider? attributeProvider = ReflectionHelpers.ResolveAttributeProvider(typeInfo.Type, property);
+
+ // Declare the property as nullable if either getter or setter are nullable.
+ bool isNonNullableProperty = false;
+ if (attributeProvider is MemberInfo memberInfo)
+ {
+ NullabilityInfo nullabilityInfo = ReflectionHelpers.GetMemberNullability(state.NullabilityInfoContext, memberInfo);
+ isNonNullableProperty =
+ (property.Get is null || nullabilityInfo.ReadState is NullabilityState.NotNull) &&
+ (property.Set is null || nullabilityInfo.WriteState is NullabilityState.NotNull);
+ }
+
+ bool isRequired = property.IsRequired;
+ bool hasDefaultValue = false;
+ JsonNode? defaultValue = null;
+
+ ParameterInfo? associatedParameter = parameterInfoMapper?.Invoke(property);
+ if (associatedParameter != null)
+ {
+ ResolveParameterInfo(
+ associatedParameter,
+ propertyTypeInfo,
+ state.NullabilityInfoContext,
+ out hasDefaultValue,
+ out defaultValue,
+ out bool isNonNullableParameter,
+ ref isRequired);
+
+ isNonNullableProperty &= isNonNullableParameter;
+ }
+
+ state.PushSchemaNode(property.Name);
+ JsonSchema propertySchema = MapJsonSchemaCore(
+ ref state,
+ propertyTypeInfo,
+ parentType: typeInfo.Type,
+ propertyInfo: property,
+ parameterInfo: associatedParameter,
+ propertyAttributeProvider: attributeProvider,
+ isNonNullableType: isNonNullableProperty,
+ customConverter: property.CustomConverter,
+ customNumberHandling: propertyNumberHandling);
+
+ state.PopSchemaNode();
+
+ if (hasDefaultValue)
+ {
+ JsonSchema.EnsureMutable(ref propertySchema);
+ propertySchema.DefaultValue = defaultValue;
+ propertySchema.HasDefaultValue = true;
+ }
+
+ (properties ??= new()).Add(new(property.Name, propertySchema));
+
+ if (isRequired)
+ {
+ (required ??= new()).Add(property.Name);
+ }
+ }
+
+ state.PopSchemaNode();
+ return CompleteSchema(ref state, new()
+ {
+ Type = JsonSchemaType.Object,
+ Properties = properties,
+ Required = required,
+ AdditionalProperties = additionalProperties,
+ });
+
+ case JsonTypeInfoKind.Enumerable:
+ Type elementType = ReflectionHelpers.GetElementType(typeInfo);
+ JsonTypeInfo elementTypeInfo = typeInfo.Options.GetTypeInfo(elementType);
+
+ if (typeDiscriminator is null)
+ {
+ state.PushSchemaNode(JsonSchemaConstants.ItemsPropertyName);
+ JsonSchema items = MapJsonSchemaCore(ref state, elementTypeInfo, customNumberHandling: effectiveNumberHandling);
+ state.PopSchemaNode();
+
+ return CompleteSchema(ref state, new()
+ {
+ Type = JsonSchemaType.Array,
+ Items = items.IsTrue ? null : items,
+ });
+ }
+ else
+ {
+ // Polymorphic enumerable types are represented using a wrapping object:
+ // { "$type" : "discriminator", "$values" : [element1, element2, ...] }
+ // Which corresponds to the schema
+ // { "properties" : { "$type" : { "const" : "discriminator" }, "$values" : { "type" : "array", "items" : { ... } } } }
+ const string ValuesKeyword = "$values";
+
+ state.PushSchemaNode(JsonSchemaConstants.PropertiesPropertyName);
+ state.PushSchemaNode(ValuesKeyword);
+ state.PushSchemaNode(JsonSchemaConstants.ItemsPropertyName);
+
+ JsonSchema items = MapJsonSchemaCore(ref state, elementTypeInfo, customNumberHandling: effectiveNumberHandling);
+
+ state.PopSchemaNode();
+ state.PopSchemaNode();
+ state.PopSchemaNode();
+
+ return CompleteSchema(ref state, new()
+ {
+ Type = JsonSchemaType.Object,
+ Properties = new()
+ {
+ typeDiscriminator.Value,
+ new(ValuesKeyword,
+ new JsonSchema
+ {
+ Type = JsonSchemaType.Array,
+ Items = items.IsTrue ? null : items,
+ }),
+ },
+ Required = parentPolymorphicTypeContainsTypesWithoutDiscriminator ? new() { typeDiscriminator.Value.Key } : null,
+ });
+ }
+
+ case JsonTypeInfoKind.Dictionary:
+ Type valueType = ReflectionHelpers.GetElementType(typeInfo);
+ JsonTypeInfo valueTypeInfo = typeInfo.Options.GetTypeInfo(valueType);
+
+ List>? dictProps = null;
+ List? dictRequired = null;
+
+ if (typeDiscriminator is { } dictDiscriminator)
+ {
+ dictProps = new() { dictDiscriminator };
+ if (parentPolymorphicTypeContainsTypesWithoutDiscriminator)
+ {
+ // Require the discriminator here since it's not common to all derived types.
+ dictRequired = new() { dictDiscriminator.Key };
+ }
+ }
+
+ state.PushSchemaNode(JsonSchemaConstants.AdditionalPropertiesPropertyName);
+ JsonSchema valueSchema = MapJsonSchemaCore(ref state, valueTypeInfo, customNumberHandling: effectiveNumberHandling);
+ state.PopSchemaNode();
+
+ return CompleteSchema(ref state, new()
+ {
+ Type = JsonSchemaType.Object,
+ Properties = dictProps,
+ Required = dictRequired,
+ AdditionalProperties = valueSchema.IsTrue ? null : valueSchema,
+ });
+
+ default:
+ Debug.Assert(typeInfo.Kind is JsonTypeInfoKind.None, "The default case should handle unrecognize type kinds.");
+
+ if (_simpleTypeSchemaFactories.TryGetValue(typeInfo.Type, out Func? simpleTypeSchemaFactory))
+ {
+ schema = simpleTypeSchemaFactory(effectiveNumberHandling);
+ }
+ else if (typeInfo.Type.IsEnum)
+ {
+ schema = GetEnumConverterSchema(typeInfo, effectiveConverter);
+ }
+ else
+ {
+ schema = JsonSchema.True;
+ }
+
+ return CompleteSchema(ref state, schema);
+ }
+
+ JsonSchema CompleteSchema(ref GenerationState state, JsonSchema schema)
+ {
+ if (schema.Ref is null)
+ {
+ if (IsNullableSchema(ref state))
+ {
+ schema.MakeNullable();
+ }
+
+ bool IsNullableSchema(ref GenerationState state)
+ {
+ // A schema is marked as nullable if either
+ // 1. We have a schema for a property where either the getter or setter are marked as nullable.
+ // 2. We have a schema for a reference type, unless we're explicitly treating null-oblivious types as non-nullable
+
+ if (propertyInfo != null || parameterInfo != null)
+ {
+ return !isNonNullableType;
+ }
+ else
+ {
+ return ReflectionHelpers.CanBeNull(typeInfo.Type) &&
+ !parentPolymorphicTypeIsNonNullable &&
+ !state.ExporterOptions.TreatNullObliviousAsNonNullable;
+ }
+ }
+ }
+
+ if (state.ExporterOptions.TransformSchemaNode != null)
+ {
+ // Prime the schema for invocation by the JsonNode transformer.
+ schema.GenerationContext = exporterContext;
+ }
+
+ return schema;
+ }
+ }
+
+ private readonly ref struct GenerationState
+ {
+ private const int DefaultMaxDepth = 64;
+ private readonly List _currentPath = new();
+ private readonly Dictionary<(JsonTypeInfo, JsonPropertyInfo?), string[]> _generated = new();
+ private readonly int _maxDepth;
+
+ public GenerationState(JsonSchemaExporterOptions exporterOptions, JsonSerializerOptions options, NullabilityInfoContext? nullabilityInfoContext = null)
+ {
+ ExporterOptions = exporterOptions;
+ NullabilityInfoContext = nullabilityInfoContext ?? new();
+ _maxDepth = options.MaxDepth is 0 ? DefaultMaxDepth : options.MaxDepth;
+ }
+
+ public JsonSchemaExporterOptions ExporterOptions { get; }
+ public NullabilityInfoContext NullabilityInfoContext { get; }
+ public int CurrentDepth => _currentPath.Count;
+
+ public void PushSchemaNode(string nodeId)
+ {
+ if (CurrentDepth == _maxDepth)
+ {
+ ThrowHelpers.ThrowInvalidOperationException_MaxDepthReached();
+ }
+
+ _currentPath.Add(nodeId);
+ }
+
+ public void PopSchemaNode()
+ {
+ _currentPath.RemoveAt(_currentPath.Count - 1);
+ }
+
+ ///
+ /// Registers the current schema node generation context; if it has already been generated return a JSON pointer to its location.
+ ///
+ public bool TryGetExistingJsonPointer(in JsonSchemaExporterContext context, [NotNullWhen(true)] out string? existingJsonPointer)
+ {
+ (JsonTypeInfo, JsonPropertyInfo?) key = (context.TypeInfo, context.PropertyInfo);
+#if NET
+ ref string[]? pathToSchema = ref CollectionsMarshal.GetValueRefOrAddDefault(_generated, key, out bool exists);
+#else
+ bool exists = _generated.TryGetValue(key, out string[]? pathToSchema);
+#endif
+ if (exists)
+ {
+ existingJsonPointer = FormatJsonPointer(pathToSchema);
+ return true;
+ }
+#if NET
+ pathToSchema = context._path;
+#else
+ _generated[key] = context._path;
+#endif
+ existingJsonPointer = null;
+ return false;
+ }
+
+ public JsonSchemaExporterContext CreateContext(
+ JsonTypeInfo typeInfo,
+ JsonTypeInfo? baseTypeInfo,
+ Type? declaringType,
+ JsonPropertyInfo? propertyInfo,
+ ParameterInfo? parameterInfo,
+ ICustomAttributeProvider? propertyAttributeProvider)
+ {
+ return new JsonSchemaExporterContext(typeInfo, baseTypeInfo, declaringType, propertyInfo, parameterInfo, propertyAttributeProvider, _currentPath.ToArray());
+ }
+
+ private static string FormatJsonPointer(ReadOnlySpan path)
+ {
+ if (path.IsEmpty)
+ {
+ return "#";
+ }
+
+ StringBuilder sb = new();
+ _ = sb.Append('#');
+
+ for (int i = 0; i < path.Length; i++)
+ {
+ string segment = path[i];
+ if (segment.AsSpan().IndexOfAny('~', '/') != -1)
+ {
+#pragma warning disable CA1307 // Specify StringComparison for clarity
+ segment = segment.Replace("~", "~0").Replace("/", "~1");
+#pragma warning restore CA1307
+ }
+
+ _ = sb.Append('/');
+ _ = sb.Append(segment);
+ }
+
+ return sb.ToString();
+ }
+ }
+
+ private static readonly Dictionary> _simpleTypeSchemaFactories = new()
+ {
+ [typeof(object)] = _ => JsonSchema.True,
+ [typeof(bool)] = _ => new JsonSchema { Type = JsonSchemaType.Boolean },
+ [typeof(byte)] = numberHandling => GetSchemaForNumericType(JsonSchemaType.Integer, numberHandling),
+ [typeof(ushort)] = numberHandling => GetSchemaForNumericType(JsonSchemaType.Integer, numberHandling),
+ [typeof(uint)] = numberHandling => GetSchemaForNumericType(JsonSchemaType.Integer, numberHandling),
+ [typeof(ulong)] = numberHandling => GetSchemaForNumericType(JsonSchemaType.Integer, numberHandling),
+ [typeof(sbyte)] = numberHandling => GetSchemaForNumericType(JsonSchemaType.Integer, numberHandling),
+ [typeof(short)] = numberHandling => GetSchemaForNumericType(JsonSchemaType.Integer, numberHandling),
+ [typeof(int)] = numberHandling => GetSchemaForNumericType(JsonSchemaType.Integer, numberHandling),
+ [typeof(long)] = numberHandling => GetSchemaForNumericType(JsonSchemaType.Integer, numberHandling),
+ [typeof(float)] = numberHandling => GetSchemaForNumericType(JsonSchemaType.Number, numberHandling, isIeeeFloatingPoint: true),
+ [typeof(double)] = numberHandling => GetSchemaForNumericType(JsonSchemaType.Number, numberHandling, isIeeeFloatingPoint: true),
+ [typeof(decimal)] = numberHandling => GetSchemaForNumericType(JsonSchemaType.Number, numberHandling),
+#if NET6_0_OR_GREATER
+ [typeof(Half)] = numberHandling => GetSchemaForNumericType(JsonSchemaType.Number, numberHandling, isIeeeFloatingPoint: true),
+#endif
+#if NET7_0_OR_GREATER
+ [typeof(UInt128)] = numberHandling => GetSchemaForNumericType(JsonSchemaType.Integer, numberHandling),
+ [typeof(Int128)] = numberHandling => GetSchemaForNumericType(JsonSchemaType.Integer, numberHandling),
+#endif
+ [typeof(char)] = _ => new JsonSchema { Type = JsonSchemaType.String, MinLength = 1, MaxLength = 1 },
+ [typeof(string)] = _ => new JsonSchema { Type = JsonSchemaType.String },
+ [typeof(byte[])] = _ => new JsonSchema { Type = JsonSchemaType.String },
+ [typeof(Memory)] = _ => new JsonSchema { Type = JsonSchemaType.String },
+ [typeof(ReadOnlyMemory)] = _ => new JsonSchema { Type = JsonSchemaType.String },
+ [typeof(DateTime)] = _ => new JsonSchema { Type = JsonSchemaType.String, Format = "date-time" },
+ [typeof(DateTimeOffset)] = _ => new JsonSchema { Type = JsonSchemaType.String, Format = "date-time" },
+ [typeof(TimeSpan)] = _ => new JsonSchema
+ {
+ Comment = "Represents a System.TimeSpan value.",
+ Type = JsonSchemaType.String,
+ Pattern = @"^-?(\d+\.)?\d{2}:\d{2}:\d{2}(\.\d{1,7})?$",
+ },
+
+#if NET6_0_OR_GREATER
+ [typeof(DateOnly)] = _ => new JsonSchema { Type = JsonSchemaType.String, Format = "date" },
+ [typeof(TimeOnly)] = _ => new JsonSchema { Type = JsonSchemaType.String, Format = "time" },
+#endif
+ [typeof(Guid)] = _ => new JsonSchema { Type = JsonSchemaType.String, Format = "uuid" },
+ [typeof(Uri)] = _ => new JsonSchema { Type = JsonSchemaType.String, Format = "uri" },
+ [typeof(Version)] = _ => new JsonSchema
+ {
+ Comment = "Represents a version string.",
+ Type = JsonSchemaType.String,
+ Pattern = @"^\d+(\.\d+){1,3}$",
+ },
+
+ [typeof(JsonDocument)] = _ => JsonSchema.True,
+ [typeof(JsonElement)] = _ => JsonSchema.True,
+ [typeof(JsonNode)] = _ => JsonSchema.True,
+ [typeof(JsonValue)] = _ => JsonSchema.True,
+ [typeof(JsonObject)] = _ => new JsonSchema { Type = JsonSchemaType.Object },
+ [typeof(JsonArray)] = _ => new JsonSchema { Type = JsonSchemaType.Array },
+ };
+
+ // Adapted from https://github.com/dotnet/runtime/blob/release/9.0/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Converters/Value/JsonPrimitiveConverter.cs#L36-L69
+ private static JsonSchema GetSchemaForNumericType(JsonSchemaType schemaType, JsonNumberHandling numberHandling, bool isIeeeFloatingPoint = false)
+ {
+ Debug.Assert(schemaType is JsonSchemaType.Integer or JsonSchemaType.Number, "schema type must be number or integer");
+ Debug.Assert(!isIeeeFloatingPoint || schemaType is JsonSchemaType.Number, "If specifying IEEE the schema type must be number");
+
+ string? pattern = null;
+
+ if ((numberHandling & (JsonNumberHandling.AllowReadingFromString | JsonNumberHandling.WriteAsString)) != 0)
+ {
+ if (schemaType is JsonSchemaType.Integer)
+ {
+ pattern = @"^-?(?:0|[1-9]\d*)$";
+ }
+ else if (isIeeeFloatingPoint)
+ {
+ pattern = @"^-?(?:0|[1-9]\d*)(?:\.\d+)?(?:[eE][+-]?\d+)?$";
+ }
+ else
+ {
+ pattern = @"^-?(?:0|[1-9]\d*)(?:\.\d+)?$";
+ }
+
+ schemaType |= JsonSchemaType.String;
+ }
+
+ if (isIeeeFloatingPoint && (numberHandling & JsonNumberHandling.AllowNamedFloatingPointLiterals) != 0)
+ {
+ return new JsonSchema
+ {
+ AnyOf = new()
+ {
+ new JsonSchema { Type = schemaType, Pattern = pattern },
+ new JsonSchema { Enum = new() { (JsonNode)"NaN", (JsonNode)"Infinity", (JsonNode)"-Infinity" } },
+ },
+ };
+ }
+
+ return new JsonSchema { Type = schemaType, Pattern = pattern };
+ }
+
+ private static JsonConverter? ExtractCustomNullableConverter(JsonConverter? converter)
+ {
+ Debug.Assert(converter is null || ReflectionHelpers.IsBuiltInConverter(converter), "If specified the converter must be built-in.");
+
+ if (converter is null)
+ {
+ return null;
+ }
+
+ return ReflectionHelpers.GetElementConverter(converter);
+ }
+
+ private static void ValidateOptions(JsonSerializerOptions options)
+ {
+ if (options.ReferenceHandler == ReferenceHandler.Preserve)
+ {
+ ThrowHelpers.ThrowNotSupportedException_ReferenceHandlerPreserveNotSupported();
+ }
+
+ options.MakeReadOnly();
+ }
+
+ private static void ResolveParameterInfo(
+ ParameterInfo parameter,
+ JsonTypeInfo parameterTypeInfo,
+ NullabilityInfoContext nullabilityInfoContext,
+ out bool hasDefaultValue,
+ out JsonNode? defaultValue,
+ out bool isNonNullable,
+ ref bool isRequired)
+ {
+ Debug.Assert(parameterTypeInfo.Type == parameter.ParameterType, "The typeInfo type must match the ParameterInfo type.");
+
+ // Incorporate the nullability information from the parameter.
+ isNonNullable = ReflectionHelpers.GetParameterNullability(nullabilityInfoContext, parameter) is NullabilityState.NotNull;
+
+ if (parameter.HasDefaultValue)
+ {
+ // Append the default value to the description.
+ object? defaultVal = ReflectionHelpers.GetNormalizedDefaultValue(parameter);
+ defaultValue = JsonSerializer.SerializeToNode(defaultVal, parameterTypeInfo);
+ hasDefaultValue = true;
+ }
+ else
+ {
+ // Parameter is not optional, mark as required.
+ isRequired = true;
+ defaultValue = null;
+ hasDefaultValue = false;
+ }
+ }
+
+ // Adapted from https://github.com/dotnet/runtime/blob/release/9.0/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Converters/Value/EnumConverter.cs#L498-L521
+ private static JsonSchema GetEnumConverterSchema(JsonTypeInfo typeInfo, JsonConverter converter)
+ {
+ Debug.Assert(typeInfo.Type.IsEnum && ReflectionHelpers.IsBuiltInConverter(converter), "must be using a built-in enum converter.");
+
+ if (converter is JsonConverterFactory factory)
+ {
+ converter = factory.CreateConverter(typeInfo.Type, typeInfo.Options)!;
+ }
+
+ ReflectionHelpers.GetEnumConverterConfig(converter, out JsonNamingPolicy? namingPolicy, out bool allowString);
+
+ if (allowString)
+ {
+ // This explicitly ignores the integer component in converters configured as AllowNumbers | AllowStrings
+ // which is the default for JsonStringEnumConverter. This sacrifices some precision in the schema for simplicity.
+
+ if (typeInfo.Type.GetCustomAttribute() is not null)
+ {
+ // Do not report enum values in case of flags.
+ return new() { Type = JsonSchemaType.String };
+ }
+
+ JsonArray enumValues = new();
+ foreach (string name in Enum.GetNames(typeInfo.Type))
+ {
+ // This does not account for custom names specified via the new
+ // JsonStringEnumMemberNameAttribute introduced in .NET 9.
+ string effectiveName = namingPolicy?.ConvertName(name) ?? name;
+ enumValues.Add((JsonNode)effectiveName);
+ }
+
+ return new() { Enum = enumValues };
+ }
+
+ return new() { Type = JsonSchemaType.Integer };
+ }
+
+ private static class JsonSchemaConstants
+ {
+ public const string SchemaPropertyName = "$schema";
+ public const string RefPropertyName = "$ref";
+ public const string CommentPropertyName = "$comment";
+ public const string TitlePropertyName = "title";
+ public const string DescriptionPropertyName = "description";
+ public const string TypePropertyName = "type";
+ public const string FormatPropertyName = "format";
+ public const string PatternPropertyName = "pattern";
+ public const string PropertiesPropertyName = "properties";
+ public const string RequiredPropertyName = "required";
+ public const string ItemsPropertyName = "items";
+ public const string AdditionalPropertiesPropertyName = "additionalProperties";
+ public const string EnumPropertyName = "enum";
+ public const string NotPropertyName = "not";
+ public const string AnyOfPropertyName = "anyOf";
+ public const string ConstPropertyName = "const";
+ public const string DefaultPropertyName = "default";
+ public const string MinLengthPropertyName = "minLength";
+ public const string MaxLengthPropertyName = "maxLength";
+ }
+
+ private static class ThrowHelpers
+ {
+ [DoesNotReturn]
+ public static void ThrowInvalidOperationException_MaxDepthReached() =>
+ throw new InvalidOperationException("The depth of the generated JSON schema exceeds the JsonSerializerOptions.MaxDepth setting.");
+
+ [DoesNotReturn]
+ public static void ThrowNotSupportedException_ReferenceHandlerPreserveNotSupported() =>
+ throw new NotSupportedException("Schema generation not supported with ReferenceHandler.Preserve enabled.");
+ }
+}
+#endif
diff --git a/src/Shared/JsonSchemaExporter/JsonSchemaExporterContext.cs b/src/Shared/JsonSchemaExporter/JsonSchemaExporterContext.cs
new file mode 100644
index 00000000000..3602ee46df4
--- /dev/null
+++ b/src/Shared/JsonSchemaExporter/JsonSchemaExporterContext.cs
@@ -0,0 +1,77 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+#if !NET9_0_OR_GREATER
+using System;
+using System.Reflection;
+using System.Text.Json.Serialization.Metadata;
+
+namespace System.Text.Json.Schema;
+
+///
+/// Defines the context in which a JSON schema within a type graph is being generated.
+///
+#if !SHARED_PROJECT
+[System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
+#endif
+internal readonly struct JsonSchemaExporterContext
+{
+#pragma warning disable IDE1006 // Naming Styles
+ internal readonly string[] _path;
+#pragma warning restore IDE1006 // Naming Styles
+
+ internal JsonSchemaExporterContext(
+ JsonTypeInfo typeInfo,
+ JsonTypeInfo? baseTypeInfo,
+ Type? declaringType,
+ JsonPropertyInfo? propertyInfo,
+ ParameterInfo? parameterInfo,
+ ICustomAttributeProvider? propertyAttributeProvider,
+ string[] path)
+ {
+ TypeInfo = typeInfo;
+ DeclaringType = declaringType;
+ BaseTypeInfo = baseTypeInfo;
+ PropertyInfo = propertyInfo;
+ ParameterInfo = parameterInfo;
+ PropertyAttributeProvider = propertyAttributeProvider;
+ _path = path;
+ }
+
+ ///
+ /// Gets the path to the schema document currently being generated.
+ ///
+ public ReadOnlySpan Path => _path;
+
+ ///
+ /// Gets the for the type being processed.
+ ///
+ public JsonTypeInfo TypeInfo { get; }
+
+ ///
+ /// Gets the declaring type of the property or parameter being processed.
+ ///
+ public Type? DeclaringType { get; }
+
+ ///
+ /// Gets the type info for the polymorphic base type if generated as a derived type.
+ ///
+ public JsonTypeInfo? BaseTypeInfo { get; }
+
+ ///
+ /// Gets the if the schema is being generated for a property.
+ ///
+ public JsonPropertyInfo? PropertyInfo { get; }
+
+ ///
+ /// Gets the if a constructor parameter
+ /// has been associated with the accompanying .
+ ///
+ public ParameterInfo? ParameterInfo { get; }
+
+ ///
+ /// Gets the corresponding to the property or field being processed.
+ ///
+ public ICustomAttributeProvider? PropertyAttributeProvider { get; }
+}
+#endif
diff --git a/src/Shared/JsonSchemaExporter/JsonSchemaExporterOptions.cs b/src/Shared/JsonSchemaExporter/JsonSchemaExporterOptions.cs
new file mode 100644
index 00000000000..53a269ea612
--- /dev/null
+++ b/src/Shared/JsonSchemaExporter/JsonSchemaExporterOptions.cs
@@ -0,0 +1,38 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+#if !NET9_0_OR_GREATER
+using System;
+using System.Text.Json.Nodes;
+
+namespace System.Text.Json.Schema;
+
+///
+/// Controls the behavior of the class.
+///
+#if !SHARED_PROJECT
+[System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
+#endif
+internal sealed class JsonSchemaExporterOptions
+{
+ ///
+ /// Gets the default configuration object used by .
+ ///
+ public static JsonSchemaExporterOptions Default { get; } = new();
+
+ ///
+ /// Gets a value indicating whether non-nullable schemas should be generated for null oblivious reference types.
+ ///
+ ///
+ /// Defaults to . Due to restrictions in the run-time representation of nullable reference types
+ /// most occurrences are null oblivious and are treated as nullable by the serializer. A notable exception to that rule
+ /// are nullability annotations of field, property and constructor parameters which are represented in the contract metadata.
+ ///
+ public bool TreatNullObliviousAsNonNullable { get; init; }
+
+ ///
+ /// Gets a callback that is invoked for every schema that is generated within the type graph.
+ ///
+ public Func? TransformSchemaNode { get; init; }
+}
+#endif
diff --git a/src/Shared/JsonSchemaExporter/NullabilityInfoContext/NullabilityInfo.cs b/src/Shared/JsonSchemaExporter/NullabilityInfoContext/NullabilityInfo.cs
new file mode 100644
index 00000000000..bd9b132cd0f
--- /dev/null
+++ b/src/Shared/JsonSchemaExporter/NullabilityInfoContext/NullabilityInfo.cs
@@ -0,0 +1,75 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+#if !NET6_0_OR_GREATER
+using System.Diagnostics.CodeAnalysis;
+
+#pragma warning disable SA1623 // Property summary documentation should match accessors
+
+namespace System.Reflection
+{
+ ///
+ /// A class that represents nullability info.
+ ///
+ [ExcludeFromCodeCoverage]
+ internal sealed class NullabilityInfo
+ {
+ internal NullabilityInfo(Type type, NullabilityState readState, NullabilityState writeState,
+ NullabilityInfo? elementType, NullabilityInfo[] typeArguments)
+ {
+ Type = type;
+ ReadState = readState;
+ WriteState = writeState;
+ ElementType = elementType;
+ GenericTypeArguments = typeArguments;
+ }
+
+ ///
+ /// The of the member or generic parameter
+ /// to which this NullabilityInfo belongs.
+ ///
+ public Type Type { get; }
+
+ ///
+ /// The nullability read state of the member.
+ ///
+ public NullabilityState ReadState { get; internal set; }
+
+ ///
+ /// The nullability write state of the member.
+ ///
+ public NullabilityState WriteState { get; internal set; }
+
+ ///
+ /// If the member type is an array, gives the of the elements of the array, null otherwise.
+ ///
+ public NullabilityInfo? ElementType { get; }
+
+ ///
+ /// If the member type is a generic type, gives the array of for each type parameter.
+ ///
+ public NullabilityInfo[] GenericTypeArguments { get; }
+ }
+
+ ///
+ /// An enum that represents nullability state.
+ ///
+ internal enum NullabilityState
+ {
+ ///
+ /// Nullability context not enabled (oblivious).
+ ///
+ Unknown,
+
+ ///
+ /// Non nullable value or reference type.
+ ///
+ NotNull,
+
+ ///
+ /// Nullable value or reference type.
+ ///
+ Nullable,
+ }
+}
+#endif
diff --git a/src/Shared/JsonSchemaExporter/NullabilityInfoContext/NullabilityInfoContext.cs b/src/Shared/JsonSchemaExporter/NullabilityInfoContext/NullabilityInfoContext.cs
new file mode 100644
index 00000000000..3edee1b9cb8
--- /dev/null
+++ b/src/Shared/JsonSchemaExporter/NullabilityInfoContext/NullabilityInfoContext.cs
@@ -0,0 +1,661 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+#if !NET6_0_OR_GREATER
+using System.Collections.Generic;
+using System.Collections.ObjectModel;
+using System.Diagnostics;
+using System.Diagnostics.CodeAnalysis;
+using System.Linq;
+
+#pragma warning disable SA1204 // Static elements should appear before instance elements
+#pragma warning disable S109 // Magic numbers should not be used
+#pragma warning disable S1067 // Expressions should not be too complex
+#pragma warning disable S4136 // Method overloads should be grouped together
+#pragma warning disable SA1202 // Elements should be ordered by access
+#pragma warning disable IDE1006 // Naming Styles
+
+namespace System.Reflection
+{
+ ///
+ /// Provides APIs for populating nullability information/context from reflection members:
+ /// , , and .
+ ///
+ [ExcludeFromCodeCoverage]
+ internal sealed class NullabilityInfoContext
+ {
+ private const string CompilerServicesNameSpace = "System.Runtime.CompilerServices";
+ private readonly Dictionary _publicOnlyModules = new();
+ private readonly Dictionary _context = new();
+
+ [Flags]
+ private enum NotAnnotatedStatus
+ {
+ None = 0x0, // no restriction, all members annotated
+ Private = 0x1, // private members not annotated
+ Internal = 0x2, // internal members not annotated
+ }
+
+ private NullabilityState? GetNullableContext(MemberInfo? memberInfo)
+ {
+ while (memberInfo != null)
+ {
+ if (_context.TryGetValue(memberInfo, out NullabilityState state))
+ {
+ return state;
+ }
+
+ foreach (CustomAttributeData attribute in memberInfo.GetCustomAttributesData())
+ {
+ if (attribute.AttributeType.Name == "NullableContextAttribute" &&
+ attribute.AttributeType.Namespace == CompilerServicesNameSpace &&
+ attribute.ConstructorArguments.Count == 1)
+ {
+ state = TranslateByte(attribute.ConstructorArguments[0].Value);
+ _context.Add(memberInfo, state);
+ return state;
+ }
+ }
+
+ memberInfo = memberInfo.DeclaringType;
+ }
+
+ return null;
+ }
+
+ ///
+ /// Populates for the given .
+ /// If the nullablePublicOnly feature is set for an assembly, like it does in .NET SDK, the private and/or internal member's
+ /// nullability attributes are omitted, in this case the API will return NullabilityState.Unknown state.
+ ///
+ /// The parameter which nullability info gets populated.
+ /// If the parameterInfo parameter is null.
+ /// .
+ public NullabilityInfo Create(ParameterInfo parameterInfo)
+ {
+ IList attributes = parameterInfo.GetCustomAttributesData();
+ NullableAttributeStateParser parser = parameterInfo.Member is MethodBase method && IsPrivateOrInternalMethodAndAnnotationDisabled(method)
+ ? NullableAttributeStateParser.Unknown
+ : CreateParser(attributes);
+ NullabilityInfo nullability = GetNullabilityInfo(parameterInfo.Member, parameterInfo.ParameterType, parser);
+
+ if (nullability.ReadState != NullabilityState.Unknown)
+ {
+ CheckParameterMetadataType(parameterInfo, nullability);
+ }
+
+ CheckNullabilityAttributes(nullability, attributes);
+ return nullability;
+ }
+
+ private void CheckParameterMetadataType(ParameterInfo parameter, NullabilityInfo nullability)
+ {
+ ParameterInfo? metaParameter;
+ MemberInfo metaMember;
+
+ switch (parameter.Member)
+ {
+ case ConstructorInfo ctor:
+ var metaCtor = (ConstructorInfo)GetMemberMetadataDefinition(ctor);
+ metaMember = metaCtor;
+ metaParameter = GetMetaParameter(metaCtor, parameter);
+ break;
+
+ case MethodInfo method:
+ MethodInfo metaMethod = GetMethodMetadataDefinition(method);
+ metaMember = metaMethod;
+ metaParameter = string.IsNullOrEmpty(parameter.Name) ? metaMethod.ReturnParameter : GetMetaParameter(metaMethod, parameter);
+ break;
+
+ default:
+ return;
+ }
+
+ if (metaParameter != null)
+ {
+ CheckGenericParameters(nullability, metaMember, metaParameter.ParameterType, parameter.Member.ReflectedType);
+ }
+ }
+
+ private static ParameterInfo? GetMetaParameter(MethodBase metaMethod, ParameterInfo parameter)
+ {
+ var parameters = metaMethod.GetParameters();
+ for (int i = 0; i < parameters.Length; i++)
+ {
+ if (parameter.Position == i &&
+ parameter.Name == parameters[i].Name)
+ {
+ return parameters[i];
+ }
+ }
+
+ return null;
+ }
+
+ private static MethodInfo GetMethodMetadataDefinition(MethodInfo method)
+ {
+ if (method.IsGenericMethod && !method.IsGenericMethodDefinition)
+ {
+ method = method.GetGenericMethodDefinition();
+ }
+
+ return (MethodInfo)GetMemberMetadataDefinition(method);
+ }
+
+ private static void CheckNullabilityAttributes(NullabilityInfo nullability, IList attributes)
+ {
+ var codeAnalysisReadState = NullabilityState.Unknown;
+ var codeAnalysisWriteState = NullabilityState.Unknown;
+
+ foreach (CustomAttributeData attribute in attributes)
+ {
+ if (attribute.AttributeType.Namespace == "System.Diagnostics.CodeAnalysis")
+ {
+ if (attribute.AttributeType.Name == "NotNullAttribute")
+ {
+ codeAnalysisReadState = NullabilityState.NotNull;
+ }
+ else if ((attribute.AttributeType.Name == "MaybeNullAttribute" ||
+ attribute.AttributeType.Name == "MaybeNullWhenAttribute") &&
+ codeAnalysisReadState == NullabilityState.Unknown &&
+ !IsValueTypeOrValueTypeByRef(nullability.Type))
+ {
+ codeAnalysisReadState = NullabilityState.Nullable;
+ }
+ else if (attribute.AttributeType.Name == "DisallowNullAttribute")
+ {
+ codeAnalysisWriteState = NullabilityState.NotNull;
+ }
+ else if (attribute.AttributeType.Name == "AllowNullAttribute" &&
+ codeAnalysisWriteState == NullabilityState.Unknown &&
+ !IsValueTypeOrValueTypeByRef(nullability.Type))
+ {
+ codeAnalysisWriteState = NullabilityState.Nullable;
+ }
+ }
+ }
+
+ if (codeAnalysisReadState != NullabilityState.Unknown)
+ {
+ nullability.ReadState = codeAnalysisReadState;
+ }
+
+ if (codeAnalysisWriteState != NullabilityState.Unknown)
+ {
+ nullability.WriteState = codeAnalysisWriteState;
+ }
+ }
+
+ ///
+ /// Populates for the given .
+ /// If the nullablePublicOnly feature is set for an assembly, like it does in .NET SDK, the private and/or internal member's
+ /// nullability attributes are omitted, in this case the API will return NullabilityState.Unknown state.
+ ///
+ /// The parameter which nullability info gets populated.
+ /// If the propertyInfo parameter is null.
+ /// .
+ public NullabilityInfo Create(PropertyInfo propertyInfo)
+ {
+ MethodInfo? getter = propertyInfo.GetGetMethod(true);
+ MethodInfo? setter = propertyInfo.GetSetMethod(true);
+ bool annotationsDisabled = (getter == null || IsPrivateOrInternalMethodAndAnnotationDisabled(getter))
+ && (setter == null || IsPrivateOrInternalMethodAndAnnotationDisabled(setter));
+ NullableAttributeStateParser parser = annotationsDisabled ? NullableAttributeStateParser.Unknown : CreateParser(propertyInfo.GetCustomAttributesData());
+ NullabilityInfo nullability = GetNullabilityInfo(propertyInfo, propertyInfo.PropertyType, parser);
+
+ if (getter != null)
+ {
+ CheckNullabilityAttributes(nullability, getter.ReturnParameter.GetCustomAttributesData());
+ }
+ else
+ {
+ nullability.ReadState = NullabilityState.Unknown;
+ }
+
+ if (setter != null)
+ {
+ CheckNullabilityAttributes(nullability, setter.GetParameters().Last().GetCustomAttributesData());
+ }
+ else
+ {
+ nullability.WriteState = NullabilityState.Unknown;
+ }
+
+ return nullability;
+ }
+
+ private bool IsPrivateOrInternalMethodAndAnnotationDisabled(MethodBase method)
+ {
+ if ((method.IsPrivate || method.IsFamilyAndAssembly || method.IsAssembly) &&
+ IsPublicOnly(method.IsPrivate, method.IsFamilyAndAssembly, method.IsAssembly, method.Module))
+ {
+ return true;
+ }
+
+ return false;
+ }
+
+ ///
+ /// Populates for the given .
+ /// If the nullablePublicOnly feature is set for an assembly, like it does in .NET SDK, the private and/or internal member's
+ /// nullability attributes are omitted, in this case the API will return NullabilityState.Unknown state.
+ ///
+ /// The parameter which nullability info gets populated.
+ /// If the eventInfo parameter is null.
+ /// .
+ public NullabilityInfo Create(EventInfo eventInfo)
+ {
+ return GetNullabilityInfo(eventInfo, eventInfo.EventHandlerType!, CreateParser(eventInfo.GetCustomAttributesData()));
+ }
+
+ ///
+ /// Populates for the given
+ /// If the nullablePublicOnly feature is set for an assembly, like it does in .NET SDK, the private and/or internal member's
+ /// nullability attributes are omitted, in this case the API will return NullabilityState.Unknown state.
+ ///
+ /// The parameter which nullability info gets populated.
+ /// If the fieldInfo parameter is null.
+ /// .
+ public NullabilityInfo Create(FieldInfo fieldInfo)
+ {
+ IList attributes = fieldInfo.GetCustomAttributesData();
+ NullableAttributeStateParser parser = IsPrivateOrInternalFieldAndAnnotationDisabled(fieldInfo) ? NullableAttributeStateParser.Unknown : CreateParser(attributes);
+ NullabilityInfo nullability = GetNullabilityInfo(fieldInfo, fieldInfo.FieldType, parser);
+ CheckNullabilityAttributes(nullability, attributes);
+ return nullability;
+ }
+
+ private bool IsPrivateOrInternalFieldAndAnnotationDisabled(FieldInfo fieldInfo)
+ {
+ if ((fieldInfo.IsPrivate || fieldInfo.IsFamilyAndAssembly || fieldInfo.IsAssembly) &&
+ IsPublicOnly(fieldInfo.IsPrivate, fieldInfo.IsFamilyAndAssembly, fieldInfo.IsAssembly, fieldInfo.Module))
+ {
+ return true;
+ }
+
+ return false;
+ }
+
+ private bool IsPublicOnly(bool isPrivate, bool isFamilyAndAssembly, bool isAssembly, Module module)
+ {
+ if (!_publicOnlyModules.TryGetValue(module, out NotAnnotatedStatus value))
+ {
+ value = PopulateAnnotationInfo(module.GetCustomAttributesData());
+ _publicOnlyModules.Add(module, value);
+ }
+
+ if (value == NotAnnotatedStatus.None)
+ {
+ return false;
+ }
+
+ if (((isPrivate || isFamilyAndAssembly) && value.HasFlag(NotAnnotatedStatus.Private)) ||
+ (isAssembly && value.HasFlag(NotAnnotatedStatus.Internal)))
+ {
+ return true;
+ }
+
+ return false;
+ }
+
+ private static NotAnnotatedStatus PopulateAnnotationInfo(IList customAttributes)
+ {
+ foreach (CustomAttributeData attribute in customAttributes)
+ {
+ if (attribute.AttributeType.Name == "NullablePublicOnlyAttribute" &&
+ attribute.AttributeType.Namespace == CompilerServicesNameSpace &&
+ attribute.ConstructorArguments.Count == 1)
+ {
+ if (attribute.ConstructorArguments[0].Value is bool boolValue && boolValue)
+ {
+ return NotAnnotatedStatus.Internal | NotAnnotatedStatus.Private;
+ }
+ else
+ {
+ return NotAnnotatedStatus.Private;
+ }
+ }
+ }
+
+ return NotAnnotatedStatus.None;
+ }
+
+ private NullabilityInfo GetNullabilityInfo(MemberInfo memberInfo, Type type, NullableAttributeStateParser parser)
+ {
+ int index = 0;
+ NullabilityInfo nullability = GetNullabilityInfo(memberInfo, type, parser, ref index);
+
+ if (nullability.ReadState != NullabilityState.Unknown)
+ {
+ TryLoadGenericMetaTypeNullability(memberInfo, nullability);
+ }
+
+ return nullability;
+ }
+
+ private NullabilityInfo GetNullabilityInfo(MemberInfo memberInfo, Type type, NullableAttributeStateParser parser, ref int index)
+ {
+ NullabilityState state = NullabilityState.Unknown;
+ NullabilityInfo? elementState = null;
+ NullabilityInfo[] genericArgumentsState = Array.Empty();
+ Type underlyingType = type;
+
+ if (underlyingType.IsByRef || underlyingType.IsPointer)
+ {
+ underlyingType = underlyingType.GetElementType()!;
+ }
+
+ if (underlyingType.IsValueType)
+ {
+ if (Nullable.GetUnderlyingType(underlyingType) is { } nullableUnderlyingType)
+ {
+ underlyingType = nullableUnderlyingType;
+ state = NullabilityState.Nullable;
+ }
+ else
+ {
+ state = NullabilityState.NotNull;
+ }
+
+ if (underlyingType.IsGenericType)
+ {
+ ++index;
+ }
+ }
+ else
+ {
+ if (!parser.ParseNullableState(index++, ref state)
+ && GetNullableContext(memberInfo) is { } contextState)
+ {
+ state = contextState;
+ }
+
+ if (underlyingType.IsArray)
+ {
+ elementState = GetNullabilityInfo(memberInfo, underlyingType.GetElementType()!, parser, ref index);
+ }
+ }
+
+ if (underlyingType.IsGenericType)
+ {
+ Type[] genericArguments = underlyingType.GetGenericArguments();
+ genericArgumentsState = new NullabilityInfo[genericArguments.Length];
+
+ for (int i = 0; i < genericArguments.Length; i++)
+ {
+ genericArgumentsState[i] = GetNullabilityInfo(memberInfo, genericArguments[i], parser, ref index);
+ }
+ }
+
+ return new NullabilityInfo(type, state, state, elementState, genericArgumentsState);
+ }
+
+ private static NullableAttributeStateParser CreateParser(IList customAttributes)
+ {
+ foreach (CustomAttributeData attribute in customAttributes)
+ {
+ if (attribute.AttributeType.Name == "NullableAttribute" &&
+ attribute.AttributeType.Namespace == CompilerServicesNameSpace &&
+ attribute.ConstructorArguments.Count == 1)
+ {
+ return new NullableAttributeStateParser(attribute.ConstructorArguments[0].Value);
+ }
+ }
+
+ return new NullableAttributeStateParser(null);
+ }
+
+ private void TryLoadGenericMetaTypeNullability(MemberInfo memberInfo, NullabilityInfo nullability)
+ {
+ MemberInfo? metaMember = GetMemberMetadataDefinition(memberInfo);
+ Type? metaType = null;
+ if (metaMember is FieldInfo field)
+ {
+ metaType = field.FieldType;
+ }
+ else if (metaMember is PropertyInfo property)
+ {
+ metaType = GetPropertyMetaType(property);
+ }
+
+ if (metaType != null)
+ {
+ CheckGenericParameters(nullability, metaMember!, metaType, memberInfo.ReflectedType);
+ }
+ }
+
+ private static MemberInfo GetMemberMetadataDefinition(MemberInfo member)
+ {
+ Type? type = member.DeclaringType;
+ if ((type != null) && type.IsGenericType && !type.IsGenericTypeDefinition)
+ {
+ return NullabilityInfoHelpers.GetMemberWithSameMetadataDefinitionAs(type.GetGenericTypeDefinition(), member);
+ }
+
+ return member;
+ }
+
+ private static Type GetPropertyMetaType(PropertyInfo property)
+ {
+ if (property.GetGetMethod(true) is MethodInfo method)
+ {
+ return method.ReturnType;
+ }
+
+ return property.GetSetMethod(true)!.GetParameters()[0].ParameterType;
+ }
+
+ private void CheckGenericParameters(NullabilityInfo nullability, MemberInfo metaMember, Type metaType, Type? reflectedType)
+ {
+ if (metaType.IsGenericParameter)
+ {
+ if (nullability.ReadState == NullabilityState.NotNull)
+ {
+ _ = TryUpdateGenericParameterNullability(nullability, metaType, reflectedType);
+ }
+ }
+ else if (metaType.ContainsGenericParameters)
+ {
+ if (nullability.GenericTypeArguments.Length > 0)
+ {
+ Type[] genericArguments = metaType.GetGenericArguments();
+
+ for (int i = 0; i < genericArguments.Length; i++)
+ {
+ CheckGenericParameters(nullability.GenericTypeArguments[i], metaMember, genericArguments[i], reflectedType);
+ }
+ }
+ else if (nullability.ElementType is { } elementNullability && metaType.IsArray)
+ {
+ CheckGenericParameters(elementNullability, metaMember, metaType.GetElementType()!, reflectedType);
+ }
+
+ // We could also follow this branch for metaType.IsPointer, but since pointers must be unmanaged this
+ // will be a no-op regardless
+ else if (metaType.IsByRef)
+ {
+ CheckGenericParameters(nullability, metaMember, metaType.GetElementType()!, reflectedType);
+ }
+ }
+ }
+
+ private bool TryUpdateGenericParameterNullability(NullabilityInfo nullability, Type genericParameter, Type? reflectedType)
+ {
+ Debug.Assert(genericParameter.IsGenericParameter, "must be generic parameter");
+
+ if (reflectedType is not null
+ && !genericParameter.IsGenericMethodParameter()
+ && TryUpdateGenericTypeParameterNullabilityFromReflectedType(nullability, genericParameter, reflectedType, reflectedType))
+ {
+ return true;
+ }
+
+ if (IsValueTypeOrValueTypeByRef(nullability.Type))
+ {
+ return true;
+ }
+
+ var state = NullabilityState.Unknown;
+ if (CreateParser(genericParameter.GetCustomAttributesData()).ParseNullableState(0, ref state))
+ {
+ nullability.ReadState = state;
+ nullability.WriteState = state;
+ return true;
+ }
+
+ if (GetNullableContext(genericParameter) is { } contextState)
+ {
+ nullability.ReadState = contextState;
+ nullability.WriteState = contextState;
+ return true;
+ }
+
+ return false;
+ }
+
+ private bool TryUpdateGenericTypeParameterNullabilityFromReflectedType(NullabilityInfo nullability, Type genericParameter, Type context, Type reflectedType)
+ {
+ Debug.Assert(genericParameter.IsGenericParameter && !genericParameter.IsGenericMethodParameter(), "must be generic parameter");
+
+ Type contextTypeDefinition = context.IsGenericType && !context.IsGenericTypeDefinition ? context.GetGenericTypeDefinition() : context;
+ if (genericParameter.DeclaringType == contextTypeDefinition)
+ {
+ return false;
+ }
+
+ Type? baseType = contextTypeDefinition.BaseType;
+ if (baseType is null)
+ {
+ return false;
+ }
+
+ if (!baseType.IsGenericType
+ || (baseType.IsGenericTypeDefinition ? baseType : baseType.GetGenericTypeDefinition()) != genericParameter.DeclaringType)
+ {
+ return TryUpdateGenericTypeParameterNullabilityFromReflectedType(nullability, genericParameter, baseType, reflectedType);
+ }
+
+ Type[] genericArguments = baseType.GetGenericArguments();
+ Type genericArgument = genericArguments[genericParameter.GenericParameterPosition];
+ if (genericArgument.IsGenericParameter)
+ {
+ return TryUpdateGenericParameterNullability(nullability, genericArgument, reflectedType);
+ }
+
+ NullableAttributeStateParser parser = CreateParser(contextTypeDefinition.GetCustomAttributesData());
+ int nullabilityStateIndex = 1; // start at 1 since index 0 is the type itself
+ for (int i = 0; i < genericParameter.GenericParameterPosition; i++)
+ {
+ nullabilityStateIndex += CountNullabilityStates(genericArguments[i]);
+ }
+
+ return TryPopulateNullabilityInfo(nullability, parser, ref nullabilityStateIndex);
+
+ static int CountNullabilityStates(Type type)
+ {
+ Type underlyingType = Nullable.GetUnderlyingType(type) ?? type;
+ if (underlyingType.IsGenericType)
+ {
+ int count = 1;
+ foreach (Type genericArgument in underlyingType.GetGenericArguments())
+ {
+ count += CountNullabilityStates(genericArgument);
+ }
+
+ return count;
+ }
+
+ if (underlyingType.HasElementType)
+ {
+ return (underlyingType.IsArray ? 1 : 0) + CountNullabilityStates(underlyingType.GetElementType()!);
+ }
+
+ return type.IsValueType ? 0 : 1;
+ }
+ }
+
+#pragma warning disable SA1204 // Static elements should appear before instance elements
+ private static bool TryPopulateNullabilityInfo(NullabilityInfo nullability, NullableAttributeStateParser parser, ref int index)
+#pragma warning restore SA1204 // Static elements should appear before instance elements
+ {
+ bool isValueType = IsValueTypeOrValueTypeByRef(nullability.Type);
+ if (!isValueType)
+ {
+ var state = NullabilityState.Unknown;
+ if (!parser.ParseNullableState(index, ref state))
+ {
+ return false;
+ }
+
+ nullability.ReadState = state;
+ nullability.WriteState = state;
+ }
+
+ if (!isValueType || (Nullable.GetUnderlyingType(nullability.Type) ?? nullability.Type).IsGenericType)
+ {
+ index++;
+ }
+
+ if (nullability.GenericTypeArguments.Length > 0)
+ {
+ foreach (NullabilityInfo genericTypeArgumentNullability in nullability.GenericTypeArguments)
+ {
+ _ = TryPopulateNullabilityInfo(genericTypeArgumentNullability, parser, ref index);
+ }
+ }
+ else if (nullability.ElementType is { } elementTypeNullability)
+ {
+ _ = TryPopulateNullabilityInfo(elementTypeNullability, parser, ref index);
+ }
+
+ return true;
+ }
+
+ private static NullabilityState TranslateByte(object? value)
+ {
+ return value is byte b ? TranslateByte(b) : NullabilityState.Unknown;
+ }
+
+ private static NullabilityState TranslateByte(byte b) =>
+ b switch
+ {
+ 1 => NullabilityState.NotNull,
+ 2 => NullabilityState.Nullable,
+ _ => NullabilityState.Unknown
+ };
+
+ private static bool IsValueTypeOrValueTypeByRef(Type type) =>
+ type.IsValueType || ((type.IsByRef || type.IsPointer) && type.GetElementType()!.IsValueType);
+
+ private readonly struct NullableAttributeStateParser
+ {
+ private static readonly object UnknownByte = (byte)0;
+
+ private readonly object? _nullableAttributeArgument;
+
+ public NullableAttributeStateParser(object? nullableAttributeArgument)
+ {
+ _nullableAttributeArgument = nullableAttributeArgument;
+ }
+
+ public static NullableAttributeStateParser Unknown => new(UnknownByte);
+
+ public bool ParseNullableState(int index, ref NullabilityState state)
+ {
+ switch (_nullableAttributeArgument)
+ {
+ case byte b:
+ state = TranslateByte(b);
+ return true;
+ case ReadOnlyCollection args
+ when index < args.Count && args[index].Value is byte elementB:
+ state = TranslateByte(elementB);
+ return true;
+ default:
+ return false;
+ }
+ }
+ }
+ }
+}
+#endif
diff --git a/src/Shared/JsonSchemaExporter/NullabilityInfoContext/NullabilityInfoHelpers.cs b/src/Shared/JsonSchemaExporter/NullabilityInfoContext/NullabilityInfoHelpers.cs
new file mode 100644
index 00000000000..1ee573a0020
--- /dev/null
+++ b/src/Shared/JsonSchemaExporter/NullabilityInfoContext/NullabilityInfoHelpers.cs
@@ -0,0 +1,47 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+#if !NET6_0_OR_GREATER
+using System.Diagnostics.CodeAnalysis;
+
+#pragma warning disable IDE1006 // Naming Styles
+#pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields
+
+namespace System.Reflection
+{
+ ///
+ /// Polyfills for System.Private.CoreLib internals.
+ ///
+ [ExcludeFromCodeCoverage]
+ internal static class NullabilityInfoHelpers
+ {
+ public static MemberInfo GetMemberWithSameMetadataDefinitionAs(Type type, MemberInfo member)
+ {
+ const BindingFlags all = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance;
+ foreach (var info in type.GetMembers(all))
+ {
+ if (info.HasSameMetadataDefinitionAs(member))
+ {
+ return info;
+ }
+ }
+
+ throw new MissingMemberException(type.FullName, member.Name);
+ }
+
+ // https://github.com/dotnet/runtime/blob/main/src/coreclr/System.Private.CoreLib/src/System/Reflection/MemberInfo.Internal.cs
+ public static bool HasSameMetadataDefinitionAs(this MemberInfo target, MemberInfo other)
+ {
+ return target.MetadataToken == other.MetadataToken &&
+ target.Module.Equals(other.Module);
+ }
+
+ // https://github.com/dotnet/runtime/issues/23493
+ public static bool IsGenericMethodParameter(this Type target)
+ {
+ return target.IsGenericParameter &&
+ target.DeclaringMethod != null;
+ }
+ }
+}
+#endif
diff --git a/src/Shared/JsonSchemaExporter/README.md b/src/Shared/JsonSchemaExporter/README.md
new file mode 100644
index 00000000000..1a4d13c5841
--- /dev/null
+++ b/src/Shared/JsonSchemaExporter/README.md
@@ -0,0 +1,11 @@
+# JsonSchemaExporter
+
+Provides a polyfill for the [.NET 9 `JsonSchemaExporter` component](https://learn.microsoft.com/dotnet/standard/serialization/system-text-json/extract-schema) that is compatible with all supported targets using System.Text.Json version 8.
+
+To use this in your project, add the following to your `.csproj` file:
+
+```xml
+
+ true
+
+```
diff --git a/src/Shared/Shared.csproj b/src/Shared/Shared.csproj
index f6cbb03ea83..439c3788557 100644
--- a/src/Shared/Shared.csproj
+++ b/src/Shared/Shared.csproj
@@ -12,11 +12,12 @@
true
true
true
- true
+ true
true
true
true
true
+ true
@@ -33,6 +34,10 @@
+
+
+
+
diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/AdditionalPropertiesDictionaryTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/AdditionalPropertiesDictionaryTests.cs
index a9a544c8ca8..09f515fa066 100644
--- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/AdditionalPropertiesDictionaryTests.cs
+++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/AdditionalPropertiesDictionaryTests.cs
@@ -90,4 +90,45 @@ static void AssertNotFound(T1 input)
Assert.Equal(default(T2), value);
}
}
+
+ [Fact]
+ public void TryAdd_AddsOnlyIfNonExistent()
+ {
+ AdditionalPropertiesDictionary d = [];
+
+ Assert.False(d.ContainsKey("key"));
+ Assert.True(d.TryAdd("key", "value"));
+ Assert.True(d.ContainsKey("key"));
+ Assert.Equal("value", d["key"]);
+
+ Assert.False(d.TryAdd("key", "value2"));
+ Assert.True(d.ContainsKey("key"));
+ Assert.Equal("value", d["key"]);
+ }
+
+ [Fact]
+ public void Enumerator_EnumeratesAllItems()
+ {
+ AdditionalPropertiesDictionary d = [];
+
+ const int NumProperties = 10;
+ for (int i = 0; i < NumProperties; i++)
+ {
+ d.Add($"key{i}", $"value{i}");
+ }
+
+ Assert.Equal(NumProperties, d.Count);
+
+ // This depends on an implementation detail of the ordering in which the dictionary
+ // enumerates items. If that ever changes, this test will need to be updated.
+ int count = 0;
+ foreach (KeyValuePair item in d)
+ {
+ Assert.Equal($"key{count}", item.Key);
+ Assert.Equal($"value{count}", item.Value);
+ count++;
+ }
+
+ Assert.Equal(NumProperties, count);
+ }
}
diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatOptionsTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatOptionsTests.cs
index f83169712c3..fcd40a2f446 100644
--- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatOptionsTests.cs
+++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatOptionsTests.cs
@@ -19,6 +19,7 @@ public void Constructor_Parameterless_PropsDefaulted()
Assert.Null(options.TopK);
Assert.Null(options.FrequencyPenalty);
Assert.Null(options.PresencePenalty);
+ Assert.Null(options.Seed);
Assert.Null(options.ResponseFormat);
Assert.Null(options.ModelId);
Assert.Null(options.StopSequences);
@@ -33,6 +34,7 @@ public void Constructor_Parameterless_PropsDefaulted()
Assert.Null(clone.TopK);
Assert.Null(clone.FrequencyPenalty);
Assert.Null(clone.PresencePenalty);
+ Assert.Null(options.Seed);
Assert.Null(clone.ResponseFormat);
Assert.Null(clone.ModelId);
Assert.Null(clone.StopSequences);
@@ -69,6 +71,7 @@ public void Properties_Roundtrip()
options.TopK = 42;
options.FrequencyPenalty = 0.4f;
options.PresencePenalty = 0.5f;
+ options.Seed = 12345;
options.ResponseFormat = ChatResponseFormat.Json;
options.ModelId = "modelId";
options.StopSequences = stopSequences;
@@ -82,6 +85,7 @@ public void Properties_Roundtrip()
Assert.Equal(42, options.TopK);
Assert.Equal(0.4f, options.FrequencyPenalty);
Assert.Equal(0.5f, options.PresencePenalty);
+ Assert.Equal(12345, options.Seed);
Assert.Same(ChatResponseFormat.Json, options.ResponseFormat);
Assert.Equal("modelId", options.ModelId);
Assert.Same(stopSequences, options.StopSequences);
@@ -96,6 +100,7 @@ public void Properties_Roundtrip()
Assert.Equal(42, clone.TopK);
Assert.Equal(0.4f, clone.FrequencyPenalty);
Assert.Equal(0.5f, clone.PresencePenalty);
+ Assert.Equal(12345, options.Seed);
Assert.Same(ChatResponseFormat.Json, clone.ResponseFormat);
Assert.Equal("modelId", clone.ModelId);
Assert.Equal(stopSequences, clone.StopSequences);
@@ -126,6 +131,7 @@ public void JsonSerialization_Roundtrips()
options.TopK = 42;
options.FrequencyPenalty = 0.4f;
options.PresencePenalty = 0.5f;
+ options.Seed = 12345;
options.ResponseFormat = ChatResponseFormat.Json;
options.ModelId = "modelId";
options.StopSequences = stopSequences;
@@ -148,6 +154,7 @@ public void JsonSerialization_Roundtrips()
Assert.Equal(42, deserialized.TopK);
Assert.Equal(0.4f, deserialized.FrequencyPenalty);
Assert.Equal(0.5f, deserialized.PresencePenalty);
+ Assert.Equal(12345, deserialized.Seed);
Assert.Equal(ChatResponseFormat.Json, deserialized.ResponseFormat);
Assert.NotSame(ChatResponseFormat.Json, deserialized.ResponseFormat);
Assert.Equal("modelId", deserialized.ModelId);
diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Microsoft.Extensions.AI.Abstractions.Tests.csproj b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Microsoft.Extensions.AI.Abstractions.Tests.csproj
index 0d4d5fbfa96..911ce1b2bf8 100644
--- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Microsoft.Extensions.AI.Abstractions.Tests.csproj
+++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Microsoft.Extensions.AI.Abstractions.Tests.csproj
@@ -5,16 +5,27 @@
- $(NoWarn);CA1063;CA1861;CA2201;VSTHRD003
+ $(NoWarn);CA1063;CA1861;CA2201;VSTHRD003;S104
true
+ true
+ true
+ true
true
+ true
+
+
+
+
+
+
+
diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/AIJsonUtilitiesTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Utilities/AIJsonUtilitiesTests.cs
similarity index 77%
rename from test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/AIJsonUtilitiesTests.cs
rename to test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Utilities/AIJsonUtilitiesTests.cs
index db482d26804..52f9cad246d 100644
--- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/AIJsonUtilitiesTests.cs
+++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Utilities/AIJsonUtilitiesTests.cs
@@ -3,7 +3,9 @@
using System.ComponentModel;
using System.Text.Json;
+using System.Text.Json.Nodes;
using System.Text.Json.Serialization;
+using Microsoft.Extensions.AI.JsonSchemaExporter;
using Xunit;
namespace Microsoft.Extensions.AI;
@@ -130,7 +132,7 @@ public static void ResolveParameterJsonSchema_ReturnsExpectedValue()
}
[Fact]
- public static void ResolveParameterJsonSchema_TreatsIntegralTypesAsInteger_EvenWithAllowReadingFromString()
+ public static void CreateParameterJsonSchema_TreatsIntegralTypesAsInteger_EvenWithAllowReadingFromString()
{
JsonElement expected = JsonDocument.Parse("""
{
@@ -158,4 +160,38 @@ public enum MyEnumValue
A = 1,
B = 2
}
+
+ [Fact]
+ public static void CreateJsonSchema_CanBeBoolean()
+ {
+ JsonElement schema = AIJsonUtilities.CreateJsonSchema(typeof(object));
+ Assert.Equal(JsonValueKind.True, schema.ValueKind);
+ }
+
+ [Theory]
+ [MemberData(nameof(TestTypes.GetTestDataUsingAllValues), MemberType = typeof(TestTypes))]
+ public static void CreateJsonSchema_ValidateWithTestData(ITestData testData)
+ {
+ // Stress tests the schema generation method using types from the JsonSchemaExporter test battery.
+
+ JsonSerializerOptions options = testData.Options is { } opts
+ ? new(opts) { TypeInfoResolver = TestTypes.TestTypesContext.Default }
+ : TestTypes.TestTypesContext.Default.Options;
+
+ JsonElement schema = AIJsonUtilities.CreateJsonSchema(testData.Type, serializerOptions: options);
+ JsonNode? schemaAsNode = JsonSerializer.SerializeToNode(schema, options);
+
+ Assert.NotNull(schemaAsNode);
+ Assert.Equal(testData.ExpectedJsonSchema.GetValueKind(), schemaAsNode.GetValueKind());
+
+ if (testData.Value is null || testData.WritesNumbersAsStrings)
+ {
+ // By design, our generated schema does not accept null root values
+ // or numbers formatted as strings, so we skip schema validation.
+ return;
+ }
+
+ JsonNode? serializedValue = JsonSerializer.SerializeToNode(testData.Value, testData.Type, options);
+ SchemaTestHelpers.AssertDocumentMatchesSchema(schemaAsNode, serializedValue);
+ }
}
diff --git a/test/Libraries/Microsoft.Extensions.AI.AotCompatibility.TestApp/Microsoft.Extensions.AI.AotCompatibility.TestApp.csproj b/test/Libraries/Microsoft.Extensions.AI.AotCompatibility.TestApp/Microsoft.Extensions.AI.AotCompatibility.TestApp.csproj
new file mode 100644
index 00000000000..183cd150937
--- /dev/null
+++ b/test/Libraries/Microsoft.Extensions.AI.AotCompatibility.TestApp/Microsoft.Extensions.AI.AotCompatibility.TestApp.csproj
@@ -0,0 +1,26 @@
+
+
+
+ Exe
+ $(LatestTargetFramework)
+ true
+ false
+ true
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/test/Libraries/Microsoft.Extensions.AI.AotCompatibility.TestApp/Program.cs b/test/Libraries/Microsoft.Extensions.AI.AotCompatibility.TestApp/Program.cs
new file mode 100644
index 00000000000..b518dfa7739
--- /dev/null
+++ b/test/Libraries/Microsoft.Extensions.AI.AotCompatibility.TestApp/Program.cs
@@ -0,0 +1,22 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+#pragma warning disable S125 // Remove this commented out code
+
+using Microsoft.Extensions.AI;
+
+// Use types from each library.
+
+// Microsoft.Extensions.AI.Ollama
+using var b = new OllamaChatClient("http://localhost:11434", "llama3.2");
+
+// Microsoft.Extensions.AI.AzureAIInference
+// using var a = new Azure.AI.Inference.ChatCompletionClient(new Uri("http://localhost"), new("apikey")); // uncomment once warnings in Azure.AI.Inference are addressed
+
+// Microsoft.Extensions.AI.OpenAI
+// using var c = new OpenAI.OpenAIClient("apikey").AsChatClient("gpt-4o-mini"); // uncomment once warnings in OpenAI are addressed
+
+// Microsoft.Extensions.AI
+AIFunctionFactory.Create(() => { });
+
+System.Console.WriteLine("Success!");
diff --git a/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/AzureAIInferenceChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/AzureAIInferenceChatClientTests.cs
index 4fb5122cc93..f404f5e61ef 100644
--- a/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/AzureAIInferenceChatClientTests.cs
+++ b/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/AzureAIInferenceChatClientTests.cs
@@ -247,8 +247,8 @@ public async Task MultipleMessages_NonStreaming()
],
"presence_penalty": 0.5,
"frequency_penalty": 0.75,
- "model": "gpt-4o-mini",
- "seed": 42
+ "seed": 42,
+ "model": "gpt-4o-mini"
}
""";
@@ -303,7 +303,7 @@ public async Task MultipleMessages_NonStreaming()
FrequencyPenalty = 0.75f,
PresencePenalty = 0.5f,
StopSequences = ["great"],
- AdditionalProperties = new() { ["seed"] = 42L },
+ Seed = 42,
});
Assert.NotNull(response);
diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs
index 0863e31db37..e9c2bd57d65 100644
--- a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs
+++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs
@@ -6,6 +6,7 @@
using System.ComponentModel;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
+using System.IO;
using System.Linq;
using System.Runtime.InteropServices;
using System.Text;
@@ -132,6 +133,27 @@ public virtual async Task CompleteStreamingAsync_UsageDataAvailable()
Assert.Equal(usage.Details.InputTokenCount + usage.Details.OutputTokenCount, usage.Details.TotalTokenCount);
}
+ protected virtual string? GetModel_MultiModal_DescribeImage() => null;
+
+ [ConditionalFact]
+ public virtual async Task MultiModal_DescribeImage()
+ {
+ SkipIfNotEnabled();
+
+ var response = await _chatClient.CompleteAsync(
+ [
+ new(ChatRole.User,
+ [
+ new TextContent("What does this logo say?"),
+ new ImageContent(GetImageDataUri()),
+ ])
+ ],
+ new() { ModelId = GetModel_MultiModal_DescribeImage() });
+
+ Assert.Single(response.Choices);
+ Assert.True(response.Message.Text?.IndexOf("net", StringComparison.OrdinalIgnoreCase) >= 0, response.Message.Text);
+ }
+
[ConditionalFact]
public virtual async Task FunctionInvocation_AutomaticallyInvokeFunction_Parameterless()
{
@@ -714,6 +736,15 @@ private enum JobType
Unknown,
}
+ private static Uri GetImageDataUri()
+ {
+ using Stream? s = typeof(ChatClientIntegrationTests).Assembly.GetManifestResourceStream("Microsoft.Extensions.AI.dotnet.png");
+ Assert.NotNull(s);
+ MemoryStream ms = new();
+ s.CopyTo(ms);
+ return new Uri($"data:image/png;base64,{Convert.ToBase64String(ms.ToArray())}");
+ }
+
[MemberNotNull(nameof(_chatClient))]
protected void SkipIfNotEnabled()
{
diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/Microsoft.Extensions.AI.Integration.Tests.csproj b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/Microsoft.Extensions.AI.Integration.Tests.csproj
index e38ccd3268b..04d9bc6d29f 100644
--- a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/Microsoft.Extensions.AI.Integration.Tests.csproj
+++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/Microsoft.Extensions.AI.Integration.Tests.csproj
@@ -15,6 +15,10 @@
true
+
+
+
+
diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/dotnet.png b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/dotnet.png
new file mode 100644
index 00000000000..fb00ecf91e4
Binary files /dev/null and b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/dotnet.png differ
diff --git a/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaChatClientIntegrationTests.cs b/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaChatClientIntegrationTests.cs
index 891378c0e86..4c71690baaf 100644
--- a/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaChatClientIntegrationTests.cs
+++ b/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaChatClientIntegrationTests.cs
@@ -30,6 +30,8 @@ public override Task FunctionInvocation_RequireAny() =>
public override Task FunctionInvocation_RequireSpecific() =>
throw new SkipTestException("Ollama does not currently support requiring function invocation.");
+ protected override string? GetModel_MultiModal_DescribeImage() => "llava";
+
[ConditionalFact]
public async Task PromptBasedFunctionCalling_NoArgs()
{
@@ -47,7 +49,7 @@ public async Task PromptBasedFunctionCalling_NoArgs()
ModelId = "llama3:8b",
Tools = [AIFunctionFactory.Create(() => secretNumber, "GetSecretNumber")],
Temperature = 0,
- AdditionalProperties = new() { ["seed"] = 0L },
+ Seed = 0,
});
Assert.Single(response.Choices);
@@ -81,7 +83,7 @@ public async Task PromptBasedFunctionCalling_WithArgs()
{
Tools = [stockPriceTool, irrelevantTool],
Temperature = 0,
- AdditionalProperties = new() { ["seed"] = 0L },
+ Seed = 0,
});
Assert.Single(response.Choices);
diff --git a/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaChatClientTests.cs
index 3e281173c8b..67b10e3f24b 100644
--- a/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaChatClientTests.cs
+++ b/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaChatClientTests.cs
@@ -254,7 +254,7 @@ public async Task MultipleMessages_NonStreaming()
FrequencyPenalty = 0.75f,
PresencePenalty = 0.5f,
StopSequences = ["great"],
- AdditionalProperties = new() { ["seed"] = 42 },
+ Seed = 42,
});
Assert.NotNull(response);
diff --git a/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIChatClientTests.cs
index 691804e5fb8..05d2f5a22ff 100644
--- a/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIChatClientTests.cs
+++ b/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIChatClientTests.cs
@@ -348,7 +348,7 @@ public async Task MultipleMessages_NonStreaming()
FrequencyPenalty = 0.75f,
PresencePenalty = 0.5f,
StopSequences = ["great"],
- AdditionalProperties = new() { ["seed"] = 42 },
+ Seed = 42,
});
Assert.NotNull(response);
diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ConfigureOptionsChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ConfigureOptionsChatClientTests.cs
index a27761c99ec..a911340813f 100644
--- a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ConfigureOptionsChatClientTests.cs
+++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ConfigureOptionsChatClientTests.cs
@@ -26,11 +26,13 @@ public void UseChatOptions_InvalidArgs_Throws()
Assert.Throws("configureOptions", () => builder.UseChatOptions(null!));
}
- [Fact]
- public async Task ConfigureOptions_ReturnedInstancePassedToNextClient()
+ [Theory]
+ [InlineData(false)]
+ [InlineData(true)]
+ public async Task ConfigureOptions_ReturnedInstancePassedToNextClient(bool nullReturned)
{
ChatOptions providedOptions = new();
- ChatOptions returnedOptions = new();
+ ChatOptions? returnedOptions = nullReturned ? null : new();
ChatCompletion expectedCompletion = new(Array.Empty());
var expectedUpdates = Enumerable.Range(0, 3).Select(i => new StreamingChatCompletionUpdate()).ToArray();
using CancellationTokenSource cts = new();
diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/ConfigureOptionsEmbeddingGeneratorTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/ConfigureOptionsEmbeddingGeneratorTests.cs
new file mode 100644
index 00000000000..b8a4b82cb59
--- /dev/null
+++ b/test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/ConfigureOptionsEmbeddingGeneratorTests.cs
@@ -0,0 +1,58 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+using System.Threading;
+using System.Threading.Tasks;
+using Xunit;
+
+namespace Microsoft.Extensions.AI;
+
+public class ConfigureOptionsEmbeddingGeneratorTests
+{
+ [Fact]
+ public void ConfigureOptionsEmbeddingGenerator_InvalidArgs_Throws()
+ {
+ Assert.Throws("innerGenerator", () => new ConfigureOptionsEmbeddingGenerator>(null!, _ => new EmbeddingGenerationOptions()));
+ Assert.Throws("configureOptions", () => new ConfigureOptionsEmbeddingGenerator>(new TestEmbeddingGenerator(), null!));
+ }
+
+ [Fact]
+ public void UseEmbeddingGenerationOptions_InvalidArgs_Throws()
+ {
+ var builder = new EmbeddingGeneratorBuilder>();
+ Assert.Throws("configureOptions", () => builder.UseEmbeddingGenerationOptions(null!));
+ }
+
+ [Theory]
+ [InlineData(false)]
+ [InlineData(true)]
+ public async Task ConfigureOptions_ReturnedInstancePassedToNextClient(bool nullReturned)
+ {
+ EmbeddingGenerationOptions providedOptions = new();
+ EmbeddingGenerationOptions? returnedOptions = nullReturned ? null : new();
+ GeneratedEmbeddings> expectedEmbeddings = [];
+ using CancellationTokenSource cts = new();
+
+ using IEmbeddingGenerator> innerGenerator = new TestEmbeddingGenerator
+ {
+ GenerateAsyncCallback = (inputs, options, cancellationToken) =>
+ {
+ Assert.Same(returnedOptions, options);
+ Assert.Equal(cts.Token, cancellationToken);
+ return Task.FromResult(expectedEmbeddings);
+ }
+ };
+
+ using var generator = new EmbeddingGeneratorBuilder>()
+ .UseEmbeddingGenerationOptions(options =>
+ {
+ Assert.Same(providedOptions, options);
+ return returnedOptions;
+ })
+ .Use(innerGenerator);
+
+ var embeddings = await generator.GenerateAsync([], providedOptions, cts.Token);
+ Assert.Same(expectedEmbeddings, embeddings);
+ }
+}
diff --git a/test/Libraries/Microsoft.Extensions.Caching.Hybrid.Tests/HybridCacheEventSourceTests.cs b/test/Libraries/Microsoft.Extensions.Caching.Hybrid.Tests/HybridCacheEventSourceTests.cs
new file mode 100644
index 00000000000..3a266af7ce3
--- /dev/null
+++ b/test/Libraries/Microsoft.Extensions.Caching.Hybrid.Tests/HybridCacheEventSourceTests.cs
@@ -0,0 +1,205 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Diagnostics.Tracing;
+using Microsoft.Extensions.Caching.Hybrid.Internal;
+using Xunit.Abstractions;
+
+namespace Microsoft.Extensions.Caching.Hybrid.Tests;
+
+public class HybridCacheEventSourceTests(ITestOutputHelper log, TestEventListener listener) : IClassFixture
+{
+ // see notes in TestEventListener for context on fixture usage
+
+ [SkippableFact]
+ public void MatchesNameAndGuid()
+ {
+ // Assert
+ Assert.Equal("Microsoft-Extensions-HybridCache", listener.Source.Name);
+ Assert.Equal(Guid.Parse("b3aca39e-5dc9-5e21-f669-b72225b66cfc"), listener.Source.Guid); // from name
+ }
+
+ [SkippableFact]
+ public async Task LocalCacheHit()
+ {
+ AssertEnabled();
+
+ listener.Reset().Source.LocalCacheHit();
+ listener.AssertSingleEvent(HybridCacheEventSource.EventIdLocalCacheHit, "LocalCacheHit", EventLevel.Verbose);
+
+ await AssertCountersAsync();
+ listener.AssertCounter("total-local-cache-hits", "Total Local Cache Hits", 1);
+ listener.AssertRemainingCountersZero();
+ }
+
+ [SkippableFact]
+ public async Task LocalCacheMiss()
+ {
+ AssertEnabled();
+
+ listener.Reset().Source.LocalCacheMiss();
+ listener.AssertSingleEvent(HybridCacheEventSource.EventIdLocalCacheMiss, "LocalCacheMiss", EventLevel.Verbose);
+
+ await AssertCountersAsync();
+ listener.AssertCounter("total-local-cache-misses", "Total Local Cache Misses", 1);
+ listener.AssertRemainingCountersZero();
+ }
+
+ [SkippableFact]
+ public async Task DistributedCacheGet()
+ {
+ AssertEnabled();
+
+ listener.Reset().Source.DistributedCacheGet();
+ listener.AssertSingleEvent(HybridCacheEventSource.EventIdDistributedCacheGet, "DistributedCacheGet", EventLevel.Verbose);
+
+ await AssertCountersAsync();
+ listener.AssertCounter("current-distributed-cache-fetches", "Current Distributed Cache Fetches", 1);
+ listener.AssertRemainingCountersZero();
+ }
+
+ [SkippableFact]
+ public async Task DistributedCacheHit()
+ {
+ AssertEnabled();
+
+ listener.Reset().Source.DistributedCacheGet();
+ listener.Reset(resetCounters: false).Source.DistributedCacheHit();
+ listener.AssertSingleEvent(HybridCacheEventSource.EventIdDistributedCacheHit, "DistributedCacheHit", EventLevel.Verbose);
+
+ await AssertCountersAsync();
+ listener.AssertCounter("total-distributed-cache-hits", "Total Distributed Cache Hits", 1);
+ listener.AssertRemainingCountersZero();
+ }
+
+ [SkippableFact]
+ public async Task DistributedCacheMiss()
+ {
+ AssertEnabled();
+
+ listener.Reset().Source.DistributedCacheGet();
+ listener.Reset(resetCounters: false).Source.DistributedCacheMiss();
+ listener.AssertSingleEvent(HybridCacheEventSource.EventIdDistributedCacheMiss, "DistributedCacheMiss", EventLevel.Verbose);
+
+ await AssertCountersAsync();
+ listener.AssertCounter("total-distributed-cache-misses", "Total Distributed Cache Misses", 1);
+ listener.AssertRemainingCountersZero();
+ }
+
+ [SkippableFact]
+ public async Task DistributedCacheFailed()
+ {
+ AssertEnabled();
+
+ listener.Reset().Source.DistributedCacheGet();
+ listener.Reset(resetCounters: false).Source.DistributedCacheFailed();
+ listener.AssertSingleEvent(HybridCacheEventSource.EventIdDistributedCacheFailed, "DistributedCacheFailed", EventLevel.Error);
+
+ await AssertCountersAsync();
+ listener.AssertRemainingCountersZero();
+ }
+
+ [SkippableFact]
+ public async Task UnderlyingDataQueryStart()
+ {
+ AssertEnabled();
+
+ listener.Reset().Source.UnderlyingDataQueryStart();
+ listener.AssertSingleEvent(HybridCacheEventSource.EventIdUnderlyingDataQueryStart, "UnderlyingDataQueryStart", EventLevel.Verbose);
+
+ await AssertCountersAsync();
+ listener.AssertCounter("current-data-query", "Current Data Queries", 1);
+ listener.AssertCounter("total-data-query", "Total Data Queries", 1);
+ listener.AssertRemainingCountersZero();
+ }
+
+ [SkippableFact]
+ public async Task UnderlyingDataQueryComplete()
+ {
+ AssertEnabled();
+
+ listener.Reset().Source.UnderlyingDataQueryStart();
+ listener.Reset(resetCounters: false).Source.UnderlyingDataQueryComplete();
+ listener.AssertSingleEvent(HybridCacheEventSource.EventIdUnderlyingDataQueryComplete, "UnderlyingDataQueryComplete", EventLevel.Verbose);
+
+ await AssertCountersAsync();
+ listener.AssertCounter("total-data-query", "Total Data Queries", 1);
+ listener.AssertRemainingCountersZero();
+ }
+
+ [SkippableFact]
+ public async Task UnderlyingDataQueryFailed()
+ {
+ AssertEnabled();
+
+ listener.Reset().Source.UnderlyingDataQueryStart();
+ listener.Reset(resetCounters: false).Source.UnderlyingDataQueryFailed();
+ listener.AssertSingleEvent(HybridCacheEventSource.EventIdUnderlyingDataQueryFailed, "UnderlyingDataQueryFailed", EventLevel.Error);
+
+ await AssertCountersAsync();
+ listener.AssertCounter("total-data-query", "Total Data Queries", 1);
+ listener.AssertRemainingCountersZero();
+ }
+
+ [SkippableFact]
+ public async Task LocalCacheWrite()
+ {
+ AssertEnabled();
+
+ listener.Reset().Source.LocalCacheWrite();
+ listener.AssertSingleEvent(HybridCacheEventSource.EventIdLocalCacheWrite, "LocalCacheWrite", EventLevel.Verbose);
+
+ await AssertCountersAsync();
+ listener.AssertCounter("total-local-cache-writes", "Total Local Cache Writes", 1);
+ listener.AssertRemainingCountersZero();
+ }
+
+ [SkippableFact]
+ public async Task DistributedCacheWrite()
+ {
+ AssertEnabled();
+
+ listener.Reset().Source.DistributedCacheWrite();
+ listener.AssertSingleEvent(HybridCacheEventSource.EventIdDistributedCacheWrite, "DistributedCacheWrite", EventLevel.Verbose);
+
+ await AssertCountersAsync();
+ listener.AssertCounter("total-distributed-cache-writes", "Total Distributed Cache Writes", 1);
+ listener.AssertRemainingCountersZero();
+ }
+
+ [SkippableFact]
+ public async Task StampedeJoin()
+ {
+ AssertEnabled();
+
+ listener.Reset().Source.StampedeJoin();
+ listener.AssertSingleEvent(HybridCacheEventSource.EventIdStampedeJoin, "StampedeJoin", EventLevel.Verbose);
+
+ await AssertCountersAsync();
+ listener.AssertCounter("total-stampede-joins", "Total Stampede Joins", 1);
+ listener.AssertRemainingCountersZero();
+ }
+
+ private void AssertEnabled()
+ {
+ // including this data for visibility when tests fail - ETW subsystem can be ... weird
+ log.WriteLine($".NET {Environment.Version} on {Environment.OSVersion}, {IntPtr.Size * 8}-bit");
+
+ Skip.IfNot(listener.Source.IsEnabled(), "Event source not enabled");
+ }
+
+ private async Task AssertCountersAsync()
+ {
+ var count = await listener.TryAwaitCountersAsync();
+
+ // ETW counters timing can be painfully unpredictable; generally
+ // it'll work fine locally, especially on modern .NET, but:
+ // CI servers and netfx in particular - not so much. The tests
+ // can still observe and validate the simple events, though, which
+ // should be enough to be credible that the eventing system is
+ // fundamentally working. We're not meant to be testing that
+ // the counters system *itself* works!
+
+ Skip.If(count == 0, "No counters received");
+ }
+}
diff --git a/test/Libraries/Microsoft.Extensions.Caching.Hybrid.Tests/LogCollector.cs b/test/Libraries/Microsoft.Extensions.Caching.Hybrid.Tests/LogCollector.cs
new file mode 100644
index 00000000000..bdb5ff981c0
--- /dev/null
+++ b/test/Libraries/Microsoft.Extensions.Caching.Hybrid.Tests/LogCollector.cs
@@ -0,0 +1,84 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using Microsoft.Extensions.Logging;
+using Xunit.Abstractions;
+
+namespace Microsoft.Extensions.Caching.Hybrid.Tests;
+
+// dummy implementation for collecting test output
+internal class LogCollector : ILoggerProvider
+{
+ private readonly List<(string categoryName, LogLevel logLevel, EventId eventId, Exception? exception, string message)> _items = [];
+
+ public (string categoryName, LogLevel logLevel, EventId eventId, Exception? exception, string message)[] ToArray()
+ {
+ lock (_items)
+ {
+ return _items.ToArray();
+ }
+ }
+
+ public void WriteTo(ITestOutputHelper log)
+ {
+ lock (_items)
+ {
+ foreach (var logItem in _items)
+ {
+ var errSuffix = logItem.exception is null ? "" : $" - {logItem.exception.Message}";
+ log.WriteLine($"{logItem.categoryName} {logItem.eventId}: {logItem.message}{errSuffix}");
+ }
+ }
+ }
+
+ public void AssertErrors(int[] errorIds)
+ {
+ lock (_items)
+ {
+ bool same;
+ if (errorIds.Length == _items.Count)
+ {
+ int index = 0;
+ same = true;
+ foreach (var item in _items)
+ {
+ if (item.eventId.Id != errorIds[index++])
+ {
+ same = false;
+ break;
+ }
+ }
+ }
+ else
+ {
+ same = false;
+ }
+
+ if (!same)
+ {
+ // we expect this to fail, then
+ Assert.Equal(string.Join(",", errorIds), string.Join(",", _items.Select(static x => x.eventId.Id)));
+ }
+ }
+ }
+
+ ILogger ILoggerProvider.CreateLogger(string categoryName) => new TypedLogCollector(this, categoryName);
+
+ void IDisposable.Dispose()
+ {
+ // nothing to do
+ }
+
+ private sealed class TypedLogCollector(LogCollector parent, string categoryName) : ILogger
+ {
+ IDisposable? ILogger.BeginScope(TState state) => null;
+ bool ILogger.IsEnabled(LogLevel logLevel) => true;
+ void ILogger.Log(LogLevel logLevel, EventId eventId, TState state, Exception? exception, Func formatter)
+ {
+ lock (parent._items)
+ {
+ parent._items.Add((categoryName, logLevel, eventId, exception, formatter(state, exception)));
+ }
+ }
+ }
+}
diff --git a/test/Libraries/Microsoft.Extensions.Caching.Hybrid.Tests/Microsoft.Extensions.Caching.Hybrid.Tests.csproj b/test/Libraries/Microsoft.Extensions.Caching.Hybrid.Tests/Microsoft.Extensions.Caching.Hybrid.Tests.csproj
index ef80a84eee9..fb8863cf776 100644
--- a/test/Libraries/Microsoft.Extensions.Caching.Hybrid.Tests/Microsoft.Extensions.Caching.Hybrid.Tests.csproj
+++ b/test/Libraries/Microsoft.Extensions.Caching.Hybrid.Tests/Microsoft.Extensions.Caching.Hybrid.Tests.csproj
@@ -12,13 +12,15 @@
+
-
+
+
diff --git a/test/Libraries/Microsoft.Extensions.Caching.Hybrid.Tests/NullDistributedCache.cs b/test/Libraries/Microsoft.Extensions.Caching.Hybrid.Tests/NullDistributedCache.cs
new file mode 100644
index 00000000000..d07cb51bb93
--- /dev/null
+++ b/test/Libraries/Microsoft.Extensions.Caching.Hybrid.Tests/NullDistributedCache.cs
@@ -0,0 +1,31 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using Microsoft.Extensions.Caching.Distributed;
+
+namespace Microsoft.Extensions.Caching.Hybrid.Tests;
+
+// dummy L2 that doesn't actually store anything
+internal class NullDistributedCache : IDistributedCache
+{
+ byte[]? IDistributedCache.Get(string key) => null;
+ Task IDistributedCache.GetAsync(string key, CancellationToken token) => Task.FromResult(null);
+ void IDistributedCache.Refresh(string key)
+ {
+ // nothing to do
+ }
+
+ Task IDistributedCache.RefreshAsync(string key, CancellationToken token) => Task.CompletedTask;
+ void IDistributedCache.Remove(string key)
+ {
+ // nothing to do
+ }
+
+ Task IDistributedCache.RemoveAsync(string key, CancellationToken token) => Task.CompletedTask;
+ void IDistributedCache.Set(string key, byte[] value, DistributedCacheEntryOptions options)
+ {
+ // nothing to do
+ }
+
+ Task IDistributedCache.SetAsync(string key, byte[] value, DistributedCacheEntryOptions options, CancellationToken token) => Task.CompletedTask;
+}
diff --git a/test/Libraries/Microsoft.Extensions.Caching.Hybrid.Tests/SizeTests.cs b/test/Libraries/Microsoft.Extensions.Caching.Hybrid.Tests/SizeTests.cs
index 119c2297882..66f4fc7628d 100644
--- a/test/Libraries/Microsoft.Extensions.Caching.Hybrid.Tests/SizeTests.cs
+++ b/test/Libraries/Microsoft.Extensions.Caching.Hybrid.Tests/SizeTests.cs
@@ -1,31 +1,60 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
+using System.Buffers;
+using System.ComponentModel;
+using Microsoft.Extensions.Caching.Distributed;
using Microsoft.Extensions.Caching.Hybrid.Internal;
using Microsoft.Extensions.DependencyInjection;
+using Microsoft.Extensions.Logging;
+using Xunit.Abstractions;
namespace Microsoft.Extensions.Caching.Hybrid.Tests;
-public class SizeTests
+public class SizeTests(ITestOutputHelper log)
{
[Theory]
- [InlineData(null, true)] // does not enforce size limits
- [InlineData(8L, false)] // unreasonably small limit; chosen because our test string has length 12 - hence no expectation to find the second time
- [InlineData(1024L, true)] // reasonable size limit
- public async Task ValidateSizeLimit_Immutable(long? sizeLimit, bool expectFromL1)
+ [InlineData("abc", null, true, null, null)] // does not enforce size limits
+ [InlineData("", null, false, null, null, Log.IdKeyEmptyOrWhitespace, Log.IdKeyEmptyOrWhitespace)] // invalid key
+ [InlineData(" ", null, false, null, null, Log.IdKeyEmptyOrWhitespace, Log.IdKeyEmptyOrWhitespace)] // invalid key
+ [InlineData(null, null, false, null, null, Log.IdKeyEmptyOrWhitespace, Log.IdKeyEmptyOrWhitespace)] // invalid key
+ [InlineData("abc", 8L, false, null, null)] // unreasonably small limit; chosen because our test string has length 12 - hence no expectation to find the second time
+ [InlineData("abc", 1024L, true, null, null)] // reasonable size limit
+ [InlineData("abc", 1024L, true, 8L, null, Log.IdMaximumPayloadBytesExceeded)] // reasonable size limit, small HC quota
+ [InlineData("abc", null, false, null, 2, Log.IdMaximumKeyLengthExceeded, Log.IdMaximumKeyLengthExceeded)] // key limit exceeded
+ [InlineData("a\u0000c", null, false, null, null, Log.IdKeyInvalidContent, Log.IdKeyInvalidContent)] // invalid key
+ [InlineData("a\u001Fc", null, false, null, null, Log.IdKeyInvalidContent, Log.IdKeyInvalidContent)] // invalid key
+ [InlineData("a\u0020c", null, true, null, null)] // fine (this is just space)
+ public async Task ValidateSizeLimit_Immutable(string? key, long? sizeLimit, bool expectFromL1, long? maximumPayloadBytes, int? maximumKeyLength,
+ params int[] errorIds)
{
+ using var collector = new LogCollector();
var services = new ServiceCollection();
services.AddMemoryCache(options => options.SizeLimit = sizeLimit);
- services.AddHybridCache();
+ services.AddHybridCache(options =>
+ {
+ if (maximumKeyLength.HasValue)
+ {
+ options.MaximumKeyLength = maximumKeyLength.GetValueOrDefault();
+ }
+
+ if (maximumPayloadBytes.HasValue)
+ {
+ options.MaximumPayloadBytes = maximumPayloadBytes.GetValueOrDefault();
+ }
+ });
+ services.AddLogging(options =>
+ {
+ options.ClearProviders();
+ options.AddProvider(collector);
+ });
using ServiceProvider provider = services.BuildServiceProvider();
var cache = Assert.IsType(provider.GetRequiredService());
- const string Key = "abc";
-
// this looks weird; it is intentionally not a const - we want to check
// same instance without worrying about interning from raw literals
string expected = new("simple value".ToArray());
- var actual = await cache.GetOrCreateAsync(Key, ct => new(expected));
+ var actual = await cache.GetOrCreateAsync(key!, ct => new(expected));
// expect same contents
Assert.Equal(expected, actual);
@@ -35,7 +64,7 @@ public async Task ValidateSizeLimit_Immutable(long? sizeLimit, bool expectFromL1
Assert.Same(expected, actual);
// rinse and repeat, to check we get the value from L1
- actual = await cache.GetOrCreateAsync(Key, ct => new(Guid.NewGuid().ToString()));
+ actual = await cache.GetOrCreateAsync(key!, ct => new(Guid.NewGuid().ToString()));
if (expectFromL1)
{
@@ -51,30 +80,54 @@ public async Task ValidateSizeLimit_Immutable(long? sizeLimit, bool expectFromL1
// L1 cache not used
Assert.NotEqual(expected, actual);
}
+
+ collector.WriteTo(log);
+ collector.AssertErrors(errorIds);
}
[Theory]
- [InlineData(null, true)] // does not enforce size limits
- [InlineData(8L, false)] // unreasonably small limit; chosen because our test string has length 12 - hence no expectation to find the second time
- [InlineData(1024L, true)] // reasonable size limit
- public async Task ValidateSizeLimit_Mutable(long? sizeLimit, bool expectFromL1)
+ [InlineData("abc", null, true, null, null)] // does not enforce size limits
+ [InlineData("", null, false, null, null, Log.IdKeyEmptyOrWhitespace, Log.IdKeyEmptyOrWhitespace)] // invalid key
+ [InlineData(" ", null, false, null, null, Log.IdKeyEmptyOrWhitespace, Log.IdKeyEmptyOrWhitespace)] // invalid key
+ [InlineData(null, null, false, null, null, Log.IdKeyEmptyOrWhitespace, Log.IdKeyEmptyOrWhitespace)] // invalid key
+ [InlineData("abc", 8L, false, null, null)] // unreasonably small limit; chosen because our test string has length 12 - hence no expectation to find the second time
+ [InlineData("abc", 1024L, true, null, null)] // reasonable size limit
+ [InlineData("abc", 1024L, true, 8L, null, Log.IdMaximumPayloadBytesExceeded)] // reasonable size limit, small HC quota
+ [InlineData("abc", null, false, null, 2, Log.IdMaximumKeyLengthExceeded, Log.IdMaximumKeyLengthExceeded)] // key limit exceeded
+ public async Task ValidateSizeLimit_Mutable(string? key, long? sizeLimit, bool expectFromL1, long? maximumPayloadBytes, int? maximumKeyLength,
+ params int[] errorIds)
{
+ using var collector = new LogCollector();
var services = new ServiceCollection();
services.AddMemoryCache(options => options.SizeLimit = sizeLimit);
- services.AddHybridCache();
+ services.AddHybridCache(options =>
+ {
+ if (maximumKeyLength.HasValue)
+ {
+ options.MaximumKeyLength = maximumKeyLength.GetValueOrDefault();
+ }
+
+ if (maximumPayloadBytes.HasValue)
+ {
+ options.MaximumPayloadBytes = maximumPayloadBytes.GetValueOrDefault();
+ }
+ });
+ services.AddLogging(options =>
+ {
+ options.ClearProviders();
+ options.AddProvider(collector);
+ });
using ServiceProvider provider = services.BuildServiceProvider();
var cache = Assert.IsType(provider.GetRequiredService());
- const string Key = "abc";
-
string expected = "simple value";
- var actual = await cache.GetOrCreateAsync(Key, ct => new(new MutablePoco { Value = expected }));
+ var actual = await cache.GetOrCreateAsync(key!, ct => new(new MutablePoco { Value = expected }));
// expect same contents
Assert.Equal(expected, actual.Value);
// rinse and repeat, to check we get the value from L1
- actual = await cache.GetOrCreateAsync(Key, ct => new(new MutablePoco { Value = Guid.NewGuid().ToString() }));
+ actual = await cache.GetOrCreateAsync(key!, ct => new(new MutablePoco { Value = Guid.NewGuid().ToString() }));
if (expectFromL1)
{
@@ -86,10 +139,217 @@ public async Task ValidateSizeLimit_Mutable(long? sizeLimit, bool expectFromL1)
// L1 cache not used
Assert.NotEqual(expected, actual.Value);
}
+
+ collector.WriteTo(log);
+ collector.AssertErrors(errorIds);
+ }
+
+ [Theory]
+ [InlineData("some value", false, 1, 1, 2, false)]
+ [InlineData("read fail", false, 1, 1, 1, true, Log.IdDeserializationFailure)]
+ [InlineData("write fail", true, 1, 1, 0, true, Log.IdSerializationFailure)]
+ public async Task BrokenSerializer_Mutable(string value, bool same, int runCount, int serializeCount, int deserializeCount, bool expectKnownFailure, params int[] errorIds)
+ {
+ using var collector = new LogCollector();
+ var services = new ServiceCollection();
+ services.AddMemoryCache();
+ services.AddSingleton();
+ var serializer = new MutablePoco.Serializer();
+ services.AddHybridCache().AddSerializer(serializer);
+ services.AddLogging(options =>
+ {
+ options.ClearProviders();
+ options.AddProvider(collector);
+ });
+ using ServiceProvider provider = services.BuildServiceProvider();
+ var cache = Assert.IsType(provider.GetRequiredService());
+
+ int actualRunCount = 0;
+ Func> func = _ =>
+ {
+ Interlocked.Increment(ref actualRunCount);
+ return new(new MutablePoco { Value = value });
+ };
+
+ if (expectKnownFailure)
+ {
+ await Assert.ThrowsAsync(async () => await cache.GetOrCreateAsync("key", func));
+ }
+ else
+ {
+ var first = await cache.GetOrCreateAsync("key", func);
+ var second = await cache.GetOrCreateAsync("key", func);
+ Assert.Equal(value, first.Value);
+ Assert.Equal(value, second.Value);
+
+ if (same)
+ {
+ Assert.Same(first, second);
+ }
+ else
+ {
+ Assert.NotSame(first, second);
+ }
+ }
+
+ Assert.Equal(runCount, Volatile.Read(ref actualRunCount));
+ Assert.Equal(serializeCount, serializer.WriteCount);
+ Assert.Equal(deserializeCount, serializer.ReadCount);
+ collector.WriteTo(log);
+ collector.AssertErrors(errorIds);
+ }
+
+ [Theory]
+ [InlineData("some value", true, 1, 1, 0, false, true)]
+ [InlineData("read fail", true, 1, 1, 0, false, true)]
+ [InlineData("write fail", true, 1, 1, 0, true, true, Log.IdSerializationFailure)]
+
+ // without L2, we only need the serializer for sizing purposes (L1), not used for deserialize
+ [InlineData("some value", true, 1, 1, 0, false, false)]
+ [InlineData("read fail", true, 1, 1, 0, false, false)]
+ [InlineData("write fail", true, 1, 1, 0, true, false, Log.IdSerializationFailure)]
+ [System.Diagnostics.CodeAnalysis.SuppressMessage("Major Code Smell", "S107:Methods should not have too many parameters", Justification = "Test scenario range; reducing duplication")]
+ public async Task BrokenSerializer_Immutable(string value, bool same, int runCount, int serializeCount, int deserializeCount, bool expectKnownFailure, bool withL2,
+ params int[] errorIds)
+ {
+ using var collector = new LogCollector();
+ var services = new ServiceCollection();
+ services.AddMemoryCache();
+ if (withL2)
+ {
+ services.AddSingleton();
+ }
+
+ var serializer = new ImmutablePoco.Serializer();
+ services.AddHybridCache().AddSerializer(serializer);
+ services.AddLogging(options =>
+ {
+ options.ClearProviders();
+ options.AddProvider(collector);
+ });
+ using ServiceProvider provider = services.BuildServiceProvider();
+ var cache = Assert.IsType(provider.GetRequiredService());
+
+ int actualRunCount = 0;
+ Func> func = _ =>
+ {
+ Interlocked.Increment(ref actualRunCount);
+ return new(new ImmutablePoco(value));
+ };
+
+ if (expectKnownFailure)
+ {
+ await Assert.ThrowsAsync(async () => await cache.GetOrCreateAsync("key", func));
+ }
+ else
+ {
+ var first = await cache.GetOrCreateAsync("key", func);
+ var second = await cache.GetOrCreateAsync("key", func);
+ Assert.Equal(value, first.Value);
+ Assert.Equal(value, second.Value);
+
+ if (same)
+ {
+ Assert.Same(first, second);
+ }
+ else
+ {
+ Assert.NotSame(first, second);
+ }
+ }
+
+ Assert.Equal(runCount, Volatile.Read(ref actualRunCount));
+ Assert.Equal(serializeCount, serializer.WriteCount);
+ Assert.Equal(deserializeCount, serializer.ReadCount);
+ collector.WriteTo(log);
+ collector.AssertErrors(errorIds);
+ }
+
+ public class KnownFailureException : Exception
+ {
+ public KnownFailureException(string message)
+ : base(message)
+ {
+ }
}
public class MutablePoco
{
public string Value { get; set; } = "";
+
+ public sealed class Serializer : IHybridCacheSerializer
+ {
+ private int _readCount;
+ private int _writeCount;
+
+ public int ReadCount => Volatile.Read(ref _readCount);
+ public int WriteCount => Volatile.Read(ref _writeCount);
+
+ public MutablePoco Deserialize(ReadOnlySequence source)
+ {
+ Interlocked.Increment(ref _readCount);
+ var value = InbuiltTypeSerializer.DeserializeString(source);
+ if (value == "read fail")
+ {
+ throw new KnownFailureException("read failure");
+ }
+
+ return new MutablePoco { Value = value };
+ }
+
+ public void Serialize(MutablePoco value, IBufferWriter target)
+ {
+ Interlocked.Increment(ref _writeCount);
+ if (value.Value == "write fail")
+ {
+ throw new KnownFailureException("write failure");
+ }
+
+ InbuiltTypeSerializer.SerializeString(value.Value, target);
+ }
+ }
+ }
+
+ [ImmutableObject(true)]
+ public sealed class ImmutablePoco
+ {
+ public ImmutablePoco(string value)
+ {
+ Value = value;
+ }
+
+ public string Value { get; }
+
+ public sealed class Serializer : IHybridCacheSerializer
+ {
+ private int _readCount;
+ private int _writeCount;
+
+ public int ReadCount => Volatile.Read(ref _readCount);
+ public int WriteCount => Volatile.Read(ref _writeCount);
+
+ public ImmutablePoco Deserialize(ReadOnlySequence source)
+ {
+ Interlocked.Increment(ref _readCount);
+ var value = InbuiltTypeSerializer.DeserializeString(source);
+ if (value == "read fail")
+ {
+ throw new KnownFailureException("read failure");
+ }
+
+ return new ImmutablePoco(value);
+ }
+
+ public void Serialize(ImmutablePoco value, IBufferWriter target)
+ {
+ Interlocked.Increment(ref _writeCount);
+ if (value.Value == "write fail")
+ {
+ throw new KnownFailureException("write failure");
+ }
+
+ InbuiltTypeSerializer.SerializeString(value.Value, target);
+ }
+ }
}
}
diff --git a/test/Libraries/Microsoft.Extensions.Caching.Hybrid.Tests/TestEventListener.cs b/test/Libraries/Microsoft.Extensions.Caching.Hybrid.Tests/TestEventListener.cs
new file mode 100644
index 00000000000..ecb97ef3c7e
--- /dev/null
+++ b/test/Libraries/Microsoft.Extensions.Caching.Hybrid.Tests/TestEventListener.cs
@@ -0,0 +1,189 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Diagnostics;
+using System.Diagnostics.Tracing;
+using System.Globalization;
+using Microsoft.Extensions.Caching.Hybrid.Internal;
+
+namespace Microsoft.Extensions.Caching.Hybrid.Tests;
+
+public sealed class TestEventListener : EventListener
+{
+ // captures both event and counter data
+
+ // this is used as a class fixture from HybridCacheEventSourceTests, because there
+ // seems to be some unpredictable behaviours if multiple event sources/listeners are
+ // casually created etc
+ private const double EventCounterIntervalSec = 0.25;
+
+ private readonly List<(int id, string name, EventLevel level)> _events = [];
+ private readonly Dictionary _counters = [];
+
+ private object SyncLock => _events;
+
+ internal HybridCacheEventSource Source { get; } = new();
+
+ public TestEventListener Reset(bool resetCounters = true)
+ {
+ lock (SyncLock)
+ {
+ _events.Clear();
+ _counters.Clear();
+
+ if (resetCounters)
+ {
+ Source.ResetCounters();
+ }
+ }
+
+ Assert.True(Source.IsEnabled(), "should report as enabled");
+
+ return this;
+ }
+
+ protected override void OnEventSourceCreated(EventSource eventSource)
+ {
+ if (ReferenceEquals(eventSource, Source))
+ {
+ var args = new Dictionary
+ {
+ ["EventCounterIntervalSec"] = EventCounterIntervalSec.ToString("G", CultureInfo.InvariantCulture),
+ };
+ EnableEvents(Source, EventLevel.Verbose, EventKeywords.All, args);
+ }
+
+ base.OnEventSourceCreated(eventSource);
+ }
+
+ protected override void OnEventWritten(EventWrittenEventArgs eventData)
+ {
+ if (ReferenceEquals(eventData.EventSource, Source))
+ {
+ // capture counters/events
+ lock (SyncLock)
+ {
+ if (eventData.EventName == "EventCounters"
+ && eventData.Payload is { Count: > 0 })
+ {
+ foreach (var payload in eventData.Payload)
+ {
+ if (payload is IDictionary map)
+ {
+ string? name = null;
+ string? displayName = null;
+ double? value = null;
+ bool isIncrement = false;
+ foreach (var pair in map)
+ {
+ switch (pair.Key)
+ {
+ case "Name" when pair.Value is string:
+ name = (string)pair.Value;
+ break;
+ case "DisplayName" when pair.Value is string s:
+ displayName = s;
+ break;
+ case "Mean":
+ isIncrement = false;
+ value = Convert.ToDouble(pair.Value);
+ break;
+ case "Increment":
+ isIncrement = true;
+ value = Convert.ToDouble(pair.Value);
+ break;
+ }
+ }
+
+ if (name is not null && value is not null)
+ {
+ if (isIncrement && _counters.TryGetValue(name, out var oldPair))
+ {
+ value += oldPair.value; // treat as delta from old
+ }
+
+ Debug.WriteLine($"{name}={value}");
+ _counters[name] = (displayName, value.Value);
+ }
+ }
+ }
+ }
+ else
+ {
+ _events.Add((eventData.EventId, eventData.EventName ?? "", eventData.Level));
+ }
+ }
+ }
+
+ base.OnEventWritten(eventData);
+ }
+
+ public (int id, string name, EventLevel level) SingleEvent()
+ {
+ (int id, string name, EventLevel level) evt;
+ lock (SyncLock)
+ {
+ evt = Assert.Single(_events);
+ }
+
+ return evt;
+ }
+
+ public void AssertSingleEvent(int id, string name, EventLevel level)
+ {
+ var evt = SingleEvent();
+ Assert.Equal(name, evt.name);
+ Assert.Equal(id, evt.id);
+ Assert.Equal(level, evt.level);
+ }
+
+ public double AssertCounter(string name, string displayName)
+ {
+ lock (SyncLock)
+ {
+ Assert.True(_counters.TryGetValue(name, out var pair), $"counter not found: {name}");
+ Assert.Equal(displayName, pair.displayName);
+
+ _counters.Remove(name); // count as validated
+ return pair.value;
+ }
+ }
+
+ public void AssertCounter(string name, string displayName, double expected)
+ {
+ var actual = AssertCounter(name, displayName);
+ if (!Equals(expected, actual))
+ {
+ Assert.Fail($"{name}: expected {expected}, actual {actual}");
+ }
+ }
+
+ [System.Diagnostics.CodeAnalysis.SuppressMessage("Major Bug", "S1244:Floating point numbers should not be tested for equality", Justification = "Test expects exact zero")]
+ public void AssertRemainingCountersZero()
+ {
+ lock (SyncLock)
+ {
+ foreach (var pair in _counters)
+ {
+ if (pair.Value.value != 0)
+ {
+ Assert.Fail($"{pair.Key}: expected 0, actual {pair.Value.value}");
+ }
+ }
+ }
+ }
+
+ [System.Diagnostics.CodeAnalysis.SuppressMessage("Performance", "CA1822:Mark members as static", Justification = "Clarity and usability")]
+ public async Task TryAwaitCountersAsync()
+ {
+ // allow 2 cycles because if we only allow 1, we run the risk of a
+ // snapshot being captured mid-cycle when we were setting up the test
+ // (ok, that's an unlikely race condition, but!)
+ await Task.Delay(TimeSpan.FromSeconds(EventCounterIntervalSec * 2));
+
+ lock (SyncLock)
+ {
+ return _counters.Count;
+ }
+ }
+}
diff --git a/test/Libraries/Microsoft.Extensions.Caching.Hybrid.Tests/UnreliableL2Tests.cs b/test/Libraries/Microsoft.Extensions.Caching.Hybrid.Tests/UnreliableL2Tests.cs
new file mode 100644
index 00000000000..7af85f9cba2
--- /dev/null
+++ b/test/Libraries/Microsoft.Extensions.Caching.Hybrid.Tests/UnreliableL2Tests.cs
@@ -0,0 +1,251 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Diagnostics.CodeAnalysis;
+using Microsoft.Extensions.Caching.Distributed;
+using Microsoft.Extensions.Caching.Hybrid.Internal;
+using Microsoft.Extensions.Caching.Memory;
+using Microsoft.Extensions.DependencyInjection;
+using Microsoft.Extensions.Logging;
+using Xunit.Abstractions;
+
+namespace Microsoft.Extensions.Caching.Hybrid.Tests;
+
+// validate HC stability when the L2 is unreliable
+public class UnreliableL2Tests(ITestOutputHelper testLog)
+{
+ [Theory]
+ [InlineData(BreakType.None)]
+ [InlineData(BreakType.Synchronous, Log.IdCacheBackendWriteFailure)]
+ [InlineData(BreakType.Asynchronous, Log.IdCacheBackendWriteFailure)]
+ [InlineData(BreakType.AsynchronousYield, Log.IdCacheBackendWriteFailure)]
+ [SuppressMessage("Usage", "VSTHRD003:Avoid awaiting foreign Tasks", Justification = "Intentional; tracking for out-of-band support only")]
+ public async Task WriteFailureInvisible(BreakType writeBreak, params int[] errorIds)
+ {
+ using (GetServices(out var hc, out var l1, out var l2, out var log))
+ using (log)
+ {
+ // normal behaviour when working fine
+ var x = await hc.GetOrCreateAsync("x", NewGuid);
+ Assert.Equal(x, await hc.GetOrCreateAsync("x", NewGuid));
+ Assert.NotNull(l2.Tail.Get("x")); // exists
+
+ l2.WriteBreak = writeBreak;
+ var y = await hc.GetOrCreateAsync("y", NewGuid);
+ Assert.Equal(y, await hc.GetOrCreateAsync("y", NewGuid));
+ if (writeBreak == BreakType.None)
+ {
+ Assert.NotNull(l2.Tail.Get("y")); // exists
+ }
+ else
+ {
+ Assert.Null(l2.Tail.Get("y")); // does not exist
+ }
+
+ await l2.LastWrite; // allows out-of-band write to complete
+ await Task.Delay(150); // even then: thread jitter can cause problems
+
+ log.WriteTo(testLog);
+ log.AssertErrors(errorIds);
+ }
+ }
+
+ [Theory]
+ [InlineData(BreakType.None)]
+ [InlineData(BreakType.Synchronous, Log.IdCacheBackendReadFailure, Log.IdCacheBackendReadFailure)]
+ [InlineData(BreakType.Asynchronous, Log.IdCacheBackendReadFailure, Log.IdCacheBackendReadFailure)]
+ [InlineData(BreakType.AsynchronousYield, Log.IdCacheBackendReadFailure, Log.IdCacheBackendReadFailure)]
+ public async Task ReadFailureInvisible(BreakType readBreak, params int[] errorIds)
+ {
+ using (GetServices(out var hc, out var l1, out var l2, out var log))
+ using (log)
+ {
+ // create two new values via HC; this should go down to l2
+ var x = await hc.GetOrCreateAsync("x", NewGuid);
+ var y = await hc.GetOrCreateAsync("y", NewGuid);
+
+ // this should be reliable and repeatable
+ Assert.Equal(x, await hc.GetOrCreateAsync("x", NewGuid));
+ Assert.Equal(y, await hc.GetOrCreateAsync("y", NewGuid));
+
+ // even if we clean L1, causing new L2 fetches
+ l1.Clear();
+ Assert.Equal(x, await hc.GetOrCreateAsync("x", NewGuid));
+ Assert.Equal(y, await hc.GetOrCreateAsync("y", NewGuid));
+
+ // now we break L2 in some predictable way, *without* clearing L1 - the
+ // values should still be available via L1
+ l2.ReadBreak = readBreak;
+ Assert.Equal(x, await hc.GetOrCreateAsync("x", NewGuid));
+ Assert.Equal(y, await hc.GetOrCreateAsync("y", NewGuid));
+
+ // but if we clear L1 to force L2 hits, we anticipate problems
+ l1.Clear();
+ if (readBreak == BreakType.None)
+ {
+ Assert.Equal(x, await hc.GetOrCreateAsync("x", NewGuid));
+ Assert.Equal(y, await hc.GetOrCreateAsync("y", NewGuid));
+ }
+ else
+ {
+ // because L2 is unavailable and L1 is empty, we expect the callback
+ // to be used again, generating new values
+ var a = await hc.GetOrCreateAsync("x", NewGuid, NoL2Write);
+ var b = await hc.GetOrCreateAsync("y", NewGuid, NoL2Write);
+
+ Assert.NotEqual(x, a);
+ Assert.NotEqual(y, b);
+
+ // but those *new* values are at least reliable inside L1
+ Assert.Equal(a, await hc.GetOrCreateAsync("x", NewGuid));
+ Assert.Equal(b, await hc.GetOrCreateAsync("y", NewGuid));
+ }
+
+ log.WriteTo(testLog);
+ log.AssertErrors(errorIds);
+ }
+ }
+
+ private static HybridCacheEntryOptions NoL2Write { get; } = new HybridCacheEntryOptions { Flags = HybridCacheEntryFlags.DisableDistributedCacheWrite };
+
+ public enum BreakType
+ {
+ None, // async API works correctly
+ Synchronous, // async API faults directly rather than return a faulted task
+ Asynchronous, // async API returns a completed asynchronous fault
+ AsynchronousYield, // async API returns an incomplete asynchronous fault
+ }
+
+ private static ValueTask NewGuid(CancellationToken cancellationToken) => new(Guid.NewGuid());
+
+ private static IDisposable GetServices(out HybridCache hc, out MemoryCache l1,
+ out UnreliableDistributedCache l2, out LogCollector log)
+ {
+ // we need an entirely separate MC for the dummy backend, not connected to our
+ // "real" services
+ var services = new ServiceCollection();
+ services.AddDistributedMemoryCache();
+ var backend = services.BuildServiceProvider().GetRequiredService();
+
+ // now create the "real" services
+ l2 = new UnreliableDistributedCache(backend);
+ var collector = new LogCollector();
+ log = collector;
+ services = new ServiceCollection();
+ services.AddSingleton(l2);
+ services.AddHybridCache();
+ services.AddLogging(options =>
+ {
+ options.ClearProviders();
+ options.AddProvider(collector);
+ });
+ var lifetime = services.BuildServiceProvider();
+ hc = lifetime.GetRequiredService();
+ l1 = Assert.IsType(lifetime.GetRequiredService());
+ return lifetime;
+ }
+
+ private sealed class UnreliableDistributedCache : IDistributedCache
+ {
+ public UnreliableDistributedCache(IDistributedCache tail)
+ {
+ Tail = tail;
+ }
+
+ public IDistributedCache Tail { get; }
+ public BreakType ReadBreak { get; set; }
+ public BreakType WriteBreak { get; set; }
+
+ public Task LastWrite { get; private set; } = Task.CompletedTask;
+
+ public byte[]? Get(string key) => throw new NotSupportedException(); // only async API in use
+
+ public Task GetAsync(string key, CancellationToken token = default)
+ => TrackLast(ThrowIfBrokenAsync(ReadBreak) ?? Tail.GetAsync(key, token));
+
+ public void Refresh(string key) => throw new NotSupportedException(); // only async API in use
+
+ public Task RefreshAsync(string key, CancellationToken token = default)
+ => TrackLast(ThrowIfBrokenAsync(WriteBreak) ?? Tail.RefreshAsync(key, token));
+
+ public void Remove(string key) => throw new NotSupportedException(); // only async API in use
+
+ public Task RemoveAsync(string key, CancellationToken token = default)
+ => TrackLast(ThrowIfBrokenAsync(WriteBreak) ?? Tail.RemoveAsync(key, token));
+
+ public void Set(string key, byte[] value, DistributedCacheEntryOptions options) => throw new NotSupportedException(); // only async API in use
+
+ public Task SetAsync(string key, byte[] value, DistributedCacheEntryOptions options, CancellationToken token = default)
+ => TrackLast(ThrowIfBrokenAsync(WriteBreak) ?? Tail.SetAsync(key, value, options, token));
+
+ [DoesNotReturn]
+ private static void Throw() => throw new IOException("L2 offline");
+
+ private static async Task ThrowAsync(bool yield)
+ {
+ if (yield)
+ {
+ await Task.Yield();
+ }
+
+ Throw();
+ return default; // never reached
+ }
+
+ private static Task? ThrowIfBrokenAsync(BreakType breakType) => ThrowIfBrokenAsync(breakType);
+
+ [SuppressMessage("Critical Bug", "S4586:Non-async \"Task/Task\" methods should not return null", Justification = "Intentional for propagation")]
+ private static Task? ThrowIfBrokenAsync(BreakType breakType)
+ {
+ switch (breakType)
+ {
+ case BreakType.Asynchronous:
+ return ThrowAsync(false);
+ case BreakType.AsynchronousYield:
+ return ThrowAsync(true);
+ case BreakType.None:
+ return null;
+ default:
+ // includes BreakType.Synchronous and anything unknown
+ Throw();
+ break;
+ }
+
+ return null;
+ }
+
+ [SuppressMessage("Usage", "VSTHRD003:Avoid awaiting foreign Tasks", Justification = "Intentional; tracking for out-of-band support only")]
+ [SuppressMessage("Design", "CA1031:Do not catch general exception types", Justification = "We don't need the failure type - just the timing")]
+ private static Task IgnoreFailure(Task task)
+ {
+ return task.Status == TaskStatus.RanToCompletion
+ ? Task.CompletedTask : IgnoreAsync(task);
+
+ static async Task IgnoreAsync(Task task)
+ {
+ try
+ {
+ await task;
+ }
+ catch
+ {
+ // we only care about the "when"; failure is fine
+ }
+ }
+ }
+
+ [SuppressMessage("Usage", "VSTHRD003:Avoid awaiting foreign Tasks", Justification = "Intentional; tracking for out-of-band support only")]
+ private Task TrackLast(Task lastWrite)
+ {
+ LastWrite = IgnoreFailure(lastWrite);
+ return lastWrite;
+ }
+
+ [SuppressMessage("Usage", "VSTHRD003:Avoid awaiting foreign Tasks", Justification = "Intentional; tracking for out-of-band support only")]
+ private Task TrackLast(Task lastWrite)
+ {
+ LastWrite = IgnoreFailure(lastWrite);
+ return lastWrite;
+ }
+ }
+}
diff --git a/test/Libraries/Microsoft.Extensions.Telemetry.Abstractions.Tests/Microsoft.Extensions.Telemetry.Abstractions.Tests.csproj b/test/Libraries/Microsoft.Extensions.Telemetry.Abstractions.Tests/Microsoft.Extensions.Telemetry.Abstractions.Tests.csproj
index ac284fee861..387cec3c5c0 100644
--- a/test/Libraries/Microsoft.Extensions.Telemetry.Abstractions.Tests/Microsoft.Extensions.Telemetry.Abstractions.Tests.csproj
+++ b/test/Libraries/Microsoft.Extensions.Telemetry.Abstractions.Tests/Microsoft.Extensions.Telemetry.Abstractions.Tests.csproj
@@ -12,4 +12,8 @@
+
+
+
+
diff --git a/test/Shared/JsonSchemaExporter/JsonSchemaExporterConfigurationTests.cs b/test/Shared/JsonSchemaExporter/JsonSchemaExporterConfigurationTests.cs
new file mode 100644
index 00000000000..1d2b6caa74e
--- /dev/null
+++ b/test/Shared/JsonSchemaExporter/JsonSchemaExporterConfigurationTests.cs
@@ -0,0 +1,35 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Text.Json.Schema;
+using Xunit;
+
+namespace Microsoft.Extensions.AI.JsonSchemaExporter;
+
+public static class JsonSchemaExporterConfigurationTests
+{
+ [Theory]
+ [InlineData(false)]
+ [InlineData(true)]
+ public static void JsonSchemaExporterOptions_DefaultValues(bool useSingleton)
+ {
+ JsonSchemaExporterOptions configuration = useSingleton ? JsonSchemaExporterOptions.Default : new();
+ Assert.False(configuration.TreatNullObliviousAsNonNullable);
+ Assert.Null(configuration.TransformSchemaNode);
+ }
+
+ [Fact]
+ public static void JsonSchemaExporterOptions_Singleton_ReturnsSameInstance()
+ {
+ Assert.Same(JsonSchemaExporterOptions.Default, JsonSchemaExporterOptions.Default);
+ }
+
+ [Theory]
+ [InlineData(false)]
+ [InlineData(true)]
+ public static void JsonSchemaExporterOptions_TreatNullObliviousAsNonNullable(bool treatNullObliviousAsNonNullable)
+ {
+ JsonSchemaExporterOptions configuration = new() { TreatNullObliviousAsNonNullable = treatNullObliviousAsNonNullable };
+ Assert.Equal(treatNullObliviousAsNonNullable, configuration.TreatNullObliviousAsNonNullable);
+ }
+}
diff --git a/test/Shared/JsonSchemaExporter/JsonSchemaExporterTests.cs b/test/Shared/JsonSchemaExporter/JsonSchemaExporterTests.cs
new file mode 100644
index 00000000000..2ec81987dc2
--- /dev/null
+++ b/test/Shared/JsonSchemaExporter/JsonSchemaExporterTests.cs
@@ -0,0 +1,147 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+using System.Collections.Generic;
+using System.Collections.Immutable;
+using System.Text.Json;
+using System.Text.Json.Nodes;
+using System.Text.Json.Schema;
+using System.Text.Json.Serialization;
+using System.Text.Json.Serialization.Metadata;
+#if !NET9_0_OR_GREATER
+using System.Xml.Linq;
+#endif
+using Xunit;
+
+#pragma warning disable SA1402 // File may only contain a single type
+
+namespace Microsoft.Extensions.AI.JsonSchemaExporter;
+
+public abstract class JsonSchemaExporterTests
+{
+ protected abstract JsonSerializerOptions Options { get; }
+
+ [Theory]
+ [MemberData(nameof(TestTypes.GetTestData), MemberType = typeof(TestTypes))]
+ public void TestTypes_GeneratesExpectedJsonSchema(ITestData testData)
+ {
+ JsonSerializerOptions options = testData.Options is { } opts
+ ? new(opts) { TypeInfoResolver = Options.TypeInfoResolver }
+ : Options;
+
+ JsonNode schema = options.GetJsonSchemaAsNode(testData.Type, (JsonSchemaExporterOptions?)testData.ExporterOptions);
+ SchemaTestHelpers.AssertEqualJsonSchema(testData.ExpectedJsonSchema, schema);
+ }
+
+ [Theory]
+ [MemberData(nameof(TestTypes.GetTestDataUsingAllValues), MemberType = typeof(TestTypes))]
+ public void TestTypes_SerializedValueMatchesGeneratedSchema(ITestData testData)
+ {
+ JsonSerializerOptions options = testData.Options is { } opts
+ ? new(opts) { TypeInfoResolver = Options.TypeInfoResolver }
+ : Options;
+
+ JsonNode schema = options.GetJsonSchemaAsNode(testData.Type, (JsonSchemaExporterOptions?)testData.ExporterOptions);
+ JsonNode? instance = JsonSerializer.SerializeToNode(testData.Value, testData.Type, options);
+ SchemaTestHelpers.AssertDocumentMatchesSchema(schema, instance);
+ }
+
+ [Theory]
+ [InlineData(typeof(string), "string")]
+ [InlineData(typeof(int[]), "array")]
+ [InlineData(typeof(Dictionary), "object")]
+ [InlineData(typeof(TestTypes.SimplePoco), "object")]
+ public void TreatNullObliviousAsNonNullable_True_MarksAllReferenceTypesAsNonNullable(Type referenceType, string expectedType)
+ {
+ Assert.True(!referenceType.IsValueType);
+ var config = new JsonSchemaExporterOptions { TreatNullObliviousAsNonNullable = true };
+ JsonNode schema = Options.GetJsonSchemaAsNode(referenceType, config);
+ JsonValue type = Assert.IsAssignableFrom(schema["type"]);
+ Assert.Equal(expectedType, (string)type!);
+ }
+
+ [Theory]
+ [InlineData(typeof(int), "integer")]
+ [InlineData(typeof(double), "number")]
+ [InlineData(typeof(bool), "boolean")]
+ [InlineData(typeof(ImmutableArray), "array")]
+ [InlineData(typeof(TestTypes.StructDictionary), "object")]
+ [InlineData(typeof(TestTypes.SimpleRecordStruct), "object")]
+ public void TreatNullObliviousAsNonNullable_True_DoesNotImpactNonReferenceTypes(Type referenceType, string expectedType)
+ {
+ Assert.True(referenceType.IsValueType);
+ var config = new JsonSchemaExporterOptions { TreatNullObliviousAsNonNullable = true };
+ JsonNode schema = Options.GetJsonSchemaAsNode(referenceType, config);
+ JsonValue value = Assert.IsAssignableFrom(schema["type"]);
+ Assert.Equal(expectedType, (string)value!);
+ }
+
+#if !NET9_0 // Disable until https://github.com/dotnet/runtime/pull/108764 gets backported
+ [Fact]
+ public void CanGenerateXElementSchema()
+ {
+ JsonNode schema = Options.GetJsonSchemaAsNode(typeof(XElement));
+ Assert.True(schema.ToJsonString().Length < 100_000);
+ }
+#endif
+
+ [Fact]
+ public void TreatNullObliviousAsNonNullable_True_DoesNotImpactObjectType()
+ {
+ var config = new JsonSchemaExporterOptions { TreatNullObliviousAsNonNullable = true };
+ JsonNode schema = Options.GetJsonSchemaAsNode(typeof(object), config);
+ Assert.False(schema is JsonObject jObj && jObj.ContainsKey("type"));
+ }
+
+ [Fact]
+ public void TypeWithDisallowUnmappedMembers_AdditionalPropertiesFailValidation()
+ {
+ JsonNode schema = Options.GetJsonSchemaAsNode(typeof(TestTypes.PocoDisallowingUnmappedMembers));
+ JsonNode? jsonWithUnmappedProperties = JsonNode.Parse("""{ "UnmappedProperty" : {} }""");
+ SchemaTestHelpers.AssertDoesNotMatchSchema(schema, jsonWithUnmappedProperties);
+ }
+
+ [Fact]
+ public void GetJsonSchema_NullInputs_ThrowsArgumentNullException()
+ {
+ Assert.Throws(() => ((JsonSerializerOptions)null!).GetJsonSchemaAsNode(typeof(int)));
+ Assert.Throws(() => Options.GetJsonSchemaAsNode(type: null!));
+ Assert.Throws(() => ((JsonTypeInfo)null!).GetJsonSchemaAsNode());
+ }
+
+ [Fact]
+ public void GetJsonSchema_NoResolver_ThrowInvalidOperationException()
+ {
+ var options = new JsonSerializerOptions();
+ Assert.Throws(() => options.GetJsonSchemaAsNode(typeof(int)));
+ }
+
+ [Fact]
+ public void MaxDepth_SetToZero_NonTrivialSchema_ThrowsInvalidOperationException()
+ {
+ JsonSerializerOptions options = new(Options) { MaxDepth = 1 };
+ var ex = Assert.Throws(() => options.GetJsonSchemaAsNode(typeof(TestTypes.SimplePoco)));
+ Assert.Contains("The depth of the generated JSON schema exceeds the JsonSerializerOptions.MaxDepth setting.", ex.Message);
+ }
+
+ [Fact]
+ public void ReferenceHandlePreserve_Enabled_ThrowsNotSupportedException()
+ {
+ var options = new JsonSerializerOptions(Options) { ReferenceHandler = ReferenceHandler.Preserve };
+ options.MakeReadOnly();
+
+ var ex = Assert.Throws(() => options.GetJsonSchemaAsNode(typeof(TestTypes.SimplePoco)));
+ Assert.Contains("ReferenceHandler.Preserve", ex.Message);
+ }
+}
+
+public sealed class ReflectionJsonSchemaExporterTests : JsonSchemaExporterTests
+{
+ protected override JsonSerializerOptions Options => JsonSerializerOptions.Default;
+}
+
+public sealed class SourceGenJsonSchemaExporterTests : JsonSchemaExporterTests
+{
+ protected override JsonSerializerOptions Options => TestTypes.TestTypesContext.Default.Options;
+}
diff --git a/test/Shared/JsonSchemaExporter/SchemaTestHelpers.cs b/test/Shared/JsonSchemaExporter/SchemaTestHelpers.cs
new file mode 100644
index 00000000000..02e659a27aa
--- /dev/null
+++ b/test/Shared/JsonSchemaExporter/SchemaTestHelpers.cs
@@ -0,0 +1,82 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text.Json;
+using System.Text.Json.Nodes;
+using System.Text.Json.Serialization;
+using Json.Schema;
+using Xunit.Sdk;
+
+namespace Microsoft.Extensions.AI.JsonSchemaExporter;
+
+internal static partial class SchemaTestHelpers
+{
+ public static void AssertEqualJsonSchema(JsonNode expectedJsonSchema, JsonNode actualJsonSchema)
+ {
+ if (!JsonNode.DeepEquals(expectedJsonSchema, actualJsonSchema))
+ {
+ throw new XunitException($"""
+ Generated schema does not match the expected specification.
+ Expected:
+ {FormatJson(expectedJsonSchema)}
+ Actual:
+ {FormatJson(actualJsonSchema)}
+ """);
+ }
+ }
+
+ public static void AssertDocumentMatchesSchema(JsonNode schema, JsonNode? instance)
+ {
+ EvaluationResults results = EvaluateSchemaCore(schema, instance);
+ if (!results.IsValid)
+ {
+ IEnumerable errors = results.Details
+ .Where(d => d.HasErrors)
+ .SelectMany(d => d.Errors!.Select(error => $"Path:${d.InstanceLocation} {error.Key}:{error.Value}"));
+
+ throw new XunitException($"""
+ Instance JSON document does not match the specified schema.
+ Schema:
+ {FormatJson(schema)}
+ Instance:
+ {FormatJson(instance)}
+ Errors:
+ {string.Join(Environment.NewLine, errors)}
+ """);
+ }
+ }
+
+ public static void AssertDoesNotMatchSchema(JsonNode schema, JsonNode? instance)
+ {
+ EvaluationResults results = EvaluateSchemaCore(schema, instance);
+ if (results.IsValid)
+ {
+ throw new XunitException($"""
+ Instance JSON document matches the specified schema.
+ Schema:
+ {FormatJson(schema)}
+ Instance:
+ {FormatJson(instance)}
+ """);
+ }
+ }
+
+ private static EvaluationResults EvaluateSchemaCore(JsonNode schema, JsonNode? instance)
+ {
+ JsonSchema jsonSchema = JsonSerializer.Deserialize(schema, Context.Default.JsonSchema)!;
+ EvaluationOptions options = new() { OutputFormat = OutputFormat.List };
+ return jsonSchema.Evaluate(instance, options);
+ }
+
+ private static string FormatJson(JsonNode? node) =>
+ JsonSerializer.Serialize(node, Context.Default.JsonNode!);
+
+ [JsonSerializable(typeof(string))]
+ [JsonSerializable(typeof(JsonNode))]
+ [JsonSerializable(typeof(JsonSchema))]
+ [JsonSourceGenerationOptions(WriteIndented = true)]
+ private partial class Context : JsonSerializerContext;
+}
diff --git a/test/Shared/JsonSchemaExporter/TestData.cs b/test/Shared/JsonSchemaExporter/TestData.cs
new file mode 100644
index 00000000000..0254a62b144
--- /dev/null
+++ b/test/Shared/JsonSchemaExporter/TestData.cs
@@ -0,0 +1,67 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+using System.Collections.Generic;
+using System.Diagnostics.CodeAnalysis;
+using System.Text.Json;
+using System.Text.Json.Nodes;
+using System.Text.Json.Schema;
+
+namespace Microsoft.Extensions.AI.JsonSchemaExporter;
+
+internal sealed record TestData(
+ T? Value,
+ [StringSyntax(StringSyntaxAttribute.Json)] string ExpectedJsonSchema,
+ IEnumerable? AdditionalValues = null,
+ JsonSchemaExporterOptions? ExporterOptions = null,
+ JsonSerializerOptions? Options = null,
+ bool WritesNumbersAsStrings = false)
+ : ITestData
+{
+ private static readonly JsonDocumentOptions _schemaParseOptions = new() { CommentHandling = JsonCommentHandling.Skip };
+
+ public Type Type => typeof(T);
+ object? ITestData.Value => Value;
+ object? ITestData.ExporterOptions => ExporterOptions;
+ JsonNode ITestData.ExpectedJsonSchema { get; } =
+ JsonNode.Parse(ExpectedJsonSchema, documentOptions: _schemaParseOptions)
+ ?? throw new ArgumentNullException("schema must not be null");
+
+ IEnumerable ITestData.GetTestDataForAllValues()
+ {
+ yield return this;
+
+ if (default(T) is null &&
+ ExporterOptions is { TreatNullObliviousAsNonNullable: false } &&
+ Value is not null)
+ {
+ yield return this with { Value = default };
+ }
+
+ if (AdditionalValues != null)
+ {
+ foreach (T? value in AdditionalValues)
+ {
+ yield return this with { Value = value, AdditionalValues = null };
+ }
+ }
+ }
+}
+
+public interface ITestData
+{
+ Type Type { get; }
+
+ object? Value { get; }
+
+ JsonNode ExpectedJsonSchema { get; }
+
+ object? ExporterOptions { get; }
+
+ JsonSerializerOptions? Options { get; }
+
+ bool WritesNumbersAsStrings { get; }
+
+ IEnumerable GetTestDataForAllValues();
+}
diff --git a/test/Shared/JsonSchemaExporter/TestTypes.cs b/test/Shared/JsonSchemaExporter/TestTypes.cs
new file mode 100644
index 00000000000..d21a40640dd
--- /dev/null
+++ b/test/Shared/JsonSchemaExporter/TestTypes.cs
@@ -0,0 +1,1291 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+using System.Collections;
+using System.Collections.Generic;
+using System.Collections.Immutable;
+using System.ComponentModel;
+using System.ComponentModel.DataAnnotations;
+using System.Diagnostics.CodeAnalysis;
+using System.Linq;
+using System.Reflection;
+using System.Text.Json;
+using System.Text.Json.Nodes;
+using System.Text.Json.Schema;
+using System.Text.Json.Serialization;
+using System.Xml.Linq;
+
+#pragma warning disable SA1118 // Parameter should not span multiple lines
+#pragma warning disable JSON001 // Comments not allowed
+#pragma warning disable S2344 // Enumeration type names should not have "Flags" or "Enum" suffixes
+#pragma warning disable SA1502 // Element should not be on a single line
+#pragma warning disable SA1136 // Enum values should be on separate lines
+#pragma warning disable SA1133 // Do not combine attributes
+#pragma warning disable S3604 // Member initializer values should not be redundant
+#pragma warning disable SA1515 // Single-line comment should be preceded by blank line
+#pragma warning disable CA1052 // Static holder types should be Static or NotInheritable
+#pragma warning disable S1121 // Assignments should not be made from within sub-expressions
+#pragma warning disable IDE0073 // The file header is missing or not located at the top of the file
+
+namespace Microsoft.Extensions.AI.JsonSchemaExporter;
+
+public static partial class TestTypes
+{
+ public static IEnumerable