1
0
mirror of https://github.com/chylex/Minecraft-Phantom-Panel.git synced 2025-09-17 21:24:49 +02:00

1 Commits

Author SHA1 Message Date
c9c9b91e34 Replace NetMQ with custom TCP logic 2025-09-17 17:05:25 +02:00
66 changed files with 993 additions and 606 deletions

View File

@@ -19,6 +19,8 @@ public sealed class AgentServices {
public ActorSystem ActorSystem { get; } public ActorSystem ActorSystem { get; }
private ControllerConnection ControllerConnection { get; }
private AgentInfo AgentInfo { get; } private AgentInfo AgentInfo { get; }
private AgentFolders AgentFolders { get; } private AgentFolders AgentFolders { get; }
private AgentState AgentState { get; } private AgentState AgentState { get; }
@@ -31,6 +33,8 @@ public sealed class AgentServices {
public AgentServices(AgentInfo agentInfo, AgentFolders agentFolders, AgentServiceConfiguration serviceConfiguration, ControllerConnection controllerConnection) { public AgentServices(AgentInfo agentInfo, AgentFolders agentFolders, AgentServiceConfiguration serviceConfiguration, ControllerConnection controllerConnection) {
this.ActorSystem = ActorSystemFactory.Create("Agent"); this.ActorSystem = ActorSystemFactory.Create("Agent");
this.ControllerConnection = controllerConnection;
this.AgentInfo = agentInfo; this.AgentInfo = agentInfo;
this.AgentFolders = agentFolders; this.AgentFolders = agentFolders;
this.AgentState = new AgentState(); this.AgentState = new AgentState();
@@ -49,12 +53,14 @@ public sealed class AgentServices {
} }
} }
public async Task<bool> Register(ControllerConnection connection, CancellationToken cancellationToken) { public async Task<bool> Register(CancellationToken cancellationToken) {
Logger.Information("Registering with the controller..."); Logger.Information("Registering with the controller...");
// TODO NEED TO SEND WHEN SERVER RESTARTS!!!
ImmutableArray<ConfigureInstanceMessage> configureInstanceMessages; ImmutableArray<ConfigureInstanceMessage> configureInstanceMessages;
try { try {
configureInstanceMessages = await connection.Send<RegisterAgentMessage, ImmutableArray<ConfigureInstanceMessage>>(new RegisterAgentMessage(AgentInfo), TimeSpan.FromMinutes(1), cancellationToken); configureInstanceMessages = await ControllerConnection.Send<RegisterAgentMessage, ImmutableArray<ConfigureInstanceMessage>>(new RegisterAgentMessage(AgentInfo), TimeSpan.FromMinutes(1), cancellationToken);
} catch (Exception e) { } catch (Exception e) {
Logger.Fatal(e, "Registration failed."); Logger.Fatal(e, "Registration failed.");
return false; return false;
@@ -76,7 +82,7 @@ public sealed class AgentServices {
} }
} }
await connection.Send(new AdvertiseJavaRuntimesMessage(JavaRuntimeRepository.All), cancellationToken); await ControllerConnection.Send(new AdvertiseJavaRuntimesMessage(JavaRuntimeRepository.All), cancellationToken);
InstanceTicketManager.RefreshAgentStatus(); InstanceTicketManager.RefreshAgentStatus();
return true; return true;

View File

@@ -56,11 +56,6 @@ sealed class InstanceManagerActor : ReceiveActor<InstanceManagerActor.ICommand>
ReceiveAsync<ShutdownCommand>(Shutdown); ReceiveAsync<ShutdownCommand>(Shutdown);
} }
private string GetInstanceLoggerName(Guid guid) {
var prefix = guid.ToString();
return prefix[..prefix.IndexOf('-')] + "/" + Interlocked.Increment(ref instanceLoggerSequenceId);
}
private sealed record InstanceInfo(ActorRef<InstanceActor.ICommand> Actor, InstanceConfiguration Configuration, IServerLauncher Launcher); private sealed record InstanceInfo(ActorRef<InstanceActor.ICommand> Actor, InstanceConfiguration Configuration, IServerLauncher Launcher);
public interface ICommand {} public interface ICommand {}
@@ -118,7 +113,8 @@ sealed class InstanceManagerActor : ReceiveActor<InstanceManagerActor.ICommand>
} }
} }
else { else {
var instanceInit = new InstanceActor.Init(agentState, instanceGuid, GetInstanceLoggerName(instanceGuid), instanceServices, instanceTicketManager, shutdownCancellationToken); var instanceLoggerName = PhantomLogger.ShortenGuid(instanceGuid) + "/" + Interlocked.Increment(ref instanceLoggerSequenceId);;
var instanceInit = new InstanceActor.Init(agentState, instanceGuid, instanceLoggerName, instanceServices, instanceTicketManager, shutdownCancellationToken);
instances[instanceGuid] = instance = new InstanceInfo(Context.ActorOf(InstanceActor.Factory(instanceInit), "Instance-" + instanceGuid), configuration, launcher); instances[instanceGuid] = instance = new InstanceInfo(Context.ActorOf(InstanceActor.Factory(instanceInit), "Instance-" + instanceGuid), configuration, launcher);
Logger.Information("Created instance \"{Name}\" (GUID {Guid}).", configuration.InstanceName, instanceGuid); Logger.Information("Created instance \"{Name}\" (GUID {Guid}).", configuration.InstanceName, instanceGuid);

View File

@@ -7,6 +7,7 @@ using Phantom.Common.Messages.Agent;
using Phantom.Common.Messages.Agent.ToController; using Phantom.Common.Messages.Agent.ToController;
using Phantom.Utils.Actor; using Phantom.Utils.Actor;
using Phantom.Utils.Logging; using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Message;
using Phantom.Utils.Rpc.Runtime.Client; using Phantom.Utils.Rpc.Runtime.Client;
using Phantom.Utils.Runtime; using Phantom.Utils.Runtime;
using Phantom.Utils.Threading; using Phantom.Utils.Threading;
@@ -61,33 +62,34 @@ try {
return 1; return 1;
} }
var controllerConnection = new ControllerConnection(rpcClient.SendChannel);
Task? rpcClientListener = null; Task? rpcClientListener = null;
try { try {
PhantomLogger.Root.InformationHeading("Launching Phantom Panel agent..."); PhantomLogger.Root.InformationHeading("Launching Phantom Panel agent...");
var agentInfo = new AgentInfo(agentGuid.Value, agentName, ProtocolVersion, fullVersion, maxInstances, maxMemory, allowedServerPorts, allowedRconPorts); var agentInfo = new AgentInfo(agentGuid.Value, agentName, ProtocolVersion, fullVersion, maxInstances, maxMemory, allowedServerPorts, allowedRconPorts);
var agentServices = new AgentServices(agentInfo, folders, new AgentServiceConfiguration(maxConcurrentBackupCompressionTasks), controllerConnection); var agentServices = new AgentServices(agentInfo, folders, new AgentServiceConfiguration(maxConcurrentBackupCompressionTasks), new ControllerConnection(rpcClient.SendChannel));
await agentServices.Initialize(); await agentServices.Initialize();
var rpcMessageHandlerInit = new ControllerMessageHandlerActor.Init(agentServices); var rpcMessageHandlerInit = new ControllerMessageHandlerActor.Init(agentServices);
var rpcMessageHandlerActor = agentServices.ActorSystem.ActorOf(ControllerMessageHandlerActor.Factory(rpcMessageHandlerInit), "ControllerMessageHandler"); var rpcMessageHandlerActor = agentServices.ActorSystem.ActorOf(ControllerMessageHandlerActor.Factory(rpcMessageHandlerInit), "ControllerMessageHandler");
rpcClientListener = rpcClient.Listen(rpcMessageHandlerActor); rpcClientListener = rpcClient.Listen(new IMessageReceiver<IMessageToAgent>.Actor(rpcMessageHandlerActor));
if (await agentServices.Register(controllerConnection, shutdownCancellationToken)) { if (await agentServices.Register(shutdownCancellationToken)) {
PhantomLogger.Root.Information("Phantom Panel agent is ready."); PhantomLogger.Root.Information("Phantom Panel agent is ready.");
await shutdownCancellationToken.WaitHandle.WaitOneAsync(); await shutdownCancellationToken.WaitHandle.WaitOneAsync();
} }
await agentServices.Shutdown(); await agentServices.Shutdown();
} finally { } finally {
PhantomLogger.Root.Information("Unregistering agent...");
try { try {
await controllerConnection.Send(new UnregisterAgentMessage(), CancellationToken.None); using var unregisterCancellationTokenSource = new CancellationTokenSource(TimeSpan.FromSeconds(10));
// TODO wait for acknowledgment await rpcClient.SendChannel.SendMessage(new UnregisterAgentMessage(), unregisterCancellationTokenSource.Token);
} catch (OperationCanceledException) {
PhantomLogger.Root.Warning("Could not unregister agent after shutdown.");
} catch (Exception e) { } catch (Exception e) {
PhantomLogger.Root.Warning(e, "Could not unregister agent after shutdown."); PhantomLogger.Root.Warning(e, "Could not unregister agent during shutdown.");
} finally { } finally {
await rpcClient.Shutdown(); await rpcClient.Shutdown();

View File

@@ -22,7 +22,6 @@ public static class AgentMessageRegistries {
ToController.Add<RegisterAgentMessage, ImmutableArray<ConfigureInstanceMessage>>(0); ToController.Add<RegisterAgentMessage, ImmutableArray<ConfigureInstanceMessage>>(0);
ToController.Add<UnregisterAgentMessage>(1); ToController.Add<UnregisterAgentMessage>(1);
ToController.Add<AgentIsAliveMessage>(2);
ToController.Add<AdvertiseJavaRuntimesMessage>(3); ToController.Add<AdvertiseJavaRuntimesMessage>(3);
ToController.Add<ReportInstanceStatusMessage>(4); ToController.Add<ReportInstanceStatusMessage>(4);
ToController.Add<InstanceOutputMessage>(5); ToController.Add<InstanceOutputMessage>(5);

View File

@@ -1,6 +0,0 @@
using MemoryPack;
namespace Phantom.Common.Messages.Agent.ToController;
[MemoryPackable(GenerateType.VersionTolerant)]
public sealed partial record AgentIsAliveMessage : IMessageToController;

View File

@@ -1,8 +0,0 @@
using MemoryPack;
namespace Phantom.Common.Messages.Web.ToWeb;
[MemoryPackable(GenerateType.VersionTolerant)]
public sealed partial record RegisterWebResultMessage(
[property: MemoryPackOrder(0)] bool Success
) : IMessageToWeb;

View File

@@ -21,6 +21,7 @@ using Phantom.Utils.Actor.Mailbox;
using Phantom.Utils.Actor.Tasks; using Phantom.Utils.Actor.Tasks;
using Phantom.Utils.Collections; using Phantom.Utils.Collections;
using Phantom.Utils.Logging; using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Runtime.Server;
using Serilog; using Serilog;
namespace Phantom.Controller.Services.Agents; namespace Phantom.Controller.Services.Agents;
@@ -167,13 +168,13 @@ sealed class AgentActor : ReceiveActor<AgentActor.ICommand> {
return configurationMessages.ToImmutable(); return configurationMessages.ToImmutable();
} }
public interface ICommand {} public interface ICommand;
private sealed record InitializeCommand : ICommand; private sealed record InitializeCommand : ICommand;
public sealed record RegisterCommand(AgentConfiguration Configuration, RpcConnectionToClient<IMessageToAgent> Connection) : ICommand, ICanReply<ImmutableArray<ConfigureInstanceMessage>>; public sealed record RegisterCommand(AgentConfiguration Configuration, RpcServerToClientConnection<IMessageToController, IMessageToAgent> Connection) : ICommand, ICanReply<ImmutableArray<ConfigureInstanceMessage>>;
public sealed record UnregisterCommand(RpcConnectionToClient<IMessageToAgent> Connection) : ICommand; public sealed record UnregisterCommand(RpcServerToClientConnection<IMessageToController, IMessageToAgent> Connection) : ICommand;
private sealed record RefreshConnectionStatusCommand : ICommand; private sealed record RefreshConnectionStatusCommand : ICommand;

View File

@@ -1,35 +1,30 @@
using Phantom.Common.Messages.Agent; using Phantom.Common.Messages.Agent;
using Phantom.Utils.Actor; using Phantom.Utils.Actor;
using Phantom.Utils.Logging; using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Runtime.Server;
using Serilog; using Serilog;
namespace Phantom.Controller.Services.Agents; namespace Phantom.Controller.Services.Agents;
sealed class AgentConnection { sealed class AgentConnection(Guid agentGuid, string agentName) {
private static readonly ILogger Logger = PhantomLogger.Create<AgentConnection>(); private static readonly ILogger Logger = PhantomLogger.Create<AgentConnection>();
private readonly Guid agentGuid; private string agentName = agentName;
private string agentName; private RpcServerToClientConnection<IMessageToController, IMessageToAgent>? connection;
private RpcConnectionToClient<IMessageToAgent>? connection; public void UpdateConnection(RpcServerToClientConnection<IMessageToController, IMessageToAgent> newConnection, string newAgentName) {
public AgentConnection(Guid agentGuid, string agentName) {
this.agentName = agentName;
this.agentGuid = agentGuid;
}
public void UpdateConnection(RpcConnectionToClient<IMessageToAgent> newConnection, string newAgentName) {
lock (this) { lock (this) {
connection?.Close(); connection?.ClientClosedSession();
connection = newConnection; connection = newConnection;
agentName = newAgentName; agentName = newAgentName;
} }
} }
public bool CloseIfSame(RpcConnectionToClient<IMessageToAgent> expected) { public bool CloseIfSame(RpcServerToClientConnection<IMessageToController, IMessageToAgent> expected) {
lock (this) { lock (this) {
if (connection != null && connection.IsSame(expected)) { if (connection != null && ReferenceEquals(connection, expected)) {
connection.Close(); connection.ClientClosedSession();
connection = null;
return true; return true;
} }
} }
@@ -44,7 +39,7 @@ sealed class AgentConnection {
return Task.CompletedTask; return Task.CompletedTask;
} }
return connection.Send(message); return connection.SendChannel.SendMessage(message).AsTask();
} }
} }
@@ -52,10 +47,10 @@ sealed class AgentConnection {
lock (this) { lock (this) {
if (connection == null) { if (connection == null) {
LogAgentOffline(); LogAgentOffline();
return Task.FromResult<TReply?>(default); return Task.FromResult<TReply?>(null);
} }
return connection.Send<TMessage, TReply>(message, waitForReplyTime, waitForReplyCancellationToken)!; return connection.SendChannel.SendMessage<TMessage, TReply>(message, waitForReplyTime, waitForReplyCancellationToken)!;
} }
} }

View File

@@ -13,7 +13,7 @@ using Phantom.Controller.Minecraft;
using Phantom.Controller.Services.Users.Sessions; using Phantom.Controller.Services.Users.Sessions;
using Phantom.Utils.Actor; using Phantom.Utils.Actor;
using Phantom.Utils.Logging; using Phantom.Utils.Logging;
using Phantom.Utils.Rpc; using Phantom.Utils.Rpc.Runtime.Server;
using Serilog; using Serilog;
namespace Phantom.Controller.Services.Agents; namespace Phantom.Controller.Services.Agents;
@@ -22,19 +22,18 @@ sealed class AgentManager {
private static readonly ILogger Logger = PhantomLogger.Create<AgentManager>(); private static readonly ILogger Logger = PhantomLogger.Create<AgentManager>();
private readonly IActorRefFactory actorSystem; private readonly IActorRefFactory actorSystem;
private readonly AuthToken authToken;
private readonly ControllerState controllerState; private readonly ControllerState controllerState;
private readonly MinecraftVersions minecraftVersions; private readonly MinecraftVersions minecraftVersions;
private readonly UserLoginManager userLoginManager; private readonly UserLoginManager userLoginManager;
private readonly IDbContextProvider dbProvider; private readonly IDbContextProvider dbProvider;
private readonly CancellationToken cancellationToken; private readonly CancellationToken cancellationToken;
private readonly ConcurrentDictionary<Guid, ActorRef<AgentActor.ICommand>> agentsByGuid = new (); private readonly ConcurrentDictionary<Guid, ActorRef<AgentActor.ICommand>> agentsByAgentGuid = new ();
private readonly ConcurrentDictionary<Guid, Guid> agentGuidsBySessionGuid = new ();
private readonly Func<Guid, AgentConfiguration, ActorRef<AgentActor.ICommand>> addAgentActorFactory; private readonly Func<Guid, AgentConfiguration, ActorRef<AgentActor.ICommand>> addAgentActorFactory;
public AgentManager(IActorRefFactory actorSystem, AuthToken authToken, ControllerState controllerState, MinecraftVersions minecraftVersions, UserLoginManager userLoginManager, IDbContextProvider dbProvider, CancellationToken cancellationToken) { public AgentManager(IActorRefFactory actorSystem, ControllerState controllerState, MinecraftVersions minecraftVersions, UserLoginManager userLoginManager, IDbContextProvider dbProvider, CancellationToken cancellationToken) {
this.actorSystem = actorSystem; this.actorSystem = actorSystem;
this.authToken = authToken;
this.controllerState = controllerState; this.controllerState = controllerState;
this.minecraftVersions = minecraftVersions; this.minecraftVersions = minecraftVersions;
this.userLoginManager = userLoginManager; this.userLoginManager = userLoginManager;
@@ -57,28 +56,21 @@ sealed class AgentManager {
var agentGuid = entity.AgentGuid; var agentGuid = entity.AgentGuid;
var agentConfiguration = new AgentConfiguration(entity.Name, entity.ProtocolVersion, entity.BuildVersion, entity.MaxInstances, entity.MaxMemory); var agentConfiguration = new AgentConfiguration(entity.Name, entity.ProtocolVersion, entity.BuildVersion, entity.MaxInstances, entity.MaxMemory);
if (agentsByGuid.TryAdd(agentGuid, CreateAgentActor(agentGuid, agentConfiguration))) { if (agentsByAgentGuid.TryAdd(agentGuid, CreateAgentActor(agentGuid, agentConfiguration))) {
Logger.Information("Loaded agent \"{AgentName}\" (GUID {AgentGuid}) from database.", agentConfiguration.AgentName, agentGuid); Logger.Information("Loaded agent \"{AgentName}\" (GUID {AgentGuid}) from database.", agentConfiguration.AgentName, agentGuid);
} }
} }
} }
public async Task<bool> RegisterAgent(AuthToken authToken, AgentInfo agentInfo, RpcConnectionToClient<IMessageToAgent> connection) { public async Task<ImmutableArray<ConfigureInstanceMessage>> RegisterAgent(AgentInfo agentInfo, RpcServerToClientConnection<IMessageToController, IMessageToAgent> connection) {
if (!this.authToken.FixedTimeEquals(authToken)) {
await connection.Send(new RegisterAgentFailureMessage(RegisterAgentError.InvalidToken));
return false;
}
var agentProperties = AgentConfiguration.From(agentInfo); var agentProperties = AgentConfiguration.From(agentInfo);
var agentActor = agentsByGuid.GetOrAdd(agentInfo.AgentGuid, addAgentActorFactory, agentProperties); var agentActor = agentsByAgentGuid.GetOrAdd(agentInfo.AgentGuid, addAgentActorFactory, agentProperties);
var configureInstanceMessages = await agentActor.Request(new AgentActor.RegisterCommand(agentProperties, connection), cancellationToken); agentGuidsBySessionGuid[connection.SessionId] = agentInfo.AgentGuid;
await connection.Send(new RegisterAgentSuccessMessage(configureInstanceMessages)); return await agentActor.Request(new AgentActor.RegisterCommand(agentProperties, connection), cancellationToken);
return true;
} }
public bool TellAgent(Guid agentGuid, AgentActor.ICommand command) { public bool TellAgent(Guid agentGuid, AgentActor.ICommand command) {
if (agentsByGuid.TryGetValue(agentGuid, out var agent)) { if (agentsByAgentGuid.TryGetValue(agentGuid, out var agent)) {
agent.Tell(command); agent.Tell(command);
return true; return true;
} }
@@ -94,7 +86,7 @@ sealed class AgentManager {
return (UserInstanceActionFailure) UserActionFailure.NotAuthorized; return (UserInstanceActionFailure) UserActionFailure.NotAuthorized;
} }
if (!agentsByGuid.TryGetValue(agentGuid, out var agent)) { if (!agentsByAgentGuid.TryGetValue(agentGuid, out var agent)) {
return (UserInstanceActionFailure) InstanceActionFailure.AgentDoesNotExist; return (UserInstanceActionFailure) InstanceActionFailure.AgentDoesNotExist;
} }
@@ -102,4 +94,12 @@ sealed class AgentManager {
var result = await agent.Request(command, cancellationToken); var result = await agent.Request(command, cancellationToken);
return result.MapError(static error => (UserInstanceActionFailure) error); return result.MapError(static error => (UserInstanceActionFailure) error);
} }
public bool TryGetAgentGuidBySessionGuid(Guid sessionGuid, out Guid agentGuid) {
return agentGuidsBySessionGuid.TryGetValue(sessionGuid, out agentGuid);
}
public void OnSessionClosed(Guid sessionId, Guid agentGuid) {
agentGuidsBySessionGuid.TryRemove(new KeyValuePair<Guid, Guid>(sessionId, agentGuid));
}
} }

View File

@@ -1,8 +1,4 @@
using Akka.Actor; using Akka.Actor;
using Phantom.Common.Messages.Agent;
using Phantom.Common.Messages.Agent.ToController;
using Phantom.Common.Messages.Web;
using Phantom.Common.Messages.Web.ToController;
using Phantom.Controller.Database; using Phantom.Controller.Database;
using Phantom.Controller.Minecraft; using Phantom.Controller.Minecraft;
using Phantom.Controller.Services.Agents; using Phantom.Controller.Services.Agents;
@@ -13,8 +9,8 @@ using Phantom.Controller.Services.Users;
using Phantom.Controller.Services.Users.Sessions; using Phantom.Controller.Services.Users.Sessions;
using Phantom.Utils.Actor; using Phantom.Utils.Actor;
using Phantom.Utils.Rpc; using Phantom.Utils.Rpc;
using IMessageFromAgentToController = Phantom.Common.Messages.Agent.IMessageToController; using IRpcAgentRegistrar = Phantom.Utils.Rpc.Runtime.Server.IRpcServerClientRegistrar<Phantom.Common.Messages.Agent.IMessageToController, Phantom.Common.Messages.Agent.IMessageToAgent>;
using IMessageFromWebToController = Phantom.Common.Messages.Web.IMessageToController; using IRpcWebRegistrar = Phantom.Utils.Rpc.Runtime.Server.IRpcServerClientRegistrar<Phantom.Common.Messages.Web.IMessageToController, Phantom.Common.Messages.Web.IMessageToWeb>;
namespace Phantom.Controller.Services; namespace Phantom.Controller.Services;
@@ -37,13 +33,13 @@ public sealed class ControllerServices : IDisposable {
private AuditLogManager AuditLogManager { get; } private AuditLogManager AuditLogManager { get; }
private EventLogManager EventLogManager { get; } private EventLogManager EventLogManager { get; }
public IRegistrationHandler<IMessageToAgent, IMessageFromAgentToController, RegisterAgentMessage> AgentRegistrationHandler { get; } public IRpcAgentRegistrar AgentRegistrar { get; }
public IRegistrationHandler<IMessageToWeb, IMessageFromWebToController, RegisterWebMessage> WebRegistrationHandler { get; } public IRpcWebRegistrar WebRegistrar { get; }
private readonly IDbContextProvider dbProvider; private readonly IDbContextProvider dbProvider;
private readonly CancellationToken cancellationToken; private readonly CancellationToken cancellationToken;
public ControllerServices(IDbContextProvider dbProvider, AuthToken agentAuthToken, AuthToken webAuthToken, CancellationToken shutdownCancellationToken) { public ControllerServices(IDbContextProvider dbProvider, AuthToken agentAuthToken, CancellationToken shutdownCancellationToken) {
this.dbProvider = dbProvider; this.dbProvider = dbProvider;
this.cancellationToken = shutdownCancellationToken; this.cancellationToken = shutdownCancellationToken;
@@ -59,14 +55,14 @@ public sealed class ControllerServices : IDisposable {
this.UserLoginManager = new UserLoginManager(AuthenticatedUserCache, UserManager, dbProvider); this.UserLoginManager = new UserLoginManager(AuthenticatedUserCache, UserManager, dbProvider);
this.PermissionManager = new PermissionManager(dbProvider); this.PermissionManager = new PermissionManager(dbProvider);
this.AgentManager = new AgentManager(ActorSystem, agentAuthToken, ControllerState, MinecraftVersions, UserLoginManager, dbProvider, cancellationToken); this.AgentManager = new AgentManager(ActorSystem, ControllerState, MinecraftVersions, UserLoginManager, dbProvider, cancellationToken);
this.InstanceLogManager = new InstanceLogManager(); this.InstanceLogManager = new InstanceLogManager();
this.AuditLogManager = new AuditLogManager(dbProvider); this.AuditLogManager = new AuditLogManager(dbProvider);
this.EventLogManager = new EventLogManager(ControllerState, ActorSystem, dbProvider, shutdownCancellationToken); this.EventLogManager = new EventLogManager(ControllerState, ActorSystem, dbProvider, shutdownCancellationToken);
this.AgentRegistrationHandler = new AgentRegistrationHandler(AgentManager, InstanceLogManager, EventLogManager); this.AgentRegistrar = new AgentClientRegistrar(ActorSystem, AgentManager, InstanceLogManager, EventLogManager);
this.WebRegistrationHandler = new WebRegistrationHandler(webAuthToken, ControllerState, InstanceLogManager, UserManager, RoleManager, UserRoleManager, UserLoginManager, AuditLogManager, AgentManager, MinecraftVersions, EventLogManager); this.WebRegistrar = new WebClientRegistrar(ActorSystem, ControllerState, InstanceLogManager, UserManager, RoleManager, UserRoleManager, UserLoginManager, AuditLogManager, AgentManager, MinecraftVersions, EventLogManager);
} }
public async Task Initialize() { public async Task Initialize() {

View File

@@ -0,0 +1,31 @@
using Akka.Actor;
using Phantom.Common.Messages.Agent;
using Phantom.Controller.Services.Agents;
using Phantom.Controller.Services.Events;
using Phantom.Controller.Services.Instances;
using Phantom.Utils.Actor;
using Phantom.Utils.Rpc.Message;
using Phantom.Utils.Rpc.Runtime.Server;
namespace Phantom.Controller.Services.Rpc;
sealed class AgentClientRegistrar(
IActorRefFactory actorSystem,
AgentManager agentManager,
InstanceLogManager instanceLogManager,
EventLogManager eventLogManager
) : IRpcServerClientRegistrar<IMessageToController, IMessageToAgent> {
public IMessageReceiver<IMessageToController> Register(RpcServerToClientConnection<IMessageToController, IMessageToAgent> connection) {
var name = "AgentClient-" + connection.SessionId;
var init = new AgentMessageHandlerActor.Init(connection, agentManager, instanceLogManager, eventLogManager);
return new Receiver(connection.SessionId, agentManager, actorSystem.ActorOf(AgentMessageHandlerActor.Factory(init), name));
}
private sealed class Receiver(Guid sessionId, AgentManager agentManager, ActorRef<IMessageToController> actor) : IMessageReceiver<IMessageToController>.Actor(actor) {
public override void OnPing() {
if (agentManager.TryGetAgentGuidBySessionGuid(sessionId, out var agentGuid)) {
agentManager.TellAgent(agentGuid, new AgentActor.NotifyIsAliveCommand());
}
}
}
}

View File

@@ -1,4 +1,5 @@
using Phantom.Common.Data.Replies; using System.Collections.Immutable;
using Akka.Actor;
using Phantom.Common.Messages.Agent; using Phantom.Common.Messages.Agent;
using Phantom.Common.Messages.Agent.ToAgent; using Phantom.Common.Messages.Agent.ToAgent;
using Phantom.Common.Messages.Agent.ToController; using Phantom.Common.Messages.Agent.ToController;
@@ -6,34 +7,32 @@ using Phantom.Controller.Services.Agents;
using Phantom.Controller.Services.Events; using Phantom.Controller.Services.Events;
using Phantom.Controller.Services.Instances; using Phantom.Controller.Services.Instances;
using Phantom.Utils.Actor; using Phantom.Utils.Actor;
using Phantom.Utils.Rpc.Runtime.Server;
namespace Phantom.Controller.Services.Rpc; namespace Phantom.Controller.Services.Rpc;
sealed class AgentMessageHandlerActor : ReceiveActor<IMessageToController> { sealed class AgentMessageHandlerActor : ReceiveActor<IMessageToController> {
public readonly record struct Init(Guid AgentGuid, RpcConnectionToClient<IMessageToAgent> Connection, AgentRegistrationHandler AgentRegistrationHandler, AgentManager AgentManager, InstanceLogManager InstanceLogManager, EventLogManager EventLogManager); public readonly record struct Init(RpcServerToClientConnection<IMessageToController, IMessageToAgent> Connection, AgentManager AgentManager, InstanceLogManager InstanceLogManager, EventLogManager EventLogManager);
public static Props<IMessageToController> Factory(Init init) { public static Props<IMessageToController> Factory(Init init) {
return Props<IMessageToController>.Create(() => new AgentMessageHandlerActor(init), new ActorConfiguration { SupervisorStrategy = SupervisorStrategies.Resume }); return Props<IMessageToController>.Create(() => new AgentMessageHandlerActor(init), new ActorConfiguration { SupervisorStrategy = SupervisorStrategies.Resume });
} }
private readonly Guid agentGuid; private readonly RpcServerToClientConnection<IMessageToController, IMessageToAgent> connection;
private readonly RpcConnectionToClient<IMessageToAgent> connection;
private readonly AgentRegistrationHandler agentRegistrationHandler;
private readonly AgentManager agentManager; private readonly AgentManager agentManager;
private readonly InstanceLogManager instanceLogManager; private readonly InstanceLogManager instanceLogManager;
private readonly EventLogManager eventLogManager; private readonly EventLogManager eventLogManager;
private Guid? registeredAgentGuid;
private AgentMessageHandlerActor(Init init) { private AgentMessageHandlerActor(Init init) {
this.agentGuid = init.AgentGuid;
this.connection = init.Connection; this.connection = init.Connection;
this.agentRegistrationHandler = init.AgentRegistrationHandler;
this.agentManager = init.AgentManager; this.agentManager = init.AgentManager;
this.instanceLogManager = init.InstanceLogManager; this.instanceLogManager = init.InstanceLogManager;
this.eventLogManager = init.EventLogManager; this.eventLogManager = init.EventLogManager;
ReceiveAsync<RegisterAgentMessage>(HandleRegisterAgent); ReceiveAsyncAndReply<RegisterAgentMessage, ImmutableArray<ConfigureInstanceMessage>>(HandleRegisterAgent);
Receive<UnregisterAgentMessage>(HandleUnregisterAgent); ReceiveAsync<UnregisterAgentMessage>(HandleUnregisterAgent);
Receive<AgentIsAliveMessage>(HandleAgentIsAlive);
Receive<AdvertiseJavaRuntimesMessage>(HandleAdvertiseJavaRuntimes); Receive<AdvertiseJavaRuntimesMessage>(HandleAdvertiseJavaRuntimes);
Receive<ReportAgentStatusMessage>(HandleReportAgentStatus); Receive<ReportAgentStatusMessage>(HandleReportAgentStatus);
Receive<ReportInstanceStatusMessage>(HandleReportInstanceStatus); Receive<ReportInstanceStatusMessage>(HandleReportInstanceStatus);
@@ -42,42 +41,42 @@ sealed class AgentMessageHandlerActor : ReceiveActor<IMessageToController> {
Receive<InstanceOutputMessage>(HandleInstanceOutput); Receive<InstanceOutputMessage>(HandleInstanceOutput);
} }
private async Task HandleRegisterAgent(RegisterAgentMessage message) { private Guid RequireAgentGuid() {
if (agentGuid != message.AgentInfo.AgentGuid) { return registeredAgentGuid ?? throw new InvalidOperationException("Agent has not registered yet.");
await connection.Send(new RegisterAgentFailureMessage(RegisterAgentError.ConnectionAlreadyHasAnAgent));
}
else {
await agentRegistrationHandler.TryRegisterImpl(connection, message);
}
} }
private void HandleUnregisterAgent(UnregisterAgentMessage message) { private Task<ImmutableArray<ConfigureInstanceMessage>> HandleRegisterAgent(RegisterAgentMessage message) {
registeredAgentGuid = message.AgentInfo.AgentGuid;
return agentManager.RegisterAgent(message.AgentInfo, connection);
}
private Task HandleUnregisterAgent(UnregisterAgentMessage message) {
Guid agentGuid = RequireAgentGuid();
agentManager.TellAgent(agentGuid, new AgentActor.UnregisterCommand(connection)); agentManager.TellAgent(agentGuid, new AgentActor.UnregisterCommand(connection));
connection.Close(); agentManager.OnSessionClosed(connection.SessionId, agentGuid);
}
Self.Tell(PoisonPill.Instance);
private void HandleAgentIsAlive(AgentIsAliveMessage message) { return connection.ClientClosedSession();
agentManager.TellAgent(agentGuid, new AgentActor.NotifyIsAliveCommand());
} }
private void HandleAdvertiseJavaRuntimes(AdvertiseJavaRuntimesMessage message) { private void HandleAdvertiseJavaRuntimes(AdvertiseJavaRuntimesMessage message) {
agentManager.TellAgent(agentGuid, new AgentActor.UpdateJavaRuntimesCommand(message.Runtimes)); agentManager.TellAgent(RequireAgentGuid(), new AgentActor.UpdateJavaRuntimesCommand(message.Runtimes));
} }
private void HandleReportAgentStatus(ReportAgentStatusMessage message) { private void HandleReportAgentStatus(ReportAgentStatusMessage message) {
agentManager.TellAgent(agentGuid, new AgentActor.UpdateStatsCommand(message.RunningInstanceCount, message.RunningInstanceMemory)); agentManager.TellAgent(RequireAgentGuid(), new AgentActor.UpdateStatsCommand(message.RunningInstanceCount, message.RunningInstanceMemory));
} }
private void HandleReportInstanceStatus(ReportInstanceStatusMessage message) { private void HandleReportInstanceStatus(ReportInstanceStatusMessage message) {
agentManager.TellAgent(agentGuid, new AgentActor.UpdateInstanceStatusCommand(message.InstanceGuid, message.InstanceStatus)); agentManager.TellAgent(RequireAgentGuid(), new AgentActor.UpdateInstanceStatusCommand(message.InstanceGuid, message.InstanceStatus));
} }
private void HandleReportInstancePlayerCounts(ReportInstancePlayerCountsMessage message) { private void HandleReportInstancePlayerCounts(ReportInstancePlayerCountsMessage message) {
agentManager.TellAgent(agentGuid, new AgentActor.UpdateInstancePlayerCountsCommand(message.InstanceGuid, message.PlayerCounts)); agentManager.TellAgent(RequireAgentGuid(), new AgentActor.UpdateInstancePlayerCountsCommand(message.InstanceGuid, message.PlayerCounts));
} }
private void HandleReportInstanceEvent(ReportInstanceEventMessage message) { private void HandleReportInstanceEvent(ReportInstanceEventMessage message) {
message.Event.Accept(eventLogManager.CreateInstanceEventVisitor(message.EventGuid, message.UtcTime, agentGuid, message.InstanceGuid)); message.Event.Accept(eventLogManager.CreateInstanceEventVisitor(message.EventGuid, message.UtcTime, RequireAgentGuid(), message.InstanceGuid));
} }
private void HandleInstanceOutput(InstanceOutputMessage message) { private void HandleInstanceOutput(InstanceOutputMessage message) {

View File

@@ -1,34 +0,0 @@
using Phantom.Common.Messages.Agent;
using Phantom.Common.Messages.Agent.ToController;
using Phantom.Controller.Services.Agents;
using Phantom.Controller.Services.Events;
using Phantom.Controller.Services.Instances;
using Phantom.Utils.Actor;
using Phantom.Utils.Rpc.Runtime2;
namespace Phantom.Controller.Services.Rpc;
sealed class AgentRegistrationHandler : IRegistrationHandler<IMessageToAgent, IMessageToController, RegisterAgentMessage> {
private readonly AgentManager agentManager;
private readonly InstanceLogManager instanceLogManager;
private readonly EventLogManager eventLogManager;
public AgentRegistrationHandler(AgentManager agentManager, InstanceLogManager instanceLogManager, EventLogManager eventLogManager) {
this.agentManager = agentManager;
this.instanceLogManager = instanceLogManager;
this.eventLogManager = eventLogManager;
}
async Task<Props<IMessageToController>?> IRegistrationHandler<IMessageToAgent, IMessageToController, RegisterAgentMessage>.TryRegister(RpcConnectionToClient<IMessageToAgent> connection, RegisterAgentMessage message) {
return await TryRegisterImpl(connection, message) ? CreateMessageHandlerActorProps(message.AgentInfo.AgentGuid, connection) : null;
}
public Task<bool> TryRegisterImpl(RpcConnectionToClient<IMessageToAgent> connection, RegisterAgentMessage message) {
return agentManager.RegisterAgent(message.AuthToken, message.AgentInfo, connection);
}
private Props<IMessageToController> CreateMessageHandlerActorProps(Guid agentGuid, RpcConnectionToClient<IMessageToAgent> connection) {
var init = new AgentMessageHandlerActor.Init(agentGuid, connection, this, agentManager, instanceLogManager, eventLogManager);
return AgentMessageHandlerActor.Factory(init);
}
}

View File

@@ -0,0 +1,33 @@
using Akka.Actor;
using Phantom.Common.Messages.Web;
using Phantom.Controller.Minecraft;
using Phantom.Controller.Services.Agents;
using Phantom.Controller.Services.Events;
using Phantom.Controller.Services.Instances;
using Phantom.Controller.Services.Users;
using Phantom.Controller.Services.Users.Sessions;
using Phantom.Utils.Actor;
using Phantom.Utils.Rpc.Message;
using Phantom.Utils.Rpc.Runtime.Server;
namespace Phantom.Controller.Services.Rpc;
sealed class WebClientRegistrar(
IActorRefFactory actorSystem,
ControllerState controllerState,
InstanceLogManager instanceLogManager,
UserManager userManager,
RoleManager roleManager,
UserRoleManager userRoleManager,
UserLoginManager userLoginManager,
AuditLogManager auditLogManager,
AgentManager agentManager,
MinecraftVersions minecraftVersions,
EventLogManager eventLogManager
) : IRpcServerClientRegistrar<IMessageToController, IMessageToWeb> {
public IMessageReceiver<IMessageToController> Register(RpcServerToClientConnection<IMessageToController, IMessageToWeb> connection) {
var name = "WebClient-" + connection.SessionId;
var init = new WebMessageHandlerActor.Init(connection, controllerState, instanceLogManager, userManager, roleManager, userRoleManager, userLoginManager, auditLogManager, agentManager, minecraftVersions, eventLogManager);
return new IMessageReceiver<IMessageToController>.Actor(actorSystem.ActorOf(WebMessageHandlerActor.Factory(init), name));
}
}

View File

@@ -5,17 +5,18 @@ using Phantom.Common.Messages.Web;
using Phantom.Common.Messages.Web.ToWeb; using Phantom.Common.Messages.Web.ToWeb;
using Phantom.Controller.Services.Instances; using Phantom.Controller.Services.Instances;
using Phantom.Utils.Actor; using Phantom.Utils.Actor;
using Phantom.Utils.Rpc.Runtime;
namespace Phantom.Controller.Services.Rpc; namespace Phantom.Controller.Services.Rpc;
sealed class WebMessageDataUpdateSenderActor : ReceiveActor<WebMessageDataUpdateSenderActor.ICommand> { sealed class WebMessageDataUpdateSenderActor : ReceiveActor<WebMessageDataUpdateSenderActor.ICommand> {
public readonly record struct Init(RpcConnectionToClient<IMessageToWeb> Connection, ControllerState ControllerState, InstanceLogManager InstanceLogManager); public readonly record struct Init(RpcSendChannel<IMessageToWeb> Connection, ControllerState ControllerState, InstanceLogManager InstanceLogManager);
public static Props<ICommand> Factory(Init init) { public static Props<ICommand> Factory(Init init) {
return Props<ICommand>.Create(() => new WebMessageDataUpdateSenderActor(init), new ActorConfiguration { SupervisorStrategy = SupervisorStrategies.Resume }); return Props<ICommand>.Create(() => new WebMessageDataUpdateSenderActor(init), new ActorConfiguration { SupervisorStrategy = SupervisorStrategies.Resume });
} }
private readonly RpcConnectionToClient<IMessageToWeb> connection; private readonly RpcSendChannel<IMessageToWeb> connection;
private readonly ControllerState controllerState; private readonly ControllerState controllerState;
private readonly InstanceLogManager instanceLogManager; private readonly InstanceLogManager instanceLogManager;
private readonly ActorRef<ICommand> selfCached; private readonly ActorRef<ICommand> selfCached;
@@ -69,18 +70,18 @@ sealed class WebMessageDataUpdateSenderActor : ReceiveActor<WebMessageDataUpdate
private sealed record RefreshUserSessionCommand(Guid UserGuid) : ICommand; private sealed record RefreshUserSessionCommand(Guid UserGuid) : ICommand;
private Task RefreshAgents(RefreshAgentsCommand command) { private Task RefreshAgents(RefreshAgentsCommand command) {
return connection.Send(new RefreshAgentsMessage(command.Agents.Values.ToImmutableArray())); return connection.SendMessage(new RefreshAgentsMessage(command.Agents.Values.ToImmutableArray())).AsTask();
} }
private Task RefreshInstances(RefreshInstancesCommand command) { private Task RefreshInstances(RefreshInstancesCommand command) {
return connection.Send(new RefreshInstancesMessage(command.Instances.Values.ToImmutableArray())); return connection.SendMessage(new RefreshInstancesMessage(command.Instances.Values.ToImmutableArray())).AsTask();
} }
private Task ReceiveInstanceLogs(ReceiveInstanceLogsCommand command) { private Task ReceiveInstanceLogs(ReceiveInstanceLogsCommand command) {
return connection.Send(new InstanceOutputMessage(command.InstanceGuid, command.Lines)); return connection.SendMessage(new InstanceOutputMessage(command.InstanceGuid, command.Lines)).AsTask();
} }
private Task RefreshUserSession(RefreshUserSessionCommand command) { private Task RefreshUserSession(RefreshUserSessionCommand command) {
return connection.Send(new RefreshUserSessionMessage(command.UserGuid)); return connection.SendMessage(new RefreshUserSessionMessage(command.UserGuid)).AsTask();
} }
} }

View File

@@ -1,4 +1,5 @@
using System.Collections.Immutable; using System.Collections.Immutable;
using Akka.Actor;
using Phantom.Common.Data; using Phantom.Common.Data;
using Phantom.Common.Data.Java; using Phantom.Common.Data.Java;
using Phantom.Common.Data.Minecraft; using Phantom.Common.Data.Minecraft;
@@ -16,13 +17,13 @@ using Phantom.Controller.Services.Instances;
using Phantom.Controller.Services.Users; using Phantom.Controller.Services.Users;
using Phantom.Controller.Services.Users.Sessions; using Phantom.Controller.Services.Users.Sessions;
using Phantom.Utils.Actor; using Phantom.Utils.Actor;
using Phantom.Utils.Rpc.Runtime.Server;
namespace Phantom.Controller.Services.Rpc; namespace Phantom.Controller.Services.Rpc;
sealed class WebMessageHandlerActor : ReceiveActor<IMessageToController> { sealed class WebMessageHandlerActor : ReceiveActor<IMessageToController> {
public readonly record struct Init( public readonly record struct Init(
RpcConnectionToClient<IMessageToWeb> Connection, RpcServerToClientConnection<IMessageToController, IMessageToWeb> Connection,
WebRegistrationHandler WebRegistrationHandler,
ControllerState ControllerState, ControllerState ControllerState,
InstanceLogManager InstanceLogManager, InstanceLogManager InstanceLogManager,
UserManager UserManager, UserManager UserManager,
@@ -39,8 +40,7 @@ sealed class WebMessageHandlerActor : ReceiveActor<IMessageToController> {
return Props<IMessageToController>.Create(() => new WebMessageHandlerActor(init), new ActorConfiguration { SupervisorStrategy = SupervisorStrategies.Resume }); return Props<IMessageToController>.Create(() => new WebMessageHandlerActor(init), new ActorConfiguration { SupervisorStrategy = SupervisorStrategies.Resume });
} }
private readonly RpcConnectionToClient<IMessageToWeb> connection; private readonly RpcServerToClientConnection<IMessageToController, IMessageToWeb> connection;
private readonly WebRegistrationHandler webRegistrationHandler;
private readonly ControllerState controllerState; private readonly ControllerState controllerState;
private readonly UserManager userManager; private readonly UserManager userManager;
private readonly RoleManager roleManager; private readonly RoleManager roleManager;
@@ -53,7 +53,6 @@ sealed class WebMessageHandlerActor : ReceiveActor<IMessageToController> {
private WebMessageHandlerActor(Init init) { private WebMessageHandlerActor(Init init) {
this.connection = init.Connection; this.connection = init.Connection;
this.webRegistrationHandler = init.WebRegistrationHandler;
this.controllerState = init.ControllerState; this.controllerState = init.ControllerState;
this.userManager = init.UserManager; this.userManager = init.UserManager;
this.roleManager = init.RoleManager; this.roleManager = init.RoleManager;
@@ -64,11 +63,10 @@ sealed class WebMessageHandlerActor : ReceiveActor<IMessageToController> {
this.minecraftVersions = init.MinecraftVersions; this.minecraftVersions = init.MinecraftVersions;
this.eventLogManager = init.EventLogManager; this.eventLogManager = init.EventLogManager;
var senderActorInit = new WebMessageDataUpdateSenderActor.Init(connection, controllerState, init.InstanceLogManager); var senderActorInit = new WebMessageDataUpdateSenderActor.Init(connection.SendChannel, controllerState, init.InstanceLogManager);
Context.ActorOf(WebMessageDataUpdateSenderActor.Factory(senderActorInit), "DataUpdateSender"); Context.ActorOf(WebMessageDataUpdateSenderActor.Factory(senderActorInit), "DataUpdateSender");
ReceiveAsync<RegisterWebMessage>(HandleRegisterWeb); ReceiveAsync<UnregisterWebMessage>(HandleUnregisterWeb);
Receive<UnregisterWebMessage>(HandleUnregisterWeb);
ReceiveAndReplyLater<LogInMessage, Optional<LogInSuccess>>(HandleLogIn); ReceiveAndReplyLater<LogInMessage, Optional<LogInSuccess>>(HandleLogIn);
Receive<LogOutMessage>(HandleLogOut); Receive<LogOutMessage>(HandleLogOut);
ReceiveAndReply<GetAuthenticatedUser, Optional<AuthenticatedUserInfo>>(GetAuthenticatedUser); ReceiveAndReply<GetAuthenticatedUser, Optional<AuthenticatedUserInfo>>(GetAuthenticatedUser);
@@ -89,12 +87,9 @@ sealed class WebMessageHandlerActor : ReceiveActor<IMessageToController> {
ReceiveAndReplyLater<GetEventLogMessage, Result<ImmutableArray<EventLogItem>, UserActionFailure>>(HandleGetEventLog); ReceiveAndReplyLater<GetEventLogMessage, Result<ImmutableArray<EventLogItem>, UserActionFailure>>(HandleGetEventLog);
} }
private async Task HandleRegisterWeb(RegisterWebMessage message) { private Task HandleUnregisterWeb(UnregisterWebMessage message) {
await webRegistrationHandler.TryRegisterImpl(connection, message); Self.Tell(PoisonPill.Instance);
} return connection.ClientClosedSession();
private void HandleUnregisterWeb(UnregisterWebMessage message) {
connection.Close();
} }
private Task<Optional<LogInSuccess>> HandleLogIn(LogInMessage message) { private Task<Optional<LogInSuccess>> HandleLogIn(LogInMessage message) {

View File

@@ -1,67 +0,0 @@
using Phantom.Common.Messages.Web;
using Phantom.Common.Messages.Web.ToController;
using Phantom.Common.Messages.Web.ToWeb;
using Phantom.Controller.Minecraft;
using Phantom.Controller.Services.Agents;
using Phantom.Controller.Services.Events;
using Phantom.Controller.Services.Instances;
using Phantom.Controller.Services.Users;
using Phantom.Controller.Services.Users.Sessions;
using Phantom.Utils.Actor;
using Phantom.Utils.Logging;
using Phantom.Utils.Rpc;
using Serilog;
namespace Phantom.Controller.Services.Rpc;
sealed class WebRegistrationHandler : IRegistrationHandler<IMessageToWeb, IMessageToController, RegisterWebMessage> {
private static readonly ILogger Logger = PhantomLogger.Create<WebRegistrationHandler>();
private readonly AuthToken webAuthToken;
private readonly ControllerState controllerState;
private readonly InstanceLogManager instanceLogManager;
private readonly UserManager userManager;
private readonly RoleManager roleManager;
private readonly UserRoleManager userRoleManager;
private readonly UserLoginManager userLoginManager;
private readonly AuditLogManager auditLogManager;
private readonly AgentManager agentManager;
private readonly MinecraftVersions minecraftVersions;
private readonly EventLogManager eventLogManager;
public WebRegistrationHandler(AuthToken webAuthToken, ControllerState controllerState, InstanceLogManager instanceLogManager, UserManager userManager, RoleManager roleManager, UserRoleManager userRoleManager, UserLoginManager userLoginManager, AuditLogManager auditLogManager, AgentManager agentManager, MinecraftVersions minecraftVersions, EventLogManager eventLogManager) {
this.webAuthToken = webAuthToken;
this.controllerState = controllerState;
this.userManager = userManager;
this.roleManager = roleManager;
this.userRoleManager = userRoleManager;
this.userLoginManager = userLoginManager;
this.auditLogManager = auditLogManager;
this.agentManager = agentManager;
this.minecraftVersions = minecraftVersions;
this.eventLogManager = eventLogManager;
this.instanceLogManager = instanceLogManager;
}
async Task<Props<IMessageToController>?> IRegistrationHandler<IMessageToWeb, IMessageToController, RegisterWebMessage>.TryRegister(RpcConnectionToClient<IMessageToWeb> connection, RegisterWebMessage message) {
return await TryRegisterImpl(connection, message) ? CreateMessageHandlerActorProps(connection) : null;
}
public async Task<bool> TryRegisterImpl(RpcConnectionToClient<IMessageToWeb> connection, RegisterWebMessage message) {
if (webAuthToken.FixedTimeEquals(message.AuthToken)) {
Logger.Information("Web authorized successfully.");
await connection.Send(new RegisterWebResultMessage(true));
return true;
}
else {
Logger.Warning("Web failed to authorize, invalid token.");
await connection.Send(new RegisterWebResultMessage(false));
return false;
}
}
private Props<IMessageToController> CreateMessageHandlerActorProps(RpcConnectionToClient<IMessageToWeb> connection) {
var init = new WebMessageHandlerActor.Init(connection, this, controllerState, instanceLogManager, userManager, roleManager, userRoleManager, userLoginManager, auditLogManager, agentManager, minecraftVersions, eventLogManager);
return WebMessageHandlerActor.Factory(init);
}
}

View File

@@ -1,4 +1,6 @@
using System.Reflection; using System.Reflection;
using Phantom.Common.Messages.Agent;
using Phantom.Common.Messages.Web;
using Phantom.Controller; using Phantom.Controller;
using Phantom.Controller.Database.Postgres; using Phantom.Controller.Database.Postgres;
using Phantom.Controller.Services; using Phantom.Controller.Services;
@@ -7,6 +9,8 @@ using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Runtime.Server; using Phantom.Utils.Rpc.Runtime.Server;
using Phantom.Utils.Runtime; using Phantom.Utils.Runtime;
using Phantom.Utils.Tasks; using Phantom.Utils.Tasks;
using RpcAgentServer = Phantom.Utils.Rpc.Runtime.Server.RpcServer<Phantom.Common.Messages.Agent.IMessageToController, Phantom.Common.Messages.Agent.IMessageToAgent>;
using RpcWebServer = Phantom.Utils.Rpc.Runtime.Server.RpcServer<Phantom.Common.Messages.Web.IMessageToController, Phantom.Common.Messages.Web.IMessageToWeb>;
var shutdownCancellationTokenSource = new CancellationTokenSource(); var shutdownCancellationTokenSource = new CancellationTokenSource();
var shutdownCancellationToken = shutdownCancellationTokenSource.Token; var shutdownCancellationToken = shutdownCancellationTokenSource.Token;
@@ -53,12 +57,28 @@ try {
var dbContextFactory = new ApplicationDbContextFactory(sqlConnectionString); var dbContextFactory = new ApplicationDbContextFactory(sqlConnectionString);
using var controllerServices = new ControllerServices(dbContextFactory, agentKeyData.AuthToken, webKeyData.AuthToken, shutdownCancellationToken); using var controllerServices = new ControllerServices(dbContextFactory, agentKeyData.AuthToken, shutdownCancellationToken);
await controllerServices.Initialize(); await controllerServices.Initialize();
var agentConnectionParameters = new RpcServerConnectionParameters(
EndPoint: agentRpcServerHost,
Certificate: agentKeyData.Certificate,
AuthToken: agentKeyData.AuthToken,
SendQueueCapacity: 100,
PingInterval: TimeSpan.FromSeconds(10)
);
var webConnectionParameters = new RpcServerConnectionParameters(
EndPoint: webRpcServerHost,
Certificate: webKeyData.Certificate,
AuthToken: webKeyData.AuthToken,
SendQueueCapacity: 500,
PingInterval: TimeSpan.FromMinutes(1)
);
LinkedTasks<bool> rpcServerTasks = new LinkedTasks<bool>([ LinkedTasks<bool> rpcServerTasks = new LinkedTasks<bool>([
new RpcServer("Agent", agentRpcServerHost, agentKeyData.AuthToken, agentKeyData.Certificate).Run(shutdownCancellationToken), new RpcAgentServer("Agent", agentConnectionParameters, AgentMessageRegistries.Definitions, controllerServices.AgentRegistrar).Run(shutdownCancellationToken),
new RpcServer("Web", webRpcServerHost, agentKeyData.AuthToken, webKeyData.Certificate).Run(shutdownCancellationToken), new RpcWebServer("Web", webConnectionParameters, WebMessageRegistries.Definitions, controllerServices.WebRegistrar).Run(shutdownCancellationToken),
]); ]);
// If either RPC server crashes, stop the whole process. // If either RPC server crashes, stop the whole process.
@@ -71,11 +91,6 @@ try {
} }
return 0; return 0;
// await Task.WhenAll(
// RpcServerRuntime.Launch(ConfigureRpc("Agent", agentRpcServerHost, agentRpcServerPort, agentKeyData), AgentMessageRegistries.Definitions, controllerServices.AgentRegistrationHandler, controllerServices.ActorSystem, shutdownCancellationToken),
// RpcServerRuntime.Launch(ConfigureRpc("Web", webRpcServerHost, webRpcServerPort, webKeyData), WebMessageRegistries.Definitions, controllerServices.WebRegistrationHandler, controllerServices.ActorSystem, shutdownCancellationToken)
// );
} catch (OperationCanceledException) { } catch (OperationCanceledException) {
return 0; return 0;
} catch (StopProcedureException) { } catch (StopProcedureException) {

View File

@@ -26,7 +26,7 @@ sealed record Variables(
EndPoint webRpcServerHost = new IPEndPoint( EndPoint webRpcServerHost = new IPEndPoint(
EnvironmentVariables.GetIpAddress("WEB_RPC_SERVER_HOST").WithDefault(IPAddress.Any), EnvironmentVariables.GetIpAddress("WEB_RPC_SERVER_HOST").WithDefault(IPAddress.Any),
EnvironmentVariables.GetPortNumber("WEB_RPC_SERVER_PORT").WithDefault(9401) EnvironmentVariables.GetPortNumber("WEB_RPC_SERVER_PORT").WithDefault(9402)
); );
return new Variables( return new Variables(

View File

@@ -23,7 +23,6 @@
<ItemGroup> <ItemGroup>
<PackageReference Update="BCrypt.Net-Next.StrongName" Version="4.0.3" /> <PackageReference Update="BCrypt.Net-Next.StrongName" Version="4.0.3" />
<PackageReference Update="MemoryPack" Version="1.10.0" /> <PackageReference Update="MemoryPack" Version="1.10.0" />
<PackageReference Update="NetMQ" Version="4.0.1.13" />
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>

View File

@@ -31,29 +31,48 @@ public static class PhantomLogger {
} }
public static ILogger Create<T>() { public static ILogger Create<T>() {
return Create(typeof(T).Name); return Create(TypeName<T>());
} }
public static ILogger Create<T>(string name) { public static ILogger Create<T>(string name) {
return Create(ConcatNames(typeof(T).Name, name)); return Create(ConcatNames(TypeName<T>(), name));
} }
public static ILogger Create<T>(string name1, string name2) { public static ILogger Create<T>(string name1, string name2) {
return Create(ConcatNames(typeof(T).Name, ConcatNames(name1, name2))); return Create(ConcatNames(TypeName<T>(), name1, name2));
} }
public static ILogger Create<T1, T2>() { public static ILogger Create<T1, T2>() {
return Create(ConcatNames(typeof(T1).Name, typeof(T2).Name)); return Create(ConcatNames(TypeName<T1>(), TypeName<T2>()));
} }
public static ILogger Create<T1, T2>(string name) { public static ILogger Create<T1, T2>(string name) {
return Create(ConcatNames(typeof(T1).Name, ConcatNames(typeof(T2).Name, name))); return Create(ConcatNames(TypeName<T1>(), TypeName<T2>(), name));
} }
private static string ConcatNames(string name1, string name2) { public static ILogger Create<T1, T2>(string name1, string name2) {
return Create(ConcatNames(TypeName<T1>(), TypeName<T2>(), ConcatNames(name1, name2)));
}
private static string TypeName<T>() {
string typeName = typeof(T).Name;
int genericsStartIndex = typeName.IndexOf('`');
return genericsStartIndex > 0 ? typeName[..genericsStartIndex] : typeName;
}
public static string ConcatNames(string name1, string name2) {
return name1 + ":" + name2; return name1 + ":" + name2;
} }
public static string ConcatNames(string name1, string name2, string name3) {
return ConcatNames(name1, ConcatNames(name2, name3));
}
public static string ShortenGuid(Guid guid) {
var prefix = guid.ToString();
return prefix[..prefix.IndexOf('-')];
}
public static void Dispose() { public static void Dispose() {
Root.Dispose(); Root.Dispose();
Base.Dispose(); Base.Dispose();

View File

@@ -4,11 +4,13 @@ namespace Phantom.Utils.Rpc.Frame;
interface IFrame { interface IFrame {
private const byte TypePingId = 0; private const byte TypePingId = 0;
private const byte TypeMessageId = 1; private const byte TypePongId = 1;
private const byte TypeReplyId = 2; private const byte TypeMessageId = 2;
private const byte TypeErrorId = 3; private const byte TypeReplyId = 3;
private const byte TypeErrorId = 4;
static readonly ReadOnlyMemory<byte> TypePing = new ([TypePingId]); static readonly ReadOnlyMemory<byte> TypePing = new ([TypePingId]);
static readonly ReadOnlyMemory<byte> TypePong = new ([TypePongId]);
static readonly ReadOnlyMemory<byte> TypeMessage = new ([TypeMessageId]); static readonly ReadOnlyMemory<byte> TypeMessage = new ([TypeMessageId]);
static readonly ReadOnlyMemory<byte> TypeReply = new ([TypeReplyId]); static readonly ReadOnlyMemory<byte> TypeReply = new ([TypeReplyId]);
static readonly ReadOnlyMemory<byte> TypeError = new ([TypeErrorId]); static readonly ReadOnlyMemory<byte> TypeError = new ([TypeErrorId]);
@@ -21,11 +23,18 @@ interface IFrame {
switch (oneByteBuffer[0]) { switch (oneByteBuffer[0]) {
case TypePingId: case TypePingId:
var pingTime = await PingFrame.Read(stream, cancellationToken);
await reader.OnPingFrame(pingTime, cancellationToken);
break;
case TypePongId:
var pongFrame = await PongFrame.Read(stream, cancellationToken);
reader.OnPongFrame(pongFrame);
break; break;
case TypeMessageId: case TypeMessageId:
var messageFrame = await MessageFrame.Read(stream, cancellationToken); var messageFrame = await MessageFrame.Read(stream, cancellationToken);
await reader.OnMessageFrame(messageFrame, stream, cancellationToken); await reader.OnMessageFrame(messageFrame, cancellationToken);
break; break;
case TypeReplyId: case TypeReplyId:
@@ -39,13 +48,13 @@ interface IFrame {
break; break;
default: default:
reader.OnUnknownFrameStart(oneByteBuffer[0]); reader.OnUnknownFrameId(oneByteBuffer[0]);
break; break;
} }
} }
} }
ReadOnlyMemory<byte> Type { get; } ReadOnlyMemory<byte> FrameType { get; }
Task Write(Stream stream, CancellationToken cancellationToken = default); Task Write(Stream stream, CancellationToken cancellationToken = default);
} }

View File

@@ -3,8 +3,10 @@
namespace Phantom.Utils.Rpc.Frame; namespace Phantom.Utils.Rpc.Frame;
interface IFrameReader { interface IFrameReader {
Task OnMessageFrame(MessageFrame frame, Stream stream, CancellationToken cancellationToken); ValueTask OnPingFrame(DateTimeOffset pingTime, CancellationToken cancellationToken);
void OnPongFrame(PongFrame frame);
Task OnMessageFrame(MessageFrame frame, CancellationToken cancellationToken);
void OnReplyFrame(ReplyFrame frame); void OnReplyFrame(ReplyFrame frame);
void OnErrorFrame(ErrorFrame frame); void OnErrorFrame(ErrorFrame frame);
void OnUnknownFrameStart(byte id); void OnUnknownFrameId(byte frameId);
} }

View File

@@ -3,7 +3,7 @@
namespace Phantom.Utils.Rpc.Frame.Types; namespace Phantom.Utils.Rpc.Frame.Types;
sealed record ErrorFrame(uint ReplyingToMessageId, RpcError Error) : IFrame { sealed record ErrorFrame(uint ReplyingToMessageId, RpcError Error) : IFrame {
public ReadOnlyMemory<byte> Type => IFrame.TypeError; public ReadOnlyMemory<byte> FrameType => IFrame.TypeError;
public async Task Write(Stream stream, CancellationToken cancellationToken) { public async Task Write(Stream stream, CancellationToken cancellationToken) {
await Serialization.WriteUnsignedInt(ReplyingToMessageId, stream, cancellationToken); await Serialization.WriteUnsignedInt(ReplyingToMessageId, stream, cancellationToken);

View File

@@ -5,7 +5,7 @@ namespace Phantom.Utils.Rpc.Frame.Types;
sealed record MessageFrame(uint MessageId, ushort RegistryCode, ReadOnlyMemory<byte> SerializedMessage) : IFrame { sealed record MessageFrame(uint MessageId, ushort RegistryCode, ReadOnlyMemory<byte> SerializedMessage) : IFrame {
public const int MaxMessageBytes = 1024 * 1024 * 8; public const int MaxMessageBytes = 1024 * 1024 * 8;
public ReadOnlyMemory<byte> Type => IFrame.TypeMessage; public ReadOnlyMemory<byte> FrameType => IFrame.TypeMessage;
public async Task Write(Stream stream, CancellationToken cancellationToken) { public async Task Write(Stream stream, CancellationToken cancellationToken) {
int messageLength = SerializedMessage.Length; int messageLength = SerializedMessage.Length;

View File

@@ -1,11 +1,15 @@
namespace Phantom.Utils.Rpc.Frame.Types; namespace Phantom.Utils.Rpc.Frame.Types;
sealed record PingFrame : IFrame { sealed record PingFrame : IFrame {
public static PingFrame Instance { get; } = new PingFrame(); public static PingFrame Instance { get; } = new ();
public ReadOnlyMemory<byte> Type => IFrame.TypePing; public ReadOnlyMemory<byte> FrameType => IFrame.TypePing;
public Task Write(Stream stream, CancellationToken cancellationToken) { public async Task Write(Stream stream, CancellationToken cancellationToken) {
return Task.CompletedTask; await Serialization.WriteSignedLong(DateTimeOffset.UtcNow.ToUnixTimeMilliseconds(), stream, cancellationToken);
}
public static async Task<DateTimeOffset> Read(Stream stream, CancellationToken cancellationToken) {
return DateTimeOffset.FromUnixTimeMilliseconds(await Serialization.ReadSignedLong(stream, cancellationToken));
} }
} }

View File

@@ -0,0 +1,13 @@
namespace Phantom.Utils.Rpc.Frame.Types;
sealed record PongFrame(DateTimeOffset PingTime) : IFrame {
public ReadOnlyMemory<byte> FrameType => IFrame.TypePong;
public async Task Write(Stream stream, CancellationToken cancellationToken) {
await Serialization.WriteSignedLong(PingTime.ToUnixTimeMilliseconds(), stream, cancellationToken);
}
public static async Task<PongFrame> Read(Stream stream, CancellationToken cancellationToken) {
return new PongFrame(DateTimeOffset.FromUnixTimeMilliseconds(await Serialization.ReadSignedLong(stream, cancellationToken)));
}
}

View File

@@ -5,7 +5,7 @@ namespace Phantom.Utils.Rpc.Frame.Types;
sealed record ReplyFrame(uint ReplyingToMessageId, ReadOnlyMemory<byte> SerializedReply) : IFrame { sealed record ReplyFrame(uint ReplyingToMessageId, ReadOnlyMemory<byte> SerializedReply) : IFrame {
public const int MaxReplyBytes = 1024 * 1024 * 32; public const int MaxReplyBytes = 1024 * 1024 * 32;
public ReadOnlyMemory<byte> Type => IFrame.TypeReply; public ReadOnlyMemory<byte> FrameType => IFrame.TypeReply;
public async Task Write(Stream stream, CancellationToken cancellationToken) { public async Task Write(Stream stream, CancellationToken cancellationToken) {
int replyLength = SerializedReply.Length; int replyLength = SerializedReply.Length;

View File

@@ -0,0 +1,5 @@
namespace Phantom.Utils.Rpc;
public interface IRpcListener {
}

View File

@@ -0,0 +1,8 @@
using Phantom.Utils.Rpc.Runtime;
namespace Phantom.Utils.Rpc;
interface IRpcReplySender {
ValueTask SendReply<TReply>(uint replyingToMessageId, TReply reply, CancellationToken cancellationToken);
ValueTask SendError(uint replyingToMessageId, RpcError error, CancellationToken cancellationToken);
}

View File

@@ -0,0 +1,21 @@
using Phantom.Utils.Actor;
namespace Phantom.Utils.Rpc.Message;
public interface IMessageReceiver<TMessageBase> {
void OnPing();
void OnMessage(TMessageBase message);
Task<TReply> OnMessage<TMessage, TReply>(TMessage message, CancellationToken cancellationToken = default) where TMessage : TMessageBase, ICanReply<TReply>;
class Actor(ActorRef<TMessageBase> actor) : IMessageReceiver<TMessageBase> {
public virtual void OnPing() {}
public void OnMessage(TMessageBase message) {
actor.Tell(message);
}
public Task<TReply> OnMessage<TMessage, TReply>(TMessage message, CancellationToken cancellationToken = default) where TMessage : TMessageBase, ICanReply<TReply> {
return actor.Request(message, cancellationToken);
}
}
}

View File

@@ -1,11 +0,0 @@
using Phantom.Utils.Actor;
using Phantom.Utils.Rpc.Runtime;
namespace Phantom.Utils.Rpc.Message;
interface MessageHandler<TMessageBase> {
ActorRef<TMessageBase> Actor { get; }
ValueTask OnReply<TMessage, TReply>(uint messageId, TReply reply, CancellationToken cancellationToken) where TMessage : TMessageBase, ICanReply<TReply>;
ValueTask OnError(uint messageId, RpcError error, CancellationToken cancellationToken);
}

View File

@@ -0,0 +1,13 @@
using Phantom.Utils.Collections;
namespace Phantom.Utils.Rpc.Message;
sealed class MessageReceiveTracker {
private readonly RangeSet<uint> receivedMessageIds = new ();
public bool ReceiveMessage(uint messageId) {
lock (receivedMessageIds) {
return receivedMessageIds.Add(messageId);
}
}
}

View File

@@ -1,4 +1,5 @@
using Phantom.Utils.Actor; using System.Diagnostics.CodeAnalysis;
using Phantom.Utils.Actor;
using Phantom.Utils.Rpc.Frame.Types; using Phantom.Utils.Rpc.Frame.Types;
using Phantom.Utils.Rpc.Runtime; using Phantom.Utils.Rpc.Runtime;
using Serilog; using Serilog;
@@ -8,7 +9,7 @@ namespace Phantom.Utils.Rpc.Message;
public sealed class MessageRegistry<TMessageBase>(ILogger logger) { public sealed class MessageRegistry<TMessageBase>(ILogger logger) {
private readonly Dictionary<Type, ushort> typeToCodeMapping = new (); private readonly Dictionary<Type, ushort> typeToCodeMapping = new ();
private readonly Dictionary<ushort, Type> codeToTypeMapping = new (); private readonly Dictionary<ushort, Type> codeToTypeMapping = new ();
private readonly Dictionary<ushort, Func<uint, ReadOnlyMemory<byte>, MessageHandler<TMessageBase>, CancellationToken, Task>> codeToHandlerMapping = new (); private readonly Dictionary<ushort, Func<uint, ReadOnlyMemory<byte>, RpcMessageHandler<TMessageBase>, CancellationToken, Task>> codeToHandlerMapping = new ();
public void Add<TMessage>(ushort code) where TMessage : TMessageBase { public void Add<TMessage>(ushort code) where TMessage : TMessageBase {
if (HasReplyType(typeof(TMessage))) { if (HasReplyType(typeof(TMessage))) {
@@ -36,6 +37,10 @@ public sealed class MessageRegistry<TMessageBase>(ILogger logger) {
return messageType.GetInterfaces().Any(type => type.FullName is {} name && name.StartsWith(replyInterfaceName, StringComparison.Ordinal)); return messageType.GetInterfaces().Any(type => type.FullName is {} name && name.StartsWith(replyInterfaceName, StringComparison.Ordinal));
} }
internal bool TryGetType(MessageFrame frame, [NotNullWhen(true)] out Type? type) {
return codeToTypeMapping.TryGetValue(frame.RegistryCode, out type);
}
internal MessageFrame CreateFrame<TMessage>(uint messageId, TMessage message) where TMessage : TMessageBase { internal MessageFrame CreateFrame<TMessage>(uint messageId, TMessage message) where TMessage : TMessageBase {
if (typeToCodeMapping.TryGetValue(typeof(TMessage), out ushort code)) { if (typeToCodeMapping.TryGetValue(typeof(TMessage), out ushort code)) {
return new MessageFrame(messageId, code, Serialization.Serialize(message)); return new MessageFrame(messageId, code, Serialization.Serialize(message));
@@ -45,7 +50,7 @@ public sealed class MessageRegistry<TMessageBase>(ILogger logger) {
} }
} }
internal async Task Handle(MessageFrame frame, MessageHandler<TMessageBase> handler, CancellationToken cancellationToken) { internal async Task Handle(MessageFrame frame, RpcMessageHandler<TMessageBase> handler, CancellationToken cancellationToken) {
uint messageId = frame.MessageId; uint messageId = frame.MessageId;
if (codeToHandlerMapping.TryGetValue(frame.RegistryCode, out var action)) { if (codeToHandlerMapping.TryGetValue(frame.RegistryCode, out var action)) {
@@ -53,47 +58,47 @@ public sealed class MessageRegistry<TMessageBase>(ILogger logger) {
} }
else { else {
logger.Error("Unknown message code {Code} for message {MessageId}.", frame.RegistryCode, messageId); logger.Error("Unknown message code {Code} for message {MessageId}.", frame.RegistryCode, messageId);
await handler.OnError(messageId, RpcError.UnknownMessageRegistryCode, cancellationToken); await handler.SendError(messageId, RpcError.UnknownMessageRegistryCode, cancellationToken);
} }
} }
private async Task DeserializationHandler<TMessage>(uint messageId, ReadOnlyMemory<byte> serializedMessage, MessageHandler<TMessageBase> handler, CancellationToken cancellationToken) where TMessage : TMessageBase { private async Task DeserializationHandler<TMessage>(uint messageId, ReadOnlyMemory<byte> serializedMessage, RpcMessageHandler<TMessageBase> handler, CancellationToken cancellationToken) where TMessage : TMessageBase {
TMessage message; TMessage message;
try { try {
message = Serialization.Deserialize<TMessage>(serializedMessage); message = Serialization.Deserialize<TMessage>(serializedMessage);
} catch (Exception e) { } catch (Exception e) {
logger.Error(e, "Could not deserialize message {MessageId} ({MessageType}).", messageId, typeof(TMessage).Name); logger.Error(e, "Could not deserialize message {MessageId} ({MessageType}).", messageId, typeof(TMessage).Name);
await handler.OnError(messageId, RpcError.MessageDeserializationError, cancellationToken); await handler.SendError(messageId, RpcError.MessageDeserializationError, cancellationToken);
return; return;
} }
handler.Actor.Tell(message); handler.Receiver.OnMessage(message);
} }
private async Task DeserializationHandler<TMessage, TReply>(uint messageId, ReadOnlyMemory<byte> serializedMessage, MessageHandler<TMessageBase> handler, CancellationToken cancellationToken) where TMessage : TMessageBase, ICanReply<TReply> { private async Task DeserializationHandler<TMessage, TReply>(uint messageId, ReadOnlyMemory<byte> serializedMessage, RpcMessageHandler<TMessageBase> handler, CancellationToken cancellationToken) where TMessage : TMessageBase, ICanReply<TReply> {
TMessage message; TMessage message;
try { try {
message = Serialization.Deserialize<TMessage>(serializedMessage); message = Serialization.Deserialize<TMessage>(serializedMessage);
} catch (Exception e) { } catch (Exception e) {
logger.Error(e, "Could not deserialize message {MessageId} ({MessageType}).", messageId, typeof(TMessage).Name); logger.Error(e, "Could not deserialize message {MessageId} ({MessageType}).", messageId, typeof(TMessage).Name);
await handler.OnError(messageId, RpcError.MessageDeserializationError, cancellationToken); await handler.SendError(messageId, RpcError.MessageDeserializationError, cancellationToken);
return; return;
} }
TReply reply; TReply reply;
try { try {
reply = await handler.Actor.Request(message, cancellationToken); reply = await handler.Receiver.OnMessage<TMessage, TReply>(message, cancellationToken);
} catch (Exception e) { } catch (Exception e) {
logger.Error(e, "Could not handle message {MessageId} ({MessageType}).", messageId, typeof(TMessage).Name); logger.Error(e, "Could not handle message {MessageId} ({MessageType}).", messageId, typeof(TMessage).Name);
await handler.OnError(messageId, RpcError.MessageHandlingError, cancellationToken); await handler.SendError(messageId, RpcError.MessageHandlingError, cancellationToken);
return; return;
} }
try { try {
await handler.OnReply<TMessage, TReply>(messageId, reply, cancellationToken); await handler.SendReply(messageId, reply, cancellationToken);
} catch (Exception e) { } catch (Exception e) {
logger.Error(e, "Could not reply to message {MessageId} ({MessageType}).", messageId, typeof(TMessage).Name); logger.Error(e, "Could not reply to message {MessageId} ({MessageType}).", messageId, typeof(TMessage).Name);
await handler.OnError(messageId, RpcError.MessageHandlingError, cancellationToken); await handler.SendError(messageId, RpcError.MessageHandlingError, cancellationToken);
} }
} }
} }

View File

@@ -8,7 +8,7 @@ namespace Phantom.Utils.Rpc.Message;
sealed class MessageReplyTracker { sealed class MessageReplyTracker {
private readonly ILogger logger; private readonly ILogger logger;
private readonly ConcurrentDictionary<uint, TaskCompletionSource<ReadOnlyMemory<byte>>> replyTasks = new (concurrencyLevel: 4, capacity: 16); private readonly ConcurrentDictionary<uint, TaskCompletionSource<ReadOnlyMemory<byte>>> replyTasks = new (concurrencyLevel: 2, capacity: 16);
internal MessageReplyTracker(string loggerName) { internal MessageReplyTracker(string loggerName) {
this.logger = PhantomLogger.Create<MessageReplyTracker>(loggerName); this.logger = PhantomLogger.Create<MessageReplyTracker>(loggerName);

View File

@@ -7,7 +7,6 @@
<ItemGroup> <ItemGroup>
<PackageReference Include="MemoryPack" /> <PackageReference Include="MemoryPack" />
<PackageReference Include="NetMQ" />
<PackageReference Include="Serilog" /> <PackageReference Include="Serilog" />
</ItemGroup> </ItemGroup>

View File

@@ -1,7 +1,4 @@
using Phantom.Utils.Actor; using Phantom.Utils.Logging;
using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Frame;
using Phantom.Utils.Rpc.Frame.Types;
using Phantom.Utils.Rpc.Message; using Phantom.Utils.Rpc.Message;
using Serilog; using Serilog;
@@ -9,67 +6,39 @@ namespace Phantom.Utils.Rpc.Runtime.Client;
public sealed class RpcClient<TClientToServerMessage, TServerToClientMessage> : IDisposable { public sealed class RpcClient<TClientToServerMessage, TServerToClientMessage> : IDisposable {
public static async Task<RpcClient<TClientToServerMessage, TServerToClientMessage>?> Connect(string loggerName, RpcClientConnectionParameters connectionParameters, IMessageDefinitions<TClientToServerMessage, TServerToClientMessage> messageDefinitions, CancellationToken cancellationToken) { public static async Task<RpcClient<TClientToServerMessage, TServerToClientMessage>?> Connect(string loggerName, RpcClientConnectionParameters connectionParameters, IMessageDefinitions<TClientToServerMessage, TServerToClientMessage> messageDefinitions, CancellationToken cancellationToken) {
RpcClientConnector connector = new RpcClientConnector(loggerName, connectionParameters); RpcClientToServerConnector connector = new RpcClientToServerConnector(loggerName, connectionParameters);
RpcClientConnector.Connection? connection = await connector.EstablishNewConnection(cancellationToken); RpcClientToServerConnector.Connection? connection = await connector.EstablishNewConnection(cancellationToken);
return connection == null ? null : new RpcClient<TClientToServerMessage, TServerToClientMessage>(loggerName, connectionParameters, messageDefinitions, connector, connection); return connection == null ? null : new RpcClient<TClientToServerMessage, TServerToClientMessage>(loggerName, connectionParameters, messageDefinitions, connector, connection);
} }
private readonly string loggerName;
private readonly ILogger logger; private readonly ILogger logger;
private readonly MessageRegistry<TServerToClientMessage> serverToClientMessageRegistry; private readonly MessageRegistry<TServerToClientMessage> messageRegistry;
private readonly RpcClientConnection connection; private readonly MessageReceiveTracker messageReceiveTracker = new ();
private readonly RpcClientToServerConnection connection;
private readonly CancellationTokenSource shutdownCancellationTokenSource = new ();
public RpcSendChannel<TClientToServerMessage> SendChannel { get; } public RpcSendChannel<TClientToServerMessage> SendChannel { get; }
private RpcClient(string loggerName, RpcClientConnectionParameters connectionParameters, IMessageDefinitions<TClientToServerMessage, TServerToClientMessage> messageDefinitions, RpcClientConnector connector, RpcClientConnector.Connection connection) { private RpcClient(string loggerName, RpcClientConnectionParameters connectionParameters, IMessageDefinitions<TClientToServerMessage, TServerToClientMessage> messageDefinitions, RpcClientToServerConnector connector, RpcClientToServerConnector.Connection connection) {
this.loggerName = loggerName;
this.logger = PhantomLogger.Create<RpcClient<TClientToServerMessage, TServerToClientMessage>>(loggerName); this.logger = PhantomLogger.Create<RpcClient<TClientToServerMessage, TServerToClientMessage>>(loggerName);
this.serverToClientMessageRegistry = messageDefinitions.ToClient; this.messageRegistry = messageDefinitions.ToClient;
this.connection = new RpcClientConnection(loggerName, connector, connection); this.connection = new RpcClientToServerConnection(loggerName, connector, connection);
this.SendChannel = new RpcSendChannel<TClientToServerMessage>(loggerName, connectionParameters, this.connection, messageDefinitions.ToServer); this.SendChannel = new RpcSendChannel<TClientToServerMessage>(loggerName, connectionParameters.Common, this.connection, messageDefinitions.ToServer);
} }
public async Task Listen(ActorRef<TServerToClientMessage> actor) { public async Task Listen(IMessageReceiver<TServerToClientMessage> receiver) {
var messageHandler = new RpcMessageHandler<TServerToClientMessage>(receiver, SendChannel);
var frameReader = new RpcFrameReader<TClientToServerMessage, TServerToClientMessage>(loggerName, messageRegistry, messageReceiveTracker, messageHandler, SendChannel);
try { try {
await connection.ReadConnection(stream => Receive(stream, new MessageHandlerImpl(SendChannel, actor))); await connection.ReadConnection(frameReader, shutdownCancellationTokenSource.Token);
} catch (OperationCanceledException) { } catch (OperationCanceledException) {
// Ignore. // Ignore.
} }
} }
private async Task Receive(Stream stream, MessageHandlerImpl handler) {
await IFrame.ReadFrom(stream, new FrameReader(this, handler), CancellationToken.None);
}
private sealed class FrameReader(RpcClient<TClientToServerMessage, TServerToClientMessage> client, MessageHandlerImpl handler) : IFrameReader {
public Task OnMessageFrame(MessageFrame frame, Stream stream, CancellationToken cancellationToken) {
return client.serverToClientMessageRegistry.Handle(frame, handler, cancellationToken);
}
public void OnReplyFrame(ReplyFrame frame) {
client.SendChannel.ReceiveReply(frame);
}
public void OnErrorFrame(ErrorFrame frame) {
client.SendChannel.ReceiveError(frame.ReplyingToMessageId, frame.Error);
}
public void OnUnknownFrameStart(byte id) {
client.logger.Error("Received unknown frame ID: {Id}", id);
}
}
private sealed class MessageHandlerImpl(RpcSendChannel<TClientToServerMessage> sendChannel, ActorRef<TServerToClientMessage> actor) : MessageHandler<TServerToClientMessage> {
public ActorRef<TServerToClientMessage> Actor => actor;
public ValueTask OnReply<TMessage, TReply>(uint messageId, TReply reply, CancellationToken cancellationToken) where TMessage : TServerToClientMessage, ICanReply<TReply> {
return sendChannel.SendReply(messageId, reply, cancellationToken);
}
public ValueTask OnError(uint messageId, RpcError error, CancellationToken cancellationToken) {
return sendChannel.SendError(messageId, error, cancellationToken);
}
}
public async Task Shutdown() { public async Task Shutdown() {
logger.Information("Shutting down client..."); logger.Information("Shutting down client...");
@@ -80,11 +49,13 @@ public sealed class RpcClient<TClientToServerMessage, TServerToClientMessage> :
} }
try { try {
connection.Close(); connection.StopReconnecting();
} catch (Exception e) { } catch (Exception e) {
logger.Error(e, "Caught exception while closing connection."); logger.Error(e, "Caught exception while closing connection.");
} }
await shutdownCancellationTokenSource.CancelAsync();
logger.Information("Client shut down."); logger.Information("Client shut down.");
} }

View File

@@ -10,4 +10,6 @@ public readonly record struct RpcClientConnectionParameters(
AuthToken AuthToken, AuthToken AuthToken,
ushort SendQueueCapacity, ushort SendQueueCapacity,
TimeSpan PingInterval TimeSpan PingInterval
); ) {
internal RpcCommonConnectionParameters Common => new (SendQueueCapacity, PingInterval);
}

View File

@@ -1,21 +1,23 @@
using Phantom.Utils.Logging; using System.Net.Sockets;
using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Frame;
using Serilog; using Serilog;
namespace Phantom.Utils.Rpc.Runtime.Client; namespace Phantom.Utils.Rpc.Runtime.Client;
sealed class RpcClientConnection(string loggerName, RpcClientConnector connector, RpcClientConnector.Connection initialConnection) : IRpcConnectionProvider, IDisposable { sealed class RpcClientToServerConnection(string loggerName, RpcClientToServerConnector connector, RpcClientToServerConnector.Connection initialConnection) : IRpcConnectionProvider, IDisposable {
private readonly ILogger logger = PhantomLogger.Create<RpcClientConnection>(loggerName); private readonly ILogger logger = PhantomLogger.Create<RpcClientToServerConnection>(loggerName);
private readonly SemaphoreSlim semaphore = new (1); private readonly SemaphoreSlim semaphore = new (1);
private RpcClientConnector.Connection currentConnection = initialConnection; private RpcClientToServerConnector.Connection currentConnection = initialConnection;
private readonly CancellationTokenSource newConnectionCancellationTokenSource = new (); private readonly CancellationTokenSource newConnectionCancellationTokenSource = new ();
public async Task<Stream> GetStream() { public async Task<Stream> GetStream(CancellationToken cancellationToken) {
return (await GetConnection()).Stream; return (await GetConnection()).Stream;
} }
private async Task<RpcClientConnector.Connection> GetConnection() { private async Task<RpcClientToServerConnector.Connection> GetConnection() {
CancellationToken cancellationToken = newConnectionCancellationTokenSource.Token; CancellationToken cancellationToken = newConnectionCancellationTokenSource.Token;
await semaphore.WaitAsync(cancellationToken); await semaphore.WaitAsync(cancellationToken);
@@ -30,8 +32,8 @@ sealed class RpcClientConnection(string loggerName, RpcClientConnector connector
} }
} }
public async Task ReadConnection(Func<Stream, Task> reader) { public async Task ReadConnection(IFrameReader frameReader, CancellationToken cancellationToken) {
RpcClientConnector.Connection? connection = null; RpcClientToServerConnector.Connection? connection = null;
try { try {
while (true) { while (true) {
@@ -48,25 +50,27 @@ sealed class RpcClientConnection(string loggerName, RpcClientConnector connector
} }
try { try {
await reader(connection.Stream); await IFrame.ReadFrom(connection.Stream, frameReader, cancellationToken);
} catch (OperationCanceledException) { } catch (OperationCanceledException) {
throw; throw;
} catch (EndOfStreamException) { } catch (EndOfStreamException) {
logger.Warning("Socket was closed."); logger.Warning("Socket was closed.");
} catch (SocketException e) {
logger.Error("Socket reading was interrupted. Socket error {ErrorCode} ({ErrorCodeName}), reason: {ErrorMessage}", e.ErrorCode, e.SocketErrorCode, e.Message);
} catch (Exception e) { } catch (Exception e) {
logger.Error(e, "Closing socket due to an exception while reading it."); logger.Error(e, "Socket reading was interrupted.");
}
try {
await connection.Shutdown(); try {
} catch (Exception e2) { await connection.Shutdown();
logger.Error(e2, "Caught exception closing the socket."); } catch (Exception e) {
} logger.Error(e, "Caught exception closing the socket.");
} }
} }
} finally { } finally {
if (connection != null) { if (connection != null) {
try { try {
await connection.Disconnect(); // TODO what happens if already disconnected? await connection.Disconnect();
} finally { } finally {
connection.Dispose(); connection.Dispose();
} }
@@ -74,7 +78,7 @@ sealed class RpcClientConnection(string loggerName, RpcClientConnector connector
} }
} }
public void Close() { public void StopReconnecting() {
newConnectionCancellationTokenSource.Cancel(); newConnectionCancellationTokenSource.Cancel();
} }

View File

@@ -9,7 +9,7 @@ using Serilog;
namespace Phantom.Utils.Rpc.Runtime.Client; namespace Phantom.Utils.Rpc.Runtime.Client;
internal sealed class RpcClientConnector { internal sealed class RpcClientToServerConnector {
private static readonly TimeSpan InitialRetryDelay = TimeSpan.FromMilliseconds(100); private static readonly TimeSpan InitialRetryDelay = TimeSpan.FromMilliseconds(100);
private static readonly TimeSpan MaximumRetryDelay = TimeSpan.FromSeconds(30); private static readonly TimeSpan MaximumRetryDelay = TimeSpan.FromSeconds(30);
private static readonly TimeSpan DisconnectTimeout = TimeSpan.FromSeconds(10); private static readonly TimeSpan DisconnectTimeout = TimeSpan.FromSeconds(10);
@@ -21,8 +21,8 @@ internal sealed class RpcClientConnector {
private bool loggedCertificateValidationError = false; private bool loggedCertificateValidationError = false;
public RpcClientConnector(string loggerName, RpcClientConnectionParameters parameters) { public RpcClientToServerConnector(string loggerName, RpcClientConnectionParameters parameters) {
this.logger = PhantomLogger.Create<RpcClientConnector>(loggerName); this.logger = PhantomLogger.Create<RpcClientToServerConnector>(loggerName);
this.sessionId = Guid.NewGuid(); this.sessionId = Guid.NewGuid();
this.parameters = parameters; this.parameters = parameters;
@@ -53,7 +53,8 @@ internal sealed class RpcClientConnector {
return newConnection; return newConnection;
} }
logger.Warning("Failed to connect to server, trying again in {}.", nextAttemptDelay.TotalSeconds.ToString("F1")); cancellationToken.ThrowIfCancellationRequested();
logger.Information("Trying to reconnect in {Seconds}s.", nextAttemptDelay.TotalSeconds.ToString("F1"));
await Task.Delay(nextAttemptDelay, cancellationToken); await Task.Delay(nextAttemptDelay, cancellationToken);
nextAttemptDelay = Comparables.Min(nextAttemptDelay.Multiply(1.5), MaximumRetryDelay); nextAttemptDelay = Comparables.Min(nextAttemptDelay.Multiply(1.5), MaximumRetryDelay);
@@ -66,9 +67,12 @@ internal sealed class RpcClientConnector {
Socket clientSocket = new Socket(SocketType.Stream, ProtocolType.Tcp); Socket clientSocket = new Socket(SocketType.Stream, ProtocolType.Tcp);
try { try {
await clientSocket.ConnectAsync(parameters.Host, parameters.Port, cancellationToken); await clientSocket.ConnectAsync(parameters.Host, parameters.Port, cancellationToken);
} catch (SocketException e) {
logger.Error("Could not connect. Socket error {ErrorCode} ({ErrorCodeName}), reason: {ErrorMessage}", e.ErrorCode, e.SocketErrorCode, e.Message);
return null;
} catch (Exception e) { } catch (Exception e) {
logger.Error(e, "Could not connect."); logger.Error(e, "Could not connect.");
throw; return null;
} }
SslStream? stream; SslStream? stream;

View File

@@ -1,5 +1,5 @@
namespace Phantom.Utils.Rpc.Runtime; namespace Phantom.Utils.Rpc.Runtime;
interface IRpcConnectionProvider { interface IRpcConnectionProvider {
Task<Stream> GetStream(); Task<Stream> GetStream(CancellationToken cancellationToken);
} }

View File

@@ -0,0 +1,6 @@
namespace Phantom.Utils.Rpc.Runtime;
readonly record struct RpcCommonConnectionParameters(
ushort SendQueueCapacity,
TimeSpan PingInterval
);

View File

@@ -3,7 +3,8 @@
public enum RpcError : byte { public enum RpcError : byte {
InvalidData = 0, InvalidData = 0,
UnknownMessageRegistryCode = 1, UnknownMessageRegistryCode = 1,
MessageDeserializationError = 2, MessageTooLarge = 2,
MessageHandlingError = 3, MessageDeserializationError = 3,
MessageTooLarge = 4, MessageHandlingError = 4,
MessageAlreadyHandled = 5,
} }

View File

@@ -1,13 +1,14 @@
namespace Phantom.Utils.Rpc.Runtime; namespace Phantom.Utils.Rpc.Runtime;
public sealed class RpcErrorException : Exception { sealed class RpcErrorException : Exception {
internal static RpcErrorException From(RpcError error) { internal static RpcErrorException From(RpcError error) {
return error switch { return error switch {
RpcError.InvalidData => new RpcErrorException("Invalid data", error), RpcError.InvalidData => new RpcErrorException("Invalid data", error),
RpcError.UnknownMessageRegistryCode => new RpcErrorException("Unknown message registry code", error), RpcError.UnknownMessageRegistryCode => new RpcErrorException("Unknown message registry code", error),
RpcError.MessageTooLarge => new RpcErrorException("Message is too large", error),
RpcError.MessageDeserializationError => new RpcErrorException("Message deserialization error", error), RpcError.MessageDeserializationError => new RpcErrorException("Message deserialization error", error),
RpcError.MessageHandlingError => new RpcErrorException("Message handling error", error), RpcError.MessageHandlingError => new RpcErrorException("Message handling error", error),
RpcError.MessageTooLarge => new RpcErrorException("Message is too large", error), RpcError.MessageAlreadyHandled => new RpcErrorException("Message already handled", error),
_ => new RpcErrorException("Unknown error", error), _ => new RpcErrorException("Unknown error", error),
}; };
} }

View File

@@ -0,0 +1,53 @@
using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Frame;
using Phantom.Utils.Rpc.Frame.Types;
using Phantom.Utils.Rpc.Message;
using Serilog;
namespace Phantom.Utils.Rpc.Runtime;
sealed class RpcFrameReader<TSentMessage, TReceivedMessage>(
string loggerName,
MessageRegistry<TReceivedMessage> messageRegistry,
MessageReceiveTracker messageReceiveTracker,
RpcMessageHandler<TReceivedMessage> messageHandler,
RpcSendChannel<TSentMessage> sendChannel
) : IFrameReader {
private readonly ILogger logger = PhantomLogger.Create<RpcFrameReader<TSentMessage, TReceivedMessage>>(loggerName);
public ValueTask OnPingFrame(DateTimeOffset pingTime, CancellationToken cancellationToken) {
messageHandler.OnPing();
return sendChannel.SendPong(pingTime, cancellationToken);
}
public void OnPongFrame(PongFrame frame) {
sendChannel.ReceivePong(frame);
}
public Task OnMessageFrame(MessageFrame frame, CancellationToken cancellationToken) {
if (!messageReceiveTracker.ReceiveMessage(frame.MessageId)) {
logger.Warning("Received duplicate message {MessageId}.", frame.MessageId);
return messageHandler.SendError(frame.MessageId, RpcError.MessageAlreadyHandled, cancellationToken).AsTask();
}
if (messageRegistry.TryGetType(frame, out var messageType)) {
logger.Verbose("Received message {MesageId} of type {MessageType} ({Bytes} B).", frame.MessageId, messageType.Name, frame.SerializedMessage.Length);
}
return messageRegistry.Handle(frame, messageHandler, cancellationToken);
}
public void OnReplyFrame(ReplyFrame frame) {
logger.Verbose("Received reply to message {MesageId} ({Bytes} B).", frame.ReplyingToMessageId, frame.SerializedReply.Length);
sendChannel.ReceiveReply(frame);
}
public void OnErrorFrame(ErrorFrame frame) {
logger.Warning("Received error response to message {MesageId}: {Error}", frame.ReplyingToMessageId, frame.Error);
sendChannel.ReceiveError(frame.ReplyingToMessageId, frame.Error);
}
public void OnUnknownFrameId(byte frameId) {
logger.Error("Received unknown frame ID: {FrameId}", frameId);
}
}

View File

@@ -0,0 +1,19 @@
using Phantom.Utils.Rpc.Message;
namespace Phantom.Utils.Rpc.Runtime;
sealed class RpcMessageHandler<TMessageBase>(IMessageReceiver<TMessageBase> receiver, IRpcReplySender replySender) {
public IMessageReceiver<TMessageBase> Receiver => receiver;
public void OnPing() {
receiver.OnPing();
}
public ValueTask SendReply<TReply>(uint messageId, TReply reply, CancellationToken cancellationToken) {
return replySender.SendReply(messageId, reply, cancellationToken);
}
public ValueTask SendError(uint messageId, RpcError error, CancellationToken cancellationToken) {
return replySender.SendError(messageId, error, cancellationToken);
}
}

View File

@@ -5,12 +5,11 @@ using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Frame; using Phantom.Utils.Rpc.Frame;
using Phantom.Utils.Rpc.Frame.Types; using Phantom.Utils.Rpc.Frame.Types;
using Phantom.Utils.Rpc.Message; using Phantom.Utils.Rpc.Message;
using Phantom.Utils.Rpc.Runtime.Client;
using Serilog; using Serilog;
namespace Phantom.Utils.Rpc.Runtime; namespace Phantom.Utils.Rpc.Runtime;
public sealed class RpcSendChannel<TMessageBase> : IDisposable { public sealed class RpcSendChannel<TMessageBase> : IRpcReplySender, IDisposable {
private readonly ILogger logger; private readonly ILogger logger;
private readonly IRpcConnectionProvider connectionProvider; private readonly IRpcConnectionProvider connectionProvider;
private readonly MessageRegistry<TMessageBase> messageRegistry; private readonly MessageRegistry<TMessageBase> messageRegistry;
@@ -20,11 +19,13 @@ public sealed class RpcSendChannel<TMessageBase> : IDisposable {
private readonly Task sendQueueTask; private readonly Task sendQueueTask;
private readonly Task pingTask; private readonly Task pingTask;
private readonly CancellationTokenSource sendQueueCancellationTokenSource = new ();
private readonly CancellationTokenSource pingCancellationTokenSource = new (); private readonly CancellationTokenSource pingCancellationTokenSource = new ();
private uint nextMessageId; private uint nextMessageId;
private TaskCompletionSource<DateTimeOffset>? pongTask;
internal RpcSendChannel(string loggerName, RpcClientConnectionParameters connectionParameters, IRpcConnectionProvider connectionProvider, MessageRegistry<TMessageBase> messageRegistry) { internal RpcSendChannel(string loggerName, RpcCommonConnectionParameters connectionParameters, IRpcConnectionProvider connectionProvider, MessageRegistry<TMessageBase> messageRegistry) {
this.logger = PhantomLogger.Create<RpcSendChannel<TMessageBase>>(loggerName); this.logger = PhantomLogger.Create<RpcSendChannel<TMessageBase>>(loggerName);
this.connectionProvider = connectionProvider; this.connectionProvider = connectionProvider;
this.messageRegistry = messageRegistry; this.messageRegistry = messageRegistry;
@@ -41,11 +42,15 @@ public sealed class RpcSendChannel<TMessageBase> : IDisposable {
this.pingTask = Ping(connectionParameters.PingInterval); this.pingTask = Ping(connectionParameters.PingInterval);
} }
internal async ValueTask SendPong(DateTimeOffset pingTime, CancellationToken cancellationToken) {
await SendFrame(new PongFrame(pingTime), cancellationToken);
}
public bool TrySendMessage<TMessage>(TMessage message) where TMessage : TMessageBase { public bool TrySendMessage<TMessage>(TMessage message) where TMessage : TMessageBase {
return sendQueue.Writer.TryWrite(NextMessageFrame(message)); return sendQueue.Writer.TryWrite(NextMessageFrame(message));
} }
public async ValueTask SendMessage<TMessage>(TMessage message, CancellationToken cancellationToken) where TMessage : TMessageBase { public async ValueTask SendMessage<TMessage>(TMessage message, CancellationToken cancellationToken = default) where TMessage : TMessageBase {
await SendFrame(NextMessageFrame(message), cancellationToken); await SendFrame(NextMessageFrame(message), cancellationToken);
} }
@@ -64,11 +69,11 @@ public sealed class RpcSendChannel<TMessageBase> : IDisposable {
return await messageReplyTracker.WaitForReply<TReply>(messageId, waitForReplyTime, cancellationToken); return await messageReplyTracker.WaitForReply<TReply>(messageId, waitForReplyTime, cancellationToken);
} }
internal async ValueTask SendReply<TReply>(uint replyingToMessageId, TReply reply, CancellationToken cancellationToken) { async ValueTask IRpcReplySender.SendReply<TReply>(uint replyingToMessageId, TReply reply, CancellationToken cancellationToken) {
await SendFrame(new ReplyFrame(replyingToMessageId, Serialization.Serialize(reply)), cancellationToken); await SendFrame(new ReplyFrame(replyingToMessageId, Serialization.Serialize(reply)), cancellationToken);
} }
internal async ValueTask SendError(uint replyingToMessageId, RpcError error, CancellationToken cancellationToken) { async ValueTask IRpcReplySender.SendError(uint replyingToMessageId, RpcError error, CancellationToken cancellationToken) {
await SendFrame(new ErrorFrame(replyingToMessageId, error), cancellationToken); await SendFrame(new ErrorFrame(replyingToMessageId, error), cancellationToken);
} }
@@ -84,20 +89,21 @@ public sealed class RpcSendChannel<TMessageBase> : IDisposable {
} }
private async Task ProcessSendQueue() { private async Task ProcessSendQueue() {
await foreach (IFrame frame in sendQueue.Reader.ReadAllAsync()) { CancellationToken cancellationToken = sendQueueCancellationTokenSource.Token;
await foreach (IFrame frame in sendQueue.Reader.ReadAllAsync(cancellationToken)) {
while (true) { while (true) {
Stream stream;
try { try {
stream = await connectionProvider.GetStream(); Stream stream = await connectionProvider.GetStream(cancellationToken);
await stream.WriteAsync(frame.FrameType, cancellationToken);
await frame.Write(stream, cancellationToken);
await stream.FlushAsync(cancellationToken);
break;
} catch (OperationCanceledException) { } catch (OperationCanceledException) {
throw; throw;
} catch (Exception) { } catch (Exception) {
continue; // Retry.
} }
await stream.WriteAsync(frame.Type);
await frame.Write(stream);
await stream.FlushAsync();
} }
} }
} }
@@ -109,12 +115,26 @@ public sealed class RpcSendChannel<TMessageBase> : IDisposable {
while (true) { while (true) {
await Task.Delay(interval, cancellationToken); await Task.Delay(interval, cancellationToken);
pongTask = new TaskCompletionSource<DateTimeOffset>();
if (!sendQueue.Writer.TryWrite(PingFrame.Instance)) { if (!sendQueue.Writer.TryWrite(PingFrame.Instance)) {
cancellationToken.ThrowIfCancellationRequested();
logger.Warning("Skipped a ping due to a full queue."); logger.Warning("Skipped a ping due to a full queue.");
continue;
} }
DateTimeOffset pingTime = await pongTask.Task.WaitAsync(cancellationToken);
DateTimeOffset currentTime = DateTimeOffset.UtcNow;
TimeSpan roundTripTime = currentTime - pingTime;
logger.Information("Received pong, round trip time: {RoundTripTime} ms", (long) roundTripTime.TotalMilliseconds);
} }
} }
internal void ReceivePong(PongFrame frame) {
pongTask?.TrySetResult(frame.PingTime);
}
internal void ReceiveReply(ReplyFrame frame) { internal void ReceiveReply(ReplyFrame frame) {
messageReplyTracker.ReceiveReply(frame.ReplyingToMessageId, frame.SerializedReply); messageReplyTracker.ReceiveReply(frame.ReplyingToMessageId, frame.SerializedReply);
} }
@@ -128,7 +148,16 @@ public sealed class RpcSendChannel<TMessageBase> : IDisposable {
sendQueue.Writer.TryComplete(); sendQueue.Writer.TryComplete();
try { try {
await Task.WhenAll(sendQueueTask, pingTask); await pingTask;
} catch (Exception) {
// Ignore.
}
try {
await sendQueueTask.WaitAsync(TimeSpan.FromSeconds(15));
} catch (TimeoutException) {
logger.Warning("Could not finish processing send queue before timeout, forcibly shutting it down.");
await sendQueueCancellationTokenSource.CancelAsync();
} catch (Exception) { } catch (Exception) {
// Ignore. // Ignore.
} }
@@ -136,6 +165,7 @@ public sealed class RpcSendChannel<TMessageBase> : IDisposable {
public void Dispose() { public void Dispose() {
sendQueueTask.Dispose(); sendQueueTask.Dispose();
sendQueueCancellationTokenSource.Dispose();
pingCancellationTokenSource.Dispose(); pingCancellationTokenSource.Dispose();
} }
} }

View File

@@ -0,0 +1,7 @@
using Phantom.Utils.Rpc.Message;
namespace Phantom.Utils.Rpc.Runtime.Server;
public interface IRpcServerClientRegistrar<TClientToServerMessage, TServerToClientMessage> {
IMessageReceiver<TClientToServerMessage> Register(RpcServerToClientConnection<TClientToServerMessage, TServerToClientMessage> connection);
}

View File

@@ -4,17 +4,25 @@ using System.Net.Sockets;
using System.Security.Cryptography.X509Certificates; using System.Security.Cryptography.X509Certificates;
using Phantom.Utils.Logging; using Phantom.Utils.Logging;
using Phantom.Utils.Monads; using Phantom.Utils.Monads;
using Phantom.Utils.Rpc.Message;
using Phantom.Utils.Rpc.Runtime.Tls; using Phantom.Utils.Rpc.Runtime.Tls;
using Serilog; using Serilog;
namespace Phantom.Utils.Rpc.Runtime.Server; namespace Phantom.Utils.Rpc.Runtime.Server;
public sealed class RpcServer(string loggerName, EndPoint endPoint, AuthToken authToken, RpcServerCertificate certificate) { public sealed class RpcServer<TClientToServerMessage, TServerToClientMessage>(
private readonly ILogger logger = PhantomLogger.Create<RpcServer>(loggerName); string loggerName,
private readonly RpcServerClientManager clientManager = new (); RpcServerConnectionParameters connectionParameters,
IMessageDefinitions<TClientToServerMessage, TServerToClientMessage> messageDefinitions,
IRpcServerClientRegistrar<TClientToServerMessage, TServerToClientMessage> clientRegistrar
) {
private readonly ILogger logger = PhantomLogger.Create<RpcServer<TClientToServerMessage, TServerToClientMessage>>(loggerName);
private readonly RpcServerClientSessions<TServerToClientMessage> clientSessions = new (connectionParameters.Common, messageDefinitions.ToClient);
private readonly List<Client> clients = []; private readonly List<Client> clients = [];
public async Task<bool> Run(CancellationToken shutdownToken) { public async Task<bool> Run(CancellationToken shutdownToken) {
EndPoint endPoint = connectionParameters.EndPoint;
SslServerAuthenticationOptions sslOptions = new () { SslServerAuthenticationOptions sslOptions = new () {
AllowRenegotiation = false, AllowRenegotiation = false,
AllowTlsResume = true, AllowTlsResume = true,
@@ -22,7 +30,7 @@ public sealed class RpcServer(string loggerName, EndPoint endPoint, AuthToken au
ClientCertificateRequired = false, ClientCertificateRequired = false,
EnabledSslProtocols = TlsSupport.SupportedProtocols, EnabledSslProtocols = TlsSupport.SupportedProtocols,
EncryptionPolicy = EncryptionPolicy.RequireEncryption, EncryptionPolicy = EncryptionPolicy.RequireEncryption,
ServerCertificate = certificate.Certificate, ServerCertificate = connectionParameters.Certificate.Certificate,
}; };
try { try {
@@ -39,11 +47,13 @@ public sealed class RpcServer(string loggerName, EndPoint endPoint, AuthToken au
try { try {
logger.Information("Server listening on {EndPoint}.", endPoint); logger.Information("Server listening on {EndPoint}.", endPoint);
while (!shutdownToken.IsCancellationRequested) { while (true) {
Socket clientSocket = await serverSocket.AcceptAsync(shutdownToken); Socket clientSocket = await serverSocket.AcceptAsync(shutdownToken);
clients.Add(new Client(clientManager, clientSocket, sslOptions, authToken, shutdownToken)); clients.Add(new Client(loggerName, messageDefinitions, clientRegistrar, clientSessions, clientSocket, sslOptions, connectionParameters.AuthToken, shutdownToken));
clients.RemoveAll(static client => client.Task.IsCompleted); clients.RemoveAll(static client => client.Task.IsCompleted);
} }
} catch (OperationCanceledException) {
// Ignore.
} finally { } finally {
await Stop(serverSocket); await Stop(serverSocket);
} }
@@ -67,6 +77,7 @@ public sealed class RpcServer(string loggerName, EndPoint endPoint, AuthToken au
try { try {
await Task.WhenAll(clients.Select(static client => client.Task)); await Task.WhenAll(clients.Select(static client => client.Task));
await clientSessions.Shutdown();
} catch (Exception) { } catch (Exception) {
// Ignore exceptions when shutting down. // Ignore exceptions when shutting down.
} }
@@ -75,26 +86,54 @@ public sealed class RpcServer(string loggerName, EndPoint endPoint, AuthToken au
} }
private sealed class Client { private sealed class Client {
private static readonly TimeSpan DisconnectTimeout = TimeSpan.FromSeconds(10); private static TimeSpan DisconnectTimeout => TimeSpan.FromSeconds(10);
private static string GetAddressDescriptor(Socket socket) {
EndPoint? endPoint = socket.RemoteEndPoint;
return endPoint switch {
IPEndPoint ip => ip.Port.ToString(),
null => "{unknown}",
_ => "{" + endPoint + "}",
};
}
public Task Task { get; } public Task Task { get; }
private string Address => socket.RemoteEndPoint?.ToString() ?? "<unknown address>"; private string Address => socket.RemoteEndPoint?.ToString() ?? "<unknown address>";
private readonly ILogger logger; private ILogger logger;
private readonly RpcServerClientManager clientManager; private readonly string serverLoggerName;
private readonly IMessageDefinitions<TClientToServerMessage, TServerToClientMessage> messageDefinitions;
private readonly IRpcServerClientRegistrar<TClientToServerMessage, TServerToClientMessage> clientRegistrar;
private readonly RpcServerClientSessions<TServerToClientMessage> clientSessions;
private readonly Socket socket; private readonly Socket socket;
private readonly SslServerAuthenticationOptions sslOptions; private readonly SslServerAuthenticationOptions sslOptions;
private readonly AuthToken authToken; private readonly AuthToken authToken;
private readonly CancellationToken shutdownToken; private readonly CancellationToken shutdownToken;
public Client(RpcServerClientManager clientManager, Socket socket, SslServerAuthenticationOptions sslOptions, AuthToken authToken, CancellationToken shutdownToken) { public Client(
this.logger = PhantomLogger.Create<RpcServer, Client>(Address); string serverLoggerName,
this.clientManager = clientManager; IMessageDefinitions<TClientToServerMessage, TServerToClientMessage> messageDefinitions,
IRpcServerClientRegistrar<TClientToServerMessage, TServerToClientMessage> clientRegistrar,
RpcServerClientSessions<TServerToClientMessage> clientSessions,
Socket socket,
SslServerAuthenticationOptions sslOptions,
AuthToken authToken,
CancellationToken shutdownToken
) {
this.logger = PhantomLogger.Create<RpcServer<TClientToServerMessage, TServerToClientMessage>, Client>(PhantomLogger.ConcatNames(serverLoggerName, GetAddressDescriptor(socket)));
this.serverLoggerName = serverLoggerName;
this.messageDefinitions = messageDefinitions;
this.clientRegistrar = clientRegistrar;
this.clientSessions = clientSessions;
this.socket = socket; this.socket = socket;
this.sslOptions = sslOptions; this.sslOptions = sslOptions;
this.authToken = authToken; this.authToken = authToken;
this.shutdownToken = shutdownToken; this.shutdownToken = shutdownToken;
this.Task = Run(); this.Task = Run();
} }
@@ -103,21 +142,19 @@ public sealed class RpcServer(string loggerName, EndPoint endPoint, AuthToken au
try { try {
await using var stream = new SslStream(new NetworkStream(socket, ownsSocket: false), leaveInnerStreamOpen: false); await using var stream = new SslStream(new NetworkStream(socket, ownsSocket: false), leaveInnerStreamOpen: false);
Guid? sessionId; Guid? sessionIdResult;
try { try {
sessionId = await InitializeConnection(stream); sessionIdResult = await InitializeConnection(stream);
} catch (OperationCanceledException) { } catch (OperationCanceledException) {
logger.Warning("Cancelling incoming client due to shutdown."); logger.Warning("Cancelling incoming client due to shutdown.");
return; return;
} }
if (!sessionId.HasValue) { if (sessionIdResult.HasValue) {
return; await RunConnectedSession(sessionIdResult.Value, stream);
} }
} catch (Exception e) {
logger.Information("Client connected."); logger.Error(e, "Caught exception while processing client.");
clientManager.SetConnection(sessionId.Value, stream);
} finally { } finally {
logger.Information("Disconnecting client..."); logger.Information("Disconnecting client...");
try { try {
@@ -129,6 +166,7 @@ public sealed class RpcServer(string loggerName, EndPoint endPoint, AuthToken au
logger.Error(e, "Could not disconnect client socket."); logger.Error(e, "Could not disconnect client socket.");
} finally { } finally {
socket.Close(); socket.Close();
logger.Information("Client socket closed.");
} }
} }
} }
@@ -180,5 +218,37 @@ public sealed class RpcServer(string loggerName, EndPoint endPoint, AuthToken au
var sessionId = await Serialization.ReadGuid(stream, cancellationToken); var sessionId = await Serialization.ReadGuid(stream, cancellationToken);
return Either.Left(sessionId); return Either.Left(sessionId);
} }
private async Task RunConnectedSession(Guid sessionId, Stream stream) {
var loggerName = PhantomLogger.ConcatNames(serverLoggerName, clientSessions.NextLoggerName(sessionId));
logger.Information("Client connected with session {SessionId}, new logger name: {LoggerName}", sessionId, loggerName);
logger = PhantomLogger.Create<RpcServer<TClientToServerMessage, TServerToClientMessage>, Client>(loggerName);
var session = clientSessions.OnConnected(sessionId, loggerName, stream);
try {
var connection = new RpcServerToClientConnection<TClientToServerMessage, TServerToClientMessage>(loggerName, clientSessions, sessionId, messageDefinitions.ToServer, stream, session);
IMessageReceiver<TClientToServerMessage> messageReceiver;
try {
messageReceiver = clientRegistrar.Register(connection);
} catch (Exception e) {
logger.Error(e, "Could not register client.");
return;
}
try {
await connection.Listen(messageReceiver);
} catch (EndOfStreamException) {
logger.Warning("Socket reading was interrupted, connection lost.");
} catch (SocketException e) {
logger.Error("Socket reading was interrupted. Socket error {ErrorCode} ({ErrorCodeName}), reason: {ErrorMessage}", e.ErrorCode, e.SocketErrorCode, e.Message);
} catch (Exception e) {
logger.Error(e, "Socket reading was interrupted.");
}
} finally {
clientSessions.OnDisconnected(sessionId);
}
}
} }
} }

View File

@@ -1,4 +0,0 @@
namespace Phantom.Utils.Rpc.Runtime.Server;
public sealed class RpcServerClientConnection {
}

View File

@@ -1,13 +0,0 @@
namespace Phantom.Utils.Rpc.Runtime.Server;
sealed class RpcServerClientManager {
private readonly Dictionary<Guid, RpcServerClientConnection> connectionsBySessionId = new ();
internal void SetConnection(Guid sessionId, Stream stream) {
connectionsBySessionId.AddOrUpdate(sessionId, id => {
return new RpcServerClientConnection(id, stream);
}, connection => {
return connection;
});
}
}

View File

@@ -0,0 +1,50 @@
using Phantom.Utils.Rpc.Message;
namespace Phantom.Utils.Rpc.Runtime.Server;
sealed class RpcServerClientSession<TServerToClientMessage> : IRpcConnectionProvider {
public RpcSendChannel<TServerToClientMessage> SendChannel { get; }
public MessageReceiveTracker MessageReceiveTracker { get; }
private TaskCompletionSource<Stream> nextStream = new ();
public RpcServerClientSession(string loggerName, RpcCommonConnectionParameters connectionParameters, MessageRegistry<TServerToClientMessage> messageRegistry) {
this.SendChannel = new RpcSendChannel<TServerToClientMessage>(loggerName, connectionParameters, this, messageRegistry);
this.MessageReceiveTracker = new MessageReceiveTracker();
}
public void OnConnected(Stream stream) {
lock (this) {
if (!nextStream.Task.IsCanceled && !nextStream.TrySetResult(stream)) {
nextStream = new TaskCompletionSource<Stream>();
nextStream.SetResult(stream);
}
}
}
public void OnDisconnected() {
lock (this) {
var task = nextStream.Task;
if (task is { IsCompleted: true, IsCanceled: false }) {
nextStream = new TaskCompletionSource<Stream>();
}
}
}
Task<Stream> IRpcConnectionProvider.GetStream(CancellationToken cancellationToken) {
lock (this) {
return nextStream.Task;
}
}
public Task Close() {
lock (this) {
if (!nextStream.TrySetCanceled()) {
nextStream = new TaskCompletionSource<Stream>();
nextStream.SetCanceled();
}
}
return SendChannel.Close();
}
}

View File

@@ -0,0 +1,61 @@
using System.Collections.Concurrent;
using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Message;
namespace Phantom.Utils.Rpc.Runtime.Server;
sealed class RpcServerClientSessions<TServerToClientMessage> {
private readonly RpcCommonConnectionParameters connectionParameters;
private readonly MessageRegistry<TServerToClientMessage> messageRegistry;
private readonly ConcurrentDictionary<Guid, RpcServerClientSession<TServerToClientMessage>> sessionsById = new ();
private readonly ConcurrentDictionary<Guid, uint> sessionLoggerSequenceIds = new ();
private readonly Func<Guid, string, RpcServerClientSession<TServerToClientMessage>> createSessionFunction;
public RpcServerClientSessions(RpcCommonConnectionParameters connectionParameters, MessageRegistry<TServerToClientMessage> messageRegistry) {
this.connectionParameters = connectionParameters;
this.messageRegistry = messageRegistry;
this.createSessionFunction = CreateSession;
}
public string NextLoggerName(Guid sessionId) {
string name = PhantomLogger.ShortenGuid(sessionId);
return name + "/" + sessionLoggerSequenceIds.AddOrUpdate(sessionId, static _ => 1, static (_, prev) => prev + 1);
}
public RpcServerClientSession<TServerToClientMessage> OnConnected(Guid sessionId, string loggerName, Stream stream) {
var session = sessionsById.GetOrAdd(sessionId, createSessionFunction, loggerName);
session.OnConnected(stream);
return session;
}
private RpcServerClientSession<TServerToClientMessage> CreateSession(Guid sessionId, string loggerName) {
return new RpcServerClientSession<TServerToClientMessage>(loggerName, connectionParameters, messageRegistry);
}
public void OnDisconnected(Guid sessionId) {
if (sessionsById.TryGetValue(sessionId, out var session)) {
session.OnDisconnected();
}
}
public Task CloseSession(Guid sessionId) {
if (sessionsById.Remove(sessionId, out var session)) {
return session.Close();
}
else {
return Task.CompletedTask;
}
}
public Task Shutdown() {
List<Task> tasks = [];
foreach (Guid guid in sessionsById.Keys) {
tasks.Add(CloseSession(guid));
}
return Task.WhenAll(tasks);
}
}

View File

@@ -0,0 +1,14 @@
using System.Net;
using Phantom.Utils.Rpc.Runtime.Tls;
namespace Phantom.Utils.Rpc.Runtime.Server;
public readonly record struct RpcServerConnectionParameters(
EndPoint EndPoint,
RpcServerCertificate Certificate,
AuthToken AuthToken,
ushort SendQueueCapacity,
TimeSpan PingInterval
) {
internal RpcCommonConnectionParameters Common => new (SendQueueCapacity, PingInterval);
}

View File

@@ -0,0 +1,49 @@
using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Frame;
using Phantom.Utils.Rpc.Message;
using Serilog;
namespace Phantom.Utils.Rpc.Runtime.Server;
public sealed class RpcServerToClientConnection<TClientToServerMessage, TServerToClientMessage> {
private readonly string loggerName;
private readonly ILogger logger;
private readonly RpcServerClientSessions<TServerToClientMessage> sessions;
private readonly MessageRegistry<TClientToServerMessage> messageRegistry;
private readonly MessageReceiveTracker messageReceiveTracker;
private readonly Stream stream;
private readonly CancellationTokenSource closeCancellationTokenSource = new ();
public Guid SessionId { get; }
public RpcSendChannel<TServerToClientMessage> SendChannel { get; }
internal RpcServerToClientConnection(string loggerName, RpcServerClientSessions<TServerToClientMessage> sessions, Guid sessionId, MessageRegistry<TClientToServerMessage> messageRegistry, Stream stream, RpcServerClientSession<TServerToClientMessage> session) {
this.loggerName = loggerName;
this.logger = PhantomLogger.Create<RpcServerToClientConnection<TClientToServerMessage, TServerToClientMessage>>(loggerName);
this.sessions = sessions;
this.messageRegistry = messageRegistry;
this.messageReceiveTracker = session.MessageReceiveTracker;
this.stream = stream;
this.SessionId = sessionId;
this.SendChannel = session.SendChannel;
}
internal async Task Listen(IMessageReceiver<TClientToServerMessage> receiver) {
var messageHandler = new RpcMessageHandler<TClientToServerMessage>(receiver, SendChannel);
var frameReader = new RpcFrameReader<TServerToClientMessage, TClientToServerMessage>(loggerName, messageRegistry, messageReceiveTracker, messageHandler, SendChannel);
try {
await IFrame.ReadFrom(stream, frameReader, closeCancellationTokenSource.Token);
} catch (OperationCanceledException) {
// Ignore.
}
}
public async Task ClientClosedSession() {
logger.Information("Client closed session.");
Task closeSessionTask = sessions.CloseSession(SessionId);
await closeCancellationTokenSource.CancelAsync();
await closeSessionTask;
}
}

View File

@@ -1,135 +0,0 @@
// using System.Collections.Concurrent;
// using Akka.Actor;
// using NetMQ.Sockets;
// using Phantom.Utils.Actor;
// using Phantom.Utils.Logging;
// using Serilog;
// using Serilog.Events;
//
// namespace Phantom.Utils.Rpc.Runtime2;
//
// public static class RpcServerRuntime {
// public static Task Launch<TClientMessage, TServerMessage, TRegistrationMessage, TReplyMessage>(
// RpcConfiguration config,
// IMessageDefinitions<TClientMessage, TServerMessage, TReplyMessage> messageDefinitions,
// IRegistrationHandler<TClientMessage, TServerMessage, TRegistrationMessage> registrationHandler,
// IActorRefFactory actorSystem,
// CancellationToken cancellationToken
// ) where TRegistrationMessage : TServerMessage where TReplyMessage : TClientMessage, TServerMessage {
// return RpcServerRuntime<TClientMessage, TServerMessage, TRegistrationMessage, TReplyMessage>.Launch(config, messageDefinitions, registrationHandler, actorSystem, cancellationToken);
// }
// }
//
// internal sealed class RpcServerRuntime<TClientMessage, TServerMessage, TRegistrationMessage, TReplyMessage> : RpcRuntime<ServerSocket> where TRegistrationMessage : TServerMessage where TReplyMessage : TClientMessage, TServerMessage {
// internal static Task Launch(RpcConfiguration config, IMessageDefinitions<TClientMessage, TServerMessage, TReplyMessage> messageDefinitions, IRegistrationHandler<TClientMessage, TServerMessage, TRegistrationMessage> registrationHandler, IActorRefFactory actorSystem, CancellationToken cancellationToken) {
// var socket = RpcServerSocket.Connect(config);
// return new RpcServerRuntime<TClientMessage, TServerMessage, TRegistrationMessage, TReplyMessage>(socket, messageDefinitions, registrationHandler, actorSystem, cancellationToken).Launch();
// }
//
// private readonly string serviceName;
// private readonly IMessageDefinitions<TClientMessage, TServerMessage, TReplyMessage> messageDefinitions;
// private readonly IRegistrationHandler<TClientMessage, TServerMessage, TRegistrationMessage> registrationHandler;
// private readonly IActorRefFactory actorSystem;
// private readonly CancellationToken cancellationToken;
//
// private RpcServerRuntime(RpcServerSocket socket, IMessageDefinitions<TClientMessage, TServerMessage, TReplyMessage> messageDefinitions, IRegistrationHandler<TClientMessage, TServerMessage, TRegistrationMessage> registrationHandler, IActorRefFactory actorSystem, CancellationToken cancellationToken) : base(socket) {
// this.serviceName = socket.Config.ServiceName;
// this.messageDefinitions = messageDefinitions;
// this.registrationHandler = registrationHandler;
// this.actorSystem = actorSystem;
// this.cancellationToken = cancellationToken;
// }
//
// private protected override Task Run(ServerSocket socket) {
// var clients = new ConcurrentDictionary<ulong, Client>();
//
// void OnConnectionClosed(object? sender, RpcClientConnectionClosedEventArgs e) {
// if (clients.Remove(e.RoutingId, out var client)) {
// client.Connection.Closed -= OnConnectionClosed;
// }
// }
//
// while (!cancellationToken.IsCancellationRequested) {
// var (routingId, data) = socket.Receive(cancellationToken);
//
// if (data.Length == 0) {
// LogUnknownMessage(routingId, data);
// continue;
// }
//
// Type? messageType = messageDefinitions.ToServer.TryGetType(data, out var type) ? type : null;
// if (messageType == null) {
// LogUnknownMessage(routingId, data);
// continue;
// }
//
// if (!clients.TryGetValue(routingId, out var client)) {
// if (messageType != typeof(TRegistrationMessage)) {
// RuntimeLogger.Warning("Received {MessageType} ({Bytes} B) from unregistered client {RoutingId}.", messageType.Name, data.Length, routingId);
// continue;
// }
//
// var clientLoggerName = LoggerName + ":" + routingId;
// var clientActorName = "Rpc-" + serviceName + "-" + routingId;
//
// // TODO add pings and tear down connection after too much inactivity
// var connection = new RpcConnectionToClient<TClientMessage>(socket, routingId, messageDefinitions.ToClient, ReplyTracker);
// connection.Closed += OnConnectionClosed;
//
// client = new Client(clientLoggerName, clientActorName, connection, actorSystem, messageDefinitions, registrationHandler);
// clients[routingId] = client;
// }
//
// client.Enqueue(messageType, data);
// }
//
// foreach (var client in clients.Values) {
// client.Connection.Close();
// }
//
// return Task.CompletedTask;
// }
//
// private void LogUnknownMessage(uint routingId, ReadOnlyMemory<byte> data) {
// RuntimeLogger.Warning("Received unknown message ({Bytes} B) from {RoutingId}.", data.Length, routingId);
// }
//
// private protected override Task Disconnect(ServerSocket socket) {
// return Task.CompletedTask;
// }
//
// private sealed class Client {
// public RpcConnectionToClient<TClientMessage> Connection { get; }
//
// private readonly ILogger logger;
// private readonly ActorRef<RpcReceiverActor<TClientMessage, TServerMessage, TRegistrationMessage, TReplyMessage>.ReceiveMessageCommand> receiverActor;
//
// public Client(string loggerName, string actorName, RpcConnectionToClient<TClientMessage> connection, IActorRefFactory actorSystem, IMessageDefinitions<TClientMessage, TServerMessage, TReplyMessage> messageDefinitions, IRegistrationHandler<TClientMessage, TServerMessage, TRegistrationMessage> registrationHandler) {
// this.Connection = connection;
// this.Connection.Closed += OnConnectionClosed;
//
// this.logger = PhantomLogger.Create(loggerName);
//
// var receiverActorInit = new RpcReceiverActor<TClientMessage, TServerMessage, TRegistrationMessage, TReplyMessage>.Init(loggerName, messageDefinitions, registrationHandler, Connection);
// this.receiverActor = actorSystem.ActorOf(RpcReceiverActor<TClientMessage, TServerMessage, TRegistrationMessage, TReplyMessage>.Factory(receiverActorInit), actorName + "-Receiver");
// }
//
// internal void Enqueue(Type messageType, ReadOnlyMemory<byte> data) {
// LogMessageType(messageType, data);
// receiverActor.Tell(new RpcReceiverActor<TClientMessage, TServerMessage, TRegistrationMessage, TReplyMessage>.ReceiveMessageCommand(messageType, data));
// }
//
// private void LogMessageType(Type messageType, ReadOnlyMemory<byte> data) {
// if (logger.IsEnabled(LogEventLevel.Verbose)) {
// logger.Verbose("Received {MessageType} ({Bytes} B).", messageType.Name, data.Length);
// }
// }
//
// private void OnConnectionClosed(object? sender, RpcClientConnectionClosedEventArgs e) {
// Connection.Closed -= OnConnectionClosed;
//
// logger.Debug("Closing connection...");
// receiverActor.Stop();
// }
// }
// }

View File

@@ -52,6 +52,14 @@ static class Serialization {
return ReadValue(BinaryPrimitives.ReadUInt32LittleEndian, sizeof(uint), stream, cancellationToken); return ReadValue(BinaryPrimitives.ReadUInt32LittleEndian, sizeof(uint), stream, cancellationToken);
} }
public static ValueTask WriteSignedLong(long value, Stream stream, CancellationToken cancellationToken) {
return WriteValue(value, sizeof(long), BinaryPrimitives.WriteInt64LittleEndian, stream, cancellationToken);
}
public static ValueTask<long> ReadSignedLong(Stream stream, CancellationToken cancellationToken) {
return ReadValue(BinaryPrimitives.ReadInt64LittleEndian, sizeof(long), stream, cancellationToken);
}
public static ValueTask WriteGuid(Guid guid, Stream stream, CancellationToken cancellationToken) { public static ValueTask WriteGuid(Guid guid, Stream stream, CancellationToken cancellationToken) {
static void Write(Span<byte> span, Guid guid) { static void Write(Span<byte> span, Guid guid) {
if (!guid.TryWriteBytes(span)) { if (!guid.TryWriteBytes(span)) {

View File

@@ -0,0 +1,93 @@
using NUnit.Framework;
using Phantom.Utils.Collections;
using Range = Phantom.Utils.Collections.RangeSet<int>.Range;
namespace Phantom.Utils.Tests.Collections;
[TestFixture]
public class RangeSetTests {
[Test]
public void OneValue() {
var set = new RangeSet<int>();
set.Add(5);
Assert.That(set, Is.EqualTo(new[] {
new Range(Min: 5, Max: 5),
}));
}
[Test]
public void MultipleDisjointValues() {
var set = new RangeSet<int>();
set.Add(5);
set.Add(7);
set.Add(1);
set.Add(3);
Assert.That(set, Is.EqualTo(new[] {
new Range(Min: 1, Max: 1),
new Range(Min: 3, Max: 3),
new Range(Min: 5, Max: 5),
new Range(Min: 7, Max: 7),
}));
}
[Test]
public void ExtendMin() {
var set = new RangeSet<int>();
set.Add(5);
set.Add(4);
Assert.That(set, Is.EqualTo(new[] {
new Range(Min: 4, Max: 5),
}));
}
[Test]
public void ExtendMax() {
var set = new RangeSet<int>();
set.Add(5);
set.Add(6);
Assert.That(set, Is.EqualTo(new[] {
new Range(Min: 5, Max: 6),
}));
}
[Test]
public void ExtendMaxAndMerge() {
var set = new RangeSet<int>();
set.Add(5);
set.Add(7);
set.Add(6);
Assert.That(set, Is.EqualTo(new[] {
new Range(Min: 5, Max: 7),
}));
}
[Test]
public void MultipleMergingAndDisjointValues() {
var set = new RangeSet<int>();
set.Add(1);
set.Add(2);
set.Add(5);
set.Add(4);
set.Add(10);
set.Add(7);
set.Add(9);
set.Add(11);
set.Add(16);
set.Add(12);
set.Add(3);
set.Add(14);
Assert.That(set, Is.EqualTo(new[] {
new Range(Min: 1, Max: 5),
new Range(Min: 7, Max: 7),
new Range(Min: 9, Max: 12),
new Range(Min: 14, Max: 14),
new Range(Min: 16, Max: 16),
}));
}
}

View File

@@ -0,0 +1,69 @@
using System.Collections;
using System.Numerics;
namespace Phantom.Utils.Collections;
public sealed class RangeSet<T> : IEnumerable<RangeSet<T>.Range> where T : IBinaryInteger<T> {
private readonly List<Range> ranges = [];
public bool Add(T value) {
int index = 0;
for (; index < ranges.Count; index++) {
var range = ranges[index];
if (range.Contains(value)) {
return false;
}
if (range.ExtendIfAtEdge(value, out var extendedRange)) {
ranges[index] = extendedRange;
if (index < ranges.Count - 1) {
var nextRange = ranges[index + 1];
if (extendedRange.Max + T.One == nextRange.Min) {
ranges[index] = new Range(extendedRange.Min, nextRange.Max);
ranges.RemoveAt(index + 1);
}
}
return true;
}
if (range.Max > value) {
break;
}
}
ranges.Insert(index, new Range(value, value));
return true;
}
public IEnumerator<Range> GetEnumerator() {
return ranges.GetEnumerator();
}
IEnumerator IEnumerable.GetEnumerator() {
return GetEnumerator();
}
public readonly record struct Range(T Min, T Max) {
internal bool ExtendIfAtEdge(T value, out Range newRange) {
if (value == Min - T.One) {
newRange = this with { Min = value };
return true;
}
else if (value == Max + T.One) {
newRange = this with { Max = value };
return true;
}
else {
newRange = default;
return false;
}
}
internal bool Contains(T value) {
return value >= Min && value <= Max;
}
}
}

View File

@@ -14,7 +14,7 @@ namespace Phantom.Web.Services;
public static class PhantomWebServices { public static class PhantomWebServices {
public static void AddPhantomServices(this IServiceCollection services) { public static void AddPhantomServices(this IServiceCollection services) {
services.AddSingleton<ControllerConnection>(); services.AddSingleton<ControllerConnection>();
services.AddSingleton<ControllerMessageHandlerFactory>(); services.AddSingleton<ControllerMessageHandlerActorInitFactory>();
services.AddSingleton<AgentManager>(); services.AddSingleton<AgentManager>();
services.AddSingleton<InstanceManager>(); services.AddSingleton<InstanceManager>();

View File

@@ -6,7 +6,7 @@ namespace Phantom.Web.Services.Rpc;
public sealed class ControllerConnection(RpcSendChannel<IMessageToController> connection) { public sealed class ControllerConnection(RpcSendChannel<IMessageToController> connection) {
public ValueTask Send<TMessage>(TMessage message) where TMessage : IMessageToController { public ValueTask Send<TMessage>(TMessage message) where TMessage : IMessageToController {
return connection.SendMessage(message, CancellationToken.None); return connection.SendMessage(message);
} }
public Task<TReply> Send<TMessage, TReply>(TMessage message, TimeSpan waitForReplyTime, CancellationToken waitForReplyCancellationToken = default) where TMessage : IMessageToController, ICanReply<TReply> { public Task<TReply> Send<TMessage, TReply>(TMessage message, TimeSpan waitForReplyTime, CancellationToken waitForReplyCancellationToken = default) where TMessage : IMessageToController, ICanReply<TReply> {

View File

@@ -7,13 +7,12 @@ using Phantom.Web.Services.Instances;
namespace Phantom.Web.Services.Rpc; namespace Phantom.Web.Services.Rpc;
sealed class ControllerMessageHandlerActor : ReceiveActor<IMessageToWeb> { public sealed class ControllerMessageHandlerActor : ReceiveActor<IMessageToWeb> {
public readonly record struct Init( public readonly record struct Init(
AgentManager AgentManager, AgentManager AgentManager,
InstanceManager InstanceManager, InstanceManager InstanceManager,
InstanceLogManager InstanceLogManager, InstanceLogManager InstanceLogManager,
UserSessionRefreshManager UserSessionRefreshManager, UserSessionRefreshManager UserSessionRefreshManager
TaskCompletionSource<bool> RegisterSuccessWaiter
); );
public static Props<IMessageToWeb> Factory(Init init) { public static Props<IMessageToWeb> Factory(Init init) {
@@ -24,26 +23,19 @@ sealed class ControllerMessageHandlerActor : ReceiveActor<IMessageToWeb> {
private readonly InstanceManager instanceManager; private readonly InstanceManager instanceManager;
private readonly InstanceLogManager instanceLogManager; private readonly InstanceLogManager instanceLogManager;
private readonly UserSessionRefreshManager userSessionRefreshManager; private readonly UserSessionRefreshManager userSessionRefreshManager;
private readonly TaskCompletionSource<bool> registerSuccessWaiter;
private ControllerMessageHandlerActor(Init init) { private ControllerMessageHandlerActor(Init init) {
this.agentManager = init.AgentManager; this.agentManager = init.AgentManager;
this.instanceManager = init.InstanceManager; this.instanceManager = init.InstanceManager;
this.instanceLogManager = init.InstanceLogManager; this.instanceLogManager = init.InstanceLogManager;
this.userSessionRefreshManager = init.UserSessionRefreshManager; this.userSessionRefreshManager = init.UserSessionRefreshManager;
this.registerSuccessWaiter = init.RegisterSuccessWaiter;
Receive<RegisterWebResultMessage>(HandleRegisterWebResult);
Receive<RefreshAgentsMessage>(HandleRefreshAgents); Receive<RefreshAgentsMessage>(HandleRefreshAgents);
Receive<RefreshInstancesMessage>(HandleRefreshInstances); Receive<RefreshInstancesMessage>(HandleRefreshInstances);
Receive<InstanceOutputMessage>(HandleInstanceOutput); Receive<InstanceOutputMessage>(HandleInstanceOutput);
Receive<RefreshUserSessionMessage>(HandleRefreshUserSession); Receive<RefreshUserSessionMessage>(HandleRefreshUserSession);
} }
private void HandleRegisterWebResult(RegisterWebResultMessage message) {
registerSuccessWaiter.TrySetResult(message.Success);
}
private void HandleRefreshAgents(RefreshAgentsMessage message) { private void HandleRefreshAgents(RefreshAgentsMessage message) {
agentManager.RefreshAgents(message.Agents); agentManager.RefreshAgents(message.Agents);
} }

View File

@@ -0,0 +1,16 @@
using Phantom.Web.Services.Agents;
using Phantom.Web.Services.Authentication;
using Phantom.Web.Services.Instances;
namespace Phantom.Web.Services.Rpc;
public sealed class ControllerMessageHandlerActorInitFactory(
AgentManager agentManager,
InstanceManager instanceManager,
InstanceLogManager instanceLogManager,
UserSessionRefreshManager userSessionRefreshManager
) {
public ControllerMessageHandlerActor.Init Create() {
return new ControllerMessageHandlerActor.Init(agentManager, instanceManager, instanceLogManager, userSessionRefreshManager);
}
}

View File

@@ -1,28 +0,0 @@
using Akka.Actor;
using Phantom.Common.Messages.Web;
using Phantom.Utils.Actor;
using Phantom.Utils.Tasks;
using Phantom.Web.Services.Agents;
using Phantom.Web.Services.Authentication;
using Phantom.Web.Services.Instances;
namespace Phantom.Web.Services.Rpc;
public sealed class ControllerMessageHandlerFactory(
AgentManager agentManager,
InstanceManager instanceManager,
InstanceLogManager instanceLogManager,
UserSessionRefreshManager userSessionRefreshManager
) {
private readonly TaskCompletionSource<bool> registerSuccessWaiter = AsyncTasks.CreateCompletionSource<bool>();
public Task<bool> RegisterSuccessWaiter => registerSuccessWaiter.Task;
private int messageHandlerId = 0;
public ActorRef<IMessageToWeb> Create(IActorRefFactory actorSystem) {
var init = new ControllerMessageHandlerActor.Init(agentManager, instanceManager, instanceLogManager, userSessionRefreshManager, registerSuccessWaiter);
var name = "ControllerMessageHandler-" + Interlocked.Increment(ref messageHandlerId);
return actorSystem.ActorOf(ControllerMessageHandlerActor.Factory(init), name);
}
}

View File

@@ -5,6 +5,7 @@ using Phantom.Utils.Actor;
using Phantom.Utils.Cryptography; using Phantom.Utils.Cryptography;
using Phantom.Utils.IO; using Phantom.Utils.IO;
using Phantom.Utils.Logging; using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Message;
using Phantom.Utils.Rpc.Runtime.Client; using Phantom.Utils.Rpc.Runtime.Client;
using Phantom.Utils.Runtime; using Phantom.Utils.Runtime;
using Phantom.Utils.Threading; using Phantom.Utils.Threading;
@@ -71,11 +72,6 @@ try {
using var actorSystem = ActorSystemFactory.Create("Web"); using var actorSystem = ActorSystemFactory.Create("Web");
ControllerMessageHandlerFactory messageHandlerFactory;
await using (var scope = webApplication.Services.CreateAsyncScope()) {
messageHandlerFactory = scope.ServiceProvider.GetRequiredService<ControllerMessageHandlerFactory>();
}
Task? rpcClientListener = null; Task? rpcClientListener = null;
try { try {
PhantomLogger.Root.InformationHeading("Launching Phantom Panel web..."); PhantomLogger.Root.InformationHeading("Launching Phantom Panel web...");
@@ -83,17 +79,28 @@ try {
PhantomLogger.Root.Information("For administrator setup, visit: {HttpUrl}{SetupPath}", webConfiguration.HttpUrl, webConfiguration.BasePath + "setup"); PhantomLogger.Root.Information("For administrator setup, visit: {HttpUrl}{SetupPath}", webConfiguration.HttpUrl, webConfiguration.BasePath + "setup");
await WebLauncher.Launch(webConfiguration, webApplication); await WebLauncher.Launch(webConfiguration, webApplication);
rpcClientListener = rpcClient.Listen(messageHandlerFactory.Create(actorSystem));
ActorRef<IMessageToWeb> rpcMessageHandlerActor;
await using (var scope = webApplication.Services.CreateAsyncScope()) {
var rpcMessageHandlerInit = scope.ServiceProvider.GetRequiredService<ControllerMessageHandlerActorInitFactory>().Create();
rpcMessageHandlerActor = actorSystem.ActorOf(ControllerMessageHandlerActor.Factory(rpcMessageHandlerInit), "ControllerMessageHandler");
}
rpcClientListener = rpcClient.Listen(new IMessageReceiver<IMessageToWeb>.Actor(rpcMessageHandlerActor));
PhantomLogger.Root.Information("Phantom Panel web is ready."); PhantomLogger.Root.Information("Phantom Panel web is ready.");
await shutdownCancellationToken.WaitHandle.WaitOneAsync(); await shutdownCancellationToken.WaitHandle.WaitOneAsync();
await webApplication.StopAsync();
} finally { } finally {
PhantomLogger.Root.Information("Unregistering web...");
try { try {
await rpcClient.SendChannel.SendMessage(new UnregisterWebMessage(), CancellationToken.None); using var unregisterCancellationTokenSource = new CancellationTokenSource(TimeSpan.FromSeconds(10));
// TODO wait for acknowledgment await rpcClient.SendChannel.SendMessage(new UnregisterWebMessage(), unregisterCancellationTokenSource.Token);
} catch (OperationCanceledException) {
PhantomLogger.Root.Warning("Could not unregister web after shutdown.");
} catch (Exception e) { } catch (Exception e) {
PhantomLogger.Root.Warning(e, "Could not unregister agent after shutdown."); PhantomLogger.Root.Warning(e, "Could not unregister web after shutdown.");
} finally { } finally {
await rpcClient.Shutdown(); await rpcClient.Shutdown();

View File

@@ -62,7 +62,7 @@ static class WebLauncher {
application.MapFallbackToPage("/_Host"); application.MapFallbackToPage("/_Host");
logger.Information("Starting Web server on port {Port}...", config.Port); logger.Information("Starting Web server on port {Port}...", config.Port);
return application.RunAsync(config.CancellationToken); return application.StartAsync(config.CancellationToken);
} }
private sealed class NullLifetime : IHostLifetime { private sealed class NullLifetime : IHostLifetime {