diff --git a/GuildWarsPartySearch/Endpoints/LiveFeed.cs b/GuildWarsPartySearch/Endpoints/LiveFeed.cs index a73e7b3..c0152a9 100644 --- a/GuildWarsPartySearch/Endpoints/LiveFeed.cs +++ b/GuildWarsPartySearch/Endpoints/LiveFeed.cs @@ -35,8 +35,15 @@ public override async Task ExecuteAsync(TextContent? content, CancellationToken public override async Task SocketAccepted(CancellationToken cancellationToken) { - var scopedLogger = this.logger.CreateScopedLogger(nameof(this.SocketAccepted), this.Context?.Connection.RemoteIpAddress?.ToString() ?? string.Empty); - this.liveFeedService.AddClient(this.WebSocket!); + var ipAddress = this.Context?.Connection.RemoteIpAddress?.ToString(); + var scopedLogger = this.logger.CreateScopedLogger(nameof(this.SocketAccepted), ipAddress ?? string.Empty); + if (!await this.liveFeedService.AddClient(this.WebSocket!, ipAddress, cancellationToken)) + { + scopedLogger.LogError("Client rejected"); + this.WebSocket?.CloseAsync(System.Net.WebSockets.WebSocketCloseStatus.NormalClosure, "Connection rejected", cancellationToken); + return; + } + scopedLogger.LogDebug("Client accepted to livefeed"); scopedLogger.LogDebug("Sending all party searches"); @@ -46,8 +53,9 @@ public override async Task SocketAccepted(CancellationToken cancellationToken) public override Task SocketClosed() { - var scopedLogger = this.logger.CreateScopedLogger(nameof(this.SocketAccepted), this.Context?.Connection.RemoteIpAddress?.ToString() ?? string.Empty); - this.liveFeedService.RemoveClient(this.WebSocket!); + var ipAddress = this.Context?.Connection.RemoteIpAddress?.ToString(); + var scopedLogger = this.logger.CreateScopedLogger(nameof(this.SocketAccepted), ipAddress ?? string.Empty); + this.liveFeedService.RemoveClient(this.WebSocket!, ipAddress); scopedLogger.LogDebug("Client removed from livefeed"); return Task.CompletedTask; } diff --git a/GuildWarsPartySearch/Options/IpWhitelistOptions.cs b/GuildWarsPartySearch/Options/IpWhitelistOptions.cs index c2847ce..87554ac 100644 --- a/GuildWarsPartySearch/Options/IpWhitelistOptions.cs +++ b/GuildWarsPartySearch/Options/IpWhitelistOptions.cs @@ -5,5 +5,5 @@ namespace GuildWarsPartySearch.Server.Options; public class IpWhitelistOptions { [JsonPropertyName(nameof(Addresses))] - public List Addresses { get; set; } = [ "127.0.0.1" ]; + public List Addresses { get; set; } = []; } diff --git a/GuildWarsPartySearch/Services/Feed/ILiveFeedService.cs b/GuildWarsPartySearch/Services/Feed/ILiveFeedService.cs index f1923d0..f5e1c88 100644 --- a/GuildWarsPartySearch/Services/Feed/ILiveFeedService.cs +++ b/GuildWarsPartySearch/Services/Feed/ILiveFeedService.cs @@ -4,7 +4,7 @@ namespace GuildWarsPartySearch.Server.Services.Feed; public interface ILiveFeedService { - void AddClient(WebSocket webSocket); - void RemoveClient(WebSocket webSocket); + Task AddClient(WebSocket webSocket, string? ipAddress, CancellationToken cancellationToken); + void RemoveClient(WebSocket webSocket, string? ipAddress); Task PushUpdate(Models.PartySearch partySearchUpdate, CancellationToken cancellationToken); } diff --git a/GuildWarsPartySearch/Services/Feed/LiveFeedService.cs b/GuildWarsPartySearch/Services/Feed/LiveFeedService.cs index 73b17e0..8ca0b6c 100644 --- a/GuildWarsPartySearch/Services/Feed/LiveFeedService.cs +++ b/GuildWarsPartySearch/Services/Feed/LiveFeedService.cs @@ -1,5 +1,7 @@ using GuildWarsPartySearch.Server.Models.Endpoints; +using GuildWarsPartySearch.Server.Services.Database; using System.Core.Extensions; +using System.Extensions; using System.Net.WebSockets; using System.Text; using System.Text.Json; @@ -8,24 +10,29 @@ namespace GuildWarsPartySearch.Server.Services.Feed; public sealed class LiveFeedService : ILiveFeedService { + private const int MaxConnectionsPerIP = 2; + private readonly SemaphoreSlim semaphore = new(1); - private readonly List clients = []; + private readonly Dictionary> clients = []; + private readonly IIpWhitelistDatabase ipWhitelistDatabase; private readonly JsonSerializerOptions jsonSerializerOptions; private readonly ILogger logger; public LiveFeedService( IHostApplicationLifetime lifetime, + IIpWhitelistDatabase ipWhitelistDatabase, JsonSerializerOptions jsonSerializerOptions, ILogger logger) { lifetime.ApplicationStopping.Register(this.ShutDownConnections); + this.ipWhitelistDatabase = ipWhitelistDatabase.ThrowIfNull(); this.jsonSerializerOptions = jsonSerializerOptions.ThrowIfNull(); this.logger = logger.ThrowIfNull(); } - public void AddClient(WebSocket client) + public Task AddClient(WebSocket client, string? ipAddress, CancellationToken cancellationToken) { - AddClientInternal(client); + return AddClientInternal(client, ipAddress, cancellationToken); } public async Task PushUpdate(Models.PartySearch partySearchUpdate, CancellationToken cancellationToken) @@ -33,7 +40,7 @@ public async Task PushUpdate(Models.PartySearch partySearchUpdate, CancellationT // Since LiveFeed endpoint expects a PartySearchList, so we send a PartySearchList with only the update to keep it consistent var payloadString = JsonSerializer.Serialize(new PartySearchList { Searches = [partySearchUpdate] }, this.jsonSerializerOptions); var payload = Encoding.UTF8.GetBytes(payloadString); - await ExecuteOnClientsInternal(async client => + await ExecuteOnClientsInternal(async (address, client) => { try { @@ -42,27 +49,71 @@ await ExecuteOnClientsInternal(async client => catch(Exception ex) { this.logger.LogError(ex, $"Encountered exception while broadcasting update"); - RemoveClientInternal(client); + RemoveClientInternal(client, address); } }); } - public void RemoveClient(WebSocket client) + public void RemoveClient(WebSocket client, string? ipAddress) { - RemoveClientInternal(client); + RemoveClientInternal(client, ipAddress); } - private void AddClientInternal(WebSocket client) + private async Task AddClientInternal(WebSocket client, string? ipAddress, CancellationToken cancellationToken) { - this.semaphore.Wait(); - this.clients.Add(client); - this.semaphore.Release(); + var scopedLogger = this.logger.CreateScopedLogger(nameof(this.AddClientInternal), ipAddress ?? string.Empty); + + await this.semaphore.WaitAsync(cancellationToken); + try + { + if (ipAddress is null || + ipAddress.IsNullOrWhiteSpace()) + { + return false; + } + + var whitelistedIps = await this.ipWhitelistDatabase.GetWhitelistedAddresses(cancellationToken); + if (whitelistedIps.None(addr => addr == ipAddress) && + this.clients.TryGetValue(ipAddress, out var sockets) && + sockets.Count >= 2) + { + scopedLogger.LogError("Too many live connections. Rejecting"); + return false; + } + + if (!this.clients.TryGetValue(ipAddress, out var existingSockets)) + { + existingSockets = []; + this.clients[ipAddress] = existingSockets; + } + + existingSockets.Add(client); + return true; + } + finally + { + this.semaphore.Release(); + } } - private void RemoveClientInternal(WebSocket client) + private void RemoveClientInternal(WebSocket client, string? ipAddress) { this.semaphore.Wait(); - this.clients.Remove(client); + if (ipAddress is null || + ipAddress.IsNullOrWhiteSpace()) + { + return; + } + + if (this.clients.TryGetValue(ipAddress, out var sockets)) + { + sockets.Remove(client); + if (sockets.Count == 0) + { + this.clients.Remove(ipAddress); + } + } + if (client?.State is not WebSocketState.Closed or WebSocketState.Aborted) { client?.Abort(); @@ -71,19 +122,22 @@ private void RemoveClientInternal(WebSocket client) this.semaphore.Release(); } - private async Task ExecuteOnClientsInternal(Func action) + private async Task ExecuteOnClientsInternal(Func action) { await this.semaphore.WaitAsync(); - await Task.WhenAll(this.clients.Select(client => action(client))); + await Task.WhenAll(this.clients.SelectMany(pair => pair.Value.Select(client => action(pair.Key, client)))); this.semaphore.Release(); } private void ShutDownConnections() { this.semaphore.Wait(); - foreach(var client in this.clients) + foreach(var sockets in this.clients.Values) { - client.Abort(); + foreach(var client in sockets) + { + client.Abort(); + } } this.semaphore.Release();