diff --git a/eng/MSBuild/LegacySupport.props b/eng/MSBuild/LegacySupport.props index 2cfe7b73964..842951ab867 100644 --- a/eng/MSBuild/LegacySupport.props +++ b/eng/MSBuild/LegacySupport.props @@ -43,6 +43,10 @@ + + + + diff --git a/eng/packages/TestOnly.props b/eng/packages/TestOnly.props index 2bde3b34e05..4c78b8dcbe8 100644 --- a/eng/packages/TestOnly.props +++ b/eng/packages/TestOnly.props @@ -7,6 +7,7 @@ + @@ -20,6 +21,7 @@ + diff --git a/eng/spellchecking_exclusions.dic b/eng/spellchecking_exclusions.dic index 2fc9b74699b..72596816516 100644 Binary files a/eng/spellchecking_exclusions.dic and b/eng/spellchecking_exclusions.dic differ diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/AdditionalPropertiesDictionary.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/AdditionalPropertiesDictionary.cs index 616ad284198..4a681d4679a 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/AdditionalPropertiesDictionary.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/AdditionalPropertiesDictionary.cs @@ -4,13 +4,21 @@ using System; using System.Collections; using System.Collections.Generic; +using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.Globalization; using System.Linq; +using Microsoft.Shared.Diagnostics; + +#pragma warning disable S1144 // Unused private types or members should be removed +#pragma warning disable S2365 // Properties should not make collection or array copies +#pragma warning disable S3604 // Member initializer values should not be redundant namespace Microsoft.Extensions.AI; /// Provides a dictionary used as the AdditionalProperties dictionary on Microsoft.Extensions.AI objects. +[DebuggerTypeProxy(typeof(DebugView))] +[DebuggerDisplay("Count = {Count}")] public sealed class AdditionalPropertiesDictionary : IDictionary, IReadOnlyDictionary { /// The underlying dictionary. @@ -77,6 +85,25 @@ public object? this[string key] /// public void Add(string key, object? value) => _dictionary.Add(key, value); + /// Attempts to add the specified key and value to the dictionary. + /// The key of the element to add. + /// The value of the element to add. + /// if the key/value pair was added to the dictionary successfully; otherwise, . + public bool TryAdd(string key, object? value) + { +#if NET + return _dictionary.TryAdd(key, value); +#else + if (!_dictionary.ContainsKey(key)) + { + _dictionary.Add(key, value); + return true; + } + + return false; +#endif + } + /// void ICollection>.Add(KeyValuePair item) => ((ICollection>)_dictionary).Add(item); @@ -93,11 +120,17 @@ public object? this[string key] void ICollection>.CopyTo(KeyValuePair[] array, int arrayIndex) => ((ICollection>)_dictionary).CopyTo(array, arrayIndex); + /// + /// Returns an enumerator that iterates through the . + /// + /// An that enumerates the contents of the . + public Enumerator GetEnumerator() => new(_dictionary.GetEnumerator()); + /// - public IEnumerator> GetEnumerator() => _dictionary.GetEnumerator(); + IEnumerator> IEnumerable>.GetEnumerator() => GetEnumerator(); /// - IEnumerator IEnumerable.GetEnumerator() => _dictionary.GetEnumerator(); + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); /// public bool Remove(string key) => _dictionary.Remove(key); @@ -156,4 +189,59 @@ public bool TryGetValue(string key, [NotNullWhen(true)] out T? value) value = default; return false; } + + /// Enumerates the elements of an . + public struct Enumerator : IEnumerator> + { + /// The wrapped dictionary enumerator. + private Dictionary.Enumerator _dictionaryEnumerator; + + /// Initializes a new instance of the struct with the dictionary enumerator to wrap. + /// The dictionary enumerator to wrap. + internal Enumerator(Dictionary.Enumerator dictionaryEnumerator) + { + _dictionaryEnumerator = dictionaryEnumerator; + } + + /// + public KeyValuePair Current => _dictionaryEnumerator.Current; + + /// + object IEnumerator.Current => Current; + + /// + public void Dispose() => _dictionaryEnumerator.Dispose(); + + /// + public bool MoveNext() => _dictionaryEnumerator.MoveNext(); + + /// + public void Reset() => Reset(ref _dictionaryEnumerator); + + /// Calls on an enumerator. + private static void Reset(ref TEnumerator enumerator) + where TEnumerator : struct, IEnumerator + { + enumerator.Reset(); + } + } + + /// Provides a debugger view for the collection. + private sealed class DebugView(AdditionalPropertiesDictionary properties) + { + private readonly AdditionalPropertiesDictionary _properties = Throw.IfNull(properties); + + [DebuggerBrowsable(DebuggerBrowsableState.RootHidden)] + public AdditionalProperty[] Items => (from p in _properties select new AdditionalProperty(p.Key, p.Value)).ToArray(); + + [DebuggerDisplay("{Value}", Name = "[{Key}]")] + public readonly struct AdditionalProperty(string key, object? value) + { + [DebuggerBrowsable(DebuggerBrowsableState.Collapsed)] + public string Key { get; } = key; + + [DebuggerBrowsable(DebuggerBrowsableState.Collapsed)] + public object? Value { get; } = value; + } + } } diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/CHANGELOG.md b/src/Libraries/Microsoft.Extensions.AI.Abstractions/CHANGELOG.md new file mode 100644 index 00000000000..6b347a8c09d --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/CHANGELOG.md @@ -0,0 +1,19 @@ +# Release History + +## 9.0.0-preview.9.24525.1 + +- Lowered the required version of System.Text.Json to 8.0.5 when targeting net8.0 or older. +- Annotated `FunctionCallContent.Exception` and `FunctionResultContent.Exception` as `[JsonIgnore]`, such that they're ignored when serializing instances with `JsonSerializer`. The corresponding constructors accepting an `Exception` were removed. +- Annotated `ChatCompletion.Message` as `[JsonIgnore]`, such that it's ignored when serializing instances with `JsonSerializer`. +- Added the `FunctionCallContent.CreateFromParsedArguments` method. +- Added the `AdditionalPropertiesDictionary.TryGetValue` method. +- Added the `StreamingChatCompletionUpdate.ModelId` property and removed the `AIContent.ModelId` property. +- Renamed the `GenerateAsync` extension method on `IEmbeddingGenerator<,>` to `GenerateEmbeddingsAsync` and updated it to return `Embedding` rather than `GeneratedEmbeddings`. +- Added `GenerateAndZipAsync` and `GenerateEmbeddingVectorAsync` extension methods for `IEmbeddingGenerator<,>`. +- Added the `EmbeddingGeneratorOptions.Dimensions` property. +- Added the `ChatOptions.TopK` property. +- Normalized `null` inputs in `TextContent` to be empty strings. + +## 9.0.0-preview.9.24507.7 + +Initial Preview diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatOptions.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatOptions.cs index 4edbed900b4..0a4f6f58296 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatOptions.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatOptions.cs @@ -27,6 +27,9 @@ public class ChatOptions /// Gets or sets the presence penalty for generating chat responses. public float? PresencePenalty { get; set; } + /// Gets or sets a seed value used by a service to control the reproducability of results. + public long? Seed { get; set; } + /// /// Gets or sets the response format for the chat request. /// @@ -74,6 +77,7 @@ public virtual ChatOptions Clone() TopK = TopK, FrequencyPenalty = FrequencyPenalty, PresencePenalty = PresencePenalty, + Seed = Seed, ResponseFormat = ResponseFormat, ModelId = ModelId, ToolMode = ToolMode, diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Microsoft.Extensions.AI.Abstractions.csproj b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Microsoft.Extensions.AI.Abstractions.csproj index 8f00d6b9271..b96b4dca920 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Microsoft.Extensions.AI.Abstractions.csproj +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Microsoft.Extensions.AI.Abstractions.csproj @@ -17,9 +17,11 @@ $(TargetFrameworks);netstandard2.0 $(NoWarn);CA2227;CA1034;SA1316;S3253 true + true + true true true true diff --git a/src/Libraries/Microsoft.Extensions.AI/Utilities/AIJsonSchemaCreateOptions.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonSchemaCreateOptions.cs similarity index 100% rename from src/Libraries/Microsoft.Extensions.AI/Utilities/AIJsonSchemaCreateOptions.cs rename to src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonSchemaCreateOptions.cs diff --git a/src/Libraries/Microsoft.Extensions.AI/Utilities/AIJsonUtilities.Defaults.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Defaults.cs similarity index 98% rename from src/Libraries/Microsoft.Extensions.AI/Utilities/AIJsonUtilities.Defaults.cs rename to src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Defaults.cs index 94340160cb1..de2c2a695b6 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Utilities/AIJsonUtilities.Defaults.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Defaults.cs @@ -23,11 +23,11 @@ private static JsonSerializerOptions CreateDefaultOptions() { // If reflection-based serialization is enabled by default, use it, as it's the most permissive in terms of what it can serialize, // and we want to be flexible in terms of what can be put into the various collections in the object model. - // Otherwise, use the source-generated options to enable Native AOT. + // Otherwise, use the source-generated options to enable trimming and Native AOT. if (JsonSerializer.IsReflectionEnabledByDefault) { - // Keep in sync with the JsonSourceGenerationOptions on JsonContext below. + // Keep in sync with the JsonSourceGenerationOptions attribute on JsonContext below. JsonSerializerOptions options = new(JsonSerializerDefaults.Web) { TypeInfoResolver = new DefaultJsonTypeInfoResolver(), diff --git a/src/Libraries/Microsoft.Extensions.AI/Utilities/AIJsonUtilities.Schema.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.cs similarity index 81% rename from src/Libraries/Microsoft.Extensions.AI/Utilities/AIJsonUtilities.Schema.cs rename to src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.cs index 46fe45342f2..b555148df8b 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Utilities/AIJsonUtilities.Schema.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.cs @@ -5,17 +5,22 @@ using System.Collections.Concurrent; using System.ComponentModel; using System.Diagnostics; +#if !NET9_0_OR_GREATER +using System.Diagnostics.CodeAnalysis; +#endif using System.Linq; using System.Reflection; using System.Runtime.CompilerServices; using System.Text.Json; using System.Text.Json.Nodes; using System.Text.Json.Schema; +using System.Text.Json.Serialization; using Microsoft.Shared.Diagnostics; #pragma warning disable S1121 // Assignments should not be made from within sub-expressions #pragma warning disable S107 // Methods should not have too many parameters #pragma warning disable S1075 // URIs should not be hardcoded +#pragma warning disable SA1118 // Parameter should not span multiple lines using FunctionParameterKey = ( System.Type? Type, @@ -138,8 +143,6 @@ public static JsonElement CreateJsonSchema( JsonSerializerOptions? serializerOptions = null, AIJsonSchemaCreateOptions? inferenceOptions = null) { - _ = Throw.IfNull(serializerOptions); - serializerOptions ??= DefaultOptions; inferenceOptions ??= AIJsonSchemaCreateOptions.Default; @@ -176,6 +179,11 @@ private static JsonElement GetJsonSchemaCached(JsonSerializerOptions options, Fu #endif } +#if !NET9_0_OR_GREATER + [UnconditionalSuppressMessage("Trimming", "IL2026:Members annotated with 'RequiresUnreferencedCodeAttribute' require dynamic access", + Justification = "Pre STJ-9 schema extraction can fail with a runtime exception if certain reflection metadata have been trimmed. " + + "The exception message will guide users to turn off 'IlcTrimMetadata' which resolves all issues.")] +#endif private static JsonElement GetJsonSchemaCore(JsonSerializerOptions options, FunctionParameterKey key) { _ = Throw.IfNull(options); @@ -238,16 +246,9 @@ JsonNode TransformSchemaNode(JsonSchemaExporterContext ctx, JsonNode schema) const string DefaultPropertyName = "default"; const string RefPropertyName = "$ref"; - // Find the first DescriptionAttribute, starting first from the property, then the parameter, and finally the type itself. - Type descAttrType = typeof(DescriptionAttribute); - var descriptionAttribute = - GetAttrs(descAttrType, ctx.PropertyInfo?.AttributeProvider)?.FirstOrDefault() ?? - GetAttrs(descAttrType, ctx.PropertyInfo?.AssociatedParameter?.AttributeProvider)?.FirstOrDefault() ?? - GetAttrs(descAttrType, ctx.TypeInfo.Type)?.FirstOrDefault(); - - if (descriptionAttribute is DescriptionAttribute attr) + if (ctx.ResolveAttribute() is { } attr) { - ConvertSchemaToObject(ref schema).Insert(0, DescriptionPropertyName, (JsonNode)attr.Description); + ConvertSchemaToObject(ref schema).InsertAtStart(DescriptionPropertyName, (JsonNode)attr.Description); } if (schema is JsonObject objSchema) @@ -270,7 +271,7 @@ JsonNode TransformSchemaNode(JsonSchemaExporterContext ctx, JsonNode schema) // Include the type keyword in enum types if (key.IncludeTypeInEnumSchemas && ctx.TypeInfo.Type.IsEnum && objSchema.ContainsKey(EnumPropertyName) && !objSchema.ContainsKey(TypePropertyName)) { - objSchema.Insert(0, TypePropertyName, "string"); + objSchema.InsertAtStart(TypePropertyName, "string"); } // Disallow additional properties in object schemas @@ -278,24 +279,24 @@ JsonNode TransformSchemaNode(JsonSchemaExporterContext ctx, JsonNode schema) { objSchema.Add(AdditionalPropertiesPropertyName, (JsonNode)false); } - } - - if (ctx.Path.IsEmpty) - { - // We are at the root-level schema node, update/append parameter-specific metadata // Some consumers of the JSON schema, including Ollama as of v0.3.13, don't understand // schemas with "type": [...], and only understand "type" being a single value. // STJ represents .NET integer types as ["string", "integer"], which will then lead to an error. - if (TypeIsArrayContainingInteger(schema)) + if (TypeIsIntegerWithStringNumberHandling(ctx, objSchema)) { // We don't want to emit any array for "type". In this case we know it contains "integer" // so reduce the type to that alone, assuming it's the most specific type. - // This makes schemas for Int32 (etc) work with Ollama + // This makes schemas for Int32 (etc) work with Ollama. JsonObject obj = ConvertSchemaToObject(ref schema); obj[TypePropertyName] = "integer"; _ = obj.Remove(PatternPropertyName); } + } + + if (ctx.Path.IsEmpty) + { + // We are at the root-level schema node, update/append parameter-specific metadata if (!string.IsNullOrWhiteSpace(key.Description)) { @@ -305,7 +306,7 @@ JsonNode TransformSchemaNode(JsonSchemaExporterContext ctx, JsonNode schema) if (index < 0) { // If there's no description property, insert it at the beginning of the doc. - obj.Insert(0, DescriptionPropertyName, (JsonNode)key.Description!); + obj.InsertAtStart(DescriptionPropertyName, (JsonNode)key.Description!); } else { @@ -323,15 +324,12 @@ JsonNode TransformSchemaNode(JsonSchemaExporterContext ctx, JsonNode schema) if (key.IncludeSchemaUri) { // The $schema property must be the first keyword in the object - ConvertSchemaToObject(ref schema).Insert(0, SchemaPropertyName, (JsonNode)SchemaKeywordUri); + ConvertSchemaToObject(ref schema).InsertAtStart(SchemaPropertyName, (JsonNode)SchemaKeywordUri); } } return schema; - static object[]? GetAttrs(Type attrType, ICustomAttributeProvider? provider) => - provider?.GetCustomAttributes(attrType, inherit: false); - static JsonObject ConvertSchemaToObject(ref JsonNode schema) { JsonObject obj; @@ -354,22 +352,82 @@ static JsonObject ConvertSchemaToObject(ref JsonNode schema) } } - private static bool TypeIsArrayContainingInteger(JsonNode schema) + private static bool TypeIsIntegerWithStringNumberHandling(JsonSchemaExporterContext ctx, JsonObject schema) { - if (schema["type"] is JsonArray typeArray) + if (ctx.TypeInfo.NumberHandling is not JsonNumberHandling.Strict && schema["type"] is JsonArray typeArray) { - foreach (var entry in typeArray) + int count = 0; + foreach (JsonNode? entry in typeArray) { - if (entry?.GetValueKind() == JsonValueKind.String && entry.GetValue() == "integer") + if (entry?.GetValueKind() is JsonValueKind.String && + entry.GetValue() is "integer" or "string") { - return true; + count++; } } + + return count == typeArray.Count; } return false; } + private static void InsertAtStart(this JsonObject jsonObject, string key, JsonNode value) + { +#if NET9_0_OR_GREATER + jsonObject.Insert(0, key, value); +#else + jsonObject.Remove(key); + var copiedEntries = jsonObject.ToArray(); + jsonObject.Clear(); + + jsonObject.Add(key, value); + foreach (var entry in copiedEntries) + { + jsonObject[entry.Key] = entry.Value; + } +#endif + } + +#if !NET9_0_OR_GREATER + private static int IndexOf(this JsonObject jsonObject, string key) + { + int i = 0; + foreach (var entry in jsonObject) + { + if (string.Equals(entry.Key, key, StringComparison.Ordinal)) + { + return i; + } + + i++; + } + + return -1; + } +#endif + + private static TAttribute? ResolveAttribute(this JsonSchemaExporterContext ctx) + where TAttribute : Attribute + { + // Resolve attributes from locations in the following order: + // 1. Property-level attributes + // 2. Parameter-level attributes and + // 3. Type-level attributes. + return +#if NET9_0_OR_GREATER + GetAttrs(ctx.PropertyInfo?.AttributeProvider) ?? + GetAttrs(ctx.PropertyInfo?.AssociatedParameter?.AttributeProvider) ?? +#else + GetAttrs(ctx.PropertyAttributeProvider) ?? + GetAttrs(ctx.ParameterInfo) ?? +#endif + GetAttrs(ctx.TypeInfo.Type); + + static TAttribute? GetAttrs(ICustomAttributeProvider? provider) => + (TAttribute?)provider?.GetCustomAttributes(typeof(TAttribute), inherit: false).FirstOrDefault(); + } + private static JsonElement ParseJsonElement(ReadOnlySpan utf8Json) { Utf8JsonReader reader = new(utf8Json); diff --git a/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceChatClient.cs b/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceChatClient.cs index ecc41140b27..ba76f5c3c90 100644 --- a/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceChatClient.cs @@ -285,6 +285,7 @@ private ChatCompletionsOptions ToAzureAIOptions(IList chatContents, result.NucleusSamplingFactor = options.TopP; result.PresencePenalty = options.PresencePenalty; result.Temperature = options.Temperature; + result.Seed = options.Seed; if (options.StopSequences is { Count: > 0 } stopSequences) { @@ -306,11 +307,6 @@ private ChatCompletionsOptions ToAzureAIOptions(IList chatContents, { switch (prop.Key) { - // These properties are strongly-typed on the ChatCompletionsOptions class but not on the ChatOptions class. - case nameof(result.Seed) when prop.Value is long seed: - result.Seed = seed; - break; - // Propagate everything else to the ChatCompletionOptions' AdditionalProperties. default: if (prop.Value is not null) diff --git a/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceEmbeddingGenerator.cs index 84198e6b2cc..866e55ad87a 100644 --- a/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceEmbeddingGenerator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceEmbeddingGenerator.cs @@ -156,7 +156,7 @@ private EmbeddingsOptions ToAzureAIOptions(IEnumerable inputs, Embedding { EmbeddingsOptions result = new(inputs) { - Dimensions = _dimensions, + Dimensions = options?.Dimensions ?? _dimensions, Model = options?.ModelId ?? Metadata.ModelId, EncodingFormat = format, }; diff --git a/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/CHANGELOG.md b/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/CHANGELOG.md new file mode 100644 index 00000000000..7929cc7e8b2 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/CHANGELOG.md @@ -0,0 +1,12 @@ +# Release History + +## 9.0.0-preview.9.24525.1 + +- Lowered the required version of System.Text.Json to 8.0.5 when targeting net8.0 or older. +- Updated to use Azure.AI.Inference 1.0.0-beta.2. +- Added `AzureAIInferenceEmbeddingGenerator` and corresponding `AsEmbeddingGenerator` extension method. +- Improved handling of assistant messages that include both text and function call content. + +## 9.0.0-preview.9.24507.7 + +Initial Preview diff --git a/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/JsonContext.cs b/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/JsonContext.cs index 5576cbf134a..1e1dabffab7 100644 --- a/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/JsonContext.cs +++ b/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/JsonContext.cs @@ -48,11 +48,11 @@ private static JsonSerializerOptions CreateDefaultToolJsonOptions() { // If reflection-based serialization is enabled by default, use it, as it's the most permissive in terms of what it can serialize, // and we want to be flexible in terms of what can be put into the various collections in the object model. - // Otherwise, use the source-generated options to enable Native AOT. + // Otherwise, use the source-generated options to enable trimming and Native AOT. if (JsonSerializer.IsReflectionEnabledByDefault) { - // Keep in sync with the JsonSourceGenerationOptions on JsonContext below. + // Keep in sync with the JsonSourceGenerationOptions attribute on JsonContext above. JsonSerializerOptions options = new(JsonSerializerDefaults.Web) { TypeInfoResolver = new DefaultJsonTypeInfoResolver(), diff --git a/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/Microsoft.Extensions.AI.AzureAIInference.csproj b/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/Microsoft.Extensions.AI.AzureAIInference.csproj index 3a66e7837f2..0e3f60b8db3 100644 --- a/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/Microsoft.Extensions.AI.AzureAIInference.csproj +++ b/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/Microsoft.Extensions.AI.AzureAIInference.csproj @@ -17,6 +17,7 @@ $(TargetFrameworks);netstandard2.0 $(NoWarn);CA1063;CA2227;SA1316;S1067;S1121;S3358 true + true @@ -29,6 +30,7 @@ + diff --git a/src/Libraries/Microsoft.Extensions.AI.Ollama/CHANGELOG.md b/src/Libraries/Microsoft.Extensions.AI.Ollama/CHANGELOG.md new file mode 100644 index 00000000000..ffb35814039 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Ollama/CHANGELOG.md @@ -0,0 +1,10 @@ +# Release History + +## 9.0.0-preview.9.24525.1 + +- Lowered the required version of System.Text.Json to 8.0.5 when targeting net8.0 or older. +- Added additional constructors to `OllamaChatClient` and `OllamaEmbeddingGenerator` that accept `string` endpoints, in addition to the existing ones accepting `Uri` endpoints. + +## 9.0.0-preview.9.24507.7 + +Initial Preview diff --git a/src/Libraries/Microsoft.Extensions.AI.Ollama/Microsoft.Extensions.AI.Ollama.csproj b/src/Libraries/Microsoft.Extensions.AI.Ollama/Microsoft.Extensions.AI.Ollama.csproj index ad3064c8a66..018184d6bf0 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Ollama/Microsoft.Extensions.AI.Ollama.csproj +++ b/src/Libraries/Microsoft.Extensions.AI.Ollama/Microsoft.Extensions.AI.Ollama.csproj @@ -17,6 +17,7 @@ $(TargetFrameworks);netstandard2.0 $(NoWarn);CA2227;SA1316;S1121;EA0002 true + true diff --git a/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatClient.cs b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatClient.cs index 72ddb13b2ac..18ff5d50b7c 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatClient.cs @@ -273,7 +273,6 @@ private OllamaChatRequest ToOllamaChatRequest(IList chatMessages, C TransferMetadataValue(nameof(OllamaRequestOptions.penalize_newline), (options, value) => options.penalize_newline = value); TransferMetadataValue(nameof(OllamaRequestOptions.repeat_last_n), (options, value) => options.repeat_last_n = value); TransferMetadataValue(nameof(OllamaRequestOptions.repeat_penalty), (options, value) => options.repeat_penalty = value); - TransferMetadataValue(nameof(OllamaRequestOptions.seed), (options, value) => options.seed = value); TransferMetadataValue(nameof(OllamaRequestOptions.tfs_z), (options, value) => options.tfs_z = value); TransferMetadataValue(nameof(OllamaRequestOptions.typical_p), (options, value) => options.typical_p = value); TransferMetadataValue(nameof(OllamaRequestOptions.use_mmap), (options, value) => options.use_mmap = value); @@ -314,6 +313,11 @@ private OllamaChatRequest ToOllamaChatRequest(IList chatMessages, C { (request.Options ??= new()).top_k = topK; } + + if (options.Seed is long seed) + { + (request.Options ??= new()).seed = seed; + } } return request; diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/CHANGELOG.md b/src/Libraries/Microsoft.Extensions.AI.OpenAI/CHANGELOG.md new file mode 100644 index 00000000000..179da41a0b0 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/CHANGELOG.md @@ -0,0 +1,12 @@ +# Release History + +## 9.0.0-preview.9.24525.1 + +- Lowered the required version of System.Text.Json to 8.0.5 when targeting net8.0 or older. +- Improved handling of system messages that include multiple content items. +- Improved handling of assistant messages that include both text and function call content. +- Fixed handling of streaming updates containing empty payloads. + +## 9.0.0-preview.9.24507.7 + +Initial Preview diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/Microsoft.Extensions.AI.OpenAI.csproj b/src/Libraries/Microsoft.Extensions.AI.OpenAI/Microsoft.Extensions.AI.OpenAI.csproj index 76930738579..f2e2e9c0f52 100644 --- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/Microsoft.Extensions.AI.OpenAI.csproj +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/Microsoft.Extensions.AI.OpenAI.csproj @@ -17,6 +17,7 @@ $(TargetFrameworks);netstandard2.0 $(NoWarn);CA1063;CA1508;CA2227;SA1316;S1121;S3358;EA0002 true + true @@ -26,6 +27,7 @@ + diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs index 935bb88f812..985060256f7 100644 --- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs @@ -3,11 +3,13 @@ using System; using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; using System.Reflection; using System.Runtime.CompilerServices; using System.Text; using System.Text.Json; using System.Text.Json.Serialization; +using System.Text.Json.Serialization.Metadata; using System.Threading; using System.Threading.Tasks; using Microsoft.Shared.Diagnostics; @@ -265,8 +267,7 @@ public async IAsyncEnumerable CompleteStreamingAs existing.CallId ??= toolCallUpdate.ToolCallId; existing.Name ??= toolCallUpdate.FunctionName; - if (toolCallUpdate.FunctionArgumentsUpdate is { } update && - !update.ToMemory().IsEmpty) // workaround for https://github.com/dotnet/runtime/issues/68262 in 6.0.0 package + if (toolCallUpdate.FunctionArgumentsUpdate is { } update && !update.ToMemory().IsEmpty) { _ = (existing.Arguments ??= new()).Append(update.ToString()); } @@ -391,6 +392,9 @@ private static ChatCompletionOptions ToOpenAIOptions(ChatOptions? options) result.TopP = options.TopP; result.PresencePenalty = options.PresencePenalty; result.Temperature = options.Temperature; +#pragma warning disable OPENAI001 // Type is for evaluation purposes only and is subject to change or removal in future updates. + result.Seed = options.Seed; +#pragma warning restore OPENAI001 if (options.StopSequences is { Count: > 0 } stopSequences) { @@ -425,13 +429,6 @@ private static ChatCompletionOptions ToOpenAIOptions(ChatOptions? options) result.AllowParallelToolCalls = allowParallelToolCalls; } -#pragma warning disable OPENAI001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. - if (additionalProperties.TryGetValue(nameof(result.Seed), out long seed)) - { - result.Seed = seed; - } -#pragma warning restore OPENAI001 - if (additionalProperties.TryGetValue(nameof(result.TopLogProbabilityCount), out int topLogProbabilityCountInt)) { result.TopLogProbabilityCount = topLogProbabilityCountInt; @@ -587,10 +584,9 @@ private sealed class OpenAIChatToolJson string? result = resultContent.Result as string; if (result is null && resultContent.Result is not null) { - JsonSerializerOptions options = ToolCallJsonSerializerOptions ?? JsonContext.Default.Options; try { - result = JsonSerializer.Serialize(resultContent.Result, options.GetTypeInfo(typeof(object))); + result = JsonSerializer.Serialize(resultContent.Result, JsonContext.GetTypeInfo(typeof(object), ToolCallJsonSerializerOptions)); } catch (NotSupportedException) { @@ -617,7 +613,9 @@ private sealed class OpenAIChatToolJson ChatToolCall.CreateFunctionToolCall( callRequest.CallId, callRequest.Name, - BinaryData.FromObjectAsJson(callRequest.Arguments, ToolCallJsonSerializerOptions))); + new(JsonSerializer.SerializeToUtf8Bytes( + callRequest.Arguments, + JsonContext.GetTypeInfo(typeof(IDictionary), ToolCallJsonSerializerOptions))))); } } @@ -670,8 +668,53 @@ private static FunctionCallContent ParseCallContentFromBinaryData(BinaryData ut8 argumentParser: static json => JsonSerializer.Deserialize(json, JsonContext.Default.IDictionaryStringObject)!); /// Source-generated JSON type information. + [JsonSourceGenerationOptions(JsonSerializerDefaults.Web, + UseStringEnumConverter = true, + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, + WriteIndented = true)] [JsonSerializable(typeof(OpenAIChatToolJson))] [JsonSerializable(typeof(IDictionary))] [JsonSerializable(typeof(JsonElement))] - private sealed partial class JsonContext : JsonSerializerContext; + private sealed partial class JsonContext : JsonSerializerContext + { + /// Gets the singleton used as the default in JSON serialization operations. + private static readonly JsonSerializerOptions _defaultToolJsonOptions = CreateDefaultToolJsonOptions(); + + /// Gets JSON type information for the specified type. + /// + /// This first tries to get the type information from , + /// falling back to if it can't. + /// + public static JsonTypeInfo GetTypeInfo(Type type, JsonSerializerOptions? firstOptions) => + firstOptions?.TryGetTypeInfo(type, out JsonTypeInfo? info) is true ? + info : + _defaultToolJsonOptions.GetTypeInfo(type); + + /// Creates the default to use for serialization-related operations. + [UnconditionalSuppressMessage("AotAnalysis", "IL3050", Justification = "DefaultJsonTypeInfoResolver is only used when reflection-based serialization is enabled")] + [UnconditionalSuppressMessage("ReflectionAnalysis", "IL2026", Justification = "DefaultJsonTypeInfoResolver is only used when reflection-based serialization is enabled")] + private static JsonSerializerOptions CreateDefaultToolJsonOptions() + { + // If reflection-based serialization is enabled by default, use it, as it's the most permissive in terms of what it can serialize, + // and we want to be flexible in terms of what can be put into the various collections in the object model. + // Otherwise, use the source-generated options to enable trimming and Native AOT. + + if (JsonSerializer.IsReflectionEnabledByDefault) + { + // Keep in sync with the JsonSourceGenerationOptions attribute on JsonContext above. + JsonSerializerOptions options = new(JsonSerializerDefaults.Web) + { + TypeInfoResolver = new DefaultJsonTypeInfoResolver(), + Converters = { new JsonStringEnumConverter() }, + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, + WriteIndented = true, + }; + + options.MakeReadOnly(); + return options; + } + + return Default.Options; + } + } } diff --git a/src/Libraries/Microsoft.Extensions.AI/CHANGELOG.md b/src/Libraries/Microsoft.Extensions.AI/CHANGELOG.md new file mode 100644 index 00000000000..e2dae2e6e37 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/CHANGELOG.md @@ -0,0 +1,17 @@ +# Release History + +## 9.0.0-preview.9.24525.1 + +- Added new `AIJsonUtilities` and `AIJsonSchemaCreateOptions` classes. +- Made `AIFunctionFactory.Create` safe for use with Native AOT. +- Simplified the set of `AIFunctionFactory.Create` overloads. +- Changed the default for `FunctionInvokingChatClient.ConcurrentInvocation` from `true` to `false`. +- Improved the readability of JSON generated as part of logging. +- Fixed handling of generated JSON schema names when using arrays or generic types. +- Improved `CachingChatClient`'s coalescing of streaming updates, including reduced memory allocation and enhanced metadata propagation. +- Updated `OpenTelemetryChatClient` and `OpenTelemetryEmbeddingGenerator` to conform to the latest 1.28.0 draft specification of the Semantic Conventions for Generative AI systems. +- Improved `CompleteAsync`'s structured output support to handle primitive types, enums, and arrays. + +## 9.0.0-preview.9.24507.7 + +Initial Preview diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ConfigureOptionsChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ConfigureOptionsChatClient.cs index 895bf8873df..990c92d3ad9 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ConfigureOptionsChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ConfigureOptionsChatClient.cs @@ -17,7 +17,7 @@ namespace Microsoft.Extensions.AI; /// /// The configuration callback is invoked with the caller-supplied instance. To override the caller-supplied options /// with a new instance, the callback may simply return that new instance, for example _ => new ChatOptions() { MaxTokens = 1000 }. To provide -/// a new instance only if the caller-supplied instance is `null`, the callback may conditionally return a new instance, for example +/// a new instance only if the caller-supplied instance is , the callback may conditionally return a new instance, for example /// options => options ?? new ChatOptions() { MaxTokens = 1000 }. Any changes to the caller-provided options instance will persist on the /// original instance, so the callback must take care to only do so when such mutations are acceptable, such as by cloning the original instance /// and mutating the clone, for example: @@ -31,6 +31,9 @@ namespace Microsoft.Extensions.AI; /// /// /// +/// The callback may return , in which case a options will be passed to the next client in the pipeline. +/// +/// /// The provided implementation of is thread-safe for concurrent use so long as the employed configuration /// callback is also thread-safe for concurrent requests. If callers employ a shared options instance, care should be taken in the /// configuration callback, as multiple calls to it may end up running in parallel with the same options instance. @@ -39,7 +42,7 @@ namespace Microsoft.Extensions.AI; public sealed class ConfigureOptionsChatClient : DelegatingChatClient { /// The callback delegate used to configure options. - private readonly Func _configureOptions; + private readonly Func _configureOptions; /// Initializes a new instance of the class with the specified callback. /// The inner client. @@ -47,7 +50,7 @@ public sealed class ConfigureOptionsChatClient : DelegatingChatClient /// The delegate to invoke to configure the instance. It is passed the caller-supplied /// instance and should return the configured instance to use. /// - public ConfigureOptionsChatClient(IChatClient innerClient, Func configureOptions) + public ConfigureOptionsChatClient(IChatClient innerClient, Func configureOptions) : base(innerClient) { _configureOptions = Throw.IfNull(configureOptions); diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ConfigureOptionsChatClientBuilderExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ConfigureOptionsChatClientBuilderExtensions.cs index 12b903c0dac..2d98fbd9003 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ConfigureOptionsChatClientBuilderExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ConfigureOptionsChatClientBuilderExtensions.cs @@ -21,9 +21,10 @@ public static class ConfigureOptionsChatClientBuilderExtensions /// /// The . /// + /// /// The configuration callback is invoked with the caller-supplied instance. To override the caller-supplied options /// with a new instance, the callback may simply return that new instance, for example _ => new ChatOptions() { MaxTokens = 1000 }. To provide - /// a new instance only if the caller-supplied instance is `null`, the callback may conditionally return a new instance, for example + /// a new instance only if the caller-supplied instance is , the callback may conditionally return a new instance, for example /// options => options ?? new ChatOptions() { MaxTokens = 1000 }. Any changes to the caller-provided options instance will persist on the /// original instance, so the callback must take care to only do so when such mutations are acceptable, such as by cloning the original instance /// and mutating the clone, for example: @@ -35,9 +36,13 @@ public static class ConfigureOptionsChatClientBuilderExtensions /// return newOptions; /// } /// + /// + /// + /// The callback may return , in which case a options will be passed to the next client in the pipeline. + /// /// public static ChatClientBuilder UseChatOptions( - this ChatClientBuilder builder, Func configureOptions) + this ChatClientBuilder builder, Func configureOptions) { _ = Throw.IfNull(builder); _ = Throw.IfNull(configureOptions); diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/OpenTelemetryChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/OpenTelemetryChatClient.cs index 905e756e246..a6dfe53adf5 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/OpenTelemetryChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/OpenTelemetryChatClient.cs @@ -322,7 +322,7 @@ private static ChatCompletion ComposeStreamingUpdatesIntoChatCompletion( _ = activity.AddTag(OpenTelemetryConsts.GenAI.Request.PerProvider(_system, "response_format"), responseFormat); } - if (options.AdditionalProperties?.TryGetValue("seed", out long seed) is true) + if (options.Seed is long seed) { _ = activity.AddTag(OpenTelemetryConsts.GenAI.Request.PerProvider(_system, "seed"), seed); } diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/ConfigureOptionsEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/ConfigureOptionsEmbeddingGenerator.cs new file mode 100644 index 00000000000..9068ac41caa --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/ConfigureOptionsEmbeddingGenerator.cs @@ -0,0 +1,75 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Shared.Diagnostics; + +#pragma warning disable SA1629 // Documentation text should end with a period + +namespace Microsoft.Extensions.AI; + +/// A delegating embedding generator that updates or replaces the used by the remainder of the pipeline. +/// Specifies the type of the input passed to the generator. +/// Specifies the type of the embedding instance produced by the generator. +/// +/// +/// The configuration callback is invoked with the caller-supplied instance. To override the caller-supplied options +/// with a new instance, the callback may simply return that new instance, for example _ => new EmbeddingGenerationOptions() { Dimensions = 100 }. To provide +/// a new instance only if the caller-supplied instance is , the callback may conditionally return a new instance, for example +/// options => options ?? new EmbeddingGenerationOptions() { Dimensions = 100 }. Any changes to the caller-provided options instance will persist on the +/// original instance, so the callback must take care to only do so when such mutations are acceptable, such as by cloning the original instance +/// and mutating the clone, for example: +/// +/// options => +/// { +/// var newOptions = options?.Clone() ?? new(); +/// newOptions.Dimensions = 100; +/// return newOptions; +/// } +/// +/// +/// +/// The callback may return , in which case a options will be passed to the next generator in the pipeline. +/// +/// +/// The provided implementation of is thread-safe for concurrent use so long as the employed configuration +/// callback is also thread-safe for concurrent requests. If callers employ a shared options instance, care should be taken in the +/// configuration callback, as multiple calls to it may end up running in parallel with the same options instance. +/// +/// +public sealed class ConfigureOptionsEmbeddingGenerator : DelegatingEmbeddingGenerator + where TEmbedding : Embedding +{ + /// The callback delegate used to configure options. + private readonly Func _configureOptions; + + /// + /// Initializes a new instance of the class with the + /// specified callback. + /// + /// The inner generator. + /// + /// The delegate to invoke to configure the instance. It is passed the caller-supplied + /// instance and should return the configured instance to use. + /// + public ConfigureOptionsEmbeddingGenerator( + IEmbeddingGenerator innerGenerator, + Func configureOptions) + : base(innerGenerator) + { + _configureOptions = Throw.IfNull(configureOptions); + } + + /// + public override async Task> GenerateAsync( + IEnumerable values, + EmbeddingGenerationOptions? options = null, + CancellationToken cancellationToken = default) + { + return await base.GenerateAsync(values, _configureOptions(options), cancellationToken).ConfigureAwait(false); + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/ConfigureOptionsEmbeddingGeneratorBuilderExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/ConfigureOptionsEmbeddingGeneratorBuilderExtensions.cs new file mode 100644 index 00000000000..011f4c058e9 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/ConfigureOptionsEmbeddingGeneratorBuilderExtensions.cs @@ -0,0 +1,56 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using Microsoft.Shared.Diagnostics; + +#pragma warning disable SA1629 // Documentation text should end with a period + +namespace Microsoft.Extensions.AI; + +/// Provides extensions for configuring instances. +public static class ConfigureOptionsEmbeddingGeneratorBuilderExtensions +{ + /// + /// Adds a callback that updates or replaces . This can be used to set default options. + /// + /// Specifies the type of the input passed to the generator. + /// Specifies the type of the embedding instance produced by the generator. + /// The . + /// + /// The delegate to invoke to configure the instance. It is passed the caller-supplied + /// instance and should return the configured instance to use. + /// + /// The . + /// + /// + /// The configuration callback is invoked with the caller-supplied instance. To override the caller-supplied options + /// with a new instance, the callback may simply return that new instance, for example _ => new EmbeddingGenerationOptions() { Dimensions = 100 }. To provide + /// a new instance only if the caller-supplied instance is , the callback may conditionally return a new instance, for example + /// options => options ?? new EmbeddingGenerationOptions() { Dimensions = 100 }. Any changes to the caller-provided options instance will persist on the + /// original instance, so the callback must take care to only do so when such mutations are acceptable, such as by cloning the original instance + /// and mutating the clone, for example: + /// + /// options => + /// { + /// var newOptions = options?.Clone() ?? new(); + /// newOptions.Dimensions = 100; + /// return newOptions; + /// } + /// + /// + /// + /// The callback may return , in which case a options will be passed to the next generator in the pipeline. + /// + /// + public static EmbeddingGeneratorBuilder UseEmbeddingGenerationOptions( + this EmbeddingGeneratorBuilder builder, + Func configureOptions) + where TEmbedding : Embedding + { + _ = Throw.IfNull(builder); + _ = Throw.IfNull(configureOptions); + + return builder.Use(innerGenerator => new ConfigureOptionsEmbeddingGenerator(innerGenerator, configureOptions)); + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.csproj b/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.csproj index 2dfd7347ea8..e4ebd6198a7 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.csproj +++ b/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.csproj @@ -19,6 +19,7 @@ $(TargetFrameworks);netstandard2.0 $(NoWarn);CA2227;CA1034;SA1316;S1067;S1121;S1994;S3253 true + true diff --git a/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.CacheItem.cs b/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.CacheItem.cs index 5585b9b2a29..05edc65dc06 100644 --- a/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.CacheItem.cs +++ b/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.CacheItem.cs @@ -5,6 +5,7 @@ using System.Diagnostics; using System.Threading; using Microsoft.Extensions.Caching.Memory; +using Microsoft.Extensions.Logging; namespace Microsoft.Extensions.Caching.Hybrid.Internal; @@ -22,7 +23,7 @@ internal abstract class CacheItem // zero. // This counter also drives cache lifetime, with the cache itself incrementing the count by one. In the // case of mutable data, cache eviction may reduce this to zero (in cooperation with any concurrent readers, - // who incr/decr around their fetch), allowing safe buffer recycling. + // who increment/decrement around their fetch), allowing safe buffer recycling. internal int RefCount => Volatile.Read(ref _refCount); @@ -89,13 +90,18 @@ internal abstract class CacheItem : CacheItem { public abstract bool TryGetSize(out long size); - // attempt to get a value that was *not* previously reserved - public abstract bool TryGetValue(out T value); + // Attempt to get a value that was *not* previously reserved. + // Note on ILogger usage: we don't want to propagate and store this everywhere. + // It is used for reporting deserialization problems - pass it as needed. + // (CacheItem gets into the IMemoryCache - let's minimize the onward reachable set + // of that cache, by only handing it leaf nodes of a "tree", not a "graph" with + // backwards access - we can also limit object size at the same time) + public abstract bool TryGetValue(ILogger log, out T value); // get a value that *was* reserved, countermanding our reservation in the process - public T GetReservedValue() + public T GetReservedValue(ILogger log) { - if (!TryGetValue(out var value)) + if (!TryGetValue(log, out var value)) { Throw(); } diff --git a/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.ImmutableCacheItem.cs b/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.ImmutableCacheItem.cs index 9ae8468ba29..2e803d87ad6 100644 --- a/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.ImmutableCacheItem.cs +++ b/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.ImmutableCacheItem.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Threading; +using Microsoft.Extensions.Logging; namespace Microsoft.Extensions.Caching.Hybrid.Internal; @@ -38,7 +39,7 @@ public void SetValue(T value, long size) Size = size; } - public override bool TryGetValue(out T value) + public override bool TryGetValue(ILogger log, out T value) { value = _value; return true; // always available diff --git a/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.L2.cs b/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.L2.cs index 1e694448737..230a657bdc3 100644 --- a/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.L2.cs +++ b/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.L2.cs @@ -16,12 +16,16 @@ internal partial class DefaultHybridCache { [SuppressMessage("Performance", "CA1849:Call async methods when in an async method", Justification = "Manual sync check")] [SuppressMessage("Usage", "VSTHRD003:Avoid awaiting foreign Tasks", Justification = "Manual sync check")] + [SuppressMessage("Design", "CA1031:Do not catch general exception types", Justification = "Explicit async exception handling")] + [SuppressMessage("Reliability", "CA2000:Dispose objects before losing scope", Justification = "Deliberate recycle only on success")] internal ValueTask GetFromL2Async(string key, CancellationToken token) { switch (GetFeatures(CacheFeatures.BackendCache | CacheFeatures.BackendBuffers)) { case CacheFeatures.BackendCache: // legacy byte[]-based + var pendingLegacy = _backendCache!.GetAsync(key, token); + #if NETCOREAPP2_0_OR_GREATER || NETSTANDARD2_1_OR_GREATER if (!pendingLegacy.IsCompletedSuccessfully) #else @@ -36,6 +40,7 @@ internal ValueTask GetFromL2Async(string key, CancellationToken tok case CacheFeatures.BackendCache | CacheFeatures.BackendBuffers: // IBufferWriter-based RecyclableArrayBufferWriter writer = RecyclableArrayBufferWriter.Create(MaximumPayloadBytes); var cache = Unsafe.As(_backendCache!); // type-checked already + var pendingBuffers = cache.TryGetAsync(key, writer, token); if (!pendingBuffers.IsCompletedSuccessfully) { @@ -49,7 +54,7 @@ internal ValueTask GetFromL2Async(string key, CancellationToken tok return new(result); } - return default; + return default; // treat as a "miss" static async Task AwaitedLegacyAsync(Task pending, DefaultHybridCache @this) { @@ -115,6 +120,11 @@ internal void SetL1(string key, CacheItem value, HybridCacheEntryOptions? // commit cacheEntry.Dispose(); + + if (HybridCacheEventSource.Log.IsEnabled()) + { + HybridCacheEventSource.Log.LocalCacheWrite(); + } } } diff --git a/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.MutableCacheItem.cs b/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.MutableCacheItem.cs index 2d02c23b6d8..db95e8c4590 100644 --- a/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.MutableCacheItem.cs +++ b/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.MutableCacheItem.cs @@ -1,14 +1,18 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System; +using Microsoft.Extensions.Logging; + namespace Microsoft.Extensions.Caching.Hybrid.Internal; internal partial class DefaultHybridCache { private sealed partial class MutableCacheItem : CacheItem // used to hold types that require defensive copies { - private IHybridCacheSerializer _serializer = null!; // deferred until SetValue + private IHybridCacheSerializer? _serializer; private BufferChunk _buffer; + private T? _fallbackValue; // only used in the case of serialization failures public override bool NeedsEvictionCallback => _buffer.ReturnToPool; @@ -21,16 +25,27 @@ public void SetValue(ref BufferChunk buffer, IHybridCacheSerializer serialize buffer = default; // we're taking over the lifetime; the caller no longer has it! } - public override bool TryGetValue(out T value) + public void SetFallbackValue(T fallbackValue) + { + _fallbackValue = fallbackValue; + } + + public override bool TryGetValue(ILogger log, out T value) { // only if we haven't already burned if (TryReserve()) { try { - value = _serializer.Deserialize(_buffer.AsSequence()); + var serializer = _serializer; + value = serializer is null ? _fallbackValue! : serializer.Deserialize(_buffer.AsSequence()); return true; } + catch (Exception ex) + { + log.DeserializationFailure(ex); + throw; + } finally { _ = Release(); diff --git a/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.Serialization.cs b/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.Serialization.cs index 523a95e279a..d12b2cce592 100644 --- a/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.Serialization.cs +++ b/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.Serialization.cs @@ -3,7 +3,7 @@ using System; using System.Collections.Concurrent; -using System.Reflection; +using System.Diagnostics.CodeAnalysis; using System.Runtime.CompilerServices; using Microsoft.Extensions.DependencyInjection; @@ -51,4 +51,54 @@ static IHybridCacheSerializer ResolveAndAddSerializer(DefaultHybridCache @thi return serializer; } } + + [SuppressMessage("Design", "CA1031:Do not catch general exception types", Justification = "Intentional for logged failure mode")] + private bool TrySerialize(T value, out BufferChunk buffer, out IHybridCacheSerializer? serializer) + { + // note: also returns the serializer we resolved, because most-any time we want to serialize, we'll also want + // to make sure we use that same instance later (without needing to re-resolve and/or store the entire HC machinery) + + RecyclableArrayBufferWriter? writer = null; + buffer = default; + try + { + writer = RecyclableArrayBufferWriter.Create(MaximumPayloadBytes); // note this lifetime spans the SetL2Async + serializer = GetSerializer(); + + serializer.Serialize(value, writer); + + buffer = new(writer.DetachCommitted(out var length), length, returnToPool: true); // remove buffer ownership from the writer + writer.Dispose(); // we're done with the writer + return true; + } + catch (Exception ex) + { + bool knownCause = false; + + // ^^^ if we know what happened, we can record directly via cause-specific events + // and treat as a handled failure (i.e. return false) - otherwise, we'll bubble + // the fault up a few layers *in addition to* logging in a failure event + + if (writer is not null) + { + if (writer.QuotaExceeded) + { + _logger.MaximumPayloadBytesExceeded(ex, MaximumPayloadBytes); + knownCause = true; + } + + writer.Dispose(); + } + + if (!knownCause) + { + _logger.SerializationFailure(ex); + throw; + } + + buffer = default; + serializer = null; + return false; + } + } } diff --git a/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.StampedeState.cs b/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.StampedeState.cs index eba71774395..e2439357f26 100644 --- a/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.StampedeState.cs +++ b/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.StampedeState.cs @@ -74,8 +74,6 @@ protected StampedeState(DefaultHybridCache cache, in StampedeKey key, CacheItem public abstract void Execute(); - protected int MaximumPayloadBytes => _cache.MaximumPayloadBytes; - public override string ToString() => Key.ToString(); public abstract void SetCanceled(); diff --git a/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.StampedeStateT.cs b/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.StampedeStateT.cs index 4e45acae930..4be5b351485 100644 --- a/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.StampedeStateT.cs +++ b/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.StampedeStateT.cs @@ -6,6 +6,7 @@ using System.Diagnostics.CodeAnalysis; using System.Threading; using System.Threading.Tasks; +using Microsoft.Extensions.Logging; using static Microsoft.Extensions.Caching.Hybrid.Internal.DefaultHybridCache; namespace Microsoft.Extensions.Caching.Hybrid.Internal; @@ -14,7 +15,8 @@ internal partial class DefaultHybridCache { internal sealed class StampedeState : StampedeState { - private const HybridCacheEntryFlags FlagsDisableL1AndL2 = HybridCacheEntryFlags.DisableLocalCacheWrite | HybridCacheEntryFlags.DisableDistributedCacheWrite; + // note on terminology: L1 and L2 are, for brevity, used interchangeably with "local" and "distributed" cache, i.e. `IMemoryCache` and `IDistributedCache` + private const HybridCacheEntryFlags FlagsDisableL1AndL2Write = HybridCacheEntryFlags.DisableLocalCacheWrite | HybridCacheEntryFlags.DisableDistributedCacheWrite; private readonly TaskCompletionSource>? _result; private TState? _state; @@ -76,13 +78,13 @@ public Task ExecuteDirectAsync(in TState state, Func _result?.TrySetCanceled(SharedToken); [SuppressMessage("Usage", "VSTHRD003:Avoid awaiting foreign Tasks", Justification = "Custom task management")] - public ValueTask JoinAsync(CancellationToken token) + public ValueTask JoinAsync(ILogger log, CancellationToken token) { // If the underlying has already completed, and/or our local token can't cancel: we // can simply wrap the shared task; otherwise, we need our own cancellation state. - return token.CanBeCanceled && !Task.IsCompleted ? WithCancellationAsync(this, token) : UnwrapReservedAsync(); + return token.CanBeCanceled && !Task.IsCompleted ? WithCancellationAsync(log, this, token) : UnwrapReservedAsync(log); - static async ValueTask WithCancellationAsync(StampedeState stampede, CancellationToken token) + static async ValueTask WithCancellationAsync(ILogger log, StampedeState stampede, CancellationToken token) { var cancelStub = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); using var reg = token.Register(static obj => @@ -112,7 +114,7 @@ static async ValueTask WithCancellationAsync(StampedeState stamped } // outside the catch, so we know we only decrement one way or the other - return result.GetReservedValue(); + return result.GetReservedValue(log); } } @@ -133,7 +135,7 @@ static Task> InvalidAsync() => System.Threading.Tasks.Task.FromExce [SuppressMessage("Performance", "CA1849:Call async methods when in an async method", Justification = "Checked manual unwrap")] [SuppressMessage("Usage", "VSTHRD003:Avoid awaiting foreign Tasks", Justification = "Checked manual unwrap")] [SuppressMessage("Major Code Smell", "S1121:Assignments should not be made from within sub-expressions", Justification = "Unusual, but legit here")] - internal ValueTask UnwrapReservedAsync() + internal ValueTask UnwrapReservedAsync(ILogger log) { var task = Task; #if NETCOREAPP2_0_OR_GREATER || NETSTANDARD2_1_OR_GREATER @@ -142,16 +144,16 @@ internal ValueTask UnwrapReservedAsync() if (task.Status == TaskStatus.RanToCompletion) #endif { - return new(task.Result.GetReservedValue()); + return new(task.Result.GetReservedValue(log)); } // if the type is immutable, callers can share the final step too (this may leave dangling // reservation counters, but that's OK) - var result = ImmutableTypeCache.IsImmutable ? (_sharedUnwrap ??= AwaitedAsync(Task)) : AwaitedAsync(Task); + var result = ImmutableTypeCache.IsImmutable ? (_sharedUnwrap ??= AwaitedAsync(log, Task)) : AwaitedAsync(log, Task); return new(result); - static async Task AwaitedAsync(Task> task) - => (await task.ConfigureAwait(false)).GetReservedValue(); + static async Task AwaitedAsync(ILogger log, Task> task) + => (await task.ConfigureAwait(false)).GetReservedValue(log); } [DoesNotReturn] @@ -161,12 +163,43 @@ static async Task AwaitedAsync(Task> task) [SuppressMessage("Design", "CA1031:Do not catch general exception types", Justification = "Exception is passed through to faulted task result")] private async Task BackgroundFetchAsync() { + bool eventSourceEnabled = HybridCacheEventSource.Log.IsEnabled(); try { // read from L2 if appropriate if ((Key.Flags & HybridCacheEntryFlags.DisableDistributedCacheRead) == 0) { - var result = await Cache.GetFromL2Async(Key.Key, SharedToken).ConfigureAwait(false); + BufferChunk result; + try + { + if (eventSourceEnabled) + { + HybridCacheEventSource.Log.DistributedCacheGet(); + } + + result = await Cache.GetFromL2Async(Key.Key, SharedToken).ConfigureAwait(false); + if (eventSourceEnabled) + { + if (result.Array is not null) + { + HybridCacheEventSource.Log.DistributedCacheHit(); + } + else + { + HybridCacheEventSource.Log.DistributedCacheMiss(); + } + } + } + catch (Exception ex) + { + if (eventSourceEnabled) + { + HybridCacheEventSource.Log.DistributedCacheFailed(); + } + + Cache._logger.CacheUnderlyingDataQueryFailure(ex); + result = default; // treat as "miss" + } if (result.Array is not null) { @@ -179,7 +212,30 @@ private async Task BackgroundFetchAsync() if ((Key.Flags & HybridCacheEntryFlags.DisableUnderlyingData) == 0) { // invoke the callback supplied by the caller - T newValue = await _underlying!(_state!, SharedToken).ConfigureAwait(false); + T newValue; + try + { + if (eventSourceEnabled) + { + HybridCacheEventSource.Log.UnderlyingDataQueryStart(); + } + + newValue = await _underlying!(_state!, SharedToken).ConfigureAwait(false); + + if (eventSourceEnabled) + { + HybridCacheEventSource.Log.UnderlyingDataQueryComplete(); + } + } + catch + { + if (eventSourceEnabled) + { + HybridCacheEventSource.Log.UnderlyingDataQueryFailed(); + } + + throw; + } // If we're writing this value *anywhere*, we're going to need to serialize; this is obvious // in the case of L2, but we also need it for L1, because MemoryCache might be enforcing @@ -187,11 +243,11 @@ private async Task BackgroundFetchAsync() // Likewise, if we're writing to a MutableCacheItem, we'll be serializing *anyway* for the payload. // // Rephrasing that: the only scenario in which we *do not* need to serialize is if: - // - it is an ImmutableCacheItem - // - we're writing neither to L1 nor L2 + // - it is an ImmutableCacheItem (so we don't need bytes for the CacheItem, L1) + // - we're not writing to L2 CacheItem cacheItem = CacheItem; - bool skipSerialize = cacheItem is ImmutableCacheItem && (Key.Flags & FlagsDisableL1AndL2) == FlagsDisableL1AndL2; + bool skipSerialize = cacheItem is ImmutableCacheItem && (Key.Flags & FlagsDisableL1AndL2Write) == FlagsDisableL1AndL2Write; if (skipSerialize) { @@ -202,33 +258,55 @@ private async Task BackgroundFetchAsync() // ^^^ The first thing we need to do is make sure we're not getting into a thread race over buffer disposal. // In particular, if this cache item is somehow so short-lived that the buffers would be released *before* we're // done writing them to L2, which happens *after* we've provided the value to consumers. - RecyclableArrayBufferWriter writer = RecyclableArrayBufferWriter.Create(MaximumPayloadBytes); // note this lifetime spans the SetL2Async - IHybridCacheSerializer serializer = Cache.GetSerializer(); - serializer.Serialize(newValue, writer); - BufferChunk buffer = new(writer.DetachCommitted(out var length), length, returnToPool: true); // remove buffer ownership from the writer - writer.Dispose(); // we're done with the writer - - // protect "buffer" (this is why we "reserved") for writing to L2 if needed; SetResultPreSerialized - // *may* (depending on context) claim this buffer, in which case "bufferToRelease" gets reset, and - // the final RecycleIfAppropriate() is a no-op; however, the buffer is valid in either event, - // (with TryReserve above guaranteeing that we aren't in a race condition). - BufferChunk bufferToRelease = buffer; - - // and since "bufferToRelease" is the thing that will be returned at some point, we can make it explicit - // that we do not need or want "buffer" to do any recycling (they're the same memory) - buffer = buffer.DoNotReturnToPool(); - - // set the underlying result for this operation (includes L1 write if appropriate) - SetResultPreSerialized(newValue, ref bufferToRelease, serializer); - - // Note that at this point we've already released most or all of the waiting callers. Everything - // from this point onwards happens in the background, from the perspective of the calling code. - - // Write to L2 if appropriate. - if ((Key.Flags & HybridCacheEntryFlags.DisableDistributedCacheWrite) == 0) + + BufferChunk bufferToRelease = default; + if (Cache.TrySerialize(newValue, out var buffer, out var serializer)) { - // We already have the payload serialized, so this is trivial to do. - await Cache.SetL2Async(Key.Key, in buffer, _options, SharedToken).ConfigureAwait(false); + // note we also capture the resolved serializer ^^^ - we'll need it again later + + // protect "buffer" (this is why we "reserved") for writing to L2 if needed; SetResultPreSerialized + // *may* (depending on context) claim this buffer, in which case "bufferToRelease" gets reset, and + // the final RecycleIfAppropriate() is a no-op; however, the buffer is valid in either event, + // (with TryReserve above guaranteeing that we aren't in a race condition). + bufferToRelease = buffer; + + // and since "bufferToRelease" is the thing that will be returned at some point, we can make it explicit + // that we do not need or want "buffer" to do any recycling (they're the same memory) + buffer = buffer.DoNotReturnToPool(); + + // set the underlying result for this operation (includes L1 write if appropriate) + SetResultPreSerialized(newValue, ref bufferToRelease, serializer); + + // Note that at this point we've already released most or all of the waiting callers. Everything + // from this point onwards happens in the background, from the perspective of the calling code. + + // Write to L2 if appropriate. + if ((Key.Flags & HybridCacheEntryFlags.DisableDistributedCacheWrite) == 0) + { + // We already have the payload serialized, so this is trivial to do. + try + { + await Cache.SetL2Async(Key.Key, in buffer, _options, SharedToken).ConfigureAwait(false); + + if (eventSourceEnabled) + { + HybridCacheEventSource.Log.DistributedCacheWrite(); + } + } + catch (Exception ex) + { + // log the L2 write failure, but that doesn't need to interrupt the app flow (so: + // don't rethrow); L1 will still reduce impact, and L1 without L2 is better than + // hard failure every time + Cache._logger.CacheBackendWriteFailure(ex); + } + } + } + else + { + // unable to serialize (or quota exceeded); try to at least store the onwards value; this is + // especially useful for immutable data types + SetResultPreSerialized(newValue, ref bufferToRelease, serializer); } // Release our hook on the CacheItem (only really important for "mutable"). @@ -309,7 +387,7 @@ private void SetResultAndRecycleIfAppropriate(ref BufferChunk value) private void SetImmutableResultWithoutSerialize(T value) { - Debug.Assert((Key.Flags & FlagsDisableL1AndL2) == FlagsDisableL1AndL2, "Only expected if L1+L2 disabled"); + Debug.Assert((Key.Flags & FlagsDisableL1AndL2Write) == FlagsDisableL1AndL2Write, "Only expected if L1+L2 disabled"); // set a result from a value we calculated directly CacheItem cacheItem; @@ -328,7 +406,7 @@ private void SetImmutableResultWithoutSerialize(T value) SetResult(cacheItem); } - private void SetResultPreSerialized(T value, ref BufferChunk buffer, IHybridCacheSerializer serializer) + private void SetResultPreSerialized(T value, ref BufferChunk buffer, IHybridCacheSerializer? serializer) { // set a result from a value we calculated directly that // has ALREADY BEEN SERIALIZED (we can optionally consume this buffer) @@ -343,8 +421,17 @@ private void SetResultPreSerialized(T value, ref BufferChunk buffer, IHybridCach // (but leave the buffer alone) break; case MutableCacheItem mutable: - mutable.SetValue(ref buffer, serializer); - mutable.DebugOnlyTrackBuffer(Cache); + if (serializer is null) + { + // serialization is failing; set fallback value + mutable.SetFallbackValue(value); + } + else + { + mutable.SetValue(ref buffer, serializer); + mutable.DebugOnlyTrackBuffer(Cache); + } + cacheItem = mutable; break; default: diff --git a/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.cs b/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.cs index c789e7c6652..71dbf71fd54 100644 --- a/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.cs +++ b/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/DefaultHybridCache.cs @@ -22,6 +22,9 @@ namespace Microsoft.Extensions.Caching.Hybrid.Internal; /// internal sealed partial class DefaultHybridCache : HybridCache { + // reserve non-printable characters from keys, to prevent potential L2 abuse + private static readonly char[] _keyReservedCharacters = Enumerable.Range(0, 32).Select(i => (char)i).ToArray(); + [System.Diagnostics.CodeAnalysis.SuppressMessage("Style", "IDE0032:Use auto property", Justification = "Keep usage explicit")] private readonly IDistributedCache? _backendCache; [System.Diagnostics.CodeAnalysis.SuppressMessage("Style", "IDE0032:Use auto property", Justification = "Keep usage explicit")] @@ -37,6 +40,7 @@ internal sealed partial class DefaultHybridCache : HybridCache private readonly HybridCacheEntryFlags _defaultFlags; // note this already includes hardFlags private readonly TimeSpan _defaultExpiration; private readonly TimeSpan _defaultLocalCacheExpiration; + private readonly int _maximumKeyLength; private readonly DistributedCacheEntryOptions _defaultDistributedCacheExpiration; @@ -90,6 +94,7 @@ public DefaultHybridCache(IOptions options, IServiceProvider _serializerFactories = factories; MaximumPayloadBytes = checked((int)_options.MaximumPayloadBytes); // for now hard-limit to 2GiB + _maximumKeyLength = _options.MaximumKeyLength; var defaultEntryOptions = _options.DefaultEntryOptions; @@ -119,11 +124,33 @@ public override ValueTask GetOrCreateAsync(string key, TState stat } var flags = GetEffectiveFlags(options); - if ((flags & HybridCacheEntryFlags.DisableLocalCacheRead) == 0 && _localCache.TryGetValue(key, out var untyped) - && untyped is CacheItem typed && typed.TryGetValue(out var value)) + if (!ValidateKey(key)) { - // short-circuit - return new(value); + // we can't use cache, but we can still provide the data + return RunWithoutCacheAsync(flags, state, underlyingDataCallback, cancellationToken); + } + + bool eventSourceEnabled = HybridCacheEventSource.Log.IsEnabled(); + if ((flags & HybridCacheEntryFlags.DisableLocalCacheRead) == 0) + { + if (_localCache.TryGetValue(key, out var untyped) + && untyped is CacheItem typed && typed.TryGetValue(_logger, out var value)) + { + // short-circuit + if (eventSourceEnabled) + { + HybridCacheEventSource.Log.LocalCacheHit(); + } + + return new(value); + } + else + { + if (eventSourceEnabled) + { + HybridCacheEventSource.Log.LocalCacheMiss(); + } + } } if (GetOrCreateStampedeState(key, flags, out var stampede, canBeCanceled)) @@ -139,11 +166,19 @@ public override ValueTask GetOrCreateAsync(string key, TState stat { // we're going to run to completion; no need to get complicated _ = stampede.ExecuteDirectAsync(in state, underlyingDataCallback, options); // this larger task includes L2 write etc - return stampede.UnwrapReservedAsync(); + return stampede.UnwrapReservedAsync(_logger); + } + } + else + { + // pre-existing query + if (eventSourceEnabled) + { + HybridCacheEventSource.Log.StampedeJoin(); } } - return stampede.JoinAsync(cancellationToken); + return stampede.JoinAsync(_logger, cancellationToken); } public override ValueTask RemoveAsync(string key, CancellationToken token = default) @@ -164,7 +199,39 @@ public override ValueTask SetAsync(string key, T value, HybridCacheEntryOptio return new(state.ExecuteDirectAsync(value, static (state, _) => new(state), options)); // note this spans L2 write etc } + private static ValueTask RunWithoutCacheAsync(HybridCacheEntryFlags flags, TState state, + Func> underlyingDataCallback, + CancellationToken cancellationToken) + { + return (flags & HybridCacheEntryFlags.DisableUnderlyingData) == 0 + ? underlyingDataCallback(state, cancellationToken) : default; + } + [MethodImpl(MethodImplOptions.AggressiveInlining)] private HybridCacheEntryFlags GetEffectiveFlags(HybridCacheEntryOptions? options) - => (options?.Flags | _hardFlags) ?? _defaultFlags; + => (options?.Flags | _hardFlags) ?? _defaultFlags; + + private bool ValidateKey(string key) + { + if (string.IsNullOrWhiteSpace(key)) + { + _logger.KeyEmptyOrWhitespace(); + return false; + } + + if (key.Length > _maximumKeyLength) + { + _logger.MaximumKeyLengthExceeded(_maximumKeyLength, key.Length); + return false; + } + + if (key.IndexOfAny(_keyReservedCharacters) >= 0) + { + _logger.KeyInvalidContent(); + return false; + } + + // nothing to complain about + return true; + } } diff --git a/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/HybridCacheEventSource.cs b/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/HybridCacheEventSource.cs new file mode 100644 index 00000000000..92a5d729e57 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/HybridCacheEventSource.cs @@ -0,0 +1,203 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics; +using System.Diagnostics.Tracing; +using System.Runtime.CompilerServices; +using System.Threading; + +namespace Microsoft.Extensions.Caching.Hybrid.Internal; + +[EventSource(Name = "Microsoft-Extensions-HybridCache")] +internal sealed class HybridCacheEventSource : EventSource +{ + public static readonly HybridCacheEventSource Log = new(); + + internal const int EventIdLocalCacheHit = 1; + internal const int EventIdLocalCacheMiss = 2; + internal const int EventIdDistributedCacheGet = 3; + internal const int EventIdDistributedCacheHit = 4; + internal const int EventIdDistributedCacheMiss = 5; + internal const int EventIdDistributedCacheFailed = 6; + internal const int EventIdUnderlyingDataQueryStart = 7; + internal const int EventIdUnderlyingDataQueryComplete = 8; + internal const int EventIdUnderlyingDataQueryFailed = 9; + internal const int EventIdLocalCacheWrite = 10; + internal const int EventIdDistributedCacheWrite = 11; + internal const int EventIdStampedeJoin = 12; + + // fast local counters + private long _totalLocalCacheHit; + private long _totalLocalCacheMiss; + private long _totalDistributedCacheHit; + private long _totalDistributedCacheMiss; + private long _totalUnderlyingDataQuery; + private long _currentUnderlyingDataQuery; + private long _currentDistributedFetch; + private long _totalLocalCacheWrite; + private long _totalDistributedCacheWrite; + private long _totalStampedeJoin; + +#if !(NETSTANDARD2_0 || NET462) + // full Counter infrastructure + private DiagnosticCounter[]? _counters; +#endif + + [NonEvent] + public void ResetCounters() + { + Debug.WriteLine($"{nameof(HybridCacheEventSource)} counters reset!"); + + Volatile.Write(ref _totalLocalCacheHit, 0); + Volatile.Write(ref _totalLocalCacheMiss, 0); + Volatile.Write(ref _totalDistributedCacheHit, 0); + Volatile.Write(ref _totalDistributedCacheMiss, 0); + Volatile.Write(ref _totalUnderlyingDataQuery, 0); + Volatile.Write(ref _currentUnderlyingDataQuery, 0); + Volatile.Write(ref _currentDistributedFetch, 0); + Volatile.Write(ref _totalLocalCacheWrite, 0); + Volatile.Write(ref _totalDistributedCacheWrite, 0); + Volatile.Write(ref _totalStampedeJoin, 0); + } + + [Event(EventIdLocalCacheHit, Level = EventLevel.Verbose)] + public void LocalCacheHit() + { + DebugAssertEnabled(); + _ = Interlocked.Increment(ref _totalLocalCacheHit); + WriteEvent(EventIdLocalCacheHit); + } + + [Event(EventIdLocalCacheMiss, Level = EventLevel.Verbose)] + public void LocalCacheMiss() + { + DebugAssertEnabled(); + _ = Interlocked.Increment(ref _totalLocalCacheMiss); + WriteEvent(EventIdLocalCacheMiss); + } + + [Event(EventIdDistributedCacheGet, Level = EventLevel.Verbose)] + public void DistributedCacheGet() + { + // should be followed by DistributedCacheHit, DistributedCacheMiss or DistributedCacheFailed + DebugAssertEnabled(); + _ = Interlocked.Increment(ref _currentDistributedFetch); + WriteEvent(EventIdDistributedCacheGet); + } + + [Event(EventIdDistributedCacheHit, Level = EventLevel.Verbose)] + public void DistributedCacheHit() + { + DebugAssertEnabled(); + + // note: not concerned about off-by-one here, i.e. don't panic + // about these two being atomic ref each-other - just the overall shape + _ = Interlocked.Increment(ref _totalDistributedCacheHit); + _ = Interlocked.Decrement(ref _currentDistributedFetch); + WriteEvent(EventIdDistributedCacheHit); + } + + [Event(EventIdDistributedCacheMiss, Level = EventLevel.Verbose)] + public void DistributedCacheMiss() + { + DebugAssertEnabled(); + + // note: not concerned about off-by-one here, i.e. don't panic + // about these two being atomic ref each-other - just the overall shape + _ = Interlocked.Increment(ref _totalDistributedCacheMiss); + _ = Interlocked.Decrement(ref _currentDistributedFetch); + WriteEvent(EventIdDistributedCacheMiss); + } + + [Event(EventIdDistributedCacheFailed, Level = EventLevel.Error)] + public void DistributedCacheFailed() + { + DebugAssertEnabled(); + _ = Interlocked.Decrement(ref _currentDistributedFetch); + WriteEvent(EventIdDistributedCacheFailed); + } + + [Event(EventIdUnderlyingDataQueryStart, Level = EventLevel.Verbose)] + public void UnderlyingDataQueryStart() + { + // should be followed by UnderlyingDataQueryComplete or UnderlyingDataQueryFailed + DebugAssertEnabled(); + _ = Interlocked.Increment(ref _totalUnderlyingDataQuery); + _ = Interlocked.Increment(ref _currentUnderlyingDataQuery); + WriteEvent(EventIdUnderlyingDataQueryStart); + } + + [Event(EventIdUnderlyingDataQueryComplete, Level = EventLevel.Verbose)] + public void UnderlyingDataQueryComplete() + { + DebugAssertEnabled(); + _ = Interlocked.Decrement(ref _currentUnderlyingDataQuery); + WriteEvent(EventIdUnderlyingDataQueryComplete); + } + + [Event(EventIdUnderlyingDataQueryFailed, Level = EventLevel.Error)] + public void UnderlyingDataQueryFailed() + { + DebugAssertEnabled(); + _ = Interlocked.Decrement(ref _currentUnderlyingDataQuery); + WriteEvent(EventIdUnderlyingDataQueryFailed); + } + + [Event(EventIdLocalCacheWrite, Level = EventLevel.Verbose)] + public void LocalCacheWrite() + { + DebugAssertEnabled(); + _ = Interlocked.Increment(ref _totalLocalCacheWrite); + WriteEvent(EventIdLocalCacheWrite); + } + + [Event(EventIdDistributedCacheWrite, Level = EventLevel.Verbose)] + public void DistributedCacheWrite() + { + DebugAssertEnabled(); + _ = Interlocked.Increment(ref _totalDistributedCacheWrite); + WriteEvent(EventIdDistributedCacheWrite); + } + + [Event(EventIdStampedeJoin, Level = EventLevel.Verbose)] + internal void StampedeJoin() + { + DebugAssertEnabled(); + _ = Interlocked.Increment(ref _totalStampedeJoin); + WriteEvent(EventIdStampedeJoin); + } + +#if !(NETSTANDARD2_0 || NET462) + [System.Diagnostics.CodeAnalysis.SuppressMessage("Reliability", "CA2000:Dispose objects before losing scope", Justification = "Lifetime exceeds obvious scope; handed to event source")] + [NonEvent] + protected override void OnEventCommand(EventCommandEventArgs command) + { + if (command.Command == EventCommand.Enable) + { + // lazily create counters on first Enable + _counters ??= [ + new PollingCounter("total-local-cache-hits", this, () => Volatile.Read(ref _totalLocalCacheHit)) { DisplayName = "Total Local Cache Hits" }, + new PollingCounter("total-local-cache-misses", this, () => Volatile.Read(ref _totalLocalCacheMiss)) { DisplayName = "Total Local Cache Misses" }, + new PollingCounter("total-distributed-cache-hits", this, () => Volatile.Read(ref _totalDistributedCacheHit)) { DisplayName = "Total Distributed Cache Hits" }, + new PollingCounter("total-distributed-cache-misses", this, () => Volatile.Read(ref _totalDistributedCacheMiss)) { DisplayName = "Total Distributed Cache Misses" }, + new PollingCounter("total-data-query", this, () => Volatile.Read(ref _totalUnderlyingDataQuery)) { DisplayName = "Total Data Queries" }, + new PollingCounter("current-data-query", this, () => Volatile.Read(ref _currentUnderlyingDataQuery)) { DisplayName = "Current Data Queries" }, + new PollingCounter("current-distributed-cache-fetches", this, () => Volatile.Read(ref _currentDistributedFetch)) { DisplayName = "Current Distributed Cache Fetches" }, + new PollingCounter("total-local-cache-writes", this, () => Volatile.Read(ref _totalLocalCacheWrite)) { DisplayName = "Total Local Cache Writes" }, + new PollingCounter("total-distributed-cache-writes", this, () => Volatile.Read(ref _totalDistributedCacheWrite)) { DisplayName = "Total Distributed Cache Writes" }, + new PollingCounter("total-stampede-joins", this, () => Volatile.Read(ref _totalStampedeJoin)) { DisplayName = "Total Stampede Joins" }, + ]; + } + + base.OnEventCommand(command); + } +#endif + + [NonEvent] + [Conditional("DEBUG")] + private void DebugAssertEnabled([CallerMemberName] string caller = "") + { + Debug.Assert(IsEnabled(), $"Missing check to {nameof(HybridCacheEventSource)}.{nameof(Log)}.{nameof(IsEnabled)} from {caller}"); + Debug.WriteLine($"{nameof(HybridCacheEventSource)}: {caller}"); // also log all event calls, for visibility + } +} diff --git a/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/InbuiltTypeSerializer.cs b/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/InbuiltTypeSerializer.cs index 3ef26341433..4800428a88f 100644 --- a/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/InbuiltTypeSerializer.cs +++ b/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/InbuiltTypeSerializer.cs @@ -17,6 +17,18 @@ internal sealed class InbuiltTypeSerializer : IHybridCacheSerializer, IH public static InbuiltTypeSerializer Instance { get; } = new(); string IHybridCacheSerializer.Deserialize(ReadOnlySequence source) + => DeserializeString(source); + + void IHybridCacheSerializer.Serialize(string value, IBufferWriter target) + => SerializeString(value, target); + + byte[] IHybridCacheSerializer.Deserialize(ReadOnlySequence source) + => source.ToArray(); + + void IHybridCacheSerializer.Serialize(byte[] value, IBufferWriter target) + => target.Write(value); + + internal static string DeserializeString(ReadOnlySequence source) { #if NET5_0_OR_GREATER return Encoding.UTF8.GetString(source); @@ -36,7 +48,7 @@ string IHybridCacheSerializer.Deserialize(ReadOnlySequence source) #endif } - void IHybridCacheSerializer.Serialize(string value, IBufferWriter target) + internal static void SerializeString(string value, IBufferWriter target) { #if NET5_0_OR_GREATER Encoding.UTF8.GetBytes(value, target); @@ -49,10 +61,4 @@ void IHybridCacheSerializer.Serialize(string value, IBufferWriter ArrayPool.Shared.Return(oversized); #endif } - - byte[] IHybridCacheSerializer.Deserialize(ReadOnlySequence source) - => source.ToArray(); - - void IHybridCacheSerializer.Serialize(byte[] value, IBufferWriter target) - => target.Write(value); } diff --git a/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/Log.cs b/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/Log.cs new file mode 100644 index 00000000000..785107c32ec --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/Log.cs @@ -0,0 +1,49 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using Microsoft.Extensions.Logging; + +namespace Microsoft.Extensions.Caching.Hybrid.Internal; + +internal static partial class Log +{ + internal const int IdMaximumPayloadBytesExceeded = 1; + internal const int IdSerializationFailure = 2; + internal const int IdDeserializationFailure = 3; + internal const int IdKeyEmptyOrWhitespace = 4; + internal const int IdMaximumKeyLengthExceeded = 5; + internal const int IdCacheBackendReadFailure = 6; + internal const int IdCacheBackendWriteFailure = 7; + internal const int IdKeyInvalidContent = 8; + + [LoggerMessage(LogLevel.Error, "Cache MaximumPayloadBytes ({Bytes}) exceeded.", EventName = "MaximumPayloadBytesExceeded", EventId = IdMaximumPayloadBytesExceeded, SkipEnabledCheck = false)] + internal static partial void MaximumPayloadBytesExceeded(this ILogger logger, Exception e, int bytes); + + // note that serialization is critical enough that we perform hard failures in addition to logging; serialization + // failures are unlikely to be transient (i.e. connectivity); we would rather this shows up in QA, rather than + // being invisible and people *thinking* they're using cache, when actually they are not + + [LoggerMessage(LogLevel.Error, "Cache serialization failure.", EventName = "SerializationFailure", EventId = IdSerializationFailure, SkipEnabledCheck = false)] + internal static partial void SerializationFailure(this ILogger logger, Exception e); + + // (see same notes per SerializationFailure) + [LoggerMessage(LogLevel.Error, "Cache deserialization failure.", EventName = "DeserializationFailure", EventId = IdDeserializationFailure, SkipEnabledCheck = false)] + internal static partial void DeserializationFailure(this ILogger logger, Exception e); + + [LoggerMessage(LogLevel.Error, "Cache key empty or whitespace.", EventName = "KeyEmptyOrWhitespace", EventId = IdKeyEmptyOrWhitespace, SkipEnabledCheck = false)] + internal static partial void KeyEmptyOrWhitespace(this ILogger logger); + + [LoggerMessage(LogLevel.Error, "Cache key maximum length exceeded (maximum: {MaxLength}, actual: {KeyLength}).", EventName = "MaximumKeyLengthExceeded", + EventId = IdMaximumKeyLengthExceeded, SkipEnabledCheck = false)] + internal static partial void MaximumKeyLengthExceeded(this ILogger logger, int maxLength, int keyLength); + + [LoggerMessage(LogLevel.Error, "Cache backend read failure.", EventName = "CacheBackendReadFailure", EventId = IdCacheBackendReadFailure, SkipEnabledCheck = false)] + internal static partial void CacheUnderlyingDataQueryFailure(this ILogger logger, Exception ex); + + [LoggerMessage(LogLevel.Error, "Cache backend write failure.", EventName = "CacheBackendWriteFailure", EventId = IdCacheBackendWriteFailure, SkipEnabledCheck = false)] + internal static partial void CacheBackendWriteFailure(this ILogger logger, Exception ex); + + [LoggerMessage(LogLevel.Error, "Cache key contains invalid content.", EventName = "KeyInvalidContent", EventId = IdKeyInvalidContent, SkipEnabledCheck = false)] + internal static partial void KeyInvalidContent(this ILogger logger); // for PII etc reasons, we won't include the actual key +} diff --git a/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/RecyclableArrayBufferWriter.cs b/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/RecyclableArrayBufferWriter.cs index 2f2da2c7019..985d55c9f0e 100644 --- a/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/RecyclableArrayBufferWriter.cs +++ b/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Internal/RecyclableArrayBufferWriter.cs @@ -46,20 +46,20 @@ internal sealed class RecyclableArrayBufferWriter : IBufferWriter, IDispos public int CommittedBytes => _index; public int FreeCapacity => _buffer.Length - _index; + public bool QuotaExceeded { get; private set; } + private static RecyclableArrayBufferWriter? _spare; + public static RecyclableArrayBufferWriter Create(int maxLength) { var obj = Interlocked.Exchange(ref _spare, null) ?? new(); - Debug.Assert(obj._index == 0, "index should be zero initially"); - obj._maxLength = maxLength; + obj.Initialize(maxLength); return obj; } private RecyclableArrayBufferWriter() { _buffer = []; - _index = 0; - _maxLength = int.MaxValue; } public void Dispose() @@ -91,6 +91,7 @@ public void Advance(int count) if (_index + count > _maxLength) { + QuotaExceeded = true; ThrowQuota(); } @@ -199,4 +200,12 @@ private void CheckAndResizeBuffer(int sizeHint) static void ThrowOutOfMemoryException() => throw new InvalidOperationException("Unable to grow buffer as requested"); } + + private void Initialize(int maxLength) + { + // think .ctor, but with pooled object re-use + _index = 0; + _maxLength = maxLength; + QuotaExceeded = false; + } } diff --git a/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Microsoft.Extensions.Caching.Hybrid.csproj b/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Microsoft.Extensions.Caching.Hybrid.csproj index ec8946d2f9d..d3029266b57 100644 --- a/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Microsoft.Extensions.Caching.Hybrid.csproj +++ b/src/Libraries/Microsoft.Extensions.Caching.Hybrid/Microsoft.Extensions.Caching.Hybrid.csproj @@ -4,7 +4,7 @@ Multi-level caching implementation building on and extending IDistributedCache $(NetCoreTargetFrameworks)$(ConditionalNet462);netstandard2.0;netstandard2.1 true - cache;distributedcache;hybrid + cache;distributedcache;hybridcache true true true @@ -21,6 +21,11 @@ true + true + true + + + false diff --git a/src/Libraries/Microsoft.Extensions.Compliance.Abstractions/Microsoft.Extensions.Compliance.Abstractions.csproj b/src/Libraries/Microsoft.Extensions.Compliance.Abstractions/Microsoft.Extensions.Compliance.Abstractions.csproj index 5a6c93e1dc7..c83b7284da5 100644 --- a/src/Libraries/Microsoft.Extensions.Compliance.Abstractions/Microsoft.Extensions.Compliance.Abstractions.csproj +++ b/src/Libraries/Microsoft.Extensions.Compliance.Abstractions/Microsoft.Extensions.Compliance.Abstractions.csproj @@ -1,6 +1,7 @@  Microsoft.Extensions.Compliance + $(NetCoreTargetFrameworks);netstandard2.0; Abstractions to help ensure compliant data management. Fundamentals diff --git a/src/Shared/JsonSchemaExporter/JsonSchemaExporter.JsonSchema.cs b/src/Shared/JsonSchemaExporter/JsonSchemaExporter.JsonSchema.cs new file mode 100644 index 00000000000..0f1044fc6eb --- /dev/null +++ b/src/Shared/JsonSchemaExporter/JsonSchemaExporter.JsonSchema.cs @@ -0,0 +1,545 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +#if !NET9_0_OR_GREATER +using System.Collections.Generic; +using System.Diagnostics; +using System.Text.Json.Nodes; + +namespace System.Text.Json.Schema; + +#pragma warning disable SA1204 // Static elements should appear before instance elements +#pragma warning disable S1144 // Unused private types or members should be removed + +internal static partial class JsonSchemaExporter +{ + // Simple JSON schema representation taken from System.Text.Json + // https://github.com/dotnet/runtime/blob/50d6cad649aad2bfa4069268eddd16fd51ec5cf3/src/libraries/System.Text.Json/src/System/Text/Json/Schema/JsonSchema.cs + private sealed class JsonSchema + { + public static JsonSchema False { get; } = new(false); + public static JsonSchema True { get; } = new(true); + + public JsonSchema() + { + } + + private JsonSchema(bool trueOrFalse) + { + _trueOrFalse = trueOrFalse; + } + + public bool IsTrue => _trueOrFalse is true; + public bool IsFalse => _trueOrFalse is false; + private readonly bool? _trueOrFalse; + + public string? Schema + { + get => _schema; + set + { + VerifyMutable(); + _schema = value; + } + } + + private string? _schema; + + public string? Title + { + get => _title; + set + { + VerifyMutable(); + _title = value; + } + } + + private string? _title; + + public string? Description + { + get => _description; + set + { + VerifyMutable(); + _description = value; + } + } + + private string? _description; + + public string? Ref + { + get => _ref; + set + { + VerifyMutable(); + _ref = value; + } + } + + private string? _ref; + + public string? Comment + { + get => _comment; + set + { + VerifyMutable(); + _comment = value; + } + } + + private string? _comment; + + public JsonSchemaType Type + { + get => _type; + set + { + VerifyMutable(); + _type = value; + } + } + + private JsonSchemaType _type = JsonSchemaType.Any; + + public string? Format + { + get => _format; + set + { + VerifyMutable(); + _format = value; + } + } + + private string? _format; + + public string? Pattern + { + get => _pattern; + set + { + VerifyMutable(); + _pattern = value; + } + } + + private string? _pattern; + + public JsonNode? Constant + { + get => _constant; + set + { + VerifyMutable(); + _constant = value; + } + } + + private JsonNode? _constant; + + public List>? Properties + { + get => _properties; + set + { + VerifyMutable(); + _properties = value; + } + } + + private List>? _properties; + + public List? Required + { + get => _required; + set + { + VerifyMutable(); + _required = value; + } + } + + private List? _required; + + public JsonSchema? Items + { + get => _items; + set + { + VerifyMutable(); + _items = value; + } + } + + private JsonSchema? _items; + + public JsonSchema? AdditionalProperties + { + get => _additionalProperties; + set + { + VerifyMutable(); + _additionalProperties = value; + } + } + + private JsonSchema? _additionalProperties; + + public JsonArray? Enum + { + get => _enum; + set + { + VerifyMutable(); + _enum = value; + } + } + + private JsonArray? _enum; + + public JsonSchema? Not + { + get => _not; + set + { + VerifyMutable(); + _not = value; + } + } + + private JsonSchema? _not; + + public List? AnyOf + { + get => _anyOf; + set + { + VerifyMutable(); + _anyOf = value; + } + } + + private List? _anyOf; + + public bool HasDefaultValue + { + get => _hasDefaultValue; + set + { + VerifyMutable(); + _hasDefaultValue = value; + } + } + + private bool _hasDefaultValue; + + public JsonNode? DefaultValue + { + get => _defaultValue; + set + { + VerifyMutable(); + _defaultValue = value; + } + } + + private JsonNode? _defaultValue; + + public int? MinLength + { + get => _minLength; + set + { + VerifyMutable(); + _minLength = value; + } + } + + private int? _minLength; + + public int? MaxLength + { + get => _maxLength; + set + { + VerifyMutable(); + _maxLength = value; + } + } + + private int? _maxLength; + + public JsonSchemaExporterContext? GenerationContext { get; set; } + + public int KeywordCount + { + get + { + if (_trueOrFalse != null) + { + return 0; + } + + int count = 0; + Count(Schema != null); + Count(Ref != null); + Count(Comment != null); + Count(Title != null); + Count(Description != null); + Count(Type != JsonSchemaType.Any); + Count(Format != null); + Count(Pattern != null); + Count(Constant != null); + Count(Properties != null); + Count(Required != null); + Count(Items != null); + Count(AdditionalProperties != null); + Count(Enum != null); + Count(Not != null); + Count(AnyOf != null); + Count(HasDefaultValue); + Count(MinLength != null); + Count(MaxLength != null); + + return count; + + void Count(bool isKeywordSpecified) => count += isKeywordSpecified ? 1 : 0; + } + } + + public void MakeNullable() + { + if (_trueOrFalse != null) + { + return; + } + + if (Type != JsonSchemaType.Any) + { + Type |= JsonSchemaType.Null; + } + } + + public JsonNode ToJsonNode(JsonSchemaExporterOptions options) + { + if (_trueOrFalse is { } boolSchema) + { + return CompleteSchema((JsonNode)boolSchema); + } + + var objSchema = new JsonObject(); + + if (Schema != null) + { + objSchema.Add(JsonSchemaConstants.SchemaPropertyName, Schema); + } + + if (Title != null) + { + objSchema.Add(JsonSchemaConstants.TitlePropertyName, Title); + } + + if (Description != null) + { + objSchema.Add(JsonSchemaConstants.DescriptionPropertyName, Description); + } + + if (Ref != null) + { + objSchema.Add(JsonSchemaConstants.RefPropertyName, Ref); + } + + if (Comment != null) + { + objSchema.Add(JsonSchemaConstants.CommentPropertyName, Comment); + } + + if (MapSchemaType(Type) is JsonNode type) + { + objSchema.Add(JsonSchemaConstants.TypePropertyName, type); + } + + if (Format != null) + { + objSchema.Add(JsonSchemaConstants.FormatPropertyName, Format); + } + + if (Pattern != null) + { + objSchema.Add(JsonSchemaConstants.PatternPropertyName, Pattern); + } + + if (Constant != null) + { + objSchema.Add(JsonSchemaConstants.ConstPropertyName, Constant); + } + + if (Properties != null) + { + var properties = new JsonObject(); + foreach (KeyValuePair property in Properties) + { + properties.Add(property.Key, property.Value.ToJsonNode(options)); + } + + objSchema.Add(JsonSchemaConstants.PropertiesPropertyName, properties); + } + + if (Required != null) + { + var requiredArray = new JsonArray(); + foreach (string requiredProperty in Required) + { + requiredArray.Add((JsonNode)requiredProperty); + } + + objSchema.Add(JsonSchemaConstants.RequiredPropertyName, requiredArray); + } + + if (Items != null) + { + objSchema.Add(JsonSchemaConstants.ItemsPropertyName, Items.ToJsonNode(options)); + } + + if (AdditionalProperties != null) + { + objSchema.Add(JsonSchemaConstants.AdditionalPropertiesPropertyName, AdditionalProperties.ToJsonNode(options)); + } + + if (Enum != null) + { + objSchema.Add(JsonSchemaConstants.EnumPropertyName, Enum); + } + + if (Not != null) + { + objSchema.Add(JsonSchemaConstants.NotPropertyName, Not.ToJsonNode(options)); + } + + if (AnyOf != null) + { + JsonArray anyOfArray = new(); + foreach (JsonSchema schema in AnyOf) + { + anyOfArray.Add(schema.ToJsonNode(options)); + } + + objSchema.Add(JsonSchemaConstants.AnyOfPropertyName, anyOfArray); + } + + if (HasDefaultValue) + { + objSchema.Add(JsonSchemaConstants.DefaultPropertyName, DefaultValue); + } + + if (MinLength is int minLength) + { + objSchema.Add(JsonSchemaConstants.MinLengthPropertyName, (JsonNode)minLength); + } + + if (MaxLength is int maxLength) + { + objSchema.Add(JsonSchemaConstants.MaxLengthPropertyName, (JsonNode)maxLength); + } + + return CompleteSchema(objSchema); + + JsonNode CompleteSchema(JsonNode schema) + { + if (GenerationContext is { } context) + { + Debug.Assert(options.TransformSchemaNode != null, "context should only be populated if a callback is present."); + + // Apply any user-defined transformations to the schema. + return options.TransformSchemaNode!(context, schema); + } + + return schema; + } + } + + public static void EnsureMutable(ref JsonSchema schema) + { + switch (schema._trueOrFalse) + { + case false: + schema = new JsonSchema { Not = JsonSchema.True }; + break; + case true: + schema = new JsonSchema(); + break; + } + } + + private static readonly JsonSchemaType[] _schemaValues = new JsonSchemaType[] + { + // NB the order of these values influences order of types in the rendered schema + JsonSchemaType.String, + JsonSchemaType.Integer, + JsonSchemaType.Number, + JsonSchemaType.Boolean, + JsonSchemaType.Array, + JsonSchemaType.Object, + JsonSchemaType.Null, + }; + + private void VerifyMutable() + { + Debug.Assert(_trueOrFalse is null, "Schema is not mutable"); + } + + private static JsonNode? MapSchemaType(JsonSchemaType schemaType) + { + if (schemaType is JsonSchemaType.Any) + { + return null; + } + + if (ToIdentifier(schemaType) is string identifier) + { + return identifier; + } + + var array = new JsonArray(); + foreach (JsonSchemaType type in _schemaValues) + { + if ((schemaType & type) != 0) + { + array.Add((JsonNode)ToIdentifier(type)!); + } + } + + return array; + + static string? ToIdentifier(JsonSchemaType schemaType) => schemaType switch + { + JsonSchemaType.Null => "null", + JsonSchemaType.Boolean => "boolean", + JsonSchemaType.Integer => "integer", + JsonSchemaType.Number => "number", + JsonSchemaType.String => "string", + JsonSchemaType.Array => "array", + JsonSchemaType.Object => "object", + _ => null, + }; + } + } + + [Flags] + private enum JsonSchemaType + { + Any = 0, // No type declared on the schema + Null = 1, + Boolean = 2, + Integer = 4, + Number = 8, + String = 16, + Array = 32, + Object = 64, + } +} +#endif diff --git a/src/Shared/JsonSchemaExporter/JsonSchemaExporter.ReflectionHelpers.cs b/src/Shared/JsonSchemaExporter/JsonSchemaExporter.ReflectionHelpers.cs new file mode 100644 index 00000000000..481e5f75753 --- /dev/null +++ b/src/Shared/JsonSchemaExporter/JsonSchemaExporter.ReflectionHelpers.cs @@ -0,0 +1,427 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +#if !NET9_0_OR_GREATER +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +#if !NET +using System.Linq; +#endif +using System.Reflection; +using System.Text.Json.Serialization; +using System.Text.Json.Serialization.Metadata; +using Microsoft.Shared.Diagnostics; + +#pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields + +namespace System.Text.Json.Schema; + +internal static partial class JsonSchemaExporter +{ + private static class ReflectionHelpers + { + private const BindingFlags AllInstance = BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic; + private static PropertyInfo? _jsonTypeInfo_ElementType; + private static PropertyInfo? _jsonPropertyInfo_MemberName; + private static FieldInfo? _nullableConverter_ElementConverter_Generic; + private static FieldInfo? _enumConverter_Options_Generic; + private static FieldInfo? _enumConverter_NamingPolicy_Generic; + + public static bool IsBuiltInConverter(JsonConverter converter) => + converter.GetType().Assembly == typeof(JsonConverter).Assembly; + + public static bool CanBeNull(Type type) => !type.IsValueType || Nullable.GetUnderlyingType(type) is not null; + + public static Type GetElementType(JsonTypeInfo typeInfo) + { + Debug.Assert(typeInfo.Kind is JsonTypeInfoKind.Enumerable or JsonTypeInfoKind.Dictionary, "TypeInfo must be of collection type"); + + // Uses reflection to access the element type encapsulated by a JsonTypeInfo. + if (_jsonTypeInfo_ElementType is null) + { + PropertyInfo? elementTypeProperty = typeof(JsonTypeInfo).GetProperty("ElementType", AllInstance); + _jsonTypeInfo_ElementType = Throw.IfNull(elementTypeProperty); + } + + return (Type)_jsonTypeInfo_ElementType.GetValue(typeInfo)!; + } + + public static string? GetMemberName(JsonPropertyInfo propertyInfo) + { + // Uses reflection to the member name encapsulated by a JsonPropertyInfo. + if (_jsonPropertyInfo_MemberName is null) + { + PropertyInfo? memberName = typeof(JsonPropertyInfo).GetProperty("MemberName", AllInstance); + _jsonPropertyInfo_MemberName = Throw.IfNull(memberName); + } + + return (string?)_jsonPropertyInfo_MemberName.GetValue(propertyInfo); + } + + public static JsonConverter GetElementConverter(JsonConverter nullableConverter) + { + // Uses reflection to access the element converter encapsulated by a nullable converter. + if (_nullableConverter_ElementConverter_Generic is null) + { + FieldInfo? genericFieldInfo = Type + .GetType("System.Text.Json.Serialization.Converters.NullableConverter`1, System.Text.Json")! + .GetField("_elementConverter", AllInstance); + + _nullableConverter_ElementConverter_Generic = Throw.IfNull(genericFieldInfo); + } + + Type converterType = nullableConverter.GetType(); + var thisFieldInfo = (FieldInfo)converterType.GetMemberWithSameMetadataDefinitionAs(_nullableConverter_ElementConverter_Generic); + return (JsonConverter)thisFieldInfo.GetValue(nullableConverter)!; + } + + public static void GetEnumConverterConfig(JsonConverter enumConverter, out JsonNamingPolicy? namingPolicy, out bool allowString) + { + // Uses reflection to access configuration encapsulated by an enum converter. + if (_enumConverter_Options_Generic is null) + { + FieldInfo? genericFieldInfo = Type + .GetType("System.Text.Json.Serialization.Converters.EnumConverter`1, System.Text.Json")! + .GetField("_converterOptions", AllInstance); + + _enumConverter_Options_Generic = Throw.IfNull(genericFieldInfo); + } + + if (_enumConverter_NamingPolicy_Generic is null) + { + FieldInfo? genericFieldInfo = Type + .GetType("System.Text.Json.Serialization.Converters.EnumConverter`1, System.Text.Json")! + .GetField("_namingPolicy", AllInstance); + + _enumConverter_NamingPolicy_Generic = Throw.IfNull(genericFieldInfo); + } + + const int EnumConverterOptionsAllowStrings = 1; + Type converterType = enumConverter.GetType(); + var converterOptionsField = (FieldInfo)converterType.GetMemberWithSameMetadataDefinitionAs(_enumConverter_Options_Generic); + var namingPolicyField = (FieldInfo)converterType.GetMemberWithSameMetadataDefinitionAs(_enumConverter_NamingPolicy_Generic); + + namingPolicy = (JsonNamingPolicy?)namingPolicyField.GetValue(enumConverter); + int converterOptions = (int)converterOptionsField.GetValue(enumConverter)!; + allowString = (converterOptions & EnumConverterOptionsAllowStrings) != 0; + } + + // The .NET 8 source generator doesn't populate attribute providers for properties + // cf. https://github.com/dotnet/runtime/issues/100095 + // Work around the issue by running a query for the relevant MemberInfo using the internal MemberName property + // https://github.com/dotnet/runtime/blob/de774ff9ee1a2c06663ab35be34b755cd8d29731/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/JsonPropertyInfo.cs#L206 + public static ICustomAttributeProvider? ResolveAttributeProvider( + [DynamicallyAccessedMembers( + DynamicallyAccessedMemberTypes.PublicProperties | DynamicallyAccessedMemberTypes.NonPublicProperties | + DynamicallyAccessedMemberTypes.PublicFields | DynamicallyAccessedMemberTypes.NonPublicFields)] + Type? declaringType, + JsonPropertyInfo? propertyInfo) + { + if (declaringType is null || propertyInfo is null) + { + return null; + } + + if (propertyInfo.AttributeProvider is { } provider) + { + return provider; + } + + string? memberName = ReflectionHelpers.GetMemberName(propertyInfo); + if (memberName is not null) + { + return (MemberInfo?)declaringType.GetProperty(memberName, AllInstance) ?? + declaringType.GetField(memberName, AllInstance); + } + + return null; + } + + // Resolves the parameters of the deserialization constructor for a type, if they exist. + public static Func? ResolveJsonConstructorParameterMapper( + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors | DynamicallyAccessedMemberTypes.NonPublicConstructors)] + Type type, + JsonTypeInfo typeInfo) + { + Debug.Assert(type == typeInfo.Type, "The declaring type must match the typeInfo type."); + Debug.Assert(typeInfo.Kind is JsonTypeInfoKind.Object, "Should only be passed object JSON kinds."); + + if (typeInfo.Properties.Count > 0 && + typeInfo.CreateObject is null && // Ensure that a default constructor isn't being used + TryGetDeserializationConstructor(type, useDefaultCtorInAnnotatedStructs: true, out ConstructorInfo? ctor)) + { + ParameterInfo[]? parameters = ctor?.GetParameters(); + if (parameters?.Length > 0) + { + Dictionary dict = new(parameters.Length); + foreach (ParameterInfo parameter in parameters) + { + if (parameter.Name is not null) + { + // We don't care about null parameter names or conflicts since they + // would have already been rejected by JsonTypeInfo exporterOptions. + dict[new(parameter.Name, parameter.ParameterType)] = parameter; + } + } + + return prop => dict.TryGetValue(new(prop.Name, prop.PropertyType), out ParameterInfo? parameter) ? parameter : null; + } + } + + return null; + } + + // Resolves the nullable reference type annotations for a property or field, + // additionally addressing a few known bugs of the NullabilityInfo pre .NET 9. + public static NullabilityInfo GetMemberNullability(NullabilityInfoContext context, MemberInfo memberInfo) + { + Debug.Assert(memberInfo is PropertyInfo or FieldInfo, "Member must be property or field."); + return memberInfo is PropertyInfo prop + ? context.Create(prop) + : context.Create((FieldInfo)memberInfo); + } + + public static NullabilityState GetParameterNullability(NullabilityInfoContext context, ParameterInfo parameterInfo) + { +#if NET8_0 + // Workaround for https://github.com/dotnet/runtime/issues/92487 + // The fix has been incorporated into .NET 9 (and the polyfilled implementations in netfx). + // Should be removed once .NET 8 support is dropped. + if (GetGenericParameterDefinition(parameterInfo) is { ParameterType: { IsGenericParameter: true } typeParam }) + { + // Step 1. Look for nullable annotations on the type parameter. + if (GetNullableFlags(typeParam) is byte[] flags) + { + return TranslateByte(flags[0]); + } + + // Step 2. Look for nullable annotations on the generic method declaration. + if (typeParam.DeclaringMethod != null && GetNullableContextFlag(typeParam.DeclaringMethod) is byte flag) + { + return TranslateByte(flag); + } + + // Step 3. Look for nullable annotations on the generic method declaration. + if (GetNullableContextFlag(typeParam.DeclaringType!) is byte flag2) + { + return TranslateByte(flag2); + } + + // Default to nullable. + return NullabilityState.Nullable; + + static byte[]? GetNullableFlags(MemberInfo member) + { + foreach (CustomAttributeData attr in member.GetCustomAttributesData()) + { + Type attrType = attr.AttributeType; + if (attrType.Name == "NullableAttribute" && attrType.Namespace == "System.Runtime.CompilerServices") + { + foreach (CustomAttributeTypedArgument ctorArg in attr.ConstructorArguments) + { + switch (ctorArg.Value) + { + case byte flag: + return [flag]; + case byte[] flags: + return flags; + } + } + } + } + + return null; + } + + static byte? GetNullableContextFlag(MemberInfo member) + { + foreach (CustomAttributeData attr in member.GetCustomAttributesData()) + { + Type attrType = attr.AttributeType; + if (attrType.Name == "NullableContextAttribute" && attrType.Namespace == "System.Runtime.CompilerServices") + { + foreach (CustomAttributeTypedArgument ctorArg in attr.ConstructorArguments) + { + if (ctorArg.Value is byte flag) + { + return flag; + } + } + } + } + + return null; + } + +#pragma warning disable S109 // Magic numbers should not be used + static NullabilityState TranslateByte(byte b) => b switch + { + 1 => NullabilityState.NotNull, + 2 => NullabilityState.Nullable, + _ => NullabilityState.Unknown + }; +#pragma warning restore S109 // Magic numbers should not be used + } + + static ParameterInfo GetGenericParameterDefinition(ParameterInfo parameter) + { + if (parameter.Member is { DeclaringType.IsConstructedGenericType: true } + or MethodInfo { IsGenericMethod: true, IsGenericMethodDefinition: false }) + { + var genericMethod = (MethodBase)GetGenericMemberDefinition(parameter.Member); + return genericMethod.GetParameters()[parameter.Position]; + } + + return parameter; + } + + static MemberInfo GetGenericMemberDefinition(MemberInfo member) + { + if (member is Type type) + { + return type.IsConstructedGenericType ? type.GetGenericTypeDefinition() : type; + } + + if (member.DeclaringType?.IsConstructedGenericType is true) + { + return member.DeclaringType.GetGenericTypeDefinition().GetMemberWithSameMetadataDefinitionAs(member); + } + + if (member is MethodInfo { IsGenericMethod: true, IsGenericMethodDefinition: false } method) + { + return method.GetGenericMethodDefinition(); + } + + return member; + } +#endif + return context.Create(parameterInfo).WriteState; + } + + // Taken from https://github.com/dotnet/runtime/blob/903bc019427ca07080530751151ea636168ad334/src/libraries/System.Text.Json/Common/ReflectionExtensions.cs#L288-L317 + public static object? GetNormalizedDefaultValue(ParameterInfo parameterInfo) + { + Type parameterType = parameterInfo.ParameterType; + object? defaultValue = parameterInfo.DefaultValue; + + if (defaultValue is null) + { + return null; + } + + // DBNull.Value is sometimes used as the default value (returned by reflection) of nullable params in place of null. + if (defaultValue == DBNull.Value && parameterType != typeof(DBNull)) + { + return null; + } + + // Default values of enums or nullable enums are represented using the underlying type and need to be cast explicitly + // cf. https://github.com/dotnet/runtime/issues/68647 + if (parameterType.IsEnum) + { + return Enum.ToObject(parameterType, defaultValue); + } + + if (Nullable.GetUnderlyingType(parameterType) is Type underlyingType && underlyingType.IsEnum) + { + return Enum.ToObject(underlyingType, defaultValue); + } + + return defaultValue; + } + + // Resolves the deserialization constructor for a type using logic copied from + // https://github.com/dotnet/runtime/blob/e12e2fa6cbdd1f4b0c8ad1b1e2d960a480c21703/src/libraries/System.Text.Json/Common/ReflectionExtensions.cs#L227-L286 + private static bool TryGetDeserializationConstructor( + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors | DynamicallyAccessedMemberTypes.NonPublicConstructors)] + Type type, + bool useDefaultCtorInAnnotatedStructs, + out ConstructorInfo? deserializationCtor) + { + ConstructorInfo? ctorWithAttribute = null; + ConstructorInfo? publicParameterlessCtor = null; + ConstructorInfo? lonePublicCtor = null; + + ConstructorInfo[] constructors = type.GetConstructors(BindingFlags.Public | BindingFlags.Instance); + + if (constructors.Length == 1) + { + lonePublicCtor = constructors[0]; + } + + foreach (ConstructorInfo constructor in constructors) + { + if (HasJsonConstructorAttribute(constructor)) + { + if (ctorWithAttribute != null) + { + deserializationCtor = null; + return false; + } + + ctorWithAttribute = constructor; + } + else if (constructor.GetParameters().Length == 0) + { + publicParameterlessCtor = constructor; + } + } + + // Search for non-public ctors with [JsonConstructor]. + foreach (ConstructorInfo constructor in type.GetConstructors(BindingFlags.NonPublic | BindingFlags.Instance)) + { + if (HasJsonConstructorAttribute(constructor)) + { + if (ctorWithAttribute != null) + { + deserializationCtor = null; + return false; + } + + ctorWithAttribute = constructor; + } + } + + // Structs will use default constructor if attribute isn't used. + if (useDefaultCtorInAnnotatedStructs && type.IsValueType && ctorWithAttribute == null) + { + deserializationCtor = null; + return true; + } + + deserializationCtor = ctorWithAttribute ?? publicParameterlessCtor ?? lonePublicCtor; + return true; + + static bool HasJsonConstructorAttribute(ConstructorInfo constructorInfo) => + constructorInfo.GetCustomAttribute() != null; + } + + // Parameter to property matching semantics as declared in + // https://github.com/dotnet/runtime/blob/12d96ccfaed98e23c345188ee08f8cfe211c03e7/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Metadata/JsonTypeInfo.cs#L1007-L1030 + private readonly struct ParameterLookupKey : IEquatable + { + public ParameterLookupKey(string name, Type type) + { + Name = name; + Type = type; + } + + public string Name { get; } + public Type Type { get; } + + public override int GetHashCode() => StringComparer.OrdinalIgnoreCase.GetHashCode(Name); + public bool Equals(ParameterLookupKey other) => Type == other.Type && string.Equals(Name, other.Name, StringComparison.OrdinalIgnoreCase); + public override bool Equals(object? obj) => obj is ParameterLookupKey key && Equals(key); + } + } + +#if !NET + private static MemberInfo GetMemberWithSameMetadataDefinitionAs(this Type specializedType, MemberInfo member) + { + const BindingFlags All = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance; + return specializedType.GetMember(member.Name, member.MemberType, All).First(m => m.MetadataToken == member.MetadataToken); + } +#endif +} +#endif diff --git a/src/Shared/JsonSchemaExporter/JsonSchemaExporter.cs b/src/Shared/JsonSchemaExporter/JsonSchemaExporter.cs new file mode 100644 index 00000000000..5c6ce6d9ab7 --- /dev/null +++ b/src/Shared/JsonSchemaExporter/JsonSchemaExporter.cs @@ -0,0 +1,801 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +#if !NET9_0_OR_GREATER +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Globalization; +using System.Linq; +using System.Reflection; +#if NET +using System.Runtime.InteropServices; +#endif +using System.Text.Json.Nodes; +using System.Text.Json.Serialization; +using System.Text.Json.Serialization.Metadata; +using Microsoft.Shared.Diagnostics; + +#pragma warning disable LA0002 // Use 'Microsoft.Shared.Text.NumericExtensions.ToInvariantString' for improved performance +#pragma warning disable S107 // Methods should not have too many parameters +#pragma warning disable S1121 // Assignments should not be made from within sub-expressions + +namespace System.Text.Json.Schema; + +/// +/// Maps .NET types to JSON schema objects using contract metadata from instances. +/// +#if !SHARED_PROJECT +[System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] +#endif +internal static partial class JsonSchemaExporter +{ + // Polyfill implementation of JsonSchemaExporter for System.Text.Json version 8.0.0. + // Uses private reflection to access metadata not available with the older APIs of STJ. + + private const string RequiresUnreferencedCodeMessage = + "Uses private reflection on System.Text.Json components to access converter metadata. " + + "If running Native AOT ensure that the 'IlcTrimMetadata' property has been disabled."; + + /// + /// Generates a JSON schema corresponding to the contract metadata of the specified type. + /// + /// The options instance from which to resolve the contract metadata. + /// The root type for which to generate the JSON schema. + /// The exporterOptions object controlling the schema generation. + /// A new instance defining the JSON schema for . + /// One of the specified parameters is . + /// The parameter contains unsupported exporterOptions. + [RequiresUnreferencedCode(RequiresUnreferencedCodeMessage)] + public static JsonNode GetJsonSchemaAsNode(this JsonSerializerOptions options, Type type, JsonSchemaExporterOptions? exporterOptions = null) + { + _ = Throw.IfNull(options); + _ = Throw.IfNull(type); + ValidateOptions(options); + + exporterOptions ??= JsonSchemaExporterOptions.Default; + JsonTypeInfo typeInfo = options.GetTypeInfo(type); + return MapRootTypeJsonSchema(typeInfo, exporterOptions); + } + + /// + /// Generates a JSON schema corresponding to the specified contract metadata. + /// + /// The contract metadata for which to generate the schema. + /// The exporterOptions object controlling the schema generation. + /// A new instance defining the JSON schema for . + /// One of the specified parameters is . + /// The parameter contains unsupported exporterOptions. + [RequiresUnreferencedCode(RequiresUnreferencedCodeMessage)] + public static JsonNode GetJsonSchemaAsNode(this JsonTypeInfo typeInfo, JsonSchemaExporterOptions? exporterOptions = null) + { + _ = Throw.IfNull(typeInfo); + ValidateOptions(typeInfo.Options); + + exporterOptions ??= JsonSchemaExporterOptions.Default; + return MapRootTypeJsonSchema(typeInfo, exporterOptions); + } + + [RequiresUnreferencedCode(RequiresUnreferencedCodeMessage)] + private static JsonNode MapRootTypeJsonSchema(JsonTypeInfo typeInfo, JsonSchemaExporterOptions exporterOptions) + { + GenerationState state = new(exporterOptions, typeInfo.Options); + JsonSchema schema = MapJsonSchemaCore(ref state, typeInfo); + return schema.ToJsonNode(exporterOptions); + } + + [RequiresUnreferencedCode(RequiresUnreferencedCodeMessage)] + private static JsonSchema MapJsonSchemaCore( + ref GenerationState state, + JsonTypeInfo typeInfo, + Type? parentType = null, + JsonPropertyInfo? propertyInfo = null, + ICustomAttributeProvider? propertyAttributeProvider = null, + ParameterInfo? parameterInfo = null, + bool isNonNullableType = false, + JsonConverter? customConverter = null, + JsonNumberHandling? customNumberHandling = null, + JsonTypeInfo? parentPolymorphicTypeInfo = null, + bool parentPolymorphicTypeContainsTypesWithoutDiscriminator = false, + bool parentPolymorphicTypeIsNonNullable = false, + KeyValuePair? typeDiscriminator = null, + bool cacheResult = true) + { + Debug.Assert(typeInfo.IsReadOnly, "The specified contract must have been made read-only."); + + JsonSchemaExporterContext exporterContext = state.CreateContext(typeInfo, parentPolymorphicTypeInfo, parentType, propertyInfo, parameterInfo, propertyAttributeProvider); + + if (cacheResult && typeInfo.Kind is not JsonTypeInfoKind.None && + state.TryGetExistingJsonPointer(exporterContext, out string? existingJsonPointer)) + { + // The schema context has already been generated in the schema document, return a reference to it. + return CompleteSchema(ref state, new JsonSchema { Ref = existingJsonPointer }); + } + + JsonSchema schema; + JsonConverter effectiveConverter = customConverter ?? typeInfo.Converter; + JsonNumberHandling effectiveNumberHandling = customNumberHandling ?? typeInfo.NumberHandling ?? typeInfo.Options.NumberHandling; + + if (!ReflectionHelpers.IsBuiltInConverter(effectiveConverter)) + { + // Return a `true` schema for types with user-defined converters. + return CompleteSchema(ref state, JsonSchema.True); + } + + if (parentPolymorphicTypeInfo is null && typeInfo.PolymorphismOptions is { DerivedTypes.Count: > 0 } polyOptions) + { + // This is the base type of a polymorphic type hierarchy. The schema for this type + // will include an "anyOf" property with the schemas for all derived types. + + string typeDiscriminatorKey = polyOptions.TypeDiscriminatorPropertyName; + List derivedTypes = polyOptions.DerivedTypes.ToList(); + + if (!typeInfo.Type.IsAbstract && !derivedTypes.Any(derived => derived.DerivedType == typeInfo.Type)) + { + // For non-abstract base types that haven't been explicitly configured, + // add a trivial schema to the derived types since we should support it. + derivedTypes.Add(new JsonDerivedType(typeInfo.Type)); + } + + bool containsTypesWithoutDiscriminator = derivedTypes.Exists(static derivedTypes => derivedTypes.TypeDiscriminator is null); + JsonSchemaType schemaType = JsonSchemaType.Any; + List? anyOf = new(derivedTypes.Count); + + state.PushSchemaNode(JsonSchemaConstants.AnyOfPropertyName); + + foreach (JsonDerivedType derivedType in derivedTypes) + { + Debug.Assert(derivedType.TypeDiscriminator is null or int or string, "Type discriminator does not have the expected type."); + + KeyValuePair? derivedTypeDiscriminator = null; + if (derivedType.TypeDiscriminator is { } discriminatorValue) + { + JsonNode discriminatorNode = discriminatorValue switch + { + string stringId => (JsonNode)stringId, + _ => (JsonNode)(int)discriminatorValue, + }; + + JsonSchema discriminatorSchema = new() { Constant = discriminatorNode }; + derivedTypeDiscriminator = new(typeDiscriminatorKey, discriminatorSchema); + } + + JsonTypeInfo derivedTypeInfo = typeInfo.Options.GetTypeInfo(derivedType.DerivedType); + + state.PushSchemaNode(anyOf.Count.ToString(CultureInfo.InvariantCulture)); + JsonSchema derivedSchema = MapJsonSchemaCore( + ref state, + derivedTypeInfo, + parentPolymorphicTypeInfo: typeInfo, + typeDiscriminator: derivedTypeDiscriminator, + parentPolymorphicTypeContainsTypesWithoutDiscriminator: containsTypesWithoutDiscriminator, + parentPolymorphicTypeIsNonNullable: isNonNullableType, + cacheResult: false); + + state.PopSchemaNode(); + + // Determine if all derived schemas have the same type. + if (anyOf.Count == 0) + { + schemaType = derivedSchema.Type; + } + else if (schemaType != derivedSchema.Type) + { + schemaType = JsonSchemaType.Any; + } + + anyOf.Add(derivedSchema); + } + + state.PopSchemaNode(); + + if (schemaType is not JsonSchemaType.Any) + { + // If all derived types have the same schema type, we can simplify the schema + // by moving the type keyword to the base schema and removing it from the derived schemas. + foreach (JsonSchema derivedSchema in anyOf) + { + derivedSchema.Type = JsonSchemaType.Any; + + if (derivedSchema.KeywordCount == 0) + { + // if removing the type results in an empty schema, + // remove the anyOf array entirely since it's always true. + anyOf = null; + break; + } + } + } + + schema = new() + { + Type = schemaType, + AnyOf = anyOf, + + // If all derived types have a discriminator, we can require it in the base schema. + Required = containsTypesWithoutDiscriminator ? null : new() { typeDiscriminatorKey }, + }; + + return CompleteSchema(ref state, schema); + } + + if (Nullable.GetUnderlyingType(typeInfo.Type) is Type nullableElementType) + { + JsonTypeInfo elementTypeInfo = typeInfo.Options.GetTypeInfo(nullableElementType); + customConverter = ExtractCustomNullableConverter(customConverter); + schema = MapJsonSchemaCore(ref state, elementTypeInfo, customConverter: customConverter, cacheResult: false); + + if (schema.Enum != null) + { + Debug.Assert(elementTypeInfo.Type.IsEnum, "The enum keyword should only be populated by schemas for enum types."); + schema.Enum.Add(null); // Append null to the enum array. + } + + return CompleteSchema(ref state, schema); + } + + switch (typeInfo.Kind) + { + case JsonTypeInfoKind.Object: + List>? properties = null; + List? required = null; + JsonSchema? additionalProperties = null; + + JsonUnmappedMemberHandling effectiveUnmappedMemberHandling = typeInfo.UnmappedMemberHandling ?? typeInfo.Options.UnmappedMemberHandling; + if (effectiveUnmappedMemberHandling is JsonUnmappedMemberHandling.Disallow) + { + // Disallow unspecified properties. + additionalProperties = JsonSchema.False; + } + + if (typeDiscriminator is { } typeDiscriminatorPair) + { + (properties = new()).Add(typeDiscriminatorPair); + if (parentPolymorphicTypeContainsTypesWithoutDiscriminator) + { + // Require the discriminator here since it's not common to all derived types. + (required = new()).Add(typeDiscriminatorPair.Key); + } + } + + Func? parameterInfoMapper = + ReflectionHelpers.ResolveJsonConstructorParameterMapper(typeInfo.Type, typeInfo); + + state.PushSchemaNode(JsonSchemaConstants.PropertiesPropertyName); + foreach (JsonPropertyInfo property in typeInfo.Properties) + { + if (property is { Get: null, Set: null } or { IsExtensionData: true }) + { + continue; // Skip JsonIgnored properties and extension data + } + + JsonNumberHandling? propertyNumberHandling = property.NumberHandling ?? effectiveNumberHandling; + JsonTypeInfo propertyTypeInfo = typeInfo.Options.GetTypeInfo(property.PropertyType); + + // Resolve the attribute provider for the property. + ICustomAttributeProvider? attributeProvider = ReflectionHelpers.ResolveAttributeProvider(typeInfo.Type, property); + + // Declare the property as nullable if either getter or setter are nullable. + bool isNonNullableProperty = false; + if (attributeProvider is MemberInfo memberInfo) + { + NullabilityInfo nullabilityInfo = ReflectionHelpers.GetMemberNullability(state.NullabilityInfoContext, memberInfo); + isNonNullableProperty = + (property.Get is null || nullabilityInfo.ReadState is NullabilityState.NotNull) && + (property.Set is null || nullabilityInfo.WriteState is NullabilityState.NotNull); + } + + bool isRequired = property.IsRequired; + bool hasDefaultValue = false; + JsonNode? defaultValue = null; + + ParameterInfo? associatedParameter = parameterInfoMapper?.Invoke(property); + if (associatedParameter != null) + { + ResolveParameterInfo( + associatedParameter, + propertyTypeInfo, + state.NullabilityInfoContext, + out hasDefaultValue, + out defaultValue, + out bool isNonNullableParameter, + ref isRequired); + + isNonNullableProperty &= isNonNullableParameter; + } + + state.PushSchemaNode(property.Name); + JsonSchema propertySchema = MapJsonSchemaCore( + ref state, + propertyTypeInfo, + parentType: typeInfo.Type, + propertyInfo: property, + parameterInfo: associatedParameter, + propertyAttributeProvider: attributeProvider, + isNonNullableType: isNonNullableProperty, + customConverter: property.CustomConverter, + customNumberHandling: propertyNumberHandling); + + state.PopSchemaNode(); + + if (hasDefaultValue) + { + JsonSchema.EnsureMutable(ref propertySchema); + propertySchema.DefaultValue = defaultValue; + propertySchema.HasDefaultValue = true; + } + + (properties ??= new()).Add(new(property.Name, propertySchema)); + + if (isRequired) + { + (required ??= new()).Add(property.Name); + } + } + + state.PopSchemaNode(); + return CompleteSchema(ref state, new() + { + Type = JsonSchemaType.Object, + Properties = properties, + Required = required, + AdditionalProperties = additionalProperties, + }); + + case JsonTypeInfoKind.Enumerable: + Type elementType = ReflectionHelpers.GetElementType(typeInfo); + JsonTypeInfo elementTypeInfo = typeInfo.Options.GetTypeInfo(elementType); + + if (typeDiscriminator is null) + { + state.PushSchemaNode(JsonSchemaConstants.ItemsPropertyName); + JsonSchema items = MapJsonSchemaCore(ref state, elementTypeInfo, customNumberHandling: effectiveNumberHandling); + state.PopSchemaNode(); + + return CompleteSchema(ref state, new() + { + Type = JsonSchemaType.Array, + Items = items.IsTrue ? null : items, + }); + } + else + { + // Polymorphic enumerable types are represented using a wrapping object: + // { "$type" : "discriminator", "$values" : [element1, element2, ...] } + // Which corresponds to the schema + // { "properties" : { "$type" : { "const" : "discriminator" }, "$values" : { "type" : "array", "items" : { ... } } } } + const string ValuesKeyword = "$values"; + + state.PushSchemaNode(JsonSchemaConstants.PropertiesPropertyName); + state.PushSchemaNode(ValuesKeyword); + state.PushSchemaNode(JsonSchemaConstants.ItemsPropertyName); + + JsonSchema items = MapJsonSchemaCore(ref state, elementTypeInfo, customNumberHandling: effectiveNumberHandling); + + state.PopSchemaNode(); + state.PopSchemaNode(); + state.PopSchemaNode(); + + return CompleteSchema(ref state, new() + { + Type = JsonSchemaType.Object, + Properties = new() + { + typeDiscriminator.Value, + new(ValuesKeyword, + new JsonSchema + { + Type = JsonSchemaType.Array, + Items = items.IsTrue ? null : items, + }), + }, + Required = parentPolymorphicTypeContainsTypesWithoutDiscriminator ? new() { typeDiscriminator.Value.Key } : null, + }); + } + + case JsonTypeInfoKind.Dictionary: + Type valueType = ReflectionHelpers.GetElementType(typeInfo); + JsonTypeInfo valueTypeInfo = typeInfo.Options.GetTypeInfo(valueType); + + List>? dictProps = null; + List? dictRequired = null; + + if (typeDiscriminator is { } dictDiscriminator) + { + dictProps = new() { dictDiscriminator }; + if (parentPolymorphicTypeContainsTypesWithoutDiscriminator) + { + // Require the discriminator here since it's not common to all derived types. + dictRequired = new() { dictDiscriminator.Key }; + } + } + + state.PushSchemaNode(JsonSchemaConstants.AdditionalPropertiesPropertyName); + JsonSchema valueSchema = MapJsonSchemaCore(ref state, valueTypeInfo, customNumberHandling: effectiveNumberHandling); + state.PopSchemaNode(); + + return CompleteSchema(ref state, new() + { + Type = JsonSchemaType.Object, + Properties = dictProps, + Required = dictRequired, + AdditionalProperties = valueSchema.IsTrue ? null : valueSchema, + }); + + default: + Debug.Assert(typeInfo.Kind is JsonTypeInfoKind.None, "The default case should handle unrecognize type kinds."); + + if (_simpleTypeSchemaFactories.TryGetValue(typeInfo.Type, out Func? simpleTypeSchemaFactory)) + { + schema = simpleTypeSchemaFactory(effectiveNumberHandling); + } + else if (typeInfo.Type.IsEnum) + { + schema = GetEnumConverterSchema(typeInfo, effectiveConverter); + } + else + { + schema = JsonSchema.True; + } + + return CompleteSchema(ref state, schema); + } + + JsonSchema CompleteSchema(ref GenerationState state, JsonSchema schema) + { + if (schema.Ref is null) + { + if (IsNullableSchema(ref state)) + { + schema.MakeNullable(); + } + + bool IsNullableSchema(ref GenerationState state) + { + // A schema is marked as nullable if either + // 1. We have a schema for a property where either the getter or setter are marked as nullable. + // 2. We have a schema for a reference type, unless we're explicitly treating null-oblivious types as non-nullable + + if (propertyInfo != null || parameterInfo != null) + { + return !isNonNullableType; + } + else + { + return ReflectionHelpers.CanBeNull(typeInfo.Type) && + !parentPolymorphicTypeIsNonNullable && + !state.ExporterOptions.TreatNullObliviousAsNonNullable; + } + } + } + + if (state.ExporterOptions.TransformSchemaNode != null) + { + // Prime the schema for invocation by the JsonNode transformer. + schema.GenerationContext = exporterContext; + } + + return schema; + } + } + + private readonly ref struct GenerationState + { + private const int DefaultMaxDepth = 64; + private readonly List _currentPath = new(); + private readonly Dictionary<(JsonTypeInfo, JsonPropertyInfo?), string[]> _generated = new(); + private readonly int _maxDepth; + + public GenerationState(JsonSchemaExporterOptions exporterOptions, JsonSerializerOptions options, NullabilityInfoContext? nullabilityInfoContext = null) + { + ExporterOptions = exporterOptions; + NullabilityInfoContext = nullabilityInfoContext ?? new(); + _maxDepth = options.MaxDepth is 0 ? DefaultMaxDepth : options.MaxDepth; + } + + public JsonSchemaExporterOptions ExporterOptions { get; } + public NullabilityInfoContext NullabilityInfoContext { get; } + public int CurrentDepth => _currentPath.Count; + + public void PushSchemaNode(string nodeId) + { + if (CurrentDepth == _maxDepth) + { + ThrowHelpers.ThrowInvalidOperationException_MaxDepthReached(); + } + + _currentPath.Add(nodeId); + } + + public void PopSchemaNode() + { + _currentPath.RemoveAt(_currentPath.Count - 1); + } + + /// + /// Registers the current schema node generation context; if it has already been generated return a JSON pointer to its location. + /// + public bool TryGetExistingJsonPointer(in JsonSchemaExporterContext context, [NotNullWhen(true)] out string? existingJsonPointer) + { + (JsonTypeInfo, JsonPropertyInfo?) key = (context.TypeInfo, context.PropertyInfo); +#if NET + ref string[]? pathToSchema = ref CollectionsMarshal.GetValueRefOrAddDefault(_generated, key, out bool exists); +#else + bool exists = _generated.TryGetValue(key, out string[]? pathToSchema); +#endif + if (exists) + { + existingJsonPointer = FormatJsonPointer(pathToSchema); + return true; + } +#if NET + pathToSchema = context._path; +#else + _generated[key] = context._path; +#endif + existingJsonPointer = null; + return false; + } + + public JsonSchemaExporterContext CreateContext( + JsonTypeInfo typeInfo, + JsonTypeInfo? baseTypeInfo, + Type? declaringType, + JsonPropertyInfo? propertyInfo, + ParameterInfo? parameterInfo, + ICustomAttributeProvider? propertyAttributeProvider) + { + return new JsonSchemaExporterContext(typeInfo, baseTypeInfo, declaringType, propertyInfo, parameterInfo, propertyAttributeProvider, _currentPath.ToArray()); + } + + private static string FormatJsonPointer(ReadOnlySpan path) + { + if (path.IsEmpty) + { + return "#"; + } + + StringBuilder sb = new(); + _ = sb.Append('#'); + + for (int i = 0; i < path.Length; i++) + { + string segment = path[i]; + if (segment.AsSpan().IndexOfAny('~', '/') != -1) + { +#pragma warning disable CA1307 // Specify StringComparison for clarity + segment = segment.Replace("~", "~0").Replace("/", "~1"); +#pragma warning restore CA1307 + } + + _ = sb.Append('/'); + _ = sb.Append(segment); + } + + return sb.ToString(); + } + } + + private static readonly Dictionary> _simpleTypeSchemaFactories = new() + { + [typeof(object)] = _ => JsonSchema.True, + [typeof(bool)] = _ => new JsonSchema { Type = JsonSchemaType.Boolean }, + [typeof(byte)] = numberHandling => GetSchemaForNumericType(JsonSchemaType.Integer, numberHandling), + [typeof(ushort)] = numberHandling => GetSchemaForNumericType(JsonSchemaType.Integer, numberHandling), + [typeof(uint)] = numberHandling => GetSchemaForNumericType(JsonSchemaType.Integer, numberHandling), + [typeof(ulong)] = numberHandling => GetSchemaForNumericType(JsonSchemaType.Integer, numberHandling), + [typeof(sbyte)] = numberHandling => GetSchemaForNumericType(JsonSchemaType.Integer, numberHandling), + [typeof(short)] = numberHandling => GetSchemaForNumericType(JsonSchemaType.Integer, numberHandling), + [typeof(int)] = numberHandling => GetSchemaForNumericType(JsonSchemaType.Integer, numberHandling), + [typeof(long)] = numberHandling => GetSchemaForNumericType(JsonSchemaType.Integer, numberHandling), + [typeof(float)] = numberHandling => GetSchemaForNumericType(JsonSchemaType.Number, numberHandling, isIeeeFloatingPoint: true), + [typeof(double)] = numberHandling => GetSchemaForNumericType(JsonSchemaType.Number, numberHandling, isIeeeFloatingPoint: true), + [typeof(decimal)] = numberHandling => GetSchemaForNumericType(JsonSchemaType.Number, numberHandling), +#if NET6_0_OR_GREATER + [typeof(Half)] = numberHandling => GetSchemaForNumericType(JsonSchemaType.Number, numberHandling, isIeeeFloatingPoint: true), +#endif +#if NET7_0_OR_GREATER + [typeof(UInt128)] = numberHandling => GetSchemaForNumericType(JsonSchemaType.Integer, numberHandling), + [typeof(Int128)] = numberHandling => GetSchemaForNumericType(JsonSchemaType.Integer, numberHandling), +#endif + [typeof(char)] = _ => new JsonSchema { Type = JsonSchemaType.String, MinLength = 1, MaxLength = 1 }, + [typeof(string)] = _ => new JsonSchema { Type = JsonSchemaType.String }, + [typeof(byte[])] = _ => new JsonSchema { Type = JsonSchemaType.String }, + [typeof(Memory)] = _ => new JsonSchema { Type = JsonSchemaType.String }, + [typeof(ReadOnlyMemory)] = _ => new JsonSchema { Type = JsonSchemaType.String }, + [typeof(DateTime)] = _ => new JsonSchema { Type = JsonSchemaType.String, Format = "date-time" }, + [typeof(DateTimeOffset)] = _ => new JsonSchema { Type = JsonSchemaType.String, Format = "date-time" }, + [typeof(TimeSpan)] = _ => new JsonSchema + { + Comment = "Represents a System.TimeSpan value.", + Type = JsonSchemaType.String, + Pattern = @"^-?(\d+\.)?\d{2}:\d{2}:\d{2}(\.\d{1,7})?$", + }, + +#if NET6_0_OR_GREATER + [typeof(DateOnly)] = _ => new JsonSchema { Type = JsonSchemaType.String, Format = "date" }, + [typeof(TimeOnly)] = _ => new JsonSchema { Type = JsonSchemaType.String, Format = "time" }, +#endif + [typeof(Guid)] = _ => new JsonSchema { Type = JsonSchemaType.String, Format = "uuid" }, + [typeof(Uri)] = _ => new JsonSchema { Type = JsonSchemaType.String, Format = "uri" }, + [typeof(Version)] = _ => new JsonSchema + { + Comment = "Represents a version string.", + Type = JsonSchemaType.String, + Pattern = @"^\d+(\.\d+){1,3}$", + }, + + [typeof(JsonDocument)] = _ => JsonSchema.True, + [typeof(JsonElement)] = _ => JsonSchema.True, + [typeof(JsonNode)] = _ => JsonSchema.True, + [typeof(JsonValue)] = _ => JsonSchema.True, + [typeof(JsonObject)] = _ => new JsonSchema { Type = JsonSchemaType.Object }, + [typeof(JsonArray)] = _ => new JsonSchema { Type = JsonSchemaType.Array }, + }; + + // Adapted from https://github.com/dotnet/runtime/blob/release/9.0/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Converters/Value/JsonPrimitiveConverter.cs#L36-L69 + private static JsonSchema GetSchemaForNumericType(JsonSchemaType schemaType, JsonNumberHandling numberHandling, bool isIeeeFloatingPoint = false) + { + Debug.Assert(schemaType is JsonSchemaType.Integer or JsonSchemaType.Number, "schema type must be number or integer"); + Debug.Assert(!isIeeeFloatingPoint || schemaType is JsonSchemaType.Number, "If specifying IEEE the schema type must be number"); + + string? pattern = null; + + if ((numberHandling & (JsonNumberHandling.AllowReadingFromString | JsonNumberHandling.WriteAsString)) != 0) + { + if (schemaType is JsonSchemaType.Integer) + { + pattern = @"^-?(?:0|[1-9]\d*)$"; + } + else if (isIeeeFloatingPoint) + { + pattern = @"^-?(?:0|[1-9]\d*)(?:\.\d+)?(?:[eE][+-]?\d+)?$"; + } + else + { + pattern = @"^-?(?:0|[1-9]\d*)(?:\.\d+)?$"; + } + + schemaType |= JsonSchemaType.String; + } + + if (isIeeeFloatingPoint && (numberHandling & JsonNumberHandling.AllowNamedFloatingPointLiterals) != 0) + { + return new JsonSchema + { + AnyOf = new() + { + new JsonSchema { Type = schemaType, Pattern = pattern }, + new JsonSchema { Enum = new() { (JsonNode)"NaN", (JsonNode)"Infinity", (JsonNode)"-Infinity" } }, + }, + }; + } + + return new JsonSchema { Type = schemaType, Pattern = pattern }; + } + + private static JsonConverter? ExtractCustomNullableConverter(JsonConverter? converter) + { + Debug.Assert(converter is null || ReflectionHelpers.IsBuiltInConverter(converter), "If specified the converter must be built-in."); + + if (converter is null) + { + return null; + } + + return ReflectionHelpers.GetElementConverter(converter); + } + + private static void ValidateOptions(JsonSerializerOptions options) + { + if (options.ReferenceHandler == ReferenceHandler.Preserve) + { + ThrowHelpers.ThrowNotSupportedException_ReferenceHandlerPreserveNotSupported(); + } + + options.MakeReadOnly(); + } + + private static void ResolveParameterInfo( + ParameterInfo parameter, + JsonTypeInfo parameterTypeInfo, + NullabilityInfoContext nullabilityInfoContext, + out bool hasDefaultValue, + out JsonNode? defaultValue, + out bool isNonNullable, + ref bool isRequired) + { + Debug.Assert(parameterTypeInfo.Type == parameter.ParameterType, "The typeInfo type must match the ParameterInfo type."); + + // Incorporate the nullability information from the parameter. + isNonNullable = ReflectionHelpers.GetParameterNullability(nullabilityInfoContext, parameter) is NullabilityState.NotNull; + + if (parameter.HasDefaultValue) + { + // Append the default value to the description. + object? defaultVal = ReflectionHelpers.GetNormalizedDefaultValue(parameter); + defaultValue = JsonSerializer.SerializeToNode(defaultVal, parameterTypeInfo); + hasDefaultValue = true; + } + else + { + // Parameter is not optional, mark as required. + isRequired = true; + defaultValue = null; + hasDefaultValue = false; + } + } + + // Adapted from https://github.com/dotnet/runtime/blob/release/9.0/src/libraries/System.Text.Json/src/System/Text/Json/Serialization/Converters/Value/EnumConverter.cs#L498-L521 + private static JsonSchema GetEnumConverterSchema(JsonTypeInfo typeInfo, JsonConverter converter) + { + Debug.Assert(typeInfo.Type.IsEnum && ReflectionHelpers.IsBuiltInConverter(converter), "must be using a built-in enum converter."); + + if (converter is JsonConverterFactory factory) + { + converter = factory.CreateConverter(typeInfo.Type, typeInfo.Options)!; + } + + ReflectionHelpers.GetEnumConverterConfig(converter, out JsonNamingPolicy? namingPolicy, out bool allowString); + + if (allowString) + { + // This explicitly ignores the integer component in converters configured as AllowNumbers | AllowStrings + // which is the default for JsonStringEnumConverter. This sacrifices some precision in the schema for simplicity. + + if (typeInfo.Type.GetCustomAttribute() is not null) + { + // Do not report enum values in case of flags. + return new() { Type = JsonSchemaType.String }; + } + + JsonArray enumValues = new(); + foreach (string name in Enum.GetNames(typeInfo.Type)) + { + // This does not account for custom names specified via the new + // JsonStringEnumMemberNameAttribute introduced in .NET 9. + string effectiveName = namingPolicy?.ConvertName(name) ?? name; + enumValues.Add((JsonNode)effectiveName); + } + + return new() { Enum = enumValues }; + } + + return new() { Type = JsonSchemaType.Integer }; + } + + private static class JsonSchemaConstants + { + public const string SchemaPropertyName = "$schema"; + public const string RefPropertyName = "$ref"; + public const string CommentPropertyName = "$comment"; + public const string TitlePropertyName = "title"; + public const string DescriptionPropertyName = "description"; + public const string TypePropertyName = "type"; + public const string FormatPropertyName = "format"; + public const string PatternPropertyName = "pattern"; + public const string PropertiesPropertyName = "properties"; + public const string RequiredPropertyName = "required"; + public const string ItemsPropertyName = "items"; + public const string AdditionalPropertiesPropertyName = "additionalProperties"; + public const string EnumPropertyName = "enum"; + public const string NotPropertyName = "not"; + public const string AnyOfPropertyName = "anyOf"; + public const string ConstPropertyName = "const"; + public const string DefaultPropertyName = "default"; + public const string MinLengthPropertyName = "minLength"; + public const string MaxLengthPropertyName = "maxLength"; + } + + private static class ThrowHelpers + { + [DoesNotReturn] + public static void ThrowInvalidOperationException_MaxDepthReached() => + throw new InvalidOperationException("The depth of the generated JSON schema exceeds the JsonSerializerOptions.MaxDepth setting."); + + [DoesNotReturn] + public static void ThrowNotSupportedException_ReferenceHandlerPreserveNotSupported() => + throw new NotSupportedException("Schema generation not supported with ReferenceHandler.Preserve enabled."); + } +} +#endif diff --git a/src/Shared/JsonSchemaExporter/JsonSchemaExporterContext.cs b/src/Shared/JsonSchemaExporter/JsonSchemaExporterContext.cs new file mode 100644 index 00000000000..3602ee46df4 --- /dev/null +++ b/src/Shared/JsonSchemaExporter/JsonSchemaExporterContext.cs @@ -0,0 +1,77 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +#if !NET9_0_OR_GREATER +using System; +using System.Reflection; +using System.Text.Json.Serialization.Metadata; + +namespace System.Text.Json.Schema; + +/// +/// Defines the context in which a JSON schema within a type graph is being generated. +/// +#if !SHARED_PROJECT +[System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] +#endif +internal readonly struct JsonSchemaExporterContext +{ +#pragma warning disable IDE1006 // Naming Styles + internal readonly string[] _path; +#pragma warning restore IDE1006 // Naming Styles + + internal JsonSchemaExporterContext( + JsonTypeInfo typeInfo, + JsonTypeInfo? baseTypeInfo, + Type? declaringType, + JsonPropertyInfo? propertyInfo, + ParameterInfo? parameterInfo, + ICustomAttributeProvider? propertyAttributeProvider, + string[] path) + { + TypeInfo = typeInfo; + DeclaringType = declaringType; + BaseTypeInfo = baseTypeInfo; + PropertyInfo = propertyInfo; + ParameterInfo = parameterInfo; + PropertyAttributeProvider = propertyAttributeProvider; + _path = path; + } + + /// + /// Gets the path to the schema document currently being generated. + /// + public ReadOnlySpan Path => _path; + + /// + /// Gets the for the type being processed. + /// + public JsonTypeInfo TypeInfo { get; } + + /// + /// Gets the declaring type of the property or parameter being processed. + /// + public Type? DeclaringType { get; } + + /// + /// Gets the type info for the polymorphic base type if generated as a derived type. + /// + public JsonTypeInfo? BaseTypeInfo { get; } + + /// + /// Gets the if the schema is being generated for a property. + /// + public JsonPropertyInfo? PropertyInfo { get; } + + /// + /// Gets the if a constructor parameter + /// has been associated with the accompanying . + /// + public ParameterInfo? ParameterInfo { get; } + + /// + /// Gets the corresponding to the property or field being processed. + /// + public ICustomAttributeProvider? PropertyAttributeProvider { get; } +} +#endif diff --git a/src/Shared/JsonSchemaExporter/JsonSchemaExporterOptions.cs b/src/Shared/JsonSchemaExporter/JsonSchemaExporterOptions.cs new file mode 100644 index 00000000000..53a269ea612 --- /dev/null +++ b/src/Shared/JsonSchemaExporter/JsonSchemaExporterOptions.cs @@ -0,0 +1,38 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +#if !NET9_0_OR_GREATER +using System; +using System.Text.Json.Nodes; + +namespace System.Text.Json.Schema; + +/// +/// Controls the behavior of the class. +/// +#if !SHARED_PROJECT +[System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] +#endif +internal sealed class JsonSchemaExporterOptions +{ + /// + /// Gets the default configuration object used by . + /// + public static JsonSchemaExporterOptions Default { get; } = new(); + + /// + /// Gets a value indicating whether non-nullable schemas should be generated for null oblivious reference types. + /// + /// + /// Defaults to . Due to restrictions in the run-time representation of nullable reference types + /// most occurrences are null oblivious and are treated as nullable by the serializer. A notable exception to that rule + /// are nullability annotations of field, property and constructor parameters which are represented in the contract metadata. + /// + public bool TreatNullObliviousAsNonNullable { get; init; } + + /// + /// Gets a callback that is invoked for every schema that is generated within the type graph. + /// + public Func? TransformSchemaNode { get; init; } +} +#endif diff --git a/src/Shared/JsonSchemaExporter/NullabilityInfoContext/NullabilityInfo.cs b/src/Shared/JsonSchemaExporter/NullabilityInfoContext/NullabilityInfo.cs new file mode 100644 index 00000000000..bd9b132cd0f --- /dev/null +++ b/src/Shared/JsonSchemaExporter/NullabilityInfoContext/NullabilityInfo.cs @@ -0,0 +1,75 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +#if !NET6_0_OR_GREATER +using System.Diagnostics.CodeAnalysis; + +#pragma warning disable SA1623 // Property summary documentation should match accessors + +namespace System.Reflection +{ + /// + /// A class that represents nullability info. + /// + [ExcludeFromCodeCoverage] + internal sealed class NullabilityInfo + { + internal NullabilityInfo(Type type, NullabilityState readState, NullabilityState writeState, + NullabilityInfo? elementType, NullabilityInfo[] typeArguments) + { + Type = type; + ReadState = readState; + WriteState = writeState; + ElementType = elementType; + GenericTypeArguments = typeArguments; + } + + /// + /// The of the member or generic parameter + /// to which this NullabilityInfo belongs. + /// + public Type Type { get; } + + /// + /// The nullability read state of the member. + /// + public NullabilityState ReadState { get; internal set; } + + /// + /// The nullability write state of the member. + /// + public NullabilityState WriteState { get; internal set; } + + /// + /// If the member type is an array, gives the of the elements of the array, null otherwise. + /// + public NullabilityInfo? ElementType { get; } + + /// + /// If the member type is a generic type, gives the array of for each type parameter. + /// + public NullabilityInfo[] GenericTypeArguments { get; } + } + + /// + /// An enum that represents nullability state. + /// + internal enum NullabilityState + { + /// + /// Nullability context not enabled (oblivious). + /// + Unknown, + + /// + /// Non nullable value or reference type. + /// + NotNull, + + /// + /// Nullable value or reference type. + /// + Nullable, + } +} +#endif diff --git a/src/Shared/JsonSchemaExporter/NullabilityInfoContext/NullabilityInfoContext.cs b/src/Shared/JsonSchemaExporter/NullabilityInfoContext/NullabilityInfoContext.cs new file mode 100644 index 00000000000..3edee1b9cb8 --- /dev/null +++ b/src/Shared/JsonSchemaExporter/NullabilityInfoContext/NullabilityInfoContext.cs @@ -0,0 +1,661 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +#if !NET6_0_OR_GREATER +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Linq; + +#pragma warning disable SA1204 // Static elements should appear before instance elements +#pragma warning disable S109 // Magic numbers should not be used +#pragma warning disable S1067 // Expressions should not be too complex +#pragma warning disable S4136 // Method overloads should be grouped together +#pragma warning disable SA1202 // Elements should be ordered by access +#pragma warning disable IDE1006 // Naming Styles + +namespace System.Reflection +{ + /// + /// Provides APIs for populating nullability information/context from reflection members: + /// , , and . + /// + [ExcludeFromCodeCoverage] + internal sealed class NullabilityInfoContext + { + private const string CompilerServicesNameSpace = "System.Runtime.CompilerServices"; + private readonly Dictionary _publicOnlyModules = new(); + private readonly Dictionary _context = new(); + + [Flags] + private enum NotAnnotatedStatus + { + None = 0x0, // no restriction, all members annotated + Private = 0x1, // private members not annotated + Internal = 0x2, // internal members not annotated + } + + private NullabilityState? GetNullableContext(MemberInfo? memberInfo) + { + while (memberInfo != null) + { + if (_context.TryGetValue(memberInfo, out NullabilityState state)) + { + return state; + } + + foreach (CustomAttributeData attribute in memberInfo.GetCustomAttributesData()) + { + if (attribute.AttributeType.Name == "NullableContextAttribute" && + attribute.AttributeType.Namespace == CompilerServicesNameSpace && + attribute.ConstructorArguments.Count == 1) + { + state = TranslateByte(attribute.ConstructorArguments[0].Value); + _context.Add(memberInfo, state); + return state; + } + } + + memberInfo = memberInfo.DeclaringType; + } + + return null; + } + + /// + /// Populates for the given . + /// If the nullablePublicOnly feature is set for an assembly, like it does in .NET SDK, the private and/or internal member's + /// nullability attributes are omitted, in this case the API will return NullabilityState.Unknown state. + /// + /// The parameter which nullability info gets populated. + /// If the parameterInfo parameter is null. + /// . + public NullabilityInfo Create(ParameterInfo parameterInfo) + { + IList attributes = parameterInfo.GetCustomAttributesData(); + NullableAttributeStateParser parser = parameterInfo.Member is MethodBase method && IsPrivateOrInternalMethodAndAnnotationDisabled(method) + ? NullableAttributeStateParser.Unknown + : CreateParser(attributes); + NullabilityInfo nullability = GetNullabilityInfo(parameterInfo.Member, parameterInfo.ParameterType, parser); + + if (nullability.ReadState != NullabilityState.Unknown) + { + CheckParameterMetadataType(parameterInfo, nullability); + } + + CheckNullabilityAttributes(nullability, attributes); + return nullability; + } + + private void CheckParameterMetadataType(ParameterInfo parameter, NullabilityInfo nullability) + { + ParameterInfo? metaParameter; + MemberInfo metaMember; + + switch (parameter.Member) + { + case ConstructorInfo ctor: + var metaCtor = (ConstructorInfo)GetMemberMetadataDefinition(ctor); + metaMember = metaCtor; + metaParameter = GetMetaParameter(metaCtor, parameter); + break; + + case MethodInfo method: + MethodInfo metaMethod = GetMethodMetadataDefinition(method); + metaMember = metaMethod; + metaParameter = string.IsNullOrEmpty(parameter.Name) ? metaMethod.ReturnParameter : GetMetaParameter(metaMethod, parameter); + break; + + default: + return; + } + + if (metaParameter != null) + { + CheckGenericParameters(nullability, metaMember, metaParameter.ParameterType, parameter.Member.ReflectedType); + } + } + + private static ParameterInfo? GetMetaParameter(MethodBase metaMethod, ParameterInfo parameter) + { + var parameters = metaMethod.GetParameters(); + for (int i = 0; i < parameters.Length; i++) + { + if (parameter.Position == i && + parameter.Name == parameters[i].Name) + { + return parameters[i]; + } + } + + return null; + } + + private static MethodInfo GetMethodMetadataDefinition(MethodInfo method) + { + if (method.IsGenericMethod && !method.IsGenericMethodDefinition) + { + method = method.GetGenericMethodDefinition(); + } + + return (MethodInfo)GetMemberMetadataDefinition(method); + } + + private static void CheckNullabilityAttributes(NullabilityInfo nullability, IList attributes) + { + var codeAnalysisReadState = NullabilityState.Unknown; + var codeAnalysisWriteState = NullabilityState.Unknown; + + foreach (CustomAttributeData attribute in attributes) + { + if (attribute.AttributeType.Namespace == "System.Diagnostics.CodeAnalysis") + { + if (attribute.AttributeType.Name == "NotNullAttribute") + { + codeAnalysisReadState = NullabilityState.NotNull; + } + else if ((attribute.AttributeType.Name == "MaybeNullAttribute" || + attribute.AttributeType.Name == "MaybeNullWhenAttribute") && + codeAnalysisReadState == NullabilityState.Unknown && + !IsValueTypeOrValueTypeByRef(nullability.Type)) + { + codeAnalysisReadState = NullabilityState.Nullable; + } + else if (attribute.AttributeType.Name == "DisallowNullAttribute") + { + codeAnalysisWriteState = NullabilityState.NotNull; + } + else if (attribute.AttributeType.Name == "AllowNullAttribute" && + codeAnalysisWriteState == NullabilityState.Unknown && + !IsValueTypeOrValueTypeByRef(nullability.Type)) + { + codeAnalysisWriteState = NullabilityState.Nullable; + } + } + } + + if (codeAnalysisReadState != NullabilityState.Unknown) + { + nullability.ReadState = codeAnalysisReadState; + } + + if (codeAnalysisWriteState != NullabilityState.Unknown) + { + nullability.WriteState = codeAnalysisWriteState; + } + } + + /// + /// Populates for the given . + /// If the nullablePublicOnly feature is set for an assembly, like it does in .NET SDK, the private and/or internal member's + /// nullability attributes are omitted, in this case the API will return NullabilityState.Unknown state. + /// + /// The parameter which nullability info gets populated. + /// If the propertyInfo parameter is null. + /// . + public NullabilityInfo Create(PropertyInfo propertyInfo) + { + MethodInfo? getter = propertyInfo.GetGetMethod(true); + MethodInfo? setter = propertyInfo.GetSetMethod(true); + bool annotationsDisabled = (getter == null || IsPrivateOrInternalMethodAndAnnotationDisabled(getter)) + && (setter == null || IsPrivateOrInternalMethodAndAnnotationDisabled(setter)); + NullableAttributeStateParser parser = annotationsDisabled ? NullableAttributeStateParser.Unknown : CreateParser(propertyInfo.GetCustomAttributesData()); + NullabilityInfo nullability = GetNullabilityInfo(propertyInfo, propertyInfo.PropertyType, parser); + + if (getter != null) + { + CheckNullabilityAttributes(nullability, getter.ReturnParameter.GetCustomAttributesData()); + } + else + { + nullability.ReadState = NullabilityState.Unknown; + } + + if (setter != null) + { + CheckNullabilityAttributes(nullability, setter.GetParameters().Last().GetCustomAttributesData()); + } + else + { + nullability.WriteState = NullabilityState.Unknown; + } + + return nullability; + } + + private bool IsPrivateOrInternalMethodAndAnnotationDisabled(MethodBase method) + { + if ((method.IsPrivate || method.IsFamilyAndAssembly || method.IsAssembly) && + IsPublicOnly(method.IsPrivate, method.IsFamilyAndAssembly, method.IsAssembly, method.Module)) + { + return true; + } + + return false; + } + + /// + /// Populates for the given . + /// If the nullablePublicOnly feature is set for an assembly, like it does in .NET SDK, the private and/or internal member's + /// nullability attributes are omitted, in this case the API will return NullabilityState.Unknown state. + /// + /// The parameter which nullability info gets populated. + /// If the eventInfo parameter is null. + /// . + public NullabilityInfo Create(EventInfo eventInfo) + { + return GetNullabilityInfo(eventInfo, eventInfo.EventHandlerType!, CreateParser(eventInfo.GetCustomAttributesData())); + } + + /// + /// Populates for the given + /// If the nullablePublicOnly feature is set for an assembly, like it does in .NET SDK, the private and/or internal member's + /// nullability attributes are omitted, in this case the API will return NullabilityState.Unknown state. + /// + /// The parameter which nullability info gets populated. + /// If the fieldInfo parameter is null. + /// . + public NullabilityInfo Create(FieldInfo fieldInfo) + { + IList attributes = fieldInfo.GetCustomAttributesData(); + NullableAttributeStateParser parser = IsPrivateOrInternalFieldAndAnnotationDisabled(fieldInfo) ? NullableAttributeStateParser.Unknown : CreateParser(attributes); + NullabilityInfo nullability = GetNullabilityInfo(fieldInfo, fieldInfo.FieldType, parser); + CheckNullabilityAttributes(nullability, attributes); + return nullability; + } + + private bool IsPrivateOrInternalFieldAndAnnotationDisabled(FieldInfo fieldInfo) + { + if ((fieldInfo.IsPrivate || fieldInfo.IsFamilyAndAssembly || fieldInfo.IsAssembly) && + IsPublicOnly(fieldInfo.IsPrivate, fieldInfo.IsFamilyAndAssembly, fieldInfo.IsAssembly, fieldInfo.Module)) + { + return true; + } + + return false; + } + + private bool IsPublicOnly(bool isPrivate, bool isFamilyAndAssembly, bool isAssembly, Module module) + { + if (!_publicOnlyModules.TryGetValue(module, out NotAnnotatedStatus value)) + { + value = PopulateAnnotationInfo(module.GetCustomAttributesData()); + _publicOnlyModules.Add(module, value); + } + + if (value == NotAnnotatedStatus.None) + { + return false; + } + + if (((isPrivate || isFamilyAndAssembly) && value.HasFlag(NotAnnotatedStatus.Private)) || + (isAssembly && value.HasFlag(NotAnnotatedStatus.Internal))) + { + return true; + } + + return false; + } + + private static NotAnnotatedStatus PopulateAnnotationInfo(IList customAttributes) + { + foreach (CustomAttributeData attribute in customAttributes) + { + if (attribute.AttributeType.Name == "NullablePublicOnlyAttribute" && + attribute.AttributeType.Namespace == CompilerServicesNameSpace && + attribute.ConstructorArguments.Count == 1) + { + if (attribute.ConstructorArguments[0].Value is bool boolValue && boolValue) + { + return NotAnnotatedStatus.Internal | NotAnnotatedStatus.Private; + } + else + { + return NotAnnotatedStatus.Private; + } + } + } + + return NotAnnotatedStatus.None; + } + + private NullabilityInfo GetNullabilityInfo(MemberInfo memberInfo, Type type, NullableAttributeStateParser parser) + { + int index = 0; + NullabilityInfo nullability = GetNullabilityInfo(memberInfo, type, parser, ref index); + + if (nullability.ReadState != NullabilityState.Unknown) + { + TryLoadGenericMetaTypeNullability(memberInfo, nullability); + } + + return nullability; + } + + private NullabilityInfo GetNullabilityInfo(MemberInfo memberInfo, Type type, NullableAttributeStateParser parser, ref int index) + { + NullabilityState state = NullabilityState.Unknown; + NullabilityInfo? elementState = null; + NullabilityInfo[] genericArgumentsState = Array.Empty(); + Type underlyingType = type; + + if (underlyingType.IsByRef || underlyingType.IsPointer) + { + underlyingType = underlyingType.GetElementType()!; + } + + if (underlyingType.IsValueType) + { + if (Nullable.GetUnderlyingType(underlyingType) is { } nullableUnderlyingType) + { + underlyingType = nullableUnderlyingType; + state = NullabilityState.Nullable; + } + else + { + state = NullabilityState.NotNull; + } + + if (underlyingType.IsGenericType) + { + ++index; + } + } + else + { + if (!parser.ParseNullableState(index++, ref state) + && GetNullableContext(memberInfo) is { } contextState) + { + state = contextState; + } + + if (underlyingType.IsArray) + { + elementState = GetNullabilityInfo(memberInfo, underlyingType.GetElementType()!, parser, ref index); + } + } + + if (underlyingType.IsGenericType) + { + Type[] genericArguments = underlyingType.GetGenericArguments(); + genericArgumentsState = new NullabilityInfo[genericArguments.Length]; + + for (int i = 0; i < genericArguments.Length; i++) + { + genericArgumentsState[i] = GetNullabilityInfo(memberInfo, genericArguments[i], parser, ref index); + } + } + + return new NullabilityInfo(type, state, state, elementState, genericArgumentsState); + } + + private static NullableAttributeStateParser CreateParser(IList customAttributes) + { + foreach (CustomAttributeData attribute in customAttributes) + { + if (attribute.AttributeType.Name == "NullableAttribute" && + attribute.AttributeType.Namespace == CompilerServicesNameSpace && + attribute.ConstructorArguments.Count == 1) + { + return new NullableAttributeStateParser(attribute.ConstructorArguments[0].Value); + } + } + + return new NullableAttributeStateParser(null); + } + + private void TryLoadGenericMetaTypeNullability(MemberInfo memberInfo, NullabilityInfo nullability) + { + MemberInfo? metaMember = GetMemberMetadataDefinition(memberInfo); + Type? metaType = null; + if (metaMember is FieldInfo field) + { + metaType = field.FieldType; + } + else if (metaMember is PropertyInfo property) + { + metaType = GetPropertyMetaType(property); + } + + if (metaType != null) + { + CheckGenericParameters(nullability, metaMember!, metaType, memberInfo.ReflectedType); + } + } + + private static MemberInfo GetMemberMetadataDefinition(MemberInfo member) + { + Type? type = member.DeclaringType; + if ((type != null) && type.IsGenericType && !type.IsGenericTypeDefinition) + { + return NullabilityInfoHelpers.GetMemberWithSameMetadataDefinitionAs(type.GetGenericTypeDefinition(), member); + } + + return member; + } + + private static Type GetPropertyMetaType(PropertyInfo property) + { + if (property.GetGetMethod(true) is MethodInfo method) + { + return method.ReturnType; + } + + return property.GetSetMethod(true)!.GetParameters()[0].ParameterType; + } + + private void CheckGenericParameters(NullabilityInfo nullability, MemberInfo metaMember, Type metaType, Type? reflectedType) + { + if (metaType.IsGenericParameter) + { + if (nullability.ReadState == NullabilityState.NotNull) + { + _ = TryUpdateGenericParameterNullability(nullability, metaType, reflectedType); + } + } + else if (metaType.ContainsGenericParameters) + { + if (nullability.GenericTypeArguments.Length > 0) + { + Type[] genericArguments = metaType.GetGenericArguments(); + + for (int i = 0; i < genericArguments.Length; i++) + { + CheckGenericParameters(nullability.GenericTypeArguments[i], metaMember, genericArguments[i], reflectedType); + } + } + else if (nullability.ElementType is { } elementNullability && metaType.IsArray) + { + CheckGenericParameters(elementNullability, metaMember, metaType.GetElementType()!, reflectedType); + } + + // We could also follow this branch for metaType.IsPointer, but since pointers must be unmanaged this + // will be a no-op regardless + else if (metaType.IsByRef) + { + CheckGenericParameters(nullability, metaMember, metaType.GetElementType()!, reflectedType); + } + } + } + + private bool TryUpdateGenericParameterNullability(NullabilityInfo nullability, Type genericParameter, Type? reflectedType) + { + Debug.Assert(genericParameter.IsGenericParameter, "must be generic parameter"); + + if (reflectedType is not null + && !genericParameter.IsGenericMethodParameter() + && TryUpdateGenericTypeParameterNullabilityFromReflectedType(nullability, genericParameter, reflectedType, reflectedType)) + { + return true; + } + + if (IsValueTypeOrValueTypeByRef(nullability.Type)) + { + return true; + } + + var state = NullabilityState.Unknown; + if (CreateParser(genericParameter.GetCustomAttributesData()).ParseNullableState(0, ref state)) + { + nullability.ReadState = state; + nullability.WriteState = state; + return true; + } + + if (GetNullableContext(genericParameter) is { } contextState) + { + nullability.ReadState = contextState; + nullability.WriteState = contextState; + return true; + } + + return false; + } + + private bool TryUpdateGenericTypeParameterNullabilityFromReflectedType(NullabilityInfo nullability, Type genericParameter, Type context, Type reflectedType) + { + Debug.Assert(genericParameter.IsGenericParameter && !genericParameter.IsGenericMethodParameter(), "must be generic parameter"); + + Type contextTypeDefinition = context.IsGenericType && !context.IsGenericTypeDefinition ? context.GetGenericTypeDefinition() : context; + if (genericParameter.DeclaringType == contextTypeDefinition) + { + return false; + } + + Type? baseType = contextTypeDefinition.BaseType; + if (baseType is null) + { + return false; + } + + if (!baseType.IsGenericType + || (baseType.IsGenericTypeDefinition ? baseType : baseType.GetGenericTypeDefinition()) != genericParameter.DeclaringType) + { + return TryUpdateGenericTypeParameterNullabilityFromReflectedType(nullability, genericParameter, baseType, reflectedType); + } + + Type[] genericArguments = baseType.GetGenericArguments(); + Type genericArgument = genericArguments[genericParameter.GenericParameterPosition]; + if (genericArgument.IsGenericParameter) + { + return TryUpdateGenericParameterNullability(nullability, genericArgument, reflectedType); + } + + NullableAttributeStateParser parser = CreateParser(contextTypeDefinition.GetCustomAttributesData()); + int nullabilityStateIndex = 1; // start at 1 since index 0 is the type itself + for (int i = 0; i < genericParameter.GenericParameterPosition; i++) + { + nullabilityStateIndex += CountNullabilityStates(genericArguments[i]); + } + + return TryPopulateNullabilityInfo(nullability, parser, ref nullabilityStateIndex); + + static int CountNullabilityStates(Type type) + { + Type underlyingType = Nullable.GetUnderlyingType(type) ?? type; + if (underlyingType.IsGenericType) + { + int count = 1; + foreach (Type genericArgument in underlyingType.GetGenericArguments()) + { + count += CountNullabilityStates(genericArgument); + } + + return count; + } + + if (underlyingType.HasElementType) + { + return (underlyingType.IsArray ? 1 : 0) + CountNullabilityStates(underlyingType.GetElementType()!); + } + + return type.IsValueType ? 0 : 1; + } + } + +#pragma warning disable SA1204 // Static elements should appear before instance elements + private static bool TryPopulateNullabilityInfo(NullabilityInfo nullability, NullableAttributeStateParser parser, ref int index) +#pragma warning restore SA1204 // Static elements should appear before instance elements + { + bool isValueType = IsValueTypeOrValueTypeByRef(nullability.Type); + if (!isValueType) + { + var state = NullabilityState.Unknown; + if (!parser.ParseNullableState(index, ref state)) + { + return false; + } + + nullability.ReadState = state; + nullability.WriteState = state; + } + + if (!isValueType || (Nullable.GetUnderlyingType(nullability.Type) ?? nullability.Type).IsGenericType) + { + index++; + } + + if (nullability.GenericTypeArguments.Length > 0) + { + foreach (NullabilityInfo genericTypeArgumentNullability in nullability.GenericTypeArguments) + { + _ = TryPopulateNullabilityInfo(genericTypeArgumentNullability, parser, ref index); + } + } + else if (nullability.ElementType is { } elementTypeNullability) + { + _ = TryPopulateNullabilityInfo(elementTypeNullability, parser, ref index); + } + + return true; + } + + private static NullabilityState TranslateByte(object? value) + { + return value is byte b ? TranslateByte(b) : NullabilityState.Unknown; + } + + private static NullabilityState TranslateByte(byte b) => + b switch + { + 1 => NullabilityState.NotNull, + 2 => NullabilityState.Nullable, + _ => NullabilityState.Unknown + }; + + private static bool IsValueTypeOrValueTypeByRef(Type type) => + type.IsValueType || ((type.IsByRef || type.IsPointer) && type.GetElementType()!.IsValueType); + + private readonly struct NullableAttributeStateParser + { + private static readonly object UnknownByte = (byte)0; + + private readonly object? _nullableAttributeArgument; + + public NullableAttributeStateParser(object? nullableAttributeArgument) + { + _nullableAttributeArgument = nullableAttributeArgument; + } + + public static NullableAttributeStateParser Unknown => new(UnknownByte); + + public bool ParseNullableState(int index, ref NullabilityState state) + { + switch (_nullableAttributeArgument) + { + case byte b: + state = TranslateByte(b); + return true; + case ReadOnlyCollection args + when index < args.Count && args[index].Value is byte elementB: + state = TranslateByte(elementB); + return true; + default: + return false; + } + } + } + } +} +#endif diff --git a/src/Shared/JsonSchemaExporter/NullabilityInfoContext/NullabilityInfoHelpers.cs b/src/Shared/JsonSchemaExporter/NullabilityInfoContext/NullabilityInfoHelpers.cs new file mode 100644 index 00000000000..1ee573a0020 --- /dev/null +++ b/src/Shared/JsonSchemaExporter/NullabilityInfoContext/NullabilityInfoHelpers.cs @@ -0,0 +1,47 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +#if !NET6_0_OR_GREATER +using System.Diagnostics.CodeAnalysis; + +#pragma warning disable IDE1006 // Naming Styles +#pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields + +namespace System.Reflection +{ + /// + /// Polyfills for System.Private.CoreLib internals. + /// + [ExcludeFromCodeCoverage] + internal static class NullabilityInfoHelpers + { + public static MemberInfo GetMemberWithSameMetadataDefinitionAs(Type type, MemberInfo member) + { + const BindingFlags all = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance; + foreach (var info in type.GetMembers(all)) + { + if (info.HasSameMetadataDefinitionAs(member)) + { + return info; + } + } + + throw new MissingMemberException(type.FullName, member.Name); + } + + // https://github.com/dotnet/runtime/blob/main/src/coreclr/System.Private.CoreLib/src/System/Reflection/MemberInfo.Internal.cs + public static bool HasSameMetadataDefinitionAs(this MemberInfo target, MemberInfo other) + { + return target.MetadataToken == other.MetadataToken && + target.Module.Equals(other.Module); + } + + // https://github.com/dotnet/runtime/issues/23493 + public static bool IsGenericMethodParameter(this Type target) + { + return target.IsGenericParameter && + target.DeclaringMethod != null; + } + } +} +#endif diff --git a/src/Shared/JsonSchemaExporter/README.md b/src/Shared/JsonSchemaExporter/README.md new file mode 100644 index 00000000000..1a4d13c5841 --- /dev/null +++ b/src/Shared/JsonSchemaExporter/README.md @@ -0,0 +1,11 @@ +# JsonSchemaExporter + +Provides a polyfill for the [.NET 9 `JsonSchemaExporter` component](https://learn.microsoft.com/dotnet/standard/serialization/system-text-json/extract-schema) that is compatible with all supported targets using System.Text.Json version 8. + +To use this in your project, add the following to your `.csproj` file: + +```xml + + true + +``` diff --git a/src/Shared/Shared.csproj b/src/Shared/Shared.csproj index f6cbb03ea83..439c3788557 100644 --- a/src/Shared/Shared.csproj +++ b/src/Shared/Shared.csproj @@ -12,11 +12,12 @@ true true true - true + true true true true true + true @@ -33,6 +34,10 @@ + + + + diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/AdditionalPropertiesDictionaryTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/AdditionalPropertiesDictionaryTests.cs index a9a544c8ca8..09f515fa066 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/AdditionalPropertiesDictionaryTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/AdditionalPropertiesDictionaryTests.cs @@ -90,4 +90,45 @@ static void AssertNotFound(T1 input) Assert.Equal(default(T2), value); } } + + [Fact] + public void TryAdd_AddsOnlyIfNonExistent() + { + AdditionalPropertiesDictionary d = []; + + Assert.False(d.ContainsKey("key")); + Assert.True(d.TryAdd("key", "value")); + Assert.True(d.ContainsKey("key")); + Assert.Equal("value", d["key"]); + + Assert.False(d.TryAdd("key", "value2")); + Assert.True(d.ContainsKey("key")); + Assert.Equal("value", d["key"]); + } + + [Fact] + public void Enumerator_EnumeratesAllItems() + { + AdditionalPropertiesDictionary d = []; + + const int NumProperties = 10; + for (int i = 0; i < NumProperties; i++) + { + d.Add($"key{i}", $"value{i}"); + } + + Assert.Equal(NumProperties, d.Count); + + // This depends on an implementation detail of the ordering in which the dictionary + // enumerates items. If that ever changes, this test will need to be updated. + int count = 0; + foreach (KeyValuePair item in d) + { + Assert.Equal($"key{count}", item.Key); + Assert.Equal($"value{count}", item.Value); + count++; + } + + Assert.Equal(NumProperties, count); + } } diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatOptionsTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatOptionsTests.cs index f83169712c3..fcd40a2f446 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatOptionsTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatOptionsTests.cs @@ -19,6 +19,7 @@ public void Constructor_Parameterless_PropsDefaulted() Assert.Null(options.TopK); Assert.Null(options.FrequencyPenalty); Assert.Null(options.PresencePenalty); + Assert.Null(options.Seed); Assert.Null(options.ResponseFormat); Assert.Null(options.ModelId); Assert.Null(options.StopSequences); @@ -33,6 +34,7 @@ public void Constructor_Parameterless_PropsDefaulted() Assert.Null(clone.TopK); Assert.Null(clone.FrequencyPenalty); Assert.Null(clone.PresencePenalty); + Assert.Null(options.Seed); Assert.Null(clone.ResponseFormat); Assert.Null(clone.ModelId); Assert.Null(clone.StopSequences); @@ -69,6 +71,7 @@ public void Properties_Roundtrip() options.TopK = 42; options.FrequencyPenalty = 0.4f; options.PresencePenalty = 0.5f; + options.Seed = 12345; options.ResponseFormat = ChatResponseFormat.Json; options.ModelId = "modelId"; options.StopSequences = stopSequences; @@ -82,6 +85,7 @@ public void Properties_Roundtrip() Assert.Equal(42, options.TopK); Assert.Equal(0.4f, options.FrequencyPenalty); Assert.Equal(0.5f, options.PresencePenalty); + Assert.Equal(12345, options.Seed); Assert.Same(ChatResponseFormat.Json, options.ResponseFormat); Assert.Equal("modelId", options.ModelId); Assert.Same(stopSequences, options.StopSequences); @@ -96,6 +100,7 @@ public void Properties_Roundtrip() Assert.Equal(42, clone.TopK); Assert.Equal(0.4f, clone.FrequencyPenalty); Assert.Equal(0.5f, clone.PresencePenalty); + Assert.Equal(12345, options.Seed); Assert.Same(ChatResponseFormat.Json, clone.ResponseFormat); Assert.Equal("modelId", clone.ModelId); Assert.Equal(stopSequences, clone.StopSequences); @@ -126,6 +131,7 @@ public void JsonSerialization_Roundtrips() options.TopK = 42; options.FrequencyPenalty = 0.4f; options.PresencePenalty = 0.5f; + options.Seed = 12345; options.ResponseFormat = ChatResponseFormat.Json; options.ModelId = "modelId"; options.StopSequences = stopSequences; @@ -148,6 +154,7 @@ public void JsonSerialization_Roundtrips() Assert.Equal(42, deserialized.TopK); Assert.Equal(0.4f, deserialized.FrequencyPenalty); Assert.Equal(0.5f, deserialized.PresencePenalty); + Assert.Equal(12345, deserialized.Seed); Assert.Equal(ChatResponseFormat.Json, deserialized.ResponseFormat); Assert.NotSame(ChatResponseFormat.Json, deserialized.ResponseFormat); Assert.Equal("modelId", deserialized.ModelId); diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Microsoft.Extensions.AI.Abstractions.Tests.csproj b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Microsoft.Extensions.AI.Abstractions.Tests.csproj index 0d4d5fbfa96..911ce1b2bf8 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Microsoft.Extensions.AI.Abstractions.Tests.csproj +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Microsoft.Extensions.AI.Abstractions.Tests.csproj @@ -5,16 +5,27 @@ - $(NoWarn);CA1063;CA1861;CA2201;VSTHRD003 + $(NoWarn);CA1063;CA1861;CA2201;VSTHRD003;S104 true + true + true + true true + true + + + + + + + diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/AIJsonUtilitiesTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Utilities/AIJsonUtilitiesTests.cs similarity index 77% rename from test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/AIJsonUtilitiesTests.cs rename to test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Utilities/AIJsonUtilitiesTests.cs index db482d26804..52f9cad246d 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/AIJsonUtilitiesTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Utilities/AIJsonUtilitiesTests.cs @@ -3,7 +3,9 @@ using System.ComponentModel; using System.Text.Json; +using System.Text.Json.Nodes; using System.Text.Json.Serialization; +using Microsoft.Extensions.AI.JsonSchemaExporter; using Xunit; namespace Microsoft.Extensions.AI; @@ -130,7 +132,7 @@ public static void ResolveParameterJsonSchema_ReturnsExpectedValue() } [Fact] - public static void ResolveParameterJsonSchema_TreatsIntegralTypesAsInteger_EvenWithAllowReadingFromString() + public static void CreateParameterJsonSchema_TreatsIntegralTypesAsInteger_EvenWithAllowReadingFromString() { JsonElement expected = JsonDocument.Parse(""" { @@ -158,4 +160,38 @@ public enum MyEnumValue A = 1, B = 2 } + + [Fact] + public static void CreateJsonSchema_CanBeBoolean() + { + JsonElement schema = AIJsonUtilities.CreateJsonSchema(typeof(object)); + Assert.Equal(JsonValueKind.True, schema.ValueKind); + } + + [Theory] + [MemberData(nameof(TestTypes.GetTestDataUsingAllValues), MemberType = typeof(TestTypes))] + public static void CreateJsonSchema_ValidateWithTestData(ITestData testData) + { + // Stress tests the schema generation method using types from the JsonSchemaExporter test battery. + + JsonSerializerOptions options = testData.Options is { } opts + ? new(opts) { TypeInfoResolver = TestTypes.TestTypesContext.Default } + : TestTypes.TestTypesContext.Default.Options; + + JsonElement schema = AIJsonUtilities.CreateJsonSchema(testData.Type, serializerOptions: options); + JsonNode? schemaAsNode = JsonSerializer.SerializeToNode(schema, options); + + Assert.NotNull(schemaAsNode); + Assert.Equal(testData.ExpectedJsonSchema.GetValueKind(), schemaAsNode.GetValueKind()); + + if (testData.Value is null || testData.WritesNumbersAsStrings) + { + // By design, our generated schema does not accept null root values + // or numbers formatted as strings, so we skip schema validation. + return; + } + + JsonNode? serializedValue = JsonSerializer.SerializeToNode(testData.Value, testData.Type, options); + SchemaTestHelpers.AssertDocumentMatchesSchema(schemaAsNode, serializedValue); + } } diff --git a/test/Libraries/Microsoft.Extensions.AI.AotCompatibility.TestApp/Microsoft.Extensions.AI.AotCompatibility.TestApp.csproj b/test/Libraries/Microsoft.Extensions.AI.AotCompatibility.TestApp/Microsoft.Extensions.AI.AotCompatibility.TestApp.csproj new file mode 100644 index 00000000000..183cd150937 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.AotCompatibility.TestApp/Microsoft.Extensions.AI.AotCompatibility.TestApp.csproj @@ -0,0 +1,26 @@ + + + + Exe + $(LatestTargetFramework) + true + false + true + + + + + + + + + + + + + + diff --git a/test/Libraries/Microsoft.Extensions.AI.AotCompatibility.TestApp/Program.cs b/test/Libraries/Microsoft.Extensions.AI.AotCompatibility.TestApp/Program.cs new file mode 100644 index 00000000000..b518dfa7739 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.AotCompatibility.TestApp/Program.cs @@ -0,0 +1,22 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +#pragma warning disable S125 // Remove this commented out code + +using Microsoft.Extensions.AI; + +// Use types from each library. + +// Microsoft.Extensions.AI.Ollama +using var b = new OllamaChatClient("http://localhost:11434", "llama3.2"); + +// Microsoft.Extensions.AI.AzureAIInference +// using var a = new Azure.AI.Inference.ChatCompletionClient(new Uri("http://localhost"), new("apikey")); // uncomment once warnings in Azure.AI.Inference are addressed + +// Microsoft.Extensions.AI.OpenAI +// using var c = new OpenAI.OpenAIClient("apikey").AsChatClient("gpt-4o-mini"); // uncomment once warnings in OpenAI are addressed + +// Microsoft.Extensions.AI +AIFunctionFactory.Create(() => { }); + +System.Console.WriteLine("Success!"); diff --git a/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/AzureAIInferenceChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/AzureAIInferenceChatClientTests.cs index 4fb5122cc93..f404f5e61ef 100644 --- a/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/AzureAIInferenceChatClientTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/AzureAIInferenceChatClientTests.cs @@ -247,8 +247,8 @@ public async Task MultipleMessages_NonStreaming() ], "presence_penalty": 0.5, "frequency_penalty": 0.75, - "model": "gpt-4o-mini", - "seed": 42 + "seed": 42, + "model": "gpt-4o-mini" } """; @@ -303,7 +303,7 @@ public async Task MultipleMessages_NonStreaming() FrequencyPenalty = 0.75f, PresencePenalty = 0.5f, StopSequences = ["great"], - AdditionalProperties = new() { ["seed"] = 42L }, + Seed = 42, }); Assert.NotNull(response); diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs index 0863e31db37..e9c2bd57d65 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs @@ -6,6 +6,7 @@ using System.ComponentModel; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; +using System.IO; using System.Linq; using System.Runtime.InteropServices; using System.Text; @@ -132,6 +133,27 @@ public virtual async Task CompleteStreamingAsync_UsageDataAvailable() Assert.Equal(usage.Details.InputTokenCount + usage.Details.OutputTokenCount, usage.Details.TotalTokenCount); } + protected virtual string? GetModel_MultiModal_DescribeImage() => null; + + [ConditionalFact] + public virtual async Task MultiModal_DescribeImage() + { + SkipIfNotEnabled(); + + var response = await _chatClient.CompleteAsync( + [ + new(ChatRole.User, + [ + new TextContent("What does this logo say?"), + new ImageContent(GetImageDataUri()), + ]) + ], + new() { ModelId = GetModel_MultiModal_DescribeImage() }); + + Assert.Single(response.Choices); + Assert.True(response.Message.Text?.IndexOf("net", StringComparison.OrdinalIgnoreCase) >= 0, response.Message.Text); + } + [ConditionalFact] public virtual async Task FunctionInvocation_AutomaticallyInvokeFunction_Parameterless() { @@ -714,6 +736,15 @@ private enum JobType Unknown, } + private static Uri GetImageDataUri() + { + using Stream? s = typeof(ChatClientIntegrationTests).Assembly.GetManifestResourceStream("Microsoft.Extensions.AI.dotnet.png"); + Assert.NotNull(s); + MemoryStream ms = new(); + s.CopyTo(ms); + return new Uri($"data:image/png;base64,{Convert.ToBase64String(ms.ToArray())}"); + } + [MemberNotNull(nameof(_chatClient))] protected void SkipIfNotEnabled() { diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/Microsoft.Extensions.AI.Integration.Tests.csproj b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/Microsoft.Extensions.AI.Integration.Tests.csproj index e38ccd3268b..04d9bc6d29f 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/Microsoft.Extensions.AI.Integration.Tests.csproj +++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/Microsoft.Extensions.AI.Integration.Tests.csproj @@ -15,6 +15,10 @@ true + + + + diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/dotnet.png b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/dotnet.png new file mode 100644 index 00000000000..fb00ecf91e4 Binary files /dev/null and b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/dotnet.png differ diff --git a/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaChatClientIntegrationTests.cs b/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaChatClientIntegrationTests.cs index 891378c0e86..4c71690baaf 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaChatClientIntegrationTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaChatClientIntegrationTests.cs @@ -30,6 +30,8 @@ public override Task FunctionInvocation_RequireAny() => public override Task FunctionInvocation_RequireSpecific() => throw new SkipTestException("Ollama does not currently support requiring function invocation."); + protected override string? GetModel_MultiModal_DescribeImage() => "llava"; + [ConditionalFact] public async Task PromptBasedFunctionCalling_NoArgs() { @@ -47,7 +49,7 @@ public async Task PromptBasedFunctionCalling_NoArgs() ModelId = "llama3:8b", Tools = [AIFunctionFactory.Create(() => secretNumber, "GetSecretNumber")], Temperature = 0, - AdditionalProperties = new() { ["seed"] = 0L }, + Seed = 0, }); Assert.Single(response.Choices); @@ -81,7 +83,7 @@ public async Task PromptBasedFunctionCalling_WithArgs() { Tools = [stockPriceTool, irrelevantTool], Temperature = 0, - AdditionalProperties = new() { ["seed"] = 0L }, + Seed = 0, }); Assert.Single(response.Choices); diff --git a/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaChatClientTests.cs index 3e281173c8b..67b10e3f24b 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaChatClientTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaChatClientTests.cs @@ -254,7 +254,7 @@ public async Task MultipleMessages_NonStreaming() FrequencyPenalty = 0.75f, PresencePenalty = 0.5f, StopSequences = ["great"], - AdditionalProperties = new() { ["seed"] = 42 }, + Seed = 42, }); Assert.NotNull(response); diff --git a/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIChatClientTests.cs index 691804e5fb8..05d2f5a22ff 100644 --- a/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIChatClientTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIChatClientTests.cs @@ -348,7 +348,7 @@ public async Task MultipleMessages_NonStreaming() FrequencyPenalty = 0.75f, PresencePenalty = 0.5f, StopSequences = ["great"], - AdditionalProperties = new() { ["seed"] = 42 }, + Seed = 42, }); Assert.NotNull(response); diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ConfigureOptionsChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ConfigureOptionsChatClientTests.cs index a27761c99ec..a911340813f 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ConfigureOptionsChatClientTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ConfigureOptionsChatClientTests.cs @@ -26,11 +26,13 @@ public void UseChatOptions_InvalidArgs_Throws() Assert.Throws("configureOptions", () => builder.UseChatOptions(null!)); } - [Fact] - public async Task ConfigureOptions_ReturnedInstancePassedToNextClient() + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task ConfigureOptions_ReturnedInstancePassedToNextClient(bool nullReturned) { ChatOptions providedOptions = new(); - ChatOptions returnedOptions = new(); + ChatOptions? returnedOptions = nullReturned ? null : new(); ChatCompletion expectedCompletion = new(Array.Empty()); var expectedUpdates = Enumerable.Range(0, 3).Select(i => new StreamingChatCompletionUpdate()).ToArray(); using CancellationTokenSource cts = new(); diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/ConfigureOptionsEmbeddingGeneratorTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/ConfigureOptionsEmbeddingGeneratorTests.cs new file mode 100644 index 00000000000..b8a4b82cb59 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/ConfigureOptionsEmbeddingGeneratorTests.cs @@ -0,0 +1,58 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class ConfigureOptionsEmbeddingGeneratorTests +{ + [Fact] + public void ConfigureOptionsEmbeddingGenerator_InvalidArgs_Throws() + { + Assert.Throws("innerGenerator", () => new ConfigureOptionsEmbeddingGenerator>(null!, _ => new EmbeddingGenerationOptions())); + Assert.Throws("configureOptions", () => new ConfigureOptionsEmbeddingGenerator>(new TestEmbeddingGenerator(), null!)); + } + + [Fact] + public void UseEmbeddingGenerationOptions_InvalidArgs_Throws() + { + var builder = new EmbeddingGeneratorBuilder>(); + Assert.Throws("configureOptions", () => builder.UseEmbeddingGenerationOptions(null!)); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task ConfigureOptions_ReturnedInstancePassedToNextClient(bool nullReturned) + { + EmbeddingGenerationOptions providedOptions = new(); + EmbeddingGenerationOptions? returnedOptions = nullReturned ? null : new(); + GeneratedEmbeddings> expectedEmbeddings = []; + using CancellationTokenSource cts = new(); + + using IEmbeddingGenerator> innerGenerator = new TestEmbeddingGenerator + { + GenerateAsyncCallback = (inputs, options, cancellationToken) => + { + Assert.Same(returnedOptions, options); + Assert.Equal(cts.Token, cancellationToken); + return Task.FromResult(expectedEmbeddings); + } + }; + + using var generator = new EmbeddingGeneratorBuilder>() + .UseEmbeddingGenerationOptions(options => + { + Assert.Same(providedOptions, options); + return returnedOptions; + }) + .Use(innerGenerator); + + var embeddings = await generator.GenerateAsync([], providedOptions, cts.Token); + Assert.Same(expectedEmbeddings, embeddings); + } +} diff --git a/test/Libraries/Microsoft.Extensions.Caching.Hybrid.Tests/HybridCacheEventSourceTests.cs b/test/Libraries/Microsoft.Extensions.Caching.Hybrid.Tests/HybridCacheEventSourceTests.cs new file mode 100644 index 00000000000..3a266af7ce3 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.Caching.Hybrid.Tests/HybridCacheEventSourceTests.cs @@ -0,0 +1,205 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics.Tracing; +using Microsoft.Extensions.Caching.Hybrid.Internal; +using Xunit.Abstractions; + +namespace Microsoft.Extensions.Caching.Hybrid.Tests; + +public class HybridCacheEventSourceTests(ITestOutputHelper log, TestEventListener listener) : IClassFixture +{ + // see notes in TestEventListener for context on fixture usage + + [SkippableFact] + public void MatchesNameAndGuid() + { + // Assert + Assert.Equal("Microsoft-Extensions-HybridCache", listener.Source.Name); + Assert.Equal(Guid.Parse("b3aca39e-5dc9-5e21-f669-b72225b66cfc"), listener.Source.Guid); // from name + } + + [SkippableFact] + public async Task LocalCacheHit() + { + AssertEnabled(); + + listener.Reset().Source.LocalCacheHit(); + listener.AssertSingleEvent(HybridCacheEventSource.EventIdLocalCacheHit, "LocalCacheHit", EventLevel.Verbose); + + await AssertCountersAsync(); + listener.AssertCounter("total-local-cache-hits", "Total Local Cache Hits", 1); + listener.AssertRemainingCountersZero(); + } + + [SkippableFact] + public async Task LocalCacheMiss() + { + AssertEnabled(); + + listener.Reset().Source.LocalCacheMiss(); + listener.AssertSingleEvent(HybridCacheEventSource.EventIdLocalCacheMiss, "LocalCacheMiss", EventLevel.Verbose); + + await AssertCountersAsync(); + listener.AssertCounter("total-local-cache-misses", "Total Local Cache Misses", 1); + listener.AssertRemainingCountersZero(); + } + + [SkippableFact] + public async Task DistributedCacheGet() + { + AssertEnabled(); + + listener.Reset().Source.DistributedCacheGet(); + listener.AssertSingleEvent(HybridCacheEventSource.EventIdDistributedCacheGet, "DistributedCacheGet", EventLevel.Verbose); + + await AssertCountersAsync(); + listener.AssertCounter("current-distributed-cache-fetches", "Current Distributed Cache Fetches", 1); + listener.AssertRemainingCountersZero(); + } + + [SkippableFact] + public async Task DistributedCacheHit() + { + AssertEnabled(); + + listener.Reset().Source.DistributedCacheGet(); + listener.Reset(resetCounters: false).Source.DistributedCacheHit(); + listener.AssertSingleEvent(HybridCacheEventSource.EventIdDistributedCacheHit, "DistributedCacheHit", EventLevel.Verbose); + + await AssertCountersAsync(); + listener.AssertCounter("total-distributed-cache-hits", "Total Distributed Cache Hits", 1); + listener.AssertRemainingCountersZero(); + } + + [SkippableFact] + public async Task DistributedCacheMiss() + { + AssertEnabled(); + + listener.Reset().Source.DistributedCacheGet(); + listener.Reset(resetCounters: false).Source.DistributedCacheMiss(); + listener.AssertSingleEvent(HybridCacheEventSource.EventIdDistributedCacheMiss, "DistributedCacheMiss", EventLevel.Verbose); + + await AssertCountersAsync(); + listener.AssertCounter("total-distributed-cache-misses", "Total Distributed Cache Misses", 1); + listener.AssertRemainingCountersZero(); + } + + [SkippableFact] + public async Task DistributedCacheFailed() + { + AssertEnabled(); + + listener.Reset().Source.DistributedCacheGet(); + listener.Reset(resetCounters: false).Source.DistributedCacheFailed(); + listener.AssertSingleEvent(HybridCacheEventSource.EventIdDistributedCacheFailed, "DistributedCacheFailed", EventLevel.Error); + + await AssertCountersAsync(); + listener.AssertRemainingCountersZero(); + } + + [SkippableFact] + public async Task UnderlyingDataQueryStart() + { + AssertEnabled(); + + listener.Reset().Source.UnderlyingDataQueryStart(); + listener.AssertSingleEvent(HybridCacheEventSource.EventIdUnderlyingDataQueryStart, "UnderlyingDataQueryStart", EventLevel.Verbose); + + await AssertCountersAsync(); + listener.AssertCounter("current-data-query", "Current Data Queries", 1); + listener.AssertCounter("total-data-query", "Total Data Queries", 1); + listener.AssertRemainingCountersZero(); + } + + [SkippableFact] + public async Task UnderlyingDataQueryComplete() + { + AssertEnabled(); + + listener.Reset().Source.UnderlyingDataQueryStart(); + listener.Reset(resetCounters: false).Source.UnderlyingDataQueryComplete(); + listener.AssertSingleEvent(HybridCacheEventSource.EventIdUnderlyingDataQueryComplete, "UnderlyingDataQueryComplete", EventLevel.Verbose); + + await AssertCountersAsync(); + listener.AssertCounter("total-data-query", "Total Data Queries", 1); + listener.AssertRemainingCountersZero(); + } + + [SkippableFact] + public async Task UnderlyingDataQueryFailed() + { + AssertEnabled(); + + listener.Reset().Source.UnderlyingDataQueryStart(); + listener.Reset(resetCounters: false).Source.UnderlyingDataQueryFailed(); + listener.AssertSingleEvent(HybridCacheEventSource.EventIdUnderlyingDataQueryFailed, "UnderlyingDataQueryFailed", EventLevel.Error); + + await AssertCountersAsync(); + listener.AssertCounter("total-data-query", "Total Data Queries", 1); + listener.AssertRemainingCountersZero(); + } + + [SkippableFact] + public async Task LocalCacheWrite() + { + AssertEnabled(); + + listener.Reset().Source.LocalCacheWrite(); + listener.AssertSingleEvent(HybridCacheEventSource.EventIdLocalCacheWrite, "LocalCacheWrite", EventLevel.Verbose); + + await AssertCountersAsync(); + listener.AssertCounter("total-local-cache-writes", "Total Local Cache Writes", 1); + listener.AssertRemainingCountersZero(); + } + + [SkippableFact] + public async Task DistributedCacheWrite() + { + AssertEnabled(); + + listener.Reset().Source.DistributedCacheWrite(); + listener.AssertSingleEvent(HybridCacheEventSource.EventIdDistributedCacheWrite, "DistributedCacheWrite", EventLevel.Verbose); + + await AssertCountersAsync(); + listener.AssertCounter("total-distributed-cache-writes", "Total Distributed Cache Writes", 1); + listener.AssertRemainingCountersZero(); + } + + [SkippableFact] + public async Task StampedeJoin() + { + AssertEnabled(); + + listener.Reset().Source.StampedeJoin(); + listener.AssertSingleEvent(HybridCacheEventSource.EventIdStampedeJoin, "StampedeJoin", EventLevel.Verbose); + + await AssertCountersAsync(); + listener.AssertCounter("total-stampede-joins", "Total Stampede Joins", 1); + listener.AssertRemainingCountersZero(); + } + + private void AssertEnabled() + { + // including this data for visibility when tests fail - ETW subsystem can be ... weird + log.WriteLine($".NET {Environment.Version} on {Environment.OSVersion}, {IntPtr.Size * 8}-bit"); + + Skip.IfNot(listener.Source.IsEnabled(), "Event source not enabled"); + } + + private async Task AssertCountersAsync() + { + var count = await listener.TryAwaitCountersAsync(); + + // ETW counters timing can be painfully unpredictable; generally + // it'll work fine locally, especially on modern .NET, but: + // CI servers and netfx in particular - not so much. The tests + // can still observe and validate the simple events, though, which + // should be enough to be credible that the eventing system is + // fundamentally working. We're not meant to be testing that + // the counters system *itself* works! + + Skip.If(count == 0, "No counters received"); + } +} diff --git a/test/Libraries/Microsoft.Extensions.Caching.Hybrid.Tests/LogCollector.cs b/test/Libraries/Microsoft.Extensions.Caching.Hybrid.Tests/LogCollector.cs new file mode 100644 index 00000000000..bdb5ff981c0 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.Caching.Hybrid.Tests/LogCollector.cs @@ -0,0 +1,84 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.Extensions.Logging; +using Xunit.Abstractions; + +namespace Microsoft.Extensions.Caching.Hybrid.Tests; + +// dummy implementation for collecting test output +internal class LogCollector : ILoggerProvider +{ + private readonly List<(string categoryName, LogLevel logLevel, EventId eventId, Exception? exception, string message)> _items = []; + + public (string categoryName, LogLevel logLevel, EventId eventId, Exception? exception, string message)[] ToArray() + { + lock (_items) + { + return _items.ToArray(); + } + } + + public void WriteTo(ITestOutputHelper log) + { + lock (_items) + { + foreach (var logItem in _items) + { + var errSuffix = logItem.exception is null ? "" : $" - {logItem.exception.Message}"; + log.WriteLine($"{logItem.categoryName} {logItem.eventId}: {logItem.message}{errSuffix}"); + } + } + } + + public void AssertErrors(int[] errorIds) + { + lock (_items) + { + bool same; + if (errorIds.Length == _items.Count) + { + int index = 0; + same = true; + foreach (var item in _items) + { + if (item.eventId.Id != errorIds[index++]) + { + same = false; + break; + } + } + } + else + { + same = false; + } + + if (!same) + { + // we expect this to fail, then + Assert.Equal(string.Join(",", errorIds), string.Join(",", _items.Select(static x => x.eventId.Id))); + } + } + } + + ILogger ILoggerProvider.CreateLogger(string categoryName) => new TypedLogCollector(this, categoryName); + + void IDisposable.Dispose() + { + // nothing to do + } + + private sealed class TypedLogCollector(LogCollector parent, string categoryName) : ILogger + { + IDisposable? ILogger.BeginScope(TState state) => null; + bool ILogger.IsEnabled(LogLevel logLevel) => true; + void ILogger.Log(LogLevel logLevel, EventId eventId, TState state, Exception? exception, Func formatter) + { + lock (parent._items) + { + parent._items.Add((categoryName, logLevel, eventId, exception, formatter(state, exception))); + } + } + } +} diff --git a/test/Libraries/Microsoft.Extensions.Caching.Hybrid.Tests/Microsoft.Extensions.Caching.Hybrid.Tests.csproj b/test/Libraries/Microsoft.Extensions.Caching.Hybrid.Tests/Microsoft.Extensions.Caching.Hybrid.Tests.csproj index ef80a84eee9..fb8863cf776 100644 --- a/test/Libraries/Microsoft.Extensions.Caching.Hybrid.Tests/Microsoft.Extensions.Caching.Hybrid.Tests.csproj +++ b/test/Libraries/Microsoft.Extensions.Caching.Hybrid.Tests/Microsoft.Extensions.Caching.Hybrid.Tests.csproj @@ -12,13 +12,15 @@ + - + + diff --git a/test/Libraries/Microsoft.Extensions.Caching.Hybrid.Tests/NullDistributedCache.cs b/test/Libraries/Microsoft.Extensions.Caching.Hybrid.Tests/NullDistributedCache.cs new file mode 100644 index 00000000000..d07cb51bb93 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.Caching.Hybrid.Tests/NullDistributedCache.cs @@ -0,0 +1,31 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.Extensions.Caching.Distributed; + +namespace Microsoft.Extensions.Caching.Hybrid.Tests; + +// dummy L2 that doesn't actually store anything +internal class NullDistributedCache : IDistributedCache +{ + byte[]? IDistributedCache.Get(string key) => null; + Task IDistributedCache.GetAsync(string key, CancellationToken token) => Task.FromResult(null); + void IDistributedCache.Refresh(string key) + { + // nothing to do + } + + Task IDistributedCache.RefreshAsync(string key, CancellationToken token) => Task.CompletedTask; + void IDistributedCache.Remove(string key) + { + // nothing to do + } + + Task IDistributedCache.RemoveAsync(string key, CancellationToken token) => Task.CompletedTask; + void IDistributedCache.Set(string key, byte[] value, DistributedCacheEntryOptions options) + { + // nothing to do + } + + Task IDistributedCache.SetAsync(string key, byte[] value, DistributedCacheEntryOptions options, CancellationToken token) => Task.CompletedTask; +} diff --git a/test/Libraries/Microsoft.Extensions.Caching.Hybrid.Tests/SizeTests.cs b/test/Libraries/Microsoft.Extensions.Caching.Hybrid.Tests/SizeTests.cs index 119c2297882..66f4fc7628d 100644 --- a/test/Libraries/Microsoft.Extensions.Caching.Hybrid.Tests/SizeTests.cs +++ b/test/Libraries/Microsoft.Extensions.Caching.Hybrid.Tests/SizeTests.cs @@ -1,31 +1,60 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Buffers; +using System.ComponentModel; +using Microsoft.Extensions.Caching.Distributed; using Microsoft.Extensions.Caching.Hybrid.Internal; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Xunit.Abstractions; namespace Microsoft.Extensions.Caching.Hybrid.Tests; -public class SizeTests +public class SizeTests(ITestOutputHelper log) { [Theory] - [InlineData(null, true)] // does not enforce size limits - [InlineData(8L, false)] // unreasonably small limit; chosen because our test string has length 12 - hence no expectation to find the second time - [InlineData(1024L, true)] // reasonable size limit - public async Task ValidateSizeLimit_Immutable(long? sizeLimit, bool expectFromL1) + [InlineData("abc", null, true, null, null)] // does not enforce size limits + [InlineData("", null, false, null, null, Log.IdKeyEmptyOrWhitespace, Log.IdKeyEmptyOrWhitespace)] // invalid key + [InlineData(" ", null, false, null, null, Log.IdKeyEmptyOrWhitespace, Log.IdKeyEmptyOrWhitespace)] // invalid key + [InlineData(null, null, false, null, null, Log.IdKeyEmptyOrWhitespace, Log.IdKeyEmptyOrWhitespace)] // invalid key + [InlineData("abc", 8L, false, null, null)] // unreasonably small limit; chosen because our test string has length 12 - hence no expectation to find the second time + [InlineData("abc", 1024L, true, null, null)] // reasonable size limit + [InlineData("abc", 1024L, true, 8L, null, Log.IdMaximumPayloadBytesExceeded)] // reasonable size limit, small HC quota + [InlineData("abc", null, false, null, 2, Log.IdMaximumKeyLengthExceeded, Log.IdMaximumKeyLengthExceeded)] // key limit exceeded + [InlineData("a\u0000c", null, false, null, null, Log.IdKeyInvalidContent, Log.IdKeyInvalidContent)] // invalid key + [InlineData("a\u001Fc", null, false, null, null, Log.IdKeyInvalidContent, Log.IdKeyInvalidContent)] // invalid key + [InlineData("a\u0020c", null, true, null, null)] // fine (this is just space) + public async Task ValidateSizeLimit_Immutable(string? key, long? sizeLimit, bool expectFromL1, long? maximumPayloadBytes, int? maximumKeyLength, + params int[] errorIds) { + using var collector = new LogCollector(); var services = new ServiceCollection(); services.AddMemoryCache(options => options.SizeLimit = sizeLimit); - services.AddHybridCache(); + services.AddHybridCache(options => + { + if (maximumKeyLength.HasValue) + { + options.MaximumKeyLength = maximumKeyLength.GetValueOrDefault(); + } + + if (maximumPayloadBytes.HasValue) + { + options.MaximumPayloadBytes = maximumPayloadBytes.GetValueOrDefault(); + } + }); + services.AddLogging(options => + { + options.ClearProviders(); + options.AddProvider(collector); + }); using ServiceProvider provider = services.BuildServiceProvider(); var cache = Assert.IsType(provider.GetRequiredService()); - const string Key = "abc"; - // this looks weird; it is intentionally not a const - we want to check // same instance without worrying about interning from raw literals string expected = new("simple value".ToArray()); - var actual = await cache.GetOrCreateAsync(Key, ct => new(expected)); + var actual = await cache.GetOrCreateAsync(key!, ct => new(expected)); // expect same contents Assert.Equal(expected, actual); @@ -35,7 +64,7 @@ public async Task ValidateSizeLimit_Immutable(long? sizeLimit, bool expectFromL1 Assert.Same(expected, actual); // rinse and repeat, to check we get the value from L1 - actual = await cache.GetOrCreateAsync(Key, ct => new(Guid.NewGuid().ToString())); + actual = await cache.GetOrCreateAsync(key!, ct => new(Guid.NewGuid().ToString())); if (expectFromL1) { @@ -51,30 +80,54 @@ public async Task ValidateSizeLimit_Immutable(long? sizeLimit, bool expectFromL1 // L1 cache not used Assert.NotEqual(expected, actual); } + + collector.WriteTo(log); + collector.AssertErrors(errorIds); } [Theory] - [InlineData(null, true)] // does not enforce size limits - [InlineData(8L, false)] // unreasonably small limit; chosen because our test string has length 12 - hence no expectation to find the second time - [InlineData(1024L, true)] // reasonable size limit - public async Task ValidateSizeLimit_Mutable(long? sizeLimit, bool expectFromL1) + [InlineData("abc", null, true, null, null)] // does not enforce size limits + [InlineData("", null, false, null, null, Log.IdKeyEmptyOrWhitespace, Log.IdKeyEmptyOrWhitespace)] // invalid key + [InlineData(" ", null, false, null, null, Log.IdKeyEmptyOrWhitespace, Log.IdKeyEmptyOrWhitespace)] // invalid key + [InlineData(null, null, false, null, null, Log.IdKeyEmptyOrWhitespace, Log.IdKeyEmptyOrWhitespace)] // invalid key + [InlineData("abc", 8L, false, null, null)] // unreasonably small limit; chosen because our test string has length 12 - hence no expectation to find the second time + [InlineData("abc", 1024L, true, null, null)] // reasonable size limit + [InlineData("abc", 1024L, true, 8L, null, Log.IdMaximumPayloadBytesExceeded)] // reasonable size limit, small HC quota + [InlineData("abc", null, false, null, 2, Log.IdMaximumKeyLengthExceeded, Log.IdMaximumKeyLengthExceeded)] // key limit exceeded + public async Task ValidateSizeLimit_Mutable(string? key, long? sizeLimit, bool expectFromL1, long? maximumPayloadBytes, int? maximumKeyLength, + params int[] errorIds) { + using var collector = new LogCollector(); var services = new ServiceCollection(); services.AddMemoryCache(options => options.SizeLimit = sizeLimit); - services.AddHybridCache(); + services.AddHybridCache(options => + { + if (maximumKeyLength.HasValue) + { + options.MaximumKeyLength = maximumKeyLength.GetValueOrDefault(); + } + + if (maximumPayloadBytes.HasValue) + { + options.MaximumPayloadBytes = maximumPayloadBytes.GetValueOrDefault(); + } + }); + services.AddLogging(options => + { + options.ClearProviders(); + options.AddProvider(collector); + }); using ServiceProvider provider = services.BuildServiceProvider(); var cache = Assert.IsType(provider.GetRequiredService()); - const string Key = "abc"; - string expected = "simple value"; - var actual = await cache.GetOrCreateAsync(Key, ct => new(new MutablePoco { Value = expected })); + var actual = await cache.GetOrCreateAsync(key!, ct => new(new MutablePoco { Value = expected })); // expect same contents Assert.Equal(expected, actual.Value); // rinse and repeat, to check we get the value from L1 - actual = await cache.GetOrCreateAsync(Key, ct => new(new MutablePoco { Value = Guid.NewGuid().ToString() })); + actual = await cache.GetOrCreateAsync(key!, ct => new(new MutablePoco { Value = Guid.NewGuid().ToString() })); if (expectFromL1) { @@ -86,10 +139,217 @@ public async Task ValidateSizeLimit_Mutable(long? sizeLimit, bool expectFromL1) // L1 cache not used Assert.NotEqual(expected, actual.Value); } + + collector.WriteTo(log); + collector.AssertErrors(errorIds); + } + + [Theory] + [InlineData("some value", false, 1, 1, 2, false)] + [InlineData("read fail", false, 1, 1, 1, true, Log.IdDeserializationFailure)] + [InlineData("write fail", true, 1, 1, 0, true, Log.IdSerializationFailure)] + public async Task BrokenSerializer_Mutable(string value, bool same, int runCount, int serializeCount, int deserializeCount, bool expectKnownFailure, params int[] errorIds) + { + using var collector = new LogCollector(); + var services = new ServiceCollection(); + services.AddMemoryCache(); + services.AddSingleton(); + var serializer = new MutablePoco.Serializer(); + services.AddHybridCache().AddSerializer(serializer); + services.AddLogging(options => + { + options.ClearProviders(); + options.AddProvider(collector); + }); + using ServiceProvider provider = services.BuildServiceProvider(); + var cache = Assert.IsType(provider.GetRequiredService()); + + int actualRunCount = 0; + Func> func = _ => + { + Interlocked.Increment(ref actualRunCount); + return new(new MutablePoco { Value = value }); + }; + + if (expectKnownFailure) + { + await Assert.ThrowsAsync(async () => await cache.GetOrCreateAsync("key", func)); + } + else + { + var first = await cache.GetOrCreateAsync("key", func); + var second = await cache.GetOrCreateAsync("key", func); + Assert.Equal(value, first.Value); + Assert.Equal(value, second.Value); + + if (same) + { + Assert.Same(first, second); + } + else + { + Assert.NotSame(first, second); + } + } + + Assert.Equal(runCount, Volatile.Read(ref actualRunCount)); + Assert.Equal(serializeCount, serializer.WriteCount); + Assert.Equal(deserializeCount, serializer.ReadCount); + collector.WriteTo(log); + collector.AssertErrors(errorIds); + } + + [Theory] + [InlineData("some value", true, 1, 1, 0, false, true)] + [InlineData("read fail", true, 1, 1, 0, false, true)] + [InlineData("write fail", true, 1, 1, 0, true, true, Log.IdSerializationFailure)] + + // without L2, we only need the serializer for sizing purposes (L1), not used for deserialize + [InlineData("some value", true, 1, 1, 0, false, false)] + [InlineData("read fail", true, 1, 1, 0, false, false)] + [InlineData("write fail", true, 1, 1, 0, true, false, Log.IdSerializationFailure)] + [System.Diagnostics.CodeAnalysis.SuppressMessage("Major Code Smell", "S107:Methods should not have too many parameters", Justification = "Test scenario range; reducing duplication")] + public async Task BrokenSerializer_Immutable(string value, bool same, int runCount, int serializeCount, int deserializeCount, bool expectKnownFailure, bool withL2, + params int[] errorIds) + { + using var collector = new LogCollector(); + var services = new ServiceCollection(); + services.AddMemoryCache(); + if (withL2) + { + services.AddSingleton(); + } + + var serializer = new ImmutablePoco.Serializer(); + services.AddHybridCache().AddSerializer(serializer); + services.AddLogging(options => + { + options.ClearProviders(); + options.AddProvider(collector); + }); + using ServiceProvider provider = services.BuildServiceProvider(); + var cache = Assert.IsType(provider.GetRequiredService()); + + int actualRunCount = 0; + Func> func = _ => + { + Interlocked.Increment(ref actualRunCount); + return new(new ImmutablePoco(value)); + }; + + if (expectKnownFailure) + { + await Assert.ThrowsAsync(async () => await cache.GetOrCreateAsync("key", func)); + } + else + { + var first = await cache.GetOrCreateAsync("key", func); + var second = await cache.GetOrCreateAsync("key", func); + Assert.Equal(value, first.Value); + Assert.Equal(value, second.Value); + + if (same) + { + Assert.Same(first, second); + } + else + { + Assert.NotSame(first, second); + } + } + + Assert.Equal(runCount, Volatile.Read(ref actualRunCount)); + Assert.Equal(serializeCount, serializer.WriteCount); + Assert.Equal(deserializeCount, serializer.ReadCount); + collector.WriteTo(log); + collector.AssertErrors(errorIds); + } + + public class KnownFailureException : Exception + { + public KnownFailureException(string message) + : base(message) + { + } } public class MutablePoco { public string Value { get; set; } = ""; + + public sealed class Serializer : IHybridCacheSerializer + { + private int _readCount; + private int _writeCount; + + public int ReadCount => Volatile.Read(ref _readCount); + public int WriteCount => Volatile.Read(ref _writeCount); + + public MutablePoco Deserialize(ReadOnlySequence source) + { + Interlocked.Increment(ref _readCount); + var value = InbuiltTypeSerializer.DeserializeString(source); + if (value == "read fail") + { + throw new KnownFailureException("read failure"); + } + + return new MutablePoco { Value = value }; + } + + public void Serialize(MutablePoco value, IBufferWriter target) + { + Interlocked.Increment(ref _writeCount); + if (value.Value == "write fail") + { + throw new KnownFailureException("write failure"); + } + + InbuiltTypeSerializer.SerializeString(value.Value, target); + } + } + } + + [ImmutableObject(true)] + public sealed class ImmutablePoco + { + public ImmutablePoco(string value) + { + Value = value; + } + + public string Value { get; } + + public sealed class Serializer : IHybridCacheSerializer + { + private int _readCount; + private int _writeCount; + + public int ReadCount => Volatile.Read(ref _readCount); + public int WriteCount => Volatile.Read(ref _writeCount); + + public ImmutablePoco Deserialize(ReadOnlySequence source) + { + Interlocked.Increment(ref _readCount); + var value = InbuiltTypeSerializer.DeserializeString(source); + if (value == "read fail") + { + throw new KnownFailureException("read failure"); + } + + return new ImmutablePoco(value); + } + + public void Serialize(ImmutablePoco value, IBufferWriter target) + { + Interlocked.Increment(ref _writeCount); + if (value.Value == "write fail") + { + throw new KnownFailureException("write failure"); + } + + InbuiltTypeSerializer.SerializeString(value.Value, target); + } + } } } diff --git a/test/Libraries/Microsoft.Extensions.Caching.Hybrid.Tests/TestEventListener.cs b/test/Libraries/Microsoft.Extensions.Caching.Hybrid.Tests/TestEventListener.cs new file mode 100644 index 00000000000..ecb97ef3c7e --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.Caching.Hybrid.Tests/TestEventListener.cs @@ -0,0 +1,189 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics; +using System.Diagnostics.Tracing; +using System.Globalization; +using Microsoft.Extensions.Caching.Hybrid.Internal; + +namespace Microsoft.Extensions.Caching.Hybrid.Tests; + +public sealed class TestEventListener : EventListener +{ + // captures both event and counter data + + // this is used as a class fixture from HybridCacheEventSourceTests, because there + // seems to be some unpredictable behaviours if multiple event sources/listeners are + // casually created etc + private const double EventCounterIntervalSec = 0.25; + + private readonly List<(int id, string name, EventLevel level)> _events = []; + private readonly Dictionary _counters = []; + + private object SyncLock => _events; + + internal HybridCacheEventSource Source { get; } = new(); + + public TestEventListener Reset(bool resetCounters = true) + { + lock (SyncLock) + { + _events.Clear(); + _counters.Clear(); + + if (resetCounters) + { + Source.ResetCounters(); + } + } + + Assert.True(Source.IsEnabled(), "should report as enabled"); + + return this; + } + + protected override void OnEventSourceCreated(EventSource eventSource) + { + if (ReferenceEquals(eventSource, Source)) + { + var args = new Dictionary + { + ["EventCounterIntervalSec"] = EventCounterIntervalSec.ToString("G", CultureInfo.InvariantCulture), + }; + EnableEvents(Source, EventLevel.Verbose, EventKeywords.All, args); + } + + base.OnEventSourceCreated(eventSource); + } + + protected override void OnEventWritten(EventWrittenEventArgs eventData) + { + if (ReferenceEquals(eventData.EventSource, Source)) + { + // capture counters/events + lock (SyncLock) + { + if (eventData.EventName == "EventCounters" + && eventData.Payload is { Count: > 0 }) + { + foreach (var payload in eventData.Payload) + { + if (payload is IDictionary map) + { + string? name = null; + string? displayName = null; + double? value = null; + bool isIncrement = false; + foreach (var pair in map) + { + switch (pair.Key) + { + case "Name" when pair.Value is string: + name = (string)pair.Value; + break; + case "DisplayName" when pair.Value is string s: + displayName = s; + break; + case "Mean": + isIncrement = false; + value = Convert.ToDouble(pair.Value); + break; + case "Increment": + isIncrement = true; + value = Convert.ToDouble(pair.Value); + break; + } + } + + if (name is not null && value is not null) + { + if (isIncrement && _counters.TryGetValue(name, out var oldPair)) + { + value += oldPair.value; // treat as delta from old + } + + Debug.WriteLine($"{name}={value}"); + _counters[name] = (displayName, value.Value); + } + } + } + } + else + { + _events.Add((eventData.EventId, eventData.EventName ?? "", eventData.Level)); + } + } + } + + base.OnEventWritten(eventData); + } + + public (int id, string name, EventLevel level) SingleEvent() + { + (int id, string name, EventLevel level) evt; + lock (SyncLock) + { + evt = Assert.Single(_events); + } + + return evt; + } + + public void AssertSingleEvent(int id, string name, EventLevel level) + { + var evt = SingleEvent(); + Assert.Equal(name, evt.name); + Assert.Equal(id, evt.id); + Assert.Equal(level, evt.level); + } + + public double AssertCounter(string name, string displayName) + { + lock (SyncLock) + { + Assert.True(_counters.TryGetValue(name, out var pair), $"counter not found: {name}"); + Assert.Equal(displayName, pair.displayName); + + _counters.Remove(name); // count as validated + return pair.value; + } + } + + public void AssertCounter(string name, string displayName, double expected) + { + var actual = AssertCounter(name, displayName); + if (!Equals(expected, actual)) + { + Assert.Fail($"{name}: expected {expected}, actual {actual}"); + } + } + + [System.Diagnostics.CodeAnalysis.SuppressMessage("Major Bug", "S1244:Floating point numbers should not be tested for equality", Justification = "Test expects exact zero")] + public void AssertRemainingCountersZero() + { + lock (SyncLock) + { + foreach (var pair in _counters) + { + if (pair.Value.value != 0) + { + Assert.Fail($"{pair.Key}: expected 0, actual {pair.Value.value}"); + } + } + } + } + + [System.Diagnostics.CodeAnalysis.SuppressMessage("Performance", "CA1822:Mark members as static", Justification = "Clarity and usability")] + public async Task TryAwaitCountersAsync() + { + // allow 2 cycles because if we only allow 1, we run the risk of a + // snapshot being captured mid-cycle when we were setting up the test + // (ok, that's an unlikely race condition, but!) + await Task.Delay(TimeSpan.FromSeconds(EventCounterIntervalSec * 2)); + + lock (SyncLock) + { + return _counters.Count; + } + } +} diff --git a/test/Libraries/Microsoft.Extensions.Caching.Hybrid.Tests/UnreliableL2Tests.cs b/test/Libraries/Microsoft.Extensions.Caching.Hybrid.Tests/UnreliableL2Tests.cs new file mode 100644 index 00000000000..7af85f9cba2 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.Caching.Hybrid.Tests/UnreliableL2Tests.cs @@ -0,0 +1,251 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics.CodeAnalysis; +using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Extensions.Caching.Hybrid.Internal; +using Microsoft.Extensions.Caching.Memory; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Xunit.Abstractions; + +namespace Microsoft.Extensions.Caching.Hybrid.Tests; + +// validate HC stability when the L2 is unreliable +public class UnreliableL2Tests(ITestOutputHelper testLog) +{ + [Theory] + [InlineData(BreakType.None)] + [InlineData(BreakType.Synchronous, Log.IdCacheBackendWriteFailure)] + [InlineData(BreakType.Asynchronous, Log.IdCacheBackendWriteFailure)] + [InlineData(BreakType.AsynchronousYield, Log.IdCacheBackendWriteFailure)] + [SuppressMessage("Usage", "VSTHRD003:Avoid awaiting foreign Tasks", Justification = "Intentional; tracking for out-of-band support only")] + public async Task WriteFailureInvisible(BreakType writeBreak, params int[] errorIds) + { + using (GetServices(out var hc, out var l1, out var l2, out var log)) + using (log) + { + // normal behaviour when working fine + var x = await hc.GetOrCreateAsync("x", NewGuid); + Assert.Equal(x, await hc.GetOrCreateAsync("x", NewGuid)); + Assert.NotNull(l2.Tail.Get("x")); // exists + + l2.WriteBreak = writeBreak; + var y = await hc.GetOrCreateAsync("y", NewGuid); + Assert.Equal(y, await hc.GetOrCreateAsync("y", NewGuid)); + if (writeBreak == BreakType.None) + { + Assert.NotNull(l2.Tail.Get("y")); // exists + } + else + { + Assert.Null(l2.Tail.Get("y")); // does not exist + } + + await l2.LastWrite; // allows out-of-band write to complete + await Task.Delay(150); // even then: thread jitter can cause problems + + log.WriteTo(testLog); + log.AssertErrors(errorIds); + } + } + + [Theory] + [InlineData(BreakType.None)] + [InlineData(BreakType.Synchronous, Log.IdCacheBackendReadFailure, Log.IdCacheBackendReadFailure)] + [InlineData(BreakType.Asynchronous, Log.IdCacheBackendReadFailure, Log.IdCacheBackendReadFailure)] + [InlineData(BreakType.AsynchronousYield, Log.IdCacheBackendReadFailure, Log.IdCacheBackendReadFailure)] + public async Task ReadFailureInvisible(BreakType readBreak, params int[] errorIds) + { + using (GetServices(out var hc, out var l1, out var l2, out var log)) + using (log) + { + // create two new values via HC; this should go down to l2 + var x = await hc.GetOrCreateAsync("x", NewGuid); + var y = await hc.GetOrCreateAsync("y", NewGuid); + + // this should be reliable and repeatable + Assert.Equal(x, await hc.GetOrCreateAsync("x", NewGuid)); + Assert.Equal(y, await hc.GetOrCreateAsync("y", NewGuid)); + + // even if we clean L1, causing new L2 fetches + l1.Clear(); + Assert.Equal(x, await hc.GetOrCreateAsync("x", NewGuid)); + Assert.Equal(y, await hc.GetOrCreateAsync("y", NewGuid)); + + // now we break L2 in some predictable way, *without* clearing L1 - the + // values should still be available via L1 + l2.ReadBreak = readBreak; + Assert.Equal(x, await hc.GetOrCreateAsync("x", NewGuid)); + Assert.Equal(y, await hc.GetOrCreateAsync("y", NewGuid)); + + // but if we clear L1 to force L2 hits, we anticipate problems + l1.Clear(); + if (readBreak == BreakType.None) + { + Assert.Equal(x, await hc.GetOrCreateAsync("x", NewGuid)); + Assert.Equal(y, await hc.GetOrCreateAsync("y", NewGuid)); + } + else + { + // because L2 is unavailable and L1 is empty, we expect the callback + // to be used again, generating new values + var a = await hc.GetOrCreateAsync("x", NewGuid, NoL2Write); + var b = await hc.GetOrCreateAsync("y", NewGuid, NoL2Write); + + Assert.NotEqual(x, a); + Assert.NotEqual(y, b); + + // but those *new* values are at least reliable inside L1 + Assert.Equal(a, await hc.GetOrCreateAsync("x", NewGuid)); + Assert.Equal(b, await hc.GetOrCreateAsync("y", NewGuid)); + } + + log.WriteTo(testLog); + log.AssertErrors(errorIds); + } + } + + private static HybridCacheEntryOptions NoL2Write { get; } = new HybridCacheEntryOptions { Flags = HybridCacheEntryFlags.DisableDistributedCacheWrite }; + + public enum BreakType + { + None, // async API works correctly + Synchronous, // async API faults directly rather than return a faulted task + Asynchronous, // async API returns a completed asynchronous fault + AsynchronousYield, // async API returns an incomplete asynchronous fault + } + + private static ValueTask NewGuid(CancellationToken cancellationToken) => new(Guid.NewGuid()); + + private static IDisposable GetServices(out HybridCache hc, out MemoryCache l1, + out UnreliableDistributedCache l2, out LogCollector log) + { + // we need an entirely separate MC for the dummy backend, not connected to our + // "real" services + var services = new ServiceCollection(); + services.AddDistributedMemoryCache(); + var backend = services.BuildServiceProvider().GetRequiredService(); + + // now create the "real" services + l2 = new UnreliableDistributedCache(backend); + var collector = new LogCollector(); + log = collector; + services = new ServiceCollection(); + services.AddSingleton(l2); + services.AddHybridCache(); + services.AddLogging(options => + { + options.ClearProviders(); + options.AddProvider(collector); + }); + var lifetime = services.BuildServiceProvider(); + hc = lifetime.GetRequiredService(); + l1 = Assert.IsType(lifetime.GetRequiredService()); + return lifetime; + } + + private sealed class UnreliableDistributedCache : IDistributedCache + { + public UnreliableDistributedCache(IDistributedCache tail) + { + Tail = tail; + } + + public IDistributedCache Tail { get; } + public BreakType ReadBreak { get; set; } + public BreakType WriteBreak { get; set; } + + public Task LastWrite { get; private set; } = Task.CompletedTask; + + public byte[]? Get(string key) => throw new NotSupportedException(); // only async API in use + + public Task GetAsync(string key, CancellationToken token = default) + => TrackLast(ThrowIfBrokenAsync(ReadBreak) ?? Tail.GetAsync(key, token)); + + public void Refresh(string key) => throw new NotSupportedException(); // only async API in use + + public Task RefreshAsync(string key, CancellationToken token = default) + => TrackLast(ThrowIfBrokenAsync(WriteBreak) ?? Tail.RefreshAsync(key, token)); + + public void Remove(string key) => throw new NotSupportedException(); // only async API in use + + public Task RemoveAsync(string key, CancellationToken token = default) + => TrackLast(ThrowIfBrokenAsync(WriteBreak) ?? Tail.RemoveAsync(key, token)); + + public void Set(string key, byte[] value, DistributedCacheEntryOptions options) => throw new NotSupportedException(); // only async API in use + + public Task SetAsync(string key, byte[] value, DistributedCacheEntryOptions options, CancellationToken token = default) + => TrackLast(ThrowIfBrokenAsync(WriteBreak) ?? Tail.SetAsync(key, value, options, token)); + + [DoesNotReturn] + private static void Throw() => throw new IOException("L2 offline"); + + private static async Task ThrowAsync(bool yield) + { + if (yield) + { + await Task.Yield(); + } + + Throw(); + return default; // never reached + } + + private static Task? ThrowIfBrokenAsync(BreakType breakType) => ThrowIfBrokenAsync(breakType); + + [SuppressMessage("Critical Bug", "S4586:Non-async \"Task/Task\" methods should not return null", Justification = "Intentional for propagation")] + private static Task? ThrowIfBrokenAsync(BreakType breakType) + { + switch (breakType) + { + case BreakType.Asynchronous: + return ThrowAsync(false); + case BreakType.AsynchronousYield: + return ThrowAsync(true); + case BreakType.None: + return null; + default: + // includes BreakType.Synchronous and anything unknown + Throw(); + break; + } + + return null; + } + + [SuppressMessage("Usage", "VSTHRD003:Avoid awaiting foreign Tasks", Justification = "Intentional; tracking for out-of-band support only")] + [SuppressMessage("Design", "CA1031:Do not catch general exception types", Justification = "We don't need the failure type - just the timing")] + private static Task IgnoreFailure(Task task) + { + return task.Status == TaskStatus.RanToCompletion + ? Task.CompletedTask : IgnoreAsync(task); + + static async Task IgnoreAsync(Task task) + { + try + { + await task; + } + catch + { + // we only care about the "when"; failure is fine + } + } + } + + [SuppressMessage("Usage", "VSTHRD003:Avoid awaiting foreign Tasks", Justification = "Intentional; tracking for out-of-band support only")] + private Task TrackLast(Task lastWrite) + { + LastWrite = IgnoreFailure(lastWrite); + return lastWrite; + } + + [SuppressMessage("Usage", "VSTHRD003:Avoid awaiting foreign Tasks", Justification = "Intentional; tracking for out-of-band support only")] + private Task TrackLast(Task lastWrite) + { + LastWrite = IgnoreFailure(lastWrite); + return lastWrite; + } + } +} diff --git a/test/Libraries/Microsoft.Extensions.Telemetry.Abstractions.Tests/Microsoft.Extensions.Telemetry.Abstractions.Tests.csproj b/test/Libraries/Microsoft.Extensions.Telemetry.Abstractions.Tests/Microsoft.Extensions.Telemetry.Abstractions.Tests.csproj index ac284fee861..387cec3c5c0 100644 --- a/test/Libraries/Microsoft.Extensions.Telemetry.Abstractions.Tests/Microsoft.Extensions.Telemetry.Abstractions.Tests.csproj +++ b/test/Libraries/Microsoft.Extensions.Telemetry.Abstractions.Tests/Microsoft.Extensions.Telemetry.Abstractions.Tests.csproj @@ -12,4 +12,8 @@ + + + + diff --git a/test/Shared/JsonSchemaExporter/JsonSchemaExporterConfigurationTests.cs b/test/Shared/JsonSchemaExporter/JsonSchemaExporterConfigurationTests.cs new file mode 100644 index 00000000000..1d2b6caa74e --- /dev/null +++ b/test/Shared/JsonSchemaExporter/JsonSchemaExporterConfigurationTests.cs @@ -0,0 +1,35 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Text.Json.Schema; +using Xunit; + +namespace Microsoft.Extensions.AI.JsonSchemaExporter; + +public static class JsonSchemaExporterConfigurationTests +{ + [Theory] + [InlineData(false)] + [InlineData(true)] + public static void JsonSchemaExporterOptions_DefaultValues(bool useSingleton) + { + JsonSchemaExporterOptions configuration = useSingleton ? JsonSchemaExporterOptions.Default : new(); + Assert.False(configuration.TreatNullObliviousAsNonNullable); + Assert.Null(configuration.TransformSchemaNode); + } + + [Fact] + public static void JsonSchemaExporterOptions_Singleton_ReturnsSameInstance() + { + Assert.Same(JsonSchemaExporterOptions.Default, JsonSchemaExporterOptions.Default); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public static void JsonSchemaExporterOptions_TreatNullObliviousAsNonNullable(bool treatNullObliviousAsNonNullable) + { + JsonSchemaExporterOptions configuration = new() { TreatNullObliviousAsNonNullable = treatNullObliviousAsNonNullable }; + Assert.Equal(treatNullObliviousAsNonNullable, configuration.TreatNullObliviousAsNonNullable); + } +} diff --git a/test/Shared/JsonSchemaExporter/JsonSchemaExporterTests.cs b/test/Shared/JsonSchemaExporter/JsonSchemaExporterTests.cs new file mode 100644 index 00000000000..2ec81987dc2 --- /dev/null +++ b/test/Shared/JsonSchemaExporter/JsonSchemaExporterTests.cs @@ -0,0 +1,147 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Text.Json; +using System.Text.Json.Nodes; +using System.Text.Json.Schema; +using System.Text.Json.Serialization; +using System.Text.Json.Serialization.Metadata; +#if !NET9_0_OR_GREATER +using System.Xml.Linq; +#endif +using Xunit; + +#pragma warning disable SA1402 // File may only contain a single type + +namespace Microsoft.Extensions.AI.JsonSchemaExporter; + +public abstract class JsonSchemaExporterTests +{ + protected abstract JsonSerializerOptions Options { get; } + + [Theory] + [MemberData(nameof(TestTypes.GetTestData), MemberType = typeof(TestTypes))] + public void TestTypes_GeneratesExpectedJsonSchema(ITestData testData) + { + JsonSerializerOptions options = testData.Options is { } opts + ? new(opts) { TypeInfoResolver = Options.TypeInfoResolver } + : Options; + + JsonNode schema = options.GetJsonSchemaAsNode(testData.Type, (JsonSchemaExporterOptions?)testData.ExporterOptions); + SchemaTestHelpers.AssertEqualJsonSchema(testData.ExpectedJsonSchema, schema); + } + + [Theory] + [MemberData(nameof(TestTypes.GetTestDataUsingAllValues), MemberType = typeof(TestTypes))] + public void TestTypes_SerializedValueMatchesGeneratedSchema(ITestData testData) + { + JsonSerializerOptions options = testData.Options is { } opts + ? new(opts) { TypeInfoResolver = Options.TypeInfoResolver } + : Options; + + JsonNode schema = options.GetJsonSchemaAsNode(testData.Type, (JsonSchemaExporterOptions?)testData.ExporterOptions); + JsonNode? instance = JsonSerializer.SerializeToNode(testData.Value, testData.Type, options); + SchemaTestHelpers.AssertDocumentMatchesSchema(schema, instance); + } + + [Theory] + [InlineData(typeof(string), "string")] + [InlineData(typeof(int[]), "array")] + [InlineData(typeof(Dictionary), "object")] + [InlineData(typeof(TestTypes.SimplePoco), "object")] + public void TreatNullObliviousAsNonNullable_True_MarksAllReferenceTypesAsNonNullable(Type referenceType, string expectedType) + { + Assert.True(!referenceType.IsValueType); + var config = new JsonSchemaExporterOptions { TreatNullObliviousAsNonNullable = true }; + JsonNode schema = Options.GetJsonSchemaAsNode(referenceType, config); + JsonValue type = Assert.IsAssignableFrom(schema["type"]); + Assert.Equal(expectedType, (string)type!); + } + + [Theory] + [InlineData(typeof(int), "integer")] + [InlineData(typeof(double), "number")] + [InlineData(typeof(bool), "boolean")] + [InlineData(typeof(ImmutableArray), "array")] + [InlineData(typeof(TestTypes.StructDictionary), "object")] + [InlineData(typeof(TestTypes.SimpleRecordStruct), "object")] + public void TreatNullObliviousAsNonNullable_True_DoesNotImpactNonReferenceTypes(Type referenceType, string expectedType) + { + Assert.True(referenceType.IsValueType); + var config = new JsonSchemaExporterOptions { TreatNullObliviousAsNonNullable = true }; + JsonNode schema = Options.GetJsonSchemaAsNode(referenceType, config); + JsonValue value = Assert.IsAssignableFrom(schema["type"]); + Assert.Equal(expectedType, (string)value!); + } + +#if !NET9_0 // Disable until https://github.com/dotnet/runtime/pull/108764 gets backported + [Fact] + public void CanGenerateXElementSchema() + { + JsonNode schema = Options.GetJsonSchemaAsNode(typeof(XElement)); + Assert.True(schema.ToJsonString().Length < 100_000); + } +#endif + + [Fact] + public void TreatNullObliviousAsNonNullable_True_DoesNotImpactObjectType() + { + var config = new JsonSchemaExporterOptions { TreatNullObliviousAsNonNullable = true }; + JsonNode schema = Options.GetJsonSchemaAsNode(typeof(object), config); + Assert.False(schema is JsonObject jObj && jObj.ContainsKey("type")); + } + + [Fact] + public void TypeWithDisallowUnmappedMembers_AdditionalPropertiesFailValidation() + { + JsonNode schema = Options.GetJsonSchemaAsNode(typeof(TestTypes.PocoDisallowingUnmappedMembers)); + JsonNode? jsonWithUnmappedProperties = JsonNode.Parse("""{ "UnmappedProperty" : {} }"""); + SchemaTestHelpers.AssertDoesNotMatchSchema(schema, jsonWithUnmappedProperties); + } + + [Fact] + public void GetJsonSchema_NullInputs_ThrowsArgumentNullException() + { + Assert.Throws(() => ((JsonSerializerOptions)null!).GetJsonSchemaAsNode(typeof(int))); + Assert.Throws(() => Options.GetJsonSchemaAsNode(type: null!)); + Assert.Throws(() => ((JsonTypeInfo)null!).GetJsonSchemaAsNode()); + } + + [Fact] + public void GetJsonSchema_NoResolver_ThrowInvalidOperationException() + { + var options = new JsonSerializerOptions(); + Assert.Throws(() => options.GetJsonSchemaAsNode(typeof(int))); + } + + [Fact] + public void MaxDepth_SetToZero_NonTrivialSchema_ThrowsInvalidOperationException() + { + JsonSerializerOptions options = new(Options) { MaxDepth = 1 }; + var ex = Assert.Throws(() => options.GetJsonSchemaAsNode(typeof(TestTypes.SimplePoco))); + Assert.Contains("The depth of the generated JSON schema exceeds the JsonSerializerOptions.MaxDepth setting.", ex.Message); + } + + [Fact] + public void ReferenceHandlePreserve_Enabled_ThrowsNotSupportedException() + { + var options = new JsonSerializerOptions(Options) { ReferenceHandler = ReferenceHandler.Preserve }; + options.MakeReadOnly(); + + var ex = Assert.Throws(() => options.GetJsonSchemaAsNode(typeof(TestTypes.SimplePoco))); + Assert.Contains("ReferenceHandler.Preserve", ex.Message); + } +} + +public sealed class ReflectionJsonSchemaExporterTests : JsonSchemaExporterTests +{ + protected override JsonSerializerOptions Options => JsonSerializerOptions.Default; +} + +public sealed class SourceGenJsonSchemaExporterTests : JsonSchemaExporterTests +{ + protected override JsonSerializerOptions Options => TestTypes.TestTypesContext.Default.Options; +} diff --git a/test/Shared/JsonSchemaExporter/SchemaTestHelpers.cs b/test/Shared/JsonSchemaExporter/SchemaTestHelpers.cs new file mode 100644 index 00000000000..02e659a27aa --- /dev/null +++ b/test/Shared/JsonSchemaExporter/SchemaTestHelpers.cs @@ -0,0 +1,82 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text.Json; +using System.Text.Json.Nodes; +using System.Text.Json.Serialization; +using Json.Schema; +using Xunit.Sdk; + +namespace Microsoft.Extensions.AI.JsonSchemaExporter; + +internal static partial class SchemaTestHelpers +{ + public static void AssertEqualJsonSchema(JsonNode expectedJsonSchema, JsonNode actualJsonSchema) + { + if (!JsonNode.DeepEquals(expectedJsonSchema, actualJsonSchema)) + { + throw new XunitException($""" + Generated schema does not match the expected specification. + Expected: + {FormatJson(expectedJsonSchema)} + Actual: + {FormatJson(actualJsonSchema)} + """); + } + } + + public static void AssertDocumentMatchesSchema(JsonNode schema, JsonNode? instance) + { + EvaluationResults results = EvaluateSchemaCore(schema, instance); + if (!results.IsValid) + { + IEnumerable errors = results.Details + .Where(d => d.HasErrors) + .SelectMany(d => d.Errors!.Select(error => $"Path:${d.InstanceLocation} {error.Key}:{error.Value}")); + + throw new XunitException($""" + Instance JSON document does not match the specified schema. + Schema: + {FormatJson(schema)} + Instance: + {FormatJson(instance)} + Errors: + {string.Join(Environment.NewLine, errors)} + """); + } + } + + public static void AssertDoesNotMatchSchema(JsonNode schema, JsonNode? instance) + { + EvaluationResults results = EvaluateSchemaCore(schema, instance); + if (results.IsValid) + { + throw new XunitException($""" + Instance JSON document matches the specified schema. + Schema: + {FormatJson(schema)} + Instance: + {FormatJson(instance)} + """); + } + } + + private static EvaluationResults EvaluateSchemaCore(JsonNode schema, JsonNode? instance) + { + JsonSchema jsonSchema = JsonSerializer.Deserialize(schema, Context.Default.JsonSchema)!; + EvaluationOptions options = new() { OutputFormat = OutputFormat.List }; + return jsonSchema.Evaluate(instance, options); + } + + private static string FormatJson(JsonNode? node) => + JsonSerializer.Serialize(node, Context.Default.JsonNode!); + + [JsonSerializable(typeof(string))] + [JsonSerializable(typeof(JsonNode))] + [JsonSerializable(typeof(JsonSchema))] + [JsonSourceGenerationOptions(WriteIndented = true)] + private partial class Context : JsonSerializerContext; +} diff --git a/test/Shared/JsonSchemaExporter/TestData.cs b/test/Shared/JsonSchemaExporter/TestData.cs new file mode 100644 index 00000000000..0254a62b144 --- /dev/null +++ b/test/Shared/JsonSchemaExporter/TestData.cs @@ -0,0 +1,67 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Text.Json; +using System.Text.Json.Nodes; +using System.Text.Json.Schema; + +namespace Microsoft.Extensions.AI.JsonSchemaExporter; + +internal sealed record TestData( + T? Value, + [StringSyntax(StringSyntaxAttribute.Json)] string ExpectedJsonSchema, + IEnumerable? AdditionalValues = null, + JsonSchemaExporterOptions? ExporterOptions = null, + JsonSerializerOptions? Options = null, + bool WritesNumbersAsStrings = false) + : ITestData +{ + private static readonly JsonDocumentOptions _schemaParseOptions = new() { CommentHandling = JsonCommentHandling.Skip }; + + public Type Type => typeof(T); + object? ITestData.Value => Value; + object? ITestData.ExporterOptions => ExporterOptions; + JsonNode ITestData.ExpectedJsonSchema { get; } = + JsonNode.Parse(ExpectedJsonSchema, documentOptions: _schemaParseOptions) + ?? throw new ArgumentNullException("schema must not be null"); + + IEnumerable ITestData.GetTestDataForAllValues() + { + yield return this; + + if (default(T) is null && + ExporterOptions is { TreatNullObliviousAsNonNullable: false } && + Value is not null) + { + yield return this with { Value = default }; + } + + if (AdditionalValues != null) + { + foreach (T? value in AdditionalValues) + { + yield return this with { Value = value, AdditionalValues = null }; + } + } + } +} + +public interface ITestData +{ + Type Type { get; } + + object? Value { get; } + + JsonNode ExpectedJsonSchema { get; } + + object? ExporterOptions { get; } + + JsonSerializerOptions? Options { get; } + + bool WritesNumbersAsStrings { get; } + + IEnumerable GetTestDataForAllValues(); +} diff --git a/test/Shared/JsonSchemaExporter/TestTypes.cs b/test/Shared/JsonSchemaExporter/TestTypes.cs new file mode 100644 index 00000000000..d21a40640dd --- /dev/null +++ b/test/Shared/JsonSchemaExporter/TestTypes.cs @@ -0,0 +1,1291 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.ComponentModel; +using System.ComponentModel.DataAnnotations; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Reflection; +using System.Text.Json; +using System.Text.Json.Nodes; +using System.Text.Json.Schema; +using System.Text.Json.Serialization; +using System.Xml.Linq; + +#pragma warning disable SA1118 // Parameter should not span multiple lines +#pragma warning disable JSON001 // Comments not allowed +#pragma warning disable S2344 // Enumeration type names should not have "Flags" or "Enum" suffixes +#pragma warning disable SA1502 // Element should not be on a single line +#pragma warning disable SA1136 // Enum values should be on separate lines +#pragma warning disable SA1133 // Do not combine attributes +#pragma warning disable S3604 // Member initializer values should not be redundant +#pragma warning disable SA1515 // Single-line comment should be preceded by blank line +#pragma warning disable CA1052 // Static holder types should be Static or NotInheritable +#pragma warning disable S1121 // Assignments should not be made from within sub-expressions +#pragma warning disable IDE0073 // The file header is missing or not located at the top of the file + +namespace Microsoft.Extensions.AI.JsonSchemaExporter; + +public static partial class TestTypes +{ + public static IEnumerable GetTestData() => GetTestDataCore().Select(t => new object[] { t }); + + public static IEnumerable GetTestDataUsingAllValues() => + GetTestDataCore() + .SelectMany(t => t.GetTestDataForAllValues()) + .Select(t => new object[] { t }); + + public static IEnumerable GetTestDataCore() + { + // Primitives and built-in types + yield return new TestData( + Value: new(), + AdditionalValues: [42, false, 3.14, 3.14M, new int[] { 1, 2, 3 }, new SimpleRecord(1, "str", false, 3.14)], + ExpectedJsonSchema: "true"); + + yield return new TestData(true, """{"type":"boolean"}"""); + yield return new TestData(42, """{"type":"integer"}"""); + yield return new TestData(42, """{"type":"integer"}"""); + yield return new TestData(42, """{"type":"integer"}"""); + yield return new TestData(42, """{"type":"integer"}"""); + yield return new TestData(42, """{"type":"integer"}"""); + yield return new TestData(42, """{"type":"integer"}"""); + yield return new TestData(42, """{"type":"integer"}"""); + yield return new TestData(42, """{"type":"integer"}"""); + yield return new TestData(1.2f, """{"type":"number"}"""); + yield return new TestData(3.14159d, """{"type":"number"}"""); + yield return new TestData(3.14159M, """{"type":"number"}"""); +#if NET7_0_OR_GREATER + yield return new TestData(42, """{"type":"integer"}"""); + yield return new TestData(42, """{"type":"integer"}"""); +#endif +#if NET6_0_OR_GREATER + yield return new TestData((Half)3.141, """{"type":"number"}"""); +#endif + yield return new TestData("I am a string", """{"type":["string","null"]}"""); + yield return new TestData('c', """{"type":"string","minLength":1,"maxLength":1}"""); + yield return new TestData( + Value: [1, 2, 3], + AdditionalValues: [[]], + ExpectedJsonSchema: """{"type":["string","null"]}"""); + + yield return new TestData>(new byte[] { 1, 2, 3 }, """{"type":"string"}"""); + yield return new TestData>(new byte[] { 1, 2, 3 }, """{"type":"string"}"""); + yield return new TestData( + Value: new(2021, 1, 1), + AdditionalValues: [DateTime.MinValue, DateTime.MaxValue], + ExpectedJsonSchema: """{"type":"string","format": "date-time"}"""); + + yield return new TestData( + Value: new(new DateTime(2021, 1, 1), TimeSpan.Zero), + AdditionalValues: [DateTimeOffset.MinValue, DateTimeOffset.MaxValue], + ExpectedJsonSchema: """{"type":"string","format": "date-time"}"""); + + yield return new TestData( + Value: new(hours: 5, minutes: 13, seconds: 3), + AdditionalValues: [TimeSpan.MinValue, TimeSpan.MaxValue], + ExpectedJsonSchema: """{"$comment": "Represents a System.TimeSpan value.", "type":"string", "pattern": "^-?(\\d+\\.)?\\d{2}:\\d{2}:\\d{2}(\\.\\d{1,7})?$"}"""); + +#if NET6_0_OR_GREATER + yield return new TestData(new(2021, 1, 1), """{"type":"string","format": "date"}"""); + yield return new TestData(new(hour: 22, minute: 30, second: 33, millisecond: 100), """{"type":"string","format": "time"}"""); +#endif + yield return new TestData(Guid.Empty, """{"type":"string","format":"uuid"}"""); + yield return new TestData(new("http://example.com"), """{"type":["string","null"], "format":"uri"}"""); + yield return new TestData(new(1, 2, 3, 4), """{"$comment":"Represents a version string.", "type":["string","null"],"pattern":"^\\d+(\\.\\d+){1,3}$"}"""); + yield return new TestData(JsonDocument.Parse("""[{ "x" : 42 }]"""), "true"); + yield return new TestData(JsonDocument.Parse("""[{ "x" : 42 }]""").RootElement, "true"); + yield return new TestData(JsonNode.Parse("""[{ "x" : 42 }]"""), "true"); + yield return new TestData((JsonValue)42, "true"); + yield return new TestData(new() { ["x"] = 42 }, """{"type":["object","null"]}"""); + yield return new TestData([1, 2, 3], """{"type":["array","null"]}"""); + + // Enum types + yield return new TestData(IntEnum.A, """{"type":"integer"}"""); + yield return new TestData(StringEnum.A, """{"enum": ["A","B","C"]}"""); + yield return new TestData(FlagsStringEnum.A, """{"type":"string"}"""); + + // Nullable types + yield return new TestData(true, """{"type":["boolean","null"]}"""); + yield return new TestData(42, """{"type":["integer","null"]}"""); + yield return new TestData(3.14, """{"type":["number","null"]}"""); + yield return new TestData(Guid.Empty, """{"type":["string","null"],"format":"uuid"}"""); + yield return new TestData(JsonDocument.Parse("{}").RootElement, "true"); + yield return new TestData(IntEnum.A, """{"type":["integer","null"]}"""); + yield return new TestData(StringEnum.A, """{"enum":["A","B","C",null]}"""); + yield return new TestData( + new(1, "two", true, 3.14), + ExpectedJsonSchema: """ + { + "type":["object","null"], + "properties": { + "X": {"type":"integer"}, + "Y": {"type":"string"}, + "Z": {"type":"boolean"}, + "W": {"type":"number"} + } + } + """); + + // User-defined POCOs + yield return new TestData( + Value: new() { String = "string", StringNullable = "string", Int = 42, Double = 3.14, Boolean = true }, + AdditionalValues: [new() { String = "str", StringNullable = null }], + ExpectedJsonSchema: """ + { + "type": ["object","null"], + "properties": { + "String": { "type": "string" }, + "StringNullable": { "type": ["string", "null"] }, + "Int": { "type": "integer" }, + "Double": { "type": "number" }, + "Boolean": { "type": "boolean" } + } + } + """); + + // Same as above but with nullable types set to non-nullable + yield return new TestData( + Value: new() { String = "string", StringNullable = "string", Int = 42, Double = 3.14, Boolean = true }, + AdditionalValues: [new() { String = "str", StringNullable = null }], + ExpectedJsonSchema: """ + { + "type": "object", + "properties": { + "String": { "type": "string" }, + "StringNullable": { "type": ["string", "null"] }, + "Int": { "type": "integer" }, + "Double": { "type": "number" }, + "Boolean": { "type": "boolean" } + } + } + """, + ExporterOptions: new() { TreatNullObliviousAsNonNullable = true }); + + yield return new TestData( + Value: new(1, "two", true, 3.14), + ExpectedJsonSchema: """ + { + "type": ["object","null"], + "properties": { + "X": { "type": "integer" }, + "Y": { "type": "string" }, + "Z": { "type": "boolean" }, + "W": { "type": "number" } + }, + "required": ["X","Y","Z","W"] + } + """); + + yield return new TestData( + Value: new(1, "two", true, 3.14), + ExpectedJsonSchema: """ + { + "type": "object", + "properties": { + "X": { "type": "integer" }, + "Y": { "type": "string" }, + "Z": { "type": "boolean" }, + "W": { "type": "number" } + } + } + """); + + yield return new TestData( + Value: new(1, "two", true, 3.14, StringEnum.A), + ExpectedJsonSchema: """ + { + "type": ["object","null"], + "properties": { + "X1": { "type": "integer" }, + "X2": { "type": "string" }, + "X3": { "type": "boolean" }, + "X4": { "type": "number" }, + "X5": { "enum": ["A", "B", "C"] }, + "Y1": { "type": "integer", "default": 42 }, + "Y2": { "type": "string", "default": "str" }, + "Y3": { "type": "boolean", "default": true }, + "Y4": { "type": "number", "default": 0 }, + "Y5": { "enum": ["A", "B", "C"], "default": "A" } + }, + "required": ["X1", "X2", "X3", "X4", "X5"] + } + """); + + yield return new TestData( + new() { X = "str1", Y = "str2" }, + ExpectedJsonSchema: """ + { + "type": ["object","null"], + "properties": { + "Y": { "type": "string" }, + "Z": { "type": "integer" }, + "X": { "type": "string" } + }, + "required": [ "Y", "Z", "X" ] + } + """); + + yield return new TestData( + new() { X = 1, Y = 2 }, + ExpectedJsonSchema: """ + { + "type": [ "object", "null" ], + "properties": { + "X": { "type": "integer" } + } + } + """); + yield return new TestData( + Value: new() { IntegerProperty = 1, StringProperty = "str" }, + ExpectedJsonSchema: """ + { + "type": ["object","null"], + "properties": { + "int": { "type": "integer" }, + "str": { "type": [ "string", "null"] } + } + } + """); + + yield return new TestData( + Value: new() { X = 1 }, + ExpectedJsonSchema: """ + { + "type": ["object","null"], + "properties": { "X": { "type": ["string","integer"], "pattern": "^-?(?:0|[1-9]\\d*)$" } } + } + """); + + yield return new TestData( + Value: new() { X = 1, Y = 2, Z = 3 }, + AdditionalValues: [ + new() { X = 1, Y = double.NaN, Z = 3 }, + new() { X = 1, Y = double.PositiveInfinity, Z = 3 }, + new() { X = 1, Y = double.NegativeInfinity, Z = 3 }, + ], + WritesNumbersAsStrings: true, + ExpectedJsonSchema: """ + { + "type": ["object","null"], + "properties": { + "X": { "type": ["string", "integer"], "pattern": "^-?(?:0|[1-9]\\d*)$" }, + "Y": { + "anyOf": [ + { "type": "number" }, + { "enum": ["NaN", "Infinity", "-Infinity"]} + ] + }, + "Z": { "type": ["string", "integer"], "pattern": "^-?(?:0|[1-9]\\d*)$" }, + "W" : { "type": "number" } + } + } + """); + + yield return new TestData( + Value: new() { Value = 1, Next = new() { Value = 2, Next = new() { Value = 3 } } }, + AdditionalValues: [new() { Value = 1, Next = null }], + ExpectedJsonSchema: """ + { + "type": ["object","null"], + "properties": { + "Value": { "type": "integer" }, + "Next": { + "type": ["object","null"], + "properties": { + "Value": { "type": "integer" }, + "Next": { "$ref": "#/properties/Next" } + } + } + } + } + """); + + // Same as above but with non-nullable reference types by default. + yield return new TestData( + Value: new() { Value = 1, Next = new() { Value = 2, Next = new() { Value = 3 } } }, + AdditionalValues: [new() { Value = 1, Next = null }], + ExpectedJsonSchema: """ + { + "type": "object", + "properties": { + "Value": { "type": "integer" }, + "Next": { + "type": ["object","null"], + "properties": { + "Value": { "type": "integer" }, + "Next": { "$ref": "#/properties/Next" } + } + } + } + } + """, + ExporterOptions: new() { TreatNullObliviousAsNonNullable = true }); + +#if !NET9_0 // Disable until https://github.com/dotnet/runtime/pull/108764 gets backported + SimpleRecord recordValue = new(42, "str", true, 3.14); + yield return new TestData( + Value: new() { Value1 = recordValue, Value2 = recordValue, ArrayValue = [recordValue], ListValue = [recordValue] }, + ExpectedJsonSchema: """ + { + "type": ["object","null"], + "properties": { + "Value1": { + "type": ["object","null"], + "properties": { + "X": { "type": "integer" }, + "Y": { "type": "string" }, + "Z": { "type": "boolean" }, + "W": { "type": "number" } + }, + "required": ["X", "Y", "Z", "W"] + }, + /* The same type on a different property is repeated to + account for potential metadata resolved from attributes. */ + "Value2": { + "type": ["object","null"], + "properties": { + "X": { "type": "integer" }, + "Y": { "type": "string" }, + "Z": { "type": "boolean" }, + "W": { "type": "number" } + }, + "required": ["X", "Y", "Z", "W"] + }, + /* This collection element is the first occurrence + of the type without contextual metadata. */ + "ListValue": { + "type": ["array","null"], + "items": { + "type": ["object","null"], + "properties": { + "X": { "type": "integer" }, + "Y": { "type": "string" }, + "Z": { "type": "boolean" }, + "W": { "type": "number" } + }, + "required": ["X", "Y", "Z", "W"] + } + }, + /* This collection element is the second occurrence + of the type which points to the first occurrence. */ + "ArrayValue": { + "type": ["array","null"], + "items": { + "$ref": "#/properties/ListValue/items" + } + } + } + } + """); +#endif + + yield return new TestData( + Value: new() { X = 42 }, + ExpectedJsonSchema: """ + { + "type": ["object","null"], + "properties": { + "X": { + "type": "integer" + } + } + } + """); + + yield return new TestData(new() { Value = 42 }, "true"); + yield return new TestData(new() { Value = 42 }, """{"type":["object","null"],"properties":{"Value":true}}"""); + yield return new TestData( + Value: new() + { + IntEnum = IntEnum.A, + StringEnum = StringEnum.B, + IntEnumUsingStringConverter = IntEnum.A, + NullableIntEnumUsingStringConverter = IntEnum.B, + StringEnumUsingIntConverter = StringEnum.A, + NullableStringEnumUsingIntConverter = StringEnum.B + }, + AdditionalValues: [ + new() + { + IntEnum = (IntEnum)int.MaxValue, + StringEnum = StringEnum.A, + IntEnumUsingStringConverter = IntEnum.A, + NullableIntEnumUsingStringConverter = null, + StringEnumUsingIntConverter = (StringEnum)int.MaxValue, + NullableStringEnumUsingIntConverter = null + }, + ], + ExpectedJsonSchema: """ + { + "type": ["object","null"], + "properties": { + "IntEnum": { "type": "integer" }, + "StringEnum": { "enum": [ "A", "B", "C" ] }, + "IntEnumUsingStringConverter": { "enum": [ "A", "B", "C" ] }, + "NullableIntEnumUsingStringConverter": { "enum": [ "A", "B", "C", null ] }, + "StringEnumUsingIntConverter": { "type": "integer" }, + "NullableStringEnumUsingIntConverter": { "type": [ "integer", "null" ] } + } + } + """); + + var recordStruct = new SimpleRecordStruct(42, "str", true, 3.14); + yield return new TestData( + Value: new() { Struct = recordStruct, NullableStruct = null }, + AdditionalValues: [new() { Struct = recordStruct, NullableStruct = recordStruct }], + ExpectedJsonSchema: """ + { + "type": ["object","null"], + "properties": { + "Struct": { + "type": "object", + "properties": { + "X": {"type":"integer"}, + "Y": {"type":"string"}, + "Z": {"type":"boolean"}, + "W": {"type":"number"} + } + }, + "NullableStruct": { + "type": ["object","null"], + "properties": { + "X": {"type":"integer"}, + "Y": {"type":"string"}, + "Z": {"type":"boolean"}, + "W": {"type":"number"} + } + } + } + } + """); + + yield return new TestData( + Value: new() { NullableStruct = null, Struct = recordStruct }, + AdditionalValues: [new() { NullableStruct = recordStruct, Struct = recordStruct }], + ExpectedJsonSchema: """ + { + "type": ["object","null"], + "properties": { + "NullableStruct": { + "type": ["object","null"], + "properties": { + "X": {"type":"integer"}, + "Y": {"type":"string"}, + "Z": {"type":"boolean"}, + "W": {"type":"number"} + } + }, + "Struct": { + "type": "object", + "properties": { + "X": {"type":"integer"}, + "Y": {"type":"string"}, + "Z": {"type":"boolean"}, + "W": {"type":"number"} + } + } + } + } + """); + + yield return new TestData( + Value: new() { Name = "name", ExtensionData = new() { ["x"] = 42 } }, + """{"type":["object","null"],"properties":{"Name":{"type":["string","null"]}}}"""); + + yield return new TestData( + Value: new() { Name = "name", Age = 42 }, + ExpectedJsonSchema: """ + { + "type": ["object","null"], + "properties": { + "Name": {"type":["string","null"]}, + "Age": {"type":"integer"} + }, + "additionalProperties": false + } + """); + + // Global JsonUnmappedMemberHandling.Disallow setting + yield return new TestData( + Value: new() { String = "string", StringNullable = "string", Int = 42, Double = 3.14, Boolean = true }, + AdditionalValues: [new() { String = "str", StringNullable = null }], + ExpectedJsonSchema: """ + { + "type": ["object","null"], + "properties": { + "String": { "type": "string" }, + "StringNullable": { "type": ["string", "null"] }, + "Int": { "type": "integer" }, + "Double": { "type": "number" }, + "Boolean": { "type": "boolean" } + }, + "additionalProperties": false + } + """, + Options: new() { UnmappedMemberHandling = JsonUnmappedMemberHandling.Disallow }); + + yield return new TestData( + Value: new() { MaybeNull = null!, AllowNull = null, NotNull = null, DisallowNull = null!, NotNullDisallowNull = "str" }, + ExpectedJsonSchema: """ + { + "type": ["object","null"], + "properties": { + "MaybeNull": {"type":["string","null"]}, + "AllowNull": {"type":["string","null"]}, + "NotNull": {"type":["string","null"]}, + "DisallowNull": {"type":["string","null"]}, + "NotNullDisallowNull": {"type":"string"} + } + } + """); + + yield return new TestData( + Value: new(allowNull: null, disallowNull: "str"), + ExpectedJsonSchema: """ + { + "type": ["object","null"], + "properties": { + "AllowNull": {"type":["string","null"]}, + "DisallowNull": {"type":"string"} + }, + "required": ["AllowNull", "DisallowNull"] + } + """); + + yield return new TestData( + Value: new(null), + ExpectedJsonSchema: """ + { + "type": ["object","null"], + "properties": { + "Value": {"type":["string","null"]} + }, + "required": ["Value"] + } + """); + + yield return new TestData( + Value: new(), + ExpectedJsonSchema: """ + { + "type": ["object","null"], + "properties": { + "X1": {"type":"string", "default": "str" }, + "X2": {"type":"integer", "default": 42 }, + "X3": {"type":"boolean", "default": true }, + "X4": {"type":"number", "default": 0 }, + "X5": {"enum":["A","B","C"], "default": "A" }, + "X6": {"type":["string","null"], "default": "str" }, + "X7": {"type":["integer","null"], "default": 42 }, + "X8": {"type":["boolean","null"], "default": true }, + "X9": {"type":["number","null"], "default": 0 }, + "X10": {"enum":["A","B","C", null], "default": "A" } + } + } + """); + + yield return new TestData>( + Value: new(null!), + ExpectedJsonSchema: """ + { + "type": ["object","null"], + "properties": { + "Value": {"type":["string","null"]} + }, + "required": ["Value"] + } + """); + + yield return new TestData( + Value: new PocoWithPolymorphism.DerivedPocoStringDiscriminator { BaseValue = 42, DerivedValue = "derived" }, + AdditionalValues: [ + new PocoWithPolymorphism.DerivedPocoNoDiscriminator { BaseValue = 42, DerivedValue = "derived" }, + new PocoWithPolymorphism.DerivedPocoIntDiscriminator { BaseValue = 42, DerivedValue = "derived" }, + new PocoWithPolymorphism.DerivedCollection { BaseValue = 42 }, + new PocoWithPolymorphism.DerivedDictionary { BaseValue = 42 }, + ], + + ExpectedJsonSchema: """ + { + "type": ["object","null"], + "anyOf": [ + { + "properties": { + "BaseValue": {"type":"integer"}, + "DerivedValue": {"type":["string","null"]} + } + }, + { + "properties": { + "$type": {"const":"derivedPoco"}, + "BaseValue": {"type":"integer"}, + "DerivedValue": {"type":["string","null"]} + }, + "required": ["$type"] + }, + { + "properties": { + "$type": {"const":42}, + "BaseValue": {"type":"integer"}, + "DerivedValue": {"type":["string","null"]} + }, + "required": ["$type"] + }, + { + "properties": { + "$type": {"const":"derivedCollection"}, + "$values": { + "type": "array", + "items": {"type":"integer"} + } + }, + "required": ["$type"] + }, + { + "properties": { + "$type": {"const":"derivedDictionary"} + }, + "additionalProperties":{"type": "integer"}, + "required": ["$type"] + } + ] + } + """); + + yield return new TestData( + Value: new NonAbstractClassWithSingleDerivedType(), + AdditionalValues: [new NonAbstractClassWithSingleDerivedType.Derived()], + ExpectedJsonSchema: """ + { + "type": ["object","null"] + } + """); + +#if !NET9_0 // Disable until https://github.com/microsoft/semantic-kernel/issues/8983 gets backported to .NET 9 + yield return new TestData( + Value: new(value: null), + AdditionalValues: [new(true), new(42), new(""), new(new object()), new(Array.Empty())], + ExpectedJsonSchema: """ + { + "type": ["object","null"], + "properties": { + "Value": { "default": null } + } + } + """); +#endif + + yield return new TestData( + Value: new(), + ExpectedJsonSchema: """ + { + "type": ["object","null"], + "properties": { + "PolymorphicValue": { + "type": "object", + "anyOf": [ + { + "properties": { + "BaseValue": {"type":"integer"}, + "DerivedValue": {"type":["string","null"]} + } + }, + { + "properties": { + "$type": {"const":"derivedPoco"}, + "BaseValue": {"type":"integer"}, + "DerivedValue": {"type":["string","null"]} + }, + "required": ["$type"] + }, + { + "properties": { + "$type": {"const":42}, + "BaseValue": {"type":"integer"}, + "DerivedValue": {"type":["string","null"]} + }, + "required": ["$type"] + }, + { + "properties": { + "$type": {"const":"derivedCollection"}, + "$values": { + "type": "array", + "items": {"type":"integer"} + } + }, + "required": ["$type"] + }, + { + "properties": { + "$type": {"const":"derivedDictionary"} + }, + "additionalProperties":{"type": "integer"}, + "required": ["$type"] + } + ] + }, + "DerivedValue1": { + "type": "object", + "properties": { + "BaseValue": { + "type": "integer" + }, + "DerivedValue": { + "type": [ + "string", + "null" + ] + } + } + }, + "DerivedValue2": { + "type": "object", + "properties": { + "BaseValue": {"type":"integer"}, + "DerivedValue": {"type":["string","null"]} + } + } + } + } + """); + + yield return new TestData( + Value: new("string", -1), + ExpectedJsonSchema: """ + { + "type": ["object","null"], + "properties": { + "StringValue": {"type":"string","pattern":"\\w+"}, + "IntValue": {"type":"integer","default":42} + }, + "required": ["StringValue","IntValue"] + } + """, + ExporterOptions: new() + { + TransformSchemaNode = static (ctx, schema) => + { + if (ctx.PropertyInfo is null || schema is not JsonObject jObj) + { + return schema; + } + + if (ctx.ResolveAttribute() is { } attr) + { + jObj["default"] = JsonSerializer.SerializeToNode(attr.Value); + } + + if (ctx.ResolveAttribute() is { } regexAttr) + { + jObj["pattern"] = regexAttr.Pattern; + } + + return jObj; + } + }); + + // Collection types + yield return new TestData([1, 2, 3], """{"type":["array","null"],"items":{"type":"integer"}}"""); + yield return new TestData>([false, true, false], """{"type":["array","null"],"items":{"type":"boolean"}}"""); + yield return new TestData>(["one", "two", "three"], """{"type":["array","null"],"items":{"type":["string","null"]}}"""); + yield return new TestData>(new([1.1, 2.2, 3.3]), """{"type":["array","null"],"items":{"type":"number"}}"""); + yield return new TestData>(new(['x', '2', '+']), """{"type":["array","null"],"items":{"type":"string","minLength":1,"maxLength":1}}"""); + yield return new TestData>(ImmutableArray.Create(1, 2, 3), """{"type":"array","items":{"type":"integer"}}"""); + yield return new TestData>(ImmutableList.Create("one", "two", "three"), """{"type":["array","null"],"items":{"type":["string","null"]}}"""); + yield return new TestData>(ImmutableQueue.Create(false, false, true), """{"type":["array","null"],"items":{"type":"boolean"}}"""); + yield return new TestData([1, "two", 3.14], """{"type":["array","null"]}"""); + yield return new TestData([1, "two", 3.14], """{"type":["array","null"]}"""); + + // Dictionary types + yield return new TestData>( + Value: new() { ["one"] = 1, ["two"] = 2, ["three"] = 3 }, + ExpectedJsonSchema: """{"type":["object","null"],"additionalProperties":{"type": "integer"}}"""); + + yield return new TestData>( + Value: new([new("one", 1), new("two", 2), new("three", 3)]), + ExpectedJsonSchema: """{"type":"object","additionalProperties":{"type": "integer"}}"""); + + yield return new TestData>( + Value: new() { [1] = "one", [2] = "two", [3] = "three" }, + ExpectedJsonSchema: """{"type":["object","null"],"additionalProperties":{"type": ["string","null"]}}"""); + + yield return new TestData>( + Value: new() + { + ["one"] = new() { String = "string", StringNullable = "string", Int = 42, Double = 3.14, Boolean = true }, + ["two"] = new() { String = "string", StringNullable = null, Int = 42, Double = 3.14, Boolean = true }, + ["three"] = new() { String = "string", StringNullable = null, Int = 42, Double = 3.14, Boolean = true }, + }, + ExpectedJsonSchema: """ + { + "type": ["object","null"], + "additionalProperties": { + "properties": { + "String": { "type": "string" }, + "StringNullable": { "type": ["string", "null"] }, + "Int": { "type": "integer" }, + "Double": { "type": "number" }, + "Boolean": { "type": "boolean" } + }, + "type": ["object","null"] + } + } + """); + + yield return new TestData>( + Value: new() { ["one"] = 1, ["two"] = "two", ["three"] = 3.14 }, + ExpectedJsonSchema: """{"type":["object","null"]}"""); + + yield return new TestData( + Value: new() { ["one"] = 1, ["two"] = "two", ["three"] = 3.14 }, + ExpectedJsonSchema: """{"type":["object","null"]}"""); + } + + public enum IntEnum { A, B, C } + + [JsonConverter(typeof(JsonStringEnumConverter))] + public enum StringEnum { A, B, C } + + [Flags, JsonConverter(typeof(JsonStringEnumConverter))] + public enum FlagsStringEnum { A = 1, B = 2, C = 4 } + + public class SimplePoco + { + public string String { get; set; } = "default"; + public string? StringNullable { get; set; } + + public int Int { get; set; } + public double Double { get; set; } + public bool Boolean { get; set; } + } + + public record SimpleRecord(int X, string Y, bool Z, double W); + public record struct SimpleRecordStruct(int X, string Y, bool Z, double W); + + public record RecordWithOptionalParameters( + [property: Description("required integer")] int X1, string X2, bool X3, double X4, [Description("required string enum")] StringEnum X5, + [property: Description("optional integer")] int Y1 = 42, string Y2 = "str", bool Y3 = true, double Y4 = 0, [Description("optional string enum")] StringEnum Y5 = StringEnum.A); + + public class PocoWithRequiredMembers + { + [JsonInclude] + public required string X; + + public required string Y { get; set; } + + [JsonRequired] + public int Z { get; set; } + } + + public class PocoWithIgnoredMembers + { + public int X { get; set; } + + [JsonIgnore] + public int Y { get; set; } + } + + public class PocoWithCustomNaming + { + [JsonPropertyName("int")] + public int IntegerProperty { get; set; } + + [JsonPropertyName("str")] + public string? StringProperty { get; set; } + } + + [JsonNumberHandling(JsonNumberHandling.AllowReadingFromString)] + public class PocoWithCustomNumberHandling + { + public int X { get; set; } + } + + public class PocoWithCustomNumberHandlingOnProperties + { + [JsonNumberHandling(JsonNumberHandling.AllowReadingFromString)] + public int X { get; set; } + + [JsonNumberHandling(JsonNumberHandling.AllowNamedFloatingPointLiterals)] + public double Y { get; set; } + + [JsonNumberHandling(JsonNumberHandling.WriteAsString)] + public int Z { get; set; } + + [JsonNumberHandling(JsonNumberHandling.AllowNamedFloatingPointLiterals)] + public decimal W { get; set; } + } + + public class PocoWithRecursiveMembers + { + public int Value { get; init; } + public PocoWithRecursiveMembers? Next { get; init; } + } + + public class PocoWithNonRecursiveDuplicateOccurrences + { + public SimpleRecord? Value1 { get; set; } + public SimpleRecord? Value2 { get; set; } + public List? ListValue { get; set; } + public SimpleRecord[]? ArrayValue { get; set; } + } + + [Description("The type description")] + public class PocoWithDescription + { + [Description("The property description")] + public int X { get; set; } + } + + [JsonConverter(typeof(CustomConverter))] + public class PocoWithCustomConverter + { + public int Value { get; set; } + + public class CustomConverter : JsonConverter + { + public override PocoWithCustomConverter Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) => + new PocoWithCustomConverter { Value = reader.GetInt32() }; + + public override void Write(Utf8JsonWriter writer, PocoWithCustomConverter value, JsonSerializerOptions options) => + writer.WriteNumberValue(value.Value); + } + } + + public class PocoWithCustomPropertyConverter + { + [JsonConverter(typeof(CustomConverter))] + public int Value { get; set; } + + public class CustomConverter : JsonConverter + { + public override int Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + => int.Parse(reader.GetString()!); + + public override void Write(Utf8JsonWriter writer, int value, JsonSerializerOptions options) + => writer.WriteStringValue(value.ToString()); + } + } + + public class PocoWithEnums + { + public IntEnum IntEnum { get; init; } + public StringEnum StringEnum { get; init; } + + [JsonConverter(typeof(JsonStringEnumConverter))] + public IntEnum IntEnumUsingStringConverter { get; set; } + + [JsonConverter(typeof(JsonStringEnumConverter))] + public IntEnum? NullableIntEnumUsingStringConverter { get; set; } + + [JsonConverter(typeof(JsonNumberEnumConverter))] + public StringEnum StringEnumUsingIntConverter { get; set; } + + [JsonConverter(typeof(JsonNumberEnumConverter))] + public StringEnum? NullableStringEnumUsingIntConverter { get; set; } + } + + public class PocoWithStructFollowedByNullableStruct + { + public SimpleRecordStruct? NullableStruct { get; set; } + public SimpleRecordStruct Struct { get; set; } + } + + public class PocoWithNullableStructFollowedByStruct + { + public SimpleRecordStruct? NullableStruct { get; set; } + public SimpleRecordStruct Struct { get; set; } + } + + public class PocoWithExtensionDataProperty + { + public string? Name { get; set; } + + [JsonExtensionData] + public Dictionary? ExtensionData { get; set; } + } + + [JsonUnmappedMemberHandling(JsonUnmappedMemberHandling.Disallow)] + public class PocoDisallowingUnmappedMembers + { + public string? Name { get; set; } + public int Age { get; set; } + } + + public class PocoWithNullableAnnotationAttributes + { + [MaybeNull] + public string MaybeNull { get; set; } + + [AllowNull] + public string AllowNull { get; set; } + + [NotNull] + public string? NotNull { get; set; } + + [DisallowNull] + public string? DisallowNull { get; set; } + + [NotNull, DisallowNull] + public string? NotNullDisallowNull { get; set; } = ""; + } + + public class PocoWithNullableAnnotationAttributesOnConstructorParams([AllowNull] string allowNull, [DisallowNull] string? disallowNull) + { + public string AllowNull { get; } = allowNull!; + public string DisallowNull { get; } = disallowNull; + } + + public class PocoWithNullableConstructorParameter(string? value) + { + public string Value { get; } = value!; + } + + public class PocoWithOptionalConstructorParams( + string x1 = "str", int x2 = 42, bool x3 = true, double x4 = 0, StringEnum x5 = StringEnum.A, + string? x6 = "str", int? x7 = 42, bool? x8 = true, double? x9 = 0, StringEnum? x10 = StringEnum.A) + { + public string X1 { get; } = x1; + public int X2 { get; } = x2; + public bool X3 { get; } = x3; + public double X4 { get; } = x4; + public StringEnum X5 { get; } = x5; + + public string? X6 { get; } = x6; + public int? X7 { get; } = x7; + public bool? X8 { get; } = x8; + public double? X9 { get; } = x9; + public StringEnum? X10 { get; } = x10; + } + + // Regression test for https://github.com/dotnet/runtime/issues/92487 + public class GenericPocoWithNullableConstructorParameter(T value) + { + [NotNull] + public T Value { get; } = value!; + } + + [JsonDerivedType(typeof(DerivedPocoNoDiscriminator))] + [JsonDerivedType(typeof(DerivedPocoStringDiscriminator), "derivedPoco")] + [JsonDerivedType(typeof(DerivedPocoIntDiscriminator), 42)] + [JsonDerivedType(typeof(DerivedCollection), "derivedCollection")] + [JsonDerivedType(typeof(DerivedDictionary), "derivedDictionary")] + public abstract class PocoWithPolymorphism + { + public int BaseValue { get; set; } + + public class DerivedPocoNoDiscriminator : PocoWithPolymorphism + { + public string? DerivedValue { get; set; } + } + + public class DerivedPocoStringDiscriminator : PocoWithPolymorphism + { + public string? DerivedValue { get; set; } + } + + public class DerivedPocoIntDiscriminator : PocoWithPolymorphism + { + public string? DerivedValue { get; set; } + } + + public class DerivedCollection : PocoWithPolymorphism, IEnumerable + { + public IEnumerator GetEnumerator() => Enumerable.Repeat(BaseValue, 1).GetEnumerator(); + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + } + + public class DerivedDictionary : PocoWithPolymorphism, IReadOnlyDictionary + { + public int this[string key] => key == nameof(BaseValue) ? BaseValue : throw new KeyNotFoundException(); + public IEnumerable Keys => [nameof(BaseValue)]; + public IEnumerable Values => [BaseValue]; + public int Count => 1; + public bool ContainsKey(string key) => key == nameof(BaseValue); + public bool TryGetValue(string key, out int value) => key == nameof(BaseValue) ? (value = BaseValue) == BaseValue : (value = 0) == 0; + public IEnumerator> GetEnumerator() => Enumerable.Repeat(new KeyValuePair(nameof(BaseValue), BaseValue), 1).GetEnumerator(); + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + } + } + + [JsonDerivedType(typeof(NonAbstractClassWithSingleDerivedType.Derived))] + public class NonAbstractClassWithSingleDerivedType + { + public class Derived : NonAbstractClassWithSingleDerivedType; + } + + public class PocoCombiningPolymorphicTypeAndDerivedTypes + { + public PocoWithPolymorphism PolymorphicValue { get; set; } = new PocoWithPolymorphism.DerivedPocoNoDiscriminator { DerivedValue = "derived" }; + public PocoWithPolymorphism.DerivedPocoNoDiscriminator DerivedValue1 { get; set; } = new() { DerivedValue = "derived" }; + public PocoWithPolymorphism.DerivedPocoStringDiscriminator DerivedValue2 { get; set; } = new() { DerivedValue = "derived" }; + } + + public class ClassWithComponentModelAttributes + { + public ClassWithComponentModelAttributes(string stringValue, [DefaultValue(42)] int intValue) + { + StringValue = stringValue; + IntValue = intValue; + } + + [RegularExpression(@"\w+")] + public string StringValue { get; } + + public int IntValue { get; } + } + + public class ClassWithOptionalObjectParameter(object? value = null) + { + public object? Value { get; } = value; + } + + public readonly struct StructDictionary(IEnumerable> values) + : IReadOnlyDictionary + where TKey : notnull + { + private readonly IReadOnlyDictionary _dictionary = values.ToDictionary(kvp => kvp.Key, kvp => kvp.Value); + public TValue this[TKey key] => _dictionary[key]; + public IEnumerable Keys => _dictionary.Keys; + public IEnumerable Values => _dictionary.Values; + public int Count => _dictionary.Count; + public bool ContainsKey(TKey key) => _dictionary.ContainsKey(key); + public IEnumerator> GetEnumerator() => _dictionary.GetEnumerator(); +#if NETCOREAPP + public bool TryGetValue(TKey key, [MaybeNullWhen(false)] out TValue value) => _dictionary.TryGetValue(key, out value); +#else + public bool TryGetValue(TKey key, out TValue value) => _dictionary.TryGetValue(key, out value); +#endif + IEnumerator IEnumerable.GetEnumerator() => ((IEnumerable)_dictionary).GetEnumerator(); + } + + [JsonSerializable(typeof(object))] + [JsonSerializable(typeof(bool))] + [JsonSerializable(typeof(byte))] + [JsonSerializable(typeof(ushort))] + [JsonSerializable(typeof(uint))] + [JsonSerializable(typeof(ulong))] + [JsonSerializable(typeof(sbyte))] + [JsonSerializable(typeof(short))] + [JsonSerializable(typeof(int))] + [JsonSerializable(typeof(long))] + [JsonSerializable(typeof(float))] + [JsonSerializable(typeof(double))] + [JsonSerializable(typeof(decimal))] +#if NET7_0_OR_GREATER + [JsonSerializable(typeof(UInt128))] + [JsonSerializable(typeof(Int128))] +#endif +#if NET6_0_OR_GREATER + [JsonSerializable(typeof(Half))] +#endif + [JsonSerializable(typeof(string))] + [JsonSerializable(typeof(char))] + [JsonSerializable(typeof(byte[]))] + [JsonSerializable(typeof(Memory))] + [JsonSerializable(typeof(ReadOnlyMemory))] + [JsonSerializable(typeof(DateTime))] + [JsonSerializable(typeof(DateTimeOffset))] + [JsonSerializable(typeof(TimeSpan))] +#if NET6_0_OR_GREATER + [JsonSerializable(typeof(DateOnly))] + [JsonSerializable(typeof(TimeOnly))] +#endif + [JsonSerializable(typeof(Guid))] + [JsonSerializable(typeof(Uri))] + [JsonSerializable(typeof(Version))] + [JsonSerializable(typeof(JsonDocument))] + [JsonSerializable(typeof(JsonElement))] + [JsonSerializable(typeof(JsonNode))] + [JsonSerializable(typeof(JsonValue))] + [JsonSerializable(typeof(JsonObject))] + [JsonSerializable(typeof(JsonArray))] + // Enum types + [JsonSerializable(typeof(IntEnum))] + [JsonSerializable(typeof(StringEnum))] + [JsonSerializable(typeof(FlagsStringEnum))] + // Nullable types + [JsonSerializable(typeof(bool?))] + [JsonSerializable(typeof(int?))] + [JsonSerializable(typeof(double?))] + [JsonSerializable(typeof(Guid?))] + [JsonSerializable(typeof(JsonElement?))] + [JsonSerializable(typeof(IntEnum?))] + [JsonSerializable(typeof(StringEnum?))] + [JsonSerializable(typeof(SimpleRecordStruct?))] + // User-defined POCOs + [JsonSerializable(typeof(SimplePoco))] + [JsonSerializable(typeof(SimpleRecord))] + [JsonSerializable(typeof(SimpleRecordStruct))] + [JsonSerializable(typeof(RecordWithOptionalParameters))] + [JsonSerializable(typeof(PocoWithRequiredMembers))] + [JsonSerializable(typeof(PocoWithIgnoredMembers))] + [JsonSerializable(typeof(PocoWithCustomNaming))] + [JsonSerializable(typeof(PocoWithCustomNumberHandling))] + [JsonSerializable(typeof(PocoWithCustomNumberHandlingOnProperties))] + [JsonSerializable(typeof(PocoWithRecursiveMembers))] + [JsonSerializable(typeof(PocoWithNonRecursiveDuplicateOccurrences))] + [JsonSerializable(typeof(PocoWithDescription))] + [JsonSerializable(typeof(PocoWithCustomConverter))] + [JsonSerializable(typeof(PocoWithCustomPropertyConverter))] + [JsonSerializable(typeof(PocoWithEnums))] + [JsonSerializable(typeof(PocoWithStructFollowedByNullableStruct))] + [JsonSerializable(typeof(PocoWithNullableStructFollowedByStruct))] + [JsonSerializable(typeof(PocoWithExtensionDataProperty))] + [JsonSerializable(typeof(PocoDisallowingUnmappedMembers))] + [JsonSerializable(typeof(PocoWithNullableAnnotationAttributes))] + [JsonSerializable(typeof(PocoWithNullableAnnotationAttributesOnConstructorParams))] + [JsonSerializable(typeof(PocoWithNullableConstructorParameter))] + [JsonSerializable(typeof(PocoWithOptionalConstructorParams))] + [JsonSerializable(typeof(GenericPocoWithNullableConstructorParameter))] + [JsonSerializable(typeof(PocoWithPolymorphism))] + [JsonSerializable(typeof(NonAbstractClassWithSingleDerivedType))] + [JsonSerializable(typeof(PocoCombiningPolymorphicTypeAndDerivedTypes))] + [JsonSerializable(typeof(ClassWithComponentModelAttributes))] + [JsonSerializable(typeof(ClassWithOptionalObjectParameter))] + // Collection types + [JsonSerializable(typeof(int[]))] + [JsonSerializable(typeof(List))] + [JsonSerializable(typeof(HashSet))] + [JsonSerializable(typeof(Queue))] + [JsonSerializable(typeof(Stack))] + [JsonSerializable(typeof(ImmutableArray))] + [JsonSerializable(typeof(ImmutableList))] + [JsonSerializable(typeof(ImmutableQueue))] + [JsonSerializable(typeof(object[]))] + [JsonSerializable(typeof(System.Collections.ArrayList))] + [JsonSerializable(typeof(Dictionary))] + [JsonSerializable(typeof(SortedDictionary))] + [JsonSerializable(typeof(Dictionary))] + [JsonSerializable(typeof(Dictionary))] + [JsonSerializable(typeof(Hashtable))] + [JsonSerializable(typeof(StructDictionary))] + [JsonSerializable(typeof(XElement))] + public partial class TestTypesContext : JsonSerializerContext; + + private static TAttribute? ResolveAttribute(this JsonSchemaExporterContext ctx) + where TAttribute : Attribute + { + // Resolve attributes from locations in the following order: + // 1. Property-level attributes + // 2. Parameter-level attributes and + // 3. Type-level attributes. + return +#if NET9_0_OR_GREATER || !TESTS_JSON_SCHEMA_EXPORTER_POLYFILL + GetAttrs(ctx.PropertyInfo?.AttributeProvider) ?? + GetAttrs(ctx.PropertyInfo?.AssociatedParameter?.AttributeProvider) ?? +#else + GetAttrs(ctx.PropertyAttributeProvider) ?? + GetAttrs(ctx.ParameterInfo) ?? +#endif + GetAttrs(ctx.TypeInfo.Type); + + static TAttribute? GetAttrs(ICustomAttributeProvider? provider) => + (TAttribute?)provider?.GetCustomAttributes(typeof(TAttribute), inherit: false).FirstOrDefault(); + } +} diff --git a/test/Shared/Shared.Tests.csproj b/test/Shared/Shared.Tests.csproj index d7bfa1801e2..456e50f67a9 100644 --- a/test/Shared/Shared.Tests.csproj +++ b/test/Shared/Shared.Tests.csproj @@ -2,19 +2,26 @@ Microsoft.Shared.Test Unit tests for Microsoft.Shared + $(DefineConstants);TESTS_JSON_SCHEMA_EXPORTER_POLYFILL - $(NoWarn);CA1716 + $(NoWarn);CA1716;S104 $(TestNetCoreTargetFrameworks) $(TestNetCoreTargetFrameworks)$(ConditionalNet462) + + true + true + + +