Skip to content

Commit

Permalink
Update tests and fix post data missed disposal
Browse files Browse the repository at this point in the history
  • Loading branch information
stevejgordon committed Oct 30, 2024
1 parent 29bae46 commit 27cd6b1
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 87 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -230,10 +230,7 @@ private async ValueTask<TResponse> SetBodyCoreAsync<TResponse>(bool isAsync,
if (details.HttpStatusCode.HasValue &&
requestData.SkipDeserializationForStatusCodes.Contains(details.HttpStatusCode.Value))
{
// In this scenario, we always dispose as we've explicitly skipped reading the response
if (ownsStream)
responseStream.Dispose();

ConditionalDisposal(responseStream, ownsStream, response);
return null;
}

Expand Down Expand Up @@ -296,7 +293,6 @@ private async ValueTask<TResponse> SetBodyCoreAsync<TResponse>(bool isAsync,
{
// Note the exception this handles is ONLY thrown after a check if the stream length is zero.
// When the length is zero, `default` is returned by Deserialize(Async) instead.

ConditionalDisposal(responseStream, ownsStream, response);
return default;
}
Expand Down
6 changes: 6 additions & 0 deletions src/Elastic.Transport/Requests/Body/PostData.cs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ protected void FinishStream(Stream writableStream, MemoryStream buffer, ITranspo
buffer.Position = 0;
buffer.CopyTo(writableStream, BufferSize);
WrittenBytes ??= buffer.ToArray();
buffer.Dispose();
}

/// <summary>
Expand All @@ -132,5 +133,10 @@ protected async
buffer.Position = 0;
await buffer.CopyToAsync(writableStream, BufferSize, ctx).ConfigureAwait(false);
WrittenBytes ??= buffer.ToArray();
#if NET
await buffer.DisposeAsync().ConfigureAwait(false);
#else
buffer.Dispose();
#endif
}
}
106 changes: 78 additions & 28 deletions tests/Elastic.Transport.IntegrationTests/Http/StreamResponseTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text.Json;
using System.Text;
using System.Threading.Tasks;
using Elastic.Transport.IntegrationTests.Plumbing;
using Elastic.Transport.Products.Elasticsearch;
Expand Down Expand Up @@ -33,23 +33,6 @@ public async Task StreamResponse_ShouldNotBeDisposed()
_ = sr.ReadToEndAsync();
}

//[Fact]
//public async Task StreamResponse_MemoryStreamShouldNotBeDisposed()
//{
// var nodePool = new SingleNodePool(Server.Uri);
// var memoryStreamFactory = new TrackMemoryStreamFactory();
// var config = new TransportConfiguration(nodePool, productRegistration: new ElasticsearchProductRegistration(typeof(Clients.Elasticsearch.ElasticsearchClient)))
// .MemoryStreamFactory(memoryStreamFactory);

// var transport = new DistributedTransport(config);

// _ = await transport.PostAsync<StreamResponse>(Path, PostData.String("{}"));

// var memoryStream = memoryStreamFactory.Created.Last();

// memoryStream.IsDisposed.Should().BeFalse();
//}

[Fact]
public async Task StreamResponse_MemoryStreamShouldNotBeDisposed()
{
Expand All @@ -63,9 +46,9 @@ public async Task StreamResponse_MemoryStreamShouldNotBeDisposed()

_ = await transport.PostAsync<StreamResponse>(Path, PostData.String("{}"));

var memoryStream = memoryStreamFactory.Created.Last();

memoryStream.IsDisposed.Should().BeFalse();
// When disable direct streaming, we have 1 for the original content, 1 for the buffered request bytes and the last for the buffered response
memoryStreamFactory.Created.Count.Should().Be(3);
memoryStreamFactory.Created.Last().IsDisposed.Should().BeFalse();
}

[Fact]
Expand All @@ -80,13 +63,36 @@ public async Task StringResponse_MemoryStreamShouldBeDisposed()

_ = await transport.PostAsync<StringResponse>(Path, PostData.String("{}"));

var memoryStream = memoryStreamFactory.Created.Last();
memoryStreamFactory.Created.Count.Should().Be(2);
foreach (var memoryStream in memoryStreamFactory.Created)
{
memoryStream.IsDisposed.Should().BeTrue();
}
}

[Fact]
public async Task WhenInvalidJson_MemoryStreamShouldBeDisposed()
{
var nodePool = new SingleNodePool(Server.Uri);
var memoryStreamFactory = new TrackMemoryStreamFactory();
var config = new TransportConfiguration(nodePool, productRegistration: new ElasticsearchProductRegistration(typeof(Clients.Elasticsearch.ElasticsearchClient)))
.MemoryStreamFactory(memoryStreamFactory)
.DisableDirectStreaming(true);

var transport = new DistributedTransport(config);

var payload = new Payload { ResponseJsonString = " " };
_ = await transport.PostAsync<TestResponse>(Path, PostData.Serializable(payload));

memoryStream.IsDisposed.Should().BeTrue();
memoryStreamFactory.Created.Count.Should().Be(3);
foreach (var memoryStream in memoryStreamFactory.Created)
{
memoryStream.IsDisposed.Should().BeTrue();
}
}

