1
0
mirror of https://github.com/chylex/Minecraft-Phantom-Panel.git synced 2025-09-18 06:24:48 +02:00

1 Commits

Author SHA1 Message Date
c9c9b91e34 Replace NetMQ with custom TCP logic 2025-09-17 17:05:25 +02:00
90 changed files with 1404 additions and 1005 deletions

View File

@@ -1,9 +1,13 @@
using Akka.Actor; using System.Collections.Immutable;
using Akka.Actor;
using Phantom.Agent.Minecraft.Java; using Phantom.Agent.Minecraft.Java;
using Phantom.Agent.Services.Backups; using Phantom.Agent.Services.Backups;
using Phantom.Agent.Services.Instances; using Phantom.Agent.Services.Instances;
using Phantom.Agent.Services.Rpc; using Phantom.Agent.Services.Rpc;
using Phantom.Common.Data.Agent; using Phantom.Common.Data.Agent;
using Phantom.Common.Data.Replies;
using Phantom.Common.Messages.Agent.ToAgent;
using Phantom.Common.Messages.Agent.ToController;
using Phantom.Utils.Actor; using Phantom.Utils.Actor;
using Phantom.Utils.Logging; using Phantom.Utils.Logging;
using Serilog; using Serilog;
@@ -15,6 +19,9 @@ public sealed class AgentServices {
public ActorSystem ActorSystem { get; } public ActorSystem ActorSystem { get; }
private ControllerConnection ControllerConnection { get; }
private AgentInfo AgentInfo { get; }
private AgentFolders AgentFolders { get; } private AgentFolders AgentFolders { get; }
private AgentState AgentState { get; } private AgentState AgentState { get; }
private BackupManager BackupManager { get; } private BackupManager BackupManager { get; }
@@ -26,6 +33,9 @@ public sealed class AgentServices {
public AgentServices(AgentInfo agentInfo, AgentFolders agentFolders, AgentServiceConfiguration serviceConfiguration, ControllerConnection controllerConnection) { public AgentServices(AgentInfo agentInfo, AgentFolders agentFolders, AgentServiceConfiguration serviceConfiguration, ControllerConnection controllerConnection) {
this.ActorSystem = ActorSystemFactory.Create("Agent"); this.ActorSystem = ActorSystemFactory.Create("Agent");
this.ControllerConnection = controllerConnection;
this.AgentInfo = agentInfo;
this.AgentFolders = agentFolders; this.AgentFolders = agentFolders;
this.AgentState = new AgentState(); this.AgentState = new AgentState();
this.BackupManager = new BackupManager(agentFolders, serviceConfiguration.MaxConcurrentCompressionTasks); this.BackupManager = new BackupManager(agentFolders, serviceConfiguration.MaxConcurrentCompressionTasks);
@@ -43,6 +53,41 @@ public sealed class AgentServices {
} }
} }
public async Task<bool> Register(CancellationToken cancellationToken) {
Logger.Information("Registering with the controller...");
// TODO NEED TO SEND WHEN SERVER RESTARTS!!!
ImmutableArray<ConfigureInstanceMessage> configureInstanceMessages;
try {
configureInstanceMessages = await ControllerConnection.Send<RegisterAgentMessage, ImmutableArray<ConfigureInstanceMessage>>(new RegisterAgentMessage(AgentInfo), TimeSpan.FromMinutes(1), cancellationToken);
} catch (Exception e) {
Logger.Fatal(e, "Registration failed.");
return false;
}
foreach (var configureInstanceMessage in configureInstanceMessages) {
var configureInstanceCommand = new InstanceManagerActor.ConfigureInstanceCommand(
configureInstanceMessage.InstanceGuid,
configureInstanceMessage.Configuration,
configureInstanceMessage.LaunchProperties,
configureInstanceMessage.LaunchNow,
AlwaysReportStatus: true
);
var configureInstanceResult = await InstanceManager.Request(configureInstanceCommand, cancellationToken);
if (!configureInstanceResult.Is(ConfigureInstanceResult.Success)) {
Logger.Fatal("Unable to configure instance \"{Name}\" (GUID {Guid}), shutting down.", configureInstanceMessage.Configuration.InstanceName, configureInstanceMessage.InstanceGuid);
return false;
}
}
await ControllerConnection.Send(new AdvertiseJavaRuntimesMessage(JavaRuntimeRepository.All), cancellationToken);
InstanceTicketManager.RefreshAgentStatus();
return true;
}
public async Task Shutdown() { public async Task Shutdown() {
Logger.Information("Stopping services..."); Logger.Information("Stopping services...");

View File

@@ -56,11 +56,6 @@ sealed class InstanceManagerActor : ReceiveActor<InstanceManagerActor.ICommand>
ReceiveAsync<ShutdownCommand>(Shutdown); ReceiveAsync<ShutdownCommand>(Shutdown);
} }
private string GetInstanceLoggerName(Guid guid) {
var prefix = guid.ToString();
return prefix[..prefix.IndexOf('-')] + "/" + Interlocked.Increment(ref instanceLoggerSequenceId);
}
private sealed record InstanceInfo(ActorRef<InstanceActor.ICommand> Actor, InstanceConfiguration Configuration, IServerLauncher Launcher); private sealed record InstanceInfo(ActorRef<InstanceActor.ICommand> Actor, InstanceConfiguration Configuration, IServerLauncher Launcher);
public interface ICommand {} public interface ICommand {}
@@ -118,7 +113,8 @@ sealed class InstanceManagerActor : ReceiveActor<InstanceManagerActor.ICommand>
} }
} }
else { else {
var instanceInit = new InstanceActor.Init(agentState, instanceGuid, GetInstanceLoggerName(instanceGuid), instanceServices, instanceTicketManager, shutdownCancellationToken); var instanceLoggerName = PhantomLogger.ShortenGuid(instanceGuid) + "/" + Interlocked.Increment(ref instanceLoggerSequenceId);;
var instanceInit = new InstanceActor.Init(agentState, instanceGuid, instanceLoggerName, instanceServices, instanceTicketManager, shutdownCancellationToken);
instances[instanceGuid] = instance = new InstanceInfo(Context.ActorOf(InstanceActor.Factory(instanceInit), "Instance-" + instanceGuid), configuration, launcher); instances[instanceGuid] = instance = new InstanceInfo(Context.ActorOf(InstanceActor.Factory(instanceInit), "Instance-" + instanceGuid), configuration, launcher);
Logger.Information("Created instance \"{Name}\" (GUID {Guid}).", configuration.InstanceName, instanceGuid); Logger.Information("Created instance \"{Name}\" (GUID {Guid}).", configuration.InstanceName, instanceGuid);

View File

@@ -1,15 +1,20 @@
using Phantom.Common.Messages.Agent; using Phantom.Common.Messages.Agent;
using Phantom.Utils.Actor;
using Phantom.Utils.Rpc.Runtime; using Phantom.Utils.Rpc.Runtime;
namespace Phantom.Agent.Services.Rpc; namespace Phantom.Agent.Services.Rpc;
public sealed class ControllerConnection(RpcSendChannel<IMessageToController> sendChannel) { public sealed class ControllerConnection(RpcSendChannel<IMessageToController> sendChannel) {
public ValueTask Send<TMessage>(TMessage message) where TMessage : IMessageToController { public ValueTask Send<TMessage>(TMessage message, CancellationToken cancellationToken) where TMessage : IMessageToController {
return sendChannel.SendMessage(message, CancellationToken.None /* TODO */); return sendChannel.SendMessage(message, cancellationToken);
} }
// TODO handle properly // TODO handle properly
public bool TrySend<TMessage>(TMessage message) where TMessage : IMessageToController { public bool TrySend<TMessage>(TMessage message) where TMessage : IMessageToController {
return sendChannel.TrySendMessage(message); return sendChannel.TrySendMessage(message);
} }
public Task<TReply> Send<TMessage, TReply>(TMessage message, TimeSpan waitForReplyTime, CancellationToken cancellationToken) where TMessage : IMessageToController, ICanReply<TReply> {
return sendChannel.SendMessage<TMessage, TReply>(message, waitForReplyTime, cancellationToken);
}
} }

View File

@@ -1,82 +1,32 @@
using Phantom.Agent.Services.Instances; using Phantom.Agent.Services.Instances;
using Phantom.Common.Data; using Phantom.Common.Data;
using Phantom.Common.Data.Instance;
using Phantom.Common.Data.Replies; using Phantom.Common.Data.Replies;
using Phantom.Common.Messages.Agent; using Phantom.Common.Messages.Agent;
using Phantom.Common.Messages.Agent.ToAgent; using Phantom.Common.Messages.Agent.ToAgent;
using Phantom.Common.Messages.Agent.ToController;
using Phantom.Utils.Actor; using Phantom.Utils.Actor;
using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Runtime;
using Serilog;
namespace Phantom.Agent.Services.Rpc; namespace Phantom.Agent.Services.Rpc;
public sealed class ControllerMessageHandlerActor : ReceiveActor<IMessageToAgent> { public sealed class ControllerMessageHandlerActor : ReceiveActor<IMessageToAgent> {
private static ILogger Logger { get; } = PhantomLogger.Create<ControllerMessageHandlerActor>(); public readonly record struct Init(AgentServices Agent);
public readonly record struct Init(RpcSendChannel<IMessageToController> SendChannel, AgentServices Agent, CancellationTokenSource ShutdownTokenSource);
public static Props<IMessageToAgent> Factory(Init init) { public static Props<IMessageToAgent> Factory(Init init) {
return Props<IMessageToAgent>.Create(() => new ControllerMessageHandlerActor(init), new ActorConfiguration { SupervisorStrategy = SupervisorStrategies.Resume }); return Props<IMessageToAgent>.Create(() => new ControllerMessageHandlerActor(init), new ActorConfiguration { SupervisorStrategy = SupervisorStrategies.Resume });
} }
private readonly RpcSendChannel<IMessageToController> sendChannel;
private readonly AgentServices agent; private readonly AgentServices agent;
private readonly CancellationTokenSource shutdownTokenSource;
private ControllerMessageHandlerActor(Init init) { private ControllerMessageHandlerActor(Init init) {
this.sendChannel = init.SendChannel;
this.agent = init.Agent; this.agent = init.Agent;
this.shutdownTokenSource = init.ShutdownTokenSource;
ReceiveAsync<RegisterAgentSuccessMessage>(HandleRegisterAgentSuccess);
Receive<RegisterAgentFailureMessage>(HandleRegisterAgentFailure);
ReceiveAndReplyLater<ConfigureInstanceMessage, Result<ConfigureInstanceResult, InstanceActionFailure>>(HandleConfigureInstance); ReceiveAndReplyLater<ConfigureInstanceMessage, Result<ConfigureInstanceResult, InstanceActionFailure>>(HandleConfigureInstance);
ReceiveAndReplyLater<LaunchInstanceMessage, Result<LaunchInstanceResult, InstanceActionFailure>>(HandleLaunchInstance); ReceiveAndReplyLater<LaunchInstanceMessage, Result<LaunchInstanceResult, InstanceActionFailure>>(HandleLaunchInstance);
ReceiveAndReplyLater<StopInstanceMessage, Result<StopInstanceResult, InstanceActionFailure>>(HandleStopInstance); ReceiveAndReplyLater<StopInstanceMessage, Result<StopInstanceResult, InstanceActionFailure>>(HandleStopInstance);
ReceiveAndReplyLater<SendCommandToInstanceMessage, Result<SendCommandToInstanceResult, InstanceActionFailure>>(HandleSendCommandToInstance); ReceiveAndReplyLater<SendCommandToInstanceMessage, Result<SendCommandToInstanceResult, InstanceActionFailure>>(HandleSendCommandToInstance);
} }
private async Task HandleRegisterAgentSuccess(RegisterAgentSuccessMessage message) {
Logger.Information("Agent authentication successful.");
void ShutdownAfterConfigurationFailed(Guid instanceGuid, InstanceConfiguration configuration) {
Logger.Fatal("Unable to configure instance \"{Name}\" (GUID {Guid}), shutting down.", configuration.InstanceName, instanceGuid);
shutdownTokenSource.Cancel();
}
foreach (var configureInstanceMessage in message.InitialInstanceConfigurations) {
var result = await HandleConfigureInstance(configureInstanceMessage, alwaysReportStatus: true);
if (!result.Is(ConfigureInstanceResult.Success)) {
ShutdownAfterConfigurationFailed(configureInstanceMessage.InstanceGuid, configureInstanceMessage.Configuration);
return;
}
}
await sendChannel.SendMessage(new AdvertiseJavaRuntimesMessage(agent.JavaRuntimeRepository.All), CancellationToken.None);
agent.InstanceTicketManager.RefreshAgentStatus();
}
private void HandleRegisterAgentFailure(RegisterAgentFailureMessage message) {
string errorMessage = message.FailureKind switch {
RegisterAgentFailure.ConnectionAlreadyHasAnAgent => "This connection already has an associated agent.",
RegisterAgentFailure.InvalidToken => "Invalid token.",
_ => "Unknown error " + (byte) message.FailureKind + "."
};
Logger.Fatal("Agent authentication failed: {Error}", errorMessage);
PhantomLogger.Dispose();
Environment.Exit(1);
}
private Task<Result<ConfigureInstanceResult, InstanceActionFailure>> HandleConfigureInstance(ConfigureInstanceMessage message, bool alwaysReportStatus) {
return agent.InstanceManager.Request(new InstanceManagerActor.ConfigureInstanceCommand(message.InstanceGuid, message.Configuration, message.LaunchProperties, message.LaunchNow, alwaysReportStatus));
}
private async Task<Result<ConfigureInstanceResult, InstanceActionFailure>> HandleConfigureInstance(ConfigureInstanceMessage message) { private async Task<Result<ConfigureInstanceResult, InstanceActionFailure>> HandleConfigureInstance(ConfigureInstanceMessage message) {
return await HandleConfigureInstance(message, alwaysReportStatus: false); return await agent.InstanceManager.Request(new InstanceManagerActor.ConfigureInstanceCommand(message.InstanceGuid, message.Configuration, message.LaunchProperties, message.LaunchNow, AlwaysReportStatus: false));
} }
private async Task<Result<LaunchInstanceResult, InstanceActionFailure>> HandleLaunchInstance(LaunchInstanceMessage message) { private async Task<Result<LaunchInstanceResult, InstanceActionFailure>> HandleLaunchInstance(LaunchInstanceMessage message) {

View File

@@ -1,7 +0,0 @@
using Phantom.Common.Data;
namespace Phantom.Agent.Services.Rpc;
sealed class RpcClientAgentHandshake(AuthToken authToken) {
}

View File

@@ -7,7 +7,8 @@ using Phantom.Common.Messages.Agent;
using Phantom.Common.Messages.Agent.ToController; using Phantom.Common.Messages.Agent.ToController;
using Phantom.Utils.Actor; using Phantom.Utils.Actor;
using Phantom.Utils.Logging; using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Runtime; using Phantom.Utils.Rpc.Message;
using Phantom.Utils.Rpc.Runtime.Client;
using Phantom.Utils.Runtime; using Phantom.Utils.Runtime;
using Phantom.Utils.Threading; using Phantom.Utils.Threading;
@@ -46,19 +47,17 @@ try {
return 1; return 1;
} }
var (certificateThumbprint, authToken) = agentKey.Value;
var agentInfo = new AgentInfo(agentGuid.Value, agentName, ProtocolVersion, fullVersion, maxInstances, maxMemory, allowedServerPorts, allowedRconPorts);
var rpcClientConnectionParameters = new RpcClientConnectionParameters( var rpcClientConnectionParameters = new RpcClientConnectionParameters(
Host: controllerHost, Host: controllerHost,
Port: controllerPort, Port: controllerPort,
DistinguishedName: "phantom-controller", DistinguishedName: "phantom-controller",
CertificateThumbprint: certificateThumbprint, CertificateThumbprint: agentKey.Value.CertificateThumbprint,
AuthToken: agentKey.Value.AuthToken,
SendQueueCapacity: 500, SendQueueCapacity: 500,
PingInterval: TimeSpan.FromSeconds(10) PingInterval: TimeSpan.FromSeconds(10)
); );
using var rpcClient = await RpcClient<IMessageToController, IMessageToAgent>.Connect("Controller", rpcClientConnectionParameters, null, AgentMessageRegistries.Definitions, shutdownCancellationToken); using var rpcClient = await RpcClient<IMessageToController, IMessageToAgent>.Connect("Controller", rpcClientConnectionParameters, AgentMessageRegistries.Definitions, shutdownCancellationToken);
if (rpcClient == null) { if (rpcClient == null) {
return 1; return 1;
} }
@@ -67,23 +66,30 @@ try {
try { try {
PhantomLogger.Root.InformationHeading("Launching Phantom Panel agent..."); PhantomLogger.Root.InformationHeading("Launching Phantom Panel agent...");
var agentInfo = new AgentInfo(agentGuid.Value, agentName, ProtocolVersion, fullVersion, maxInstances, maxMemory, allowedServerPorts, allowedRconPorts);
var agentServices = new AgentServices(agentInfo, folders, new AgentServiceConfiguration(maxConcurrentBackupCompressionTasks), new ControllerConnection(rpcClient.SendChannel)); var agentServices = new AgentServices(agentInfo, folders, new AgentServiceConfiguration(maxConcurrentBackupCompressionTasks), new ControllerConnection(rpcClient.SendChannel));
await agentServices.Initialize(); await agentServices.Initialize();
var rpcMessageHandlerInit = new ControllerMessageHandlerActor.Init(rpcClient.SendChannel, agentServices, shutdownCancellationTokenSource); var rpcMessageHandlerInit = new ControllerMessageHandlerActor.Init(agentServices);
var rpcMessageHandlerActor = agentServices.ActorSystem.ActorOf(ControllerMessageHandlerActor.Factory(rpcMessageHandlerInit), "ControllerMessageHandler"); var rpcMessageHandlerActor = agentServices.ActorSystem.ActorOf(ControllerMessageHandlerActor.Factory(rpcMessageHandlerInit), "ControllerMessageHandler");
PhantomLogger.Root.Information("Phantom Panel agent is ready."); rpcClientListener = rpcClient.Listen(new IMessageReceiver<IMessageToAgent>.Actor(rpcMessageHandlerActor));
rpcClientListener = rpcClient.Listen(rpcMessageHandlerActor);
if (await agentServices.Register(shutdownCancellationToken)) {
PhantomLogger.Root.Information("Phantom Panel agent is ready.");
await shutdownCancellationToken.WaitHandle.WaitOneAsync();
}
await shutdownCancellationToken.WaitHandle.WaitOneAsync();
await agentServices.Shutdown(); await agentServices.Shutdown();
} finally { } finally {
PhantomLogger.Root.Information("Unregistering agent...");
try { try {
await rpcClient.SendChannel.SendMessage(new UnregisterAgentMessage(), CancellationToken.None); using var unregisterCancellationTokenSource = new CancellationTokenSource(TimeSpan.FromSeconds(10));
// TODO wait for acknowledgment await rpcClient.SendChannel.SendMessage(new UnregisterAgentMessage(), unregisterCancellationTokenSource.Token);
} catch (OperationCanceledException) {
PhantomLogger.Root.Warning("Could not unregister agent after shutdown.");
} catch (Exception e) { } catch (Exception e) {
PhantomLogger.Root.Warning(e, "Could not unregister agent after shutdown."); PhantomLogger.Root.Warning(e, "Could not unregister agent during shutdown.");
} finally { } finally {
await rpcClient.Shutdown(); await rpcClient.Shutdown();

View File

@@ -1,4 +1,5 @@
using Phantom.Utils.Rpc.Runtime.Tls; using Phantom.Utils.Rpc;
using Phantom.Utils.Rpc.Runtime.Tls;
namespace Phantom.Common.Data; namespace Phantom.Common.Data;
@@ -7,7 +8,7 @@ public readonly record struct ConnectionKey(RpcCertificateThumbprint Certificate
public byte[] ToBytes() { public byte[] ToBytes() {
Span<byte> result = stackalloc byte[TokenLength + CertificateThumbprint.Bytes.Length]; Span<byte> result = stackalloc byte[TokenLength + CertificateThumbprint.Bytes.Length];
AuthToken.WriteTo(result[..TokenLength]); AuthToken.Bytes.CopyTo(result[..TokenLength]);
CertificateThumbprint.Bytes.CopyTo(result[TokenLength..]); CertificateThumbprint.Bytes.CopyTo(result[TokenLength..]);
return result.ToArray(); return result.ToArray();
} }

View File

@@ -1,6 +0,0 @@
namespace Phantom.Common.Data.Replies;
public enum RegisterAgentFailure : byte {
ConnectionAlreadyHasAnAgent,
InvalidToken
}

View File

@@ -1,5 +1,6 @@
using System.Diagnostics.CodeAnalysis; using System.Diagnostics.CodeAnalysis;
using MemoryPack; using MemoryPack;
using Phantom.Utils.Monads;
using Phantom.Utils.Result; using Phantom.Utils.Result;
namespace Phantom.Common.Data; namespace Phantom.Common.Data;
@@ -24,6 +25,9 @@ public sealed partial class Result<TValue, TError> {
[MemoryPackIgnore] [MemoryPackIgnore]
public TError Error => !hasValue ? error! : throw new InvalidOperationException("Attempted to get error from a success result."); public TError Error => !hasValue ? error! : throw new InvalidOperationException("Attempted to get error from a success result.");
[MemoryPackIgnore]
public Either<TValue, TError> AsEither => hasValue ? Either.Left(value!) : Either.Right(error!);
private Result(bool hasValue, TValue? value, TError? error) { private Result(bool hasValue, TValue? value, TError? error) {
this.hasValue = hasValue; this.hasValue = hasValue;
this.value = value; this.value = value;

View File

@@ -1,4 +1,5 @@
using Phantom.Common.Data; using System.Collections.Immutable;
using Phantom.Common.Data;
using Phantom.Common.Data.Replies; using Phantom.Common.Data.Replies;
using Phantom.Common.Messages.Agent.ToAgent; using Phantom.Common.Messages.Agent.ToAgent;
using Phantom.Common.Messages.Agent.ToController; using Phantom.Common.Messages.Agent.ToController;
@@ -14,16 +15,13 @@ public static class AgentMessageRegistries {
public static IMessageDefinitions<IMessageToController, IMessageToAgent> Definitions { get; } = new MessageDefinitions(); public static IMessageDefinitions<IMessageToController, IMessageToAgent> Definitions { get; } = new MessageDefinitions();
static AgentMessageRegistries() { static AgentMessageRegistries() {
ToAgent.Add<RegisterAgentSuccessMessage>(0);
ToAgent.Add<RegisterAgentFailureMessage>(1);
ToAgent.Add<ConfigureInstanceMessage, Result<ConfigureInstanceResult, InstanceActionFailure>>(2); ToAgent.Add<ConfigureInstanceMessage, Result<ConfigureInstanceResult, InstanceActionFailure>>(2);
ToAgent.Add<LaunchInstanceMessage, Result<LaunchInstanceResult, InstanceActionFailure>>(3); ToAgent.Add<LaunchInstanceMessage, Result<LaunchInstanceResult, InstanceActionFailure>>(3);
ToAgent.Add<StopInstanceMessage, Result<StopInstanceResult, InstanceActionFailure>>(4); ToAgent.Add<StopInstanceMessage, Result<StopInstanceResult, InstanceActionFailure>>(4);
ToAgent.Add<SendCommandToInstanceMessage, Result<SendCommandToInstanceResult, InstanceActionFailure>>(5); ToAgent.Add<SendCommandToInstanceMessage, Result<SendCommandToInstanceResult, InstanceActionFailure>>(5);
ToController.Add<RegisterAgentMessage>(0); ToController.Add<RegisterAgentMessage, ImmutableArray<ConfigureInstanceMessage>>(0);
ToController.Add<UnregisterAgentMessage>(1); ToController.Add<UnregisterAgentMessage>(1);
ToController.Add<AgentIsAliveMessage>(2);
ToController.Add<AdvertiseJavaRuntimesMessage>(3); ToController.Add<AdvertiseJavaRuntimesMessage>(3);
ToController.Add<ReportInstanceStatusMessage>(4); ToController.Add<ReportInstanceStatusMessage>(4);
ToController.Add<InstanceOutputMessage>(5); ToController.Add<InstanceOutputMessage>(5);

View File

@@ -1,9 +0,0 @@
using MemoryPack;
using Phantom.Common.Data.Replies;
namespace Phantom.Common.Messages.Agent.ToAgent;
[MemoryPackable(GenerateType.VersionTolerant)]
public sealed partial record RegisterAgentFailureMessage(
[property: MemoryPackOrder(0)] RegisterAgentFailure FailureKind
) : IMessageToAgent;

View File

@@ -1,9 +0,0 @@
using System.Collections.Immutable;
using MemoryPack;
namespace Phantom.Common.Messages.Agent.ToAgent;
[MemoryPackable(GenerateType.VersionTolerant)]
public sealed partial record RegisterAgentSuccessMessage(
[property: MemoryPackOrder(0)] ImmutableArray<ConfigureInstanceMessage> InitialInstanceConfigurations
) : IMessageToAgent;

View File

@@ -1,6 +0,0 @@
using MemoryPack;
namespace Phantom.Common.Messages.Agent.ToController;
[MemoryPackable(GenerateType.VersionTolerant)]
public sealed partial record AgentIsAliveMessage : IMessageToController;

View File

@@ -1,11 +1,12 @@
using MemoryPack; using System.Collections.Immutable;
using Phantom.Common.Data; using MemoryPack;
using Phantom.Common.Data.Agent; using Phantom.Common.Data.Agent;
using Phantom.Common.Messages.Agent.ToAgent;
using Phantom.Utils.Actor;
namespace Phantom.Common.Messages.Agent.ToController; namespace Phantom.Common.Messages.Agent.ToController;
[MemoryPackable(GenerateType.VersionTolerant)] [MemoryPackable(GenerateType.VersionTolerant)]
public sealed partial record RegisterAgentMessage( public sealed partial record RegisterAgentMessage(
[property: MemoryPackOrder(0)] AuthToken AuthToken, [property: MemoryPackOrder(0)] AgentInfo AgentInfo
[property: MemoryPackOrder(1)] AgentInfo AgentInfo ) : IMessageToController, ICanReply<ImmutableArray<ConfigureInstanceMessage>>;
) : IMessageToController;

View File

@@ -1,9 +0,0 @@
using MemoryPack;
using Phantom.Common.Data;
namespace Phantom.Common.Messages.Web.ToController;
[MemoryPackable(GenerateType.VersionTolerant)]
public sealed partial record RegisterWebMessage(
[property: MemoryPackOrder(0)] AuthToken AuthToken
) : IMessageToController;

View File

@@ -1,8 +0,0 @@
using MemoryPack;
namespace Phantom.Common.Messages.Web.ToWeb;
[MemoryPackable(GenerateType.VersionTolerant)]
public sealed partial record RegisterWebResultMessage(
[property: MemoryPackOrder(0)] bool Success
) : IMessageToWeb;

View File

@@ -21,7 +21,6 @@ public static class WebMessageRegistries {
public static IMessageDefinitions<IMessageToController, IMessageToWeb> Definitions { get; } = new MessageDefinitions(); public static IMessageDefinitions<IMessageToController, IMessageToWeb> Definitions { get; } = new MessageDefinitions();
static WebMessageRegistries() { static WebMessageRegistries() {
ToController.Add<RegisterWebMessage>(0);
ToController.Add<UnregisterWebMessage>(1); ToController.Add<UnregisterWebMessage>(1);
ToController.Add<LogInMessage, Optional<LogInSuccess>>(2); ToController.Add<LogInMessage, Optional<LogInSuccess>>(2);
ToController.Add<LogOutMessage>(3); ToController.Add<LogOutMessage>(3);
@@ -42,7 +41,6 @@ public static class WebMessageRegistries {
ToController.Add<GetAuditLogMessage, Result<ImmutableArray<AuditLogItem>, UserActionFailure>>(18); ToController.Add<GetAuditLogMessage, Result<ImmutableArray<AuditLogItem>, UserActionFailure>>(18);
ToController.Add<GetEventLogMessage, Result<ImmutableArray<EventLogItem>, UserActionFailure>>(19); ToController.Add<GetEventLogMessage, Result<ImmutableArray<EventLogItem>, UserActionFailure>>(19);
ToWeb.Add<RegisterWebResultMessage>(0);
ToWeb.Add<RefreshAgentsMessage>(1); ToWeb.Add<RefreshAgentsMessage>(1);
ToWeb.Add<RefreshInstancesMessage>(2); ToWeb.Add<RefreshInstancesMessage>(2);
ToWeb.Add<InstanceOutputMessage>(3); ToWeb.Add<InstanceOutputMessage>(3);

View File

@@ -21,6 +21,7 @@ using Phantom.Utils.Actor.Mailbox;
using Phantom.Utils.Actor.Tasks; using Phantom.Utils.Actor.Tasks;
using Phantom.Utils.Collections; using Phantom.Utils.Collections;
using Phantom.Utils.Logging; using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Runtime.Server;
using Serilog; using Serilog;
namespace Phantom.Controller.Services.Agents; namespace Phantom.Controller.Services.Agents;
@@ -167,13 +168,13 @@ sealed class AgentActor : ReceiveActor<AgentActor.ICommand> {
return configurationMessages.ToImmutable(); return configurationMessages.ToImmutable();
} }
public interface ICommand {} public interface ICommand;
private sealed record InitializeCommand : ICommand; private sealed record InitializeCommand : ICommand;
public sealed record RegisterCommand(AgentConfiguration Configuration, RpcConnectionToClient<IMessageToAgent> Connection) : ICommand, ICanReply<ImmutableArray<ConfigureInstanceMessage>>; public sealed record RegisterCommand(AgentConfiguration Configuration, RpcServerToClientConnection<IMessageToController, IMessageToAgent> Connection) : ICommand, ICanReply<ImmutableArray<ConfigureInstanceMessage>>;
public sealed record UnregisterCommand(RpcConnectionToClient<IMessageToAgent> Connection) : ICommand; public sealed record UnregisterCommand(RpcServerToClientConnection<IMessageToController, IMessageToAgent> Connection) : ICommand;
private sealed record RefreshConnectionStatusCommand : ICommand; private sealed record RefreshConnectionStatusCommand : ICommand;

View File

@@ -1,35 +1,30 @@
using Phantom.Common.Messages.Agent; using Phantom.Common.Messages.Agent;
using Phantom.Utils.Actor; using Phantom.Utils.Actor;
using Phantom.Utils.Logging; using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Runtime.Server;
using Serilog; using Serilog;
namespace Phantom.Controller.Services.Agents; namespace Phantom.Controller.Services.Agents;
sealed class AgentConnection { sealed class AgentConnection(Guid agentGuid, string agentName) {
private static readonly ILogger Logger = PhantomLogger.Create<AgentConnection>(); private static readonly ILogger Logger = PhantomLogger.Create<AgentConnection>();
private readonly Guid agentGuid; private string agentName = agentName;
private string agentName; private RpcServerToClientConnection<IMessageToController, IMessageToAgent>? connection;
private RpcConnectionToClient<IMessageToAgent>? connection; public void UpdateConnection(RpcServerToClientConnection<IMessageToController, IMessageToAgent> newConnection, string newAgentName) {
public AgentConnection(Guid agentGuid, string agentName) {
this.agentName = agentName;
this.agentGuid = agentGuid;
}
public void UpdateConnection(RpcConnectionToClient<IMessageToAgent> newConnection, string newAgentName) {
lock (this) { lock (this) {
connection?.Close(); connection?.ClientClosedSession();
connection = newConnection; connection = newConnection;
agentName = newAgentName; agentName = newAgentName;
} }
} }
public bool CloseIfSame(RpcConnectionToClient<IMessageToAgent> expected) { public bool CloseIfSame(RpcServerToClientConnection<IMessageToController, IMessageToAgent> expected) {
lock (this) { lock (this) {
if (connection != null && connection.IsSame(expected)) { if (connection != null && ReferenceEquals(connection, expected)) {
connection.Close(); connection.ClientClosedSession();
connection = null;
return true; return true;
} }
} }
@@ -44,7 +39,7 @@ sealed class AgentConnection {
return Task.CompletedTask; return Task.CompletedTask;
} }
return connection.Send(message); return connection.SendChannel.SendMessage(message).AsTask();
} }
} }
@@ -52,10 +47,10 @@ sealed class AgentConnection {
lock (this) { lock (this) {
if (connection == null) { if (connection == null) {
LogAgentOffline(); LogAgentOffline();
return Task.FromResult<TReply?>(default); return Task.FromResult<TReply?>(null);
} }
return connection.Send<TMessage, TReply>(message, waitForReplyTime, waitForReplyCancellationToken)!; return connection.SendChannel.SendMessage<TMessage, TReply>(message, waitForReplyTime, waitForReplyCancellationToken)!;
} }
} }

View File

@@ -13,6 +13,7 @@ using Phantom.Controller.Minecraft;
using Phantom.Controller.Services.Users.Sessions; using Phantom.Controller.Services.Users.Sessions;
using Phantom.Utils.Actor; using Phantom.Utils.Actor;
using Phantom.Utils.Logging; using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Runtime.Server;
using Serilog; using Serilog;
namespace Phantom.Controller.Services.Agents; namespace Phantom.Controller.Services.Agents;
@@ -21,19 +22,18 @@ sealed class AgentManager {
private static readonly ILogger Logger = PhantomLogger.Create<AgentManager>(); private static readonly ILogger Logger = PhantomLogger.Create<AgentManager>();
private readonly IActorRefFactory actorSystem; private readonly IActorRefFactory actorSystem;
private readonly AuthToken authToken;
private readonly ControllerState controllerState; private readonly ControllerState controllerState;
private readonly MinecraftVersions minecraftVersions; private readonly MinecraftVersions minecraftVersions;
private readonly UserLoginManager userLoginManager; private readonly UserLoginManager userLoginManager;
private readonly IDbContextProvider dbProvider; private readonly IDbContextProvider dbProvider;
private readonly CancellationToken cancellationToken; private readonly CancellationToken cancellationToken;
private readonly ConcurrentDictionary<Guid, ActorRef<AgentActor.ICommand>> agentsByGuid = new (); private readonly ConcurrentDictionary<Guid, ActorRef<AgentActor.ICommand>> agentsByAgentGuid = new ();
private readonly ConcurrentDictionary<Guid, Guid> agentGuidsBySessionGuid = new ();
private readonly Func<Guid, AgentConfiguration, ActorRef<AgentActor.ICommand>> addAgentActorFactory; private readonly Func<Guid, AgentConfiguration, ActorRef<AgentActor.ICommand>> addAgentActorFactory;
public AgentManager(IActorRefFactory actorSystem, AuthToken authToken, ControllerState controllerState, MinecraftVersions minecraftVersions, UserLoginManager userLoginManager, IDbContextProvider dbProvider, CancellationToken cancellationToken) { public AgentManager(IActorRefFactory actorSystem, ControllerState controllerState, MinecraftVersions minecraftVersions, UserLoginManager userLoginManager, IDbContextProvider dbProvider, CancellationToken cancellationToken) {
this.actorSystem = actorSystem; this.actorSystem = actorSystem;
this.authToken = authToken;
this.controllerState = controllerState; this.controllerState = controllerState;
this.minecraftVersions = minecraftVersions; this.minecraftVersions = minecraftVersions;
this.userLoginManager = userLoginManager; this.userLoginManager = userLoginManager;
@@ -56,28 +56,21 @@ sealed class AgentManager {
var agentGuid = entity.AgentGuid; var agentGuid = entity.AgentGuid;
var agentConfiguration = new AgentConfiguration(entity.Name, entity.ProtocolVersion, entity.BuildVersion, entity.MaxInstances, entity.MaxMemory); var agentConfiguration = new AgentConfiguration(entity.Name, entity.ProtocolVersion, entity.BuildVersion, entity.MaxInstances, entity.MaxMemory);
if (agentsByGuid.TryAdd(agentGuid, CreateAgentActor(agentGuid, agentConfiguration))) { if (agentsByAgentGuid.TryAdd(agentGuid, CreateAgentActor(agentGuid, agentConfiguration))) {
Logger.Information("Loaded agent \"{AgentName}\" (GUID {AgentGuid}) from database.", agentConfiguration.AgentName, agentGuid); Logger.Information("Loaded agent \"{AgentName}\" (GUID {AgentGuid}) from database.", agentConfiguration.AgentName, agentGuid);
} }
} }
} }
public async Task<bool> RegisterAgent(AuthToken authToken, AgentInfo agentInfo, RpcConnectionToClient<IMessageToAgent> connection) { public async Task<ImmutableArray<ConfigureInstanceMessage>> RegisterAgent(AgentInfo agentInfo, RpcServerToClientConnection<IMessageToController, IMessageToAgent> connection) {
if (!this.authToken.FixedTimeEquals(authToken)) {
await connection.Send(new RegisterAgentFailureMessage(RegisterAgentFailure.InvalidToken));
return false;
}
var agentProperties = AgentConfiguration.From(agentInfo); var agentProperties = AgentConfiguration.From(agentInfo);
var agentActor = agentsByGuid.GetOrAdd(agentInfo.AgentGuid, addAgentActorFactory, agentProperties); var agentActor = agentsByAgentGuid.GetOrAdd(agentInfo.AgentGuid, addAgentActorFactory, agentProperties);
var configureInstanceMessages = await agentActor.Request(new AgentActor.RegisterCommand(agentProperties, connection), cancellationToken); agentGuidsBySessionGuid[connection.SessionId] = agentInfo.AgentGuid;
await connection.Send(new RegisterAgentSuccessMessage(configureInstanceMessages)); return await agentActor.Request(new AgentActor.RegisterCommand(agentProperties, connection), cancellationToken);
return true;
} }
public bool TellAgent(Guid agentGuid, AgentActor.ICommand command) { public bool TellAgent(Guid agentGuid, AgentActor.ICommand command) {
if (agentsByGuid.TryGetValue(agentGuid, out var agent)) { if (agentsByAgentGuid.TryGetValue(agentGuid, out var agent)) {
agent.Tell(command); agent.Tell(command);
return true; return true;
} }
@@ -93,7 +86,7 @@ sealed class AgentManager {
return (UserInstanceActionFailure) UserActionFailure.NotAuthorized; return (UserInstanceActionFailure) UserActionFailure.NotAuthorized;
} }
if (!agentsByGuid.TryGetValue(agentGuid, out var agent)) { if (!agentsByAgentGuid.TryGetValue(agentGuid, out var agent)) {
return (UserInstanceActionFailure) InstanceActionFailure.AgentDoesNotExist; return (UserInstanceActionFailure) InstanceActionFailure.AgentDoesNotExist;
} }
@@ -101,4 +94,12 @@ sealed class AgentManager {
var result = await agent.Request(command, cancellationToken); var result = await agent.Request(command, cancellationToken);
return result.MapError(static error => (UserInstanceActionFailure) error); return result.MapError(static error => (UserInstanceActionFailure) error);
} }
public bool TryGetAgentGuidBySessionGuid(Guid sessionGuid, out Guid agentGuid) {
return agentGuidsBySessionGuid.TryGetValue(sessionGuid, out agentGuid);
}
public void OnSessionClosed(Guid sessionId, Guid agentGuid) {
agentGuidsBySessionGuid.TryRemove(new KeyValuePair<Guid, Guid>(sessionId, agentGuid));
}
} }

View File

@@ -1,9 +1,4 @@
using Akka.Actor; using Akka.Actor;
using Phantom.Common.Data;
using Phantom.Common.Messages.Agent;
using Phantom.Common.Messages.Agent.ToController;
using Phantom.Common.Messages.Web;
using Phantom.Common.Messages.Web.ToController;
using Phantom.Controller.Database; using Phantom.Controller.Database;
using Phantom.Controller.Minecraft; using Phantom.Controller.Minecraft;
using Phantom.Controller.Services.Agents; using Phantom.Controller.Services.Agents;
@@ -13,9 +8,9 @@ using Phantom.Controller.Services.Rpc;
using Phantom.Controller.Services.Users; using Phantom.Controller.Services.Users;
using Phantom.Controller.Services.Users.Sessions; using Phantom.Controller.Services.Users.Sessions;
using Phantom.Utils.Actor; using Phantom.Utils.Actor;
using Phantom.Utils.Rpc.Runtime2; using Phantom.Utils.Rpc;
using IMessageFromAgentToController = Phantom.Common.Messages.Agent.IMessageToController; using IRpcAgentRegistrar = Phantom.Utils.Rpc.Runtime.Server.IRpcServerClientRegistrar<Phantom.Common.Messages.Agent.IMessageToController, Phantom.Common.Messages.Agent.IMessageToAgent>;
using IMessageFromWebToController = Phantom.Common.Messages.Web.IMessageToController; using IRpcWebRegistrar = Phantom.Utils.Rpc.Runtime.Server.IRpcServerClientRegistrar<Phantom.Common.Messages.Web.IMessageToController, Phantom.Common.Messages.Web.IMessageToWeb>;
namespace Phantom.Controller.Services; namespace Phantom.Controller.Services;
@@ -38,13 +33,13 @@ public sealed class ControllerServices : IDisposable {
private AuditLogManager AuditLogManager { get; } private AuditLogManager AuditLogManager { get; }
private EventLogManager EventLogManager { get; } private EventLogManager EventLogManager { get; }
public IRegistrationHandler<IMessageToAgent, IMessageFromAgentToController, RegisterAgentMessage> AgentRegistrationHandler { get; } public IRpcAgentRegistrar AgentRegistrar { get; }
public IRegistrationHandler<IMessageToWeb, IMessageFromWebToController, RegisterWebMessage> WebRegistrationHandler { get; } public IRpcWebRegistrar WebRegistrar { get; }
private readonly IDbContextProvider dbProvider; private readonly IDbContextProvider dbProvider;
private readonly CancellationToken cancellationToken; private readonly CancellationToken cancellationToken;
public ControllerServices(IDbContextProvider dbProvider, AuthToken agentAuthToken, AuthToken webAuthToken, CancellationToken shutdownCancellationToken) { public ControllerServices(IDbContextProvider dbProvider, AuthToken agentAuthToken, CancellationToken shutdownCancellationToken) {
this.dbProvider = dbProvider; this.dbProvider = dbProvider;
this.cancellationToken = shutdownCancellationToken; this.cancellationToken = shutdownCancellationToken;
@@ -60,14 +55,14 @@ public sealed class ControllerServices : IDisposable {
this.UserLoginManager = new UserLoginManager(AuthenticatedUserCache, UserManager, dbProvider); this.UserLoginManager = new UserLoginManager(AuthenticatedUserCache, UserManager, dbProvider);
this.PermissionManager = new PermissionManager(dbProvider); this.PermissionManager = new PermissionManager(dbProvider);
this.AgentManager = new AgentManager(ActorSystem, agentAuthToken, ControllerState, MinecraftVersions, UserLoginManager, dbProvider, cancellationToken); this.AgentManager = new AgentManager(ActorSystem, ControllerState, MinecraftVersions, UserLoginManager, dbProvider, cancellationToken);
this.InstanceLogManager = new InstanceLogManager(); this.InstanceLogManager = new InstanceLogManager();
this.AuditLogManager = new AuditLogManager(dbProvider); this.AuditLogManager = new AuditLogManager(dbProvider);
this.EventLogManager = new EventLogManager(ControllerState, ActorSystem, dbProvider, shutdownCancellationToken); this.EventLogManager = new EventLogManager(ControllerState, ActorSystem, dbProvider, shutdownCancellationToken);
this.AgentRegistrationHandler = new AgentRegistrationHandler(AgentManager, InstanceLogManager, EventLogManager); this.AgentRegistrar = new AgentClientRegistrar(ActorSystem, AgentManager, InstanceLogManager, EventLogManager);
this.WebRegistrationHandler = new WebRegistrationHandler(webAuthToken, ControllerState, InstanceLogManager, UserManager, RoleManager, UserRoleManager, UserLoginManager, AuditLogManager, AgentManager, MinecraftVersions, EventLogManager); this.WebRegistrar = new WebClientRegistrar(ActorSystem, ControllerState, InstanceLogManager, UserManager, RoleManager, UserRoleManager, UserLoginManager, AuditLogManager, AgentManager, MinecraftVersions, EventLogManager);
} }
public async Task Initialize() { public async Task Initialize() {

View File

@@ -0,0 +1,31 @@
using Akka.Actor;
using Phantom.Common.Messages.Agent;
using Phantom.Controller.Services.Agents;
using Phantom.Controller.Services.Events;
using Phantom.Controller.Services.Instances;
using Phantom.Utils.Actor;
using Phantom.Utils.Rpc.Message;
using Phantom.Utils.Rpc.Runtime.Server;
namespace Phantom.Controller.Services.Rpc;
sealed class AgentClientRegistrar(
IActorRefFactory actorSystem,
AgentManager agentManager,
InstanceLogManager instanceLogManager,
EventLogManager eventLogManager
) : IRpcServerClientRegistrar<IMessageToController, IMessageToAgent> {
public IMessageReceiver<IMessageToController> Register(RpcServerToClientConnection<IMessageToController, IMessageToAgent> connection) {
var name = "AgentClient-" + connection.SessionId;
var init = new AgentMessageHandlerActor.Init(connection, agentManager, instanceLogManager, eventLogManager);
return new Receiver(connection.SessionId, agentManager, actorSystem.ActorOf(AgentMessageHandlerActor.Factory(init), name));
}
private sealed class Receiver(Guid sessionId, AgentManager agentManager, ActorRef<IMessageToController> actor) : IMessageReceiver<IMessageToController>.Actor(actor) {
public override void OnPing() {
if (agentManager.TryGetAgentGuidBySessionGuid(sessionId, out var agentGuid)) {
agentManager.TellAgent(agentGuid, new AgentActor.NotifyIsAliveCommand());
}
}
}
}

View File

@@ -1,4 +1,5 @@
using Phantom.Common.Data.Replies; using System.Collections.Immutable;
using Akka.Actor;
using Phantom.Common.Messages.Agent; using Phantom.Common.Messages.Agent;
using Phantom.Common.Messages.Agent.ToAgent; using Phantom.Common.Messages.Agent.ToAgent;
using Phantom.Common.Messages.Agent.ToController; using Phantom.Common.Messages.Agent.ToController;
@@ -6,86 +7,79 @@ using Phantom.Controller.Services.Agents;
using Phantom.Controller.Services.Events; using Phantom.Controller.Services.Events;
using Phantom.Controller.Services.Instances; using Phantom.Controller.Services.Instances;
using Phantom.Utils.Actor; using Phantom.Utils.Actor;
using Phantom.Utils.Rpc.Runtime.Server;
namespace Phantom.Controller.Services.Rpc; namespace Phantom.Controller.Services.Rpc;
sealed class AgentMessageHandlerActor : ReceiveActor<IMessageToController> { sealed class AgentMessageHandlerActor : ReceiveActor<IMessageToController> {
public readonly record struct Init(Guid AgentGuid, RpcConnectionToClient<IMessageToAgent> Connection, AgentRegistrationHandler AgentRegistrationHandler, AgentManager AgentManager, InstanceLogManager InstanceLogManager, EventLogManager EventLogManager); public readonly record struct Init(RpcServerToClientConnection<IMessageToController, IMessageToAgent> Connection, AgentManager AgentManager, InstanceLogManager InstanceLogManager, EventLogManager EventLogManager);
public static Props<IMessageToController> Factory(Init init) { public static Props<IMessageToController> Factory(Init init) {
return Props<IMessageToController>.Create(() => new AgentMessageHandlerActor(init), new ActorConfiguration { SupervisorStrategy = SupervisorStrategies.Resume }); return Props<IMessageToController>.Create(() => new AgentMessageHandlerActor(init), new ActorConfiguration { SupervisorStrategy = SupervisorStrategies.Resume });
} }
private readonly Guid agentGuid; private readonly RpcServerToClientConnection<IMessageToController, IMessageToAgent> connection;
private readonly RpcConnectionToClient<IMessageToAgent> connection;
private readonly AgentRegistrationHandler agentRegistrationHandler;
private readonly AgentManager agentManager; private readonly AgentManager agentManager;
private readonly InstanceLogManager instanceLogManager; private readonly InstanceLogManager instanceLogManager;
private readonly EventLogManager eventLogManager; private readonly EventLogManager eventLogManager;
private Guid? registeredAgentGuid;
private AgentMessageHandlerActor(Init init) { private AgentMessageHandlerActor(Init init) {
this.agentGuid = init.AgentGuid;
this.connection = init.Connection; this.connection = init.Connection;
this.agentRegistrationHandler = init.AgentRegistrationHandler;
this.agentManager = init.AgentManager; this.agentManager = init.AgentManager;
this.instanceLogManager = init.InstanceLogManager; this.instanceLogManager = init.InstanceLogManager;
this.eventLogManager = init.EventLogManager; this.eventLogManager = init.EventLogManager;
ReceiveAsync<RegisterAgentMessage>(HandleRegisterAgent); ReceiveAsyncAndReply<RegisterAgentMessage, ImmutableArray<ConfigureInstanceMessage>>(HandleRegisterAgent);
Receive<UnregisterAgentMessage>(HandleUnregisterAgent); ReceiveAsync<UnregisterAgentMessage>(HandleUnregisterAgent);
Receive<AgentIsAliveMessage>(HandleAgentIsAlive);
Receive<AdvertiseJavaRuntimesMessage>(HandleAdvertiseJavaRuntimes); Receive<AdvertiseJavaRuntimesMessage>(HandleAdvertiseJavaRuntimes);
Receive<ReportAgentStatusMessage>(HandleReportAgentStatus); Receive<ReportAgentStatusMessage>(HandleReportAgentStatus);
Receive<ReportInstanceStatusMessage>(HandleReportInstanceStatus); Receive<ReportInstanceStatusMessage>(HandleReportInstanceStatus);
Receive<ReportInstancePlayerCountsMessage>(HandleReportInstancePlayerCounts); Receive<ReportInstancePlayerCountsMessage>(HandleReportInstancePlayerCounts);
Receive<ReportInstanceEventMessage>(HandleReportInstanceEvent); Receive<ReportInstanceEventMessage>(HandleReportInstanceEvent);
Receive<InstanceOutputMessage>(HandleInstanceOutput); Receive<InstanceOutputMessage>(HandleInstanceOutput);
Receive<ReplyMessage>(HandleReply);
} }
private async Task HandleRegisterAgent(RegisterAgentMessage message) { private Guid RequireAgentGuid() {
if (agentGuid != message.AgentInfo.AgentGuid) { return registeredAgentGuid ?? throw new InvalidOperationException("Agent has not registered yet.");
await connection.Send(new RegisterAgentFailureMessage(RegisterAgentFailure.ConnectionAlreadyHasAnAgent));
}
else {
await agentRegistrationHandler.TryRegisterImpl(connection, message);
}
} }
private void HandleUnregisterAgent(UnregisterAgentMessage message) { private Task<ImmutableArray<ConfigureInstanceMessage>> HandleRegisterAgent(RegisterAgentMessage message) {
registeredAgentGuid = message.AgentInfo.AgentGuid;
return agentManager.RegisterAgent(message.AgentInfo, connection);
}
private Task HandleUnregisterAgent(UnregisterAgentMessage message) {
Guid agentGuid = RequireAgentGuid();
agentManager.TellAgent(agentGuid, new AgentActor.UnregisterCommand(connection)); agentManager.TellAgent(agentGuid, new AgentActor.UnregisterCommand(connection));
connection.Close(); agentManager.OnSessionClosed(connection.SessionId, agentGuid);
}
Self.Tell(PoisonPill.Instance);
private void HandleAgentIsAlive(AgentIsAliveMessage message) { return connection.ClientClosedSession();
agentManager.TellAgent(agentGuid, new AgentActor.NotifyIsAliveCommand());
} }
private void HandleAdvertiseJavaRuntimes(AdvertiseJavaRuntimesMessage message) { private void HandleAdvertiseJavaRuntimes(AdvertiseJavaRuntimesMessage message) {
agentManager.TellAgent(agentGuid, new AgentActor.UpdateJavaRuntimesCommand(message.Runtimes)); agentManager.TellAgent(RequireAgentGuid(), new AgentActor.UpdateJavaRuntimesCommand(message.Runtimes));
} }
private void HandleReportAgentStatus(ReportAgentStatusMessage message) { private void HandleReportAgentStatus(ReportAgentStatusMessage message) {
agentManager.TellAgent(agentGuid, new AgentActor.UpdateStatsCommand(message.RunningInstanceCount, message.RunningInstanceMemory)); agentManager.TellAgent(RequireAgentGuid(), new AgentActor.UpdateStatsCommand(message.RunningInstanceCount, message.RunningInstanceMemory));
} }
private void HandleReportInstanceStatus(ReportInstanceStatusMessage message) { private void HandleReportInstanceStatus(ReportInstanceStatusMessage message) {
agentManager.TellAgent(agentGuid, new AgentActor.UpdateInstanceStatusCommand(message.InstanceGuid, message.InstanceStatus)); agentManager.TellAgent(RequireAgentGuid(), new AgentActor.UpdateInstanceStatusCommand(message.InstanceGuid, message.InstanceStatus));
} }
private void HandleReportInstancePlayerCounts(ReportInstancePlayerCountsMessage message) { private void HandleReportInstancePlayerCounts(ReportInstancePlayerCountsMessage message) {
agentManager.TellAgent(agentGuid, new AgentActor.UpdateInstancePlayerCountsCommand(message.InstanceGuid, message.PlayerCounts)); agentManager.TellAgent(RequireAgentGuid(), new AgentActor.UpdateInstancePlayerCountsCommand(message.InstanceGuid, message.PlayerCounts));
} }
private void HandleReportInstanceEvent(ReportInstanceEventMessage message) { private void HandleReportInstanceEvent(ReportInstanceEventMessage message) {
message.Event.Accept(eventLogManager.CreateInstanceEventVisitor(message.EventGuid, message.UtcTime, agentGuid, message.InstanceGuid)); message.Event.Accept(eventLogManager.CreateInstanceEventVisitor(message.EventGuid, message.UtcTime, RequireAgentGuid(), message.InstanceGuid));
} }
private void HandleInstanceOutput(InstanceOutputMessage message) { private void HandleInstanceOutput(InstanceOutputMessage message) {
instanceLogManager.ReceiveLines(message.InstanceGuid, message.Lines); instanceLogManager.ReceiveLines(message.InstanceGuid, message.Lines);
} }
private void HandleReply(ReplyMessage message) {
connection.Receive(message);
}
} }

View File

@@ -1,34 +0,0 @@
using Phantom.Common.Messages.Agent;
using Phantom.Common.Messages.Agent.ToController;
using Phantom.Controller.Services.Agents;
using Phantom.Controller.Services.Events;
using Phantom.Controller.Services.Instances;
using Phantom.Utils.Actor;
using Phantom.Utils.Rpc.Runtime2;
namespace Phantom.Controller.Services.Rpc;
sealed class AgentRegistrationHandler : IRegistrationHandler<IMessageToAgent, IMessageToController, RegisterAgentMessage> {
private readonly AgentManager agentManager;
private readonly InstanceLogManager instanceLogManager;
private readonly EventLogManager eventLogManager;
public AgentRegistrationHandler(AgentManager agentManager, InstanceLogManager instanceLogManager, EventLogManager eventLogManager) {
this.agentManager = agentManager;
this.instanceLogManager = instanceLogManager;
this.eventLogManager = eventLogManager;
}
async Task<Props<IMessageToController>?> IRegistrationHandler<IMessageToAgent, IMessageToController, RegisterAgentMessage>.TryRegister(RpcConnectionToClient<IMessageToAgent> connection, RegisterAgentMessage message) {
return await TryRegisterImpl(connection, message) ? CreateMessageHandlerActorProps(message.AgentInfo.AgentGuid, connection) : null;
}
public Task<bool> TryRegisterImpl(RpcConnectionToClient<IMessageToAgent> connection, RegisterAgentMessage message) {
return agentManager.RegisterAgent(message.AuthToken, message.AgentInfo, connection);
}
private Props<IMessageToController> CreateMessageHandlerActorProps(Guid agentGuid, RpcConnectionToClient<IMessageToAgent> connection) {
var init = new AgentMessageHandlerActor.Init(agentGuid, connection, this, agentManager, instanceLogManager, eventLogManager);
return AgentMessageHandlerActor.Factory(init);
}
}

View File

@@ -1,34 +0,0 @@
using Phantom.Common.Data;
using Phantom.Common.Data.Agent;
using Phantom.Common.Messages.Agent.Handshake;
using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Runtime;
using Serilog;
namespace Phantom.Controller.Services.Rpc;
public sealed class RpcServerAgentHandshake(AuthToken authToken) : RpcServerHandshake {
private static readonly ILogger Logger = PhantomLogger.Create<RpcServerAgentHandshake>();
protected override async Task<bool> AcceptClient(string remoteAddress, Stream stream, CancellationToken cancellationToken) {
Memory<byte> buffer = new Memory<byte>(new byte[AuthToken.Length]);
await stream.ReadExactlyAsync(buffer, cancellationToken: cancellationToken);
if (!authToken.FixedTimeEquals(buffer.Span)) {
Logger.Warning("Rejected client {}, invalid authorization token.", remoteAddress);
await Respond(remoteAddress, stream, new InvalidAuthToken(), cancellationToken);
return false;
}
AgentInfo agentInfo = await Serialization.Deserialize<AgentInfo>(stream, cancellationToken);
return true;
}
private async ValueTask Respond(string remoteAddress, Stream stream, IAgentHandshakeResult result, CancellationToken cancellationToken) {
try {
await Serialization.Serialize(result, stream, cancellationToken);
} catch (Exception e) {
Logger.Error(e, "Could not send handshake result to client {}.", remoteAddress);
}
}
}

View File

@@ -0,0 +1,33 @@
using Akka.Actor;
using Phantom.Common.Messages.Web;
using Phantom.Controller.Minecraft;
using Phantom.Controller.Services.Agents;
using Phantom.Controller.Services.Events;
using Phantom.Controller.Services.Instances;
using Phantom.Controller.Services.Users;
using Phantom.Controller.Services.Users.Sessions;
using Phantom.Utils.Actor;
using Phantom.Utils.Rpc.Message;
using Phantom.Utils.Rpc.Runtime.Server;
namespace Phantom.Controller.Services.Rpc;
sealed class WebClientRegistrar(
IActorRefFactory actorSystem,
ControllerState controllerState,
InstanceLogManager instanceLogManager,
UserManager userManager,
RoleManager roleManager,
UserRoleManager userRoleManager,
UserLoginManager userLoginManager,
AuditLogManager auditLogManager,
AgentManager agentManager,
MinecraftVersions minecraftVersions,
EventLogManager eventLogManager
) : IRpcServerClientRegistrar<IMessageToController, IMessageToWeb> {
public IMessageReceiver<IMessageToController> Register(RpcServerToClientConnection<IMessageToController, IMessageToWeb> connection) {
var name = "WebClient-" + connection.SessionId;
var init = new WebMessageHandlerActor.Init(connection, controllerState, instanceLogManager, userManager, roleManager, userRoleManager, userLoginManager, auditLogManager, agentManager, minecraftVersions, eventLogManager);
return new IMessageReceiver<IMessageToController>.Actor(actorSystem.ActorOf(WebMessageHandlerActor.Factory(init), name));
}
}

View File

@@ -5,17 +5,18 @@ using Phantom.Common.Messages.Web;
using Phantom.Common.Messages.Web.ToWeb; using Phantom.Common.Messages.Web.ToWeb;
using Phantom.Controller.Services.Instances; using Phantom.Controller.Services.Instances;
using Phantom.Utils.Actor; using Phantom.Utils.Actor;
using Phantom.Utils.Rpc.Runtime;
namespace Phantom.Controller.Services.Rpc; namespace Phantom.Controller.Services.Rpc;
sealed class WebMessageDataUpdateSenderActor : ReceiveActor<WebMessageDataUpdateSenderActor.ICommand> { sealed class WebMessageDataUpdateSenderActor : ReceiveActor<WebMessageDataUpdateSenderActor.ICommand> {
public readonly record struct Init(RpcConnectionToClient<IMessageToWeb> Connection, ControllerState ControllerState, InstanceLogManager InstanceLogManager); public readonly record struct Init(RpcSendChannel<IMessageToWeb> Connection, ControllerState ControllerState, InstanceLogManager InstanceLogManager);
public static Props<ICommand> Factory(Init init) { public static Props<ICommand> Factory(Init init) {
return Props<ICommand>.Create(() => new WebMessageDataUpdateSenderActor(init), new ActorConfiguration { SupervisorStrategy = SupervisorStrategies.Resume }); return Props<ICommand>.Create(() => new WebMessageDataUpdateSenderActor(init), new ActorConfiguration { SupervisorStrategy = SupervisorStrategies.Resume });
} }
private readonly RpcConnectionToClient<IMessageToWeb> connection; private readonly RpcSendChannel<IMessageToWeb> connection;
private readonly ControllerState controllerState; private readonly ControllerState controllerState;
private readonly InstanceLogManager instanceLogManager; private readonly InstanceLogManager instanceLogManager;
private readonly ActorRef<ICommand> selfCached; private readonly ActorRef<ICommand> selfCached;
@@ -69,18 +70,18 @@ sealed class WebMessageDataUpdateSenderActor : ReceiveActor<WebMessageDataUpdate
private sealed record RefreshUserSessionCommand(Guid UserGuid) : ICommand; private sealed record RefreshUserSessionCommand(Guid UserGuid) : ICommand;
private Task RefreshAgents(RefreshAgentsCommand command) { private Task RefreshAgents(RefreshAgentsCommand command) {
return connection.Send(new RefreshAgentsMessage(command.Agents.Values.ToImmutableArray())); return connection.SendMessage(new RefreshAgentsMessage(command.Agents.Values.ToImmutableArray())).AsTask();
} }
private Task RefreshInstances(RefreshInstancesCommand command) { private Task RefreshInstances(RefreshInstancesCommand command) {
return connection.Send(new RefreshInstancesMessage(command.Instances.Values.ToImmutableArray())); return connection.SendMessage(new RefreshInstancesMessage(command.Instances.Values.ToImmutableArray())).AsTask();
} }
private Task ReceiveInstanceLogs(ReceiveInstanceLogsCommand command) { private Task ReceiveInstanceLogs(ReceiveInstanceLogsCommand command) {
return connection.Send(new InstanceOutputMessage(command.InstanceGuid, command.Lines)); return connection.SendMessage(new InstanceOutputMessage(command.InstanceGuid, command.Lines)).AsTask();
} }
private Task RefreshUserSession(RefreshUserSessionCommand command) { private Task RefreshUserSession(RefreshUserSessionCommand command) {
return connection.Send(new RefreshUserSessionMessage(command.UserGuid)); return connection.SendMessage(new RefreshUserSessionMessage(command.UserGuid)).AsTask();
} }
} }

View File

@@ -1,4 +1,5 @@
using System.Collections.Immutable; using System.Collections.Immutable;
using Akka.Actor;
using Phantom.Common.Data; using Phantom.Common.Data;
using Phantom.Common.Data.Java; using Phantom.Common.Data.Java;
using Phantom.Common.Data.Minecraft; using Phantom.Common.Data.Minecraft;
@@ -16,13 +17,13 @@ using Phantom.Controller.Services.Instances;
using Phantom.Controller.Services.Users; using Phantom.Controller.Services.Users;
using Phantom.Controller.Services.Users.Sessions; using Phantom.Controller.Services.Users.Sessions;
using Phantom.Utils.Actor; using Phantom.Utils.Actor;
using Phantom.Utils.Rpc.Runtime.Server;
namespace Phantom.Controller.Services.Rpc; namespace Phantom.Controller.Services.Rpc;
sealed class WebMessageHandlerActor : ReceiveActor<IMessageToController> { sealed class WebMessageHandlerActor : ReceiveActor<IMessageToController> {
public readonly record struct Init( public readonly record struct Init(
RpcConnectionToClient<IMessageToWeb> Connection, RpcServerToClientConnection<IMessageToController, IMessageToWeb> Connection,
WebRegistrationHandler WebRegistrationHandler,
ControllerState ControllerState, ControllerState ControllerState,
InstanceLogManager InstanceLogManager, InstanceLogManager InstanceLogManager,
UserManager UserManager, UserManager UserManager,
@@ -39,8 +40,7 @@ sealed class WebMessageHandlerActor : ReceiveActor<IMessageToController> {
return Props<IMessageToController>.Create(() => new WebMessageHandlerActor(init), new ActorConfiguration { SupervisorStrategy = SupervisorStrategies.Resume }); return Props<IMessageToController>.Create(() => new WebMessageHandlerActor(init), new ActorConfiguration { SupervisorStrategy = SupervisorStrategies.Resume });
} }
private readonly RpcConnectionToClient<IMessageToWeb> connection; private readonly RpcServerToClientConnection<IMessageToController, IMessageToWeb> connection;
private readonly WebRegistrationHandler webRegistrationHandler;
private readonly ControllerState controllerState; private readonly ControllerState controllerState;
private readonly UserManager userManager; private readonly UserManager userManager;
private readonly RoleManager roleManager; private readonly RoleManager roleManager;
@@ -53,7 +53,6 @@ sealed class WebMessageHandlerActor : ReceiveActor<IMessageToController> {
private WebMessageHandlerActor(Init init) { private WebMessageHandlerActor(Init init) {
this.connection = init.Connection; this.connection = init.Connection;
this.webRegistrationHandler = init.WebRegistrationHandler;
this.controllerState = init.ControllerState; this.controllerState = init.ControllerState;
this.userManager = init.UserManager; this.userManager = init.UserManager;
this.roleManager = init.RoleManager; this.roleManager = init.RoleManager;
@@ -64,11 +63,10 @@ sealed class WebMessageHandlerActor : ReceiveActor<IMessageToController> {
this.minecraftVersions = init.MinecraftVersions; this.minecraftVersions = init.MinecraftVersions;
this.eventLogManager = init.EventLogManager; this.eventLogManager = init.EventLogManager;
var senderActorInit = new WebMessageDataUpdateSenderActor.Init(connection, controllerState, init.InstanceLogManager); var senderActorInit = new WebMessageDataUpdateSenderActor.Init(connection.SendChannel, controllerState, init.InstanceLogManager);
Context.ActorOf(WebMessageDataUpdateSenderActor.Factory(senderActorInit), "DataUpdateSender"); Context.ActorOf(WebMessageDataUpdateSenderActor.Factory(senderActorInit), "DataUpdateSender");
ReceiveAsync<RegisterWebMessage>(HandleRegisterWeb); ReceiveAsync<UnregisterWebMessage>(HandleUnregisterWeb);
Receive<UnregisterWebMessage>(HandleUnregisterWeb);
ReceiveAndReplyLater<LogInMessage, Optional<LogInSuccess>>(HandleLogIn); ReceiveAndReplyLater<LogInMessage, Optional<LogInSuccess>>(HandleLogIn);
Receive<LogOutMessage>(HandleLogOut); Receive<LogOutMessage>(HandleLogOut);
ReceiveAndReply<GetAuthenticatedUser, Optional<AuthenticatedUserInfo>>(GetAuthenticatedUser); ReceiveAndReply<GetAuthenticatedUser, Optional<AuthenticatedUserInfo>>(GetAuthenticatedUser);
@@ -87,15 +85,11 @@ sealed class WebMessageHandlerActor : ReceiveActor<IMessageToController> {
ReceiveAndReply<GetAgentJavaRuntimesMessage, ImmutableDictionary<Guid, ImmutableArray<TaggedJavaRuntime>>>(HandleGetAgentJavaRuntimes); ReceiveAndReply<GetAgentJavaRuntimesMessage, ImmutableDictionary<Guid, ImmutableArray<TaggedJavaRuntime>>>(HandleGetAgentJavaRuntimes);
ReceiveAndReplyLater<GetAuditLogMessage, Result<ImmutableArray<AuditLogItem>, UserActionFailure>>(HandleGetAuditLog); ReceiveAndReplyLater<GetAuditLogMessage, Result<ImmutableArray<AuditLogItem>, UserActionFailure>>(HandleGetAuditLog);
ReceiveAndReplyLater<GetEventLogMessage, Result<ImmutableArray<EventLogItem>, UserActionFailure>>(HandleGetEventLog); ReceiveAndReplyLater<GetEventLogMessage, Result<ImmutableArray<EventLogItem>, UserActionFailure>>(HandleGetEventLog);
Receive<ReplyMessage>(HandleReply);
} }
private async Task HandleRegisterWeb(RegisterWebMessage message) { private Task HandleUnregisterWeb(UnregisterWebMessage message) {
await webRegistrationHandler.TryRegisterImpl(connection, message); Self.Tell(PoisonPill.Instance);
} return connection.ClientClosedSession();
private void HandleUnregisterWeb(UnregisterWebMessage message) {
connection.Close();
} }
private Task<Optional<LogInSuccess>> HandleLogIn(LogInMessage message) { private Task<Optional<LogInSuccess>> HandleLogIn(LogInMessage message) {
@@ -189,8 +183,4 @@ sealed class WebMessageHandlerActor : ReceiveActor<IMessageToController> {
private Task<Result<ImmutableArray<EventLogItem>, UserActionFailure>> HandleGetEventLog(GetEventLogMessage message) { private Task<Result<ImmutableArray<EventLogItem>, UserActionFailure>> HandleGetEventLog(GetEventLogMessage message) {
return eventLogManager.GetMostRecentItems(userLoginManager.GetLoggedInUser(message.AuthToken), message.Count); return eventLogManager.GetMostRecentItems(userLoginManager.GetLoggedInUser(message.AuthToken), message.Count);
} }
private void HandleReply(ReplyMessage message) {
connection.Receive(message);
}
} }

View File

@@ -1,68 +0,0 @@
using Phantom.Common.Data;
using Phantom.Common.Messages.Web;
using Phantom.Common.Messages.Web.ToController;
using Phantom.Common.Messages.Web.ToWeb;
using Phantom.Controller.Minecraft;
using Phantom.Controller.Services.Agents;
using Phantom.Controller.Services.Events;
using Phantom.Controller.Services.Instances;
using Phantom.Controller.Services.Users;
using Phantom.Controller.Services.Users.Sessions;
using Phantom.Utils.Actor;
using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Runtime2;
using Serilog;
namespace Phantom.Controller.Services.Rpc;
sealed class WebRegistrationHandler : IRegistrationHandler<IMessageToWeb, IMessageToController, RegisterWebMessage> {
private static readonly ILogger Logger = PhantomLogger.Create<WebRegistrationHandler>();
private readonly AuthToken webAuthToken;
private readonly ControllerState controllerState;
private readonly InstanceLogManager instanceLogManager;
private readonly UserManager userManager;
private readonly RoleManager roleManager;
private readonly UserRoleManager userRoleManager;
private readonly UserLoginManager userLoginManager;
private readonly AuditLogManager auditLogManager;
private readonly AgentManager agentManager;
private readonly MinecraftVersions minecraftVersions;
private readonly EventLogManager eventLogManager;
public WebRegistrationHandler(AuthToken webAuthToken, ControllerState controllerState, InstanceLogManager instanceLogManager, UserManager userManager, RoleManager roleManager, UserRoleManager userRoleManager, UserLoginManager userLoginManager, AuditLogManager auditLogManager, AgentManager agentManager, MinecraftVersions minecraftVersions, EventLogManager eventLogManager) {
this.webAuthToken = webAuthToken;
this.controllerState = controllerState;
this.userManager = userManager;
this.roleManager = roleManager;
this.userRoleManager = userRoleManager;
this.userLoginManager = userLoginManager;
this.auditLogManager = auditLogManager;
this.agentManager = agentManager;
this.minecraftVersions = minecraftVersions;
this.eventLogManager = eventLogManager;
this.instanceLogManager = instanceLogManager;
}
async Task<Props<IMessageToController>?> IRegistrationHandler<IMessageToWeb, IMessageToController, RegisterWebMessage>.TryRegister(RpcConnectionToClient<IMessageToWeb> connection, RegisterWebMessage message) {
return await TryRegisterImpl(connection, message) ? CreateMessageHandlerActorProps(connection) : null;
}
public async Task<bool> TryRegisterImpl(RpcConnectionToClient<IMessageToWeb> connection, RegisterWebMessage message) {
if (webAuthToken.FixedTimeEquals(message.AuthToken)) {
Logger.Information("Web authorized successfully.");
await connection.Send(new RegisterWebResultMessage(true));
return true;
}
else {
Logger.Warning("Web failed to authorize, invalid token.");
await connection.Send(new RegisterWebResultMessage(false));
return false;
}
}
private Props<IMessageToController> CreateMessageHandlerActorProps(RpcConnectionToClient<IMessageToWeb> connection) {
var init = new WebMessageHandlerActor.Init(connection, this, controllerState, instanceLogManager, userManager, roleManager, userRoleManager, userLoginManager, auditLogManager, agentManager, minecraftVersions, eventLogManager);
return WebMessageHandlerActor.Factory(init);
}
}

View File

@@ -1,5 +1,5 @@
using Phantom.Common.Data; using Phantom.Utils.Rpc;
using Phantom.Utils.Rpc.Runtime; using Phantom.Utils.Rpc.Runtime.Tls;
namespace Phantom.Controller; namespace Phantom.Controller;

View File

@@ -3,7 +3,7 @@ using Phantom.Utils.Cryptography;
using Phantom.Utils.IO; using Phantom.Utils.IO;
using Phantom.Utils.Logging; using Phantom.Utils.Logging;
using Phantom.Utils.Monads; using Phantom.Utils.Monads;
using Phantom.Utils.Rpc.Runtime; using Phantom.Utils.Rpc;
using Phantom.Utils.Rpc.Runtime.Tls; using Phantom.Utils.Rpc.Runtime.Tls;
using Serilog; using Serilog;

View File

@@ -1,13 +1,16 @@
using System.Reflection; using System.Reflection;
using Phantom.Common.Messages.Agent;
using Phantom.Common.Messages.Web;
using Phantom.Controller; using Phantom.Controller;
using Phantom.Controller.Database.Postgres; using Phantom.Controller.Database.Postgres;
using Phantom.Controller.Services; using Phantom.Controller.Services;
using Phantom.Controller.Services.Rpc;
using Phantom.Utils.IO; using Phantom.Utils.IO;
using Phantom.Utils.Logging; using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Runtime; using Phantom.Utils.Rpc.Runtime.Server;
using Phantom.Utils.Runtime; using Phantom.Utils.Runtime;
using Phantom.Utils.Tasks; using Phantom.Utils.Tasks;
using RpcAgentServer = Phantom.Utils.Rpc.Runtime.Server.RpcServer<Phantom.Common.Messages.Agent.IMessageToController, Phantom.Common.Messages.Agent.IMessageToAgent>;
using RpcWebServer = Phantom.Utils.Rpc.Runtime.Server.RpcServer<Phantom.Common.Messages.Web.IMessageToController, Phantom.Common.Messages.Web.IMessageToWeb>;
var shutdownCancellationTokenSource = new CancellationTokenSource(); var shutdownCancellationTokenSource = new CancellationTokenSource();
var shutdownCancellationToken = shutdownCancellationTokenSource.Token; var shutdownCancellationToken = shutdownCancellationTokenSource.Token;
@@ -54,12 +57,28 @@ try {
var dbContextFactory = new ApplicationDbContextFactory(sqlConnectionString); var dbContextFactory = new ApplicationDbContextFactory(sqlConnectionString);
using var controllerServices = new ControllerServices(dbContextFactory, agentKeyData.AuthToken, webKeyData.AuthToken, shutdownCancellationToken); using var controllerServices = new ControllerServices(dbContextFactory, agentKeyData.AuthToken, shutdownCancellationToken);
await controllerServices.Initialize(); await controllerServices.Initialize();
var agentConnectionParameters = new RpcServerConnectionParameters(
EndPoint: agentRpcServerHost,
Certificate: agentKeyData.Certificate,
AuthToken: agentKeyData.AuthToken,
SendQueueCapacity: 100,
PingInterval: TimeSpan.FromSeconds(10)
);
var webConnectionParameters = new RpcServerConnectionParameters(
EndPoint: webRpcServerHost,
Certificate: webKeyData.Certificate,
AuthToken: webKeyData.AuthToken,
SendQueueCapacity: 500,
PingInterval: TimeSpan.FromMinutes(1)
);
LinkedTasks<bool> rpcServerTasks = new LinkedTasks<bool>([ LinkedTasks<bool> rpcServerTasks = new LinkedTasks<bool>([
new RpcServer("Agent", agentRpcServerHost, agentKeyData.Certificate, new RpcServerAgentHandshake(agentKeyData.AuthToken)).Run(shutdownCancellationToken), new RpcAgentServer("Agent", agentConnectionParameters, AgentMessageRegistries.Definitions, controllerServices.AgentRegistrar).Run(shutdownCancellationToken),
// new RpcServer("Web", webRpcServerHost, webKeyData.Certificate).Run(shutdownCancellationToken), new RpcWebServer("Web", webConnectionParameters, WebMessageRegistries.Definitions, controllerServices.WebRegistrar).Run(shutdownCancellationToken),
]); ]);
// If either RPC server crashes, stop the whole process. // If either RPC server crashes, stop the whole process.
@@ -72,11 +91,6 @@ try {
} }
return 0; return 0;
// await Task.WhenAll(
// RpcServerRuntime.Launch(ConfigureRpc("Agent", agentRpcServerHost, agentRpcServerPort, agentKeyData), AgentMessageRegistries.Definitions, controllerServices.AgentRegistrationHandler, controllerServices.ActorSystem, shutdownCancellationToken),
// RpcServerRuntime.Launch(ConfigureRpc("Web", webRpcServerHost, webRpcServerPort, webKeyData), WebMessageRegistries.Definitions, controllerServices.WebRegistrationHandler, controllerServices.ActorSystem, shutdownCancellationToken)
// );
} catch (OperationCanceledException) { } catch (OperationCanceledException) {
return 0; return 0;
} catch (StopProcedureException) { } catch (StopProcedureException) {

View File

@@ -26,7 +26,7 @@ sealed record Variables(
EndPoint webRpcServerHost = new IPEndPoint( EndPoint webRpcServerHost = new IPEndPoint(
EnvironmentVariables.GetIpAddress("WEB_RPC_SERVER_HOST").WithDefault(IPAddress.Any), EnvironmentVariables.GetIpAddress("WEB_RPC_SERVER_HOST").WithDefault(IPAddress.Any),
EnvironmentVariables.GetPortNumber("WEB_RPC_SERVER_PORT").WithDefault(9401) EnvironmentVariables.GetPortNumber("WEB_RPC_SERVER_PORT").WithDefault(9402)
); );
return new Variables( return new Variables(

View File

@@ -23,7 +23,6 @@
<ItemGroup> <ItemGroup>
<PackageReference Update="BCrypt.Net-Next.StrongName" Version="4.0.3" /> <PackageReference Update="BCrypt.Net-Next.StrongName" Version="4.0.3" />
<PackageReference Update="MemoryPack" Version="1.10.0" /> <PackageReference Update="MemoryPack" Version="1.10.0" />
<PackageReference Update="NetMQ" Version="4.0.1.13" />
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>

View File

@@ -31,25 +31,48 @@ public static class PhantomLogger {
} }
public static ILogger Create<T>() { public static ILogger Create<T>() {
return Create(typeof(T).Name); return Create(TypeName<T>());
} }
public static ILogger Create<T>(string name) { public static ILogger Create<T>(string name) {
return Create(typeof(T).Name, name); return Create(ConcatNames(TypeName<T>(), name));
} }
public static ILogger Create<T>(string name1, string name2) { public static ILogger Create<T>(string name1, string name2) {
return Create(typeof(T).Name, ConcatNames(name1, name2)); return Create(ConcatNames(TypeName<T>(), name1, name2));
} }
public static ILogger Create<T1, T2>() { public static ILogger Create<T1, T2>() {
return Create(typeof(T1).Name, typeof(T2).Name); return Create(ConcatNames(TypeName<T1>(), TypeName<T2>()));
} }
private static string ConcatNames(string name1, string name2) { public static ILogger Create<T1, T2>(string name) {
return Create(ConcatNames(TypeName<T1>(), TypeName<T2>(), name));
}
public static ILogger Create<T1, T2>(string name1, string name2) {
return Create(ConcatNames(TypeName<T1>(), TypeName<T2>(), ConcatNames(name1, name2)));
}
private static string TypeName<T>() {
string typeName = typeof(T).Name;
int genericsStartIndex = typeName.IndexOf('`');
return genericsStartIndex > 0 ? typeName[..genericsStartIndex] : typeName;
}
public static string ConcatNames(string name1, string name2) {
return name1 + ":" + name2; return name1 + ":" + name2;
} }
public static string ConcatNames(string name1, string name2, string name3) {
return ConcatNames(name1, ConcatNames(name2, name3));
}
public static string ShortenGuid(Guid guid) {
var prefix = guid.ToString();
return prefix[..prefix.IndexOf('-')];
}
public static void Dispose() { public static void Dispose() {
Root.Dispose(); Root.Dispose();
Base.Dispose(); Base.Dispose();

View File

@@ -1,18 +1,12 @@
using System.Collections.Immutable; using System.Collections.Immutable;
using System.Diagnostics.CodeAnalysis;
using System.Security.Cryptography; using System.Security.Cryptography;
using MemoryPack;
namespace Phantom.Common.Data; namespace Phantom.Utils.Rpc;
[MemoryPackable(GenerateType.VersionTolerant)] public sealed class AuthToken {
[SuppressMessage("ReSharper", "MemberCanBePrivate.Global")]
public sealed partial class AuthToken {
public const int Length = 12; public const int Length = 12;
[MemoryPackOrder(0)] public ImmutableArray<byte> Bytes { get; }
[MemoryPackInclude]
public readonly ImmutableArray<byte> Bytes;
public AuthToken(ImmutableArray<byte> bytes) { public AuthToken(ImmutableArray<byte> bytes) {
if (bytes.Length != Length) { if (bytes.Length != Length) {
@@ -30,10 +24,6 @@ public sealed partial class AuthToken {
return CryptographicOperations.FixedTimeEquals(Bytes.AsSpan(), other); return CryptographicOperations.FixedTimeEquals(Bytes.AsSpan(), other);
} }
internal void WriteTo(Span<byte> span) {
Bytes.CopyTo(span);
}
public static AuthToken Generate() { public static AuthToken Generate() {
return new AuthToken([..RandomNumberGenerator.GetBytes(Length)]); return new AuthToken([..RandomNumberGenerator.GetBytes(Length)]);
} }

View File

@@ -4,11 +4,13 @@ namespace Phantom.Utils.Rpc.Frame;
interface IFrame { interface IFrame {
private const byte TypePingId = 0; private const byte TypePingId = 0;
private const byte TypeMessageId = 1; private const byte TypePongId = 1;
private const byte TypeReplyId = 2; private const byte TypeMessageId = 2;
private const byte TypeErrorId = 3; private const byte TypeReplyId = 3;
private const byte TypeErrorId = 4;
static readonly ReadOnlyMemory<byte> TypePing = new ([TypePingId]); static readonly ReadOnlyMemory<byte> TypePing = new ([TypePingId]);
static readonly ReadOnlyMemory<byte> TypePong = new ([TypePongId]);
static readonly ReadOnlyMemory<byte> TypeMessage = new ([TypeMessageId]); static readonly ReadOnlyMemory<byte> TypeMessage = new ([TypeMessageId]);
static readonly ReadOnlyMemory<byte> TypeReply = new ([TypeReplyId]); static readonly ReadOnlyMemory<byte> TypeReply = new ([TypeReplyId]);
static readonly ReadOnlyMemory<byte> TypeError = new ([TypeErrorId]); static readonly ReadOnlyMemory<byte> TypeError = new ([TypeErrorId]);
@@ -21,11 +23,18 @@ interface IFrame {
switch (oneByteBuffer[0]) { switch (oneByteBuffer[0]) {
case TypePingId: case TypePingId:
var pingTime = await PingFrame.Read(stream, cancellationToken);
await reader.OnPingFrame(pingTime, cancellationToken);
break;
case TypePongId:
var pongFrame = await PongFrame.Read(stream, cancellationToken);
reader.OnPongFrame(pongFrame);
break; break;
case TypeMessageId: case TypeMessageId:
var messageFrame = await MessageFrame.Read(stream, cancellationToken); var messageFrame = await MessageFrame.Read(stream, cancellationToken);
await reader.OnMessageFrame(messageFrame, stream, cancellationToken); await reader.OnMessageFrame(messageFrame, cancellationToken);
break; break;
case TypeReplyId: case TypeReplyId:
@@ -39,13 +48,13 @@ interface IFrame {
break; break;
default: default:
reader.OnUnknownFrameStart(oneByteBuffer[0]); reader.OnUnknownFrameId(oneByteBuffer[0]);
break; break;
} }
} }
} }
ReadOnlyMemory<byte> Type { get; } ReadOnlyMemory<byte> FrameType { get; }
Task Write(Stream stream, CancellationToken cancellationToken); Task Write(Stream stream, CancellationToken cancellationToken = default);
} }

View File

@@ -3,8 +3,10 @@
namespace Phantom.Utils.Rpc.Frame; namespace Phantom.Utils.Rpc.Frame;
interface IFrameReader { interface IFrameReader {
Task OnMessageFrame(MessageFrame frame, Stream stream, CancellationToken cancellationToken); ValueTask OnPingFrame(DateTimeOffset pingTime, CancellationToken cancellationToken);
void OnPongFrame(PongFrame frame);
Task OnMessageFrame(MessageFrame frame, CancellationToken cancellationToken);
void OnReplyFrame(ReplyFrame frame); void OnReplyFrame(ReplyFrame frame);
void OnErrorFrame(ErrorFrame frame); void OnErrorFrame(ErrorFrame frame);
void OnUnknownFrameStart(byte id); void OnUnknownFrameId(byte frameId);
} }

View File

@@ -3,7 +3,7 @@
namespace Phantom.Utils.Rpc.Frame.Types; namespace Phantom.Utils.Rpc.Frame.Types;
sealed record ErrorFrame(uint ReplyingToMessageId, RpcError Error) : IFrame { sealed record ErrorFrame(uint ReplyingToMessageId, RpcError Error) : IFrame {
public ReadOnlyMemory<byte> Type => IFrame.TypeError; public ReadOnlyMemory<byte> FrameType => IFrame.TypeError;
public async Task Write(Stream stream, CancellationToken cancellationToken) { public async Task Write(Stream stream, CancellationToken cancellationToken) {
await Serialization.WriteUnsignedInt(ReplyingToMessageId, stream, cancellationToken); await Serialization.WriteUnsignedInt(ReplyingToMessageId, stream, cancellationToken);

View File

@@ -5,7 +5,7 @@ namespace Phantom.Utils.Rpc.Frame.Types;
sealed record MessageFrame(uint MessageId, ushort RegistryCode, ReadOnlyMemory<byte> SerializedMessage) : IFrame { sealed record MessageFrame(uint MessageId, ushort RegistryCode, ReadOnlyMemory<byte> SerializedMessage) : IFrame {
public const int MaxMessageBytes = 1024 * 1024 * 8; public const int MaxMessageBytes = 1024 * 1024 * 8;
public ReadOnlyMemory<byte> Type => IFrame.TypeMessage; public ReadOnlyMemory<byte> FrameType => IFrame.TypeMessage;
public async Task Write(Stream stream, CancellationToken cancellationToken) { public async Task Write(Stream stream, CancellationToken cancellationToken) {
int messageLength = SerializedMessage.Length; int messageLength = SerializedMessage.Length;
@@ -29,7 +29,7 @@ sealed record MessageFrame(uint MessageId, ushort RegistryCode, ReadOnlyMemory<b
private static void CheckMessageLength(int messageLength) { private static void CheckMessageLength(int messageLength) {
if (messageLength < 0) { if (messageLength < 0) {
throw new RpcErrorException("Message length is negative", RpcError.InvalidData); throw new RpcErrorException("Message length is negative.", RpcError.InvalidData);
} }
if (messageLength > MaxMessageBytes) { if (messageLength > MaxMessageBytes) {

View File

@@ -1,11 +1,15 @@
namespace Phantom.Utils.Rpc.Frame.Types; namespace Phantom.Utils.Rpc.Frame.Types;
sealed record PingFrame : IFrame { sealed record PingFrame : IFrame {
public static PingFrame Instance { get; } = new PingFrame(); public static PingFrame Instance { get; } = new ();
public ReadOnlyMemory<byte> Type => IFrame.TypePing; public ReadOnlyMemory<byte> FrameType => IFrame.TypePing;
public Task Write(Stream stream, CancellationToken cancellationToken) { public async Task Write(Stream stream, CancellationToken cancellationToken) {
return Task.CompletedTask; await Serialization.WriteSignedLong(DateTimeOffset.UtcNow.ToUnixTimeMilliseconds(), stream, cancellationToken);
}
public static async Task<DateTimeOffset> Read(Stream stream, CancellationToken cancellationToken) {
return DateTimeOffset.FromUnixTimeMilliseconds(await Serialization.ReadSignedLong(stream, cancellationToken));
} }
} }

View File

@@ -0,0 +1,13 @@
namespace Phantom.Utils.Rpc.Frame.Types;
sealed record PongFrame(DateTimeOffset PingTime) : IFrame {
public ReadOnlyMemory<byte> FrameType => IFrame.TypePong;
public async Task Write(Stream stream, CancellationToken cancellationToken) {
await Serialization.WriteSignedLong(PingTime.ToUnixTimeMilliseconds(), stream, cancellationToken);
}
public static async Task<PongFrame> Read(Stream stream, CancellationToken cancellationToken) {
return new PongFrame(DateTimeOffset.FromUnixTimeMilliseconds(await Serialization.ReadSignedLong(stream, cancellationToken)));
}
}

View File

@@ -5,7 +5,7 @@ namespace Phantom.Utils.Rpc.Frame.Types;
sealed record ReplyFrame(uint ReplyingToMessageId, ReadOnlyMemory<byte> SerializedReply) : IFrame { sealed record ReplyFrame(uint ReplyingToMessageId, ReadOnlyMemory<byte> SerializedReply) : IFrame {
public const int MaxReplyBytes = 1024 * 1024 * 32; public const int MaxReplyBytes = 1024 * 1024 * 32;
public ReadOnlyMemory<byte> Type => IFrame.TypeReply; public ReadOnlyMemory<byte> FrameType => IFrame.TypeReply;
public async Task Write(Stream stream, CancellationToken cancellationToken) { public async Task Write(Stream stream, CancellationToken cancellationToken) {
int replyLength = SerializedReply.Length; int replyLength = SerializedReply.Length;
@@ -27,7 +27,7 @@ sealed record ReplyFrame(uint ReplyingToMessageId, ReadOnlyMemory<byte> Serializ
private static void CheckReplyLength(int replyLength) { private static void CheckReplyLength(int replyLength) {
if (replyLength < 0) { if (replyLength < 0) {
throw new RpcErrorException("Reply length is negative", RpcError.InvalidData); throw new RpcErrorException("Reply length is negative.", RpcError.InvalidData);
} }
if (replyLength > MaxReplyBytes) { if (replyLength > MaxReplyBytes) {

View File

@@ -0,0 +1,5 @@
namespace Phantom.Utils.Rpc;
public interface IRpcListener {
}

View File

@@ -0,0 +1,8 @@
using Phantom.Utils.Rpc.Runtime;
namespace Phantom.Utils.Rpc;
interface IRpcReplySender {
ValueTask SendReply<TReply>(uint replyingToMessageId, TReply reply, CancellationToken cancellationToken);
ValueTask SendError(uint replyingToMessageId, RpcError error, CancellationToken cancellationToken);
}

View File

@@ -0,0 +1,21 @@
using Phantom.Utils.Actor;
namespace Phantom.Utils.Rpc.Message;
public interface IMessageReceiver<TMessageBase> {
void OnPing();
void OnMessage(TMessageBase message);
Task<TReply> OnMessage<TMessage, TReply>(TMessage message, CancellationToken cancellationToken = default) where TMessage : TMessageBase, ICanReply<TReply>;
class Actor(ActorRef<TMessageBase> actor) : IMessageReceiver<TMessageBase> {
public virtual void OnPing() {}
public void OnMessage(TMessageBase message) {
actor.Tell(message);
}
public Task<TReply> OnMessage<TMessage, TReply>(TMessage message, CancellationToken cancellationToken = default) where TMessage : TMessageBase, ICanReply<TReply> {
return actor.Request(message, cancellationToken);
}
}
}

View File

@@ -1,11 +0,0 @@
using Phantom.Utils.Actor;
using Phantom.Utils.Rpc.Runtime;
namespace Phantom.Utils.Rpc.Message;
interface MessageHandler<TMessageBase> {
ActorRef<TMessageBase> Actor { get; }
ValueTask OnReply<TMessage, TReply>(uint messageId, TReply reply, CancellationToken cancellationToken) where TMessage : TMessageBase, ICanReply<TReply>;
ValueTask OnError(uint messageId, RpcError error, CancellationToken cancellationToken);
}

View File

@@ -0,0 +1,13 @@
using Phantom.Utils.Collections;
namespace Phantom.Utils.Rpc.Message;
sealed class MessageReceiveTracker {
private readonly RangeSet<uint> receivedMessageIds = new ();
public bool ReceiveMessage(uint messageId) {
lock (receivedMessageIds) {
return receivedMessageIds.Add(messageId);
}
}
}

View File

@@ -1,4 +1,5 @@
using Phantom.Utils.Actor; using System.Diagnostics.CodeAnalysis;
using Phantom.Utils.Actor;
using Phantom.Utils.Rpc.Frame.Types; using Phantom.Utils.Rpc.Frame.Types;
using Phantom.Utils.Rpc.Runtime; using Phantom.Utils.Rpc.Runtime;
using Serilog; using Serilog;
@@ -8,11 +9,11 @@ namespace Phantom.Utils.Rpc.Message;
public sealed class MessageRegistry<TMessageBase>(ILogger logger) { public sealed class MessageRegistry<TMessageBase>(ILogger logger) {
private readonly Dictionary<Type, ushort> typeToCodeMapping = new (); private readonly Dictionary<Type, ushort> typeToCodeMapping = new ();
private readonly Dictionary<ushort, Type> codeToTypeMapping = new (); private readonly Dictionary<ushort, Type> codeToTypeMapping = new ();
private readonly Dictionary<ushort, Func<uint, ReadOnlyMemory<byte>, MessageHandler<TMessageBase>, CancellationToken, Task>> codeToHandlerMapping = new (); private readonly Dictionary<ushort, Func<uint, ReadOnlyMemory<byte>, RpcMessageHandler<TMessageBase>, CancellationToken, Task>> codeToHandlerMapping = new ();
public void Add<TMessage>(ushort code) where TMessage : TMessageBase { public void Add<TMessage>(ushort code) where TMessage : TMessageBase {
if (HasReplyType(typeof(TMessage))) { if (HasReplyType(typeof(TMessage))) {
throw new ArgumentException("This overload is for messages without a reply"); throw new ArgumentException("This overload is for messages without a reply.");
} }
AddTypeCodeMapping<TMessage>(code); AddTypeCodeMapping<TMessage>(code);
@@ -36,6 +37,10 @@ public sealed class MessageRegistry<TMessageBase>(ILogger logger) {
return messageType.GetInterfaces().Any(type => type.FullName is {} name && name.StartsWith(replyInterfaceName, StringComparison.Ordinal)); return messageType.GetInterfaces().Any(type => type.FullName is {} name && name.StartsWith(replyInterfaceName, StringComparison.Ordinal));
} }
internal bool TryGetType(MessageFrame frame, [NotNullWhen(true)] out Type? type) {
return codeToTypeMapping.TryGetValue(frame.RegistryCode, out type);
}
internal MessageFrame CreateFrame<TMessage>(uint messageId, TMessage message) where TMessage : TMessageBase { internal MessageFrame CreateFrame<TMessage>(uint messageId, TMessage message) where TMessage : TMessageBase {
if (typeToCodeMapping.TryGetValue(typeof(TMessage), out ushort code)) { if (typeToCodeMapping.TryGetValue(typeof(TMessage), out ushort code)) {
return new MessageFrame(messageId, code, Serialization.Serialize(message)); return new MessageFrame(messageId, code, Serialization.Serialize(message));
@@ -45,7 +50,7 @@ public sealed class MessageRegistry<TMessageBase>(ILogger logger) {
} }
} }
internal async Task Handle(MessageFrame frame, MessageHandler<TMessageBase> handler, CancellationToken cancellationToken) { internal async Task Handle(MessageFrame frame, RpcMessageHandler<TMessageBase> handler, CancellationToken cancellationToken) {
uint messageId = frame.MessageId; uint messageId = frame.MessageId;
if (codeToHandlerMapping.TryGetValue(frame.RegistryCode, out var action)) { if (codeToHandlerMapping.TryGetValue(frame.RegistryCode, out var action)) {
@@ -53,47 +58,47 @@ public sealed class MessageRegistry<TMessageBase>(ILogger logger) {
} }
else { else {
logger.Error("Unknown message code {Code} for message {MessageId}.", frame.RegistryCode, messageId); logger.Error("Unknown message code {Code} for message {MessageId}.", frame.RegistryCode, messageId);
await handler.OnError(messageId, RpcError.UnknownMessageRegistryCode, cancellationToken); await handler.SendError(messageId, RpcError.UnknownMessageRegistryCode, cancellationToken);
} }
} }
private async Task DeserializationHandler<TMessage>(uint messageId, ReadOnlyMemory<byte> serializedMessage, MessageHandler<TMessageBase> handler, CancellationToken cancellationToken) where TMessage : TMessageBase { private async Task DeserializationHandler<TMessage>(uint messageId, ReadOnlyMemory<byte> serializedMessage, RpcMessageHandler<TMessageBase> handler, CancellationToken cancellationToken) where TMessage : TMessageBase {
TMessage message; TMessage message;
try { try {
message = Serialization.Deserialize<TMessage>(serializedMessage); message = Serialization.Deserialize<TMessage>(serializedMessage);
} catch (Exception e) { } catch (Exception e) {
logger.Error(e, "Could not deserialize message {MessageId} ({MessageType}).", messageId, typeof(TMessage).Name); logger.Error(e, "Could not deserialize message {MessageId} ({MessageType}).", messageId, typeof(TMessage).Name);
await handler.OnError(messageId, RpcError.MessageDeserializationError, cancellationToken); await handler.SendError(messageId, RpcError.MessageDeserializationError, cancellationToken);
return; return;
} }
handler.Actor.Tell(message); handler.Receiver.OnMessage(message);
} }
private async Task DeserializationHandler<TMessage, TReply>(uint messageId, ReadOnlyMemory<byte> serializedMessage, MessageHandler<TMessageBase> handler, CancellationToken cancellationToken) where TMessage : TMessageBase, ICanReply<TReply> { private async Task DeserializationHandler<TMessage, TReply>(uint messageId, ReadOnlyMemory<byte> serializedMessage, RpcMessageHandler<TMessageBase> handler, CancellationToken cancellationToken) where TMessage : TMessageBase, ICanReply<TReply> {
TMessage message; TMessage message;
try { try {
message = Serialization.Deserialize<TMessage>(serializedMessage); message = Serialization.Deserialize<TMessage>(serializedMessage);
} catch (Exception e) { } catch (Exception e) {
logger.Error(e, "Could not deserialize message {MessageId} ({MessageType}).", messageId, typeof(TMessage).Name); logger.Error(e, "Could not deserialize message {MessageId} ({MessageType}).", messageId, typeof(TMessage).Name);
await handler.OnError(messageId, RpcError.MessageDeserializationError, cancellationToken); await handler.SendError(messageId, RpcError.MessageDeserializationError, cancellationToken);
return; return;
} }
TReply reply; TReply reply;
try { try {
reply = await handler.Actor.Request(message, cancellationToken); reply = await handler.Receiver.OnMessage<TMessage, TReply>(message, cancellationToken);
} catch (Exception e) { } catch (Exception e) {
logger.Error(e, "Could not handle message {MessageId} ({MessageType}).", messageId, typeof(TMessage).Name); logger.Error(e, "Could not handle message {MessageId} ({MessageType}).", messageId, typeof(TMessage).Name);
await handler.OnError(messageId, RpcError.MessageHandlingError, cancellationToken); await handler.SendError(messageId, RpcError.MessageHandlingError, cancellationToken);
return; return;
} }
try { try {
await handler.OnReply<TMessage, TReply>(messageId, reply, cancellationToken); await handler.SendReply(messageId, reply, cancellationToken);
} catch (Exception e) { } catch (Exception e) {
logger.Error(e, "Could not reply to message {MessageId} ({MessageType}).", messageId, typeof(TMessage).Name); logger.Error(e, "Could not reply to message {MessageId} ({MessageType}).", messageId, typeof(TMessage).Name);
await handler.OnError(messageId, RpcError.MessageHandlingError, cancellationToken); await handler.SendError(messageId, RpcError.MessageHandlingError, cancellationToken);
} }
} }
} }

View File

@@ -8,7 +8,7 @@ namespace Phantom.Utils.Rpc.Message;
sealed class MessageReplyTracker { sealed class MessageReplyTracker {
private readonly ILogger logger; private readonly ILogger logger;
private readonly ConcurrentDictionary<uint, TaskCompletionSource<ReadOnlyMemory<byte>>> replyTasks = new (concurrencyLevel: 4, capacity: 16); private readonly ConcurrentDictionary<uint, TaskCompletionSource<ReadOnlyMemory<byte>>> replyTasks = new (concurrencyLevel: 2, capacity: 16);
internal MessageReplyTracker(string loggerName) { internal MessageReplyTracker(string loggerName) {
this.logger = PhantomLogger.Create<MessageReplyTracker>(loggerName); this.logger = PhantomLogger.Create<MessageReplyTracker>(loggerName);

View File

@@ -7,7 +7,6 @@
<ItemGroup> <ItemGroup>
<PackageReference Include="MemoryPack" /> <PackageReference Include="MemoryPack" />
<PackageReference Include="NetMQ" />
<PackageReference Include="Serilog" /> <PackageReference Include="Serilog" />
</ItemGroup> </ItemGroup>

View File

@@ -0,0 +1,66 @@
using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Message;
using Serilog;
namespace Phantom.Utils.Rpc.Runtime.Client;
public sealed class RpcClient<TClientToServerMessage, TServerToClientMessage> : IDisposable {
public static async Task<RpcClient<TClientToServerMessage, TServerToClientMessage>?> Connect(string loggerName, RpcClientConnectionParameters connectionParameters, IMessageDefinitions<TClientToServerMessage, TServerToClientMessage> messageDefinitions, CancellationToken cancellationToken) {
RpcClientToServerConnector connector = new RpcClientToServerConnector(loggerName, connectionParameters);
RpcClientToServerConnector.Connection? connection = await connector.EstablishNewConnection(cancellationToken);
return connection == null ? null : new RpcClient<TClientToServerMessage, TServerToClientMessage>(loggerName, connectionParameters, messageDefinitions, connector, connection);
}
private readonly string loggerName;
private readonly ILogger logger;
private readonly MessageRegistry<TServerToClientMessage> messageRegistry;
private readonly MessageReceiveTracker messageReceiveTracker = new ();
private readonly RpcClientToServerConnection connection;
private readonly CancellationTokenSource shutdownCancellationTokenSource = new ();
public RpcSendChannel<TClientToServerMessage> SendChannel { get; }
private RpcClient(string loggerName, RpcClientConnectionParameters connectionParameters, IMessageDefinitions<TClientToServerMessage, TServerToClientMessage> messageDefinitions, RpcClientToServerConnector connector, RpcClientToServerConnector.Connection connection) {
this.loggerName = loggerName;
this.logger = PhantomLogger.Create<RpcClient<TClientToServerMessage, TServerToClientMessage>>(loggerName);
this.messageRegistry = messageDefinitions.ToClient;
this.connection = new RpcClientToServerConnection(loggerName, connector, connection);
this.SendChannel = new RpcSendChannel<TClientToServerMessage>(loggerName, connectionParameters.Common, this.connection, messageDefinitions.ToServer);
}
public async Task Listen(IMessageReceiver<TServerToClientMessage> receiver) {
var messageHandler = new RpcMessageHandler<TServerToClientMessage>(receiver, SendChannel);
var frameReader = new RpcFrameReader<TClientToServerMessage, TServerToClientMessage>(loggerName, messageRegistry, messageReceiveTracker, messageHandler, SendChannel);
try {
await connection.ReadConnection(frameReader, shutdownCancellationTokenSource.Token);
} catch (OperationCanceledException) {
// Ignore.
}
}
public async Task Shutdown() {
logger.Information("Shutting down client...");
try {
await SendChannel.Close();
} catch (Exception e) {
logger.Error(e, "Caught exception while closing send channel.");
}
try {
connection.StopReconnecting();
} catch (Exception e) {
logger.Error(e, "Caught exception while closing connection.");
}
await shutdownCancellationTokenSource.CancelAsync();
logger.Information("Client shut down.");
}
public void Dispose() {
connection.Dispose();
SendChannel.Dispose();
}
}

View File

@@ -1,12 +1,15 @@
using Phantom.Utils.Rpc.Runtime.Tls; using Phantom.Utils.Rpc.Runtime.Tls;
namespace Phantom.Utils.Rpc.Runtime; namespace Phantom.Utils.Rpc.Runtime.Client;
public readonly record struct RpcClientConnectionParameters( public readonly record struct RpcClientConnectionParameters(
string Host, string Host,
ushort Port, ushort Port,
string DistinguishedName, string DistinguishedName,
RpcCertificateThumbprint CertificateThumbprint, RpcCertificateThumbprint CertificateThumbprint,
AuthToken AuthToken,
ushort SendQueueCapacity, ushort SendQueueCapacity,
TimeSpan PingInterval TimeSpan PingInterval
); ) {
internal RpcCommonConnectionParameters Common => new (SendQueueCapacity, PingInterval);
}

View File

@@ -1,21 +1,23 @@
using Phantom.Utils.Logging; using System.Net.Sockets;
using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Frame;
using Serilog; using Serilog;
namespace Phantom.Utils.Rpc.Runtime; namespace Phantom.Utils.Rpc.Runtime.Client;
sealed class RpcClientConnection(string loggerName, RpcClientConnector connector, RpcClientConnector.Connection initialConnection) : IRpcConnectionProvider, IDisposable { sealed class RpcClientToServerConnection(string loggerName, RpcClientToServerConnector connector, RpcClientToServerConnector.Connection initialConnection) : IRpcConnectionProvider, IDisposable {
private readonly ILogger logger = PhantomLogger.Create<RpcClientConnection>(loggerName); private readonly ILogger logger = PhantomLogger.Create<RpcClientToServerConnection>(loggerName);
private readonly SemaphoreSlim semaphore = new (1); private readonly SemaphoreSlim semaphore = new (1);
private RpcClientConnector.Connection currentConnection = initialConnection; private RpcClientToServerConnector.Connection currentConnection = initialConnection;
private readonly CancellationTokenSource newConnectionCancellationTokenSource = new (); private readonly CancellationTokenSource newConnectionCancellationTokenSource = new ();
public async Task<Stream> GetStream() { public async Task<Stream> GetStream(CancellationToken cancellationToken) {
return (await GetConnection()).Stream; return (await GetConnection()).Stream;
} }
private async Task<RpcClientConnector.Connection> GetConnection() { private async Task<RpcClientToServerConnector.Connection> GetConnection() {
CancellationToken cancellationToken = newConnectionCancellationTokenSource.Token; CancellationToken cancellationToken = newConnectionCancellationTokenSource.Token;
await semaphore.WaitAsync(cancellationToken); await semaphore.WaitAsync(cancellationToken);
@@ -30,8 +32,8 @@ sealed class RpcClientConnection(string loggerName, RpcClientConnector connector
} }
} }
public async Task ReadConnection(Func<Stream, Task> reader) { public async Task ReadConnection(IFrameReader frameReader, CancellationToken cancellationToken) {
RpcClientConnector.Connection? connection = null; RpcClientToServerConnector.Connection? connection = null;
try { try {
while (true) { while (true) {
@@ -48,25 +50,27 @@ sealed class RpcClientConnection(string loggerName, RpcClientConnector connector
} }
try { try {
await reader(connection.Stream); await IFrame.ReadFrom(connection.Stream, frameReader, cancellationToken);
} catch (OperationCanceledException) { } catch (OperationCanceledException) {
throw; throw;
} catch (EndOfStreamException) { } catch (EndOfStreamException) {
logger.Warning("Socket was closed."); logger.Warning("Socket was closed.");
} catch (SocketException e) {
logger.Error("Socket reading was interrupted. Socket error {ErrorCode} ({ErrorCodeName}), reason: {ErrorMessage}", e.ErrorCode, e.SocketErrorCode, e.Message);
} catch (Exception e) { } catch (Exception e) {
logger.Error(e, "Closing socket due to an exception while reading it."); logger.Error(e, "Socket reading was interrupted.");
}
try {
await connection.Shutdown(); try {
} catch (Exception e2) { await connection.Shutdown();
logger.Error(e2, "Caught exception closing the socket."); } catch (Exception e) {
} logger.Error(e, "Caught exception closing the socket.");
} }
} }
} finally { } finally {
if (connection != null) { if (connection != null) {
try { try {
await connection.Disconnect(); // TODO what happens if already disconnected? await connection.Disconnect();
} finally { } finally {
connection.Dispose(); connection.Dispose();
} }
@@ -74,7 +78,7 @@ sealed class RpcClientConnection(string loggerName, RpcClientConnector connector
} }
} }
public void Close() { public void StopReconnecting() {
newConnectionCancellationTokenSource.Cancel(); newConnectionCancellationTokenSource.Cancel();
} }

View File

@@ -7,24 +7,24 @@ using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Runtime.Tls; using Phantom.Utils.Rpc.Runtime.Tls;
using Serilog; using Serilog;
namespace Phantom.Utils.Rpc.Runtime; namespace Phantom.Utils.Rpc.Runtime.Client;
internal sealed class RpcClientConnector { internal sealed class RpcClientToServerConnector {
private static readonly TimeSpan InitialRetryDelay = TimeSpan.FromMilliseconds(100); private static readonly TimeSpan InitialRetryDelay = TimeSpan.FromMilliseconds(100);
private static readonly TimeSpan MaximumRetryDelay = TimeSpan.FromSeconds(30); private static readonly TimeSpan MaximumRetryDelay = TimeSpan.FromSeconds(30);
private static readonly TimeSpan DisconnectTimeout = TimeSpan.FromSeconds(10); private static readonly TimeSpan DisconnectTimeout = TimeSpan.FromSeconds(10);
private readonly ILogger logger; private readonly ILogger logger;
private readonly Guid sessionId;
private readonly RpcClientConnectionParameters parameters; private readonly RpcClientConnectionParameters parameters;
private readonly RpcClientHandshake handshake;
private readonly SslClientAuthenticationOptions sslOptions; private readonly SslClientAuthenticationOptions sslOptions;
private bool loggedCertificateValidationError = false; private bool loggedCertificateValidationError = false;
public RpcClientConnector(string loggerName, RpcClientConnectionParameters parameters, RpcClientHandshake handshake) { public RpcClientToServerConnector(string loggerName, RpcClientConnectionParameters parameters) {
this.logger = PhantomLogger.Create<RpcClientConnector>(loggerName); this.logger = PhantomLogger.Create<RpcClientToServerConnector>(loggerName);
this.sessionId = Guid.NewGuid();
this.parameters = parameters; this.parameters = parameters;
this.handshake = handshake;
this.sslOptions = new SslClientAuthenticationOptions { this.sslOptions = new SslClientAuthenticationOptions {
AllowRenegotiation = false, AllowRenegotiation = false,
@@ -53,7 +53,8 @@ internal sealed class RpcClientConnector {
return newConnection; return newConnection;
} }
logger.Warning("Failed to connect to server, trying again in {}.", nextAttemptDelay.TotalSeconds.ToString("F1")); cancellationToken.ThrowIfCancellationRequested();
logger.Information("Trying to reconnect in {Seconds}s.", nextAttemptDelay.TotalSeconds.ToString("F1"));
await Task.Delay(nextAttemptDelay, cancellationToken); await Task.Delay(nextAttemptDelay, cancellationToken);
nextAttemptDelay = Comparables.Min(nextAttemptDelay.Multiply(1.5), MaximumRetryDelay); nextAttemptDelay = Comparables.Min(nextAttemptDelay.Multiply(1.5), MaximumRetryDelay);
@@ -66,55 +67,79 @@ internal sealed class RpcClientConnector {
Socket clientSocket = new Socket(SocketType.Stream, ProtocolType.Tcp); Socket clientSocket = new Socket(SocketType.Stream, ProtocolType.Tcp);
try { try {
await clientSocket.ConnectAsync(parameters.Host, parameters.Port, cancellationToken); await clientSocket.ConnectAsync(parameters.Host, parameters.Port, cancellationToken);
} catch (SocketException e) {
logger.Error("Could not connect. Socket error {ErrorCode} ({ErrorCodeName}), reason: {ErrorMessage}", e.ErrorCode, e.SocketErrorCode, e.Message);
return null;
} catch (Exception e) { } catch (Exception e) {
logger.Error(e, "Could not connect."); logger.Error(e, "Could not connect.");
throw; return null;
} }
SslStream? stream = null; SslStream? stream;
bool handledException = false;
try { try {
stream = new SslStream(new NetworkStream(clientSocket, ownsSocket: false), leaveInnerStreamOpen: false); stream = new SslStream(new NetworkStream(clientSocket, ownsSocket: false), leaveInnerStreamOpen: false);
try { if (await FinalizeConnection(stream, cancellationToken)) {
loggedCertificateValidationError = false; logger.Information("Connected to {Host}:{Port}.", parameters.Host, parameters.Port);
await stream.AuthenticateAsClientAsync(sslOptions, cancellationToken); return new Connection(clientSocket, stream);
} catch (AuthenticationException e) {
if (!loggedCertificateValidationError) {
logger.Error(e, "Could not establish a secure connection.");
}
handledException = true;
throw;
} }
logger.Information("Established a secure connection.");
try {
await handshake.AcceptServer(stream, cancellationToken);
} catch (EndOfStreamException) {
logger.Warning("Could not perform application handshake, connection lost.");
handledException = true;
throw;
} catch (Exception e) {
logger.Warning(e, "Could not perform application handshake.");
handledException = true;
throw;
}
return new Connection(clientSocket, stream);
} catch (Exception e) { } catch (Exception e) {
if (!handledException) { logger.Error(e, "Caught unhandled exception.");
logger.Error(e, "Caught unhandled exception."); stream = null;
}
try {
await DisconnectSocket(clientSocket, stream);
} finally {
clientSocket.Close();
}
return null;
}
private async Task<bool> FinalizeConnection(SslStream stream, CancellationToken cancellationToken) {
try {
loggedCertificateValidationError = false;
await stream.AuthenticateAsClientAsync(sslOptions, cancellationToken);
} catch (AuthenticationException e) {
if (!loggedCertificateValidationError) {
logger.Error(e, "Could not establish a secure connection.");
} }
try { return false;
await DisconnectSocket(clientSocket, stream); }
} finally {
clientSocket.Close(); logger.Information("Established a secure connection.");
try {
if (!await PerformApplicationHandshake(stream, cancellationToken)) {
return false;
} }
} catch (EndOfStreamException) {
logger.Warning("Could not perform application handshake, connection lost.");
} catch (Exception e) {
logger.Warning(e, "Could not perform application handshake.");
}
return true;
}
private async Task<bool> PerformApplicationHandshake(Stream stream, CancellationToken cancellationToken) {
await Serialization.WriteAuthToken(parameters.AuthToken, stream, cancellationToken);
await Serialization.WriteGuid(sessionId, stream, cancellationToken);
var result = (RpcHandshakeResult) await Serialization.ReadByte(stream, cancellationToken);
switch (result) {
case RpcHandshakeResult.Success:
return true;
return null; case RpcHandshakeResult.InvalidAuthToken:
logger.Error("Server rejected authorization token.");
return false;
default:
logger.Error("Server rejected client due to unknown error: {ErrorId}", result);
return false;
} }
} }

View File

@@ -1,5 +1,5 @@
namespace Phantom.Utils.Rpc.Runtime; namespace Phantom.Utils.Rpc.Runtime;
interface IRpcConnectionProvider { interface IRpcConnectionProvider {
Task<Stream> GetStream(); Task<Stream> GetStream(CancellationToken cancellationToken);
} }

View File

@@ -1,97 +0,0 @@
using Phantom.Utils.Actor;
using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Frame;
using Phantom.Utils.Rpc.Frame.Types;
using Phantom.Utils.Rpc.Message;
using Serilog;
namespace Phantom.Utils.Rpc.Runtime;
public sealed class RpcClient<TClientToServerMessage, TServerToClientMessage> : IDisposable {
public static async Task<RpcClient<TClientToServerMessage, TServerToClientMessage>?> Connect(string loggerName, RpcClientConnectionParameters connectionParameters, RpcClientHandshake handshake, IMessageDefinitions<TClientToServerMessage, TServerToClientMessage> messageDefinitions, CancellationToken cancellationToken) {
RpcClientConnector connector = new RpcClientConnector(loggerName, connectionParameters, handshake);
RpcClientConnector.Connection? connection = await connector.EstablishNewConnection(cancellationToken);
return connection == null ? null : new RpcClient<TClientToServerMessage, TServerToClientMessage>(loggerName, connectionParameters, messageDefinitions, connector, connection);
}
private readonly ILogger logger;
private readonly MessageRegistry<TServerToClientMessage> serverToClientMessageRegistry;
private readonly RpcClientConnection connection;
public RpcSendChannel<TClientToServerMessage> SendChannel { get; }
private RpcClient(string loggerName, RpcClientConnectionParameters connectionParameters, IMessageDefinitions<TClientToServerMessage, TServerToClientMessage> messageDefinitions, RpcClientConnector connector, RpcClientConnector.Connection connection) {
this.logger = PhantomLogger.Create<RpcClient<TClientToServerMessage, TServerToClientMessage>>(loggerName);
this.serverToClientMessageRegistry = messageDefinitions.ToClient;
this.connection = new RpcClientConnection(loggerName, connector, connection);
this.SendChannel = new RpcSendChannel<TClientToServerMessage>(loggerName, connectionParameters, this.connection, messageDefinitions.ToServer);
}
public async Task Listen(ActorRef<TServerToClientMessage> actor) {
try {
await connection.ReadConnection(stream => Receive(stream, new MessageHandlerImpl(SendChannel, actor)));
} catch (OperationCanceledException) {
// Ignore.
}
}
private async Task Receive(Stream stream, MessageHandlerImpl handler) {
await IFrame.ReadFrom(stream, new FrameReader(this, handler), CancellationToken.None);
}
private sealed class FrameReader(RpcClient<TClientToServerMessage, TServerToClientMessage> client, MessageHandlerImpl handler) : IFrameReader {
public Task OnMessageFrame(MessageFrame frame, Stream stream, CancellationToken cancellationToken) {
return client.serverToClientMessageRegistry.Handle(frame, handler, cancellationToken);
}
public void OnReplyFrame(ReplyFrame frame) {
client.SendChannel.ReceiveReply(frame);
}
public void OnErrorFrame(ErrorFrame frame) {
client.SendChannel.ReceiveError(frame.ReplyingToMessageId, frame.Error);
}
public void OnUnknownFrameStart(byte id) {
client.logger.Error("Received unknown frame ID: {Id}", id);
}
}
private sealed class MessageHandlerImpl(RpcSendChannel<TClientToServerMessage> sendChannel, ActorRef<TServerToClientMessage> actor) : MessageHandler<TServerToClientMessage> {
public ActorRef<TServerToClientMessage> Actor => actor;
public ValueTask OnReply<TMessage, TReply>(uint messageId, TReply reply, CancellationToken cancellationToken) where TMessage : TServerToClientMessage, ICanReply<TReply> {
return sendChannel.SendReply(messageId, reply, cancellationToken);
}
public ValueTask OnError(uint messageId, RpcError error, CancellationToken cancellationToken) {
return sendChannel.SendError(messageId, error, cancellationToken);
}
}
public async Task Shutdown() {
logger.Information("Shutting down client...");
try {
await SendChannel.Close();
} catch (Exception e) {
logger.Error(e, "Caught exception while closing send channel.");
}
try {
connection.Close();
} catch (Exception e) {
logger.Error(e, "Caught exception while closing connection.");
}
// TODO disconnection handshake?
logger.Information("Client shut down.");
}
public void Dispose() {
connection.Dispose();
SendChannel.Dispose();
}
}

View File

@@ -1,5 +0,0 @@
namespace Phantom.Utils.Rpc.Runtime;
public abstract class RpcClientHandshake {
protected internal abstract Task<bool> AcceptServer(Stream stream, CancellationToken cancellationToken);
}

View File

@@ -0,0 +1,6 @@
namespace Phantom.Utils.Rpc.Runtime;
readonly record struct RpcCommonConnectionParameters(
ushort SendQueueCapacity,
TimeSpan PingInterval
);

View File

@@ -3,7 +3,8 @@
public enum RpcError : byte { public enum RpcError : byte {
InvalidData = 0, InvalidData = 0,
UnknownMessageRegistryCode = 1, UnknownMessageRegistryCode = 1,
MessageDeserializationError = 2, MessageTooLarge = 2,
MessageHandlingError = 3, MessageDeserializationError = 3,
MessageTooLarge = 4, MessageHandlingError = 4,
MessageAlreadyHandled = 5,
} }

View File

@@ -1,13 +1,14 @@
namespace Phantom.Utils.Rpc.Runtime; namespace Phantom.Utils.Rpc.Runtime;
public sealed class RpcErrorException : Exception { sealed class RpcErrorException : Exception {
internal static RpcErrorException From(RpcError error) { internal static RpcErrorException From(RpcError error) {
return error switch { return error switch {
RpcError.InvalidData => new RpcErrorException("Invalid data", error), RpcError.InvalidData => new RpcErrorException("Invalid data", error),
RpcError.UnknownMessageRegistryCode => new RpcErrorException("Unknown message registry code", error), RpcError.UnknownMessageRegistryCode => new RpcErrorException("Unknown message registry code", error),
RpcError.MessageTooLarge => new RpcErrorException("Message is too large", error),
RpcError.MessageDeserializationError => new RpcErrorException("Message deserialization error", error), RpcError.MessageDeserializationError => new RpcErrorException("Message deserialization error", error),
RpcError.MessageHandlingError => new RpcErrorException("Message handling error", error), RpcError.MessageHandlingError => new RpcErrorException("Message handling error", error),
RpcError.MessageTooLarge => new RpcErrorException("Message is too large", error), RpcError.MessageAlreadyHandled => new RpcErrorException("Message already handled", error),
_ => new RpcErrorException("Unknown error", error), _ => new RpcErrorException("Unknown error", error),
}; };
} }

View File

@@ -0,0 +1,53 @@
using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Frame;
using Phantom.Utils.Rpc.Frame.Types;
using Phantom.Utils.Rpc.Message;
using Serilog;
namespace Phantom.Utils.Rpc.Runtime;
sealed class RpcFrameReader<TSentMessage, TReceivedMessage>(
string loggerName,
MessageRegistry<TReceivedMessage> messageRegistry,
MessageReceiveTracker messageReceiveTracker,
RpcMessageHandler<TReceivedMessage> messageHandler,
RpcSendChannel<TSentMessage> sendChannel
) : IFrameReader {
private readonly ILogger logger = PhantomLogger.Create<RpcFrameReader<TSentMessage, TReceivedMessage>>(loggerName);
public ValueTask OnPingFrame(DateTimeOffset pingTime, CancellationToken cancellationToken) {
messageHandler.OnPing();
return sendChannel.SendPong(pingTime, cancellationToken);
}
public void OnPongFrame(PongFrame frame) {
sendChannel.ReceivePong(frame);
}
public Task OnMessageFrame(MessageFrame frame, CancellationToken cancellationToken) {
if (!messageReceiveTracker.ReceiveMessage(frame.MessageId)) {
logger.Warning("Received duplicate message {MessageId}.", frame.MessageId);
return messageHandler.SendError(frame.MessageId, RpcError.MessageAlreadyHandled, cancellationToken).AsTask();
}
if (messageRegistry.TryGetType(frame, out var messageType)) {
logger.Verbose("Received message {MesageId} of type {MessageType} ({Bytes} B).", frame.MessageId, messageType.Name, frame.SerializedMessage.Length);
}
return messageRegistry.Handle(frame, messageHandler, cancellationToken);
}
public void OnReplyFrame(ReplyFrame frame) {
logger.Verbose("Received reply to message {MesageId} ({Bytes} B).", frame.ReplyingToMessageId, frame.SerializedReply.Length);
sendChannel.ReceiveReply(frame);
}
public void OnErrorFrame(ErrorFrame frame) {
logger.Warning("Received error response to message {MesageId}: {Error}", frame.ReplyingToMessageId, frame.Error);
sendChannel.ReceiveError(frame.ReplyingToMessageId, frame.Error);
}
public void OnUnknownFrameId(byte frameId) {
logger.Error("Received unknown frame ID: {FrameId}", frameId);
}
}

View File

@@ -1,6 +1,6 @@
namespace Phantom.Utils.Rpc.Runtime; namespace Phantom.Utils.Rpc.Runtime;
public enum RpcHandshakeResult : byte { enum RpcHandshakeResult : byte {
UnknownError = 0, Success = 0,
InvalidFormat = 1, InvalidAuthToken = 1,
} }

View File

@@ -0,0 +1,19 @@
using Phantom.Utils.Rpc.Message;
namespace Phantom.Utils.Rpc.Runtime;
sealed class RpcMessageHandler<TMessageBase>(IMessageReceiver<TMessageBase> receiver, IRpcReplySender replySender) {
public IMessageReceiver<TMessageBase> Receiver => receiver;
public void OnPing() {
receiver.OnPing();
}
public ValueTask SendReply<TReply>(uint messageId, TReply reply, CancellationToken cancellationToken) {
return replySender.SendReply(messageId, reply, cancellationToken);
}
public ValueTask SendError(uint messageId, RpcError error, CancellationToken cancellationToken) {
return replySender.SendError(messageId, error, cancellationToken);
}
}

View File

@@ -9,7 +9,7 @@ using Serilog;
namespace Phantom.Utils.Rpc.Runtime; namespace Phantom.Utils.Rpc.Runtime;
public sealed class RpcSendChannel<TMessageBase> : IDisposable { public sealed class RpcSendChannel<TMessageBase> : IRpcReplySender, IDisposable {
private readonly ILogger logger; private readonly ILogger logger;
private readonly IRpcConnectionProvider connectionProvider; private readonly IRpcConnectionProvider connectionProvider;
private readonly MessageRegistry<TMessageBase> messageRegistry; private readonly MessageRegistry<TMessageBase> messageRegistry;
@@ -19,12 +19,13 @@ public sealed class RpcSendChannel<TMessageBase> : IDisposable {
private readonly Task sendQueueTask; private readonly Task sendQueueTask;
private readonly Task pingTask; private readonly Task pingTask;
private readonly CancellationTokenSource cancellationTokenSource = new (); private readonly CancellationTokenSource sendQueueCancellationTokenSource = new ();
private readonly CancellationTokenSource pingCancellationTokenSource = new (); private readonly CancellationTokenSource pingCancellationTokenSource = new ();
private uint nextMessageId; private uint nextMessageId;
private TaskCompletionSource<DateTimeOffset>? pongTask;
internal RpcSendChannel(string loggerName, RpcClientConnectionParameters connectionParameters, IRpcConnectionProvider connectionProvider, MessageRegistry<TMessageBase> messageRegistry) { internal RpcSendChannel(string loggerName, RpcCommonConnectionParameters connectionParameters, IRpcConnectionProvider connectionProvider, MessageRegistry<TMessageBase> messageRegistry) {
this.logger = PhantomLogger.Create<RpcSendChannel<TMessageBase>>(loggerName); this.logger = PhantomLogger.Create<RpcSendChannel<TMessageBase>>(loggerName);
this.connectionProvider = connectionProvider; this.connectionProvider = connectionProvider;
this.messageRegistry = messageRegistry; this.messageRegistry = messageRegistry;
@@ -41,11 +42,15 @@ public sealed class RpcSendChannel<TMessageBase> : IDisposable {
this.pingTask = Ping(connectionParameters.PingInterval); this.pingTask = Ping(connectionParameters.PingInterval);
} }
internal async ValueTask SendPong(DateTimeOffset pingTime, CancellationToken cancellationToken) {
await SendFrame(new PongFrame(pingTime), cancellationToken);
}
public bool TrySendMessage<TMessage>(TMessage message) where TMessage : TMessageBase { public bool TrySendMessage<TMessage>(TMessage message) where TMessage : TMessageBase {
return sendQueue.Writer.TryWrite(NextMessageFrame(message)); return sendQueue.Writer.TryWrite(NextMessageFrame(message));
} }
public async ValueTask SendMessage<TMessage>(TMessage message, CancellationToken cancellationToken) where TMessage : TMessageBase { public async ValueTask SendMessage<TMessage>(TMessage message, CancellationToken cancellationToken = default) where TMessage : TMessageBase {
await SendFrame(NextMessageFrame(message), cancellationToken); await SendFrame(NextMessageFrame(message), cancellationToken);
} }
@@ -64,11 +69,11 @@ public sealed class RpcSendChannel<TMessageBase> : IDisposable {
return await messageReplyTracker.WaitForReply<TReply>(messageId, waitForReplyTime, cancellationToken); return await messageReplyTracker.WaitForReply<TReply>(messageId, waitForReplyTime, cancellationToken);
} }
internal async ValueTask SendReply<TReply>(uint replyingToMessageId, TReply reply, CancellationToken cancellationToken) { async ValueTask IRpcReplySender.SendReply<TReply>(uint replyingToMessageId, TReply reply, CancellationToken cancellationToken) {
await SendFrame(new ReplyFrame(replyingToMessageId, Serialization.Serialize(reply)), cancellationToken); await SendFrame(new ReplyFrame(replyingToMessageId, Serialization.Serialize(reply)), cancellationToken);
} }
internal async ValueTask SendError(uint replyingToMessageId, RpcError error, CancellationToken cancellationToken) { async ValueTask IRpcReplySender.SendError(uint replyingToMessageId, RpcError error, CancellationToken cancellationToken) {
await SendFrame(new ErrorFrame(replyingToMessageId, error), cancellationToken); await SendFrame(new ErrorFrame(replyingToMessageId, error), cancellationToken);
} }
@@ -84,25 +89,21 @@ public sealed class RpcSendChannel<TMessageBase> : IDisposable {
} }
private async Task ProcessSendQueue() { private async Task ProcessSendQueue() {
CancellationToken cancellationToken = cancellationTokenSource.Token; CancellationToken cancellationToken = sendQueueCancellationTokenSource.Token;
// TODO figure out cancellation
await foreach (IFrame frame in sendQueue.Reader.ReadAllAsync(cancellationToken)) { await foreach (IFrame frame in sendQueue.Reader.ReadAllAsync(cancellationToken)) {
while (true) { while (true) {
cancellationToken.ThrowIfCancellationRequested();
Stream stream;
try { try {
stream = await connectionProvider.GetStream(); Stream stream = await connectionProvider.GetStream(cancellationToken);
await stream.WriteAsync(frame.FrameType, cancellationToken);
await frame.Write(stream, cancellationToken);
await stream.FlushAsync(cancellationToken);
break;
} catch (OperationCanceledException) { } catch (OperationCanceledException) {
throw; throw;
} catch (Exception) { } catch (Exception) {
continue; // Retry.
} }
await stream.WriteAsync(frame.Type, cancellationToken);
await frame.Write(stream, cancellationToken);
await stream.FlushAsync(cancellationToken);
} }
} }
} }
@@ -114,12 +115,26 @@ public sealed class RpcSendChannel<TMessageBase> : IDisposable {
while (true) { while (true) {
await Task.Delay(interval, cancellationToken); await Task.Delay(interval, cancellationToken);
pongTask = new TaskCompletionSource<DateTimeOffset>();
if (!sendQueue.Writer.TryWrite(PingFrame.Instance)) { if (!sendQueue.Writer.TryWrite(PingFrame.Instance)) {
cancellationToken.ThrowIfCancellationRequested();
logger.Warning("Skipped a ping due to a full queue."); logger.Warning("Skipped a ping due to a full queue.");
continue;
} }
DateTimeOffset pingTime = await pongTask.Task.WaitAsync(cancellationToken);
DateTimeOffset currentTime = DateTimeOffset.UtcNow;
TimeSpan roundTripTime = currentTime - pingTime;
logger.Information("Received pong, round trip time: {RoundTripTime} ms", (long) roundTripTime.TotalMilliseconds);
} }
} }
internal void ReceivePong(PongFrame frame) {
pongTask?.TrySetResult(frame.PingTime);
}
internal void ReceiveReply(ReplyFrame frame) { internal void ReceiveReply(ReplyFrame frame) {
messageReplyTracker.ReceiveReply(frame.ReplyingToMessageId, frame.SerializedReply); messageReplyTracker.ReceiveReply(frame.ReplyingToMessageId, frame.SerializedReply);
} }
@@ -133,7 +148,16 @@ public sealed class RpcSendChannel<TMessageBase> : IDisposable {
sendQueue.Writer.TryComplete(); sendQueue.Writer.TryComplete();
try { try {
await Task.WhenAll(sendQueueTask, pingTask); await pingTask;
} catch (Exception) {
// Ignore.
}
try {
await sendQueueTask.WaitAsync(TimeSpan.FromSeconds(15));
} catch (TimeoutException) {
logger.Warning("Could not finish processing send queue before timeout, forcibly shutting it down.");
await sendQueueCancellationTokenSource.CancelAsync();
} catch (Exception) { } catch (Exception) {
// Ignore. // Ignore.
} }
@@ -141,7 +165,7 @@ public sealed class RpcSendChannel<TMessageBase> : IDisposable {
public void Dispose() { public void Dispose() {
sendQueueTask.Dispose(); sendQueueTask.Dispose();
cancellationTokenSource.Dispose(); sendQueueCancellationTokenSource.Dispose();
pingCancellationTokenSource.Dispose(); pingCancellationTokenSource.Dispose();
} }
} }

View File

@@ -1,133 +0,0 @@
using System.Net;
using System.Net.Security;
using System.Net.Sockets;
using System.Security.Cryptography.X509Certificates;
using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Runtime.Tls;
using Serilog;
namespace Phantom.Utils.Rpc.Runtime;
public sealed class RpcServer(string loggerName, EndPoint endPoint, RpcServerCertificate certificate, RpcServerHandshake handshake) {
private readonly ILogger logger = PhantomLogger.Create<RpcServer>(loggerName);
private readonly List<Client> clients = [];
public async Task<bool> Run(CancellationToken shutdownToken) {
SslServerAuthenticationOptions sslOptions = new () {
AllowRenegotiation = false,
AllowTlsResume = true,
CertificateRevocationCheckMode = X509RevocationMode.NoCheck,
ClientCertificateRequired = false,
EnabledSslProtocols = TlsSupport.SupportedProtocols,
EncryptionPolicy = EncryptionPolicy.RequireEncryption,
ServerCertificate = certificate.Certificate,
};
try {
using var serverSocket = new Socket(endPoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
try {
serverSocket.Bind(endPoint);
serverSocket.Listen(5);
} catch (Exception e) {
logger.Error(e, "Could not bind to {EndPoint}.", endPoint);
return false;
}
try {
logger.Information("Server listening on {EndPoint}.", endPoint);
while (!shutdownToken.IsCancellationRequested) {
Socket clientSocket = await serverSocket.AcceptAsync(shutdownToken);
clients.Add(new Client(logger, clientSocket, sslOptions, handshake, shutdownToken));
clients.RemoveAll(static client => client.Task.IsCompleted);
}
} finally {
await Stop(serverSocket);
}
} catch (Exception e) {
logger.Error(e, "Server crashed with uncaught exception.");
return false;
}
return true;
}
private async Task Stop(Socket serverSocket) {
try {
serverSocket.Close();
} catch (Exception e) {
logger.Error(e, "Server socket failed to close.");
return;
}
logger.Information("Server socket closed, waiting for client sockets to close.");
try {
await Task.WhenAll(clients.Select(static client => client.Task));
} catch (Exception) {
// Ignore exceptions when shutting down.
}
logger.Information("Server stopped.");
}
private sealed class Client {
private static readonly TimeSpan DisconnectTimeout = TimeSpan.FromSeconds(10);
public Task Task { get; }
private readonly ILogger logger;
private readonly Socket socket;
private readonly SslServerAuthenticationOptions sslOptions;
private readonly RpcServerHandshake handshake;
private readonly CancellationToken shutdownToken;
public Client(ILogger logger, Socket socket, SslServerAuthenticationOptions sslOptions, RpcServerHandshake handshake, CancellationToken shutdownToken) {
this.logger = logger;
this.socket = socket;
this.sslOptions = sslOptions;
this.handshake = handshake;
this.shutdownToken = shutdownToken;
this.Task = Run();
}
private async Task Run() {
try {
await using var stream = new SslStream(new NetworkStream(socket, ownsSocket: false), leaveInnerStreamOpen: false);
try {
await stream.AuthenticateAsServerAsync(sslOptions, shutdownToken);
} catch (Exception e) {
logger.Error(e, "Could not establish a secure connection.");
return;
}
try {
await handshake.AcceptClient(socket.RemoteEndPoint?.ToString() ?? "<unknown address>", stream, shutdownToken);
} catch (EndOfStreamException) {
logger.Warning("Could not perform application handshake, connection lost.");
return;
} catch (Exception e) {
logger.Warning(e, "Could not perform application handshake.");
return;
}
byte[] buffer = new byte[1024];
int readBytes;
while ((readBytes = await stream.ReadAsync(buffer, shutdownToken)) > 0) {}
} finally {
try {
using var timeoutTokenSource = new CancellationTokenSource(DisconnectTimeout);
await socket.DisconnectAsync(reuseSocket: false, timeoutTokenSource.Token);
} catch (OperationCanceledException) {
logger.Error("Could not disconnect client socket, disconnection timed out.");
} catch (Exception e) {
logger.Error(e, "Could not disconnect client socket.");
} finally {
socket.Close();
}
}
}
}
}

View File

@@ -1,5 +0,0 @@
namespace Phantom.Utils.Rpc.Runtime;
public abstract class RpcServerHandshake {
protected internal abstract Task<bool> AcceptClient(string remoteAddress, Stream stream, CancellationToken cancellationToken);
}

View File

@@ -0,0 +1,7 @@
using Phantom.Utils.Rpc.Message;
namespace Phantom.Utils.Rpc.Runtime.Server;
public interface IRpcServerClientRegistrar<TClientToServerMessage, TServerToClientMessage> {
IMessageReceiver<TClientToServerMessage> Register(RpcServerToClientConnection<TClientToServerMessage, TServerToClientMessage> connection);
}

View File

@@ -0,0 +1,254 @@
using System.Net;
using System.Net.Security;
using System.Net.Sockets;
using System.Security.Cryptography.X509Certificates;
using Phantom.Utils.Logging;
using Phantom.Utils.Monads;
using Phantom.Utils.Rpc.Message;
using Phantom.Utils.Rpc.Runtime.Tls;
using Serilog;
namespace Phantom.Utils.Rpc.Runtime.Server;
public sealed class RpcServer<TClientToServerMessage, TServerToClientMessage>(
string loggerName,
RpcServerConnectionParameters connectionParameters,
IMessageDefinitions<TClientToServerMessage, TServerToClientMessage> messageDefinitions,
IRpcServerClientRegistrar<TClientToServerMessage, TServerToClientMessage> clientRegistrar
) {
private readonly ILogger logger = PhantomLogger.Create<RpcServer<TClientToServerMessage, TServerToClientMessage>>(loggerName);
private readonly RpcServerClientSessions<TServerToClientMessage> clientSessions = new (connectionParameters.Common, messageDefinitions.ToClient);
private readonly List<Client> clients = [];
public async Task<bool> Run(CancellationToken shutdownToken) {
EndPoint endPoint = connectionParameters.EndPoint;
SslServerAuthenticationOptions sslOptions = new () {
AllowRenegotiation = false,
AllowTlsResume = true,
CertificateRevocationCheckMode = X509RevocationMode.NoCheck,
ClientCertificateRequired = false,
EnabledSslProtocols = TlsSupport.SupportedProtocols,
EncryptionPolicy = EncryptionPolicy.RequireEncryption,
ServerCertificate = connectionParameters.Certificate.Certificate,
};
try {
using var serverSocket = new Socket(endPoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
try {
serverSocket.Bind(endPoint);
serverSocket.Listen(5);
} catch (Exception e) {
logger.Error(e, "Could not bind to {EndPoint}.", endPoint);
return false;
}
try {
logger.Information("Server listening on {EndPoint}.", endPoint);
while (true) {
Socket clientSocket = await serverSocket.AcceptAsync(shutdownToken);
clients.Add(new Client(loggerName, messageDefinitions, clientRegistrar, clientSessions, clientSocket, sslOptions, connectionParameters.AuthToken, shutdownToken));
clients.RemoveAll(static client => client.Task.IsCompleted);
}
} catch (OperationCanceledException) {
// Ignore.
} finally {
await Stop(serverSocket);
}
} catch (Exception e) {
logger.Error(e, "Server crashed with uncaught exception.");
return false;
}
return true;
}
private async Task Stop(Socket serverSocket) {
try {
serverSocket.Close();
} catch (Exception e) {
logger.Error(e, "Server socket failed to close.");
return;
}
logger.Information("Server socket closed, waiting for client sockets to close.");
try {
await Task.WhenAll(clients.Select(static client => client.Task));
await clientSessions.Shutdown();
} catch (Exception) {
// Ignore exceptions when shutting down.
}
logger.Information("Server stopped.");
}
private sealed class Client {
private static TimeSpan DisconnectTimeout => TimeSpan.FromSeconds(10);
private static string GetAddressDescriptor(Socket socket) {
EndPoint? endPoint = socket.RemoteEndPoint;
return endPoint switch {
IPEndPoint ip => ip.Port.ToString(),
null => "{unknown}",
_ => "{" + endPoint + "}",
};
}
public Task Task { get; }
private string Address => socket.RemoteEndPoint?.ToString() ?? "<unknown address>";
private ILogger logger;
private readonly string serverLoggerName;
private readonly IMessageDefinitions<TClientToServerMessage, TServerToClientMessage> messageDefinitions;
private readonly IRpcServerClientRegistrar<TClientToServerMessage, TServerToClientMessage> clientRegistrar;
private readonly RpcServerClientSessions<TServerToClientMessage> clientSessions;
private readonly Socket socket;
private readonly SslServerAuthenticationOptions sslOptions;
private readonly AuthToken authToken;
private readonly CancellationToken shutdownToken;
public Client(
string serverLoggerName,
IMessageDefinitions<TClientToServerMessage, TServerToClientMessage> messageDefinitions,
IRpcServerClientRegistrar<TClientToServerMessage, TServerToClientMessage> clientRegistrar,
RpcServerClientSessions<TServerToClientMessage> clientSessions,
Socket socket,
SslServerAuthenticationOptions sslOptions,
AuthToken authToken,
CancellationToken shutdownToken
) {
this.logger = PhantomLogger.Create<RpcServer<TClientToServerMessage, TServerToClientMessage>, Client>(PhantomLogger.ConcatNames(serverLoggerName, GetAddressDescriptor(socket)));
this.serverLoggerName = serverLoggerName;
this.messageDefinitions = messageDefinitions;
this.clientRegistrar = clientRegistrar;
this.clientSessions = clientSessions;
this.socket = socket;
this.sslOptions = sslOptions;
this.authToken = authToken;
this.shutdownToken = shutdownToken;
this.Task = Run();
}
private async Task Run() {
logger.Information("Accepted client.");
try {
await using var stream = new SslStream(new NetworkStream(socket, ownsSocket: false), leaveInnerStreamOpen: false);
Guid? sessionIdResult;
try {
sessionIdResult = await InitializeConnection(stream);
} catch (OperationCanceledException) {
logger.Warning("Cancelling incoming client due to shutdown.");
return;
}
if (sessionIdResult.HasValue) {
await RunConnectedSession(sessionIdResult.Value, stream);
}
} catch (Exception e) {
logger.Error(e, "Caught exception while processing client.");
} finally {
logger.Information("Disconnecting client...");
try {
using var timeoutTokenSource = new CancellationTokenSource(DisconnectTimeout);
await socket.DisconnectAsync(reuseSocket: false, timeoutTokenSource.Token);
} catch (OperationCanceledException) {
logger.Error("Could not disconnect client socket due to timeout.");
} catch (Exception e) {
logger.Error(e, "Could not disconnect client socket.");
} finally {
socket.Close();
logger.Information("Client socket closed.");
}
}
}
private async Task<Guid?> InitializeConnection(SslStream stream) {
try {
await stream.AuthenticateAsServerAsync(sslOptions, shutdownToken);
} catch (OperationCanceledException) {
throw;
} catch (Exception e) {
logger.Error(e, "Could not establish a secure connection.");
return null;
}
Either<Guid, RpcHandshakeResult> handshakeResult;
try {
handshakeResult = await PerformApplicationHandshake(stream, shutdownToken);
} catch (OperationCanceledException) {
throw;
} catch (EndOfStreamException) {
logger.Warning("Could not perform application handshake, connection lost.");
return null;
} catch (Exception e) {
logger.Warning(e, "Could not perform application handshake.");
return null;
}
switch (handshakeResult) {
case Left<Guid, RpcHandshakeResult>(var clientSessionId):
await Serialization.WriteByte((byte) RpcHandshakeResult.Success, stream, shutdownToken);
return clientSessionId;
case Right<Guid, RpcHandshakeResult>(var error):
await Serialization.WriteByte((byte) error, stream, shutdownToken);
return null;
default:
throw new InvalidOperationException("Invalid handshake result type.");
}
}
private async Task<Either<Guid, RpcHandshakeResult>> PerformApplicationHandshake(Stream stream, CancellationToken cancellationToken) {
var clientAuthToken = await Serialization.ReadAuthToken(stream, cancellationToken);
if (!authToken.FixedTimeEquals(clientAuthToken)) {
logger.Warning("Rejected client {}, invalid authorization token.", Address);
return Either.Right(RpcHandshakeResult.InvalidAuthToken);
}
var sessionId = await Serialization.ReadGuid(stream, cancellationToken);
return Either.Left(sessionId);
}
private async Task RunConnectedSession(Guid sessionId, Stream stream) {
var loggerName = PhantomLogger.ConcatNames(serverLoggerName, clientSessions.NextLoggerName(sessionId));
logger.Information("Client connected with session {SessionId}, new logger name: {LoggerName}", sessionId, loggerName);
logger = PhantomLogger.Create<RpcServer<TClientToServerMessage, TServerToClientMessage>, Client>(loggerName);
var session = clientSessions.OnConnected(sessionId, loggerName, stream);
try {
var connection = new RpcServerToClientConnection<TClientToServerMessage, TServerToClientMessage>(loggerName, clientSessions, sessionId, messageDefinitions.ToServer, stream, session);
IMessageReceiver<TClientToServerMessage> messageReceiver;
try {
messageReceiver = clientRegistrar.Register(connection);
} catch (Exception e) {
logger.Error(e, "Could not register client.");
return;
}
try {
await connection.Listen(messageReceiver);
} catch (EndOfStreamException) {
logger.Warning("Socket reading was interrupted, connection lost.");
} catch (SocketException e) {
logger.Error("Socket reading was interrupted. Socket error {ErrorCode} ({ErrorCodeName}), reason: {ErrorMessage}", e.ErrorCode, e.SocketErrorCode, e.Message);
} catch (Exception e) {
logger.Error(e, "Socket reading was interrupted.");
}
} finally {
clientSessions.OnDisconnected(sessionId);
}
}
}
}

View File

@@ -0,0 +1,5 @@
namespace Phantom.Utils.Rpc.Runtime.Server;
public interface RpcServerClientActorFactory {
}

View File

@@ -0,0 +1,50 @@
using Phantom.Utils.Rpc.Message;
namespace Phantom.Utils.Rpc.Runtime.Server;
sealed class RpcServerClientSession<TServerToClientMessage> : IRpcConnectionProvider {
public RpcSendChannel<TServerToClientMessage> SendChannel { get; }
public MessageReceiveTracker MessageReceiveTracker { get; }
private TaskCompletionSource<Stream> nextStream = new ();
public RpcServerClientSession(string loggerName, RpcCommonConnectionParameters connectionParameters, MessageRegistry<TServerToClientMessage> messageRegistry) {
this.SendChannel = new RpcSendChannel<TServerToClientMessage>(loggerName, connectionParameters, this, messageRegistry);
this.MessageReceiveTracker = new MessageReceiveTracker();
}
public void OnConnected(Stream stream) {
lock (this) {
if (!nextStream.Task.IsCanceled && !nextStream.TrySetResult(stream)) {
nextStream = new TaskCompletionSource<Stream>();
nextStream.SetResult(stream);
}
}
}
public void OnDisconnected() {
lock (this) {
var task = nextStream.Task;
if (task is { IsCompleted: true, IsCanceled: false }) {
nextStream = new TaskCompletionSource<Stream>();
}
}
}
Task<Stream> IRpcConnectionProvider.GetStream(CancellationToken cancellationToken) {
lock (this) {
return nextStream.Task;
}
}
public Task Close() {
lock (this) {
if (!nextStream.TrySetCanceled()) {
nextStream = new TaskCompletionSource<Stream>();
nextStream.SetCanceled();
}
}
return SendChannel.Close();
}
}

View File

@@ -0,0 +1,61 @@
using System.Collections.Concurrent;
using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Message;
namespace Phantom.Utils.Rpc.Runtime.Server;
sealed class RpcServerClientSessions<TServerToClientMessage> {
private readonly RpcCommonConnectionParameters connectionParameters;
private readonly MessageRegistry<TServerToClientMessage> messageRegistry;
private readonly ConcurrentDictionary<Guid, RpcServerClientSession<TServerToClientMessage>> sessionsById = new ();
private readonly ConcurrentDictionary<Guid, uint> sessionLoggerSequenceIds = new ();
private readonly Func<Guid, string, RpcServerClientSession<TServerToClientMessage>> createSessionFunction;
public RpcServerClientSessions(RpcCommonConnectionParameters connectionParameters, MessageRegistry<TServerToClientMessage> messageRegistry) {
this.connectionParameters = connectionParameters;
this.messageRegistry = messageRegistry;
this.createSessionFunction = CreateSession;
}
public string NextLoggerName(Guid sessionId) {
string name = PhantomLogger.ShortenGuid(sessionId);
return name + "/" + sessionLoggerSequenceIds.AddOrUpdate(sessionId, static _ => 1, static (_, prev) => prev + 1);
}
public RpcServerClientSession<TServerToClientMessage> OnConnected(Guid sessionId, string loggerName, Stream stream) {
var session = sessionsById.GetOrAdd(sessionId, createSessionFunction, loggerName);
session.OnConnected(stream);
return session;
}
private RpcServerClientSession<TServerToClientMessage> CreateSession(Guid sessionId, string loggerName) {
return new RpcServerClientSession<TServerToClientMessage>(loggerName, connectionParameters, messageRegistry);
}
public void OnDisconnected(Guid sessionId) {
if (sessionsById.TryGetValue(sessionId, out var session)) {
session.OnDisconnected();
}
}
public Task CloseSession(Guid sessionId) {
if (sessionsById.Remove(sessionId, out var session)) {
return session.Close();
}
else {
return Task.CompletedTask;
}
}
public Task Shutdown() {
List<Task> tasks = [];
foreach (Guid guid in sessionsById.Keys) {
tasks.Add(CloseSession(guid));
}
return Task.WhenAll(tasks);
}
}

View File

@@ -0,0 +1,14 @@
using System.Net;
using Phantom.Utils.Rpc.Runtime.Tls;
namespace Phantom.Utils.Rpc.Runtime.Server;
public readonly record struct RpcServerConnectionParameters(
EndPoint EndPoint,
RpcServerCertificate Certificate,
AuthToken AuthToken,
ushort SendQueueCapacity,
TimeSpan PingInterval
) {
internal RpcCommonConnectionParameters Common => new (SendQueueCapacity, PingInterval);
}

View File

@@ -0,0 +1,49 @@
using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Frame;
using Phantom.Utils.Rpc.Message;
using Serilog;
namespace Phantom.Utils.Rpc.Runtime.Server;
public sealed class RpcServerToClientConnection<TClientToServerMessage, TServerToClientMessage> {
private readonly string loggerName;
private readonly ILogger logger;
private readonly RpcServerClientSessions<TServerToClientMessage> sessions;
private readonly MessageRegistry<TClientToServerMessage> messageRegistry;
private readonly MessageReceiveTracker messageReceiveTracker;
private readonly Stream stream;
private readonly CancellationTokenSource closeCancellationTokenSource = new ();
public Guid SessionId { get; }
public RpcSendChannel<TServerToClientMessage> SendChannel { get; }
internal RpcServerToClientConnection(string loggerName, RpcServerClientSessions<TServerToClientMessage> sessions, Guid sessionId, MessageRegistry<TClientToServerMessage> messageRegistry, Stream stream, RpcServerClientSession<TServerToClientMessage> session) {
this.loggerName = loggerName;
this.logger = PhantomLogger.Create<RpcServerToClientConnection<TClientToServerMessage, TServerToClientMessage>>(loggerName);
this.sessions = sessions;
this.messageRegistry = messageRegistry;
this.messageReceiveTracker = session.MessageReceiveTracker;
this.stream = stream;
this.SessionId = sessionId;
this.SendChannel = session.SendChannel;
}
internal async Task Listen(IMessageReceiver<TClientToServerMessage> receiver) {
var messageHandler = new RpcMessageHandler<TClientToServerMessage>(receiver, SendChannel);
var frameReader = new RpcFrameReader<TServerToClientMessage, TClientToServerMessage>(loggerName, messageRegistry, messageReceiveTracker, messageHandler, SendChannel);
try {
await IFrame.ReadFrom(stream, frameReader, closeCancellationTokenSource.Token);
} catch (OperationCanceledException) {
// Ignore.
}
}
public async Task ClientClosedSession() {
logger.Information("Client closed session.");
Task closeSessionTask = sessions.CloseSession(SessionId);
await closeCancellationTokenSource.CancelAsync();
await closeSessionTask;
}
}

View File

@@ -1,8 +1,7 @@
using System.Security.Cryptography.X509Certificates; using System.Security.Cryptography.X509Certificates;
using Phantom.Utils.Monads; using Phantom.Utils.Monads;
using Phantom.Utils.Rpc.Runtime.Tls;
namespace Phantom.Utils.Rpc.Runtime; namespace Phantom.Utils.Rpc.Runtime.Tls;
public sealed class RpcServerCertificate { public sealed class RpcServerCertificate {
public static byte[] CreateAndExport(string commonName) { public static byte[] CreateAndExport(string commonName) {

View File

@@ -1,135 +0,0 @@
using System.Collections.Concurrent;
using Akka.Actor;
using NetMQ.Sockets;
using Phantom.Utils.Actor;
using Phantom.Utils.Logging;
using Serilog;
using Serilog.Events;
namespace Phantom.Utils.Rpc.Runtime2;
public static class RpcServerRuntime {
public static Task Launch<TClientMessage, TServerMessage, TRegistrationMessage, TReplyMessage>(
RpcConfiguration config,
IMessageDefinitions<TClientMessage, TServerMessage, TReplyMessage> messageDefinitions,
IRegistrationHandler<TClientMessage, TServerMessage, TRegistrationMessage> registrationHandler,
IActorRefFactory actorSystem,
CancellationToken cancellationToken
) where TRegistrationMessage : TServerMessage where TReplyMessage : TClientMessage, TServerMessage {
return RpcServerRuntime<TClientMessage, TServerMessage, TRegistrationMessage, TReplyMessage>.Launch(config, messageDefinitions, registrationHandler, actorSystem, cancellationToken);
}
}
internal sealed class RpcServerRuntime<TClientMessage, TServerMessage, TRegistrationMessage, TReplyMessage> : RpcRuntime<ServerSocket> where TRegistrationMessage : TServerMessage where TReplyMessage : TClientMessage, TServerMessage {
internal static Task Launch(RpcConfiguration config, IMessageDefinitions<TClientMessage, TServerMessage, TReplyMessage> messageDefinitions, IRegistrationHandler<TClientMessage, TServerMessage, TRegistrationMessage> registrationHandler, IActorRefFactory actorSystem, CancellationToken cancellationToken) {
var socket = RpcServerSocket.Connect(config);
return new RpcServerRuntime<TClientMessage, TServerMessage, TRegistrationMessage, TReplyMessage>(socket, messageDefinitions, registrationHandler, actorSystem, cancellationToken).Launch();
}
private readonly string serviceName;
private readonly IMessageDefinitions<TClientMessage, TServerMessage, TReplyMessage> messageDefinitions;
private readonly IRegistrationHandler<TClientMessage, TServerMessage, TRegistrationMessage> registrationHandler;
private readonly IActorRefFactory actorSystem;
private readonly CancellationToken cancellationToken;
private RpcServerRuntime(RpcServerSocket socket, IMessageDefinitions<TClientMessage, TServerMessage, TReplyMessage> messageDefinitions, IRegistrationHandler<TClientMessage, TServerMessage, TRegistrationMessage> registrationHandler, IActorRefFactory actorSystem, CancellationToken cancellationToken) : base(socket) {
this.serviceName = socket.Config.ServiceName;
this.messageDefinitions = messageDefinitions;
this.registrationHandler = registrationHandler;
this.actorSystem = actorSystem;
this.cancellationToken = cancellationToken;
}
private protected override Task Run(ServerSocket socket) {
var clients = new ConcurrentDictionary<ulong, Client>();
void OnConnectionClosed(object? sender, RpcClientConnectionClosedEventArgs e) {
if (clients.Remove(e.RoutingId, out var client)) {
client.Connection.Closed -= OnConnectionClosed;
}
}
while (!cancellationToken.IsCancellationRequested) {
var (routingId, data) = socket.Receive(cancellationToken);
if (data.Length == 0) {
LogUnknownMessage(routingId, data);
continue;
}
Type? messageType = messageDefinitions.ToServer.TryGetType(data, out var type) ? type : null;
if (messageType == null) {
LogUnknownMessage(routingId, data);
continue;
}
if (!clients.TryGetValue(routingId, out var client)) {
if (messageType != typeof(TRegistrationMessage)) {
RuntimeLogger.Warning("Received {MessageType} ({Bytes} B) from unregistered client {RoutingId}.", messageType.Name, data.Length, routingId);
continue;
}
var clientLoggerName = LoggerName + ":" + routingId;
var clientActorName = "Rpc-" + serviceName + "-" + routingId;
// TODO add pings and tear down connection after too much inactivity
var connection = new RpcConnectionToClient<TClientMessage>(socket, routingId, messageDefinitions.ToClient, ReplyTracker);
connection.Closed += OnConnectionClosed;
client = new Client(clientLoggerName, clientActorName, connection, actorSystem, messageDefinitions, registrationHandler);
clients[routingId] = client;
}
client.Enqueue(messageType, data);
}
foreach (var client in clients.Values) {
client.Connection.Close();
}
return Task.CompletedTask;
}
private void LogUnknownMessage(uint routingId, ReadOnlyMemory<byte> data) {
RuntimeLogger.Warning("Received unknown message ({Bytes} B) from {RoutingId}.", data.Length, routingId);
}
private protected override Task Disconnect(ServerSocket socket) {
return Task.CompletedTask;
}
private sealed class Client {
public RpcConnectionToClient<TClientMessage> Connection { get; }
private readonly ILogger logger;
private readonly ActorRef<RpcReceiverActor<TClientMessage, TServerMessage, TRegistrationMessage, TReplyMessage>.ReceiveMessageCommand> receiverActor;
public Client(string loggerName, string actorName, RpcConnectionToClient<TClientMessage> connection, IActorRefFactory actorSystem, IMessageDefinitions<TClientMessage, TServerMessage, TReplyMessage> messageDefinitions, IRegistrationHandler<TClientMessage, TServerMessage, TRegistrationMessage> registrationHandler) {
this.Connection = connection;
this.Connection.Closed += OnConnectionClosed;
this.logger = PhantomLogger.Create(loggerName);
var receiverActorInit = new RpcReceiverActor<TClientMessage, TServerMessage, TRegistrationMessage, TReplyMessage>.Init(loggerName, messageDefinitions, registrationHandler, Connection);
this.receiverActor = actorSystem.ActorOf(RpcReceiverActor<TClientMessage, TServerMessage, TRegistrationMessage, TReplyMessage>.Factory(receiverActorInit), actorName + "-Receiver");
}
internal void Enqueue(Type messageType, ReadOnlyMemory<byte> data) {
LogMessageType(messageType, data);
receiverActor.Tell(new RpcReceiverActor<TClientMessage, TServerMessage, TRegistrationMessage, TReplyMessage>.ReceiveMessageCommand(messageType, data));
}
private void LogMessageType(Type messageType, ReadOnlyMemory<byte> data) {
if (logger.IsEnabled(LogEventLevel.Verbose)) {
logger.Verbose("Received {MessageType} ({Bytes} B).", messageType.Name, data.Length);
}
}
private void OnConnectionClosed(object? sender, RpcClientConnectionClosedEventArgs e) {
Connection.Closed -= OnConnectionClosed;
logger.Debug("Closing connection...");
receiverActor.Stop();
}
}
}

View File

@@ -5,50 +5,81 @@ using MemoryPack;
namespace Phantom.Utils.Rpc; namespace Phantom.Utils.Rpc;
static class Serialization { static class Serialization {
private const int GuidBytes = 16;
private static readonly MemoryPackSerializerOptions SerializerOptions = MemoryPackSerializerOptions.Utf8; private static readonly MemoryPackSerializerOptions SerializerOptions = MemoryPackSerializerOptions.Utf8;
private static async ValueTask WritePrimitive<T>(T value, int size, Action<Span<byte>, T> writer, Stream stream, CancellationToken cancellationToken) { private static async ValueTask WriteValue<T>(T value, int size, Action<Span<byte>, T> writer, Stream stream, CancellationToken cancellationToken) {
using var buffer = RentedMemory<byte>.Rent(size); using var buffer = RentedMemory<byte>.Rent(size);
writer(buffer.AsSpan, value); writer(buffer.AsSpan, value);
await stream.WriteAsync(buffer.AsMemory, cancellationToken); await stream.WriteAsync(buffer.AsMemory, cancellationToken);
} }
private static async ValueTask<T> ReadPrimitive<T>(Func<ReadOnlySpan<byte>, T> reader, int size, Stream stream, CancellationToken cancellationToken) { private static async ValueTask<T> ReadValue<T>(Func<ReadOnlySpan<byte>, T> reader, int size, Stream stream, CancellationToken cancellationToken) {
using var buffer = RentedMemory<byte>.Rent(size); using var buffer = RentedMemory<byte>.Rent(size);
await stream.ReadExactlyAsync(buffer.AsMemory, cancellationToken); await stream.ReadExactlyAsync(buffer.AsMemory, cancellationToken);
return reader(buffer.AsSpan); return reader(buffer.AsSpan);
} }
public static ValueTask WriteByte(byte value, Stream stream, CancellationToken cancellationToken) { public static ValueTask WriteByte(byte value, Stream stream, CancellationToken cancellationToken) {
return WritePrimitive(value, sizeof(byte), static (span, value) => span[0] = value, stream, cancellationToken); return WriteValue(value, sizeof(byte), static (span, value) => span[0] = value, stream, cancellationToken);
} }
public static ValueTask<byte> ReadByte(Stream stream, CancellationToken cancellationToken) { public static ValueTask<byte> ReadByte(Stream stream, CancellationToken cancellationToken) {
return ReadPrimitive(static span => span[0], sizeof(byte), stream, cancellationToken); return ReadValue(static span => span[0], sizeof(byte), stream, cancellationToken);
} }
public static ValueTask WriteUnsignedShort(ushort value, Stream stream, CancellationToken cancellationToken) { public static ValueTask WriteUnsignedShort(ushort value, Stream stream, CancellationToken cancellationToken) {
return WritePrimitive(value, sizeof(ushort), BinaryPrimitives.WriteUInt16LittleEndian, stream, cancellationToken); return WriteValue(value, sizeof(ushort), BinaryPrimitives.WriteUInt16LittleEndian, stream, cancellationToken);
} }
public static ValueTask<ushort> ReadUnsignedShort(Stream stream, CancellationToken cancellationToken) { public static ValueTask<ushort> ReadUnsignedShort(Stream stream, CancellationToken cancellationToken) {
return ReadPrimitive(BinaryPrimitives.ReadUInt16LittleEndian, sizeof(ushort), stream, cancellationToken); return ReadValue(BinaryPrimitives.ReadUInt16LittleEndian, sizeof(ushort), stream, cancellationToken);
} }
public static ValueTask WriteSignedInt(int value, Stream stream, CancellationToken cancellationToken) { public static ValueTask WriteSignedInt(int value, Stream stream, CancellationToken cancellationToken) {
return WritePrimitive(value, sizeof(int), BinaryPrimitives.WriteInt32LittleEndian, stream, cancellationToken); return WriteValue(value, sizeof(int), BinaryPrimitives.WriteInt32LittleEndian, stream, cancellationToken);
} }
public static ValueTask<int> ReadSignedInt(Stream stream, CancellationToken cancellationToken) { public static ValueTask<int> ReadSignedInt(Stream stream, CancellationToken cancellationToken) {
return ReadPrimitive(BinaryPrimitives.ReadInt32LittleEndian, sizeof(int), stream, cancellationToken); return ReadValue(BinaryPrimitives.ReadInt32LittleEndian, sizeof(int), stream, cancellationToken);
} }
public static ValueTask WriteUnsignedInt(uint value, Stream stream, CancellationToken cancellationToken) { public static ValueTask WriteUnsignedInt(uint value, Stream stream, CancellationToken cancellationToken) {
return WritePrimitive(value, sizeof(uint), BinaryPrimitives.WriteUInt32LittleEndian, stream, cancellationToken); return WriteValue(value, sizeof(uint), BinaryPrimitives.WriteUInt32LittleEndian, stream, cancellationToken);
} }
public static ValueTask<uint> ReadUnsignedInt(Stream stream, CancellationToken cancellationToken) { public static ValueTask<uint> ReadUnsignedInt(Stream stream, CancellationToken cancellationToken) {
return ReadPrimitive(BinaryPrimitives.ReadUInt32LittleEndian, sizeof(uint), stream, cancellationToken); return ReadValue(BinaryPrimitives.ReadUInt32LittleEndian, sizeof(uint), stream, cancellationToken);
}
public static ValueTask WriteSignedLong(long value, Stream stream, CancellationToken cancellationToken) {
return WriteValue(value, sizeof(long), BinaryPrimitives.WriteInt64LittleEndian, stream, cancellationToken);
}
public static ValueTask<long> ReadSignedLong(Stream stream, CancellationToken cancellationToken) {
return ReadValue(BinaryPrimitives.ReadInt64LittleEndian, sizeof(long), stream, cancellationToken);
}
public static ValueTask WriteGuid(Guid guid, Stream stream, CancellationToken cancellationToken) {
static void Write(Span<byte> span, Guid guid) {
if (!guid.TryWriteBytes(span)) {
throw new ArgumentException("Span is not large enough to write a GUID.", nameof(span));
}
}
return WriteValue(guid, size: GuidBytes, Write, stream, cancellationToken);
}
public static ValueTask<Guid> ReadGuid(Stream stream, CancellationToken cancellationToken) {
return ReadValue(static span => new Guid(span), size: GuidBytes, stream, cancellationToken);
}
public static ValueTask WriteAuthToken(AuthToken authToken, Stream stream, CancellationToken cancellationToken) {
return stream.WriteAsync(authToken.Bytes.AsMemory(), cancellationToken);
}
public static ValueTask<AuthToken> ReadAuthToken(Stream stream, CancellationToken cancellationToken) {
return ReadValue(static span => new AuthToken([..span]), AuthToken.Length, stream, cancellationToken);
} }
public static async ValueTask<ReadOnlyMemory<byte>> ReadBytes(int length, Stream stream, CancellationToken cancellationToken) { public static async ValueTask<ReadOnlyMemory<byte>> ReadBytes(int length, Stream stream, CancellationToken cancellationToken) {

View File

@@ -0,0 +1,93 @@
using NUnit.Framework;
using Phantom.Utils.Collections;
using Range = Phantom.Utils.Collections.RangeSet<int>.Range;
namespace Phantom.Utils.Tests.Collections;
[TestFixture]
public class RangeSetTests {
[Test]
public void OneValue() {
var set = new RangeSet<int>();
set.Add(5);
Assert.That(set, Is.EqualTo(new[] {
new Range(Min: 5, Max: 5),
}));
}
[Test]
public void MultipleDisjointValues() {
var set = new RangeSet<int>();
set.Add(5);
set.Add(7);
set.Add(1);
set.Add(3);
Assert.That(set, Is.EqualTo(new[] {
new Range(Min: 1, Max: 1),
new Range(Min: 3, Max: 3),
new Range(Min: 5, Max: 5),
new Range(Min: 7, Max: 7),
}));
}
[Test]
public void ExtendMin() {
var set = new RangeSet<int>();
set.Add(5);
set.Add(4);
Assert.That(set, Is.EqualTo(new[] {
new Range(Min: 4, Max: 5),
}));
}
[Test]
public void ExtendMax() {
var set = new RangeSet<int>();
set.Add(5);
set.Add(6);
Assert.That(set, Is.EqualTo(new[] {
new Range(Min: 5, Max: 6),
}));
}
[Test]
public void ExtendMaxAndMerge() {
var set = new RangeSet<int>();
set.Add(5);
set.Add(7);
set.Add(6);
Assert.That(set, Is.EqualTo(new[] {
new Range(Min: 5, Max: 7),
}));
}
[Test]
public void MultipleMergingAndDisjointValues() {
var set = new RangeSet<int>();
set.Add(1);
set.Add(2);
set.Add(5);
set.Add(4);
set.Add(10);
set.Add(7);
set.Add(9);
set.Add(11);
set.Add(16);
set.Add(12);
set.Add(3);
set.Add(14);
Assert.That(set, Is.EqualTo(new[] {
new Range(Min: 1, Max: 5),
new Range(Min: 7, Max: 7),
new Range(Min: 9, Max: 12),
new Range(Min: 14, Max: 14),
new Range(Min: 16, Max: 16),
}));
}
}

View File

@@ -0,0 +1,69 @@
using System.Collections;
using System.Numerics;
namespace Phantom.Utils.Collections;
public sealed class RangeSet<T> : IEnumerable<RangeSet<T>.Range> where T : IBinaryInteger<T> {
private readonly List<Range> ranges = [];
public bool Add(T value) {
int index = 0;
for (; index < ranges.Count; index++) {
var range = ranges[index];
if (range.Contains(value)) {
return false;
}
if (range.ExtendIfAtEdge(value, out var extendedRange)) {
ranges[index] = extendedRange;
if (index < ranges.Count - 1) {
var nextRange = ranges[index + 1];
if (extendedRange.Max + T.One == nextRange.Min) {
ranges[index] = new Range(extendedRange.Min, nextRange.Max);
ranges.RemoveAt(index + 1);
}
}
return true;
}
if (range.Max > value) {
break;
}
}
ranges.Insert(index, new Range(value, value));
return true;
}
public IEnumerator<Range> GetEnumerator() {
return ranges.GetEnumerator();
}
IEnumerator IEnumerable.GetEnumerator() {
return GetEnumerator();
}
public readonly record struct Range(T Min, T Max) {
internal bool ExtendIfAtEdge(T value, out Range newRange) {
if (value == Min - T.One) {
newRange = this with { Min = value };
return true;
}
else if (value == Max + T.One) {
newRange = this with { Max = value };
return true;
}
else {
newRange = default;
return false;
}
}
internal bool Contains(T value) {
return value >= Min && value <= Max;
}
}
}

View File

@@ -74,10 +74,6 @@ sealed class Base24 {
} }
public byte[] Decode(ReadOnlySpan<char> data) { public byte[] Decode(ReadOnlySpan<char> data) {
if (data == null) {
throw new ArgumentNullException(nameof(data));
}
if (data.Length % 7 != 0) { if (data.Length % 7 != 0) {
throw new ArgumentException("The data length must be multiple of 7 chars."); throw new ArgumentException("The data length must be multiple of 7 chars.");
} }

View File

@@ -1,6 +1,9 @@
namespace Phantom.Utils.Monads; namespace Phantom.Utils.Monads;
public abstract record Either<TLeft, TRight> { public abstract record Either<TLeft, TRight> {
public abstract bool IsLeft { get; }
public abstract bool IsRight { get; }
public abstract TLeft RequireLeft { get; } public abstract TLeft RequireLeft { get; }
public abstract TRight RequireRight { get; } public abstract TRight RequireRight { get; }

View File

@@ -1,6 +1,9 @@
namespace Phantom.Utils.Monads; namespace Phantom.Utils.Monads;
public sealed record Left<TLeft, TRight>(TLeft Value) : Either<TLeft, TRight> { public sealed record Left<TLeft, TRight>(TLeft Value) : Either<TLeft, TRight> {
public override bool IsLeft => true;
public override bool IsRight => false;
public override TLeft RequireLeft => Value; public override TLeft RequireLeft => Value;
public override TRight RequireRight => throw new InvalidOperationException("Either<" + typeof(TLeft).Name + ", " + typeof(TRight).Name + "> has a left value, but right value was requested."); public override TRight RequireRight => throw new InvalidOperationException("Either<" + typeof(TLeft).Name + ", " + typeof(TRight).Name + "> has a left value, but right value was requested.");

View File

@@ -1,6 +1,9 @@
namespace Phantom.Utils.Monads; namespace Phantom.Utils.Monads;
public sealed record Right<TLeft, TRight>(TRight Value) : Either<TLeft, TRight> { public sealed record Right<TLeft, TRight>(TRight Value) : Either<TLeft, TRight> {
public override bool IsLeft => false;
public override bool IsRight => true;
public override TLeft RequireLeft => throw new InvalidOperationException("Either<" + typeof(TLeft).Name + ", " + typeof(TRight).Name + "> has a right value, but left value was requested."); public override TLeft RequireLeft => throw new InvalidOperationException("Either<" + typeof(TLeft).Name + ", " + typeof(TRight).Name + "> has a right value, but left value was requested.");
public override TRight RequireRight => Value; public override TRight RequireRight => Value;

View File

@@ -14,7 +14,7 @@ namespace Phantom.Web.Services;
public static class PhantomWebServices { public static class PhantomWebServices {
public static void AddPhantomServices(this IServiceCollection services) { public static void AddPhantomServices(this IServiceCollection services) {
services.AddSingleton<ControllerConnection>(); services.AddSingleton<ControllerConnection>();
services.AddSingleton<ControllerMessageHandlerFactory>(); services.AddSingleton<ControllerMessageHandlerActorInitFactory>();
services.AddSingleton<AgentManager>(); services.AddSingleton<AgentManager>();
services.AddSingleton<InstanceManager>(); services.AddSingleton<InstanceManager>();

View File

@@ -6,7 +6,7 @@ namespace Phantom.Web.Services.Rpc;
public sealed class ControllerConnection(RpcSendChannel<IMessageToController> connection) { public sealed class ControllerConnection(RpcSendChannel<IMessageToController> connection) {
public ValueTask Send<TMessage>(TMessage message) where TMessage : IMessageToController { public ValueTask Send<TMessage>(TMessage message) where TMessage : IMessageToController {
return connection.SendMessage(message, CancellationToken.None); return connection.SendMessage(message);
} }
public Task<TReply> Send<TMessage, TReply>(TMessage message, TimeSpan waitForReplyTime, CancellationToken waitForReplyCancellationToken = default) where TMessage : IMessageToController, ICanReply<TReply> { public Task<TReply> Send<TMessage, TReply>(TMessage message, TimeSpan waitForReplyTime, CancellationToken waitForReplyCancellationToken = default) where TMessage : IMessageToController, ICanReply<TReply> {

View File

@@ -7,13 +7,12 @@ using Phantom.Web.Services.Instances;
namespace Phantom.Web.Services.Rpc; namespace Phantom.Web.Services.Rpc;
sealed class ControllerMessageHandlerActor : ReceiveActor<IMessageToWeb> { public sealed class ControllerMessageHandlerActor : ReceiveActor<IMessageToWeb> {
public readonly record struct Init( public readonly record struct Init(
AgentManager AgentManager, AgentManager AgentManager,
InstanceManager InstanceManager, InstanceManager InstanceManager,
InstanceLogManager InstanceLogManager, InstanceLogManager InstanceLogManager,
UserSessionRefreshManager UserSessionRefreshManager, UserSessionRefreshManager UserSessionRefreshManager
TaskCompletionSource<bool> RegisterSuccessWaiter
); );
public static Props<IMessageToWeb> Factory(Init init) { public static Props<IMessageToWeb> Factory(Init init) {
@@ -24,26 +23,19 @@ sealed class ControllerMessageHandlerActor : ReceiveActor<IMessageToWeb> {
private readonly InstanceManager instanceManager; private readonly InstanceManager instanceManager;
private readonly InstanceLogManager instanceLogManager; private readonly InstanceLogManager instanceLogManager;
private readonly UserSessionRefreshManager userSessionRefreshManager; private readonly UserSessionRefreshManager userSessionRefreshManager;
private readonly TaskCompletionSource<bool> registerSuccessWaiter;
private ControllerMessageHandlerActor(Init init) { private ControllerMessageHandlerActor(Init init) {
this.agentManager = init.AgentManager; this.agentManager = init.AgentManager;
this.instanceManager = init.InstanceManager; this.instanceManager = init.InstanceManager;
this.instanceLogManager = init.InstanceLogManager; this.instanceLogManager = init.InstanceLogManager;
this.userSessionRefreshManager = init.UserSessionRefreshManager; this.userSessionRefreshManager = init.UserSessionRefreshManager;
this.registerSuccessWaiter = init.RegisterSuccessWaiter;
Receive<RegisterWebResultMessage>(HandleRegisterWebResult);
Receive<RefreshAgentsMessage>(HandleRefreshAgents); Receive<RefreshAgentsMessage>(HandleRefreshAgents);
Receive<RefreshInstancesMessage>(HandleRefreshInstances); Receive<RefreshInstancesMessage>(HandleRefreshInstances);
Receive<InstanceOutputMessage>(HandleInstanceOutput); Receive<InstanceOutputMessage>(HandleInstanceOutput);
Receive<RefreshUserSessionMessage>(HandleRefreshUserSession); Receive<RefreshUserSessionMessage>(HandleRefreshUserSession);
} }
private void HandleRegisterWebResult(RegisterWebResultMessage message) {
registerSuccessWaiter.TrySetResult(message.Success);
}
private void HandleRefreshAgents(RefreshAgentsMessage message) { private void HandleRefreshAgents(RefreshAgentsMessage message) {
agentManager.RefreshAgents(message.Agents); agentManager.RefreshAgents(message.Agents);
} }

View File

@@ -0,0 +1,16 @@
using Phantom.Web.Services.Agents;
using Phantom.Web.Services.Authentication;
using Phantom.Web.Services.Instances;
namespace Phantom.Web.Services.Rpc;
public sealed class ControllerMessageHandlerActorInitFactory(
AgentManager agentManager,
InstanceManager instanceManager,
InstanceLogManager instanceLogManager,
UserSessionRefreshManager userSessionRefreshManager
) {
public ControllerMessageHandlerActor.Init Create() {
return new ControllerMessageHandlerActor.Init(agentManager, instanceManager, instanceLogManager, userSessionRefreshManager);
}
}

View File

@@ -1,28 +0,0 @@
using Akka.Actor;
using Phantom.Common.Messages.Web;
using Phantom.Utils.Actor;
using Phantom.Utils.Tasks;
using Phantom.Web.Services.Agents;
using Phantom.Web.Services.Authentication;
using Phantom.Web.Services.Instances;
namespace Phantom.Web.Services.Rpc;
public sealed class ControllerMessageHandlerFactory(
AgentManager agentManager,
InstanceManager instanceManager,
InstanceLogManager instanceLogManager,
UserSessionRefreshManager userSessionRefreshManager
) {
private readonly TaskCompletionSource<bool> registerSuccessWaiter = AsyncTasks.CreateCompletionSource<bool>();
public Task<bool> RegisterSuccessWaiter => registerSuccessWaiter.Task;
private int messageHandlerId = 0;
public ActorRef<IMessageToWeb> Create(IActorRefFactory actorSystem) {
var init = new ControllerMessageHandlerActor.Init(agentManager, instanceManager, instanceLogManager, userSessionRefreshManager, registerSuccessWaiter);
var name = "ControllerMessageHandler-" + Interlocked.Increment(ref messageHandlerId);
return actorSystem.ActorOf(ControllerMessageHandlerActor.Factory(init), name);
}
}

View File

@@ -5,7 +5,8 @@ using Phantom.Utils.Actor;
using Phantom.Utils.Cryptography; using Phantom.Utils.Cryptography;
using Phantom.Utils.IO; using Phantom.Utils.IO;
using Phantom.Utils.Logging; using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Runtime; using Phantom.Utils.Rpc.Message;
using Phantom.Utils.Rpc.Runtime.Client;
using Phantom.Utils.Runtime; using Phantom.Utils.Runtime;
using Phantom.Utils.Threading; using Phantom.Utils.Threading;
using Phantom.Web; using Phantom.Web;
@@ -48,8 +49,6 @@ try {
string dataProtectionKeysPath = Path.GetFullPath("./keys"); string dataProtectionKeysPath = Path.GetFullPath("./keys");
CreateFolderOrStop(dataProtectionKeysPath, Chmod.URWX); CreateFolderOrStop(dataProtectionKeysPath, Chmod.URWX);
var (certificateThumbprint, authToken) = webKey.Value;
var administratorToken = TokenGenerator.Create(60); var administratorToken = TokenGenerator.Create(60);
var applicationProperties = new ApplicationProperties(fullVersion, TokenGenerator.GetBytesOrThrow(administratorToken)); var applicationProperties = new ApplicationProperties(fullVersion, TokenGenerator.GetBytesOrThrow(administratorToken));
@@ -57,12 +56,13 @@ try {
Host: controllerHost, Host: controllerHost,
Port: controllerPort, Port: controllerPort,
DistinguishedName: "phantom-controller", DistinguishedName: "phantom-controller",
CertificateThumbprint: certificateThumbprint, CertificateThumbprint: webKey.Value.CertificateThumbprint,
AuthToken: webKey.Value.AuthToken,
SendQueueCapacity: 500, SendQueueCapacity: 500,
PingInterval: TimeSpan.FromSeconds(10) PingInterval: TimeSpan.FromSeconds(10)
); );
using var rpcClient = await RpcClient<IMessageToController, IMessageToWeb>.Connect("Controller", rpcClientConnectionParameters, null, WebMessageRegistries.Definitions, shutdownCancellationToken); using var rpcClient = await RpcClient<IMessageToController, IMessageToWeb>.Connect("Controller", rpcClientConnectionParameters, WebMessageRegistries.Definitions, shutdownCancellationToken);
if (rpcClient == null) { if (rpcClient == null) {
return 1; return 1;
} }
@@ -72,11 +72,6 @@ try {
using var actorSystem = ActorSystemFactory.Create("Web"); using var actorSystem = ActorSystemFactory.Create("Web");
ControllerMessageHandlerFactory messageHandlerFactory;
await using (var scope = webApplication.Services.CreateAsyncScope()) {
messageHandlerFactory = scope.ServiceProvider.GetRequiredService<ControllerMessageHandlerFactory>();
}
Task? rpcClientListener = null; Task? rpcClientListener = null;
try { try {
PhantomLogger.Root.InformationHeading("Launching Phantom Panel web..."); PhantomLogger.Root.InformationHeading("Launching Phantom Panel web...");
@@ -85,16 +80,27 @@ try {
await WebLauncher.Launch(webConfiguration, webApplication); await WebLauncher.Launch(webConfiguration, webApplication);
ActorRef<IMessageToWeb> rpcMessageHandlerActor;
await using (var scope = webApplication.Services.CreateAsyncScope()) {
var rpcMessageHandlerInit = scope.ServiceProvider.GetRequiredService<ControllerMessageHandlerActorInitFactory>().Create();
rpcMessageHandlerActor = actorSystem.ActorOf(ControllerMessageHandlerActor.Factory(rpcMessageHandlerInit), "ControllerMessageHandler");
}
rpcClientListener = rpcClient.Listen(new IMessageReceiver<IMessageToWeb>.Actor(rpcMessageHandlerActor));
PhantomLogger.Root.Information("Phantom Panel web is ready."); PhantomLogger.Root.Information("Phantom Panel web is ready.");
rpcClientListener = rpcClient.Listen(messageHandlerFactory.Create(actorSystem));
await shutdownCancellationToken.WaitHandle.WaitOneAsync(); await shutdownCancellationToken.WaitHandle.WaitOneAsync();
await webApplication.StopAsync();
} finally { } finally {
PhantomLogger.Root.Information("Unregistering web...");
try { try {
await rpcClient.SendChannel.SendMessage(new UnregisterWebMessage(), CancellationToken.None); using var unregisterCancellationTokenSource = new CancellationTokenSource(TimeSpan.FromSeconds(10));
// TODO wait for acknowledgment await rpcClient.SendChannel.SendMessage(new UnregisterWebMessage(), unregisterCancellationTokenSource.Token);
} catch (OperationCanceledException) {
PhantomLogger.Root.Warning("Could not unregister web after shutdown.");
} catch (Exception e) { } catch (Exception e) {
PhantomLogger.Root.Warning(e, "Could not unregister agent after shutdown."); PhantomLogger.Root.Warning(e, "Could not unregister web after shutdown.");
} finally { } finally {
await rpcClient.Shutdown(); await rpcClient.Shutdown();

View File

@@ -62,7 +62,7 @@ static class WebLauncher {
application.MapFallbackToPage("/_Host"); application.MapFallbackToPage("/_Host");
logger.Information("Starting Web server on port {Port}...", config.Port); logger.Information("Starting Web server on port {Port}...", config.Port);
return application.RunAsync(config.CancellationToken); return application.StartAsync(config.CancellationToken);
} }
private sealed class NullLifetime : IHostLifetime { private sealed class NullLifetime : IHostLifetime {