1
0
mirror of https://github.com/chylex/Minecraft-Phantom-Panel.git synced 2025-09-18 15:24:49 +02:00

1 Commits

Author SHA1 Message Date
c9c9b91e34 Replace NetMQ with custom TCP logic 2025-09-17 17:05:25 +02:00
24 changed files with 347 additions and 81 deletions

View File

@@ -19,6 +19,8 @@ public sealed class AgentServices {
public ActorSystem ActorSystem { get; } public ActorSystem ActorSystem { get; }
private ControllerConnection ControllerConnection { get; }
private AgentInfo AgentInfo { get; } private AgentInfo AgentInfo { get; }
private AgentFolders AgentFolders { get; } private AgentFolders AgentFolders { get; }
private AgentState AgentState { get; } private AgentState AgentState { get; }
@@ -31,6 +33,8 @@ public sealed class AgentServices {
public AgentServices(AgentInfo agentInfo, AgentFolders agentFolders, AgentServiceConfiguration serviceConfiguration, ControllerConnection controllerConnection) { public AgentServices(AgentInfo agentInfo, AgentFolders agentFolders, AgentServiceConfiguration serviceConfiguration, ControllerConnection controllerConnection) {
this.ActorSystem = ActorSystemFactory.Create("Agent"); this.ActorSystem = ActorSystemFactory.Create("Agent");
this.ControllerConnection = controllerConnection;
this.AgentInfo = agentInfo; this.AgentInfo = agentInfo;
this.AgentFolders = agentFolders; this.AgentFolders = agentFolders;
this.AgentState = new AgentState(); this.AgentState = new AgentState();
@@ -49,14 +53,14 @@ public sealed class AgentServices {
} }
} }
public async Task<bool> Register(ControllerConnection connection, CancellationToken cancellationToken) { public async Task<bool> Register(CancellationToken cancellationToken) {
Logger.Information("Registering with the controller..."); Logger.Information("Registering with the controller...");
// TODO NEED TO SEND WHEN SERVER RESTARTS!!! // TODO NEED TO SEND WHEN SERVER RESTARTS!!!
ImmutableArray<ConfigureInstanceMessage> configureInstanceMessages; ImmutableArray<ConfigureInstanceMessage> configureInstanceMessages;
try { try {
configureInstanceMessages = await connection.Send<RegisterAgentMessage, ImmutableArray<ConfigureInstanceMessage>>(new RegisterAgentMessage(AgentInfo), TimeSpan.FromMinutes(1), cancellationToken); configureInstanceMessages = await ControllerConnection.Send<RegisterAgentMessage, ImmutableArray<ConfigureInstanceMessage>>(new RegisterAgentMessage(AgentInfo), TimeSpan.FromMinutes(1), cancellationToken);
} catch (Exception e) { } catch (Exception e) {
Logger.Fatal(e, "Registration failed."); Logger.Fatal(e, "Registration failed.");
return false; return false;
@@ -78,7 +82,7 @@ public sealed class AgentServices {
} }
} }
await connection.Send(new AdvertiseJavaRuntimesMessage(JavaRuntimeRepository.All), cancellationToken); await ControllerConnection.Send(new AdvertiseJavaRuntimesMessage(JavaRuntimeRepository.All), cancellationToken);
InstanceTicketManager.RefreshAgentStatus(); InstanceTicketManager.RefreshAgentStatus();
return true; return true;

View File

@@ -62,14 +62,12 @@ try {
return 1; return 1;
} }
var controllerConnection = new ControllerConnection(rpcClient.SendChannel);
Task? rpcClientListener = null; Task? rpcClientListener = null;
try { try {
PhantomLogger.Root.InformationHeading("Launching Phantom Panel agent..."); PhantomLogger.Root.InformationHeading("Launching Phantom Panel agent...");
var agentInfo = new AgentInfo(agentGuid.Value, agentName, ProtocolVersion, fullVersion, maxInstances, maxMemory, allowedServerPorts, allowedRconPorts); var agentInfo = new AgentInfo(agentGuid.Value, agentName, ProtocolVersion, fullVersion, maxInstances, maxMemory, allowedServerPorts, allowedRconPorts);
var agentServices = new AgentServices(agentInfo, folders, new AgentServiceConfiguration(maxConcurrentBackupCompressionTasks), controllerConnection); var agentServices = new AgentServices(agentInfo, folders, new AgentServiceConfiguration(maxConcurrentBackupCompressionTasks), new ControllerConnection(rpcClient.SendChannel));
await agentServices.Initialize(); await agentServices.Initialize();
var rpcMessageHandlerInit = new ControllerMessageHandlerActor.Init(agentServices); var rpcMessageHandlerInit = new ControllerMessageHandlerActor.Init(agentServices);
@@ -77,18 +75,21 @@ try {
rpcClientListener = rpcClient.Listen(new IMessageReceiver<IMessageToAgent>.Actor(rpcMessageHandlerActor)); rpcClientListener = rpcClient.Listen(new IMessageReceiver<IMessageToAgent>.Actor(rpcMessageHandlerActor));
if (await agentServices.Register(controllerConnection, shutdownCancellationToken)) { if (await agentServices.Register(shutdownCancellationToken)) {
PhantomLogger.Root.Information("Phantom Panel agent is ready."); PhantomLogger.Root.Information("Phantom Panel agent is ready.");
await shutdownCancellationToken.WaitHandle.WaitOneAsync(); await shutdownCancellationToken.WaitHandle.WaitOneAsync();
} }
await agentServices.Shutdown(); await agentServices.Shutdown();
} finally { } finally {
PhantomLogger.Root.Information("Unregistering agent...");
try { try {
await controllerConnection.Send(new UnregisterAgentMessage(), CancellationToken.None); using var unregisterCancellationTokenSource = new CancellationTokenSource(TimeSpan.FromSeconds(10));
// TODO wait for acknowledgment await rpcClient.SendChannel.SendMessage(new UnregisterAgentMessage(), unregisterCancellationTokenSource.Token);
} catch (OperationCanceledException) {
PhantomLogger.Root.Warning("Could not unregister agent after shutdown.");
} catch (Exception e) { } catch (Exception e) {
PhantomLogger.Root.Warning(e, "Could not unregister agent after shutdown."); PhantomLogger.Root.Warning(e, "Could not unregister agent during shutdown.");
} finally { } finally {
await rpcClient.Shutdown(); await rpcClient.Shutdown();

View File

@@ -14,7 +14,7 @@ sealed class AgentConnection(Guid agentGuid, string agentName) {
public void UpdateConnection(RpcServerToClientConnection<IMessageToController, IMessageToAgent> newConnection, string newAgentName) { public void UpdateConnection(RpcServerToClientConnection<IMessageToController, IMessageToAgent> newConnection, string newAgentName) {
lock (this) { lock (this) {
connection?.CloseSession(); connection?.ClientClosedSession();
connection = newConnection; connection = newConnection;
agentName = newAgentName; agentName = newAgentName;
} }
@@ -23,7 +23,7 @@ sealed class AgentConnection(Guid agentGuid, string agentName) {
public bool CloseIfSame(RpcServerToClientConnection<IMessageToController, IMessageToAgent> expected) { public bool CloseIfSame(RpcServerToClientConnection<IMessageToController, IMessageToAgent> expected) {
lock (this) { lock (this) {
if (connection != null && ReferenceEquals(connection, expected)) { if (connection != null && ReferenceEquals(connection, expected)) {
connection.CloseSession(); connection.ClientClosedSession();
connection = null; connection = null;
return true; return true;
} }

View File

@@ -1,4 +1,5 @@
using System.Collections.Immutable; using System.Collections.Immutable;
using Akka.Actor;
using Phantom.Common.Messages.Agent; using Phantom.Common.Messages.Agent;
using Phantom.Common.Messages.Agent.ToAgent; using Phantom.Common.Messages.Agent.ToAgent;
using Phantom.Common.Messages.Agent.ToController; using Phantom.Common.Messages.Agent.ToController;
@@ -31,7 +32,7 @@ sealed class AgentMessageHandlerActor : ReceiveActor<IMessageToController> {
this.eventLogManager = init.EventLogManager; this.eventLogManager = init.EventLogManager;
ReceiveAsyncAndReply<RegisterAgentMessage, ImmutableArray<ConfigureInstanceMessage>>(HandleRegisterAgent); ReceiveAsyncAndReply<RegisterAgentMessage, ImmutableArray<ConfigureInstanceMessage>>(HandleRegisterAgent);
Receive<UnregisterAgentMessage>(HandleUnregisterAgent); ReceiveAsync<UnregisterAgentMessage>(HandleUnregisterAgent);
Receive<AdvertiseJavaRuntimesMessage>(HandleAdvertiseJavaRuntimes); Receive<AdvertiseJavaRuntimesMessage>(HandleAdvertiseJavaRuntimes);
Receive<ReportAgentStatusMessage>(HandleReportAgentStatus); Receive<ReportAgentStatusMessage>(HandleReportAgentStatus);
Receive<ReportInstanceStatusMessage>(HandleReportInstanceStatus); Receive<ReportInstanceStatusMessage>(HandleReportInstanceStatus);
@@ -49,11 +50,13 @@ sealed class AgentMessageHandlerActor : ReceiveActor<IMessageToController> {
return agentManager.RegisterAgent(message.AgentInfo, connection); return agentManager.RegisterAgent(message.AgentInfo, connection);
} }
private void HandleUnregisterAgent(UnregisterAgentMessage message) { private Task HandleUnregisterAgent(UnregisterAgentMessage message) {
Guid agentGuid = RequireAgentGuid(); Guid agentGuid = RequireAgentGuid();
agentManager.TellAgent(agentGuid, new AgentActor.UnregisterCommand(connection)); agentManager.TellAgent(agentGuid, new AgentActor.UnregisterCommand(connection));
agentManager.OnSessionClosed(connection.SessionId, agentGuid); agentManager.OnSessionClosed(connection.SessionId, agentGuid);
connection.CloseSession();
Self.Tell(PoisonPill.Instance);
return connection.ClientClosedSession();
} }
private void HandleAdvertiseJavaRuntimes(AdvertiseJavaRuntimesMessage message) { private void HandleAdvertiseJavaRuntimes(AdvertiseJavaRuntimesMessage message) {

View File

@@ -27,7 +27,7 @@ sealed class WebClientRegistrar(
) : IRpcServerClientRegistrar<IMessageToController, IMessageToWeb> { ) : IRpcServerClientRegistrar<IMessageToController, IMessageToWeb> {
public IMessageReceiver<IMessageToController> Register(RpcServerToClientConnection<IMessageToController, IMessageToWeb> connection) { public IMessageReceiver<IMessageToController> Register(RpcServerToClientConnection<IMessageToController, IMessageToWeb> connection) {
var name = "WebClient-" + connection.SessionId; var name = "WebClient-" + connection.SessionId;
var init = new WebMessageHandlerActor.Init(connection, this, controllerState, instanceLogManager, userManager, roleManager, userRoleManager, userLoginManager, auditLogManager, agentManager, minecraftVersions, eventLogManager); var init = new WebMessageHandlerActor.Init(connection, controllerState, instanceLogManager, userManager, roleManager, userRoleManager, userLoginManager, auditLogManager, agentManager, minecraftVersions, eventLogManager);
return new IMessageReceiver<IMessageToController>.Actor(actorSystem.ActorOf(WebMessageHandlerActor.Factory(init), name)); return new IMessageReceiver<IMessageToController>.Actor(actorSystem.ActorOf(WebMessageHandlerActor.Factory(init), name));
} }
} }

View File

@@ -1,4 +1,5 @@
using System.Collections.Immutable; using System.Collections.Immutable;
using Akka.Actor;
using Phantom.Common.Data; using Phantom.Common.Data;
using Phantom.Common.Data.Java; using Phantom.Common.Data.Java;
using Phantom.Common.Data.Minecraft; using Phantom.Common.Data.Minecraft;
@@ -23,7 +24,6 @@ namespace Phantom.Controller.Services.Rpc;
sealed class WebMessageHandlerActor : ReceiveActor<IMessageToController> { sealed class WebMessageHandlerActor : ReceiveActor<IMessageToController> {
public readonly record struct Init( public readonly record struct Init(
RpcServerToClientConnection<IMessageToController, IMessageToWeb> Connection, RpcServerToClientConnection<IMessageToController, IMessageToWeb> Connection,
WebClientRegistrar WebClientRegistrar,
ControllerState ControllerState, ControllerState ControllerState,
InstanceLogManager InstanceLogManager, InstanceLogManager InstanceLogManager,
UserManager UserManager, UserManager UserManager,
@@ -88,7 +88,8 @@ sealed class WebMessageHandlerActor : ReceiveActor<IMessageToController> {
} }
private Task HandleUnregisterWeb(UnregisterWebMessage message) { private Task HandleUnregisterWeb(UnregisterWebMessage message) {
return connection.CloseSession(); Self.Tell(PoisonPill.Instance);
return connection.ClientClosedSession();
} }
private Task<Optional<LogInSuccess>> HandleLogIn(LogInMessage message) { private Task<Optional<LogInSuccess>> HandleLogIn(LogInMessage message) {

View File

@@ -24,7 +24,7 @@ interface IFrame {
switch (oneByteBuffer[0]) { switch (oneByteBuffer[0]) {
case TypePingId: case TypePingId:
var pingTime = await PingFrame.Read(stream, cancellationToken); var pingTime = await PingFrame.Read(stream, cancellationToken);
await reader.OnPing(pingTime, cancellationToken); await reader.OnPingFrame(pingTime, cancellationToken);
break; break;
case TypePongId: case TypePongId:
@@ -48,7 +48,7 @@ interface IFrame {
break; break;
default: default:
reader.OnUnknownFrameStart(oneByteBuffer[0]); reader.OnUnknownFrameId(oneByteBuffer[0]);
break; break;
} }
} }

View File

@@ -3,10 +3,10 @@
namespace Phantom.Utils.Rpc.Frame; namespace Phantom.Utils.Rpc.Frame;
interface IFrameReader { interface IFrameReader {
ValueTask OnPing(DateTimeOffset pingTime, CancellationToken cancellationToken); ValueTask OnPingFrame(DateTimeOffset pingTime, CancellationToken cancellationToken);
void OnPongFrame(PongFrame frame); void OnPongFrame(PongFrame frame);
Task OnMessageFrame(MessageFrame frame, CancellationToken cancellationToken); Task OnMessageFrame(MessageFrame frame, CancellationToken cancellationToken);
void OnReplyFrame(ReplyFrame frame); void OnReplyFrame(ReplyFrame frame);
void OnErrorFrame(ErrorFrame frame); void OnErrorFrame(ErrorFrame frame);
void OnUnknownFrameStart(byte id); void OnUnknownFrameId(byte frameId);
} }

View File

@@ -0,0 +1,13 @@
using Phantom.Utils.Collections;
namespace Phantom.Utils.Rpc.Message;
sealed class MessageReceiveTracker {
private readonly RangeSet<uint> receivedMessageIds = new ();
public bool ReceiveMessage(uint messageId) {
lock (receivedMessageIds) {
return receivedMessageIds.Add(messageId);
}
}
}

View File

@@ -13,15 +13,17 @@ public sealed class RpcClient<TClientToServerMessage, TServerToClientMessage> :
private readonly string loggerName; private readonly string loggerName;
private readonly ILogger logger; private readonly ILogger logger;
private readonly MessageRegistry<TServerToClientMessage> serverToClientMessageRegistry; private readonly MessageRegistry<TServerToClientMessage> messageRegistry;
private readonly MessageReceiveTracker messageReceiveTracker = new ();
private readonly RpcClientToServerConnection connection; private readonly RpcClientToServerConnection connection;
private readonly CancellationTokenSource shutdownCancellationTokenSource = new ();
public RpcSendChannel<TClientToServerMessage> SendChannel { get; } public RpcSendChannel<TClientToServerMessage> SendChannel { get; }
private RpcClient(string loggerName, RpcClientConnectionParameters connectionParameters, IMessageDefinitions<TClientToServerMessage, TServerToClientMessage> messageDefinitions, RpcClientToServerConnector connector, RpcClientToServerConnector.Connection connection) { private RpcClient(string loggerName, RpcClientConnectionParameters connectionParameters, IMessageDefinitions<TClientToServerMessage, TServerToClientMessage> messageDefinitions, RpcClientToServerConnector connector, RpcClientToServerConnector.Connection connection) {
this.loggerName = loggerName; this.loggerName = loggerName;
this.logger = PhantomLogger.Create<RpcClient<TClientToServerMessage, TServerToClientMessage>>(loggerName); this.logger = PhantomLogger.Create<RpcClient<TClientToServerMessage, TServerToClientMessage>>(loggerName);
this.serverToClientMessageRegistry = messageDefinitions.ToClient; this.messageRegistry = messageDefinitions.ToClient;
this.connection = new RpcClientToServerConnection(loggerName, connector, connection); this.connection = new RpcClientToServerConnection(loggerName, connector, connection);
this.SendChannel = new RpcSendChannel<TClientToServerMessage>(loggerName, connectionParameters.Common, this.connection, messageDefinitions.ToServer); this.SendChannel = new RpcSendChannel<TClientToServerMessage>(loggerName, connectionParameters.Common, this.connection, messageDefinitions.ToServer);
@@ -29,9 +31,9 @@ public sealed class RpcClient<TClientToServerMessage, TServerToClientMessage> :
public async Task Listen(IMessageReceiver<TServerToClientMessage> receiver) { public async Task Listen(IMessageReceiver<TServerToClientMessage> receiver) {
var messageHandler = new RpcMessageHandler<TServerToClientMessage>(receiver, SendChannel); var messageHandler = new RpcMessageHandler<TServerToClientMessage>(receiver, SendChannel);
var frameReader = new RpcFrameReader<TClientToServerMessage, TServerToClientMessage>(loggerName, serverToClientMessageRegistry, messageHandler, SendChannel); var frameReader = new RpcFrameReader<TClientToServerMessage, TServerToClientMessage>(loggerName, messageRegistry, messageReceiveTracker, messageHandler, SendChannel);
try { try {
await connection.ReadConnection(frameReader, CancellationToken.None); await connection.ReadConnection(frameReader, shutdownCancellationTokenSource.Token);
} catch (OperationCanceledException) { } catch (OperationCanceledException) {
// Ignore. // Ignore.
} }
@@ -47,11 +49,13 @@ public sealed class RpcClient<TClientToServerMessage, TServerToClientMessage> :
} }
try { try {
connection.Close(); connection.StopReconnecting();
} catch (Exception e) { } catch (Exception e) {
logger.Error(e, "Caught exception while closing connection."); logger.Error(e, "Caught exception while closing connection.");
} }
await shutdownCancellationTokenSource.CancelAsync();
logger.Information("Client shut down."); logger.Information("Client shut down.");
} }

View File

@@ -1,4 +1,5 @@
using Phantom.Utils.Logging; using System.Net.Sockets;
using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Frame; using Phantom.Utils.Rpc.Frame;
using Serilog; using Serilog;
@@ -12,7 +13,7 @@ sealed class RpcClientToServerConnection(string loggerName, RpcClientToServerCon
private readonly CancellationTokenSource newConnectionCancellationTokenSource = new (); private readonly CancellationTokenSource newConnectionCancellationTokenSource = new ();
public async Task<Stream> GetStream() { public async Task<Stream> GetStream(CancellationToken cancellationToken) {
return (await GetConnection()).Stream; return (await GetConnection()).Stream;
} }
@@ -54,20 +55,22 @@ sealed class RpcClientToServerConnection(string loggerName, RpcClientToServerCon
throw; throw;
} catch (EndOfStreamException) { } catch (EndOfStreamException) {
logger.Warning("Socket was closed."); logger.Warning("Socket was closed.");
} catch (SocketException e) {
logger.Error("Socket reading was interrupted. Socket error {ErrorCode} ({ErrorCodeName}), reason: {ErrorMessage}", e.ErrorCode, e.SocketErrorCode, e.Message);
} catch (Exception e) { } catch (Exception e) {
logger.Error(e, "Closing socket due to an exception while reading it."); logger.Error(e, "Socket reading was interrupted.");
}
try { try {
await connection.Shutdown(); await connection.Shutdown();
} catch (Exception e2) { } catch (Exception e) {
logger.Error(e2, "Caught exception closing the socket."); logger.Error(e, "Caught exception closing the socket.");
}
} }
} }
} finally { } finally {
if (connection != null) { if (connection != null) {
try { try {
await connection.Disconnect(); // TODO what happens if already disconnected? await connection.Disconnect();
} finally { } finally {
connection.Dispose(); connection.Dispose();
} }
@@ -75,7 +78,7 @@ sealed class RpcClientToServerConnection(string loggerName, RpcClientToServerCon
} }
} }
public void Close() { public void StopReconnecting() {
newConnectionCancellationTokenSource.Cancel(); newConnectionCancellationTokenSource.Cancel();
} }

View File

@@ -53,7 +53,8 @@ internal sealed class RpcClientToServerConnector {
return newConnection; return newConnection;
} }
logger.Warning("Failed to connect to server, trying again in {}.", nextAttemptDelay.TotalSeconds.ToString("F1")); cancellationToken.ThrowIfCancellationRequested();
logger.Information("Trying to reconnect in {Seconds}s.", nextAttemptDelay.TotalSeconds.ToString("F1"));
await Task.Delay(nextAttemptDelay, cancellationToken); await Task.Delay(nextAttemptDelay, cancellationToken);
nextAttemptDelay = Comparables.Min(nextAttemptDelay.Multiply(1.5), MaximumRetryDelay); nextAttemptDelay = Comparables.Min(nextAttemptDelay.Multiply(1.5), MaximumRetryDelay);
@@ -66,9 +67,12 @@ internal sealed class RpcClientToServerConnector {
Socket clientSocket = new Socket(SocketType.Stream, ProtocolType.Tcp); Socket clientSocket = new Socket(SocketType.Stream, ProtocolType.Tcp);
try { try {
await clientSocket.ConnectAsync(parameters.Host, parameters.Port, cancellationToken); await clientSocket.ConnectAsync(parameters.Host, parameters.Port, cancellationToken);
} catch (SocketException e) {
logger.Error("Could not connect. Socket error {ErrorCode} ({ErrorCodeName}), reason: {ErrorMessage}", e.ErrorCode, e.SocketErrorCode, e.Message);
return null;
} catch (Exception e) { } catch (Exception e) {
logger.Error(e, "Could not connect."); logger.Error(e, "Could not connect.");
throw; return null;
} }
SslStream? stream; SslStream? stream;

View File

@@ -1,5 +1,5 @@
namespace Phantom.Utils.Rpc.Runtime; namespace Phantom.Utils.Rpc.Runtime;
interface IRpcConnectionProvider { interface IRpcConnectionProvider {
Task<Stream> GetStream(); Task<Stream> GetStream(CancellationToken cancellationToken);
} }

View File

@@ -3,7 +3,8 @@
public enum RpcError : byte { public enum RpcError : byte {
InvalidData = 0, InvalidData = 0,
UnknownMessageRegistryCode = 1, UnknownMessageRegistryCode = 1,
MessageDeserializationError = 2, MessageTooLarge = 2,
MessageHandlingError = 3, MessageDeserializationError = 3,
MessageTooLarge = 4, MessageHandlingError = 4,
MessageAlreadyHandled = 5,
} }

View File

@@ -1,13 +1,14 @@
namespace Phantom.Utils.Rpc.Runtime; namespace Phantom.Utils.Rpc.Runtime;
public sealed class RpcErrorException : Exception { sealed class RpcErrorException : Exception {
internal static RpcErrorException From(RpcError error) { internal static RpcErrorException From(RpcError error) {
return error switch { return error switch {
RpcError.InvalidData => new RpcErrorException("Invalid data", error), RpcError.InvalidData => new RpcErrorException("Invalid data", error),
RpcError.UnknownMessageRegistryCode => new RpcErrorException("Unknown message registry code", error), RpcError.UnknownMessageRegistryCode => new RpcErrorException("Unknown message registry code", error),
RpcError.MessageTooLarge => new RpcErrorException("Message is too large", error),
RpcError.MessageDeserializationError => new RpcErrorException("Message deserialization error", error), RpcError.MessageDeserializationError => new RpcErrorException("Message deserialization error", error),
RpcError.MessageHandlingError => new RpcErrorException("Message handling error", error), RpcError.MessageHandlingError => new RpcErrorException("Message handling error", error),
RpcError.MessageTooLarge => new RpcErrorException("Message is too large", error), RpcError.MessageAlreadyHandled => new RpcErrorException("Message already handled", error),
_ => new RpcErrorException("Unknown error", error), _ => new RpcErrorException("Unknown error", error),
}; };
} }

View File

@@ -9,12 +9,13 @@ namespace Phantom.Utils.Rpc.Runtime;
sealed class RpcFrameReader<TSentMessage, TReceivedMessage>( sealed class RpcFrameReader<TSentMessage, TReceivedMessage>(
string loggerName, string loggerName,
MessageRegistry<TReceivedMessage> messageRegistry, MessageRegistry<TReceivedMessage> messageRegistry,
MessageReceiveTracker messageReceiveTracker,
RpcMessageHandler<TReceivedMessage> messageHandler, RpcMessageHandler<TReceivedMessage> messageHandler,
RpcSendChannel<TSentMessage> sendChannel RpcSendChannel<TSentMessage> sendChannel
) : IFrameReader { ) : IFrameReader {
private readonly ILogger logger = PhantomLogger.Create<RpcFrameReader<TSentMessage, TReceivedMessage>>(loggerName); private readonly ILogger logger = PhantomLogger.Create<RpcFrameReader<TSentMessage, TReceivedMessage>>(loggerName);
public ValueTask OnPing(DateTimeOffset pingTime, CancellationToken cancellationToken) { public ValueTask OnPingFrame(DateTimeOffset pingTime, CancellationToken cancellationToken) {
messageHandler.OnPing(); messageHandler.OnPing();
return sendChannel.SendPong(pingTime, cancellationToken); return sendChannel.SendPong(pingTime, cancellationToken);
} }
@@ -24,6 +25,11 @@ sealed class RpcFrameReader<TSentMessage, TReceivedMessage>(
} }
public Task OnMessageFrame(MessageFrame frame, CancellationToken cancellationToken) { public Task OnMessageFrame(MessageFrame frame, CancellationToken cancellationToken) {
if (!messageReceiveTracker.ReceiveMessage(frame.MessageId)) {
logger.Warning("Received duplicate message {MessageId}.", frame.MessageId);
return messageHandler.SendError(frame.MessageId, RpcError.MessageAlreadyHandled, cancellationToken).AsTask();
}
if (messageRegistry.TryGetType(frame, out var messageType)) { if (messageRegistry.TryGetType(frame, out var messageType)) {
logger.Verbose("Received message {MesageId} of type {MessageType} ({Bytes} B).", frame.MessageId, messageType.Name, frame.SerializedMessage.Length); logger.Verbose("Received message {MesageId} of type {MessageType} ({Bytes} B).", frame.MessageId, messageType.Name, frame.SerializedMessage.Length);
} }
@@ -41,7 +47,7 @@ sealed class RpcFrameReader<TSentMessage, TReceivedMessage>(
sendChannel.ReceiveError(frame.ReplyingToMessageId, frame.Error); sendChannel.ReceiveError(frame.ReplyingToMessageId, frame.Error);
} }
public void OnUnknownFrameStart(byte id) { public void OnUnknownFrameId(byte frameId) {
logger.Error("Received unknown frame ID: {FrameId}", id); logger.Error("Received unknown frame ID: {FrameId}", frameId);
} }
} }

View File

@@ -19,6 +19,7 @@ public sealed class RpcSendChannel<TMessageBase> : IRpcReplySender, IDisposable
private readonly Task sendQueueTask; private readonly Task sendQueueTask;
private readonly Task pingTask; private readonly Task pingTask;
private readonly CancellationTokenSource sendQueueCancellationTokenSource = new ();
private readonly CancellationTokenSource pingCancellationTokenSource = new (); private readonly CancellationTokenSource pingCancellationTokenSource = new ();
private uint nextMessageId; private uint nextMessageId;
@@ -88,13 +89,15 @@ public sealed class RpcSendChannel<TMessageBase> : IRpcReplySender, IDisposable
} }
private async Task ProcessSendQueue() { private async Task ProcessSendQueue() {
await foreach (IFrame frame in sendQueue.Reader.ReadAllAsync()) { CancellationToken cancellationToken = sendQueueCancellationTokenSource.Token;
await foreach (IFrame frame in sendQueue.Reader.ReadAllAsync(cancellationToken)) {
while (true) { while (true) {
try { try {
Stream stream = await connectionProvider.GetStream(); Stream stream = await connectionProvider.GetStream(cancellationToken);
await stream.WriteAsync(frame.FrameType); await stream.WriteAsync(frame.FrameType, cancellationToken);
await frame.Write(stream); await frame.Write(stream, cancellationToken);
await stream.FlushAsync(); await stream.FlushAsync(cancellationToken);
break; break;
} catch (OperationCanceledException) { } catch (OperationCanceledException) {
throw; throw;
@@ -115,15 +118,16 @@ public sealed class RpcSendChannel<TMessageBase> : IRpcReplySender, IDisposable
pongTask = new TaskCompletionSource<DateTimeOffset>(); pongTask = new TaskCompletionSource<DateTimeOffset>();
if (!sendQueue.Writer.TryWrite(PingFrame.Instance)) { if (!sendQueue.Writer.TryWrite(PingFrame.Instance)) {
cancellationToken.ThrowIfCancellationRequested();
logger.Warning("Skipped a ping due to a full queue."); logger.Warning("Skipped a ping due to a full queue.");
continue; continue;
} }
DateTimeOffset pingTime = await pongTask.Task; DateTimeOffset pingTime = await pongTask.Task.WaitAsync(cancellationToken);
DateTimeOffset currentTime = DateTimeOffset.UtcNow; DateTimeOffset currentTime = DateTimeOffset.UtcNow;
TimeSpan roundTripTime = currentTime - pingTime; TimeSpan roundTripTime = currentTime - pingTime;
logger.Information("Received pong (rtt: {RoundTripTime} ms).", (long) roundTripTime.TotalMilliseconds); logger.Information("Received pong, round trip time: {RoundTripTime} ms", (long) roundTripTime.TotalMilliseconds);
} }
} }
@@ -144,7 +148,16 @@ public sealed class RpcSendChannel<TMessageBase> : IRpcReplySender, IDisposable
sendQueue.Writer.TryComplete(); sendQueue.Writer.TryComplete();
try { try {
await Task.WhenAll(sendQueueTask, pingTask); await pingTask;
} catch (Exception) {
// Ignore.
}
try {
await sendQueueTask.WaitAsync(TimeSpan.FromSeconds(15));
} catch (TimeoutException) {
logger.Warning("Could not finish processing send queue before timeout, forcibly shutting it down.");
await sendQueueCancellationTokenSource.CancelAsync();
} catch (Exception) { } catch (Exception) {
// Ignore. // Ignore.
} }
@@ -152,6 +165,7 @@ public sealed class RpcSendChannel<TMessageBase> : IRpcReplySender, IDisposable
public void Dispose() { public void Dispose() {
sendQueueTask.Dispose(); sendQueueTask.Dispose();
sendQueueCancellationTokenSource.Dispose();
pingCancellationTokenSource.Dispose(); pingCancellationTokenSource.Dispose();
} }
} }

View File

@@ -17,7 +17,7 @@ public sealed class RpcServer<TClientToServerMessage, TServerToClientMessage>(
IRpcServerClientRegistrar<TClientToServerMessage, TServerToClientMessage> clientRegistrar IRpcServerClientRegistrar<TClientToServerMessage, TServerToClientMessage> clientRegistrar
) { ) {
private readonly ILogger logger = PhantomLogger.Create<RpcServer<TClientToServerMessage, TServerToClientMessage>>(loggerName); private readonly ILogger logger = PhantomLogger.Create<RpcServer<TClientToServerMessage, TServerToClientMessage>>(loggerName);
private readonly RpcServerClientSessions<TServerToClientMessage> clientSessions = new (loggerName, connectionParameters.Common, messageDefinitions.ToClient); private readonly RpcServerClientSessions<TServerToClientMessage> clientSessions = new (connectionParameters.Common, messageDefinitions.ToClient);
private readonly List<Client> clients = []; private readonly List<Client> clients = [];
public async Task<bool> Run(CancellationToken shutdownToken) { public async Task<bool> Run(CancellationToken shutdownToken) {
@@ -102,7 +102,7 @@ public sealed class RpcServer<TClientToServerMessage, TServerToClientMessage>(
private string Address => socket.RemoteEndPoint?.ToString() ?? "<unknown address>"; private string Address => socket.RemoteEndPoint?.ToString() ?? "<unknown address>";
private readonly ILogger logger; private ILogger logger;
private readonly string serverLoggerName; private readonly string serverLoggerName;
private readonly IMessageDefinitions<TClientToServerMessage, TServerToClientMessage> messageDefinitions; private readonly IMessageDefinitions<TClientToServerMessage, TServerToClientMessage> messageDefinitions;
@@ -151,9 +151,10 @@ public sealed class RpcServer<TClientToServerMessage, TServerToClientMessage>(
} }
if (sessionIdResult.HasValue) { if (sessionIdResult.HasValue) {
logger.Information("Client connected.");
await RunConnectedSession(sessionIdResult.Value, stream); await RunConnectedSession(sessionIdResult.Value, stream);
} }
} catch (Exception e) {
logger.Error(e, "Caught exception while processing client.");
} finally { } finally {
logger.Information("Disconnecting client..."); logger.Information("Disconnecting client...");
try { try {
@@ -220,10 +221,31 @@ public sealed class RpcServer<TClientToServerMessage, TServerToClientMessage>(
private async Task RunConnectedSession(Guid sessionId, Stream stream) { private async Task RunConnectedSession(Guid sessionId, Stream stream) {
var loggerName = PhantomLogger.ConcatNames(serverLoggerName, clientSessions.NextLoggerName(sessionId)); var loggerName = PhantomLogger.ConcatNames(serverLoggerName, clientSessions.NextLoggerName(sessionId));
var session = clientSessions.OnConnected(sessionId, stream);
logger.Information("Client connected with session {SessionId}, new logger name: {LoggerName}", sessionId, loggerName);
logger = PhantomLogger.Create<RpcServer<TClientToServerMessage, TServerToClientMessage>, Client>(loggerName);
var session = clientSessions.OnConnected(sessionId, loggerName, stream);
try { try {
var connection = new RpcServerToClientConnection<TClientToServerMessage, TServerToClientMessage>(loggerName, clientSessions, sessionId, messageDefinitions.ToServer, stream, session); var connection = new RpcServerToClientConnection<TClientToServerMessage, TServerToClientMessage>(loggerName, clientSessions, sessionId, messageDefinitions.ToServer, stream, session);
await connection.Listen(clientRegistrar.Register(connection));
IMessageReceiver<TClientToServerMessage> messageReceiver;
try {
messageReceiver = clientRegistrar.Register(connection);
} catch (Exception e) {
logger.Error(e, "Could not register client.");
return;
}
try {
await connection.Listen(messageReceiver);
} catch (EndOfStreamException) {
logger.Warning("Socket reading was interrupted, connection lost.");
} catch (SocketException e) {
logger.Error("Socket reading was interrupted. Socket error {ErrorCode} ({ErrorCodeName}), reason: {ErrorMessage}", e.ErrorCode, e.SocketErrorCode, e.Message);
} catch (Exception e) {
logger.Error(e, "Socket reading was interrupted.");
}
} finally { } finally {
clientSessions.OnDisconnected(sessionId); clientSessions.OnDisconnected(sessionId);
} }

View File

@@ -4,11 +4,13 @@ namespace Phantom.Utils.Rpc.Runtime.Server;
sealed class RpcServerClientSession<TServerToClientMessage> : IRpcConnectionProvider { sealed class RpcServerClientSession<TServerToClientMessage> : IRpcConnectionProvider {
public RpcSendChannel<TServerToClientMessage> SendChannel { get; } public RpcSendChannel<TServerToClientMessage> SendChannel { get; }
public MessageReceiveTracker MessageReceiveTracker { get; }
private TaskCompletionSource<Stream> nextStream = new (); private TaskCompletionSource<Stream> nextStream = new ();
public RpcServerClientSession(string loggerName, RpcCommonConnectionParameters connectionParameters, MessageRegistry<TServerToClientMessage> messageRegistry) { public RpcServerClientSession(string loggerName, RpcCommonConnectionParameters connectionParameters, MessageRegistry<TServerToClientMessage> messageRegistry) {
this.SendChannel = new RpcSendChannel<TServerToClientMessage>(loggerName, connectionParameters, this, messageRegistry); this.SendChannel = new RpcSendChannel<TServerToClientMessage>(loggerName, connectionParameters, this, messageRegistry);
this.MessageReceiveTracker = new MessageReceiveTracker();
} }
public void OnConnected(Stream stream) { public void OnConnected(Stream stream) {
@@ -29,7 +31,7 @@ sealed class RpcServerClientSession<TServerToClientMessage> : IRpcConnectionProv
} }
} }
Task<Stream> IRpcConnectionProvider.GetStream() { Task<Stream> IRpcConnectionProvider.GetStream(CancellationToken cancellationToken) {
lock (this) { lock (this) {
return nextStream.Task; return nextStream.Task;
} }

View File

@@ -4,23 +4,34 @@ using Phantom.Utils.Rpc.Message;
namespace Phantom.Utils.Rpc.Runtime.Server; namespace Phantom.Utils.Rpc.Runtime.Server;
sealed class RpcServerClientSessions<TServerToClientMessage>(string loggerName, RpcCommonConnectionParameters connectionParameters, MessageRegistry<TServerToClientMessage> messageRegistry) { sealed class RpcServerClientSessions<TServerToClientMessage> {
private readonly RpcCommonConnectionParameters connectionParameters;
private readonly MessageRegistry<TServerToClientMessage> messageRegistry;
private readonly ConcurrentDictionary<Guid, RpcServerClientSession<TServerToClientMessage>> sessionsById = new (); private readonly ConcurrentDictionary<Guid, RpcServerClientSession<TServerToClientMessage>> sessionsById = new ();
private readonly ConcurrentDictionary<Guid, uint> sessionLoggerSequenceIds = new (); private readonly ConcurrentDictionary<Guid, uint> sessionLoggerSequenceIds = new ();
private readonly Func<Guid, string, RpcServerClientSession<TServerToClientMessage>> createSessionFunction;
public RpcServerClientSessions(RpcCommonConnectionParameters connectionParameters, MessageRegistry<TServerToClientMessage> messageRegistry) {
this.connectionParameters = connectionParameters;
this.messageRegistry = messageRegistry;
this.createSessionFunction = CreateSession;
}
public string NextLoggerName(Guid sessionId) { public string NextLoggerName(Guid sessionId) {
string name = PhantomLogger.ShortenGuid(sessionId); string name = PhantomLogger.ShortenGuid(sessionId);
return name + "/" + sessionLoggerSequenceIds.AddOrUpdate(sessionId, static _ => 1, static (_, prev) => prev + 1); return name + "/" + sessionLoggerSequenceIds.AddOrUpdate(sessionId, static _ => 1, static (_, prev) => prev + 1);
} }
public RpcSendChannel<TServerToClientMessage> OnConnected(Guid sessionId, Stream stream) { public RpcServerClientSession<TServerToClientMessage> OnConnected(Guid sessionId, string loggerName, Stream stream) {
var session = sessionsById.GetOrAdd(sessionId, CreateSession); var session = sessionsById.GetOrAdd(sessionId, createSessionFunction, loggerName);
session.OnConnected(stream); session.OnConnected(stream);
return session.SendChannel; return session;
} }
private RpcServerClientSession<TServerToClientMessage> CreateSession(Guid sessionId) { private RpcServerClientSession<TServerToClientMessage> CreateSession(Guid sessionId, string loggerName) {
return new RpcServerClientSession<TServerToClientMessage>(PhantomLogger.ConcatNames(loggerName, sessionId.ToString()), connectionParameters, messageRegistry); return new RpcServerClientSession<TServerToClientMessage>(loggerName, connectionParameters, messageRegistry);
} }
public void OnDisconnected(Guid sessionId) { public void OnDisconnected(Guid sessionId) {

View File

@@ -1,38 +1,49 @@
using Phantom.Utils.Rpc.Frame; using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Frame;
using Phantom.Utils.Rpc.Message; using Phantom.Utils.Rpc.Message;
using Serilog;
namespace Phantom.Utils.Rpc.Runtime.Server; namespace Phantom.Utils.Rpc.Runtime.Server;
public sealed class RpcServerToClientConnection<TClientToServerMessage, TServerToClientMessage> { public sealed class RpcServerToClientConnection<TClientToServerMessage, TServerToClientMessage> {
private readonly string loggerName; private readonly string loggerName;
private readonly ILogger logger;
private readonly RpcServerClientSessions<TServerToClientMessage> sessions; private readonly RpcServerClientSessions<TServerToClientMessage> sessions;
private readonly MessageRegistry<TClientToServerMessage> messageRegistry; private readonly MessageRegistry<TClientToServerMessage> messageRegistry;
private readonly MessageReceiveTracker messageReceiveTracker;
private readonly Stream stream; private readonly Stream stream;
private readonly CancellationTokenSource closeCancellationTokenSource = new ();
public Guid SessionId { get; } public Guid SessionId { get; }
public RpcSendChannel<TServerToClientMessage> SendChannel { get; } public RpcSendChannel<TServerToClientMessage> SendChannel { get; }
internal RpcServerToClientConnection(string loggerName, RpcServerClientSessions<TServerToClientMessage> sessions, Guid sessionId, MessageRegistry<TClientToServerMessage> messageRegistry, Stream stream, RpcSendChannel<TServerToClientMessage> sendChannel) { internal RpcServerToClientConnection(string loggerName, RpcServerClientSessions<TServerToClientMessage> sessions, Guid sessionId, MessageRegistry<TClientToServerMessage> messageRegistry, Stream stream, RpcServerClientSession<TServerToClientMessage> session) {
this.loggerName = loggerName; this.loggerName = loggerName;
this.logger = PhantomLogger.Create<RpcServerToClientConnection<TClientToServerMessage, TServerToClientMessage>>(loggerName);
this.sessions = sessions; this.sessions = sessions;
this.messageRegistry = messageRegistry; this.messageRegistry = messageRegistry;
this.messageReceiveTracker = session.MessageReceiveTracker;
this.stream = stream; this.stream = stream;
this.SessionId = sessionId; this.SessionId = sessionId;
this.SendChannel = sendChannel; this.SendChannel = session.SendChannel;
}
public Task CloseSession() {
return sessions.CloseSession(SessionId);
} }
internal async Task Listen(IMessageReceiver<TClientToServerMessage> receiver) { internal async Task Listen(IMessageReceiver<TClientToServerMessage> receiver) {
var messageHandler = new RpcMessageHandler<TClientToServerMessage>(receiver, SendChannel); var messageHandler = new RpcMessageHandler<TClientToServerMessage>(receiver, SendChannel);
var frameReader = new RpcFrameReader<TServerToClientMessage, TClientToServerMessage>(loggerName, messageRegistry, messageHandler, SendChannel); var frameReader = new RpcFrameReader<TServerToClientMessage, TClientToServerMessage>(loggerName, messageRegistry, messageReceiveTracker, messageHandler, SendChannel);
try { try {
await IFrame.ReadFrom(stream, frameReader, CancellationToken.None); await IFrame.ReadFrom(stream, frameReader, closeCancellationTokenSource.Token);
} catch (OperationCanceledException) { } catch (OperationCanceledException) {
// Ignore. // Ignore.
} }
} }
public async Task ClientClosedSession() {
logger.Information("Client closed session.");
Task closeSessionTask = sessions.CloseSession(SessionId);
await closeCancellationTokenSource.CancelAsync();
await closeSessionTask;
}
} }

View File

@@ -0,0 +1,93 @@
using NUnit.Framework;
using Phantom.Utils.Collections;
using Range = Phantom.Utils.Collections.RangeSet<int>.Range;
namespace Phantom.Utils.Tests.Collections;
[TestFixture]
public class RangeSetTests {
[Test]
public void OneValue() {
var set = new RangeSet<int>();
set.Add(5);
Assert.That(set, Is.EqualTo(new[] {
new Range(Min: 5, Max: 5),
}));
}
[Test]
public void MultipleDisjointValues() {
var set = new RangeSet<int>();
set.Add(5);
set.Add(7);
set.Add(1);
set.Add(3);
Assert.That(set, Is.EqualTo(new[] {
new Range(Min: 1, Max: 1),
new Range(Min: 3, Max: 3),
new Range(Min: 5, Max: 5),
new Range(Min: 7, Max: 7),
}));
}
[Test]
public void ExtendMin() {
var set = new RangeSet<int>();
set.Add(5);
set.Add(4);
Assert.That(set, Is.EqualTo(new[] {
new Range(Min: 4, Max: 5),
}));
}
[Test]
public void ExtendMax() {
var set = new RangeSet<int>();
set.Add(5);
set.Add(6);
Assert.That(set, Is.EqualTo(new[] {
new Range(Min: 5, Max: 6),
}));
}
[Test]
public void ExtendMaxAndMerge() {
var set = new RangeSet<int>();
set.Add(5);
set.Add(7);
set.Add(6);
Assert.That(set, Is.EqualTo(new[] {
new Range(Min: 5, Max: 7),
}));
}
[Test]
public void MultipleMergingAndDisjointValues() {
var set = new RangeSet<int>();
set.Add(1);
set.Add(2);
set.Add(5);
set.Add(4);
set.Add(10);
set.Add(7);
set.Add(9);
set.Add(11);
set.Add(16);
set.Add(12);
set.Add(3);
set.Add(14);
Assert.That(set, Is.EqualTo(new[] {
new Range(Min: 1, Max: 5),
new Range(Min: 7, Max: 7),
new Range(Min: 9, Max: 12),
new Range(Min: 14, Max: 14),
new Range(Min: 16, Max: 16),
}));
}
}

View File

@@ -0,0 +1,69 @@
using System.Collections;
using System.Numerics;
namespace Phantom.Utils.Collections;
public sealed class RangeSet<T> : IEnumerable<RangeSet<T>.Range> where T : IBinaryInteger<T> {
private readonly List<Range> ranges = [];
public bool Add(T value) {
int index = 0;
for (; index < ranges.Count; index++) {
var range = ranges[index];
if (range.Contains(value)) {
return false;
}
if (range.ExtendIfAtEdge(value, out var extendedRange)) {
ranges[index] = extendedRange;
if (index < ranges.Count - 1) {
var nextRange = ranges[index + 1];
if (extendedRange.Max + T.One == nextRange.Min) {
ranges[index] = new Range(extendedRange.Min, nextRange.Max);
ranges.RemoveAt(index + 1);
}
}
return true;
}
if (range.Max > value) {
break;
}
}
ranges.Insert(index, new Range(value, value));
return true;
}
public IEnumerator<Range> GetEnumerator() {
return ranges.GetEnumerator();
}
IEnumerator IEnumerable.GetEnumerator() {
return GetEnumerator();
}
public readonly record struct Range(T Min, T Max) {
internal bool ExtendIfAtEdge(T value, out Range newRange) {
if (value == Min - T.One) {
newRange = this with { Min = value };
return true;
}
else if (value == Max + T.One) {
newRange = this with { Max = value };
return true;
}
else {
newRange = default;
return false;
}
}
internal bool Contains(T value) {
return value >= Min && value <= Max;
}
}
}

View File

@@ -93,11 +93,14 @@ try {
await shutdownCancellationToken.WaitHandle.WaitOneAsync(); await shutdownCancellationToken.WaitHandle.WaitOneAsync();
await webApplication.StopAsync(); await webApplication.StopAsync();
} finally { } finally {
PhantomLogger.Root.Information("Unregistering web...");
try { try {
await rpcClient.SendChannel.SendMessage(new UnregisterWebMessage()); using var unregisterCancellationTokenSource = new CancellationTokenSource(TimeSpan.FromSeconds(10));
// TODO wait for acknowledgment await rpcClient.SendChannel.SendMessage(new UnregisterWebMessage(), unregisterCancellationTokenSource.Token);
} catch (OperationCanceledException) {
PhantomLogger.Root.Warning("Could not unregister web after shutdown.");
} catch (Exception e) { } catch (Exception e) {
PhantomLogger.Root.Warning(e, "Could not unregister agent after shutdown."); PhantomLogger.Root.Warning(e, "Could not unregister web after shutdown.");
} finally { } finally {
await rpcClient.Shutdown(); await rpcClient.Shutdown();