[Fact]
public async Task Response_MemoryStreamShouldBeDisposed()
public async Task WhenNoContent_MemoryStreamShouldBeDisposed()
{
var nodePool = new SingleNodePool(Server.Uri);
var memoryStreamFactory = new TrackMemoryStreamFactory();
Expand All @@ -95,11 +101,37 @@ public async Task Response_MemoryStreamShouldBeDisposed()

var transport = new DistributedTransport(config);

_ = await transport.PostAsync<TestResponse>(Path, PostData.String("{}"));
var payload = new Payload { ResponseJsonString = "", StatusCode = 204 };
_ = await transport.PostAsync<TestResponse>(Path, PostData.Serializable(payload));

var memoryStream = memoryStreamFactory.Created.Last();
// We expect one for sending the request payload, but as the response is 204, we shouldn't
// see other memory streams being created for the response.
memoryStreamFactory.Created.Count.Should().Be(1);
foreach (var memoryStream in memoryStreamFactory.Created)
{
memoryStream.IsDisposed.Should().BeTrue();
}
}

[Fact]
public async Task PlainText_MemoryStreamShouldBeDisposed()
{
var nodePool = new SingleNodePool(Server.Uri);
var memoryStreamFactory = new TrackMemoryStreamFactory();
var config = new TransportConfiguration(nodePool, productRegistration: new ElasticsearchProductRegistration(typeof(Clients.Elasticsearch.ElasticsearchClient)))
.MemoryStreamFactory(memoryStreamFactory)
.DisableDirectStreaming(true);

var transport = new DistributedTransport(config);

var payload = new Payload { ResponseJsonString = "text", ContentType = "text/plain" };
_ = await transport.PostAsync<TestResponse>(Path, PostData.Serializable(payload));

memoryStream.IsDisposed.Should().BeTrue();
memoryStreamFactory.Created.Count.Should().Be(3);
foreach (var memoryStream in memoryStreamFactory.Created)
{
memoryStream.IsDisposed.Should().BeTrue();
}
}

private class TestResponse : TransportResponse
Expand Down Expand Up @@ -150,9 +182,27 @@ public override MemoryStream Create(byte[] bytes, int index, int count)
}
}

public class Payload
{
public string ResponseJsonString { get; set; } = "{}";
public string ContentType { get; set; } = "application/json";
public int StatusCode { get; set; } = 200;
}

[ApiController, Route("[controller]")]
public class StreamResponseController : ControllerBase
{
[HttpPost]
public Task<JsonElement> Post([FromBody] JsonElement body) => Task.FromResult(body);
public async Task<ActionResult> Post([FromBody] Payload payload)
{
Response.ContentType = payload.ContentType;

if (payload.StatusCode != 204)
{
await Response.BodyWriter.WriteAsync(Encoding.UTF8.GetBytes(payload.ResponseJsonString));
await Response.BodyWriter.CompleteAsync();
}

return StatusCode(payload.StatusCode);
}
}
72 changes: 18 additions & 54 deletions tests/Elastic.Transport.Tests/ResponseBuilderDisposeTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,66 +29,34 @@ public async Task StreamResponseWithPotentialBodyAndDisableDirectStreaming_Memor
await AssertResponse<StreamResponse>(true, expectedDisposed: false);

[Fact]
public async Task StreamResponseWith204StatusCode_MemoryStreamIsDisposed() =>
await AssertResponse<StreamResponse>(true, 204);
public async Task ResponseWithPotentialBodyButInvalidMimeType_MemoryStreamIsDisposed() =>
await AssertResponse<TestResponse>(true, mimeType: "application/not-valid", expectedDisposed: true);

[Fact]
public async Task StreamResponseForHeadRequest_StreamIsDisposed() =>
await AssertResponse<StreamResponse>(false, httpMethod: HttpMethod.HEAD);
public async Task ResponseWithPotentialBodyButSkippedStatusCode_MemoryStreamIsDisposed() =>
await AssertResponse<TestResponse>(true, skipStatusCode: 200, expectedDisposed: true);

[Fact]
public async Task StreamResponseWithZeroContentLength_StreamIsDisposed() =>
await AssertResponse<StreamResponse>(false, contentLength: 0);

[Fact]
public async Task ResponseWithPotentialBody_StreamIsDisposed() =>
await AssertResponse<TestResponse>(false, expectedDisposed: true);

[Fact]
public async Task ResponseWithPotentialBodyButInvalidMimeType_StreamIsDisposed() =>
await AssertResponse<TestResponse>(false, mimeType: "application/not-valid", expectedDisposed: true);

