1
0
mirror of https://github.com/chylex/Minecraft-Phantom-Panel.git synced 2025-09-17 21:24:49 +02:00

1 Commits

Author SHA1 Message Date
c9c9b91e34 Replace NetMQ with custom TCP logic 2025-09-17 17:05:25 +02:00
131 changed files with 2166 additions and 1912 deletions

View File

@@ -1,21 +0,0 @@
using Phantom.Common.Messages.Agent;
using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Runtime;
using Serilog;
namespace Phantom.Agent.Rpc;
public sealed class ControllerConnection {
private static readonly ILogger Logger = PhantomLogger.Create(nameof(ControllerConnection));
private readonly RpcConnectionToServer<IMessageToController> connection;
public ControllerConnection(RpcConnectionToServer<IMessageToController> connection) {
this.connection = connection;
Logger.Information("Connection ready.");
}
public Task Send<TMessage>(TMessage message) where TMessage : IMessageToController {
return connection.Send(message);
}
}

View File

@@ -1,44 +0,0 @@
using Phantom.Common.Messages.Agent;
using Phantom.Common.Messages.Agent.ToController;
using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Runtime;
using Serilog;
namespace Phantom.Agent.Rpc;
sealed class KeepAliveLoop {
private static readonly ILogger Logger = PhantomLogger.Create<KeepAliveLoop>();
private static readonly TimeSpan KeepAliveInterval = TimeSpan.FromSeconds(10);
private readonly RpcConnectionToServer<IMessageToController> connection;
private readonly CancellationTokenSource cancellationTokenSource = new ();
public KeepAliveLoop(RpcConnectionToServer<IMessageToController> connection) {
this.connection = connection;
Task.Run(Run);
}
private async Task Run() {
var cancellationToken = cancellationTokenSource.Token;
try {
await connection.IsReady.WaitAsync(cancellationToken);
Logger.Information("Started keep-alive loop.");
while (true) {
await Task.Delay(KeepAliveInterval, cancellationToken);
await connection.Send(new AgentIsAliveMessage()).WaitAsync(cancellationToken);
}
} catch (OperationCanceledException) {
// Ignore.
} finally {
cancellationTokenSource.Dispose();
Logger.Information("Stopped keep-alive loop.");
}
}
public void Cancel() {
cancellationTokenSource.Cancel();
}
}

View File

@@ -1,12 +0,0 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
</PropertyGroup>
<ItemGroup>
<ProjectReference Include="..\..\Common\Phantom.Common.Messages.Agent\Phantom.Common.Messages.Agent.csproj" />
</ItemGroup>
</Project>

View File

@@ -1,7 +0,0 @@
using Phantom.Common.Data;
namespace Phantom.Agent.Rpc;
sealed class RpcClientAgentHandshake(AuthToken authToken) {
}

View File

@@ -1,37 +0,0 @@
using NetMQ;
using NetMQ.Sockets;
using Phantom.Common.Messages.Agent;
using Phantom.Common.Messages.Agent.BiDirectional;
using Phantom.Common.Messages.Agent.ToController;
using Phantom.Utils.Actor;
using Phantom.Utils.Rpc.Runtime;
using Phantom.Utils.Rpc.Sockets;
using Serilog;
namespace Phantom.Agent.Rpc;
public sealed class RpcClientRuntime : RpcClientRuntime<IMessageToAgent, IMessageToController, ReplyMessage> {
public static Task Launch(RpcClientSocket<IMessageToAgent, IMessageToController, ReplyMessage> socket, ActorRef<IMessageToAgent> handlerActorRef, SemaphoreSlim disconnectSemaphore, CancellationToken receiveCancellationToken) {
return new RpcClientRuntime(socket, handlerActorRef, disconnectSemaphore, receiveCancellationToken).Launch();
}
private RpcClientRuntime(RpcClientSocket<IMessageToAgent, IMessageToController, ReplyMessage> socket, ActorRef<IMessageToAgent> handlerActor, SemaphoreSlim disconnectSemaphore, CancellationToken receiveCancellationToken) : base(socket, handlerActor, disconnectSemaphore, receiveCancellationToken) {}
protected override async Task RunWithConnection(ClientSocket socket, RpcConnectionToServer<IMessageToController> connection) {
var keepAliveLoop = new KeepAliveLoop(connection);
try {
await base.RunWithConnection(socket, connection);
} finally {
keepAliveLoop.Cancel();
}
}
protected override async Task SendDisconnectMessage(ClientSocket socket, ILogger logger) {
var unregisterMessageBytes = AgentMessageRegistries.ToController.Write(new UnregisterAgentMessage()).ToArray();
try {
await socket.SendAsync(unregisterMessageBytes).AsTask().WaitAsync(TimeSpan.FromSeconds(5), CancellationToken.None);
} catch (TimeoutException) {
logger.Error("Timed out communicating agent shutdown with the controller.");
}
}
}

View File

