From 2ce24c676cbd62bb2ed531d33bda0f9dd76598e1 Mon Sep 17 00:00:00 2001 From: Joannis Orlandos Date: Sun, 18 Dec 2022 00:43:05 +0100 Subject: [PATCH] Add a pipeline builder that checks the order of handlers --- Sources/NIOCore/ChannelBuilder.swift | 200 +++++++++++++++++++ Tests/NIOCoreTests/ChannelBuilderTests.swift | 93 +++++++++ 2 files changed, 293 insertions(+) create mode 100644 Sources/NIOCore/ChannelBuilder.swift create mode 100644 Tests/NIOCoreTests/ChannelBuilderTests.swift diff --git a/Sources/NIOCore/ChannelBuilder.swift b/Sources/NIOCore/ChannelBuilder.swift new file mode 100644 index 0000000000..9accb5e7f5 --- /dev/null +++ b/Sources/NIOCore/ChannelBuilder.swift @@ -0,0 +1,200 @@ +#if swift(>=5.7) +@resultBuilder public struct ChannelPipelineBuilder { + public static func buildPartialBlock( + first handler: Handler + ) -> ModifiedTypedChannel where InboundOut == Handler.InboundIn, OutboundIn == Handler.OutboundOut { + ModifiedTypedChannel<_, _>(handlers: [ handler ]) + } + + @_disfavoredOverload + public static func buildPartialBlock( + first handler: Handler + ) -> ModifiedTypedChannel where InboundOut == Handler.InboundIn { + ModifiedTypedChannel<_, _>(handlers: [ handler ]) + } + + @_disfavoredOverload + public static func buildPartialBlock( + first handler: Handler + ) -> ModifiedTypedChannel where OutboundIn == Handler.OutboundOut { + ModifiedTypedChannel<_, _>(handlers: [ handler ]) + } + + public static func buildPartialBlock< + PartialIn, PartialOut, + Handler: ChannelDuplexHandler + >( + accumulated base: ModifiedTypedChannel, + next handler: Handler + ) -> ModifiedTypedChannel where PartialIn == Handler.InboundIn, OutboundIn == Handler.OutboundOut + { + ModifiedTypedChannel<_, _>(handlers: base.handlers + [handler]) + } + + @_disfavoredOverload + public static func buildPartialBlock< + PartialIn, PartialOut, + HandlerInboundOut, HandlerOutboundIn + >( + accumulated base: ModifiedTypedChannel, + next pipeline: CheckedPipeline + ) -> CheckedPipeline + { + CheckedPipeline<_ ,_, _, _>(handlers: base.handlers + pipeline.handlers) + } + + @_disfavoredOverload + public static func buildPartialBlock( + accumulated base: ModifiedTypedChannel, + next decoder: Decoder + ) -> ModifiedTypedChannel { + ModifiedTypedChannel<_, _>( + handlers: base.handlers + [ByteToMessageHandler(decoder)] + ) + } + + @_disfavoredOverload + public static func buildPartialBlock( + accumulated base: ModifiedTypedChannel, + next encoder: Encoder + ) -> ModifiedTypedChannel { + ModifiedTypedChannel<_, _>( + handlers: base.handlers + [MessageToByteHandler(encoder)] + ) + } + + @_disfavoredOverload + public static func buildPartialBlock< + PartialIn, PartialOut, + Handler: ChannelInboundHandler + >( + accumulated base: ModifiedTypedChannel, + next handler: Handler + ) -> ModifiedTypedChannel where PartialIn == Handler.InboundIn + { + ModifiedTypedChannel<_, _>(handlers: base.handlers + [handler]) + } + + @_disfavoredOverload + public static func buildPartialBlock< + PartialIn, PartialOut, + Handler: ChannelOutboundHandler + >( + accumulated base: ModifiedTypedChannel, + next handler: Handler + ) -> ModifiedTypedChannel where PartialOut == Handler.OutboundOut + { + ModifiedTypedChannel<_, _>(handlers: base.handlers + [handler]) + } + + @_disfavoredOverload + public static func buildFinalResult( + _ component: CheckedPipeline + ) -> CheckedPipeline { + component + } + + public static func buildFinalResult( + _ component: ModifiedTypedChannel + ) -> CheckedPipeline { + CheckedPipeline<_, _, _, _>(handlers: component.handlers) + } +} + +extension ChannelPipelineBuilder { + public static func buildPartialBlock( + first handler: Handler + ) -> ModifiedTypedChannel where Handler.InboundIn == ByteBuffer, Handler.OutboundOut == ByteBuffer, InboundOut == IOData, OutboundIn == IOData + { + ModifiedTypedChannel<_, _>(handlers: [ handler ]) + } + + @_disfavoredOverload + public static func buildPartialBlock( + first handler: Handler + ) -> ModifiedTypedChannel where Handler.InboundIn == ByteBuffer, InboundOut == IOData { + ModifiedTypedChannel<_, _>(handlers: [ handler ]) + } + + @_disfavoredOverload + public static func buildPartialBlock( + first handler: Handler + ) -> ModifiedTypedChannel where Handler.OutboundOut == ByteBuffer, OutboundIn == IOData { + ModifiedTypedChannel<_, _>(handlers: [ handler ]) + } + + @_disfavoredOverload + public static func buildPartialBlock< + PartialOut, + Handler: ChannelInboundHandler + >( + accumulated base: ModifiedTypedChannel, + next handler: Handler + ) -> ModifiedTypedChannel where Handler.InboundIn == ByteBuffer + { + ModifiedTypedChannel<_, _>(handlers: base.handlers + [handler]) + } + + @_disfavoredOverload + public static func buildPartialBlock< + PartialIn, + Handler: ChannelOutboundHandler + >( + accumulated base: ModifiedTypedChannel, + next handler: Handler + ) -> ModifiedTypedChannel where Handler.OutboundOut == ByteBuffer + { + ModifiedTypedChannel<_, _>(handlers: base.handlers + [handler]) + } +} + +extension ChannelPipelineBuilder { + public static func buildPartialBlock< + Handler: ChannelDuplexHandler + >( + accumulated base: ModifiedTypedChannel, + next handler: Handler + ) -> ModifiedTypedChannel where Handler.InboundIn == ByteBuffer, Handler.OutboundOut == ByteBuffer + { + ModifiedTypedChannel<_, _>(handlers: base.handlers + [handler]) + } +} + +public struct CheckedPipeline { + internal let handlers: [ChannelHandler] +} + +public struct ModifiedTypedChannel { + internal let handlers: [ChannelHandler] +} + +public extension ChannelPipeline { + func addHandlers< + ChannelOutput, PipelineOutput, + ChannelInput, PipelineInput + >( + reading channelOutput: ChannelOutput.Type, + writing channelInput: ChannelInput.Type, + @ChannelPipelineBuilder buildPipeline: () -> CheckedPipeline + ) -> EventLoopFuture { + addHandlers(buildPipeline().handlers) + } + + + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + func addHandlers< + ChannelOutput, PipelineOutput, + ChannelInput, PipelineInput + >( + reading channelOutput: ChannelOutput.Type, + writing channelInput: ChannelInput.Type, + @ChannelPipelineBuilder buildPipeline: () -> CheckedPipeline + ) async throws { + try await addHandlers( + reading: ChannelOutput.self, + writing: ChannelInput.self, + buildPipeline: buildPipeline + ).get() + } +} +#endif diff --git a/Tests/NIOCoreTests/ChannelBuilderTests.swift b/Tests/NIOCoreTests/ChannelBuilderTests.swift new file mode 100644 index 0000000000..1328f4c8b4 --- /dev/null +++ b/Tests/NIOCoreTests/ChannelBuilderTests.swift @@ -0,0 +1,93 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2017-2021 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import XCTest +@testable import NIOCore +import NIOEmbedded +import NIOTestUtils + +final class BytesToStringInboundHandler: ChannelInboundHandler { + typealias InboundIn = ByteBuffer + typealias InboundOut = String + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + var data = unwrapInboundIn(data) + let string = data.readString(length: data.readableBytes)! + context.fireChannelRead(wrapInboundOut(string)) + } +} + +final class BytesToStringOutboundHandler: ChannelOutboundHandler { + typealias OutboundIn = ByteBuffer + typealias OutboundOut = String + + func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { + var data = unwrapOutboundIn(data) + let string = data.readString(length: data.readableBytes)! + context.writeAndFlush(wrapOutboundOut(string), promise: promise) + } +} + +final class StringToIntInboundHandler: ChannelInboundHandler { + typealias InboundIn = String + typealias InboundOut = FWI + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + let string = unwrapInboundIn(data) + let int = InboundOut(string)! + context.fireChannelRead(wrapInboundOut(int)) + } +} + +public final class ChannelTests: XCTestCase { + func testSingleStepPipeline() async throws { + let channel = EmbeddedChannel() + try await channel.pipeline.addHandlers( + reading: ByteBuffer.self, + writing: ByteBuffer.self + ) { + BytesToStringInboundHandler() + } + + try channel.writeInbound(ByteBuffer(string: "msg")) + XCTAssertEqual(try channel.readInbound(as: String.self), "msg") + } + + func testMisconfiguredPipelineFails() async throws { + let channel = EmbeddedChannel() + try await channel.pipeline.addHandlers( + reading: ByteBuffer.self, + writing: ByteBuffer.self + ) { + BytesToStringInboundHandler() + } + + try channel.writeInbound(ByteBuffer(string: "msg")) + await XCTAssertThrowsError(try channel.readInbound(as: ByteBuffer.self)) + } + + func testTwoStepPipeline() async throws { + let channel = EmbeddedChannel() + try await channel.pipeline.addHandlers( + reading: ByteBuffer.self, + writing: ByteBuffer.self + ) { + BytesToStringInboundHandler() + StringToIntInboundHandler() + } + + try channel.writeInbound(ByteBuffer(string: "2022")) + XCTAssertEqual(try channel.readInbound(as: Int.self), 2022) + } +}