[Fact]
public async Task ResponseWithPotentialBodyButSkippedStatusCode_StreamIsDisposed() =>
await AssertResponse<TestResponse>(false, skipStatusCode: 200, expectedDisposed: true);

[Fact]
public async Task ResponseWithPotentialBodyButEmptyJson_StreamIsDisposed() =>
await AssertResponse<TestResponse>(false, responseJson: " ", expectedDisposed: true);
public async Task ResponseWithPotentialBodyButEmptyJson_MemoryStreamIsDisposed() =>
await AssertResponse<TestResponse>(true, responseJson: " ", expectedDisposed: true);

[Fact]
// NOTE: The empty string here hits a fast path in STJ which returns default if the stream length is zero.
public async Task ResponseWithPotentialBodyButNullResponseDuringDeserialization_StreamIsDisposed() =>
await AssertResponse<TestResponse>(false, responseJson: "", expectedDisposed: true);
public async Task ResponseWithPotentialBodyButNullResponseDuringDeserialization_MemoryStreamIsDisposed() =>
await AssertResponse<TestResponse>(true, responseJson: "", expectedDisposed: true);

[Fact]
public async Task ResponseWithPotentialBodyAndCustomResponseBuilder_StreamIsDisposed() =>
await AssertResponse<TestResponse>(false, customResponseBuilder: new TestCustomResponseBuilder(), expectedDisposed: true);
public async Task ResponseWithPotentialBodyAndCustomResponseBuilder_MemoryStreamIsDisposed() =>
await AssertResponse<TestResponse>(true, customResponseBuilder: new TestCustomResponseBuilder(), expectedDisposed: true);

[Fact]
// NOTE: We expect one memory stream factory creation when handling error responses
public async Task ResponseWithPotentialBodyAndErrorResponse_StreamIsDisposed() =>
await AssertResponse<TestResponse>(false, productRegistration: new TestProductRegistration(), expectedDisposed: true, memoryStreamCreateExpected: 1);

[Fact]
public async Task ResponseWithPotentialBodyAndDisableDirectStreaming_MemoryStreamIsDisposed() =>
await AssertResponse<TestResponse>(true, expectedDisposed: true);

[Fact]
public async Task ResponseWith204StatusCode_StreamIsDisposed() =>
await AssertResponse<TestResponse>(false, 204);

[Fact]
public async Task ResponseForHeadRequest_StreamIsDisposed() =>
await AssertResponse<TestResponse>(false, httpMethod: HttpMethod.HEAD);

[Fact]
public async Task ResponseWithZeroContentLength_StreamIsDisposed() =>
await AssertResponse<TestResponse>(false, contentLength: 0);
await AssertResponse<TestResponse>(true, productRegistration: new TestProductRegistration(), expectedDisposed: true);

[Fact]
public async Task StringResponseWithPotentialBodyAndDisableDirectStreaming_MemoryStreamIsDisposed() =>
await AssertResponse<StringResponse>(true, expectedDisposed: true, memoryStreamCreateExpected: 1);
await AssertResponse<StringResponse>(false, expectedDisposed: true, memoryStreamCreateExpected: 1);

private async Task AssertResponse<T>(bool disableDirectStreaming, int statusCode = 200, HttpMethod httpMethod = HttpMethod.GET, int contentLength = 10,
bool expectedDisposed = true, string mimeType = "application/json", string responseJson = "{}", int skipStatusCode = -1,
Expand Down Expand Up @@ -139,13 +107,11 @@ private async Task AssertResponse<T>(bool disableDirectStreaming, int statusCode
if (disableDirectStreaming)
{
var memoryStream = memoryStreamFactory.Created[0];
stream.IsDisposed.Should().BeTrue();
memoryStream.IsDisposed.Should().Be(expectedDisposed);
}
else
{
stream.IsDisposed.Should().Be(expectedDisposed);
}

// The latest implementation should never dispose the incoming stream and assumes the caller will handler disposal
stream.IsDisposed.Should().Be(false);

stream = new TrackDisposeStream();
var ct = new CancellationToken();
Expand All @@ -159,13 +125,11 @@ private async Task AssertResponse<T>(bool disableDirectStreaming, int statusCode
if (disableDirectStreaming)
{
var memoryStream = memoryStreamFactory.Created[0];
stream.IsDisposed.Should().BeTrue();
memoryStream.IsDisposed.Should().Be(expectedDisposed);
}
else
{
stream.IsDisposed.Should().Be(expectedDisposed);
}

// The latest implementation should never dispose the incoming stream and assumes the caller will handler disposal
stream.IsDisposed.Should().Be(false);
}

private class TestProductRegistration : ProductRegistration
Expand Down

0 comments on commit 27cd6b1

Please sign in to comment.