@@ -1,9 +1,13 @@
using Akka.Actor;
using System.Collections.Immutable;
using Akka.Actor;
using Phantom.Agent.Minecraft.Java;
using Phantom.Agent.Rpc;
using Phantom.Agent.Services.Backups;
using Phantom.Agent.Services.Instances;
using Phantom.Agent.Services.Rpc;
using Phantom.Common.Data.Agent;
using Phantom.Common.Data.Replies;
using Phantom.Common.Messages.Agent.ToAgent;
using Phantom.Common.Messages.Agent.ToController;
using Phantom.Utils.Actor;
using Phantom.Utils.Logging;
using Serilog;
@@ -15,6 +19,9 @@ public sealed class AgentServices {
public ActorSystem ActorSystem { get; }
private ControllerConnection ControllerConnection { get; }
private AgentInfo AgentInfo { get; }
private AgentFolders AgentFolders { get; }
private AgentState AgentState { get; }
private BackupManager BackupManager { get; }
@@ -26,6 +33,9 @@ public sealed class AgentServices {
public AgentServices(AgentInfo agentInfo, AgentFolders agentFolders, AgentServiceConfiguration serviceConfiguration, ControllerConnection controllerConnection) {
this.ActorSystem = ActorSystemFactory.Create("Agent");
this.ControllerConnection = controllerConnection;
this.AgentInfo = agentInfo;
this.AgentFolders = agentFolders;
this.AgentState = new AgentState();
this.BackupManager = new BackupManager(agentFolders, serviceConfiguration.MaxConcurrentCompressionTasks);
@@ -43,6 +53,41 @@ public sealed class AgentServices {
}
}
public async Task<bool> Register(CancellationToken cancellationToken) {
Logger.Information("Registering with the controller...");
// TODO NEED TO SEND WHEN SERVER RESTARTS!!!
ImmutableArray<ConfigureInstanceMessage> configureInstanceMessages;
try {
configureInstanceMessages = await ControllerConnection.Send<RegisterAgentMessage, ImmutableArray<ConfigureInstanceMessage>>(new RegisterAgentMessage(AgentInfo), TimeSpan.FromMinutes(1), cancellationToken);
} catch (Exception e) {
Logger.Fatal(e, "Registration failed.");
return false;
}
foreach (var configureInstanceMessage in configureInstanceMessages) {
var configureInstanceCommand = new InstanceManagerActor.ConfigureInstanceCommand(
configureInstanceMessage.InstanceGuid,
configureInstanceMessage.Configuration,
configureInstanceMessage.LaunchProperties,
configureInstanceMessage.LaunchNow,
AlwaysReportStatus: true
);
var configureInstanceResult = await InstanceManager.Request(configureInstanceCommand, cancellationToken);
if (!configureInstanceResult.Is(ConfigureInstanceResult.Success)) {
Logger.Fatal("Unable to configure instance \"{Name}\" (GUID {Guid}), shutting down.", configureInstanceMessage.Configuration.InstanceName, configureInstanceMessage.InstanceGuid);
return false;
}
}
await ControllerConnection.Send(new AdvertiseJavaRuntimesMessage(JavaRuntimeRepository.All), cancellationToken);
InstanceTicketManager.RefreshAgentStatus();
return true;
}
public async Task Shutdown() {
Logger.Information("Stopping services...");

View File

@@ -58,7 +58,7 @@ sealed class InstanceActor : ReceiveActor<InstanceActor.ICommand> {
private void ReportCurrentStatus() {
agentState.UpdateInstance(new Instance(instanceGuid, currentStatus));
instanceServices.ControllerConnection.Send(new ReportInstanceStatusMessage(instanceGuid, currentStatus));
instanceServices.ControllerConnection.TrySend(new ReportInstanceStatusMessage(instanceGuid, currentStatus));
}
private void TransitionState(InstanceRunningState? newState) {

View File

@@ -7,6 +7,6 @@ namespace Phantom.Agent.Services.Instances;
sealed record InstanceContext(Guid InstanceGuid, string ShortName, ILogger Logger, InstanceServices Services, ActorRef<InstanceActor.ICommand> Actor, CancellationToken ActorCancellationToken) {
public void ReportEvent(IInstanceEvent instanceEvent) {
Services.ControllerConnection.Send(new ReportInstanceEventMessage(Guid.NewGuid(), DateTime.UtcNow, InstanceGuid, instanceEvent));
Services.ControllerConnection.TrySend(new ReportInstanceEventMessage(Guid.NewGuid(), DateTime.UtcNow, InstanceGuid, instanceEvent));
}
}

View File

@@ -4,8 +4,8 @@ using Phantom.Agent.Minecraft.Launcher;
using Phantom.Agent.Minecraft.Launcher.Types;
using Phantom.Agent.Minecraft.Properties;
using Phantom.Agent.Minecraft.Server;
using Phantom.Agent.Rpc;
using Phantom.Agent.Services.Backups;
using Phantom.Agent.Services.Rpc;
using Phantom.Common.Data;
using Phantom.Common.Data.Instance;
using Phantom.Common.Data.Minecraft;
@@ -56,11 +56,6 @@ sealed class InstanceManagerActor : ReceiveActor<InstanceManagerActor.ICommand>
ReceiveAsync<ShutdownCommand>(Shutdown);
}
private string GetInstanceLoggerName(Guid guid) {
var prefix = guid.ToString();
return prefix[..prefix.IndexOf('-')] + "/" + Interlocked.Increment(ref instanceLoggerSequenceId);
}
private sealed record InstanceInfo(ActorRef<InstanceActor.ICommand> Actor, InstanceConfiguration Configuration, IServerLauncher Launcher);
public interface ICommand {}
@@ -118,7 +113,8 @@ sealed class InstanceManagerActor : ReceiveActor<InstanceManagerActor.ICommand>
}
}
else {
var instanceInit = new InstanceActor.Init(agentState, instanceGuid, GetInstanceLoggerName(instanceGuid), instanceServices, instanceTicketManager, shutdownCancellationToken);
var instanceLoggerName = PhantomLogger.ShortenGuid(instanceGuid) + "/" + Interlocked.Increment(ref instanceLoggerSequenceId);;
var instanceInit = new InstanceActor.Init(agentState, instanceGuid, instanceLoggerName, instanceServices, instanceTicketManager, shutdownCancellationToken);
instances[instanceGuid] = instance = new InstanceInfo(Context.ActorOf(InstanceActor.Factory(instanceInit), "Instance-" + instanceGuid), configuration, launcher);
Logger.Information("Created instance \"{Name}\" (GUID {Guid}).", configuration.InstanceName, instanceGuid);

View File

@@ -1,6 +1,6 @@
using Phantom.Agent.Minecraft.Launcher;
using Phantom.Agent.Rpc;
using Phantom.Agent.Services.Backups;
using Phantom.Agent.Services.Rpc;
namespace Phantom.Agent.Services.Instances;

View File

@@ -1,4 +1,4 @@
using Phantom.Agent.Rpc;
using Phantom.Agent.Services.Rpc;
using Phantom.Common.Data;
using Phantom.Common.Data.Agent;
using Phantom.Common.Data.Instance;
@@ -91,7 +91,7 @@ sealed class InstanceTicketManager {
public void RefreshAgentStatus() {
lock (this) {
controllerConnection.Send(new ReportAgentStatusMessage(activeTicketGuids.Count, usedMemory));
controllerConnection.TrySend(new ReportAgentStatusMessage(activeTicketGuids.Count, usedMemory));
}
}

View File

@@ -1,6 +1,6 @@
using System.Collections.Immutable;
using System.Threading.Channels;
using Phantom.Agent.Rpc;
using Phantom.Agent.Services.Rpc;
using Phantom.Common.Messages.Agent.ToController;
using Phantom.Utils.Logging;
using Phantom.Utils.Tasks;
@@ -63,7 +63,7 @@ sealed class InstanceLogSender : CancellableBackgroundTask {
private void SendOutputToServer(ImmutableArray<string> lines) {
if (!lines.IsEmpty) {
controllerConnection.Send(new InstanceOutputMessage(instanceGuid, lines));
controllerConnection.TrySend(new InstanceOutputMessage(instanceGuid, lines));
}
}

View File

@@ -1,6 +1,6 @@
using Phantom.Agent.Minecraft.Instance;
using Phantom.Agent.Minecraft.Server;
using Phantom.Agent.Rpc;
using Phantom.Agent.Services.Rpc;
using Phantom.Common.Data.Instance;
using Phantom.Common.Messages.Agent.ToController;
using Phantom.Utils.Logging;
@@ -38,7 +38,7 @@ sealed class InstancePlayerCountTracker : CancellableBackgroundTask {
}
onlinePlayerCountChanged?.Invoke(this, value?.Online);
controllerConnection.Send(new ReportInstancePlayerCountsMessage(instanceGuid, value));
controllerConnection.TrySend(new ReportInstancePlayerCountsMessage(instanceGuid, value));
}
}

View File

@@ -8,7 +8,6 @@
<ItemGroup>
<ProjectReference Include="..\..\Common\Phantom.Common.Messages.Agent\Phantom.Common.Messages.Agent.csproj" />
<ProjectReference Include="..\Phantom.Agent.Minecraft\Phantom.Agent.Minecraft.csproj" />
<ProjectReference Include="..\Phantom.Agent.Rpc\Phantom.Agent.Rpc.csproj" />
</ItemGroup>
</Project>

View File

@@ -0,0 +1,20 @@
using Phantom.Common.Messages.Agent;
using Phantom.Utils.Actor;
using Phantom.Utils.Rpc.Runtime;
namespace Phantom.Agent.Services.Rpc;
public sealed class ControllerConnection(RpcSendChannel<IMessageToController> sendChannel) {
public ValueTask Send<TMessage>(TMessage message, CancellationToken cancellationToken) where TMessage : IMessageToController {
return sendChannel.SendMessage(message, cancellationToken);
}
// TODO handle properly
public bool TrySend<TMessage>(TMessage message) where TMessage : IMessageToController {
return sendChannel.TrySendMessage(message);
}
public Task<TReply> Send<TMessage, TReply>(TMessage message, TimeSpan waitForReplyTime, CancellationToken cancellationToken) where TMessage : IMessageToController, ICanReply<TReply> {
return sendChannel.SendMessage<TMessage, TReply>(message, waitForReplyTime, cancellationToken);
}
}

View File

@@ -1,86 +1,32 @@
using Phantom.Agent.Services.Instances;
using Phantom.Common.Data;
using Phantom.Common.Data.Instance;
using Phantom.Common.Data.Replies;
using Phantom.Common.Messages.Agent;
using Phantom.Common.Messages.Agent.BiDirectional;
using Phantom.Common.Messages.Agent.ToAgent;
using Phantom.Common.Messages.Agent.ToController;
using Phantom.Utils.Actor;
using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Runtime;
using Serilog;
namespace Phantom.Agent.Services.Rpc;
public sealed class ControllerMessageHandlerActor : ReceiveActor<IMessageToAgent> {
private static ILogger Logger { get; } = PhantomLogger.Create<ControllerMessageHandlerActor>();
public readonly record struct Init(RpcConnectionToServer<IMessageToController> Connection, AgentServices Agent, CancellationTokenSource ShutdownTokenSource);
public readonly record struct Init(AgentServices Agent);
public static Props<IMessageToAgent> Factory(Init init) {
return Props<IMessageToAgent>.Create(() => new ControllerMessageHandlerActor(init), new ActorConfiguration { SupervisorStrategy = SupervisorStrategies.Resume });
}
private readonly RpcConnectionToServer<IMessageToController> connection;
private readonly AgentServices agent;
private readonly CancellationTokenSource shutdownTokenSource;
private ControllerMessageHandlerActor(Init init) {
this.connection = init.Connection;
this.agent = init.Agent;
this.shutdownTokenSource = init.ShutdownTokenSource;
ReceiveAsync<RegisterAgentSuccessMessage>(HandleRegisterAgentSuccess);
Receive<RegisterAgentFailureMessage>(HandleRegisterAgentFailure);
ReceiveAndReplyLater<ConfigureInstanceMessage, Result<ConfigureInstanceResult, InstanceActionFailure>>(HandleConfigureInstance);
ReceiveAndReplyLater<LaunchInstanceMessage, Result<LaunchInstanceResult, InstanceActionFailure>>(HandleLaunchInstance);
ReceiveAndReplyLater<StopInstanceMessage, Result<StopInstanceResult, InstanceActionFailure>>(HandleStopInstance);
ReceiveAndReplyLater<SendCommandToInstanceMessage, Result<SendCommandToInstanceResult, InstanceActionFailure>>(HandleSendCommandToInstance);
Receive<ReplyMessage>(HandleReply);
}
private async Task HandleRegisterAgentSuccess(RegisterAgentSuccessMessage message) {
Logger.Information("Agent authentication successful.");
void ShutdownAfterConfigurationFailed(Guid instanceGuid, InstanceConfiguration configuration) {
Logger.Fatal("Unable to configure instance \"{Name}\" (GUID {Guid}), shutting down.", configuration.InstanceName, instanceGuid);
shutdownTokenSource.Cancel();
}
foreach (var configureInstanceMessage in message.InitialInstanceConfigurations) {
var result = await HandleConfigureInstance(configureInstanceMessage, alwaysReportStatus: true);
if (!result.Is(ConfigureInstanceResult.Success)) {
ShutdownAfterConfigurationFailed(configureInstanceMessage.InstanceGuid, configureInstanceMessage.Configuration);
return;
}
}
connection.SetIsReady();
await connection.Send(new AdvertiseJavaRuntimesMessage(agent.JavaRuntimeRepository.All));
agent.InstanceTicketManager.RefreshAgentStatus();
}
private void HandleRegisterAgentFailure(RegisterAgentFailureMessage message) {
string errorMessage = message.FailureKind switch {
RegisterAgentFailure.ConnectionAlreadyHasAnAgent => "This connection already has an associated agent.",
RegisterAgentFailure.InvalidToken => "Invalid token.",
_ => "Unknown error " + (byte) message.FailureKind + "."
};
Logger.Fatal("Agent authentication failed: {Error}", errorMessage);
PhantomLogger.Dispose();
Environment.Exit(1);
}
private Task<Result<ConfigureInstanceResult, InstanceActionFailure>> HandleConfigureInstance(ConfigureInstanceMessage message, bool alwaysReportStatus) {
return agent.InstanceManager.Request(new InstanceManagerActor.ConfigureInstanceCommand(message.InstanceGuid, message.Configuration, message.LaunchProperties, message.LaunchNow, alwaysReportStatus));
}
private async Task<Result<ConfigureInstanceResult, InstanceActionFailure>> HandleConfigureInstance(ConfigureInstanceMessage message) {
return await HandleConfigureInstance(message, alwaysReportStatus: false);
return await agent.InstanceManager.Request(new InstanceManagerActor.ConfigureInstanceCommand(message.InstanceGuid, message.Configuration, message.LaunchProperties, message.LaunchNow, AlwaysReportStatus: false));
}
private async Task<Result<LaunchInstanceResult, InstanceActionFailure>> HandleLaunchInstance(LaunchInstanceMessage message) {
@@ -94,8 +40,4 @@ public sealed class ControllerMessageHandlerActor : ReceiveActor<IMessageToAgent
private async Task<Result<SendCommandToInstanceResult, InstanceActionFailure>> HandleSendCommandToInstance(SendCommandToInstanceMessage message) {
return await agent.InstanceManager.Request(new InstanceManagerActor.SendCommandToInstanceCommand(message.InstanceGuid, message.Command));
}
private void HandleReply(ReplyMessage message) {
connection.Receive(message);
}
}

View File

@@ -1,12 +1,16 @@
using System.Reflection;
using Phantom.Agent;
using Phantom.Agent.Rpc;
using Phantom.Agent.Services;
using Phantom.Agent.Services.Rpc;
using Phantom.Common.Data.Agent;
using Phantom.Common.Messages.Agent;
using Phantom.Common.Messages.Agent.ToController;
using Phantom.Utils.Actor;
using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.New;
using Phantom.Utils.Rpc.Message;
using Phantom.Utils.Rpc.Runtime.Client;
using Phantom.Utils.Runtime;
using Phantom.Utils.Threading;
const int ProtocolVersion = 1;
@@ -43,46 +47,57 @@ try {
return 1;
}
var (certificateThumbprint, authToken) = agentKey.Value;
var agentInfo = new AgentInfo(agentGuid.Value, agentName, ProtocolVersion, fullVersion, maxInstances, maxMemory, allowedServerPorts, allowedRconPorts);
var rpcClientConnectionParameters = new RpcClientConnectionParameters(
Host: controllerHost,
Port: controllerPort,
DistinguishedName: "phantom-controller",
CertificateThumbprint: agentKey.Value.CertificateThumbprint,
AuthToken: agentKey.Value.AuthToken,
SendQueueCapacity: 500,
PingInterval: TimeSpan.FromSeconds(10)
);
PhantomLogger.Root.InformationHeading("Launching Phantom Panel agent...");
var rpcClient = new RpcClient<IMessageToController, IMessageToAgent>("Controller", controllerHost, controllerPort, "phantom-controller", certificateThumbprint, null);
var rpcConnection = await rpcClient.Connect(shutdownCancellationToken);
if (rpcConnection == null) {
using var rpcClient = await RpcClient<IMessageToController, IMessageToAgent>.Connect("Controller", rpcClientConnectionParameters, AgentMessageRegistries.Definitions, shutdownCancellationToken);
if (rpcClient == null) {
return 1;
}
// var rpcConfiguration = new RpcConfiguration("Agent", controllerHost, controllerPort, controllerCertificate);
// var rpcSocket = RpcClientSocket.Connect(rpcConfiguration, AgentMessageRegistries.Definitions, new RegisterAgentMessage(agentToken, agentInfo));
//
var agentServices = new AgentServices(agentInfo, folders, new AgentServiceConfiguration(maxConcurrentBackupCompressionTasks), new ControllerConnection(rpcSocket.Connection));
Task? rpcClientListener = null;
try {
PhantomLogger.Root.InformationHeading("Launching Phantom Panel agent...");
var agentInfo = new AgentInfo(agentGuid.Value, agentName, ProtocolVersion, fullVersion, maxInstances, maxMemory, allowedServerPorts, allowedRconPorts);
var agentServices = new AgentServices(agentInfo, folders, new AgentServiceConfiguration(maxConcurrentBackupCompressionTasks), new ControllerConnection(rpcClient.SendChannel));
await agentServices.Initialize();
try {
var rpcMessageHandlerInit = new ControllerMessageHandlerActor.Init(agentServices);
var rpcMessageHandlerActor = agentServices.ActorSystem.ActorOf(ControllerMessageHandlerActor.Factory(rpcMessageHandlerInit), "ControllerMessageHandler");
} finally {
await agentServices.Shutdown();
rpcClientListener = rpcClient.Listen(new IMessageReceiver<IMessageToAgent>.Actor(rpcMessageHandlerActor));
if (await agentServices.Register(shutdownCancellationToken)) {
PhantomLogger.Root.Information("Phantom Panel agent is ready.");
await shutdownCancellationToken.WaitHandle.WaitOneAsync();
}
await agentServices.Shutdown();
} finally {
PhantomLogger.Root.Information("Unregistering agent...");
try {
using var unregisterCancellationTokenSource = new CancellationTokenSource(TimeSpan.FromSeconds(10));
await rpcClient.SendChannel.SendMessage(new UnregisterAgentMessage(), unregisterCancellationTokenSource.Token);
} catch (OperationCanceledException) {
PhantomLogger.Root.Warning("Could not unregister agent after shutdown.");
} catch (Exception e) {
PhantomLogger.Root.Warning(e, "Could not unregister agent during shutdown.");
} finally {
await rpcClient.Shutdown();
if (rpcClientListener != null) {
await rpcClientListener;
}
}
}
//
// var rpcMessageHandlerInit = new ControllerMessageHandlerActor.Init(rpcSocket.Connection, agentServices, shutdownCancellationTokenSource);
// var rpcMessageHandlerActor = agentServices.ActorSystem.ActorOf(ControllerMessageHandlerActor.Factory(rpcMessageHandlerInit), "ControllerMessageHandler");
//
// var rpcDisconnectSemaphore = new SemaphoreSlim(0, 1);
// var rpcTask = RpcClientRuntime.Launch(rpcSocket, rpcMessageHandlerActor, rpcDisconnectSemaphore, shutdownCancellationToken);
// try {
// await rpcTask.WaitAsync(shutdownCancellationToken);
// } finally {
// shutdownCancellationTokenSource.Cancel();
// await agentServices.Shutdown();
//
// rpcDisconnectSemaphore.Release();
// await rpcTask;
// rpcDisconnectSemaphore.Dispose();
//
// NetMQConfig.Cleanup();
// }
return 0;
} catch (OperationCanceledException) {

View File

@@ -1,4 +1,5 @@
using Phantom.Utils.Rpc.New;
using Phantom.Utils.Rpc;
using Phantom.Utils.Rpc.Runtime.Tls;
namespace Phantom.Common.Data;
@@ -7,7 +8,7 @@ public readonly record struct ConnectionKey(RpcCertificateThumbprint Certificate
public byte[] ToBytes() {
Span<byte> result = stackalloc byte[TokenLength + CertificateThumbprint.Bytes.Length];
AuthToken.WriteTo(result[..TokenLength]);
AuthToken.Bytes.CopyTo(result[..TokenLength]);
CertificateThumbprint.Bytes.CopyTo(result[TokenLength..]);
return result.ToArray();
}

View File

@@ -1,6 +0,0 @@
namespace Phantom.Common.Data.Replies;
public enum RegisterAgentFailure : byte {
ConnectionAlreadyHasAnAgent,
InvalidToken
}

View File

@@ -1,5 +1,6 @@
using System.Diagnostics.CodeAnalysis;
using MemoryPack;
using Phantom.Utils.Monads;
using Phantom.Utils.Result;
namespace Phantom.Common.Data;
@@ -24,6 +25,9 @@ public sealed partial class Result<TValue, TError> {
[MemoryPackIgnore]
public TError Error => !hasValue ? error! : throw new InvalidOperationException("Attempted to get error from a success result.");
[MemoryPackIgnore]
public Either<TValue, TError> AsEither => hasValue ? Either.Left(value!) : Either.Right(error!);
private Result(bool hasValue, TValue? value, TError? error) {
this.hasValue = hasValue;
this.value = value;

View File

@@ -1,6 +1,6 @@
using Phantom.Common.Data;
using System.Collections.Immutable;
using Phantom.Common.Data;
using Phantom.Common.Data.Replies;
using Phantom.Common.Messages.Agent.BiDirectional;
using Phantom.Common.Messages.Agent.ToAgent;
using Phantom.Common.Messages.Agent.ToController;
using Phantom.Utils.Logging;
@@ -12,35 +12,26 @@ public static class AgentMessageRegistries {
public static MessageRegistry<IMessageToAgent> ToAgent { get; } = new (PhantomLogger.Create("MessageRegistry", nameof(ToAgent)));
public static MessageRegistry<IMessageToController> ToController { get; } = new (PhantomLogger.Create("MessageRegistry", nameof(ToController)));
public static IMessageDefinitions<IMessageToAgent, IMessageToController, ReplyMessage> Definitions { get; } = new MessageDefinitions();
public static IMessageDefinitions<IMessageToController, IMessageToAgent> Definitions { get; } = new MessageDefinitions();
static AgentMessageRegistries() {
ToAgent.Add<RegisterAgentSuccessMessage>(0);
ToAgent.Add<RegisterAgentFailureMessage>(1);
ToAgent.Add<ConfigureInstanceMessage, Result<ConfigureInstanceResult, InstanceActionFailure>>(2);
ToAgent.Add<LaunchInstanceMessage, Result<LaunchInstanceResult, InstanceActionFailure>>(3);
ToAgent.Add<StopInstanceMessage, Result<StopInstanceResult, InstanceActionFailure>>(4);
ToAgent.Add<SendCommandToInstanceMessage, Result<SendCommandToInstanceResult, InstanceActionFailure>>(5);
ToAgent.Add<ReplyMessage>(127);
ToController.Add<RegisterAgentMessage>(0);
ToController.Add<RegisterAgentMessage, ImmutableArray<ConfigureInstanceMessage>>(0);
ToController.Add<UnregisterAgentMessage>(1);
ToController.Add<AgentIsAliveMessage>(2);
ToController.Add<AdvertiseJavaRuntimesMessage>(3);
ToController.Add<ReportInstanceStatusMessage>(4);
ToController.Add<InstanceOutputMessage>(5);
ToController.Add<ReportAgentStatusMessage>(6);
ToController.Add<ReportInstanceEventMessage>(7);
ToController.Add<ReportInstancePlayerCountsMessage>(8);
ToController.Add<ReplyMessage>(127);
}
private sealed class MessageDefinitions : IMessageDefinitions<IMessageToAgent, IMessageToController, ReplyMessage> {
private sealed class MessageDefinitions : IMessageDefinitions<IMessageToController, IMessageToAgent> {
public MessageRegistry<IMessageToAgent> ToClient => ToAgent;
public MessageRegistry<IMessageToController> ToServer => ToController;
public ReplyMessage CreateReplyMessage(uint sequenceId, byte[] serializedReply) {
return new ReplyMessage(sequenceId, serializedReply);
}
}
}

View File

@@ -1,10 +0,0 @@
using MemoryPack;
using Phantom.Utils.Rpc.Message;
namespace Phantom.Common.Messages.Agent.BiDirectional;
[MemoryPackable(GenerateType.VersionTolerant)]
public sealed partial record ReplyMessage(
[property: MemoryPackOrder(0)] uint SequenceId,
[property: MemoryPackOrder(1)] byte[] SerializedReply
) : IMessageToController, IMessageToAgent, IReply;

View File

@@ -1,9 +0,0 @@
using MemoryPack;
using Phantom.Common.Data.Replies;
namespace Phantom.Common.Messages.Agent.ToAgent;
[MemoryPackable(GenerateType.VersionTolerant)]
public sealed partial record RegisterAgentFailureMessage(
[property: MemoryPackOrder(0)] RegisterAgentFailure FailureKind
) : IMessageToAgent;

View File

@@ -1,9 +0,0 @@
using System.Collections.Immutable;
using MemoryPack;
namespace Phantom.Common.Messages.Agent.ToAgent;
[MemoryPackable(GenerateType.VersionTolerant)]
public sealed partial record RegisterAgentSuccessMessage(
[property: MemoryPackOrder(0)] ImmutableArray<ConfigureInstanceMessage> InitialInstanceConfigurations
) : IMessageToAgent;

View File

@@ -1,6 +0,0 @@
using MemoryPack;
namespace Phantom.Common.Messages.Agent.ToController;
[MemoryPackable(GenerateType.VersionTolerant)]
public sealed partial record AgentIsAliveMessage : IMessageToController;

View File

@@ -1,11 +1,12 @@
using MemoryPack;
using Phantom.Common.Data;
using System.Collections.Immutable;
using MemoryPack;
using Phantom.Common.Data.Agent;
using Phantom.Common.Messages.Agent.ToAgent;
using Phantom.Utils.Actor;
namespace Phantom.Common.Messages.Agent.ToController;
[MemoryPackable(GenerateType.VersionTolerant)]
public sealed partial record RegisterAgentMessage(
[property: MemoryPackOrder(0)] AuthToken AuthToken,
[property: MemoryPackOrder(1)] AgentInfo AgentInfo
) : IMessageToController;
[property: MemoryPackOrder(0)] AgentInfo AgentInfo
) : IMessageToController, ICanReply<ImmutableArray<ConfigureInstanceMessage>>;

View File

@@ -1,10 +0,0 @@
using MemoryPack;
using Phantom.Utils.Rpc.Message;
namespace Phantom.Common.Messages.Web.BiDirectional;
[MemoryPackable(GenerateType.VersionTolerant)]
public sealed partial record ReplyMessage(
[property: MemoryPackOrder(0)] uint SequenceId,
[property: MemoryPackOrder(1)] byte[] SerializedReply
) : IMessageToController, IMessageToWeb, IReply;

View File

@@ -1,9 +0,0 @@
using MemoryPack;
using Phantom.Common.Data;
namespace Phantom.Common.Messages.Web.ToController;
[MemoryPackable(GenerateType.VersionTolerant)]
public sealed partial record RegisterWebMessage(
[property: MemoryPackOrder(0)] AuthToken AuthToken
) : IMessageToController;

View File

@@ -1,8 +0,0 @@
using MemoryPack;
namespace Phantom.Common.Messages.Web.ToWeb;
[MemoryPackable(GenerateType.VersionTolerant)]
public sealed partial record RegisterWebResultMessage(
[property: MemoryPackOrder(0)] bool Success
) : IMessageToWeb;

View File

@@ -7,7 +7,6 @@ using Phantom.Common.Data.Web.AuditLog;
using Phantom.Common.Data.Web.EventLog;
using Phantom.Common.Data.Web.Instance;
using Phantom.Common.Data.Web.Users;
using Phantom.Common.Messages.Web.BiDirectional;
using Phantom.Common.Messages.Web.ToController;
using Phantom.Common.Messages.Web.ToWeb;
using Phantom.Utils.Logging;
@@ -19,10 +18,9 @@ public static class WebMessageRegistries {
public static MessageRegistry<IMessageToController> ToController { get; } = new (PhantomLogger.Create("MessageRegistry", nameof(ToController)));
public static MessageRegistry<IMessageToWeb> ToWeb { get; } = new (PhantomLogger.Create("MessageRegistry", nameof(ToWeb)));
public static IMessageDefinitions<IMessageToWeb, IMessageToController, ReplyMessage> Definitions { get; } = new MessageDefinitions();
public static IMessageDefinitions<IMessageToController, IMessageToWeb> Definitions { get; } = new MessageDefinitions();
static WebMessageRegistries() {
ToController.Add<RegisterWebMessage>(0);
ToController.Add<UnregisterWebMessage>(1);
ToController.Add<LogInMessage, Optional<LogInSuccess>>(2);
ToController.Add<LogOutMessage>(3);
@@ -42,22 +40,15 @@ public static class WebMessageRegistries {
ToController.Add<GetAgentJavaRuntimesMessage, ImmutableDictionary<Guid, ImmutableArray<TaggedJavaRuntime>>>(17);
ToController.Add<GetAuditLogMessage, Result<ImmutableArray<AuditLogItem>, UserActionFailure>>(18);
ToController.Add<GetEventLogMessage, Result<ImmutableArray<EventLogItem>, UserActionFailure>>(19);
ToController.Add<ReplyMessage>(127);
ToWeb.Add<RegisterWebResultMessage>(0);
ToWeb.Add<RefreshAgentsMessage>(1);
ToWeb.Add<RefreshInstancesMessage>(2);
ToWeb.Add<InstanceOutputMessage>(3);
ToWeb.Add<RefreshUserSessionMessage>(4);
ToWeb.Add<ReplyMessage>(127);
}
private sealed class MessageDefinitions : IMessageDefinitions<IMessageToWeb, IMessageToController, ReplyMessage> {
private sealed class MessageDefinitions : IMessageDefinitions<IMessageToController, IMessageToWeb> {
public MessageRegistry<IMessageToWeb> ToClient => ToWeb;
public MessageRegistry<IMessageToController> ToServer => ToController;
public ReplyMessage CreateReplyMessage(uint sequenceId, byte[] serializedReply) {
return new ReplyMessage(sequenceId, serializedReply);
}
}
}

View File

@@ -21,7 +21,7 @@ using Phantom.Utils.Actor.Mailbox;
using Phantom.Utils.Actor.Tasks;
using Phantom.Utils.Collections;
using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Runtime;
using Phantom.Utils.Rpc.Runtime.Server;
using Serilog;
namespace Phantom.Controller.Services.Agents;
@@ -168,13 +168,13 @@ sealed class AgentActor : ReceiveActor<AgentActor.ICommand> {
return configurationMessages.ToImmutable();
}
public interface ICommand {}
public interface ICommand;
private sealed record InitializeCommand : ICommand;
public sealed record RegisterCommand(AgentConfiguration Configuration, RpcConnectionToClient<IMessageToAgent> Connection) : ICommand, ICanReply<ImmutableArray<ConfigureInstanceMessage>>;
public sealed record RegisterCommand(AgentConfiguration Configuration, RpcServerToClientConnection<IMessageToController, IMessageToAgent> Connection) : ICommand, ICanReply<ImmutableArray<ConfigureInstanceMessage>>;
public sealed record UnregisterCommand(RpcConnectionToClient<IMessageToAgent> Connection) : ICommand;
public sealed record UnregisterCommand(RpcServerToClientConnection<IMessageToController, IMessageToAgent> Connection) : ICommand;
private sealed record RefreshConnectionStatusCommand : ICommand;

View File

@@ -1,36 +1,30 @@
using Phantom.Common.Messages.Agent;
using Phantom.Utils.Actor;
using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Runtime;
using Phantom.Utils.Rpc.Runtime.Server;
using Serilog;
namespace Phantom.Controller.Services.Agents;
sealed class AgentConnection {
sealed class AgentConnection(Guid agentGuid, string agentName) {
private static readonly ILogger Logger = PhantomLogger.Create<AgentConnection>();
private readonly Guid agentGuid;
private string agentName;
private string agentName = agentName;
private RpcServerToClientConnection<IMessageToController, IMessageToAgent>? connection;
private RpcConnectionToClient<IMessageToAgent>? connection;
public AgentConnection(Guid agentGuid, string agentName) {
this.agentName = agentName;
this.agentGuid = agentGuid;
}
public void UpdateConnection(RpcConnectionToClient<IMessageToAgent> newConnection, string newAgentName) {
public void UpdateConnection(RpcServerToClientConnection<IMessageToController, IMessageToAgent> newConnection, string newAgentName) {
lock (this) {
connection?.Close();
connection?.ClientClosedSession();
connection = newConnection;
agentName = newAgentName;
}
}
public bool CloseIfSame(RpcConnectionToClient<IMessageToAgent> expected) {
public bool CloseIfSame(RpcServerToClientConnection<IMessageToController, IMessageToAgent> expected) {
lock (this) {
if (connection != null && connection.IsSame(expected)) {
connection.Close();
if (connection != null && ReferenceEquals(connection, expected)) {
connection.ClientClosedSession();
connection = null;
return true;
}
}
@@ -45,7 +39,7 @@ sealed class AgentConnection {
return Task.CompletedTask;
}
return connection.Send(message);
return connection.SendChannel.SendMessage(message).AsTask();
}
}
@@ -53,10 +47,10 @@ sealed class AgentConnection {
lock (this) {
if (connection == null) {
LogAgentOffline();
return Task.FromResult<TReply?>(default);
return Task.FromResult<TReply?>(null);
}
return connection.Send<TMessage, TReply>(message, waitForReplyTime, waitForReplyCancellationToken)!;
return connection.SendChannel.SendMessage<TMessage, TReply>(message, waitForReplyTime, waitForReplyCancellationToken)!;
}
}

View File

@@ -13,7 +13,7 @@ using Phantom.Controller.Minecraft;
using Phantom.Controller.Services.Users.Sessions;
using Phantom.Utils.Actor;
using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Runtime;
using Phantom.Utils.Rpc.Runtime.Server;
using Serilog;
namespace Phantom.Controller.Services.Agents;
@@ -22,19 +22,18 @@ sealed class AgentManager {
private static readonly ILogger Logger = PhantomLogger.Create<AgentManager>();
private readonly IActorRefFactory actorSystem;
private readonly AuthToken authToken;
private readonly ControllerState controllerState;
private readonly MinecraftVersions minecraftVersions;
private readonly UserLoginManager userLoginManager;
private readonly IDbContextProvider dbProvider;
private readonly CancellationToken cancellationToken;
private readonly ConcurrentDictionary<Guid, ActorRef<AgentActor.ICommand>> agentsByGuid = new ();
private readonly ConcurrentDictionary<Guid, ActorRef<AgentActor.ICommand>> agentsByAgentGuid = new ();
private readonly ConcurrentDictionary<Guid, Guid> agentGuidsBySessionGuid = new ();
private readonly Func<Guid, AgentConfiguration, ActorRef<AgentActor.ICommand>> addAgentActorFactory;
public AgentManager(IActorRefFactory actorSystem, AuthToken authToken, ControllerState controllerState, MinecraftVersions minecraftVersions, UserLoginManager userLoginManager, IDbContextProvider dbProvider, CancellationToken cancellationToken) {
public AgentManager(IActorRefFactory actorSystem, ControllerState controllerState, MinecraftVersions minecraftVersions, UserLoginManager userLoginManager, IDbContextProvider dbProvider, CancellationToken cancellationToken) {
this.actorSystem = actorSystem;
this.authToken = authToken;
this.controllerState = controllerState;
this.minecraftVersions = minecraftVersions;
this.userLoginManager = userLoginManager;
@@ -57,28 +56,21 @@ sealed class AgentManager {
var agentGuid = entity.AgentGuid;
var agentConfiguration = new AgentConfiguration(entity.Name, entity.ProtocolVersion, entity.BuildVersion, entity.MaxInstances, entity.MaxMemory);
if (agentsByGuid.TryAdd(agentGuid, CreateAgentActor(agentGuid, agentConfiguration))) {
if (agentsByAgentGuid.TryAdd(agentGuid, CreateAgentActor(agentGuid, agentConfiguration))) {
Logger.Information("Loaded agent \"{AgentName}\" (GUID {AgentGuid}) from database.", agentConfiguration.AgentName, agentGuid);
}
}
}
public async Task<bool> RegisterAgent(AuthToken authToken, AgentInfo agentInfo, RpcConnectionToClient<IMessageToAgent> connection) {
if (!this.authToken.FixedTimeEquals(authToken)) {
await connection.Send(new RegisterAgentFailureMessage(RegisterAgentFailure.InvalidToken));
return false;
}
public async Task<ImmutableArray<ConfigureInstanceMessage>> RegisterAgent(AgentInfo agentInfo, RpcServerToClientConnection<IMessageToController, IMessageToAgent> connection) {
var agentProperties = AgentConfiguration.From(agentInfo);
var agentActor = agentsByGuid.GetOrAdd(agentInfo.AgentGuid, addAgentActorFactory, agentProperties);
var configureInstanceMessages = await agentActor.Request(new AgentActor.RegisterCommand(agentProperties, connection), cancellationToken);
await connection.Send(new RegisterAgentSuccessMessage(configureInstanceMessages));
return true;
var agentActor = agentsByAgentGuid.GetOrAdd(agentInfo.AgentGuid, addAgentActorFactory, agentProperties);
agentGuidsBySessionGuid[connection.SessionId] = agentInfo.AgentGuid;
return await agentActor.Request(new AgentActor.RegisterCommand(agentProperties, connection), cancellationToken);
}
public bool TellAgent(Guid agentGuid, AgentActor.ICommand command) {
if (agentsByGuid.TryGetValue(agentGuid, out var agent)) {
if (agentsByAgentGuid.TryGetValue(agentGuid, out var agent)) {
agent.Tell(command);
return true;
}
@@ -94,7 +86,7 @@ sealed class AgentManager {
return (UserInstanceActionFailure) UserActionFailure.NotAuthorized;
}
if (!agentsByGuid.TryGetValue(agentGuid, out var agent)) {
if (!agentsByAgentGuid.TryGetValue(agentGuid, out var agent)) {
return (UserInstanceActionFailure) InstanceActionFailure.AgentDoesNotExist;
}
@@ -102,4 +94,12 @@ sealed class AgentManager {
var result = await agent.Request(command, cancellationToken);
return result.MapError(static error => (UserInstanceActionFailure) error);
}
public bool TryGetAgentGuidBySessionGuid(Guid sessionGuid, out Guid agentGuid) {
return agentGuidsBySessionGuid.TryGetValue(sessionGuid, out agentGuid);
}
public void OnSessionClosed(Guid sessionId, Guid agentGuid) {
agentGuidsBySessionGuid.TryRemove(new KeyValuePair<Guid, Guid>(sessionId, agentGuid));
}
}

View File

@@ -1,9 +1,4 @@
using Akka.Actor;
using Phantom.Common.Data;
using Phantom.Common.Messages.Agent;
using Phantom.Common.Messages.Agent.ToController;
using Phantom.Common.Messages.Web;
using Phantom.Common.Messages.Web.ToController;
using Phantom.Controller.Database;
using Phantom.Controller.Minecraft;
using Phantom.Controller.Services.Agents;
@@ -13,9 +8,9 @@ using Phantom.Controller.Services.Rpc;
using Phantom.Controller.Services.Users;
using Phantom.Controller.Services.Users.Sessions;
using Phantom.Utils.Actor;
using Phantom.Utils.Rpc.Runtime;
using IMessageFromAgentToController = Phantom.Common.Messages.Agent.IMessageToController;
using IMessageFromWebToController = Phantom.Common.Messages.Web.IMessageToController;
using Phantom.Utils.Rpc;
using IRpcAgentRegistrar = Phantom.Utils.Rpc.Runtime.Server.IRpcServerClientRegistrar<Phantom.Common.Messages.Agent.IMessageToController, Phantom.Common.Messages.Agent.IMessageToAgent>;
using IRpcWebRegistrar = Phantom.Utils.Rpc.Runtime.Server.IRpcServerClientRegistrar<Phantom.Common.Messages.Web.IMessageToController, Phantom.Common.Messages.Web.IMessageToWeb>;
namespace Phantom.Controller.Services;
@@ -38,13 +33,13 @@ public sealed class ControllerServices : IDisposable {
private AuditLogManager AuditLogManager { get; }
private EventLogManager EventLogManager { get; }
public IRegistrationHandler<IMessageToAgent, IMessageFromAgentToController, RegisterAgentMessage> AgentRegistrationHandler { get; }
public IRegistrationHandler<IMessageToWeb, IMessageFromWebToController, RegisterWebMessage> WebRegistrationHandler { get; }
public IRpcAgentRegistrar AgentRegistrar { get; }
public IRpcWebRegistrar WebRegistrar { get; }
private readonly IDbContextProvider dbProvider;
private readonly CancellationToken cancellationToken;
public ControllerServices(IDbContextProvider dbProvider, AuthToken agentAuthToken, AuthToken webAuthToken, CancellationToken shutdownCancellationToken) {
public ControllerServices(IDbContextProvider dbProvider, AuthToken agentAuthToken, CancellationToken shutdownCancellationToken) {
this.dbProvider = dbProvider;
this.cancellationToken = shutdownCancellationToken;
@@ -60,14 +55,14 @@ public sealed class ControllerServices : IDisposable {
this.UserLoginManager = new UserLoginManager(AuthenticatedUserCache, UserManager, dbProvider);
this.PermissionManager = new PermissionManager(dbProvider);
this.AgentManager = new AgentManager(ActorSystem, agentAuthToken, ControllerState, MinecraftVersions, UserLoginManager, dbProvider, cancellationToken);
this.AgentManager = new AgentManager(ActorSystem, ControllerState, MinecraftVersions, UserLoginManager, dbProvider, cancellationToken);
this.InstanceLogManager = new InstanceLogManager();
this.AuditLogManager = new AuditLogManager(dbProvider);
this.EventLogManager = new EventLogManager(ControllerState, ActorSystem, dbProvider, shutdownCancellationToken);
this.AgentRegistrationHandler = new AgentRegistrationHandler(AgentManager, InstanceLogManager, EventLogManager);
this.WebRegistrationHandler = new WebRegistrationHandler(webAuthToken, ControllerState, InstanceLogManager, UserManager, RoleManager, UserRoleManager, UserLoginManager, AuditLogManager, AgentManager, MinecraftVersions, EventLogManager);
this.AgentRegistrar = new AgentClientRegistrar(ActorSystem, AgentManager, InstanceLogManager, EventLogManager);
this.WebRegistrar = new WebClientRegistrar(ActorSystem, ControllerState, InstanceLogManager, UserManager, RoleManager, UserRoleManager, UserLoginManager, AuditLogManager, AgentManager, MinecraftVersions, EventLogManager);
}
public async Task Initialize() {

View File

@@ -0,0 +1,31 @@
using Akka.Actor;
using Phantom.Common.Messages.Agent;
using Phantom.Controller.Services.Agents;
using Phantom.Controller.Services.Events;
using Phantom.Controller.Services.Instances;
using Phantom.Utils.Actor;
using Phantom.Utils.Rpc.Message;
using Phantom.Utils.Rpc.Runtime.Server;
namespace Phantom.Controller.Services.Rpc;
sealed class AgentClientRegistrar(
IActorRefFactory actorSystem,
AgentManager agentManager,
InstanceLogManager instanceLogManager,
EventLogManager eventLogManager
) : IRpcServerClientRegistrar<IMessageToController, IMessageToAgent> {
public IMessageReceiver<IMessageToController> Register(RpcServerToClientConnection<IMessageToController, IMessageToAgent> connection) {
var name = "AgentClient-" + connection.SessionId;
var init = new AgentMessageHandlerActor.Init(connection, agentManager, instanceLogManager, eventLogManager);
return new Receiver(connection.SessionId, agentManager, actorSystem.ActorOf(AgentMessageHandlerActor.Factory(init), name));
}
private sealed class Receiver(Guid sessionId, AgentManager agentManager, ActorRef<IMessageToController> actor) : IMessageReceiver<IMessageToController>.Actor(actor) {
public override void OnPing() {
if (agentManager.TryGetAgentGuidBySessionGuid(sessionId, out var agentGuid)) {
agentManager.TellAgent(agentGuid, new AgentActor.NotifyIsAliveCommand());
}
}
}
}

View File

@@ -1,93 +1,85 @@
using Phantom.Common.Data.Replies;
using System.Collections.Immutable;
using Akka.Actor;
using Phantom.Common.Messages.Agent;
using Phantom.Common.Messages.Agent.BiDirectional;
using Phantom.Common.Messages.Agent.ToAgent;
using Phantom.Common.Messages.Agent.ToController;
using Phantom.Controller.Services.Agents;
using Phantom.Controller.Services.Events;
using Phantom.Controller.Services.Instances;
using Phantom.Utils.Actor;
using Phantom.Utils.Rpc.Runtime;
using Phantom.Utils.Rpc.Runtime.Server;
namespace Phantom.Controller.Services.Rpc;
sealed class AgentMessageHandlerActor : ReceiveActor<IMessageToController> {
public readonly record struct Init(Guid AgentGuid, RpcConnectionToClient<IMessageToAgent> Connection, AgentRegistrationHandler AgentRegistrationHandler, AgentManager AgentManager, InstanceLogManager InstanceLogManager, EventLogManager EventLogManager);
public readonly record struct Init(RpcServerToClientConnection<IMessageToController, IMessageToAgent> Connection, AgentManager AgentManager, InstanceLogManager InstanceLogManager, EventLogManager EventLogManager);
public static Props<IMessageToController> Factory(Init init) {
return Props<IMessageToController>.Create(() => new AgentMessageHandlerActor(init), new ActorConfiguration { SupervisorStrategy = SupervisorStrategies.Resume });
}
private readonly Guid agentGuid;
private readonly RpcConnectionToClient<IMessageToAgent> connection;
private readonly AgentRegistrationHandler agentRegistrationHandler;
private readonly RpcServerToClientConnection<IMessageToController, IMessageToAgent> connection;
private readonly AgentManager agentManager;
private readonly InstanceLogManager instanceLogManager;
private readonly EventLogManager eventLogManager;
private Guid? registeredAgentGuid;
private AgentMessageHandlerActor(Init init) {
this.agentGuid = init.AgentGuid;
this.connection = init.Connection;
this.agentRegistrationHandler = init.AgentRegistrationHandler;
this.agentManager = init.AgentManager;
this.instanceLogManager = init.InstanceLogManager;
this.eventLogManager = init.EventLogManager;
ReceiveAsync<RegisterAgentMessage>(HandleRegisterAgent);
Receive<UnregisterAgentMessage>(HandleUnregisterAgent);
Receive<AgentIsAliveMessage>(HandleAgentIsAlive);
ReceiveAsyncAndReply<RegisterAgentMessage, ImmutableArray<ConfigureInstanceMessage>>(HandleRegisterAgent);
ReceiveAsync<UnregisterAgentMessage>(HandleUnregisterAgent);
Receive<AdvertiseJavaRuntimesMessage>(HandleAdvertiseJavaRuntimes);
Receive<ReportAgentStatusMessage>(HandleReportAgentStatus);
Receive<ReportInstanceStatusMessage>(HandleReportInstanceStatus);
Receive<ReportInstancePlayerCountsMessage>(HandleReportInstancePlayerCounts);
Receive<ReportInstanceEventMessage>(HandleReportInstanceEvent);
Receive<InstanceOutputMessage>(HandleInstanceOutput);
Receive<ReplyMessage>(HandleReply);
}
private async Task HandleRegisterAgent(RegisterAgentMessage message) {
if (agentGuid != message.AgentInfo.AgentGuid) {
await connection.Send(new RegisterAgentFailureMessage(RegisterAgentFailure.ConnectionAlreadyHasAnAgent));
}
else {
await agentRegistrationHandler.TryRegisterImpl(connection, message);
}
private Guid RequireAgentGuid() {
return registeredAgentGuid ?? throw new InvalidOperationException("Agent has not registered yet.");
}
private void HandleUnregisterAgent(UnregisterAgentMessage message) {
private Task<ImmutableArray<ConfigureInstanceMessage>> HandleRegisterAgent(RegisterAgentMessage message) {
registeredAgentGuid = message.AgentInfo.AgentGuid;
return agentManager.RegisterAgent(message.AgentInfo, connection);
}
private Task HandleUnregisterAgent(UnregisterAgentMessage message) {
Guid agentGuid = RequireAgentGuid();
agentManager.TellAgent(agentGuid, new AgentActor.UnregisterCommand(connection));
connection.Close();
}
agentManager.OnSessionClosed(connection.SessionId, agentGuid);
private void HandleAgentIsAlive(AgentIsAliveMessage message) {
agentManager.TellAgent(agentGuid, new AgentActor.NotifyIsAliveCommand());
Self.Tell(PoisonPill.Instance);
return connection.ClientClosedSession();
}
private void HandleAdvertiseJavaRuntimes(AdvertiseJavaRuntimesMessage message) {
agentManager.TellAgent(agentGuid, new AgentActor.UpdateJavaRuntimesCommand(message.Runtimes));
agentManager.TellAgent(RequireAgentGuid(), new AgentActor.UpdateJavaRuntimesCommand(message.Runtimes));
}
private void HandleReportAgentStatus(ReportAgentStatusMessage message) {
agentManager.TellAgent(agentGuid, new AgentActor.UpdateStatsCommand(message.RunningInstanceCount, message.RunningInstanceMemory));
agentManager.TellAgent(RequireAgentGuid(), new AgentActor.UpdateStatsCommand(message.RunningInstanceCount, message.RunningInstanceMemory));
}
private void HandleReportInstanceStatus(ReportInstanceStatusMessage message) {
agentManager.TellAgent(agentGuid, new AgentActor.UpdateInstanceStatusCommand(message.InstanceGuid, message.InstanceStatus));
agentManager.TellAgent(RequireAgentGuid(), new AgentActor.UpdateInstanceStatusCommand(message.InstanceGuid, message.InstanceStatus));
}
private void HandleReportInstancePlayerCounts(ReportInstancePlayerCountsMessage message) {
agentManager.TellAgent(agentGuid, new AgentActor.UpdateInstancePlayerCountsCommand(message.InstanceGuid, message.PlayerCounts));
agentManager.TellAgent(RequireAgentGuid(), new AgentActor.UpdateInstancePlayerCountsCommand(message.InstanceGuid, message.PlayerCounts));
}
private void HandleReportInstanceEvent(ReportInstanceEventMessage message) {
message.Event.Accept(eventLogManager.CreateInstanceEventVisitor(message.EventGuid, message.UtcTime, agentGuid, message.InstanceGuid));
message.Event.Accept(eventLogManager.CreateInstanceEventVisitor(message.EventGuid, message.UtcTime, RequireAgentGuid(), message.InstanceGuid));
}
private void HandleInstanceOutput(InstanceOutputMessage message) {
instanceLogManager.ReceiveLines(message.InstanceGuid, message.Lines);
}
private void HandleReply(ReplyMessage message) {
connection.Receive(message);
}
}

View File

@@ -1,34 +0,0 @@
using Phantom.Common.Messages.Agent;
using Phantom.Common.Messages.Agent.ToController;
using Phantom.Controller.Services.Agents;
using Phantom.Controller.Services.Events;
using Phantom.Controller.Services.Instances;
using Phantom.Utils.Actor;
using Phantom.Utils.Rpc.Runtime;
namespace Phantom.Controller.Services.Rpc;
sealed class AgentRegistrationHandler : IRegistrationHandler<IMessageToAgent, IMessageToController, RegisterAgentMessage> {
private readonly AgentManager agentManager;
private readonly InstanceLogManager instanceLogManager;
private readonly EventLogManager eventLogManager;
public AgentRegistrationHandler(AgentManager agentManager, InstanceLogManager instanceLogManager, EventLogManager eventLogManager) {
this.agentManager = agentManager;
this.instanceLogManager = instanceLogManager;
this.eventLogManager = eventLogManager;
}
async Task<Props<IMessageToController>?> IRegistrationHandler<IMessageToAgent, IMessageToController, RegisterAgentMessage>.TryRegister(RpcConnectionToClient<IMessageToAgent> connection, RegisterAgentMessage message) {
return await TryRegisterImpl(connection, message) ? CreateMessageHandlerActorProps(message.AgentInfo.AgentGuid, connection) : null;
}
public Task<bool> TryRegisterImpl(RpcConnectionToClient<IMessageToAgent> connection, RegisterAgentMessage message) {
return agentManager.RegisterAgent(message.AuthToken, message.AgentInfo, connection);
}
private Props<IMessageToController> CreateMessageHandlerActorProps(Guid agentGuid, RpcConnectionToClient<IMessageToAgent> connection) {
var init = new AgentMessageHandlerActor.Init(agentGuid, connection, this, agentManager, instanceLogManager, eventLogManager);
return AgentMessageHandlerActor.Factory(init);
}
}

View File

@@ -1,34 +0,0 @@
using Phantom.Common.Data;
using Phantom.Common.Data.Agent;
using Phantom.Common.Messages.Agent.Handshake;
using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.New;
using Serilog;
namespace Phantom.Controller.Services.Rpc;
public sealed class RpcServerAgentHandshake(AuthToken authToken) : RpcServerHandshake {
private static readonly ILogger Logger = PhantomLogger.Create<RpcServerAgentHandshake>();
protected override async Task<bool> AcceptClient(string remoteAddress, Stream stream, CancellationToken cancellationToken) {
Memory<byte> buffer = new Memory<byte>(new byte[AuthToken.Length]);
await stream.ReadExactlyAsync(buffer, cancellationToken: cancellationToken);
if (!authToken.FixedTimeEquals(buffer.Span)) {
Logger.Warning("Rejected client {}, invalid authorization token.", remoteAddress);
await Respond(remoteAddress, stream, new InvalidAuthToken(), cancellationToken);
return false;
}
AgentInfo agentInfo = await Serialization.Deserialize<AgentInfo>(stream, cancellationToken);
return true;
}
private async ValueTask Respond(string remoteAddress, Stream stream, IAgentHandshakeResult result, CancellationToken cancellationToken) {
try {
await Serialization.Serialize(result, stream, cancellationToken);
} catch (Exception e) {
Logger.Error(e, "Could not send handshake result to client {}.", remoteAddress);
}
}
}

View File

@@ -0,0 +1,33 @@
using Akka.Actor;
using Phantom.Common.Messages.Web;
using Phantom.Controller.Minecraft;
using Phantom.Controller.Services.Agents;
using Phantom.Controller.Services.Events;
using Phantom.Controller.Services.Instances;
using Phantom.Controller.Services.Users;
using Phantom.Controller.Services.Users.Sessions;
using Phantom.Utils.Actor;
using Phantom.Utils.Rpc.Message;
using Phantom.Utils.Rpc.Runtime.Server;
namespace Phantom.Controller.Services.Rpc;
sealed class WebClientRegistrar(
IActorRefFactory actorSystem,
ControllerState controllerState,
InstanceLogManager instanceLogManager,
UserManager userManager,
RoleManager roleManager,
UserRoleManager userRoleManager,
UserLoginManager userLoginManager,
AuditLogManager auditLogManager,
AgentManager agentManager,
MinecraftVersions minecraftVersions,
EventLogManager eventLogManager
) : IRpcServerClientRegistrar<IMessageToController, IMessageToWeb> {
public IMessageReceiver<IMessageToController> Register(RpcServerToClientConnection<IMessageToController, IMessageToWeb> connection) {
var name = "WebClient-" + connection.SessionId;
var init = new WebMessageHandlerActor.Init(connection, controllerState, instanceLogManager, userManager, roleManager, userRoleManager, userLoginManager, auditLogManager, agentManager, minecraftVersions, eventLogManager);
return new IMessageReceiver<IMessageToController>.Actor(actorSystem.ActorOf(WebMessageHandlerActor.Factory(init), name));
}
}

View File

@@ -10,13 +10,13 @@ using Phantom.Utils.Rpc.Runtime;
namespace Phantom.Controller.Services.Rpc;
sealed class WebMessageDataUpdateSenderActor : ReceiveActor<WebMessageDataUpdateSenderActor.ICommand> {
public readonly record struct Init(RpcConnectionToClient<IMessageToWeb> Connection, ControllerState ControllerState, InstanceLogManager InstanceLogManager);
public readonly record struct Init(RpcSendChannel<IMessageToWeb> Connection, ControllerState ControllerState, InstanceLogManager InstanceLogManager);
public static Props<ICommand> Factory(Init init) {
return Props<ICommand>.Create(() => new WebMessageDataUpdateSenderActor(init), new ActorConfiguration { SupervisorStrategy = SupervisorStrategies.Resume });
}
private readonly RpcConnectionToClient<IMessageToWeb> connection;
private readonly RpcSendChannel<IMessageToWeb> connection;
private readonly ControllerState controllerState;
private readonly InstanceLogManager instanceLogManager;
private readonly ActorRef<ICommand> selfCached;
@@ -70,18 +70,18 @@ sealed class WebMessageDataUpdateSenderActor : ReceiveActor<WebMessageDataUpdate
private sealed record RefreshUserSessionCommand(Guid UserGuid) : ICommand;
private Task RefreshAgents(RefreshAgentsCommand command) {
return connection.Send(new RefreshAgentsMessage(command.Agents.Values.ToImmutableArray()));
return connection.SendMessage(new RefreshAgentsMessage(command.Agents.Values.ToImmutableArray())).AsTask();
}
private Task RefreshInstances(RefreshInstancesCommand command) {
return connection.Send(new RefreshInstancesMessage(command.Instances.Values.ToImmutableArray()));
return connection.SendMessage(new RefreshInstancesMessage(command.Instances.Values.ToImmutableArray())).AsTask();
}
private Task ReceiveInstanceLogs(ReceiveInstanceLogsCommand command) {
return connection.Send(new InstanceOutputMessage(command.InstanceGuid, command.Lines));
return connection.SendMessage(new InstanceOutputMessage(command.InstanceGuid, command.Lines)).AsTask();
}
private Task RefreshUserSession(RefreshUserSessionCommand command) {
return connection.Send(new RefreshUserSessionMessage(command.UserGuid));
return connection.SendMessage(new RefreshUserSessionMessage(command.UserGuid)).AsTask();
}
}

View File

@@ -1,4 +1,5 @@
using System.Collections.Immutable;
using Akka.Actor;
using Phantom.Common.Data;
using Phantom.Common.Data.Java;
using Phantom.Common.Data.Minecraft;
@@ -7,7 +8,6 @@ using Phantom.Common.Data.Web.AuditLog;
using Phantom.Common.Data.Web.EventLog;
using Phantom.Common.Data.Web.Instance;
using Phantom.Common.Data.Web.Users;
using Phantom.Common.Messages.Agent.BiDirectional;
using Phantom.Common.Messages.Web;
using Phantom.Common.Messages.Web.ToController;
using Phantom.Controller.Minecraft;
@@ -17,14 +17,13 @@ using Phantom.Controller.Services.Instances;
using Phantom.Controller.Services.Users;
using Phantom.Controller.Services.Users.Sessions;
using Phantom.Utils.Actor;
using Phantom.Utils.Rpc.Runtime;
using Phantom.Utils.Rpc.Runtime.Server;
namespace Phantom.Controller.Services.Rpc;
sealed class WebMessageHandlerActor : ReceiveActor<IMessageToController> {
public readonly record struct Init(
RpcConnectionToClient<IMessageToWeb> Connection,
WebRegistrationHandler WebRegistrationHandler,
RpcServerToClientConnection<IMessageToController, IMessageToWeb> Connection,
ControllerState ControllerState,
InstanceLogManager InstanceLogManager,
UserManager UserManager,
@@ -41,8 +40,7 @@ sealed class WebMessageHandlerActor : ReceiveActor<IMessageToController> {
return Props<IMessageToController>.Create(() => new WebMessageHandlerActor(init), new ActorConfiguration { SupervisorStrategy = SupervisorStrategies.Resume });
}
private readonly RpcConnectionToClient<IMessageToWeb> connection;
private readonly WebRegistrationHandler webRegistrationHandler;
private readonly RpcServerToClientConnection<IMessageToController, IMessageToWeb> connection;
private readonly ControllerState controllerState;
private readonly UserManager userManager;
private readonly RoleManager roleManager;
@@ -55,7 +53,6 @@ sealed class WebMessageHandlerActor : ReceiveActor<IMessageToController> {
private WebMessageHandlerActor(Init init) {
this.connection = init.Connection;
this.webRegistrationHandler = init.WebRegistrationHandler;
this.controllerState = init.ControllerState;
this.userManager = init.UserManager;
this.roleManager = init.RoleManager;
@@ -66,11 +63,10 @@ sealed class WebMessageHandlerActor : ReceiveActor<IMessageToController> {
this.minecraftVersions = init.MinecraftVersions;
this.eventLogManager = init.EventLogManager;
var senderActorInit = new WebMessageDataUpdateSenderActor.Init(connection, controllerState, init.InstanceLogManager);
var senderActorInit = new WebMessageDataUpdateSenderActor.Init(connection.SendChannel, controllerState, init.InstanceLogManager);
Context.ActorOf(WebMessageDataUpdateSenderActor.Factory(senderActorInit), "DataUpdateSender");
ReceiveAsync<RegisterWebMessage>(HandleRegisterWeb);
Receive<UnregisterWebMessage>(HandleUnregisterWeb);
ReceiveAsync<UnregisterWebMessage>(HandleUnregisterWeb);
ReceiveAndReplyLater<LogInMessage, Optional<LogInSuccess>>(HandleLogIn);
Receive<LogOutMessage>(HandleLogOut);
ReceiveAndReply<GetAuthenticatedUser, Optional<AuthenticatedUserInfo>>(GetAuthenticatedUser);
@@ -89,15 +85,11 @@ sealed class WebMessageHandlerActor : ReceiveActor<IMessageToController> {
ReceiveAndReply<GetAgentJavaRuntimesMessage, ImmutableDictionary<Guid, ImmutableArray<TaggedJavaRuntime>>>(HandleGetAgentJavaRuntimes);
ReceiveAndReplyLater<GetAuditLogMessage, Result<ImmutableArray<AuditLogItem>, UserActionFailure>>(HandleGetAuditLog);
ReceiveAndReplyLater<GetEventLogMessage, Result<ImmutableArray<EventLogItem>, UserActionFailure>>(HandleGetEventLog);
Receive<ReplyMessage>(HandleReply);
}
private async Task HandleRegisterWeb(RegisterWebMessage message) {
await webRegistrationHandler.TryRegisterImpl(connection, message);
}
private void HandleUnregisterWeb(UnregisterWebMessage message) {
connection.Close();
private Task HandleUnregisterWeb(UnregisterWebMessage message) {
Self.Tell(PoisonPill.Instance);
return connection.ClientClosedSession();
}
private Task<Optional<LogInSuccess>> HandleLogIn(LogInMessage message) {
@@ -191,8 +183,4 @@ sealed class WebMessageHandlerActor : ReceiveActor<IMessageToController> {
private Task<Result<ImmutableArray<EventLogItem>, UserActionFailure>> HandleGetEventLog(GetEventLogMessage message) {
return eventLogManager.GetMostRecentItems(userLoginManager.GetLoggedInUser(message.AuthToken), message.Count);
}
private void HandleReply(ReplyMessage message) {
connection.Receive(message);
}
}

View File

@@ -1,68 +0,0 @@
using Phantom.Common.Data;
using Phantom.Common.Messages.Web;
using Phantom.Common.Messages.Web.ToController;
using Phantom.Common.Messages.Web.ToWeb;
using Phantom.Controller.Minecraft;
using Phantom.Controller.Services.Agents;
using Phantom.Controller.Services.Events;
using Phantom.Controller.Services.Instances;
using Phantom.Controller.Services.Users;
using Phantom.Controller.Services.Users.Sessions;
using Phantom.Utils.Actor;
using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Runtime;
using Serilog;
namespace Phantom.Controller.Services.Rpc;
sealed class WebRegistrationHandler : IRegistrationHandler<IMessageToWeb, IMessageToController, RegisterWebMessage> {
private static readonly ILogger Logger = PhantomLogger.Create<WebRegistrationHandler>();
private readonly AuthToken webAuthToken;
private readonly ControllerState controllerState;
private readonly InstanceLogManager instanceLogManager;
private readonly UserManager userManager;
private readonly RoleManager roleManager;
private readonly UserRoleManager userRoleManager;
private readonly UserLoginManager userLoginManager;
private readonly AuditLogManager auditLogManager;
private readonly AgentManager agentManager;
private readonly MinecraftVersions minecraftVersions;
private readonly EventLogManager eventLogManager;
public WebRegistrationHandler(AuthToken webAuthToken, ControllerState controllerState, InstanceLogManager instanceLogManager, UserManager userManager, RoleManager roleManager, UserRoleManager userRoleManager, UserLoginManager userLoginManager, AuditLogManager auditLogManager, AgentManager agentManager, MinecraftVersions minecraftVersions, EventLogManager eventLogManager) {
this.webAuthToken = webAuthToken;
this.controllerState = controllerState;
this.userManager = userManager;
this.roleManager = roleManager;
this.userRoleManager = userRoleManager;
this.userLoginManager = userLoginManager;
this.auditLogManager = auditLogManager;
this.agentManager = agentManager;
this.minecraftVersions = minecraftVersions;
this.eventLogManager = eventLogManager;
this.instanceLogManager = instanceLogManager;
}
async Task<Props<IMessageToController>?> IRegistrationHandler<IMessageToWeb, IMessageToController, RegisterWebMessage>.TryRegister(RpcConnectionToClient<IMessageToWeb> connection, RegisterWebMessage message) {
return await TryRegisterImpl(connection, message) ? CreateMessageHandlerActorProps(connection) : null;
}
public async Task<bool> TryRegisterImpl(RpcConnectionToClient<IMessageToWeb> connection, RegisterWebMessage message) {
if (webAuthToken.FixedTimeEquals(message.AuthToken)) {
Logger.Information("Web authorized successfully.");
await connection.Send(new RegisterWebResultMessage(true));
return true;
}
else {
Logger.Warning("Web failed to authorize, invalid token.");
await connection.Send(new RegisterWebResultMessage(false));
return false;
}
}
private Props<IMessageToController> CreateMessageHandlerActorProps(RpcConnectionToClient<IMessageToWeb> connection) {
var init = new WebMessageHandlerActor.Init(connection, this, controllerState, instanceLogManager, userManager, roleManager, userRoleManager, userLoginManager, auditLogManager, agentManager, minecraftVersions, eventLogManager);
return WebMessageHandlerActor.Factory(init);
}
}

View File

@@ -1,5 +1,5 @@
using Phantom.Common.Data;
using Phantom.Utils.Rpc.New;
using Phantom.Utils.Rpc;
using Phantom.Utils.Rpc.Runtime.Tls;
namespace Phantom.Controller;

View File

@@ -3,7 +3,8 @@ using Phantom.Utils.Cryptography;
using Phantom.Utils.IO;
using Phantom.Utils.Logging;
using Phantom.Utils.Monads;
using Phantom.Utils.Rpc.New;
using Phantom.Utils.Rpc;
using Phantom.Utils.Rpc.Runtime.Tls;
using Serilog;
namespace Phantom.Controller;

View File

@@ -1,13 +1,16 @@
using System.Reflection;
using Phantom.Common.Messages.Agent;
using Phantom.Common.Messages.Web;
using Phantom.Controller;
using Phantom.Controller.Database.Postgres;
using Phantom.Controller.Services;
using Phantom.Controller.Services.Rpc;
using Phantom.Utils.IO;
using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.New;
using Phantom.Utils.Rpc.Runtime.Server;
using Phantom.Utils.Runtime;
using Phantom.Utils.Tasks;
using RpcAgentServer = Phantom.Utils.Rpc.Runtime.Server.RpcServer<Phantom.Common.Messages.Agent.IMessageToController, Phantom.Common.Messages.Agent.IMessageToAgent>;
using RpcWebServer = Phantom.Utils.Rpc.Runtime.Server.RpcServer<Phantom.Common.Messages.Web.IMessageToController, Phantom.Common.Messages.Web.IMessageToWeb>;
var shutdownCancellationTokenSource = new CancellationTokenSource();
var shutdownCancellationToken = shutdownCancellationTokenSource.Token;
@@ -54,12 +57,28 @@ try {
var dbContextFactory = new ApplicationDbContextFactory(sqlConnectionString);
using var controllerServices = new ControllerServices(dbContextFactory, agentKeyData.AuthToken, webKeyData.AuthToken, shutdownCancellationToken);
using var controllerServices = new ControllerServices(dbContextFactory, agentKeyData.AuthToken, shutdownCancellationToken);
await controllerServices.Initialize();
var agentConnectionParameters = new RpcServerConnectionParameters(
EndPoint: agentRpcServerHost,
Certificate: agentKeyData.Certificate,
AuthToken: agentKeyData.AuthToken,
SendQueueCapacity: 100,
PingInterval: TimeSpan.FromSeconds(10)
);
var webConnectionParameters = new RpcServerConnectionParameters(
EndPoint: webRpcServerHost,
Certificate: webKeyData.Certificate,
AuthToken: webKeyData.AuthToken,
SendQueueCapacity: 500,
PingInterval: TimeSpan.FromMinutes(1)
);
LinkedTasks<bool> rpcServerTasks = new LinkedTasks<bool>([
new RpcServer("Agent", agentRpcServerHost, agentKeyData.Certificate, new RpcServerAgentHandshake(agentKeyData.AuthToken)).Run(shutdownCancellationToken),
// new RpcServer("Web", webRpcServerHost, webKeyData.Certificate).Run(shutdownCancellationToken),
new RpcAgentServer("Agent", agentConnectionParameters, AgentMessageRegistries.Definitions, controllerServices.AgentRegistrar).Run(shutdownCancellationToken),
new RpcWebServer("Web", webConnectionParameters, WebMessageRegistries.Definitions, controllerServices.WebRegistrar).Run(shutdownCancellationToken),
]);
// If either RPC server crashes, stop the whole process.
@@ -72,11 +91,6 @@ try {
}
return 0;
// await Task.WhenAll(
// RpcServerRuntime.Launch(ConfigureRpc("Agent", agentRpcServerHost, agentRpcServerPort, agentKeyData), AgentMessageRegistries.Definitions, controllerServices.AgentRegistrationHandler, controllerServices.ActorSystem, shutdownCancellationToken),
// RpcServerRuntime.Launch(ConfigureRpc("Web", webRpcServerHost, webRpcServerPort, webKeyData), WebMessageRegistries.Definitions, controllerServices.WebRegistrationHandler, controllerServices.ActorSystem, shutdownCancellationToken)
// );
} catch (OperationCanceledException) {
return 0;
} catch (StopProcedureException) {

View File

@@ -26,7 +26,7 @@ sealed record Variables(
EndPoint webRpcServerHost = new IPEndPoint(
EnvironmentVariables.GetIpAddress("WEB_RPC_SERVER_HOST").WithDefault(IPAddress.Any),
EnvironmentVariables.GetPortNumber("WEB_RPC_SERVER_PORT").WithDefault(9401)
EnvironmentVariables.GetPortNumber("WEB_RPC_SERVER_PORT").WithDefault(9402)
);
return new Variables(

View File

@@ -23,7 +23,6 @@
<ItemGroup>
<PackageReference Update="BCrypt.Net-Next.StrongName" Version="4.0.3" />
<PackageReference Update="MemoryPack" Version="1.10.0" />
<PackageReference Update="NetMQ" Version="4.0.1.13" />
</ItemGroup>
<ItemGroup>

View File

@@ -18,8 +18,6 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Phantom.Agent", "Agent\Phan
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Phantom.Agent.Minecraft", "Agent\Phantom.Agent.Minecraft\Phantom.Agent.Minecraft.csproj", "{9FE000D0-91AC-4CB4-8956-91CCC0270015}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Phantom.Agent.Rpc", "Agent\Phantom.Agent.Rpc\Phantom.Agent.Rpc.csproj", "{665C7B87-0165-48BC-B6A6-17A3812A70C9}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Phantom.Agent.Services", "Agent\Phantom.Agent.Services\Phantom.Agent.Services.csproj", "{AEE8B77E-AB07-423F-9981-8CD829ACB834}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Phantom.Common.Data", "Common\Phantom.Common.Data\Phantom.Common.Data.csproj", "{6C3DB1E5-F695-4D70-8F3A-78C2957274BE}"
@@ -76,10 +74,6 @@ Global
{9FE000D0-91AC-4CB4-8956-91CCC0270015}.Debug|Any CPU.Build.0 = Debug|Any CPU
{9FE000D0-91AC-4CB4-8956-91CCC0270015}.Release|Any CPU.ActiveCfg = Release|Any CPU
{9FE000D0-91AC-4CB4-8956-91CCC0270015}.Release|Any CPU.Build.0 = Release|Any CPU
{665C7B87-0165-48BC-B6A6-17A3812A70C9}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{665C7B87-0165-48BC-B6A6-17A3812A70C9}.Debug|Any CPU.Build.0 = Debug|Any CPU
{665C7B87-0165-48BC-B6A6-17A3812A70C9}.Release|Any CPU.ActiveCfg = Release|Any CPU
{665C7B87-0165-48BC-B6A6-17A3812A70C9}.Release|Any CPU.Build.0 = Release|Any CPU
{AEE8B77E-AB07-423F-9981-8CD829ACB834}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{AEE8B77E-AB07-423F-9981-8CD829ACB834}.Debug|Any CPU.Build.0 = Debug|Any CPU
{AEE8B77E-AB07-423F-9981-8CD829ACB834}.Release|Any CPU.ActiveCfg = Release|Any CPU
@@ -164,7 +158,6 @@ Global
GlobalSection(NestedProjects) = preSolution
{418BE1BF-9F63-4B46-B4E4-DF64C3B3DDA7} = {F5878792-64C8-4ECF-A075-66341FF97127}
{9FE000D0-91AC-4CB4-8956-91CCC0270015} = {F5878792-64C8-4ECF-A075-66341FF97127}
{665C7B87-0165-48BC-B6A6-17A3812A70C9} = {F5878792-64C8-4ECF-A075-66341FF97127}
{AEE8B77E-AB07-423F-9981-8CD829ACB834} = {F5878792-64C8-4ECF-A075-66341FF97127}
{6C3DB1E5-F695-4D70-8F3A-78C2957274BE} = {01CB1A81-8950-471C-BFDF-F135FDDB2C18}
{95B55357-F8F0-48C2-A1C2-5EA997651783} = {01CB1A81-8950-471C-BFDF-F135FDDB2C18}

View File

@@ -31,25 +31,48 @@ public static class PhantomLogger {
}
public static ILogger Create<T>() {
return Create(typeof(T).Name);
return Create(TypeName<T>());
}
public static ILogger Create<T>(string name) {
return Create(typeof(T).Name, name);
return Create(ConcatNames(TypeName<T>(), name));
}
public static ILogger Create<T>(string name1, string name2) {
return Create(typeof(T).Name, ConcatNames(name1, name2));
return Create(ConcatNames(TypeName<T>(), name1, name2));
}
public static ILogger Create<T1, T2>() {
return Create(typeof(T1).Name, typeof(T2).Name);
return Create(ConcatNames(TypeName<T1>(), TypeName<T2>()));
}
private static string ConcatNames(string name1, string name2) {
public static ILogger Create<T1, T2>(string name) {
return Create(ConcatNames(TypeName<T1>(), TypeName<T2>(), name));
}
public static ILogger Create<T1, T2>(string name1, string name2) {
return Create(ConcatNames(TypeName<T1>(), TypeName<T2>(), ConcatNames(name1, name2)));
}
private static string TypeName<T>() {
string typeName = typeof(T).Name;
int genericsStartIndex = typeName.IndexOf('`');
return genericsStartIndex > 0 ? typeName[..genericsStartIndex] : typeName;
}
public static string ConcatNames(string name1, string name2) {
return name1 + ":" + name2;
}
public static string ConcatNames(string name1, string name2, string name3) {
return ConcatNames(name1, ConcatNames(name2, name3));
}
public static string ShortenGuid(Guid guid) {
var prefix = guid.ToString();
return prefix[..prefix.IndexOf('-')];
}
public static void Dispose() {
Root.Dispose();
Base.Dispose();

View File

@@ -1,18 +1,12 @@
using System.Collections.Immutable;
using System.Diagnostics.CodeAnalysis;
using System.Security.Cryptography;
using MemoryPack;
namespace Phantom.Common.Data;
namespace Phantom.Utils.Rpc;
[MemoryPackable(GenerateType.VersionTolerant)]
[SuppressMessage("ReSharper", "MemberCanBePrivate.Global")]
public sealed partial class AuthToken {
public sealed class AuthToken {
public const int Length = 12;
[MemoryPackOrder(0)]
[MemoryPackInclude]
public readonly ImmutableArray<byte> Bytes;
public ImmutableArray<byte> Bytes { get; }
public AuthToken(ImmutableArray<byte> bytes) {
if (bytes.Length != Length) {
@@ -30,10 +24,6 @@ public sealed partial class AuthToken {
return CryptographicOperations.FixedTimeEquals(Bytes.AsSpan(), other);
}
internal void WriteTo(Span<byte> span) {
Bytes.CopyTo(span);
}
public static AuthToken Generate() {
return new AuthToken([..RandomNumberGenerator.GetBytes(Length)]);
}

View File

@@ -0,0 +1,60 @@
using Phantom.Utils.Rpc.Frame.Types;
namespace Phantom.Utils.Rpc.Frame;
interface IFrame {
private const byte TypePingId = 0;
private const byte TypePongId = 1;
private const byte TypeMessageId = 2;
private const byte TypeReplyId = 3;
private const byte TypeErrorId = 4;
static readonly ReadOnlyMemory<byte> TypePing = new ([TypePingId]);
static readonly ReadOnlyMemory<byte> TypePong = new ([TypePongId]);
static readonly ReadOnlyMemory<byte> TypeMessage = new ([TypeMessageId]);
static readonly ReadOnlyMemory<byte> TypeReply = new ([TypeReplyId]);
static readonly ReadOnlyMemory<byte> TypeError = new ([TypeErrorId]);
internal static async Task ReadFrom(Stream stream, IFrameReader reader, CancellationToken cancellationToken) {
byte[] oneByteBuffer = new byte[1];
while (!cancellationToken.IsCancellationRequested) {
await stream.ReadExactlyAsync(oneByteBuffer, cancellationToken);
switch (oneByteBuffer[0]) {
case TypePingId:
var pingTime = await PingFrame.Read(stream, cancellationToken);
await reader.OnPingFrame(pingTime, cancellationToken);
break;
case TypePongId:
var pongFrame = await PongFrame.Read(stream, cancellationToken);
reader.OnPongFrame(pongFrame);
break;
case TypeMessageId:
var messageFrame = await MessageFrame.Read(stream, cancellationToken);
await reader.OnMessageFrame(messageFrame, cancellationToken);
break;
case TypeReplyId:
var replyFrame = await ReplyFrame.Read(stream, cancellationToken);
reader.OnReplyFrame(replyFrame);
break;
case TypeErrorId:
var errorFrame = await ErrorFrame.Read(stream, cancellationToken);
reader.OnErrorFrame(errorFrame);
break;
default:
reader.OnUnknownFrameId(oneByteBuffer[0]);
break;
}
}
}
ReadOnlyMemory<byte> FrameType { get; }
Task Write(Stream stream, CancellationToken cancellationToken = default);
}

View File

@@ -0,0 +1,12 @@
using Phantom.Utils.Rpc.Frame.Types;
namespace Phantom.Utils.Rpc.Frame;
interface IFrameReader {
ValueTask OnPingFrame(DateTimeOffset pingTime, CancellationToken cancellationToken);
void OnPongFrame(PongFrame frame);
Task OnMessageFrame(MessageFrame frame, CancellationToken cancellationToken);
void OnReplyFrame(ReplyFrame frame);
void OnErrorFrame(ErrorFrame frame);
void OnUnknownFrameId(byte frameId);
}

View File

@@ -0,0 +1,18 @@
using Phantom.Utils.Rpc.Runtime;
namespace Phantom.Utils.Rpc.Frame.Types;
sealed record ErrorFrame(uint ReplyingToMessageId, RpcError Error) : IFrame {
public ReadOnlyMemory<byte> FrameType => IFrame.TypeError;
public async Task Write(Stream stream, CancellationToken cancellationToken) {
await Serialization.WriteUnsignedInt(ReplyingToMessageId, stream, cancellationToken);
await Serialization.WriteByte((byte) Error, stream, cancellationToken);
}
public static async Task<ErrorFrame> Read(Stream stream, CancellationToken cancellationToken) {
var replyingToMessageId = await Serialization.ReadUnsignedInt(stream, cancellationToken);
var messageError = (RpcError) await Serialization.ReadByte(stream, cancellationToken);
return new ErrorFrame(replyingToMessageId, messageError);
}
}

View File

@@ -0,0 +1,39 @@
using Phantom.Utils.Rpc.Runtime;
namespace Phantom.Utils.Rpc.Frame.Types;
sealed record MessageFrame(uint MessageId, ushort RegistryCode, ReadOnlyMemory<byte> SerializedMessage) : IFrame {
public const int MaxMessageBytes = 1024 * 1024 * 8;
public ReadOnlyMemory<byte> FrameType => IFrame.TypeMessage;
public async Task Write(Stream stream, CancellationToken cancellationToken) {
int messageLength = SerializedMessage.Length;
CheckMessageLength(messageLength);
await Serialization.WriteUnsignedInt(MessageId, stream, cancellationToken);
await Serialization.WriteUnsignedShort(RegistryCode, stream, cancellationToken);
await Serialization.WriteSignedInt(messageLength, stream, cancellationToken);
await stream.WriteAsync(SerializedMessage, cancellationToken);
}
public static async Task<MessageFrame> Read(Stream stream, CancellationToken cancellationToken) {
var messageId = await Serialization.ReadUnsignedInt(stream, cancellationToken);
var registryCode = await Serialization.ReadUnsignedShort(stream, cancellationToken);
var essageLength = await Serialization.ReadSignedInt(stream, cancellationToken);
CheckMessageLength(essageLength);
var serializedMessage = await Serialization.ReadBytes(essageLength, stream, cancellationToken);
return new MessageFrame(messageId, registryCode, serializedMessage);
}
private static void CheckMessageLength(int messageLength) {
if (messageLength < 0) {
throw new RpcErrorException("Message length is negative.", RpcError.InvalidData);
}
if (messageLength > MaxMessageBytes) {
throw new RpcErrorException("Message is too large: " + messageLength + " > " + MaxMessageBytes + " bytes", RpcError.MessageTooLarge);
}
}
}

View File

@@ -0,0 +1,15 @@
namespace Phantom.Utils.Rpc.Frame.Types;
sealed record PingFrame : IFrame {
public static PingFrame Instance { get; } = new ();
public ReadOnlyMemory<byte> FrameType => IFrame.TypePing;
public async Task Write(Stream stream, CancellationToken cancellationToken) {
await Serialization.WriteSignedLong(DateTimeOffset.UtcNow.ToUnixTimeMilliseconds(), stream, cancellationToken);
}
public static async Task<DateTimeOffset> Read(Stream stream, CancellationToken cancellationToken) {
return DateTimeOffset.FromUnixTimeMilliseconds(await Serialization.ReadSignedLong(stream, cancellationToken));
}
}

View File

@@ -0,0 +1,13 @@
namespace Phantom.Utils.Rpc.Frame.Types;
sealed record PongFrame(DateTimeOffset PingTime) : IFrame {
public ReadOnlyMemory<byte> FrameType => IFrame.TypePong;
public async Task Write(Stream stream, CancellationToken cancellationToken) {
await Serialization.WriteSignedLong(PingTime.ToUnixTimeMilliseconds(), stream, cancellationToken);
}
public static async Task<PongFrame> Read(Stream stream, CancellationToken cancellationToken) {
return new PongFrame(DateTimeOffset.FromUnixTimeMilliseconds(await Serialization.ReadSignedLong(stream, cancellationToken)));
}
}

View File

@@ -0,0 +1,37 @@
using Phantom.Utils.Rpc.Runtime;
namespace Phantom.Utils.Rpc.Frame.Types;
sealed record ReplyFrame(uint ReplyingToMessageId, ReadOnlyMemory<byte> SerializedReply) : IFrame {
public const int MaxReplyBytes = 1024 * 1024 * 32;
public ReadOnlyMemory<byte> FrameType => IFrame.TypeReply;
public async Task Write(Stream stream, CancellationToken cancellationToken) {
int replyLength = SerializedReply.Length;
CheckReplyLength(replyLength);
await Serialization.WriteUnsignedInt(ReplyingToMessageId, stream, cancellationToken);
await Serialization.WriteSignedInt(replyLength, stream, cancellationToken);
await stream.WriteAsync(SerializedReply, cancellationToken);
}
public static async Task<ReplyFrame> Read(Stream stream, CancellationToken cancellationToken) {
var replyingToMessageId = await Serialization.ReadUnsignedInt(stream, cancellationToken);
var replyLength = await Serialization.ReadSignedInt(stream, cancellationToken);
CheckReplyLength(replyLength);
var reply = await Serialization.ReadBytes(replyLength, stream, cancellationToken);
return new ReplyFrame(replyingToMessageId, reply);
}
private static void CheckReplyLength(int replyLength) {
if (replyLength < 0) {
throw new RpcErrorException("Reply length is negative.", RpcError.InvalidData);
}
if (replyLength > MaxReplyBytes) {
throw new RpcErrorException("Reply is too large: " + replyLength + " > " + MaxReplyBytes + " bytes", RpcError.MessageTooLarge);
}
}
}

View File

@@ -0,0 +1,5 @@
namespace Phantom.Utils.Rpc;
public interface IRpcListener {
}

View File

@@ -0,0 +1,8 @@
using Phantom.Utils.Rpc.Runtime;
namespace Phantom.Utils.Rpc;
interface IRpcReplySender {
ValueTask SendReply<TReply>(uint replyingToMessageId, TReply reply, CancellationToken cancellationToken);
ValueTask SendError(uint replyingToMessageId, RpcError error, CancellationToken cancellationToken);
}

View File

@@ -1,6 +1,6 @@
namespace Phantom.Utils.Rpc.Message;
public interface IMessageDefinitions<TClientMessage, TServerMessage, TReplyMessage> : IReplyMessageFactory<TReplyMessage> where TReplyMessage : TClientMessage, TServerMessage {
MessageRegistry<TClientMessage> ToClient { get; }
MessageRegistry<TServerMessage> ToServer { get; }
public interface IMessageDefinitions<TClientToServerMessage, TServerToClientMessage> {
MessageRegistry<TServerToClientMessage> ToClient { get; }
MessageRegistry<TClientToServerMessage> ToServer { get; }
}

View File

@@ -0,0 +1,21 @@
using Phantom.Utils.Actor;
namespace Phantom.Utils.Rpc.Message;
public interface IMessageReceiver<TMessageBase> {
void OnPing();
void OnMessage(TMessageBase message);
Task<TReply> OnMessage<TMessage, TReply>(TMessage message, CancellationToken cancellationToken = default) where TMessage : TMessageBase, ICanReply<TReply>;
class Actor(ActorRef<TMessageBase> actor) : IMessageReceiver<TMessageBase> {
public virtual void OnPing() {}
public void OnMessage(TMessageBase message) {
actor.Tell(message);
}
public Task<TReply> OnMessage<TMessage, TReply>(TMessage message, CancellationToken cancellationToken = default) where TMessage : TMessageBase, ICanReply<TReply> {
return actor.Request(message, cancellationToken);
}
}
}

View File

@@ -1,6 +0,0 @@
namespace Phantom.Utils.Rpc.Message;
public interface IReply {
uint SequenceId { get; }
byte[] SerializedReply { get; }
}

View File

@@ -1,5 +0,0 @@
namespace Phantom.Utils.Rpc.Message;
public interface IReplyMessageFactory<TReplyMessage> {
TReplyMessage CreateReplyMessage(uint sequenceId, byte[] serializedReply);
}

View File

@@ -1,5 +0,0 @@
namespace Phantom.Utils.Rpc.Message;
interface IReplySender {
Task SendReply(uint sequenceId, byte[] serializedReply);
}

View File

@@ -1,35 +0,0 @@
using Phantom.Utils.Actor;
using Phantom.Utils.Logging;
using Serilog;
namespace Phantom.Utils.Rpc.Message;
sealed class MessageHandler<TMessageBase> {
private readonly ILogger logger;
private readonly ActorRef<TMessageBase> handlerActor;
private readonly IReplySender replySender;
public MessageHandler(string loggerName, ActorRef<TMessageBase> handlerActor, IReplySender replySender) {
this.logger = PhantomLogger.Create("MessageHandler", loggerName);
this.handlerActor = handlerActor;
this.replySender = replySender;
}
public void Tell(TMessageBase message) {
handlerActor.Tell(message);
}
public Task TellAndReply<TMessage, TReply>(TMessage message, uint sequenceId) where TMessage : ICanReply<TReply> {
return handlerActor.Request(message).ContinueWith(task => {
if (task.IsCompletedSuccessfully) {
return replySender.SendReply(sequenceId, MessageSerializer.Serialize(task.Result));
}
if (task.IsFaulted) {
logger.Error(task.Exception, "Failed to handle message {Type}.", message.GetType().Name);
}
return task;
}, TaskScheduler.Default);
}
}

View File

@@ -0,0 +1,13 @@
using Phantom.Utils.Collections;
namespace Phantom.Utils.Rpc.Message;
sealed class MessageReceiveTracker {
private readonly RangeSet<uint> receivedMessageIds = new ();
public bool ReceiveMessage(uint messageId) {
lock (receivedMessageIds) {
return receivedMessageIds.Add(messageId);
}
}
}

View File

@@ -1,26 +1,19 @@
using System.Buffers;
using System.Diagnostics.CodeAnalysis;
using System.Diagnostics.CodeAnalysis;
using Phantom.Utils.Actor;
using Phantom.Utils.Rpc.Frame.Types;
using Phantom.Utils.Rpc.Runtime;
using Serilog;
using Serilog.Events;
namespace Phantom.Utils.Rpc.Message;
public sealed class MessageRegistry<TMessageBase> {
private const int DefaultBufferSize = 512;
private readonly ILogger logger;
public sealed class MessageRegistry<TMessageBase>(ILogger logger) {
private readonly Dictionary<Type, ushort> typeToCodeMapping = new ();
private readonly Dictionary<ushort, Type> codeToTypeMapping = new ();
private readonly Dictionary<ushort, Action<ReadOnlyMemory<byte>, ushort, MessageHandler<TMessageBase>>> codeToHandlerMapping = new ();
public MessageRegistry(ILogger logger) {
this.logger = logger;
}
private readonly Dictionary<ushort, Func<uint, ReadOnlyMemory<byte>, RpcMessageHandler<TMessageBase>, CancellationToken, Task>> codeToHandlerMapping = new ();
public void Add<TMessage>(ushort code) where TMessage : TMessageBase {
if (HasReplyType(typeof(TMessage))) {
throw new ArgumentException("This overload is for messages without a reply");
throw new ArgumentException("This overload is for messages without a reply.");
}
AddTypeCodeMapping<TMessage>(code);
@@ -44,140 +37,68 @@ public sealed class MessageRegistry<TMessageBase> {
return messageType.GetInterfaces().Any(type => type.FullName is {} name && name.StartsWith(replyInterfaceName, StringComparison.Ordinal));
}
internal bool TryGetType(ReadOnlyMemory<byte> data, [NotNullWhen(true)] out Type? type) {
try {
var code = MessageSerializer.ReadCode(ref data);
return codeToTypeMapping.TryGetValue(code, out type);
} catch (Exception) {
type = null;
return false;
}
internal bool TryGetType(MessageFrame frame, [NotNullWhen(true)] out Type? type) {
return codeToTypeMapping.TryGetValue(frame.RegistryCode, out type);
}
public ReadOnlySpan<byte> Write<TMessage>(TMessage message) where TMessage : TMessageBase {
if (!GetMessageCode<TMessage>(out var code)) {
return default;
}
var buffer = new ArrayBufferWriter<byte>(DefaultBufferSize);
try {
MessageSerializer.WriteCode(buffer, code);
MessageSerializer.Serialize(buffer, message);
CheckWrittenBufferLength<TMessage>(buffer);
return buffer.WrittenSpan;
} catch (Exception e) {
LogWriteFailure<TMessage>(e);
return default;
}
}
public ReadOnlySpan<byte> Write<TMessage, TReply>(uint sequenceId, TMessage message) where TMessage : TMessageBase, ICanReply<TReply> {
if (!GetMessageCode<TMessage>(out var code)) {
return default;
}
var buffer = new ArrayBufferWriter<byte>(DefaultBufferSize);
try {
MessageSerializer.WriteCode(buffer, code);
MessageSerializer.WriteSequenceId(buffer, sequenceId);
MessageSerializer.Serialize(buffer, message);
CheckWrittenBufferLength<TMessage>(buffer);
return buffer.WrittenSpan;
} catch (Exception e) {
LogWriteFailure<TMessage>(e);
return default;
}
}
private bool GetMessageCode<TMessage>(out ushort code) where TMessage : TMessageBase {
if (typeToCodeMapping.TryGetValue(typeof(TMessage), out code)) {
return true;
internal MessageFrame CreateFrame<TMessage>(uint messageId, TMessage message) where TMessage : TMessageBase {
if (typeToCodeMapping.TryGetValue(typeof(TMessage), out ushort code)) {
return new MessageFrame(messageId, code, Serialization.Serialize(message));
}
else {
logger.Error("Unknown message type {Type}.", typeof(TMessage));
return false;
throw new ArgumentException("Unknown message type: " + typeof(TMessage));
}
}
private void CheckWrittenBufferLength<TMessage>(ArrayBufferWriter<byte> buffer) where TMessage : TMessageBase {
if (buffer.WrittenCount > DefaultBufferSize && logger.IsEnabled(LogEventLevel.Verbose)) {
logger.Verbose("Serializing {Type} exceeded default buffer size: {WrittenSize} B > {DefaultBufferSize} B", typeof(TMessage).Name, buffer.WrittenCount, DefaultBufferSize);
}
}
internal async Task Handle(MessageFrame frame, RpcMessageHandler<TMessageBase> handler, CancellationToken cancellationToken) {
uint messageId = frame.MessageId;
private void LogWriteFailure<TMessage>(Exception e) where TMessage : TMessageBase {
logger.Error(e, "Failed to serialize message {Type}.", typeof(TMessage).Name);
}
internal bool Read<TMessage>(ReadOnlyMemory<byte> data, out TMessage message) where TMessage : TMessageBase {
if (ReadTypeCode(ref data, out ushort code) && codeToTypeMapping.TryGetValue(code, out var expectedType) && expectedType == typeof(TMessage) && ReadMessage(data, out message)) {
return true;
if (codeToHandlerMapping.TryGetValue(frame.RegistryCode, out var action)) {
await action(messageId, frame.SerializedMessage, handler, cancellationToken);
}
else {
message = default!;
return false;
logger.Error("Unknown message code {Code} for message {MessageId}.", frame.RegistryCode, messageId);
await handler.SendError(messageId, RpcError.UnknownMessageRegistryCode, cancellationToken);
}
}
internal void Handle(ReadOnlyMemory<byte> data, MessageHandler<TMessageBase> handler) {
if (!ReadTypeCode(ref data, out var code)) {
private async Task DeserializationHandler<TMessage>(uint messageId, ReadOnlyMemory<byte> serializedMessage, RpcMessageHandler<TMessageBase> handler, CancellationToken cancellationToken) where TMessage : TMessageBase {
TMessage message;
try {
message = Serialization.Deserialize<TMessage>(serializedMessage);
} catch (Exception e) {
logger.Error(e, "Could not deserialize message {MessageId} ({MessageType}).", messageId, typeof(TMessage).Name);
await handler.SendError(messageId, RpcError.MessageDeserializationError, cancellationToken);
return;
}
if (!codeToHandlerMapping.TryGetValue(code, out var handle)) {
logger.Error("Unknown message code {Code}.", code);
handler.Receiver.OnMessage(message);
}
private async Task DeserializationHandler<TMessage, TReply>(uint messageId, ReadOnlyMemory<byte> serializedMessage, RpcMessageHandler<TMessageBase> handler, CancellationToken cancellationToken) where TMessage : TMessageBase, ICanReply<TReply> {
TMessage message;
try {
message = Serialization.Deserialize<TMessage>(serializedMessage);
} catch (Exception e) {
logger.Error(e, "Could not deserialize message {MessageId} ({MessageType}).", messageId, typeof(TMessage).Name);
await handler.SendError(messageId, RpcError.MessageDeserializationError, cancellationToken);
return;
}
handle(data, code, handler);
}
private bool ReadTypeCode(ref ReadOnlyMemory<byte> data, out ushort code) {
TReply reply;
try {
code = MessageSerializer.ReadCode(ref data);
return true;
reply = await handler.Receiver.OnMessage<TMessage, TReply>(message, cancellationToken);
} catch (Exception e) {
code = default;
logger.Error(e, "Failed to deserialize message code.");
return false;
}
logger.Error(e, "Could not handle message {MessageId} ({MessageType}).", messageId, typeof(TMessage).Name);
await handler.SendError(messageId, RpcError.MessageHandlingError, cancellationToken);
return;
}
private bool ReadSequenceId<TMessage, TReply>(ref ReadOnlyMemory<byte> data, out uint sequenceId) where TMessage : TMessageBase, ICanReply<TReply> {
try {
sequenceId = MessageSerializer.ReadSequenceId(ref data);
return true;
await handler.SendReply(messageId, reply, cancellationToken);
} catch (Exception e) {
sequenceId = default;
logger.Error(e, "Failed to deserialize sequence ID of message {Type}.", typeof(TMessage).Name);
return false;
}
}
private bool ReadMessage<TMessage>(ReadOnlyMemory<byte> data, out TMessage message) where TMessage : TMessageBase {
try {
message = MessageSerializer.Deserialize<TMessage>(data);
return true;
} catch (Exception e) {
message = default!;
logger.Error(e, "Failed to deserialize message {Type}.", typeof(TMessage).Name);
return false;
}
}
private void DeserializationHandler<TMessage>(ReadOnlyMemory<byte> data, ushort code, MessageHandler<TMessageBase> handler) where TMessage : TMessageBase {
if (ReadMessage<TMessage>(data, out var message)) {
handler.Tell(message);
}
}
private void DeserializationHandler<TMessage, TReply>(ReadOnlyMemory<byte> data, ushort code, MessageHandler<TMessageBase> handler) where TMessage : TMessageBase, ICanReply<TReply> {
if (ReadSequenceId<TMessage, TReply>(ref data, out var sequenceId) && ReadMessage<TMessage>(data, out var message)) {
handler.TellAndReply<TMessage, TReply>(message, sequenceId);
logger.Error(e, "Could not reply to message {MessageId} ({MessageType}).", messageId, typeof(TMessage).Name);
await handler.SendError(messageId, RpcError.MessageHandlingError, cancellationToken);
}
}
}

View File

@@ -1,5 +1,6 @@
using System.Collections.Concurrent;
using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Runtime;
using Phantom.Utils.Tasks;
using Serilog;
@@ -7,55 +8,57 @@ namespace Phantom.Utils.Rpc.Message;
sealed class MessageReplyTracker {
private readonly ILogger logger;
private readonly ConcurrentDictionary<uint, TaskCompletionSource<byte[]>> replyTasks = new (4, 16);
private uint lastSequenceId;
private readonly ConcurrentDictionary<uint, TaskCompletionSource<ReadOnlyMemory<byte>>> replyTasks = new (concurrencyLevel: 2, capacity: 16);
internal MessageReplyTracker(string loggerName) {
this.logger = PhantomLogger.Create<MessageReplyTracker>(loggerName);
}
public uint RegisterReply() {
var sequenceId = Interlocked.Increment(ref lastSequenceId);
replyTasks[sequenceId] = AsyncTasks.CreateCompletionSource<byte[]>();
return sequenceId;
public void RegisterReply(uint messageId) {
replyTasks[messageId] = AsyncTasks.CreateCompletionSource<ReadOnlyMemory<byte>>();
}
public async Task<TReply> WaitForReply<TReply>(uint sequenceId, TimeSpan waitForReplyTime, CancellationToken cancellationToken) {
if (!replyTasks.TryGetValue(sequenceId, out var completionSource)) {
logger.Warning("No reply callback for id {SequenceId}.", sequenceId);
throw new ArgumentException("No reply callback for id: " + sequenceId, nameof(sequenceId));
public async Task<TReply> WaitForReply<TReply>(uint messageId, TimeSpan waitForReplyTime, CancellationToken cancellationToken) {
if (!replyTasks.TryGetValue(messageId, out var completionSource)) {
logger.Warning("No reply callback for id {MessageId}.", messageId);
throw new ArgumentException("No reply callback for id: " + messageId, nameof(messageId));
}
try {
byte[] replyBytes = await completionSource.Task.WaitAsync(waitForReplyTime, cancellationToken);
return MessageSerializer.Deserialize<TReply>(replyBytes);
ReadOnlyMemory<byte> serializedReply = await completionSource.Task.WaitAsync(waitForReplyTime, cancellationToken);
return Serialization.Deserialize<TReply>(serializedReply);
} catch (TimeoutException) {
logger.Debug("Timed out waiting for reply with id {SequenceId}.", sequenceId);
logger.Debug("Timed out waiting for reply with id {MessageId}.", messageId);
throw;
} catch (OperationCanceledException) {
logger.Debug("Cancelled waiting for reply with id {SequenceId}.", sequenceId);
logger.Debug("Cancelled waiting for reply with id {MessageId}.", messageId);
throw;
} catch (Exception e) {
logger.Warning(e, "Error processing reply with id {SequenceId}.", sequenceId);
logger.Warning(e, "Error processing reply with id {MessageId}.", messageId);
throw;
} finally {
ForgetReply(sequenceId);
ForgetReply(messageId);
}
}
public void ForgetReply(uint sequenceId) {
if (replyTasks.TryRemove(sequenceId, out var task)) {
public void ForgetReply(uint messageId) {
if (replyTasks.TryRemove(messageId, out var task)) {
task.SetCanceled();
}
}
public void ReceiveReply(uint sequenceId, byte[] serializedReply) {
if (replyTasks.TryRemove(sequenceId, out var task)) {
public void FailReply(uint messageId, RpcErrorException e) {
if (replyTasks.TryRemove(messageId, out var task)) {
task.SetException(e);
}
}
public void ReceiveReply(uint messageId, ReadOnlyMemory<byte> serializedReply) {
if (replyTasks.TryRemove(messageId, out var task)) {
task.SetResult(serializedReply);
}
else {
logger.Warning("Received a reply with id {SequenceId} but no registered callback.", sequenceId);
logger.Warning("Received a reply with id {MessageId} but no registered callback.", messageId);
}
}
}

View File

@@ -1,45 +0,0 @@
using System.Buffers;
using System.Buffers.Binary;
using MemoryPack;
namespace Phantom.Utils.Rpc.Message;
static class MessageSerializer {
private static readonly MemoryPackSerializerOptions SerializerOptions = MemoryPackSerializerOptions.Utf8;
public static byte[] Serialize<T>(T message) {
return MemoryPackSerializer.Serialize(message, SerializerOptions);
}
public static void Serialize<T>(IBufferWriter<byte> destination, T message) {
MemoryPackSerializer.Serialize(typeof(T), destination, message, SerializerOptions);
}
public static T Deserialize<T>(ReadOnlyMemory<byte> memory) {
return MemoryPackSerializer.Deserialize<T>(memory.Span) ?? throw new NullReferenceException();
}
public static void WriteCode(IBufferWriter<byte> destination, ushort value) {
Span<byte> buffer = stackalloc byte[2];
BinaryPrimitives.WriteUInt16LittleEndian(buffer, value);
destination.Write(buffer);
}
public static ushort ReadCode(ref ReadOnlyMemory<byte> memory) {
ushort value = BinaryPrimitives.ReadUInt16LittleEndian(memory.Span);
memory = memory[2..];
return value;
}
public static void WriteSequenceId(IBufferWriter<byte> destination, uint sequenceId) {
Span<byte> buffer = stackalloc byte[4];
BinaryPrimitives.WriteUInt32LittleEndian(buffer, sequenceId);
destination.Write(buffer);
}
public static uint ReadSequenceId(ref ReadOnlyMemory<byte> memory) {
uint value = BinaryPrimitives.ReadUInt32LittleEndian(memory.Span);
memory = memory[4..];
return value;
}
}

View File

@@ -1,17 +0,0 @@
using Phantom.Utils.Rpc.Runtime;
namespace Phantom.Utils.Rpc.Message;
sealed class ReplySender<TMessageBase, TReplyMessage> : IReplySender where TReplyMessage : TMessageBase {
private readonly RpcConnection<TMessageBase> connection;
private readonly IReplyMessageFactory<TReplyMessage> replyMessageFactory;
public ReplySender(RpcConnection<TMessageBase> connection, IReplyMessageFactory<TReplyMessage> replyMessageFactory) {
this.connection = connection;
this.replyMessageFactory = replyMessageFactory;
}
public Task SendReply(uint sequenceId, byte[] serializedReply) {
return connection.Send(replyMessageFactory.CreateReplyMessage(sequenceId, serializedReply));
}
}

View File

@@ -1,100 +0,0 @@
using System.Net.Security;
using System.Net.Sockets;
using System.Security.Authentication;
using System.Security.Cryptography.X509Certificates;
using Phantom.Utils.Logging;
using Serilog;
namespace Phantom.Utils.Rpc.New;
public sealed class RpcClient<TClientToServerMessage, TServerToClientMessage>(string loggerName, string host, ushort port, string distinguishedName, RpcCertificateThumbprint certificateThumbprint, RpcClientHandshake handshake) {
private readonly ILogger logger = PhantomLogger.Create<RpcClient<TClientToServerMessage, TServerToClientMessage>>(loggerName);
private bool loggedCertificateValidationError = false;
private bool ValidateServerCertificate(object sender, X509Certificate? certificate, X509Chain? chain, SslPolicyErrors sslPolicyErrors) {
if (certificate == null || sslPolicyErrors.HasFlag(SslPolicyErrors.RemoteCertificateNotAvailable)) {
logger.Error("Could not establish a secure connection, server did not provide a certificate.");
}
else if (sslPolicyErrors.HasFlag(SslPolicyErrors.RemoteCertificateNameMismatch)) {
logger.Error("Could not establish a secure connection, server certificate has the wrong name: {Name}", certificate.Subject);
}
else if (!certificateThumbprint.Check(certificate)) {
logger.Error("Could not establish a secure connection, server certificate does not match.");
}
else if (TlsSupport.CheckAlgorithm((X509Certificate2) certificate) is {} error) {
logger.Error("Could not establish a secure connection, server certificate rejected because it uses {ActualAlgorithmName} instead of {ExpectedAlgorithmName}.", error.ActualAlgorithmName, error.ExpectedAlgorithmName);
}
else if ((sslPolicyErrors & ~SslPolicyErrors.RemoteCertificateChainErrors) != SslPolicyErrors.None) {
logger.Error("Could not establish a secure connection, server certificate validation failed.");
}
else {
return true;
}
loggedCertificateValidationError = true;
return false;
}
public async Task<RpcClientConnection<TClientToServerMessage>?> Connect(CancellationToken shutdownToken) {
SslClientAuthenticationOptions sslOptions = new () {
AllowRenegotiation = false,
AllowTlsResume = true,
CertificateRevocationCheckMode = X509RevocationMode.NoCheck,
EnabledSslProtocols = TlsSupport.SupportedProtocols,
EncryptionPolicy = EncryptionPolicy.RequireEncryption,
RemoteCertificateValidationCallback = ValidateServerCertificate,
TargetHost = distinguishedName,
};
try {
using var clientSocket = new Socket(SocketType.Stream, ProtocolType.Tcp);
logger.Information("Connecting to {Host}:{Port}...", host, port);
try {
await clientSocket.ConnectAsync(host, port, shutdownToken);
} catch (Exception e) {
logger.Error(e, "Could not connect.");
return null;
}
await using var stream = new SslStream(new NetworkStream(clientSocket, ownsSocket: false), leaveInnerStreamOpen: false);
try {
loggedCertificateValidationError = false;
await stream.AuthenticateAsClientAsync(sslOptions, shutdownToken);
} catch (AuthenticationException e) {
if (!loggedCertificateValidationError) {
logger.Error(e, "Could not establish a secure connection.");
}
return null;
}
logger.Information("Established a secure connection.");
try {
await handshake.AcceptServer(stream, shutdownToken);
} catch (EndOfStreamException) {
logger.Warning("Could not perform application handshake, connection lost.");
return null;
} catch (Exception e) {
logger.Warning(e, "Could not perform application handshake.");
return null;
}
// await stream.WriteAsync(new byte[] { 1, 2, 3 }, shutdownToken);
byte[] buffer = new byte[1024];
int readBytes;
while ((readBytes = await stream.ReadAsync(buffer, shutdownToken)) > 0) {}
} catch (Exception e) {
logger.Error(e, "Client crashed with uncaught exception.");
return null;
} finally {
logger.Information("Client stopped.");
}
return true;
}
}

View File

@@ -1,30 +0,0 @@
using System.Threading.Channels;
using Phantom.Utils.Logging;
using Serilog;
namespace Phantom.Utils.Rpc.New;
public class RpcClientConnection<TClientToServerMessage>(string loggerName, CancellationToken shutdownCancellationToken) : IAsyncDisposable {
private readonly ILogger logger = PhantomLogger.Create<RpcClientConnection<TClientToServerMessage>>(loggerName);
private readonly Channel<TClientToServerMessage> sendQueue = Channel.CreateBounded<TClientToServerMessage>(new BoundedChannelOptions(500) {
AllowSynchronousContinuations = false,
FullMode = BoundedChannelFullMode.Wait,
SingleReader = true,
SingleWriter = false,
});
public async Task WaitFor() {
}
public async Task Send(TClientToServerMessage message, CancellationToken cancellationToken) {
if (!sendQueue.Writer.TryWrite(message)) {
await sendQueue.Writer.WriteAsync(message, cancellationToken);
}
}
public async ValueTask DisposeAsync() {
// TODO release managed resources here
}
}

View File

@@ -1,5 +0,0 @@
namespace Phantom.Utils.Rpc.New;
public abstract class RpcClientHandshake {
protected internal abstract Task<bool> AcceptServer(Stream stream, CancellationToken cancellationToken);
}

View File

@@ -1,6 +0,0 @@
namespace Phantom.Utils.Rpc.New;
public enum RpcHandshakeResult : byte {
UnknownError = 0,
InvalidFormat = 1,
}

View File

@@ -1,122 +0,0 @@
using System.Net;
using System.Net.Security;
using System.Net.Sockets;
using System.Security.Cryptography.X509Certificates;
using Phantom.Utils.Logging;
using Serilog;
namespace Phantom.Utils.Rpc.New;
public sealed class RpcServer(string loggerName, EndPoint endPoint, RpcServerCertificate certificate, RpcServerHandshake handshake) {
private readonly ILogger logger = PhantomLogger.Create<RpcServer>(loggerName);
private readonly List<Client> clients = [];
public async Task<bool> Run(CancellationToken shutdownToken) {
SslServerAuthenticationOptions sslOptions = new () {
AllowRenegotiation = false,
AllowTlsResume = true,
CertificateRevocationCheckMode = X509RevocationMode.NoCheck,
ClientCertificateRequired = false,
EnabledSslProtocols = TlsSupport.SupportedProtocols,
EncryptionPolicy = EncryptionPolicy.RequireEncryption,
ServerCertificate = certificate.Certificate,
};
try {
using var serverSocket = new Socket(endPoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
try {
serverSocket.Bind(endPoint);
serverSocket.Listen(5);
} catch (Exception e) {
logger.Error(e, "Could not bind to {EndPoint}.", endPoint);
return false;
}
try {
logger.Information("Server listening on {EndPoint}.", endPoint);
while (!shutdownToken.IsCancellationRequested) {
Socket clientSocket = await serverSocket.AcceptAsync(shutdownToken);
clients.Add(new Client(logger, clientSocket, sslOptions, handshake, shutdownToken));
clients.RemoveAll(static client => client.Task.IsCompleted);
}
} finally {
await Stop(serverSocket);
}
} catch (Exception e) {
logger.Error(e, "Server crashed with uncaught exception.");
return false;
}
return true;
}
private async Task Stop(Socket serverSocket) {
try {
serverSocket.Close();
} catch (Exception e) {
logger.Error(e, "Server socket failed to close.");
return;
}
logger.Information("Server socket closed, waiting for client sockets to close.");
try {
await Task.WhenAll(clients.Select(static client => client.Task));
} catch (Exception) {
// Ignore exceptions when shutting down.
}
logger.Information("Server stopped.");
}
private sealed class Client {
public Task Task { get; }
private readonly ILogger logger;
private readonly Socket socket;
private readonly SslServerAuthenticationOptions sslOptions;
private readonly RpcServerHandshake handshake;
private readonly CancellationToken shutdownToken;
public Client(ILogger logger, Socket socket, SslServerAuthenticationOptions sslOptions, RpcServerHandshake handshake, CancellationToken shutdownToken) {
this.logger = logger;
this.socket = socket;
this.sslOptions = sslOptions;
this.handshake = handshake;
this.shutdownToken = shutdownToken;
this.Task = Run();
}
private async Task Run() {
try {
await using var stream = new SslStream(new NetworkStream(socket, ownsSocket: false), leaveInnerStreamOpen: false);
try {
await stream.AuthenticateAsServerAsync(sslOptions, shutdownToken);
} catch (Exception e) {
logger.Error(e, "Could not establish a secure connection.");
return;
}
try {
await handshake.AcceptClient(socket.RemoteEndPoint?.ToString() ?? "<unknown address>", stream, shutdownToken);
} catch (EndOfStreamException) {
logger.Warning("Could not perform application handshake, connection lost.");
return;
} catch (Exception e) {
logger.Warning(e, "Could not perform application handshake.");
return;
}
byte[] buffer = new byte[1024];
int readBytes;
while ((readBytes = await stream.ReadAsync(buffer, shutdownToken)) > 0) {}
} finally {
socket.Close();
}
}
}
}

View File

@@ -1,5 +0,0 @@
namespace Phantom.Utils.Rpc.New;
public abstract class RpcServerHandshake {
protected internal abstract Task<bool> AcceptClient(string remoteAddress, Stream stream, CancellationToken cancellationToken);
}

View File

@@ -1,15 +0,0 @@
using MemoryPack;
namespace Phantom.Utils.Rpc.New;
public static class Serialization {
private static readonly MemoryPackSerializerOptions SerializerOptions = MemoryPackSerializerOptions.Utf8;
public static ValueTask Serialize<T>(T value, Stream stream, CancellationToken cancellationToken) {
return MemoryPackSerializer.SerializeAsync(stream, value, SerializerOptions, cancellationToken);
}
public static async ValueTask<T> Deserialize<T>(Stream stream, CancellationToken cancellationToken) {
return (await MemoryPackSerializer.DeserializeAsync<T>(stream, SerializerOptions, cancellationToken))!;
}
}

View File

@@ -7,7 +7,6 @@
<ItemGroup>
<PackageReference Include="MemoryPack" />
<PackageReference Include="NetMQ" />
<PackageReference Include="Serilog" />
</ItemGroup>

View File

@@ -1,8 +0,0 @@
using NetMQ;
namespace Phantom.Utils.Rpc;
public sealed record RpcConfiguration(string ServiceName, string Host, ushort Port, NetMQCertificate ServerCertificate) {
internal string LoggerName => "Rpc:" + ServiceName;
internal string TcpUrl => "tcp://" + Host + ":" + Port;
}

View File

@@ -1,32 +0,0 @@
using NetMQ;
using NetMQ.Sockets;
namespace Phantom.Utils.Rpc;
static class RpcExtensions {
public static ReadOnlyMemory<byte> Receive(this ClientSocket socket, CancellationToken cancellationToken) {
var msg = new Msg();
msg.InitEmpty();
try {
socket.Receive(ref msg, cancellationToken);
return msg.SliceAsMemory();
} finally {
// Only releases references, so the returned ReadOnlyMemory is safe.
msg.Close();
}
}
public static (uint, ReadOnlyMemory<byte>) Receive(this ServerSocket socket, CancellationToken cancellationToken) {
var msg = new Msg();
msg.InitEmpty();
try {
socket.Receive(ref msg, cancellationToken);
return (msg.RoutingId, msg.SliceAsMemory());
} finally {
// Only releases references, so the returned ReadOnlyMemory is safe.
msg.Close();
}
}
}

View File

@@ -0,0 +1,66 @@
using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Message;
using Serilog;
namespace Phantom.Utils.Rpc.Runtime.Client;
public sealed class RpcClient<TClientToServerMessage, TServerToClientMessage> : IDisposable {
public static async Task<RpcClient<TClientToServerMessage, TServerToClientMessage>?> Connect(string loggerName, RpcClientConnectionParameters connectionParameters, IMessageDefinitions<TClientToServerMessage, TServerToClientMessage> messageDefinitions, CancellationToken cancellationToken) {
RpcClientToServerConnector connector = new RpcClientToServerConnector(loggerName, connectionParameters);
RpcClientToServerConnector.Connection? connection = await connector.EstablishNewConnection(cancellationToken);
return connection == null ? null : new RpcClient<TClientToServerMessage, TServerToClientMessage>(loggerName, connectionParameters, messageDefinitions, connector, connection);
}
private readonly string loggerName;
private readonly ILogger logger;
private readonly MessageRegistry<TServerToClientMessage> messageRegistry;
private readonly MessageReceiveTracker messageReceiveTracker = new ();
private readonly RpcClientToServerConnection connection;
private readonly CancellationTokenSource shutdownCancellationTokenSource = new ();
public RpcSendChannel<TClientToServerMessage> SendChannel { get; }
private RpcClient(string loggerName, RpcClientConnectionParameters connectionParameters, IMessageDefinitions<TClientToServerMessage, TServerToClientMessage> messageDefinitions, RpcClientToServerConnector connector, RpcClientToServerConnector.Connection connection) {
this.loggerName = loggerName;
this.logger = PhantomLogger.Create<RpcClient<TClientToServerMessage, TServerToClientMessage>>(loggerName);
this.messageRegistry = messageDefinitions.ToClient;
this.connection = new RpcClientToServerConnection(loggerName, connector, connection);
this.SendChannel = new RpcSendChannel<TClientToServerMessage>(loggerName, connectionParameters.Common, this.connection, messageDefinitions.ToServer);
}
public async Task Listen(IMessageReceiver<TServerToClientMessage> receiver) {
var messageHandler = new RpcMessageHandler<TServerToClientMessage>(receiver, SendChannel);
var frameReader = new RpcFrameReader<TClientToServerMessage, TServerToClientMessage>(loggerName, messageRegistry, messageReceiveTracker, messageHandler, SendChannel);
try {
await connection.ReadConnection(frameReader, shutdownCancellationTokenSource.Token);
} catch (OperationCanceledException) {
// Ignore.
}
}
public async Task Shutdown() {
logger.Information("Shutting down client...");
try {
await SendChannel.Close();
} catch (Exception e) {
logger.Error(e, "Caught exception while closing send channel.");
}
try {
connection.StopReconnecting();
} catch (Exception e) {
logger.Error(e, "Caught exception while closing connection.");
}
await shutdownCancellationTokenSource.CancelAsync();
logger.Information("Client shut down.");
}
public void Dispose() {
connection.Dispose();
SendChannel.Dispose();
}
}

View File

@@ -0,0 +1,15 @@
using Phantom.Utils.Rpc.Runtime.Tls;
namespace Phantom.Utils.Rpc.Runtime.Client;
public readonly record struct RpcClientConnectionParameters(
string Host,
ushort Port,
string DistinguishedName,
RpcCertificateThumbprint CertificateThumbprint,
AuthToken AuthToken,
ushort SendQueueCapacity,
TimeSpan PingInterval
) {
internal RpcCommonConnectionParameters Common => new (SendQueueCapacity, PingInterval);
}

View File

@@ -0,0 +1,89 @@
using System.Net.Sockets;
using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Frame;
using Serilog;
namespace Phantom.Utils.Rpc.Runtime.Client;
sealed class RpcClientToServerConnection(string loggerName, RpcClientToServerConnector connector, RpcClientToServerConnector.Connection initialConnection) : IRpcConnectionProvider, IDisposable {
private readonly ILogger logger = PhantomLogger.Create<RpcClientToServerConnection>(loggerName);
private readonly SemaphoreSlim semaphore = new (1);
private RpcClientToServerConnector.Connection currentConnection = initialConnection;
private readonly CancellationTokenSource newConnectionCancellationTokenSource = new ();
public async Task<Stream> GetStream(CancellationToken cancellationToken) {
return (await GetConnection()).Stream;
}
private async Task<RpcClientToServerConnector.Connection> GetConnection() {
CancellationToken cancellationToken = newConnectionCancellationTokenSource.Token;
await semaphore.WaitAsync(cancellationToken);
try {
if (!currentConnection.Socket.Connected) {
currentConnection = await connector.EstablishNewConnectionWithRetry(cancellationToken);
}
return currentConnection;
} finally {
semaphore.Release();
}
}
public async Task ReadConnection(IFrameReader frameReader, CancellationToken cancellationToken) {
RpcClientToServerConnector.Connection? connection = null;
try {
while (true) {
connection?.Dispose();
connection = null;
try {
connection = await GetConnection();
} catch (OperationCanceledException) {
throw;
} catch (Exception e) {
logger.Warning(e, "Could not obtain a new connection.");
continue;
}
try {
await IFrame.ReadFrom(connection.Stream, frameReader, cancellationToken);
} catch (OperationCanceledException) {
throw;
} catch (EndOfStreamException) {
logger.Warning("Socket was closed.");
} catch (SocketException e) {
logger.Error("Socket reading was interrupted. Socket error {ErrorCode} ({ErrorCodeName}), reason: {ErrorMessage}", e.ErrorCode, e.SocketErrorCode, e.Message);
} catch (Exception e) {
logger.Error(e, "Socket reading was interrupted.");
}
try {
await connection.Shutdown();
} catch (Exception e) {
logger.Error(e, "Caught exception closing the socket.");
}
}
} finally {
if (connection != null) {
try {
await connection.Disconnect();
} finally {
connection.Dispose();
}
}
}
}
public void StopReconnecting() {
newConnectionCancellationTokenSource.Cancel();
}
public void Dispose() {
semaphore.Dispose();
newConnectionCancellationTokenSource.Dispose();
}
}

View File

@@ -0,0 +1,194 @@
using System.Net.Security;
using System.Net.Sockets;
using System.Security.Authentication;
using System.Security.Cryptography.X509Certificates;
using Phantom.Utils.Collections;
using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Runtime.Tls;
using Serilog;
namespace Phantom.Utils.Rpc.Runtime.Client;
internal sealed class RpcClientToServerConnector {
private static readonly TimeSpan InitialRetryDelay = TimeSpan.FromMilliseconds(100);
private static readonly TimeSpan MaximumRetryDelay = TimeSpan.FromSeconds(30);
private static readonly TimeSpan DisconnectTimeout = TimeSpan.FromSeconds(10);
private readonly ILogger logger;
private readonly Guid sessionId;
private readonly RpcClientConnectionParameters parameters;
private readonly SslClientAuthenticationOptions sslOptions;
private bool loggedCertificateValidationError = false;
public RpcClientToServerConnector(string loggerName, RpcClientConnectionParameters parameters) {
this.logger = PhantomLogger.Create<RpcClientToServerConnector>(loggerName);
this.sessionId = Guid.NewGuid();
this.parameters = parameters;
this.sslOptions = new SslClientAuthenticationOptions {
AllowRenegotiation = false,
AllowTlsResume = true,
CertificateRevocationCheckMode = X509RevocationMode.NoCheck,
EnabledSslProtocols = TlsSupport.SupportedProtocols,
EncryptionPolicy = EncryptionPolicy.RequireEncryption,
RemoteCertificateValidationCallback = ValidateServerCertificate,
TargetHost = parameters.DistinguishedName,
};
}
internal async Task<Connection> EstablishNewConnectionWithRetry(CancellationToken cancellationToken) {
TimeSpan nextAttemptDelay = InitialRetryDelay;
while (true) {
Connection? newConnection;
try {
newConnection = await EstablishNewConnection(cancellationToken);
} catch (Exception e) {
logger.Error(e, "Caught unhandled exception while connecting.");
newConnection = null;
}
if (newConnection != null) {
return newConnection;
}
cancellationToken.ThrowIfCancellationRequested();
logger.Information("Trying to reconnect in {Seconds}s.", nextAttemptDelay.TotalSeconds.ToString("F1"));
await Task.Delay(nextAttemptDelay, cancellationToken);
nextAttemptDelay = Comparables.Min(nextAttemptDelay.Multiply(1.5), MaximumRetryDelay);
}
}
internal async Task<Connection?> EstablishNewConnection(CancellationToken cancellationToken) {
logger.Information("Connecting to {Host}:{Port}...", parameters.Host, parameters.Port);
Socket clientSocket = new Socket(SocketType.Stream, ProtocolType.Tcp);
try {
await clientSocket.ConnectAsync(parameters.Host, parameters.Port, cancellationToken);
} catch (SocketException e) {
logger.Error("Could not connect. Socket error {ErrorCode} ({ErrorCodeName}), reason: {ErrorMessage}", e.ErrorCode, e.SocketErrorCode, e.Message);
return null;
} catch (Exception e) {
logger.Error(e, "Could not connect.");
return null;
}
SslStream? stream;
try {
stream = new SslStream(new NetworkStream(clientSocket, ownsSocket: false), leaveInnerStreamOpen: false);
if (await FinalizeConnection(stream, cancellationToken)) {
logger.Information("Connected to {Host}:{Port}.", parameters.Host, parameters.Port);
return new Connection(clientSocket, stream);
}
} catch (Exception e) {
logger.Error(e, "Caught unhandled exception.");
stream = null;
}
try {
await DisconnectSocket(clientSocket, stream);
} finally {
clientSocket.Close();
}
return null;
}
private async Task<bool> FinalizeConnection(SslStream stream, CancellationToken cancellationToken) {
try {
loggedCertificateValidationError = false;
await stream.AuthenticateAsClientAsync(sslOptions, cancellationToken);
} catch (AuthenticationException e) {
if (!loggedCertificateValidationError) {
logger.Error(e, "Could not establish a secure connection.");
}
return false;
}
logger.Information("Established a secure connection.");
try {
if (!await PerformApplicationHandshake(stream, cancellationToken)) {
return false;
}
} catch (EndOfStreamException) {
logger.Warning("Could not perform application handshake, connection lost.");
} catch (Exception e) {
logger.Warning(e, "Could not perform application handshake.");
}
return true;
}
private async Task<bool> PerformApplicationHandshake(Stream stream, CancellationToken cancellationToken) {
await Serialization.WriteAuthToken(parameters.AuthToken, stream, cancellationToken);
await Serialization.WriteGuid(sessionId, stream, cancellationToken);
var result = (RpcHandshakeResult) await Serialization.ReadByte(stream, cancellationToken);
switch (result) {
case RpcHandshakeResult.Success:
return true;
case RpcHandshakeResult.InvalidAuthToken:
logger.Error("Server rejected authorization token.");
return false;
default:
logger.Error("Server rejected client due to unknown error: {ErrorId}", result);
return false;
}
}
private bool ValidateServerCertificate(object sender, X509Certificate? certificate, X509Chain? chain, SslPolicyErrors sslPolicyErrors) {
if (certificate == null || sslPolicyErrors.HasFlag(SslPolicyErrors.RemoteCertificateNotAvailable)) {
logger.Error("Could not establish a secure connection, server did not provide a certificate.");
}
else if (sslPolicyErrors.HasFlag(SslPolicyErrors.RemoteCertificateNameMismatch)) {
logger.Error("Could not establish a secure connection, server certificate has the wrong name: {Name}", certificate.Subject);
}
else if (!parameters.CertificateThumbprint.Check(certificate)) {
logger.Error("Could not establish a secure connection, server certificate does not match.");
}
else if (TlsSupport.CheckAlgorithm((X509Certificate2) certificate) is {} error) {
logger.Error("Could not establish a secure connection, server certificate rejected because it uses {ActualAlgorithmName} instead of {ExpectedAlgorithmName}.", error.ActualAlgorithmName, error.ExpectedAlgorithmName);
}
else if ((sslPolicyErrors & ~SslPolicyErrors.RemoteCertificateChainErrors) != SslPolicyErrors.None) {
logger.Error("Could not establish a secure connection, server certificate validation failed.");
}
else {
return true;
}
loggedCertificateValidationError = true;
return false;
}
private static async Task DisconnectSocket(Socket socket, Stream? stream) {
if (stream != null) {
await stream.DisposeAsync();
}
using CancellationTokenSource timeoutTokenSource = new CancellationTokenSource(DisconnectTimeout);
await socket.DisconnectAsync(reuseSocket: false, timeoutTokenSource.Token);
}
internal sealed record Connection(Socket Socket, Stream Stream) : IDisposable {
public async Task Disconnect() {
await DisconnectSocket(Socket, Stream);
}
public async ValueTask Shutdown() {
await Stream.DisposeAsync();
Socket.Shutdown(SocketShutdown.Both);
}
public void Dispose() {
Stream.Dispose();
Socket.Close();
}
}
}

View File

@@ -1,7 +0,0 @@
using Phantom.Utils.Actor;
namespace Phantom.Utils.Rpc.Runtime;
public interface IRegistrationHandler<TClientMessage, TServerMessage, TRegistrationMessage> where TRegistrationMessage : TServerMessage {
Task<Props<TServerMessage>?> TryRegister(RpcConnectionToClient<TClientMessage> connection, TRegistrationMessage message);
}

View File

@@ -0,0 +1,5 @@
namespace Phantom.Utils.Rpc.Runtime;
interface IRpcConnectionProvider {
Task<Stream> GetStream(CancellationToken cancellationToken);
}

View File

@@ -1,9 +0,0 @@
namespace Phantom.Utils.Rpc.Runtime;
sealed class RpcClientConnectionClosedEventArgs : EventArgs {
internal uint RoutingId { get; }
internal RpcClientConnectionClosedEventArgs(uint routingId) {
RoutingId = routingId;
}
}

View File

@@ -1,72 +0,0 @@
using NetMQ.Sockets;
using Phantom.Utils.Actor;
using Phantom.Utils.Rpc.Message;
using Phantom.Utils.Rpc.Sockets;
using Serilog;
using Serilog.Events;
namespace Phantom.Utils.Rpc.Runtime;
public abstract class RpcClientRuntime<TClientMessage, TServerMessage, TReplyMessage> : RpcRuntime<ClientSocket> where TReplyMessage : TClientMessage, TServerMessage {
private readonly RpcConnectionToServer<TServerMessage> connection;
private readonly IMessageDefinitions<TClientMessage, TServerMessage, TReplyMessage> messageDefinitions;
private readonly ActorRef<TClientMessage> handlerActor;
private readonly SemaphoreSlim disconnectSemaphore;
private readonly CancellationToken receiveCancellationToken;
protected RpcClientRuntime(RpcClientSocket<TClientMessage, TServerMessage, TReplyMessage> socket, ActorRef<TClientMessage> handlerActor, SemaphoreSlim disconnectSemaphore, CancellationToken receiveCancellationToken) : base(socket) {
this.connection = socket.Connection;
this.messageDefinitions = socket.MessageDefinitions;
this.handlerActor = handlerActor;
this.disconnectSemaphore = disconnectSemaphore;
this.receiveCancellationToken = receiveCancellationToken;
}
private protected sealed override Task Run(ClientSocket socket) {
return RunWithConnection(socket, connection);
}
protected virtual async Task RunWithConnection(ClientSocket socket, RpcConnectionToServer<TServerMessage> connection) {
var replySender = new ReplySender<TServerMessage, TReplyMessage>(connection, messageDefinitions);
var messageHandler = new MessageHandler<TClientMessage>(LoggerName, handlerActor, replySender);
try {
while (!receiveCancellationToken.IsCancellationRequested) {
var data = socket.Receive(receiveCancellationToken);
LogMessageType(RuntimeLogger, data);
if (data.Length > 0) {
messageDefinitions.ToClient.Handle(data, messageHandler);
}
}
} catch (OperationCanceledException) {
// Ignore.
} finally {
await handlerActor.Stop();
RuntimeLogger.Debug("ZeroMQ client stopped receiving messages.");
await disconnectSemaphore.WaitAsync(CancellationToken.None);
}
}
private protected sealed override async Task Disconnect(ClientSocket socket) {
await SendDisconnectMessage(socket, RuntimeLogger);
}
protected abstract Task SendDisconnectMessage(ClientSocket socket, ILogger logger);
private void LogMessageType(ILogger logger, ReadOnlyMemory<byte> data) {
if (!logger.IsEnabled(LogEventLevel.Verbose)) {
return;
}
if (data.Length > 0 && messageDefinitions.ToClient.TryGetType(data, out var type)) {
logger.Verbose("Received {MessageType} ({Bytes} B).", type.Name, data.Length);
}
else {
logger.Verbose("Received {Bytes} B message.", data.Length);
}
}
}

View File

@@ -0,0 +1,6 @@
namespace Phantom.Utils.Rpc.Runtime;
readonly record struct RpcCommonConnectionParameters(
ushort SendQueueCapacity,
TimeSpan PingInterval
);

View File

@@ -1,40 +0,0 @@
using Phantom.Utils.Actor;
using Phantom.Utils.Rpc.Message;
namespace Phantom.Utils.Rpc.Runtime;
public abstract class RpcConnection<TMessageBase> {
private readonly MessageRegistry<TMessageBase> messageRegistry;
private readonly MessageReplyTracker replyTracker;
internal RpcConnection(MessageRegistry<TMessageBase> messageRegistry, MessageReplyTracker replyTracker) {
this.messageRegistry = messageRegistry;
this.replyTracker = replyTracker;
}
private protected abstract ValueTask Send(byte[] bytes);
public async Task Send<TMessage>(TMessage message) where TMessage : TMessageBase {
var bytes = messageRegistry.Write(message).ToArray();
if (bytes.Length > 0) {
await Send(bytes);
}
}
public async Task<TReply> Send<TMessage, TReply>(TMessage message, TimeSpan waitForReplyTime, CancellationToken waitForReplyCancellationToken) where TMessage : TMessageBase, ICanReply<TReply> {
var sequenceId = replyTracker.RegisterReply();
var bytes = messageRegistry.Write<TMessage, TReply>(sequenceId, message).ToArray();
if (bytes.Length == 0) {
replyTracker.ForgetReply(sequenceId);
throw new ArgumentException("Could not write message.", nameof(message));
}
await Send(bytes);
return await replyTracker.WaitForReply<TReply>(sequenceId, waitForReplyTime, waitForReplyCancellationToken);
}
public void Receive(IReply message) {
replyTracker.ReceiveReply(message.SequenceId, message.SerializedReply);
}
}

View File

@@ -1,41 +0,0 @@
using NetMQ;
using NetMQ.Sockets;
using Phantom.Utils.Rpc.Message;
namespace Phantom.Utils.Rpc.Runtime;
public sealed class RpcConnectionToClient<TMessageBase> : RpcConnection<TMessageBase> {
private readonly ServerSocket socket;
private readonly uint routingId;
internal event EventHandler<RpcClientConnectionClosedEventArgs>? Closed;
private bool isClosed;
internal RpcConnectionToClient(ServerSocket socket, uint routingId, MessageRegistry<TMessageBase> messageRegistry, MessageReplyTracker replyTracker) : base(messageRegistry, replyTracker) {
this.socket = socket;
this.routingId = routingId;
}
public bool IsSame(RpcConnectionToClient<TMessageBase> other) {
return this.routingId == other.routingId && this.socket == other.socket;
}
public void Close() {
bool hasClosed = false;
lock (this) {
if (!isClosed) {
isClosed = true;
hasClosed = true;
}
}
if (hasClosed) {
Closed?.Invoke(this, new RpcClientConnectionClosedEventArgs(routingId));
}
}
private protected override ValueTask Send(byte[] bytes) {
return socket.SendAsync(routingId, bytes);
}
}

View File

@@ -1,25 +0,0 @@
using NetMQ;
using NetMQ.Sockets;
using Phantom.Utils.Rpc.Message;
using Phantom.Utils.Tasks;
namespace Phantom.Utils.Rpc.Runtime;
public sealed class RpcConnectionToServer<TMessageBase> : RpcConnection<TMessageBase> {
private readonly ClientSocket socket;
private readonly TaskCompletionSource isReady = AsyncTasks.CreateCompletionSource();
public Task IsReady => isReady.Task;
internal RpcConnectionToServer(ClientSocket socket, MessageRegistry<TMessageBase> messageRegistry, MessageReplyTracker replyTracker) : base(messageRegistry, replyTracker) {
this.socket = socket;
}
public void SetIsReady() {
isReady.TrySetResult();
}
private protected override ValueTask Send(byte[] bytes) {
return socket.SendAsync(bytes);
}
}

View File

@@ -0,0 +1,10 @@
namespace Phantom.Utils.Rpc.Runtime;
public enum RpcError : byte {
InvalidData = 0,
UnknownMessageRegistryCode = 1,
MessageTooLarge = 2,
MessageDeserializationError = 3,
MessageHandlingError = 4,
MessageAlreadyHandled = 5,
}

View File

@@ -0,0 +1,21 @@
namespace Phantom.Utils.Rpc.Runtime;
sealed class RpcErrorException : Exception {
internal static RpcErrorException From(RpcError error) {
return error switch {
RpcError.InvalidData => new RpcErrorException("Invalid data", error),
RpcError.UnknownMessageRegistryCode => new RpcErrorException("Unknown message registry code", error),
RpcError.MessageTooLarge => new RpcErrorException("Message is too large", error),
RpcError.MessageDeserializationError => new RpcErrorException("Message deserialization error", error),
RpcError.MessageHandlingError => new RpcErrorException("Message handling error", error),
RpcError.MessageAlreadyHandled => new RpcErrorException("Message already handled", error),
_ => new RpcErrorException("Unknown error", error),
};
}
public RpcError Error { get; }
internal RpcErrorException(string message, RpcError error) : base(message) {
this.Error = error;
}
}

View File

@@ -0,0 +1,53 @@
using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Frame;
using Phantom.Utils.Rpc.Frame.Types;
using Phantom.Utils.Rpc.Message;
using Serilog;
namespace Phantom.Utils.Rpc.Runtime;
sealed class RpcFrameReader<TSentMessage, TReceivedMessage>(
string loggerName,
MessageRegistry<TReceivedMessage> messageRegistry,
MessageReceiveTracker messageReceiveTracker,
RpcMessageHandler<TReceivedMessage> messageHandler,
RpcSendChannel<TSentMessage> sendChannel
) : IFrameReader {
private readonly ILogger logger = PhantomLogger.Create<RpcFrameReader<TSentMessage, TReceivedMessage>>(loggerName);
public ValueTask OnPingFrame(DateTimeOffset pingTime, CancellationToken cancellationToken) {
messageHandler.OnPing();
return sendChannel.SendPong(pingTime, cancellationToken);
}
public void OnPongFrame(PongFrame frame) {
sendChannel.ReceivePong(frame);
}
public Task OnMessageFrame(MessageFrame frame, CancellationToken cancellationToken) {
if (!messageReceiveTracker.ReceiveMessage(frame.MessageId)) {
logger.Warning("Received duplicate message {MessageId}.", frame.MessageId);
return messageHandler.SendError(frame.MessageId, RpcError.MessageAlreadyHandled, cancellationToken).AsTask();
}
if (messageRegistry.TryGetType(frame, out var messageType)) {
logger.Verbose("Received message {MesageId} of type {MessageType} ({Bytes} B).", frame.MessageId, messageType.Name, frame.SerializedMessage.Length);
}
return messageRegistry.Handle(frame, messageHandler, cancellationToken);
}
public void OnReplyFrame(ReplyFrame frame) {
logger.Verbose("Received reply to message {MesageId} ({Bytes} B).", frame.ReplyingToMessageId, frame.SerializedReply.Length);
sendChannel.ReceiveReply(frame);
}
public void OnErrorFrame(ErrorFrame frame) {
logger.Warning("Received error response to message {MesageId}: {Error}", frame.ReplyingToMessageId, frame.Error);
sendChannel.ReceiveError(frame.ReplyingToMessageId, frame.Error);
}
public void OnUnknownFrameId(byte frameId) {
logger.Error("Received unknown frame ID: {FrameId}", frameId);
}
}

View File

@@ -0,0 +1,6 @@
namespace Phantom.Utils.Rpc.Runtime;
enum RpcHandshakeResult : byte {
Success = 0,
InvalidAuthToken = 1,
}

View File

@@ -0,0 +1,19 @@
using Phantom.Utils.Rpc.Message;
namespace Phantom.Utils.Rpc.Runtime;
sealed class RpcMessageHandler<TMessageBase>(IMessageReceiver<TMessageBase> receiver, IRpcReplySender replySender) {
public IMessageReceiver<TMessageBase> Receiver => receiver;
public void OnPing() {
receiver.OnPing();
}
public ValueTask SendReply<TReply>(uint messageId, TReply reply, CancellationToken cancellationToken) {
return replySender.SendReply(messageId, reply, cancellationToken);
}
public ValueTask SendError(uint messageId, RpcError error, CancellationToken cancellationToken) {
return replySender.SendError(messageId, error, cancellationToken);
}
}

View File

@@ -1,75 +0,0 @@
using Akka.Actor;
using Akka.Event;
using Phantom.Utils.Actor;
using Phantom.Utils.Rpc.Message;
namespace Phantom.Utils.Rpc.Runtime;
sealed class RpcReceiverActor<TClientMessage, TServerMessage, TRegistrationMessage, TReplyMessage> : ReceiveActor<RpcReceiverActor<TClientMessage, TServerMessage, TRegistrationMessage, TReplyMessage>.ReceiveMessageCommand>, IWithStash where TRegistrationMessage : TServerMessage where TReplyMessage : TClientMessage, TServerMessage {
public readonly record struct Init(
string LoggerName,
IMessageDefinitions<TClientMessage, TServerMessage, TReplyMessage> MessageDefinitions,
IRegistrationHandler<TClientMessage, TServerMessage, TRegistrationMessage> RegistrationHandler,
RpcConnectionToClient<TClientMessage> Connection
);
public static Props<ReceiveMessageCommand> Factory(Init init) {
return Props<ReceiveMessageCommand>.Create(() => new RpcReceiverActor<TClientMessage, TServerMessage, TRegistrationMessage, TReplyMessage>(init), new ActorConfiguration {
SupervisorStrategy = SupervisorStrategies.Resume,
StashCapacity = 100
});
}
public IStash Stash { get; set; } = null!;
private readonly string loggerName;
private readonly IMessageDefinitions<TClientMessage, TServerMessage, TReplyMessage> messageDefinitions;
private readonly IRegistrationHandler<TClientMessage, TServerMessage, TRegistrationMessage> registrationHandler;
private readonly RpcConnectionToClient<TClientMessage> connection;
private RpcReceiverActor(Init init) {
this.loggerName = init.LoggerName;
this.messageDefinitions = init.MessageDefinitions;
this.registrationHandler = init.RegistrationHandler;
this.connection = init.Connection;
ReceiveAsync<ReceiveMessageCommand>(ReceiveMessageUnauthorized);
}
public sealed record ReceiveMessageCommand(Type MessageType, ReadOnlyMemory<byte> Data);
private async Task ReceiveMessageUnauthorized(ReceiveMessageCommand command) {
if (command.MessageType == typeof(TRegistrationMessage)) {
await HandleRegistrationMessage(command);
}
else if (Stash.IsFull) {
Context.GetLogger().Warning("Stash is full, dropping message: {MessageType}", command.MessageType);
}
else {
Stash.Stash();
}
}
private async Task HandleRegistrationMessage(ReceiveMessageCommand command) {
if (!messageDefinitions.ToServer.Read(command.Data, out TRegistrationMessage message)) {
return;
}
var props = await registrationHandler.TryRegister(connection, message);
if (props == null) {
return;
}
var handlerActor = Context.ActorOf(props, "Handler");
var replySender = new ReplySender<TClientMessage, TReplyMessage>(connection, messageDefinitions);
BecomeAuthorized(new MessageHandler<TServerMessage>(loggerName, handlerActor, replySender));
}
private void BecomeAuthorized(MessageHandler<TServerMessage> handler) {
Stash.UnstashAll();
Become(() => {
Receive<ReceiveMessageCommand>(command => messageDefinitions.ToServer.Handle(command.Data, handler));
});
}
}

View File

@@ -1,50 +0,0 @@
using System.Diagnostics.CodeAnalysis;
using NetMQ;
using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Message;
using Phantom.Utils.Rpc.Sockets;
using Serilog;
namespace Phantom.Utils.Rpc.Runtime;
public abstract class RpcRuntime<TSocket> where TSocket : ThreadSafeSocket {
private readonly TSocket socket;
private protected string LoggerName { get; }
private protected ILogger RuntimeLogger { get; }
private protected MessageReplyTracker ReplyTracker { get; }
protected RpcRuntime(RpcSocket<TSocket> socket) {
this.socket = socket.Socket;
this.LoggerName = socket.Config.LoggerName;
this.RuntimeLogger = PhantomLogger.Create(LoggerName);
this.ReplyTracker = socket.ReplyTracker;
}
protected async Task Launch() {
[SuppressMessage("ReSharper", "AccessToDisposedClosure")]
async Task RunTask() {
try {
await Run(socket);
} catch (Exception e) {
RuntimeLogger.Error(e, "Caught exception in RPC thread.");
}
}
try {
await Task.Factory.StartNew(RunTask, CancellationToken.None, TaskCreationOptions.LongRunning, TaskScheduler.Default).Unwrap();
} catch (OperationCanceledException) {
// Ignore.
} finally {
await Disconnect(socket);
socket.Dispose();
RuntimeLogger.Information("ZeroMQ runtime stopped.");
}
}
private protected abstract Task Run(TSocket socket);
private protected abstract Task Disconnect(TSocket socket);
}

View File

@@ -0,0 +1,171 @@
using System.Diagnostics.CodeAnalysis;
using System.Threading.Channels;
using Phantom.Utils.Actor;
using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Frame;
using Phantom.Utils.Rpc.Frame.Types;
using Phantom.Utils.Rpc.Message;
using Serilog;
namespace Phantom.Utils.Rpc.Runtime;
public sealed class RpcSendChannel<TMessageBase> : IRpcReplySender, IDisposable {
private readonly ILogger logger;
private readonly IRpcConnectionProvider connectionProvider;
private readonly MessageRegistry<TMessageBase> messageRegistry;
private readonly MessageReplyTracker messageReplyTracker;
private readonly Channel<IFrame> sendQueue;
private readonly Task sendQueueTask;
private readonly Task pingTask;
private readonly CancellationTokenSource sendQueueCancellationTokenSource = new ();
private readonly CancellationTokenSource pingCancellationTokenSource = new ();
private uint nextMessageId;
private TaskCompletionSource<DateTimeOffset>? pongTask;
internal RpcSendChannel(string loggerName, RpcCommonConnectionParameters connectionParameters, IRpcConnectionProvider connectionProvider, MessageRegistry<TMessageBase> messageRegistry) {
this.logger = PhantomLogger.Create<RpcSendChannel<TMessageBase>>(loggerName);
this.connectionProvider = connectionProvider;
this.messageRegistry = messageRegistry;
this.messageReplyTracker = new MessageReplyTracker(loggerName);
this.sendQueue = Channel.CreateBounded<IFrame>(new BoundedChannelOptions(connectionParameters.SendQueueCapacity) {
AllowSynchronousContinuations = false,
FullMode = BoundedChannelFullMode.Wait,
SingleReader = true,
SingleWriter = false,
});
this.sendQueueTask = ProcessSendQueue();
this.pingTask = Ping(connectionParameters.PingInterval);
}
internal async ValueTask SendPong(DateTimeOffset pingTime, CancellationToken cancellationToken) {
await SendFrame(new PongFrame(pingTime), cancellationToken);
}
public bool TrySendMessage<TMessage>(TMessage message) where TMessage : TMessageBase {
return sendQueue.Writer.TryWrite(NextMessageFrame(message));
}
public async ValueTask SendMessage<TMessage>(TMessage message, CancellationToken cancellationToken = default) where TMessage : TMessageBase {
await SendFrame(NextMessageFrame(message), cancellationToken);
}
public async Task<TReply> SendMessage<TMessage, TReply>(TMessage message, TimeSpan waitForReplyTime, CancellationToken cancellationToken) where TMessage : TMessageBase, ICanReply<TReply> {
MessageFrame frame = NextMessageFrame(message);
uint messageId = frame.MessageId;
messageReplyTracker.RegisterReply(messageId);
try {
await SendFrame(frame, cancellationToken);
} catch (Exception) {
messageReplyTracker.ForgetReply(messageId);
throw;
}
return await messageReplyTracker.WaitForReply<TReply>(messageId, waitForReplyTime, cancellationToken);
}
async ValueTask IRpcReplySender.SendReply<TReply>(uint replyingToMessageId, TReply reply, CancellationToken cancellationToken) {
await SendFrame(new ReplyFrame(replyingToMessageId, Serialization.Serialize(reply)), cancellationToken);
}
async ValueTask IRpcReplySender.SendError(uint replyingToMessageId, RpcError error, CancellationToken cancellationToken) {
await SendFrame(new ErrorFrame(replyingToMessageId, error), cancellationToken);
}
private async ValueTask SendFrame(IFrame frame, CancellationToken cancellationToken) {
if (!sendQueue.Writer.TryWrite(frame)) {
await sendQueue.Writer.WriteAsync(frame, cancellationToken);
}
}
private MessageFrame NextMessageFrame<T>(T message) where T : TMessageBase {
uint messageId = Interlocked.Increment(ref nextMessageId);
return messageRegistry.CreateFrame(messageId, message);
}
private async Task ProcessSendQueue() {
CancellationToken cancellationToken = sendQueueCancellationTokenSource.Token;
await foreach (IFrame frame in sendQueue.Reader.ReadAllAsync(cancellationToken)) {
while (true) {
try {
Stream stream = await connectionProvider.GetStream(cancellationToken);
await stream.WriteAsync(frame.FrameType, cancellationToken);
await frame.Write(stream, cancellationToken);
await stream.FlushAsync(cancellationToken);
break;
} catch (OperationCanceledException) {
throw;
} catch (Exception) {
// Retry.
}
}
}
}
[SuppressMessage("ReSharper", "FunctionNeverReturns")]
private async Task Ping(TimeSpan interval) {
CancellationToken cancellationToken = pingCancellationTokenSource.Token;
while (true) {
await Task.Delay(interval, cancellationToken);
pongTask = new TaskCompletionSource<DateTimeOffset>();
if (!sendQueue.Writer.TryWrite(PingFrame.Instance)) {
cancellationToken.ThrowIfCancellationRequested();
logger.Warning("Skipped a ping due to a full queue.");
continue;
}
DateTimeOffset pingTime = await pongTask.Task.WaitAsync(cancellationToken);
DateTimeOffset currentTime = DateTimeOffset.UtcNow;
TimeSpan roundTripTime = currentTime - pingTime;
logger.Information("Received pong, round trip time: {RoundTripTime} ms", (long) roundTripTime.TotalMilliseconds);
}
}
internal void ReceivePong(PongFrame frame) {
pongTask?.TrySetResult(frame.PingTime);
}
internal void ReceiveReply(ReplyFrame frame) {
messageReplyTracker.ReceiveReply(frame.ReplyingToMessageId, frame.SerializedReply);
}
internal void ReceiveError(uint messageId, RpcError error) {
messageReplyTracker.FailReply(messageId, RpcErrorException.From(error));
}
internal async Task Close() {
await pingCancellationTokenSource.CancelAsync();
sendQueue.Writer.TryComplete();
try {
await pingTask;
} catch (Exception) {
// Ignore.
}
try {
await sendQueueTask.WaitAsync(TimeSpan.FromSeconds(15));
} catch (TimeoutException) {
logger.Warning("Could not finish processing send queue before timeout, forcibly shutting it down.");
await sendQueueCancellationTokenSource.CancelAsync();
} catch (Exception) {
// Ignore.
}
}
public void Dispose() {
sendQueueTask.Dispose();
sendQueueCancellationTokenSource.Dispose();
pingCancellationTokenSource.Dispose();
}
}

Some files were not shown because too many files have changed in this diff Show More