1
0
mirror of https://github.com/chylex/Minecraft-Phantom-Panel.git synced 2025-10-16 06:39:35 +02:00

8 Commits

54 changed files with 1270 additions and 409 deletions

View File

@@ -0,0 +1,289 @@
using System.Collections.Immutable;
using NUnit.Framework;
using Phantom.Agent.Minecraft.Java;
using Phantom.Utils.Collections;
namespace Phantom.Agent.Minecraft.Tests.Java;
[TestFixture]
public sealed class JavaPropertiesStreamTests {
public sealed class Reader {
private static async Task<ImmutableArray<KeyValuePair<string, string>>> Parse(string contents) {
using var stream = new MemoryStream(JavaPropertiesStream.Encoding.GetBytes(contents));
using var properties = new JavaPropertiesStream.Reader(stream);
return await properties.ReadProperties(CancellationToken.None).ToImmutableArrayAsync();
}
private static ImmutableArray<KeyValuePair<string, string>> KeyValue(string key, string value) {
return [new KeyValuePair<string, string>(key, value)];
}
[TestCase("")]
[TestCase("\n")]
public async Task EmptyLinesAreIgnored(string contents) {
Assert.That(await Parse(contents), Is.EquivalentTo(ImmutableArray<KeyValuePair<string, string>>.Empty));
}
[TestCase("# Comment")]
[TestCase("! Comment")]
[TestCase("# Comment\n! Comment")]
public async Task CommentsAreIgnored(string contents) {
Assert.That(await Parse(contents), Is.EquivalentTo(ImmutableArray<KeyValuePair<string, string>>.Empty));
}
[TestCase("key=value")]
[TestCase("key= value")]
[TestCase("key =value")]
[TestCase("key = value")]
[TestCase("key:value")]
[TestCase("key: value")]
[TestCase("key :value")]
[TestCase("key : value")]
[TestCase("key value")]
[TestCase("key\tvalue")]
[TestCase("key\fvalue")]
[TestCase("key \t\fvalue")]
public async Task SimpleKeyValue(string contents) {
Assert.That(await Parse(contents), Is.EquivalentTo(KeyValue("key", "value")));
}
[TestCase("key")]
[TestCase(" key")]
[TestCase(" key ")]
[TestCase("key=")]
[TestCase("key:")]
public async Task KeyWithoutValue(string contents) {
Assert.That(await Parse(contents), Is.EquivalentTo(KeyValue("key", "")));
}
[TestCase(@"\#key=value", "#key")]
[TestCase(@"\!key=value", "!key")]
public async Task KeyBeginsWithEscapedComment(string contents, string expectedKey) {
Assert.That(await Parse(contents), Is.EquivalentTo(KeyValue(expectedKey, "value")));
}
[TestCase(@"\=key=value", "=key")]
[TestCase(@"\:key=value", ":key")]
[TestCase(@"\ key=value", " key")]
[TestCase("\\\tkey=value", "\tkey")]
[TestCase("\\\fkey=value", "\fkey")]
public async Task KeyBeginsWithEscapedDelimiter(string contents, string expectedKey) {
Assert.That(await Parse(contents), Is.EquivalentTo(KeyValue(expectedKey, "value")));
}
[TestCase(@"start\=end=value", "start=end")]
[TestCase(@"start\:end:value", "start:end")]
[TestCase(@"start\ end value", "start end")]
[TestCase(@"start\ \:\=end = value", "start :=end")]
[TestCase("start\\ \\\t\\\fend = value", "start \t\fend")]
public async Task KeyContainsEscapedDelimiter(string contents, string expectedKey) {
Assert.That(await Parse(contents), Is.EquivalentTo(KeyValue(expectedKey, "value")));
}
[TestCase(@"key = \ value", " value")]
[TestCase("key = \\\tvalue", "\tvalue")]
[TestCase("key = \\\fvalue", "\fvalue")]
[TestCase("key=\\ \\\t\\\fvalue", " \t\fvalue")]
public async Task ValueBeginsWithEscapedWhitespace(string contents, string expectedValue) {
Assert.That(await Parse(contents), Is.EquivalentTo(KeyValue("key", expectedValue)));
}
[TestCase(@"key = value\", "value")]
public async Task ValueEndsWithTrailingBackslash(string contents, string expectedValue) {
Assert.That(await Parse(contents), Is.EquivalentTo(KeyValue("key", expectedValue)));
}
[TestCase("key=\\\0", "\0")]
[TestCase(@"key=\\", "\\")]
[TestCase(@"key=\t", "\t")]
[TestCase(@"key=\n", "\n")]
[TestCase(@"key=\r", "\r")]
[TestCase(@"key=\f", "\f")]
[TestCase(@"key=\u3053\u3093\u306b\u3061\u306f", "こんにちは")]
[TestCase(@"key=\u3053\u3093\u306B\u3061\u306F", "こんにちは")]
[TestCase("key=\\\0\\\\\\t\\n\\r\\f\\u3053", "\0\\\t\n\r\fこ")]
public async Task ValueContainsEscapedSpecialCharacters(string contents, string expectedValue) {
Assert.That(await Parse(contents), Is.EquivalentTo(KeyValue("key", expectedValue)));
}
[TestCase("key=first\\\nsecond", "first\nsecond")]
[TestCase("key=first\\\n second", "first\nsecond")]
[TestCase("key=first\\\n#second", "first\n#second")]
[TestCase("key=first\\\n!second", "first\n!second")]
public async Task ValueContainsNewLine(string contents, string expectedValue) {
Assert.That(await Parse(contents), Is.EquivalentTo(KeyValue("key", expectedValue)));
}
[TestCase("key=first\\\n \\ second", "first\n second")]
[TestCase("key=first\\\n \\\tsecond", "first\n\tsecond")]
[TestCase("key=first\\\n \\\fsecond", "first\n\fsecond")]
[TestCase("key=first\\\n \t\f\\ second", "first\n second")]
public async Task ValueContainsNewLineWithEscapedLeadingWhitespace(string contents, string expectedValue) {
Assert.That(await Parse(contents), Is.EquivalentTo(KeyValue("key", expectedValue)));
}
[Test]
public async Task ExampleFile() {
// From Wikipedia: https://en.wikipedia.org/wiki/.properties
const string ExampleFile = """
# You are reading a comment in ".properties" file.
! The exclamation mark ('!') can also be used for comments.
# Comments are ignored.
# Blank lines are also ignored.
# Lines with "properties" contain a key and a value separated by a delimiting character.
# There are 3 delimiting characters: equal ('='), colon (':') and whitespace (' ', '\t' and '\f').
website = https://en.wikipedia.org/
language : English
topic .properties files
# A word on a line will just create a key with no value.
empty
# Whitespace that appears between the key, the delimiter and the value is ignored.
# This means that the following are equivalent (other than for readability).
hello=hello
hello = hello
# To start the value with whitespace, escape it with a backslash ('\').
whitespaceStart = \ <-This space is not ignored.
# Keys with the same name will be overwritten by the key that is the furthest in a file.
# For example the final value for "duplicateKey" will be "second".
duplicateKey = first
duplicateKey = second
# To use the delimiter characters inside a key, you need to escape them with a ('\').
# However, there is no need to do this in the value.
delimiterCharacters\:\=\ = This is the value for the key "delimiterCharacters\:\=\ "
# Adding a backslash ('\') at the end of a line means that the value continues on the next line.
multiline = This line \
continues
# If you want your value to include a backslash ('\'), it should be escaped by another backslash ('\').
path = c:\\wiki\\templates
# This means that if the number of backslashes ('\') at the end of the line is even, the next line is not included in the value.
# In the following example, the value for "evenKey" is "This is on one line\".
evenKey = This is on one line\\
# This line is a normal comment and is not included in the value for "evenKey".
# If the number of backslash ('\') is odd, then the next line is included in the value.
# In the following example, the value for "oddKey" is "This is line one and\# This is line two".
oddKey = This is line one and\\\
# This is line two
# Whitespace characters at the beginning of a line is removed.
# Make sure to add the spaces you need before the backslash ('\') on the first line.
# If you add them at the beginning of the next line, they will be removed.
# In the following example, the value for "welcome" is "Welcome to Wikipedia!".
welcome = Welcome to \
Wikipedia!
# If you need to add newlines and carriage returns, they need to be escaped using ('\n') and ('\r') respectively.
# You can also optionally escape tabs with ('\t') for readability purposes.
valueWithEscapes = This is a newline\n and a carriage return\r and a tab\t.
# You can also use Unicode escape characters (maximum of four hexadecimal digits).
# In the following example, the value for "encodedHelloInJapanese" is "こんにちは".
encodedHelloInJapanese = \u3053\u3093\u306b\u3061\u306f
""";
ImmutableArray<KeyValuePair<string, string>> result = [
new ("website", "https://en.wikipedia.org/"),
new ("language", "English"),
new ("topic", ".properties files"),
new ("empty", ""),
new ("hello", "hello"),
new ("hello", "hello"),
new ("whitespaceStart", @" <-This space is not ignored."),
new ("duplicateKey", "first"),
new ("duplicateKey", "second"),
new ("delimiterCharacters:= ", @"This is the value for the key ""delimiterCharacters:= """),
new ("multiline", "This line \ncontinues"),
new ("path", @"c:\wiki\templates"),
new ("evenKey", @"This is on one line\"),
new ("oddKey", "This is line one and\\\n# This is line two"),
new ("welcome", "Welcome to \nWikipedia!"),
new ("valueWithEscapes", "This is a newline\n and a carriage return\r and a tab\t."),
new ("encodedHelloInJapanese", "こんにちは"),
];
Assert.That(await Parse(ExampleFile), Is.EquivalentTo(result));
}
}
public sealed class Writer {
private static async Task<string> Write(Func<JavaPropertiesStream.Writer, Task> write) {
using var stream = new MemoryStream();
await using (var writer = new JavaPropertiesStream.Writer(stream)) {
await write(writer);
}
return JavaPropertiesStream.Encoding.GetString(stream.ToArray());
}
[TestCase("one line comment", "# one line comment\n")]
[TestCase("こんにちは", "# \\u3053\\u3093\\u306B\\u3061\\u306F\n")]
[TestCase("first line\nsecond line\r\nthird line", "# first line\n# second line\n# third line\n")]
public async Task Comment(string comment, string contents) {
Assert.That(await Write(writer => writer.WriteComment(comment, CancellationToken.None)), Is.EqualTo(contents));
}
[TestCase("key", "value", "key=value\n")]
[TestCase("key", "", "key=\n")]
[TestCase("", "value", "=value\n")]
public async Task SimpleKeyValue(string key, string value, string contents) {
Assert.That(await Write(writer => writer.WriteProperty(key, value, CancellationToken.None)), Is.EqualTo(contents));
}
[TestCase("#key", "value", "\\#key=value\n")]
[TestCase("!key", "value", "\\!key=value\n")]
public async Task KeyBeginsWithEscapedComment(string key, string value, string contents) {
Assert.That(await Write(writer => writer.WriteProperty(key, value, CancellationToken.None)), Is.EqualTo(contents));
}
[TestCase("=key", "value", "\\=key=value\n")]
[TestCase(":key", "value", "\\:key=value\n")]
[TestCase(" key", "value", "\\ key=value\n")]
[TestCase("\tkey", "value", "\\tkey=value\n")]
[TestCase("\fkey", "value", "\\fkey=value\n")]
public async Task KeyBeginsWithEscapedDelimiter(string key, string value, string contents) {
Assert.That(await Write(writer => writer.WriteProperty(key, value, CancellationToken.None)), Is.EqualTo(contents));
}
[TestCase("start=end", "value", "start\\=end=value\n")]
[TestCase("start:end", "value", "start\\:end=value\n")]
[TestCase("start end", "value", "start\\ end=value\n")]
[TestCase("start :=end", "value", "start\\ \\:\\=end=value\n")]
[TestCase("start \t\fend", "value", "start\\ \\t\\fend=value\n")]
public async Task KeyContainsEscapedDelimiter(string key, string value, string contents) {
Assert.That(await Write(writer => writer.WriteProperty(key, value, CancellationToken.None)), Is.EqualTo(contents));
}
[TestCase("\\", "value", "\\\\=value\n")]
[TestCase("\t", "value", "\\t=value\n")]
[TestCase("\n", "value", "\\n=value\n")]
[TestCase("\r", "value", "\\r=value\n")]
[TestCase("\f", "value", "\\f=value\n")]
[TestCase("こんにちは", "value", "\\u3053\\u3093\\u306B\\u3061\\u306F=value\n")]
[TestCase("\\\t\n\r\fこ", "value", "\\\\\\t\\n\\r\\f\\u3053=value\n")]
[TestCase("first-line\nsecond-line\r\nthird-line", "value", "first-line\\nsecond-line\\r\\nthird-line=value\n")]
public async Task KeyContainsEscapedSpecialCharacters(string key, string value, string contents) {
Assert.That(await Write(writer => writer.WriteProperty(key, value, CancellationToken.None)), Is.EqualTo(contents));
}
[TestCase("key", "\\", "key=\\\\\n")]
[TestCase("key", "\t", "key=\\t\n")]
[TestCase("key", "\n", "key=\\n\n")]
[TestCase("key", "\r", "key=\\r\n")]
[TestCase("key", "\f", "key=\\f\n")]
[TestCase("key", "こんにちは", "key=\\u3053\\u3093\\u306B\\u3061\\u306F\n")]
[TestCase("key", "\\\t\n\r\fこ", "key=\\\\\\t\\n\\r\\f\\u3053\n")]
[TestCase("key", "first line\nsecond line\r\nthird line", "key=first line\\nsecond line\\r\\nthird line\n")]
public async Task ValueContainsEscapedSpecialCharacters(string key, string value, string contents) {
Assert.That(await Write(writer => writer.WriteProperty(key, value, CancellationToken.None)), Is.EqualTo(contents));
}
[Test]
public async Task ExampleFile() {
string contents = await Write(static async writer => {
await writer.WriteComment("Comment", CancellationToken.None);
await writer.WriteProperty("key", "value", CancellationToken.None);
await writer.WriteProperty("multiline", "first line\nsecond line", CancellationToken.None);
});
Assert.That(contents, Is.EqualTo("# Comment\nkey=value\nmultiline=first line\\nsecond line\n"));
}
}
}

View File

@@ -0,0 +1,23 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
</PropertyGroup>
<PropertyGroup>
<IsPackable>false</IsPackable>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Microsoft.NET.Test.Sdk" />
<PackageReference Include="NUnit" />
<PackageReference Include="NUnit3TestAdapter" />
<PackageReference Include="NUnit.Analyzers" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\Phantom.Agent.Minecraft\Phantom.Agent.Minecraft.csproj" />
</ItemGroup>
</Project>

View File

@@ -1,92 +1,52 @@
using System.Text; namespace Phantom.Agent.Minecraft.Java;
using Kajabity.Tools.Java;
namespace Phantom.Agent.Minecraft.Java;
sealed class JavaPropertiesFileEditor { sealed class JavaPropertiesFileEditor {
private static readonly Encoding Encoding = Encoding.GetEncoding("ISO-8859-1");
private readonly Dictionary<string, string> overriddenProperties = new (); private readonly Dictionary<string, string> overriddenProperties = new ();
public void Set(string key, string value) { public void Set(string key, string value) {
overriddenProperties[key] = value; overriddenProperties[key] = value;
} }
public async Task EditOrCreate(string filePath) { public async Task EditOrCreate(string filePath, string comment, CancellationToken cancellationToken) {
if (File.Exists(filePath)) { if (File.Exists(filePath)) {
string tmpFilePath = filePath + ".tmp"; string tmpFilePath = filePath + ".tmp";
File.Copy(filePath, tmpFilePath, overwrite: true); await Edit(filePath, tmpFilePath, comment, cancellationToken);
await EditFromCopyOrCreate(filePath, tmpFilePath);
File.Move(tmpFilePath, filePath, overwrite: true); File.Move(tmpFilePath, filePath, overwrite: true);
} }
else { else {
await EditFromCopyOrCreate(sourceFilePath: null, filePath); await Create(filePath, comment, cancellationToken);
} }
} }
private async Task EditFromCopyOrCreate(string? sourceFilePath, string targetFilePath) { private async Task Create(string targetFilePath, string comment, CancellationToken cancellationToken) {
var properties = new JavaProperties(); await using var targetWriter = new JavaPropertiesStream.Writer(targetFilePath);
if (sourceFilePath != null) { await targetWriter.WriteComment(comment, cancellationToken);
// TODO replace with custom async parser
await using var sourceStream = new FileStream(sourceFilePath, FileMode.Open, FileAccess.Read, FileShare.Read);
properties.Load(sourceStream, Encoding);
}
foreach (var (key, value) in overriddenProperties) { foreach ((string key, string value) in overriddenProperties) {
properties[key] = value; await targetWriter.WriteProperty(key, value, cancellationToken);
}
await using var targetStream = new FileStream(targetFilePath, FileMode.Create, FileAccess.Write, FileShare.Read);
await using var targetWriter = new StreamWriter(targetStream, Encoding);
await targetWriter.WriteLineAsync("# Properties");
foreach (var (key, value) in properties) {
await WriteProperty(targetWriter, key, value);
} }
} }
private static async Task WriteProperty(StreamWriter writer, string key, string value) { private async Task Edit(string sourceFilePath, string targetFilePath, string comment, CancellationToken cancellationToken) {
await WritePropertyComponent(writer, key, escapeSpaces: true); using var sourceReader = new JavaPropertiesStream.Reader(sourceFilePath);
await writer.WriteAsync('='); await using var targetWriter = new JavaPropertiesStream.Writer(targetFilePath);
await WritePropertyComponent(writer, value, escapeSpaces: false);
await writer.WriteLineAsync(); await targetWriter.WriteComment(comment, cancellationToken);
}
var remainingOverriddenPropertyKeys = new HashSet<string>(overriddenProperties.Keys);
private static async Task WritePropertyComponent(TextWriter writer, string component, bool escapeSpaces) {
for (int index = 0; index < component.Length; index++) { await foreach ((string key, string value) in sourceReader.ReadProperties(cancellationToken)) {
var c = component[index]; if (remainingOverriddenPropertyKeys.Remove(key)) {
switch (c) { await targetWriter.WriteProperty(key, overriddenProperties[key], cancellationToken);
case '\\':
case '#':
case '!':
case '=':
case ':':
case ' ' when escapeSpaces || index == 0:
await writer.WriteAsync('\\');
await writer.WriteAsync(c);
break;
case var _ when c > 31 && c < 127:
await writer.WriteAsync(c);
break;
case '\t':
await writer.WriteAsync("\\t");
break;
case '\n':
await writer.WriteAsync("\\n");
break;
case '\r':
await writer.WriteAsync("\\r");
break;
case '\f':
await writer.WriteAsync("\\f");
break;
default:
await writer.WriteAsync("\\u");
await writer.WriteAsync(((int) c).ToString("X4"));
break;
} }
else {
await targetWriter.WriteProperty(key, value, cancellationToken);
}
}
foreach (string key in remainingOverriddenPropertyKeys) {
await targetWriter.WriteProperty(key, overriddenProperties[key], cancellationToken);
} }
} }
} }

View File

@@ -0,0 +1,284 @@
using System.Buffers;
using System.Globalization;
using System.Runtime.CompilerServices;
using System.Text;
using Phantom.Utils.Collections;
namespace Phantom.Agent.Minecraft.Java;
static class JavaPropertiesStream {
internal static readonly Encoding Encoding = Encoding.GetEncoding("ISO-8859-1");
private static FileStreamOptions CreateFileStreamOptions(FileMode mode, FileAccess access) {
return new FileStreamOptions {
Mode = mode,
Access = access,
Share = FileShare.Read,
Options = FileOptions.SequentialScan,
};
}
internal sealed class Reader : IDisposable {
private static readonly SearchValues<char> LineStartWhitespace = SearchValues.Create(' ', '\t', '\f');
private static readonly SearchValues<char> KeyValueDelimiter = SearchValues.Create('=', ':', ' ', '\t', '\f');
private static readonly SearchValues<char> Backslash = SearchValues.Create('\\');
private readonly StreamReader reader;
public Reader(Stream stream) {
this.reader = new StreamReader(stream, Encoding, leaveOpen: false);
}
public Reader(string path) {
this.reader = new StreamReader(path, Encoding, detectEncodingFromByteOrderMarks: false, CreateFileStreamOptions(FileMode.Open, FileAccess.Read));
}
public async IAsyncEnumerable<KeyValuePair<string, string>> ReadProperties([EnumeratorCancellation] CancellationToken cancellationToken) {
await foreach (string line in ReadLogicalLines(cancellationToken)) {
yield return ParseLine(line.AsSpan());
}
}
private async IAsyncEnumerable<string> ReadLogicalLines([EnumeratorCancellation] CancellationToken cancellationToken) {
StringBuilder nextLogicalLine = new StringBuilder();
while (await reader.ReadLineAsync(cancellationToken) is {} line) {
var span = line.AsSpan();
int startIndex = span.IndexOfAnyExcept(LineStartWhitespace);
if (startIndex == -1) {
continue;
}
if (nextLogicalLine.Length == 0 && (span[0] == '#' || span[0] == '!')) {
continue;
}
span = span[startIndex..];
if (IsEndEscaped(span)) {
nextLogicalLine.Append(span[..^1]);
nextLogicalLine.Append('\n');
}
else {
nextLogicalLine.Append(span);
yield return nextLogicalLine.ToString();
nextLogicalLine.Clear();
}
}
if (nextLogicalLine.Length > 0) {
yield return nextLogicalLine.ToString(startIndex: 0, nextLogicalLine.Length - 1); // Remove trailing new line.
}
}
private static KeyValuePair<string, string> ParseLine(ReadOnlySpan<char> line) {
int delimiterIndex = -1;
foreach (int candidateIndex in line.IndicesOf(KeyValueDelimiter)) {
if (candidateIndex == 0 || !IsEndEscaped(line[..candidateIndex])) {
delimiterIndex = candidateIndex;
break;
}
}
if (delimiterIndex == -1) {
return new KeyValuePair<string, string>(line.ToString(), string.Empty);
}
string key = ReadPropertyComponent(line[..delimiterIndex]);
line = line[(delimiterIndex + 1)..];
int valueStartIndex = line.IndexOfAnyExcept(KeyValueDelimiter);
string value = valueStartIndex == -1 ? string.Empty : ReadPropertyComponent(line[valueStartIndex..]);
return new KeyValuePair<string, string>(key, value);
}
private static string ReadPropertyComponent(ReadOnlySpan<char> component) {
StringBuilder builder = new StringBuilder();
int nextStartIndex = 0;
foreach (int backslashIndex in component.IndicesOf(Backslash)) {
if (backslashIndex == component.Length - 1) {
break;
}
if (backslashIndex < nextStartIndex) {
continue;
}
builder.Append(component[nextStartIndex..backslashIndex]);
int escapedIndex = backslashIndex + 1;
int escapedLength = 1;
char c = component[escapedIndex];
switch (c) {
case 't':
builder.Append('\t');
break;
case 'n':
builder.Append('\n');
break;
case 'r':
builder.Append('\r');
break;
case 'f':
builder.Append('\f');
break;
case 'u':
escapedLength += 4;
int hexRangeStart = escapedIndex + 1;
int hexRangeEnd = hexRangeStart + 4;
if (hexRangeEnd - 1 < component.Length) {
var hexString = component[hexRangeStart..hexRangeEnd];
int hexValue = int.Parse(hexString, NumberStyles.HexNumber);
builder.Append((char) hexValue);
}
else {
throw new FormatException("Malformed \\uxxxx encoding.");
}
break;
default:
builder.Append(c);
break;
}
nextStartIndex = escapedIndex + escapedLength;
}
builder.Append(component[nextStartIndex..]);
return builder.ToString();
}
private static bool IsEndEscaped(ReadOnlySpan<char> span) {
if (span.EndsWith('\\')) {
int trailingBackslashCount = span.Length - span.TrimEnd('\\').Length;
return trailingBackslashCount % 2 == 1;
}
else {
return false;
}
}
public void Dispose() {
reader.Dispose();
}
}
internal sealed class Writer : IAsyncDisposable {
private const string CommentStart = "# ";
private readonly StreamWriter writer;
private readonly Memory<char> oneCharBuffer = new char[1];
public Writer(Stream stream) {
this.writer = new StreamWriter(stream, Encoding, leaveOpen: false);
}
public Writer(string path) {
this.writer = new StreamWriter(path, Encoding, CreateFileStreamOptions(FileMode.Create, FileAccess.Write));
}
public async Task WriteComment(string comment, CancellationToken cancellationToken) {
await Write(CommentStart, cancellationToken);
for (int index = 0; index < comment.Length; index++) {
char c = comment[index];
switch (c) {
case var _ when c > 31 && c < 127:
await Write(c, cancellationToken);
break;
case '\n':
case '\r':
await Write(c: '\n', cancellationToken);
await Write(CommentStart, cancellationToken);
if (index < comment.Length - 1 && comment[index + 1] == '\n') {
index++;
}
break;
default:
await Write("\\u", cancellationToken);
await Write(((int) c).ToString("X4"), cancellationToken);
break;
}
}
await Write(c: '\n', cancellationToken);
}
public async Task WriteProperty(string key, string value, CancellationToken cancellationToken) {
await WritePropertyComponent(key, escapeSpaces: true, cancellationToken);
await Write(c: '=', cancellationToken);
await WritePropertyComponent(value, escapeSpaces: false, cancellationToken);
await Write(c: '\n', cancellationToken);
}
private async Task WritePropertyComponent(string component, bool escapeSpaces, CancellationToken cancellationToken) {
for (int index = 0; index < component.Length; index++) {
char c = component[index];
switch (c) {
case '\\':
case '#':
case '!':
case '=':
case ':':
case ' ' when escapeSpaces || index == 0:
await Write(c: '\\', cancellationToken);
await Write(c, cancellationToken);
break;
case var _ when c > 31 && c < 127:
await Write(c, cancellationToken);
break;
case '\t':
await Write("\\t", cancellationToken);
break;
case '\n':
await Write("\\n", cancellationToken);
break;
case '\r':
await Write("\\r", cancellationToken);
break;
case '\f':
await Write("\\f", cancellationToken);
break;
default:
await Write("\\u", cancellationToken);
await Write(((int) c).ToString("X4"), cancellationToken);
break;
}
}
}
private Task Write(char c, CancellationToken cancellationToken) {
oneCharBuffer.Span[0] = c;
return writer.WriteAsync(oneCharBuffer, cancellationToken);
}
private Task Write(string value, CancellationToken cancellationToken) {
return writer.WriteAsync(value.AsMemory(), cancellationToken);
}
public async ValueTask DisposeAsync() {
await writer.DisposeAsync();
}
}
}

View File

@@ -43,7 +43,7 @@ public abstract class BaseLauncher : IServerLauncher {
try { try {
await AcceptEula(instanceProperties); await AcceptEula(instanceProperties);
await UpdateServerProperties(instanceProperties); await UpdateServerProperties(instanceProperties, cancellationToken);
} catch (Exception e) { } catch (Exception e) {
logger.Error(e, "Caught exception while configuring the server."); logger.Error(e, "Caught exception while configuring the server.");
return new LaunchResult.CouldNotConfigureMinecraftServer(); return new LaunchResult.CouldNotConfigureMinecraftServer();
@@ -108,9 +108,9 @@ public abstract class BaseLauncher : IServerLauncher {
await File.WriteAllLinesAsync(eulaFilePath, ["# EULA", "eula=true"], Encoding.UTF8); await File.WriteAllLinesAsync(eulaFilePath, ["# EULA", "eula=true"], Encoding.UTF8);
} }
private static async Task UpdateServerProperties(InstanceProperties instanceProperties) { private static async Task UpdateServerProperties(InstanceProperties instanceProperties, CancellationToken cancellationToken) {
var serverPropertiesEditor = new JavaPropertiesFileEditor(); var serverPropertiesEditor = new JavaPropertiesFileEditor();
instanceProperties.ServerProperties.SetTo(serverPropertiesEditor); instanceProperties.ServerProperties.SetTo(serverPropertiesEditor);
await serverPropertiesEditor.EditOrCreate(Path.Combine(instanceProperties.InstanceFolder, "server.properties")); await serverPropertiesEditor.EditOrCreate(Path.Combine(instanceProperties.InstanceFolder, "server.properties"), comment: "server.properties", cancellationToken);
} }
} }

View File

@@ -6,7 +6,7 @@
</PropertyGroup> </PropertyGroup>
<ItemGroup> <ItemGroup>
<PackageReference Include="Kajabity.Tools.Java" /> <InternalsVisibleTo Include="Phantom.Agent.Minecraft.Tests" />
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>

View File

@@ -65,7 +65,7 @@ try {
MaxConcurrentlyHandledMessages: 50 MaxConcurrentlyHandledMessages: 50
); );
using var rpcClient = await RpcClient<IMessageToController, IMessageToAgent>.Connect("Controller", rpcClientConnectionParameters, AgentMessageRegistries.Definitions, shutdownCancellationToken); using var rpcClient = await RpcClient<IMessageToController, IMessageToAgent>.Connect("Controller", rpcClientConnectionParameters, AgentMessageRegistries.Registries, shutdownCancellationToken);
if (rpcClient == null) { if (rpcClient == null) {
PhantomLogger.Root.Fatal("Could not connect to Phantom Controller, shutting down."); PhantomLogger.Root.Fatal("Could not connect to Phantom Controller, shutting down.");
return 1; return 1;

View File

@@ -1,19 +1,16 @@
using System.Collections.Immutable; using System.Collections.Immutable;
using MemoryPack; using MemoryPack;
using Phantom.Common.Data.Web.Users.AddUserErrors;
namespace Phantom.Common.Data.Web.Users { namespace Phantom.Common.Data.Web.Users;
[MemoryPackable]
[MemoryPackUnion(tag: 0, typeof(NameIsInvalid))]
[MemoryPackUnion(tag: 1, typeof(PasswordIsInvalid))]
[MemoryPackUnion(tag: 2, typeof(NameAlreadyExists))]
[MemoryPackUnion(tag: 3, typeof(UnknownError))]
public abstract partial record AddUserError {
internal AddUserError() {}
}
}
namespace Phantom.Common.Data.Web.Users.AddUserErrors { [MemoryPackable]
[MemoryPackUnion(tag: 0, typeof(NameIsInvalid))]
[MemoryPackUnion(tag: 1, typeof(PasswordIsInvalid))]
[MemoryPackUnion(tag: 2, typeof(NameAlreadyExists))]
[MemoryPackUnion(tag: 3, typeof(UnknownError))]
public abstract partial record AddUserError {
private AddUserError() {}
[MemoryPackable(GenerateType.VersionTolerant)] [MemoryPackable(GenerateType.VersionTolerant)]
public sealed partial record NameIsInvalid([property: MemoryPackOrder(0)] UsernameRequirementViolation Violation) : AddUserError; public sealed partial record NameIsInvalid([property: MemoryPackOrder(0)] UsernameRequirementViolation Violation) : AddUserError;

View File

@@ -1,19 +1,16 @@
using MemoryPack; using MemoryPack;
using Phantom.Common.Data.Web.Users.CreateOrUpdateAdministratorUserResults;
namespace Phantom.Common.Data.Web.Users { namespace Phantom.Common.Data.Web.Users;
[MemoryPackable]
[MemoryPackUnion(tag: 0, typeof(Success))]
[MemoryPackUnion(tag: 1, typeof(CreationFailed))]
[MemoryPackUnion(tag: 2, typeof(UpdatingFailed))]
[MemoryPackUnion(tag: 3, typeof(AddingToRoleFailed))]
[MemoryPackUnion(tag: 4, typeof(UnknownError))]
public abstract partial record CreateOrUpdateAdministratorUserResult {
internal CreateOrUpdateAdministratorUserResult() {}
}
}
namespace Phantom.Common.Data.Web.Users.CreateOrUpdateAdministratorUserResults { [MemoryPackable]
[MemoryPackUnion(tag: 0, typeof(Success))]
[MemoryPackUnion(tag: 1, typeof(CreationFailed))]
[MemoryPackUnion(tag: 2, typeof(UpdatingFailed))]
[MemoryPackUnion(tag: 3, typeof(AddingToRoleFailed))]
[MemoryPackUnion(tag: 4, typeof(UnknownError))]
public abstract partial record CreateOrUpdateAdministratorUserResult {
private CreateOrUpdateAdministratorUserResult() {}
[MemoryPackable(GenerateType.VersionTolerant)] [MemoryPackable(GenerateType.VersionTolerant)]
public sealed partial record Success([property: MemoryPackOrder(0)] UserInfo User) : CreateOrUpdateAdministratorUserResult; public sealed partial record Success([property: MemoryPackOrder(0)] UserInfo User) : CreateOrUpdateAdministratorUserResult;

View File

@@ -1,17 +1,14 @@
using MemoryPack; using MemoryPack;
using Phantom.Common.Data.Web.Users.CreateUserResults;
namespace Phantom.Common.Data.Web.Users { namespace Phantom.Common.Data.Web.Users;
[MemoryPackable]
[MemoryPackUnion(tag: 0, typeof(Success))]
[MemoryPackUnion(tag: 1, typeof(CreationFailed))]
[MemoryPackUnion(tag: 2, typeof(UnknownError))]
public abstract partial record CreateUserResult {
internal CreateUserResult() {}
}
}
namespace Phantom.Common.Data.Web.Users.CreateUserResults { [MemoryPackable]
[MemoryPackUnion(tag: 0, typeof(Success))]
[MemoryPackUnion(tag: 1, typeof(CreationFailed))]
[MemoryPackUnion(tag: 2, typeof(UnknownError))]
public abstract partial record CreateUserResult {
private CreateUserResult() {}
[MemoryPackable(GenerateType.VersionTolerant)] [MemoryPackable(GenerateType.VersionTolerant)]
public sealed partial record Success([property: MemoryPackOrder(0)] UserInfo User) : CreateUserResult; public sealed partial record Success([property: MemoryPackOrder(0)] UserInfo User) : CreateUserResult;

View File

@@ -1,18 +1,15 @@
using MemoryPack; using MemoryPack;
using Phantom.Common.Data.Web.Users.PasswordRequirementViolations;
namespace Phantom.Common.Data.Web.Users { namespace Phantom.Common.Data.Web.Users;
[MemoryPackable]
[MemoryPackUnion(tag: 0, typeof(TooShort))]
[MemoryPackUnion(tag: 1, typeof(MustContainLowercaseLetter))]
[MemoryPackUnion(tag: 2, typeof(MustContainUppercaseLetter))]
[MemoryPackUnion(tag: 3, typeof(MustContainDigit))]
public abstract partial record PasswordRequirementViolation {
internal PasswordRequirementViolation() {}
}
}
namespace Phantom.Common.Data.Web.Users.PasswordRequirementViolations { [MemoryPackable]
[MemoryPackUnion(tag: 0, typeof(TooShort))]
[MemoryPackUnion(tag: 1, typeof(MustContainLowercaseLetter))]
[MemoryPackUnion(tag: 2, typeof(MustContainUppercaseLetter))]
[MemoryPackUnion(tag: 3, typeof(MustContainDigit))]
public abstract partial record PasswordRequirementViolation {
private PasswordRequirementViolation() {}
[MemoryPackable(GenerateType.VersionTolerant)] [MemoryPackable(GenerateType.VersionTolerant)]
public sealed partial record TooShort([property: MemoryPackOrder(0)] int MinimumLength) : PasswordRequirementViolation; public sealed partial record TooShort([property: MemoryPackOrder(0)] int MinimumLength) : PasswordRequirementViolation;

View File

@@ -1,18 +1,15 @@
using System.Collections.Immutable; using System.Collections.Immutable;
using MemoryPack; using MemoryPack;
using Phantom.Common.Data.Web.Users.SetUserPasswordErrors;
namespace Phantom.Common.Data.Web.Users { namespace Phantom.Common.Data.Web.Users;
[MemoryPackable]
[MemoryPackUnion(tag: 0, typeof(UserNotFound))]
[MemoryPackUnion(tag: 1, typeof(PasswordIsInvalid))]
[MemoryPackUnion(tag: 2, typeof(UnknownError))]
public abstract partial record SetUserPasswordError {
internal SetUserPasswordError() {}
}
}
namespace Phantom.Common.Data.Web.Users.SetUserPasswordErrors { [MemoryPackable]
[MemoryPackUnion(tag: 0, typeof(UserNotFound))]
[MemoryPackUnion(tag: 1, typeof(PasswordIsInvalid))]
[MemoryPackUnion(tag: 2, typeof(UnknownError))]
public abstract partial record SetUserPasswordError {
private SetUserPasswordError() {}
[MemoryPackable(GenerateType.VersionTolerant)] [MemoryPackable(GenerateType.VersionTolerant)]
public sealed partial record UserNotFound : SetUserPasswordError; public sealed partial record UserNotFound : SetUserPasswordError;

View File

@@ -4,22 +4,22 @@ using Phantom.Common.Data.Replies;
namespace Phantom.Common.Data.Web.Users; namespace Phantom.Common.Data.Web.Users;
[MemoryPackable] [MemoryPackable]
[MemoryPackUnion(tag: 0, typeof(OfUserActionFailure))] [MemoryPackUnion(tag: 0, typeof(User))]
[MemoryPackUnion(tag: 1, typeof(OfInstanceActionFailure))] [MemoryPackUnion(tag: 1, typeof(Instance))]
public abstract partial record UserInstanceActionFailure { public abstract partial record UserInstanceActionFailure {
internal UserInstanceActionFailure() {} private UserInstanceActionFailure() {}
[MemoryPackable(GenerateType.VersionTolerant)]
public sealed partial record User([property: MemoryPackOrder(0)] UserActionFailure Failure) : UserInstanceActionFailure;
[MemoryPackable(GenerateType.VersionTolerant)]
public sealed partial record Instance([property: MemoryPackOrder(0)] InstanceActionFailure Failure) : UserInstanceActionFailure;
public static implicit operator UserInstanceActionFailure(UserActionFailure failure) { public static implicit operator UserInstanceActionFailure(UserActionFailure failure) {
return new OfUserActionFailure(failure); return new User(failure);
} }
public static implicit operator UserInstanceActionFailure(InstanceActionFailure failure) { public static implicit operator UserInstanceActionFailure(InstanceActionFailure failure) {
return new OfInstanceActionFailure(failure); return new Instance(failure);
} }
} }
[MemoryPackable(GenerateType.VersionTolerant)]
public sealed partial record OfUserActionFailure([property: MemoryPackOrder(0)] UserActionFailure Failure) : UserInstanceActionFailure;
[MemoryPackable(GenerateType.VersionTolerant)]
public sealed partial record OfInstanceActionFailure([property: MemoryPackOrder(0)] InstanceActionFailure Failure) : UserInstanceActionFailure;

View File

@@ -1,16 +1,13 @@
using MemoryPack; using MemoryPack;
using Phantom.Common.Data.Web.Users.UsernameRequirementViolations;
namespace Phantom.Common.Data.Web.Users { namespace Phantom.Common.Data.Web.Users;
[MemoryPackable]
[MemoryPackUnion(tag: 0, typeof(IsEmpty))]
[MemoryPackUnion(tag: 1, typeof(TooLong))]
public abstract partial record UsernameRequirementViolation {
internal UsernameRequirementViolation() {}
}
}
namespace Phantom.Common.Data.Web.Users.UsernameRequirementViolations { [MemoryPackable]
[MemoryPackUnion(tag: 0, typeof(IsEmpty))]
[MemoryPackUnion(tag: 1, typeof(TooLong))]
public abstract partial record UsernameRequirementViolation {
private UsernameRequirementViolation() {}
[MemoryPackable(GenerateType.VersionTolerant)] [MemoryPackable(GenerateType.VersionTolerant)]
public sealed partial record IsEmpty : UsernameRequirementViolation; public sealed partial record IsEmpty : UsernameRequirementViolation;

View File

@@ -3,7 +3,7 @@
namespace Phantom.Common.Data.Instance; namespace Phantom.Common.Data.Instance;
[MemoryPackable(GenerateType.VersionTolerant)] [MemoryPackable(GenerateType.VersionTolerant)]
public readonly partial record struct InstancePlayerCounts( public sealed partial record InstancePlayerCounts(
[property: MemoryPackOrder(0)] int Online, [property: MemoryPackOrder(0)] int Online,
[property: MemoryPackOrder(1)] int Maximum [property: MemoryPackOrder(1)] int Maximum
); );

View File

@@ -3,7 +3,7 @@
namespace Phantom.Common.Data.Minecraft; namespace Phantom.Common.Data.Minecraft;
[MemoryPackable(GenerateType.VersionTolerant)] [MemoryPackable(GenerateType.VersionTolerant)]
public readonly partial record struct MinecraftStopStrategy( public sealed partial record MinecraftStopStrategy(
[property: MemoryPackOrder(0)] ushort Seconds [property: MemoryPackOrder(0)] ushort Seconds
) { ) {
public static MinecraftStopStrategy Instant => new (0); public static MinecraftStopStrategy Instant => new (0);

View File

@@ -3,7 +3,7 @@ using MemoryPack;
namespace Phantom.Common.Data; namespace Phantom.Common.Data;
[MemoryPackable(GenerateType.VersionTolerant)] [MemoryPackable]
readonly partial record struct PortRange( readonly partial record struct PortRange(
[property: MemoryPackOrder(0)] ushort FirstPort, [property: MemoryPackOrder(0)] ushort FirstPort,
[property: MemoryPackOrder(1)] ushort LastPort [property: MemoryPackOrder(1)] ushort LastPort

View File

@@ -6,7 +6,7 @@ namespace Phantom.Common.Data;
/// <summary> /// <summary>
/// Represents a number of RAM allocation units, using the conversion factor of 256 MB per unit. Supports allocations up to 16 TB minus 256 MB (65535 units). /// Represents a number of RAM allocation units, using the conversion factor of 256 MB per unit. Supports allocations up to 16 TB minus 256 MB (65535 units).
/// </summary> /// </summary>
[MemoryPackable(GenerateType.VersionTolerant)] [MemoryPackable]
[SuppressMessage("ReSharper", "MemberCanBePrivate.Global")] [SuppressMessage("ReSharper", "MemberCanBePrivate.Global")]
public readonly partial record struct RamAllocationUnits( public readonly partial record struct RamAllocationUnits(
[property: MemoryPackOrder(0)] ushort RawValue [property: MemoryPackOrder(0)] ushort RawValue

View File

@@ -10,23 +10,18 @@ public static class AgentMessageRegistries {
public static MessageRegistry<IMessageToAgent> ToAgent { get; } = new (nameof(ToAgent)); public static MessageRegistry<IMessageToAgent> ToAgent { get; } = new (nameof(ToAgent));
public static MessageRegistry<IMessageToController> ToController { get; } = new (nameof(ToController)); public static MessageRegistry<IMessageToController> ToController { get; } = new (nameof(ToController));
public static IMessageDefinitions<IMessageToController, IMessageToAgent> Definitions { get; } = new MessageDefinitions(); public static MessageRegistries<IMessageToController, IMessageToAgent> Registries => new (ToAgent, ToController);
static AgentMessageRegistries() { static AgentMessageRegistries() {
ToAgent.Add<ConfigureInstanceMessage, Result<ConfigureInstanceResult, InstanceActionFailure>>(1); ToAgent.Add<ConfigureInstanceMessage, Result<ConfigureInstanceResult, InstanceActionFailure>>();
ToAgent.Add<LaunchInstanceMessage, Result<LaunchInstanceResult, InstanceActionFailure>>(2); ToAgent.Add<LaunchInstanceMessage, Result<LaunchInstanceResult, InstanceActionFailure>>();
ToAgent.Add<StopInstanceMessage, Result<StopInstanceResult, InstanceActionFailure>>(3); ToAgent.Add<StopInstanceMessage, Result<StopInstanceResult, InstanceActionFailure>>();
ToAgent.Add<SendCommandToInstanceMessage, Result<SendCommandToInstanceResult, InstanceActionFailure>>(4); ToAgent.Add<SendCommandToInstanceMessage, Result<SendCommandToInstanceResult, InstanceActionFailure>>();
ToController.Add<ReportInstanceStatusMessage>(1); ToController.Add<ReportInstanceStatusMessage>();
ToController.Add<InstanceOutputMessage>(2); ToController.Add<InstanceOutputMessage>();
ToController.Add<ReportAgentStatusMessage>(3); ToController.Add<ReportAgentStatusMessage>();
ToController.Add<ReportInstanceEventMessage>(4); ToController.Add<ReportInstanceEventMessage>();
ToController.Add<ReportInstancePlayerCountsMessage>(5); ToController.Add<ReportInstancePlayerCountsMessage>();
}
private sealed class MessageDefinitions : IMessageDefinitions<IMessageToController, IMessageToAgent> {
public MessageRegistry<IMessageToAgent> ToClient => ToAgent;
public MessageRegistry<IMessageToController> ToServer => ToController;
} }
} }

View File

@@ -17,36 +17,31 @@ public static class WebMessageRegistries {
public static MessageRegistry<IMessageToController> ToController { get; } = new (nameof(ToController)); public static MessageRegistry<IMessageToController> ToController { get; } = new (nameof(ToController));
public static MessageRegistry<IMessageToWeb> ToWeb { get; } = new (nameof(ToWeb)); public static MessageRegistry<IMessageToWeb> ToWeb { get; } = new (nameof(ToWeb));
public static IMessageDefinitions<IMessageToController, IMessageToWeb> Definitions { get; } = new MessageDefinitions(); public static MessageRegistries<IMessageToController, IMessageToWeb> Registries => new (ToWeb, ToController);
static WebMessageRegistries() { static WebMessageRegistries() {
ToController.Add<LogInMessage, Optional<LogInSuccess>>(1); ToController.Add<LogInMessage, Optional<LogInSuccess>>();
ToController.Add<LogOutMessage>(2); ToController.Add<LogOutMessage>();
ToController.Add<GetAuthenticatedUser, Optional<AuthenticatedUserInfo>>(3); ToController.Add<GetAuthenticatedUser, Optional<AuthenticatedUserInfo>>();
ToController.Add<CreateOrUpdateAdministratorUserMessage, CreateOrUpdateAdministratorUserResult>(4); ToController.Add<CreateOrUpdateAdministratorUserMessage, CreateOrUpdateAdministratorUserResult>();
ToController.Add<CreateUserMessage, Result<CreateUserResult, UserActionFailure>>(5); ToController.Add<CreateUserMessage, Result<CreateUserResult, UserActionFailure>>();
ToController.Add<DeleteUserMessage, Result<DeleteUserResult, UserActionFailure>>(6); ToController.Add<DeleteUserMessage, Result<DeleteUserResult, UserActionFailure>>();
ToController.Add<GetUsersMessage, ImmutableArray<UserInfo>>(7); ToController.Add<GetUsersMessage, ImmutableArray<UserInfo>>();
ToController.Add<GetRolesMessage, ImmutableArray<RoleInfo>>(8); ToController.Add<GetRolesMessage, ImmutableArray<RoleInfo>>();
ToController.Add<GetUserRolesMessage, ImmutableDictionary<Guid, ImmutableArray<Guid>>>(9); ToController.Add<GetUserRolesMessage, ImmutableDictionary<Guid, ImmutableArray<Guid>>>();
ToController.Add<ChangeUserRolesMessage, Result<ChangeUserRolesResult, UserActionFailure>>(10); ToController.Add<ChangeUserRolesMessage, Result<ChangeUserRolesResult, UserActionFailure>>();
ToController.Add<CreateOrUpdateInstanceMessage, Result<CreateOrUpdateInstanceResult, UserInstanceActionFailure>>(11); ToController.Add<CreateOrUpdateInstanceMessage, Result<CreateOrUpdateInstanceResult, UserInstanceActionFailure>>();
ToController.Add<LaunchInstanceMessage, Result<LaunchInstanceResult, UserInstanceActionFailure>>(12); ToController.Add<LaunchInstanceMessage, Result<LaunchInstanceResult, UserInstanceActionFailure>>();
ToController.Add<StopInstanceMessage, Result<StopInstanceResult, UserInstanceActionFailure>>(13); ToController.Add<StopInstanceMessage, Result<StopInstanceResult, UserInstanceActionFailure>>();
ToController.Add<SendCommandToInstanceMessage, Result<SendCommandToInstanceResult, UserInstanceActionFailure>>(14); ToController.Add<SendCommandToInstanceMessage, Result<SendCommandToInstanceResult, UserInstanceActionFailure>>();
ToController.Add<GetMinecraftVersionsMessage, ImmutableArray<MinecraftVersion>>(15); ToController.Add<GetMinecraftVersionsMessage, ImmutableArray<MinecraftVersion>>();
ToController.Add<GetAgentJavaRuntimesMessage, ImmutableDictionary<Guid, ImmutableArray<TaggedJavaRuntime>>>(16); ToController.Add<GetAgentJavaRuntimesMessage, ImmutableDictionary<Guid, ImmutableArray<TaggedJavaRuntime>>>();
ToController.Add<GetAuditLogMessage, Result<ImmutableArray<AuditLogItem>, UserActionFailure>>(17); ToController.Add<GetAuditLogMessage, Result<ImmutableArray<AuditLogItem>, UserActionFailure>>();
ToController.Add<GetEventLogMessage, Result<ImmutableArray<EventLogItem>, UserActionFailure>>(18); ToController.Add<GetEventLogMessage, Result<ImmutableArray<EventLogItem>, UserActionFailure>>();
ToWeb.Add<RefreshAgentsMessage>(1); ToWeb.Add<RefreshAgentsMessage>();
ToWeb.Add<RefreshInstancesMessage>(2); ToWeb.Add<RefreshInstancesMessage>();
ToWeb.Add<InstanceOutputMessage>(3); ToWeb.Add<InstanceOutputMessage>();
ToWeb.Add<RefreshUserSessionMessage>(4); ToWeb.Add<RefreshUserSessionMessage>();
}
private sealed class MessageDefinitions : IMessageDefinitions<IMessageToController, IMessageToWeb> {
public MessageRegistry<IMessageToWeb> ToClient => ToWeb;
public MessageRegistry<IMessageToController> ToServer => ToController;
} }
} }

View File

@@ -2,9 +2,6 @@
using Microsoft.EntityFrameworkCore; using Microsoft.EntityFrameworkCore;
using Phantom.Common.Data; using Phantom.Common.Data;
using Phantom.Common.Data.Web.Users; using Phantom.Common.Data.Web.Users;
using Phantom.Common.Data.Web.Users.AddUserErrors;
using Phantom.Common.Data.Web.Users.PasswordRequirementViolations;
using Phantom.Common.Data.Web.Users.UsernameRequirementViolations;
using Phantom.Controller.Database.Entities; using Phantom.Controller.Database.Entities;
using Phantom.Utils.Collections; using Phantom.Utils.Collections;
@@ -16,10 +13,10 @@ public sealed class UserRepository {
private static UsernameRequirementViolation? CheckUsernameRequirements(string username) { private static UsernameRequirementViolation? CheckUsernameRequirements(string username) {
if (string.IsNullOrWhiteSpace(username)) { if (string.IsNullOrWhiteSpace(username)) {
return new IsEmpty(); return new UsernameRequirementViolation.IsEmpty();
} }
else if (username.Length > MaxUserNameLength) { else if (username.Length > MaxUserNameLength) {
return new TooLong(MaxUserNameLength); return new UsernameRequirementViolation.TooLong(MaxUserNameLength);
} }
else { else {
return null; return null;
@@ -30,19 +27,19 @@ public sealed class UserRepository {
var violations = ImmutableArray.CreateBuilder<PasswordRequirementViolation>(); var violations = ImmutableArray.CreateBuilder<PasswordRequirementViolation>();
if (password.Length < MinimumPasswordLength) { if (password.Length < MinimumPasswordLength) {
violations.Add(new TooShort(MinimumPasswordLength)); violations.Add(new PasswordRequirementViolation.TooShort(MinimumPasswordLength));
} }
if (!password.Any(char.IsLower)) { if (!password.Any(char.IsLower)) {
violations.Add(new MustContainLowercaseLetter()); violations.Add(new PasswordRequirementViolation.MustContainLowercaseLetter());
} }
if (!password.Any(char.IsUpper)) { if (!password.Any(char.IsUpper)) {
violations.Add(new MustContainUppercaseLetter()); violations.Add(new PasswordRequirementViolation.MustContainUppercaseLetter());
} }
if (!password.Any(char.IsDigit)) { if (!password.Any(char.IsDigit)) {
violations.Add(new MustContainDigit()); violations.Add(new PasswordRequirementViolation.MustContainDigit());
} }
return violations.ToImmutable(); return violations.ToImmutable();
@@ -73,16 +70,16 @@ public sealed class UserRepository {
public async Task<Result<UserEntity, AddUserError>> CreateUser(string username, string password) { public async Task<Result<UserEntity, AddUserError>> CreateUser(string username, string password) {
var usernameRequirementViolation = CheckUsernameRequirements(username); var usernameRequirementViolation = CheckUsernameRequirements(username);
if (usernameRequirementViolation != null) { if (usernameRequirementViolation != null) {
return new NameIsInvalid(usernameRequirementViolation); return new AddUserError.NameIsInvalid(usernameRequirementViolation);
} }
var passwordRequirementViolations = CheckPasswordRequirements(password); var passwordRequirementViolations = CheckPasswordRequirements(password);
if (!passwordRequirementViolations.IsEmpty) { if (!passwordRequirementViolations.IsEmpty) {
return new PasswordIsInvalid(passwordRequirementViolations); return new AddUserError.PasswordIsInvalid(passwordRequirementViolations);
} }
if (await db.Ctx.Users.AnyAsync(user => user.Name == username)) { if (await db.Ctx.Users.AnyAsync(user => user.Name == username)) {
return new NameAlreadyExists(); return new AddUserError.NameAlreadyExists();
} }
var user = new UserEntity(Guid.NewGuid(), username, UserPasswords.Hash(password)); var user = new UserEntity(Guid.NewGuid(), username, UserPasswords.Hash(password));
@@ -95,7 +92,7 @@ public sealed class UserRepository {
public Result<SetUserPasswordError> SetUserPassword(UserEntity user, string password) { public Result<SetUserPasswordError> SetUserPassword(UserEntity user, string password) {
var requirementViolations = CheckPasswordRequirements(password); var requirementViolations = CheckPasswordRequirements(password);
if (!requirementViolations.IsEmpty) { if (!requirementViolations.IsEmpty) {
return new Common.Data.Web.Users.SetUserPasswordErrors.PasswordIsInvalid(requirementViolations); return new SetUserPasswordError.PasswordIsInvalid(requirementViolations);
} }
user.PasswordHash = UserPasswords.Hash(password); user.PasswordHash = UserPasswords.Hash(password);

View File

@@ -26,7 +26,7 @@ using Serilog;
namespace Phantom.Controller.Services.Agents; namespace Phantom.Controller.Services.Agents;
sealed class AgentActor : ReceiveActor<AgentActor.ICommand> { sealed class AgentActor : ReceiveActor<AgentActor.ICommand>, IWithTimers {
private static readonly ILogger Logger = PhantomLogger.Create<AgentActor>(); private static readonly ILogger Logger = PhantomLogger.Create<AgentActor>();
private static readonly TimeSpan DisconnectionRecheckInterval = TimeSpan.FromSeconds(5); private static readonly TimeSpan DisconnectionRecheckInterval = TimeSpan.FromSeconds(5);
@@ -38,6 +38,8 @@ sealed class AgentActor : ReceiveActor<AgentActor.ICommand> {
return Props<ICommand>.Create(() => new AgentActor(init), new ActorConfiguration { SupervisorStrategy = SupervisorStrategies.Resume, MailboxType = UnboundedJumpAheadMailbox.Name }); return Props<ICommand>.Create(() => new AgentActor(init), new ActorConfiguration { SupervisorStrategy = SupervisorStrategies.Resume, MailboxType = UnboundedJumpAheadMailbox.Name });
} }
public ITimerScheduler Timers { get; set; } = null!;
private readonly ControllerState controllerState; private readonly ControllerState controllerState;
private readonly MinecraftVersions minecraftVersions; private readonly MinecraftVersions minecraftVersions;
private readonly IDbContextProvider dbProvider; private readonly IDbContextProvider dbProvider;
@@ -110,7 +112,7 @@ sealed class AgentActor : ReceiveActor<AgentActor.ICommand> {
protected override void PreStart() { protected override void PreStart() {
Self.Tell(new InitializeCommand()); Self.Tell(new InitializeCommand());
Context.System.Scheduler.ScheduleTellRepeatedly(DisconnectionRecheckInterval, DisconnectionRecheckInterval, Self, new RefreshConnectionStatusCommand(), Self); Timers.StartPeriodicTimer("RefreshConnectionStatus", new RefreshConnectionStatusCommand(), DisconnectionRecheckInterval, Self);
} }
private ActorRef<InstanceActor.ICommand> CreateNewInstance(Instance instance) { private ActorRef<InstanceActor.ICommand> CreateNewInstance(Instance instance) {

View File

@@ -1,4 +1,5 @@
using Phantom.Common.Data.Web.Agent; using Akka.Actor;
using Phantom.Common.Data.Web.Agent;
using Phantom.Controller.Database; using Phantom.Controller.Database;
using Phantom.Utils.Actor; using Phantom.Utils.Actor;
using Phantom.Utils.Logging; using Phantom.Utils.Logging;
@@ -6,7 +7,7 @@ using Serilog;
namespace Phantom.Controller.Services.Agents; namespace Phantom.Controller.Services.Agents;
sealed class AgentDatabaseStorageActor : ReceiveActor<AgentDatabaseStorageActor.ICommand> { sealed class AgentDatabaseStorageActor : ReceiveActor<AgentDatabaseStorageActor.ICommand>, IWithTimers {
private static readonly ILogger Logger = PhantomLogger.Create<AgentDatabaseStorageActor>(); private static readonly ILogger Logger = PhantomLogger.Create<AgentDatabaseStorageActor>();
public readonly record struct Init(Guid AgentGuid, IDbContextProvider DbProvider, CancellationToken CancellationToken); public readonly record struct Init(Guid AgentGuid, IDbContextProvider DbProvider, CancellationToken CancellationToken);
@@ -15,6 +16,8 @@ sealed class AgentDatabaseStorageActor : ReceiveActor<AgentDatabaseStorageActor.
return Props<ICommand>.Create(() => new AgentDatabaseStorageActor(init), new ActorConfiguration { SupervisorStrategy = SupervisorStrategies.Resume }); return Props<ICommand>.Create(() => new AgentDatabaseStorageActor(init), new ActorConfiguration { SupervisorStrategy = SupervisorStrategies.Resume });
} }
public ITimerScheduler Timers { get; set; } = null!;
private readonly Guid agentGuid; private readonly Guid agentGuid;
private readonly IDbContextProvider dbProvider; private readonly IDbContextProvider dbProvider;
private readonly CancellationToken cancellationToken; private readonly CancellationToken cancellationToken;
@@ -74,7 +77,7 @@ sealed class AgentDatabaseStorageActor : ReceiveActor<AgentDatabaseStorageActor.
private void ScheduleFlush(TimeSpan delay) { private void ScheduleFlush(TimeSpan delay) {
if (!hasScheduledFlush) { if (!hasScheduledFlush) {
hasScheduledFlush = true; hasScheduledFlush = true;
Context.System.Scheduler.ScheduleTellOnce(delay, Self, new FlushChangesCommand(), Self); Timers.StartSingleTimer("FlushChanges", new FlushChangesCommand(), delay, Self);
} }
} }
} }

View File

@@ -1,4 +1,5 @@
using Phantom.Common.Data.Web.EventLog; using Akka.Actor;
using Phantom.Common.Data.Web.EventLog;
using Phantom.Controller.Database; using Phantom.Controller.Database;
using Phantom.Controller.Database.Repositories; using Phantom.Controller.Database.Repositories;
using Phantom.Utils.Actor; using Phantom.Utils.Actor;
@@ -7,7 +8,7 @@ using Serilog;
namespace Phantom.Controller.Services.Events; namespace Phantom.Controller.Services.Events;
sealed class EventLogDatabaseStorageActor : ReceiveActor<EventLogDatabaseStorageActor.ICommand> { sealed class EventLogDatabaseStorageActor : ReceiveActor<EventLogDatabaseStorageActor.ICommand>, IWithTimers {
private static readonly ILogger Logger = PhantomLogger.Create<EventLogDatabaseStorageActor>(); private static readonly ILogger Logger = PhantomLogger.Create<EventLogDatabaseStorageActor>();
public readonly record struct Init(IDbContextProvider DbProvider, CancellationToken CancellationToken); public readonly record struct Init(IDbContextProvider DbProvider, CancellationToken CancellationToken);
@@ -16,6 +17,8 @@ sealed class EventLogDatabaseStorageActor : ReceiveActor<EventLogDatabaseStorage
return Props<ICommand>.Create(() => new EventLogDatabaseStorageActor(init), new ActorConfiguration { SupervisorStrategy = SupervisorStrategies.Resume }); return Props<ICommand>.Create(() => new EventLogDatabaseStorageActor(init), new ActorConfiguration { SupervisorStrategy = SupervisorStrategies.Resume });
} }
public ITimerScheduler Timers { get; set; } = null!;
private readonly IDbContextProvider dbProvider; private readonly IDbContextProvider dbProvider;
private readonly CancellationToken cancellationToken; private readonly CancellationToken cancellationToken;
@@ -71,7 +74,7 @@ sealed class EventLogDatabaseStorageActor : ReceiveActor<EventLogDatabaseStorage
private void ScheduleFlush(TimeSpan delay) { private void ScheduleFlush(TimeSpan delay) {
if (!hasScheduledFlush) { if (!hasScheduledFlush) {
hasScheduledFlush = true; hasScheduledFlush = true;
Context.System.Scheduler.ScheduleTellOnce(delay, Self, new FlushChangesCommand(), Self); Timers.StartSingleTimer("FlushChanges", new FlushChangesCommand(), delay, Self);
} }
} }
} }

View File

@@ -1,7 +1,6 @@
using System.Collections.Immutable; using System.Collections.Immutable;
using Phantom.Common.Data; using Phantom.Common.Data;
using Phantom.Common.Data.Web.Users; using Phantom.Common.Data.Web.Users;
using Phantom.Common.Data.Web.Users.CreateOrUpdateAdministratorUserResults;
using Phantom.Controller.Database; using Phantom.Controller.Database;
using Phantom.Controller.Database.Entities; using Phantom.Controller.Database.Entities;
using Phantom.Controller.Database.Repositories; using Phantom.Controller.Database.Repositories;
@@ -57,12 +56,12 @@ sealed class UserManager {
wasCreated = true; wasCreated = true;
} }
else { else {
return new CreationFailed(result.Error); return new CreateOrUpdateAdministratorUserResult.CreationFailed(result.Error);
} }
} }
else { else {
if (userRepository.SetUserPassword(user, password).TryGetError(out var error)) { if (userRepository.SetUserPassword(user, password).TryGetError(out var error)) {
return new UpdatingFailed(error); return new CreateOrUpdateAdministratorUserResult.UpdatingFailed(error);
} }
auditLogWriter.AdministratorUserModified(user); auditLogWriter.AdministratorUserModified(user);
@@ -71,7 +70,7 @@ sealed class UserManager {
var role = await new RoleRepository(db).GetByGuid(Role.Administrator.Guid); var role = await new RoleRepository(db).GetByGuid(Role.Administrator.Guid);
if (role == null) { if (role == null) {
return new AddingToRoleFailed(); return new CreateOrUpdateAdministratorUserResult.AddingToRoleFailed();
} }
await new UserRoleRepository(db).Add(user, role); await new UserRoleRepository(db).Add(user, role);
@@ -85,10 +84,10 @@ sealed class UserManager {
Logger.Information("Updated administrator user \"{Username}\" (GUID {Guid}).", username, user.UserGuid); Logger.Information("Updated administrator user \"{Username}\" (GUID {Guid}).", username, user.UserGuid);
} }
return new Success(user.ToUserInfo()); return new CreateOrUpdateAdministratorUserResult.Success(user.ToUserInfo());
} catch (Exception e) { } catch (Exception e) {
Logger.Error(e, "Could not create or update administrator user \"{Username}\".", username); Logger.Error(e, "Could not create or update administrator user \"{Username}\".", username);
return new UnknownError(); return new CreateOrUpdateAdministratorUserResult.UnknownError();
} }
} }
@@ -104,7 +103,7 @@ sealed class UserManager {
try { try {
var result = await userRepository.CreateUser(username, password); var result = await userRepository.CreateUser(username, password);
if (!result) { if (!result) {
return new Common.Data.Web.Users.CreateUserResults.CreationFailed(result.Error); return new CreateUserResult.CreationFailed(result.Error);
} }
var user = result.Value; var user = result.Value;
@@ -113,10 +112,10 @@ sealed class UserManager {
await db.Ctx.SaveChangesAsync(); await db.Ctx.SaveChangesAsync();
Logger.Information("Created user \"{Username}\" (GUID {Guid}).", username, user.UserGuid); Logger.Information("Created user \"{Username}\" (GUID {Guid}).", username, user.UserGuid);
return new Common.Data.Web.Users.CreateUserResults.Success(user.ToUserInfo()); return new CreateUserResult.Success(user.ToUserInfo());
} catch (Exception e) { } catch (Exception e) {
Logger.Error(e, "Could not create user \"{Username}\".", username); Logger.Error(e, "Could not create user \"{Username}\".", username);
return new Common.Data.Web.Users.CreateUserResults.UnknownError(); return new CreateUserResult.UnknownError();
} }
} }

View File

@@ -81,8 +81,8 @@ try {
); );
var rpcServerTasks = new LinkedTasks<bool>([ var rpcServerTasks = new LinkedTasks<bool>([
new RpcAgentServer("Agent", agentConnectionParameters, AgentMessageRegistries.Definitions, controllerServices.AgentHandshake, controllerServices.AgentRegistrar).Run(shutdownCancellationToken), new RpcAgentServer("Agent", agentConnectionParameters, AgentMessageRegistries.Registries, controllerServices.AgentHandshake, controllerServices.AgentRegistrar).Run(shutdownCancellationToken),
new RpcWebServer("Web", webConnectionParameters, WebMessageRegistries.Definitions, new RpcServerClientHandshake.NoOp(), controllerServices.WebRegistrar).Run(shutdownCancellationToken), new RpcWebServer("Web", webConnectionParameters, WebMessageRegistries.Registries, new RpcServerClientHandshake.NoOp(), controllerServices.WebRegistrar).Run(shutdownCancellationToken),
]); ]);
// If either RPC server crashes, stop the whole process. // If either RPC server crashes, stop the whole process.

View File

@@ -1,35 +1,31 @@
<Project> <Project>
<ItemGroup> <ItemGroup>
<PackageReference Update="Microsoft.AspNetCore.Components.Authorization" Version="8.0.0" /> <PackageReference Update="Microsoft.AspNetCore.Components.Authorization" Version="9.0.9" />
<PackageReference Update="Microsoft.AspNetCore.Components.Web" Version="8.0.0" /> <PackageReference Update="Microsoft.AspNetCore.Components.Web" Version="9.0.9" />
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>
<PackageReference Update="Microsoft.EntityFrameworkCore.Relational" Version="8.0.0" /> <PackageReference Update="Microsoft.EntityFrameworkCore.Relational" Version="9.0.9" />
<PackageReference Update="Microsoft.EntityFrameworkCore.Tools" Version="8.0.0" /> <PackageReference Update="Microsoft.EntityFrameworkCore.Tools" Version="9.0.9" />
<PackageReference Update="Npgsql.EntityFrameworkCore.PostgreSQL" Version="8.0.0" /> <PackageReference Update="Npgsql.EntityFrameworkCore.PostgreSQL" Version="9.0.4" />
<PackageReference Update="System.Linq.Async" Version="6.0.1" /> <PackageReference Update="System.Linq.Async" Version="6.0.3" />
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>
<PackageReference Update="Kajabity.Tools.Java" Version="0.3.8607.38728" /> <PackageReference Update="Akka" Version="1.5.51" />
</ItemGroup>
<ItemGroup>
<PackageReference Update="Akka" Version="1.5.17.1" />
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>
<PackageReference Update="BCrypt.Net-Next.StrongName" Version="4.0.3" /> <PackageReference Update="BCrypt.Net-Next.StrongName" Version="4.0.3" />
<PackageReference Update="MemoryPack" Version="1.10.0" /> <PackageReference Update="MemoryPack" Version="1.21.4" />
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>
<PackageReference Update="Serilog" Version="3.1.1" /> <PackageReference Update="Serilog" Version="4.3.0" />
<PackageReference Update="Serilog.AspNetCore" Version="8.0.0" /> <PackageReference Update="Serilog.AspNetCore" Version="9.0.0" />
<PackageReference Update="Serilog.Sinks.Async" Version="1.5.0" /> <PackageReference Update="Serilog.Sinks.Async" Version="2.1.0" />
<PackageReference Update="Serilog.Sinks.Console" Version="5.0.1" /> <PackageReference Update="Serilog.Sinks.Console" Version="6.0.0" />
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>

View File

@@ -2,6 +2,8 @@
Microsoft Visual Studio Solution File, Format Version 12.00 Microsoft Visual Studio Solution File, Format Version 12.00
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Agent", "Agent", "{F5878792-64C8-4ECF-A075-66341FF97127}" Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Agent", "Agent", "{F5878792-64C8-4ECF-A075-66341FF97127}"
EndProject EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Agent.Tests", "Agent.Tests", "{94C1E464-3F91-49EA-99FF-3A3082C54CE8}"
EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Common", "Common", "{01CB1A81-8950-471C-BFDF-F135FDDB2C18}" Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Common", "Common", "{01CB1A81-8950-471C-BFDF-F135FDDB2C18}"
EndProject EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Common.Tests", "Common.Tests", "{D781E00D-8563-4102-A0CD-477A679193B5}" Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Common.Tests", "Common.Tests", "{D781E00D-8563-4102-A0CD-477A679193B5}"
@@ -18,6 +20,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Phantom.Agent", "Agent\Phan
EndProject EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Phantom.Agent.Minecraft", "Agent\Phantom.Agent.Minecraft\Phantom.Agent.Minecraft.csproj", "{9FE000D0-91AC-4CB4-8956-91CCC0270015}" Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Phantom.Agent.Minecraft", "Agent\Phantom.Agent.Minecraft\Phantom.Agent.Minecraft.csproj", "{9FE000D0-91AC-4CB4-8956-91CCC0270015}"
EndProject EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Phantom.Agent.Minecraft.Tests", "Agent\Phantom.Agent.Minecraft.Tests\Phantom.Agent.Minecraft.Tests.csproj", "{065FFFA0-DFF4-43DB-AB3D-B92EE9848DDB}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Phantom.Agent.Services", "Agent\Phantom.Agent.Services\Phantom.Agent.Services.csproj", "{AEE8B77E-AB07-423F-9981-8CD829ACB834}" Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Phantom.Agent.Services", "Agent\Phantom.Agent.Services\Phantom.Agent.Services.csproj", "{AEE8B77E-AB07-423F-9981-8CD829ACB834}"
EndProject EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Phantom.Common.Data", "Common\Phantom.Common.Data\Phantom.Common.Data.csproj", "{6C3DB1E5-F695-4D70-8F3A-78C2957274BE}" Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Phantom.Common.Data", "Common\Phantom.Common.Data\Phantom.Common.Data.csproj", "{6C3DB1E5-F695-4D70-8F3A-78C2957274BE}"
@@ -74,6 +78,10 @@ Global
{9FE000D0-91AC-4CB4-8956-91CCC0270015}.Debug|Any CPU.Build.0 = Debug|Any CPU {9FE000D0-91AC-4CB4-8956-91CCC0270015}.Debug|Any CPU.Build.0 = Debug|Any CPU
{9FE000D0-91AC-4CB4-8956-91CCC0270015}.Release|Any CPU.ActiveCfg = Release|Any CPU {9FE000D0-91AC-4CB4-8956-91CCC0270015}.Release|Any CPU.ActiveCfg = Release|Any CPU
{9FE000D0-91AC-4CB4-8956-91CCC0270015}.Release|Any CPU.Build.0 = Release|Any CPU {9FE000D0-91AC-4CB4-8956-91CCC0270015}.Release|Any CPU.Build.0 = Release|Any CPU
{065FFFA0-DFF4-43DB-AB3D-B92EE9848DDB}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{065FFFA0-DFF4-43DB-AB3D-B92EE9848DDB}.Debug|Any CPU.Build.0 = Debug|Any CPU
{065FFFA0-DFF4-43DB-AB3D-B92EE9848DDB}.Release|Any CPU.ActiveCfg = Release|Any CPU
{065FFFA0-DFF4-43DB-AB3D-B92EE9848DDB}.Release|Any CPU.Build.0 = Release|Any CPU
{AEE8B77E-AB07-423F-9981-8CD829ACB834}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {AEE8B77E-AB07-423F-9981-8CD829ACB834}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{AEE8B77E-AB07-423F-9981-8CD829ACB834}.Debug|Any CPU.Build.0 = Debug|Any CPU {AEE8B77E-AB07-423F-9981-8CD829ACB834}.Debug|Any CPU.Build.0 = Debug|Any CPU
{AEE8B77E-AB07-423F-9981-8CD829ACB834}.Release|Any CPU.ActiveCfg = Release|Any CPU {AEE8B77E-AB07-423F-9981-8CD829ACB834}.Release|Any CPU.ActiveCfg = Release|Any CPU
@@ -158,6 +166,7 @@ Global
GlobalSection(NestedProjects) = preSolution GlobalSection(NestedProjects) = preSolution
{418BE1BF-9F63-4B46-B4E4-DF64C3B3DDA7} = {F5878792-64C8-4ECF-A075-66341FF97127} {418BE1BF-9F63-4B46-B4E4-DF64C3B3DDA7} = {F5878792-64C8-4ECF-A075-66341FF97127}
{9FE000D0-91AC-4CB4-8956-91CCC0270015} = {F5878792-64C8-4ECF-A075-66341FF97127} {9FE000D0-91AC-4CB4-8956-91CCC0270015} = {F5878792-64C8-4ECF-A075-66341FF97127}
{065FFFA0-DFF4-43DB-AB3D-B92EE9848DDB} = {94C1E464-3F91-49EA-99FF-3A3082C54CE8}
{AEE8B77E-AB07-423F-9981-8CD829ACB834} = {F5878792-64C8-4ECF-A075-66341FF97127} {AEE8B77E-AB07-423F-9981-8CD829ACB834} = {F5878792-64C8-4ECF-A075-66341FF97127}
{6C3DB1E5-F695-4D70-8F3A-78C2957274BE} = {01CB1A81-8950-471C-BFDF-F135FDDB2C18} {6C3DB1E5-F695-4D70-8F3A-78C2957274BE} = {01CB1A81-8950-471C-BFDF-F135FDDB2C18}
{95B55357-F8F0-48C2-A1C2-5EA997651783} = {01CB1A81-8950-471C-BFDF-F135FDDB2C18} {95B55357-F8F0-48C2-A1C2-5EA997651783} = {01CB1A81-8950-471C-BFDF-F135FDDB2C18}

View File

@@ -3,7 +3,7 @@ using Phantom.Utils.Rpc.Runtime;
namespace Phantom.Utils.Rpc.Frame.Types; namespace Phantom.Utils.Rpc.Frame.Types;
sealed record MessageFrame(uint MessageId, ushort RegistryCode, ReadOnlyMemory<byte> SerializedMessage) : IFrame { sealed record MessageFrame(uint MessageId, byte MessageTypeCode, ReadOnlyMemory<byte> SerializedMessage) : IFrame {
public const int MaxMessageBytes = 1024 * 1024 * 8; public const int MaxMessageBytes = 1024 * 1024 * 8;
public ReadOnlyMemory<byte> FrameType => IFrame.TypeMessage; public ReadOnlyMemory<byte> FrameType => IFrame.TypeMessage;
@@ -13,19 +13,19 @@ sealed record MessageFrame(uint MessageId, ushort RegistryCode, ReadOnlyMemory<b
CheckMessageLength(serializedMessageLength); CheckMessageLength(serializedMessageLength);
await stream.WriteUnsignedInt(MessageId, cancellationToken); await stream.WriteUnsignedInt(MessageId, cancellationToken);
await stream.WriteUnsignedShort(RegistryCode, cancellationToken); await stream.WriteByte(MessageTypeCode, cancellationToken);
await stream.WriteUnsignedInt(serializedMessageLength, cancellationToken); await stream.WriteUnsignedInt(serializedMessageLength, cancellationToken);
await stream.WriteBytes(SerializedMessage, cancellationToken); await stream.WriteBytes(SerializedMessage, cancellationToken);
} }
public static async Task<MessageFrame> Read(RpcStream stream, CancellationToken cancellationToken) { public static async Task<MessageFrame> Read(RpcStream stream, CancellationToken cancellationToken) {
var messageId = await stream.ReadUnsignedInt(cancellationToken); var messageId = await stream.ReadUnsignedInt(cancellationToken);
var registryCode = await stream.ReadUnsignedShort(cancellationToken); var messageTypeCode = await stream.ReadByte(cancellationToken);
var serializedMessageLength = await stream.ReadUnsignedInt(cancellationToken); var serializedMessageLength = await stream.ReadUnsignedInt(cancellationToken);
CheckMessageLength(serializedMessageLength); CheckMessageLength(serializedMessageLength);
var serializedMessage = await stream.ReadBytes(serializedMessageLength, cancellationToken); var serializedMessage = await stream.ReadBytes(serializedMessageLength, cancellationToken);
return new MessageFrame(messageId, registryCode, serializedMessage); return new MessageFrame(messageId, messageTypeCode, serializedMessage);
} }
private static void CheckMessageLength(uint messageLength) { private static void CheckMessageLength(uint messageLength) {

View File

@@ -1,6 +0,0 @@
namespace Phantom.Utils.Rpc.Message;
public interface IMessageDefinitions<TClientToServerMessage, TServerToClientMessage> {
MessageRegistry<TServerToClientMessage> ToClient { get; }
MessageRegistry<TClientToServerMessage> ToServer { get; }
}

View File

@@ -0,0 +1,15 @@
namespace Phantom.Utils.Rpc.Message;
public readonly record struct MessageRegistries<TClientToServerMessage, TServerToClientMessage>(
MessageRegistry<TServerToClientMessage> ToClient,
MessageRegistry<TClientToServerMessage> ToServer
) {
internal WithMapping CreateMapping() {
return new WithMapping(ToClient.CreateMapping(), ToServer.CreateMapping());
}
internal readonly record struct WithMapping(
MessageRegistry<TServerToClientMessage>.WithMapping ToClient,
MessageRegistry<TClientToServerMessage>.WithMapping ToServer
);
}

View File

@@ -1,75 +1,106 @@
using System.Diagnostics.CodeAnalysis; using System.Collections.Immutable;
using Phantom.Utils.Actor; using Phantom.Utils.Actor;
using Phantom.Utils.Logging; using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Frame.Types; using Phantom.Utils.Rpc.Runtime;
using Serilog; using Serilog;
namespace Phantom.Utils.Rpc.Message; namespace Phantom.Utils.Rpc.Message;
public sealed class MessageRegistry<TMessageBase>(string loggerName) { public sealed class MessageRegistry<TMessageBase>(string loggerName) {
private readonly ILogger logger = PhantomLogger.Create<MessageRegistry<TMessageBase>>(loggerName); private readonly ILogger logger = PhantomLogger.Create<MessageRegistry<TMessageBase>>(loggerName);
private readonly Dictionary<Type, ushort> typeToCodeMapping = new (); private readonly List<MessageInfo> messageInfoList = [];
private readonly Dictionary<ushort, Registration> codeToRegistrationMapping = new ();
private readonly record struct Registration(Type MessageType, Func<uint, ReadOnlyMemory<byte>, MessageHandler<TMessageBase>, CancellationToken, Task> Handler); private readonly record struct MessageInfo(Type Type, MessageTypeName TypeName, DeserializeAndHandleFunc Action);
public void Add<TMessage>(ushort code) where TMessage : TMessageBase { internal delegate Task DeserializeAndHandleFunc(uint messageId, ReadOnlyMemory<byte> serializedMessage, MessageHandler<TMessageBase> handler, CancellationToken cancellationToken);
Type messageType = typeof(TMessage);
public void Add<TMessage>() where TMessage : TMessageBase {
if (HasReplyType(messageType)) { if (HasReplyType(typeof(TMessage))) {
throw new ArgumentException("This overload is for messages without a reply."); throw new ArgumentException("This overload is for messages without a reply.");
} }
typeToCodeMapping.Add(messageType, code); AddImpl(typeof(TMessage), DeserializationHandler<TMessage>);
codeToRegistrationMapping.Add(code, new Registration(messageType, DeserializationHandler<TMessage>));
} }
public void Add<TMessage, TReply>(ushort code) where TMessage : TMessageBase, ICanReply<TReply> { public void Add<TMessage, TReply>() where TMessage : TMessageBase, ICanReply<TReply> {
Type messageType = typeof(TMessage); AddImpl(typeof(TMessage), DeserializationHandler<TMessage, TReply>);
typeToCodeMapping.Add(messageType, code);
codeToRegistrationMapping.Add(code, new Registration(messageType, DeserializationHandler<TMessage, TReply>));
} }
private bool HasReplyType(Type messageType) { private void AddImpl(Type messageType, DeserializeAndHandleFunc action) {
messageInfoList.Add(new MessageInfo(messageType, new MessageTypeName(messageType.Name), action));
}
private static bool HasReplyType(Type messageType) {
string replyInterfaceName = typeof(ICanReply<object>).FullName!; string replyInterfaceName = typeof(ICanReply<object>).FullName!;
replyInterfaceName = replyInterfaceName[..(replyInterfaceName.IndexOf('`') + 1)]; replyInterfaceName = replyInterfaceName[..(replyInterfaceName.IndexOf('`') + 1)];
return messageType.GetInterfaces().Any(type => type.FullName is {} name && name.StartsWith(replyInterfaceName, StringComparison.Ordinal)); return messageType.GetInterfaces().Any(type => type.FullName is {} name && name.StartsWith(replyInterfaceName, StringComparison.Ordinal));
} }
internal bool TryGetType(MessageFrame frame, [NotNullWhen(true)] out Type? type) { internal WithMapping CreateMapping() {
if (codeToRegistrationMapping.TryGetValue(frame.RegistryCode, out var registration)) { var messageTypeNames = ImmutableArray.CreateBuilder<MessageTypeName>();
type = registration.MessageType; var messageTypeMapping = new MessageTypeMapping<TMessageBase>.Builder();
return true;
}
else {
type = null;
return false;
}
}
internal MessageFrame CreateFrame<TMessage>(uint messageId, TMessage message) where TMessage : TMessageBase {
if (typeToCodeMapping.TryGetValue(typeof(TMessage), out ushort code)) {
return new MessageFrame(messageId, code, MessageSerialization.Serialize(message));
}
else {
throw new ArgumentException("Unknown message type: " + typeof(TMessage));
}
}
internal async Task Handle(MessageFrame frame, MessageHandler<TMessageBase> handler, CancellationToken cancellationToken) {
uint messageId = frame.MessageId;
if (codeToRegistrationMapping.TryGetValue(frame.RegistryCode, out var registration)) { int nextMessageCode = 0;
await registration.Handler(messageId, frame.SerializedMessage, handler, cancellationToken);
foreach ((Type messageType, MessageTypeName messageTypeName, DeserializeAndHandleFunc action) in messageInfoList) {
if (nextMessageCode == byte.MaxValue) {
throw new InvalidOperationException("Trying to register too many messages (" + (nextMessageCode + 1) + ").");
}
messageTypeNames.Add(messageTypeName);
messageTypeMapping.Add((byte) nextMessageCode++, messageType, action);
} }
else {
logger.Error("Unknown message code {Code} for message {MessageId}.", frame.RegistryCode, messageId); return new WithMapping(messageTypeNames.ToImmutable(), messageTypeMapping.Build(loggerName));
await handler.SendError(messageId, MessageError.UnknownMessageRegistryCode, cancellationToken); }
internal sealed class WithMapping(ImmutableArray<MessageTypeName> messageTypeNames, MessageTypeMapping<TMessageBase> mapping) {
public MessageTypeMapping<TMessageBase> Mapping => mapping;
public async ValueTask Write(RpcStream stream, CancellationToken cancellationToken) {
foreach (MessageTypeName typeName in messageTypeNames) {
await typeName.Write(stream, cancellationToken);
}
await MessageTypeName.WriteEnd(stream, cancellationToken);
} }
} }
internal async ValueTask<ReadMappingResult> ReadMapping(RpcStream stream, CancellationToken cancellationToken) {
var messageTypeNameToInfoMapping = messageInfoList.ToImmutableDictionary(static item => item.TypeName, static item => item);
var messageTypeMapping = new MessageTypeMapping<TMessageBase>.Builder();
var supportedMessages = ImmutableSortedDictionary.CreateBuilder<byte, MessageTypeName>();
var unsupportedMessages = ImmutableSortedDictionary.CreateBuilder<byte, MessageTypeName>();
byte nextMessageCode = 0;
while (await MessageTypeName.Read(stream, cancellationToken) is {} messageTypeName) {
if (nextMessageCode == byte.MaxValue) {
throw new InvalidOperationException("Trying to register too many messages (" + (nextMessageCode + 1) + ").");
}
if (messageTypeNameToInfoMapping.TryGetValue(messageTypeName, out var messageInfo)) {
messageTypeMapping.Add(nextMessageCode, messageInfo.Type, messageInfo.Action);
supportedMessages.Add(nextMessageCode, messageTypeName);
}
else {
unsupportedMessages.Add(nextMessageCode, messageTypeName);
}
++nextMessageCode;
}
return new ReadMappingResult(messageTypeMapping.Build(loggerName), supportedMessages.ToImmutable(), unsupportedMessages.ToImmutable());
}
internal readonly record struct ReadMappingResult(
MessageTypeMapping<TMessageBase> TypeMapping,
ImmutableSortedDictionary<byte, MessageTypeName> SupportedMessages,
ImmutableSortedDictionary<byte, MessageTypeName> UnsupportedMessages
);
private async Task DeserializationHandler<TMessage>(uint messageId, ReadOnlyMemory<byte> serializedMessage, MessageHandler<TMessageBase> handler, CancellationToken cancellationToken) where TMessage : TMessageBase { private async Task DeserializationHandler<TMessage>(uint messageId, ReadOnlyMemory<byte> serializedMessage, MessageHandler<TMessageBase> handler, CancellationToken cancellationToken) where TMessage : TMessageBase {
TMessage message; TMessage message;
try { try {

View File

@@ -0,0 +1,68 @@
using System.Collections.Frozen;
using System.Diagnostics.CodeAnalysis;
using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Frame.Types;
using Serilog;
namespace Phantom.Utils.Rpc.Message;
sealed class MessageTypeMapping<TMessageBase> {
private readonly ILogger logger;
private readonly FrozenDictionary<Type, byte> messageTypeToTypeCodeMapping;
private readonly FrozenDictionary<byte, Registration> messageTypeCodeToRegistrationMapping;
private MessageTypeMapping(string loggerName, FrozenDictionary<Type, byte> messageTypeToTypeCodeMapping, FrozenDictionary<byte, Registration> messageTypeCodeToRegistrationMapping) {
this.logger = PhantomLogger.Create<MessageTypeMapping<TMessageBase>>(loggerName);
this.messageTypeToTypeCodeMapping = messageTypeToTypeCodeMapping;
this.messageTypeCodeToRegistrationMapping = messageTypeCodeToRegistrationMapping;
}
private readonly record struct Registration(Type MessageType, MessageRegistry<TMessageBase>.DeserializeAndHandleFunc Action);
public bool TryGetType(MessageFrame frame, [NotNullWhen(true)] out Type? type) {
if (messageTypeCodeToRegistrationMapping.TryGetValue(frame.MessageTypeCode, out var registration)) {
type = registration.MessageType;
return true;
}
else {
type = null;
return false;
}
}
public MessageFrame CreateFrame<TMessage>(uint messageId, TMessage message) where TMessage : TMessageBase {
if (messageTypeToTypeCodeMapping.TryGetValue(typeof(TMessage), out byte messageTypeCode)) {
return new MessageFrame(messageId, messageTypeCode, MessageSerialization.Serialize(message));
}
else {
throw new ArgumentException("Unknown message type: " + typeof(TMessage));
}
}
public async Task Handle(MessageFrame frame, MessageHandler<TMessageBase> handler, CancellationToken cancellationToken) {
uint messageId = frame.MessageId;
if (messageTypeCodeToRegistrationMapping.TryGetValue(frame.MessageTypeCode, out var registration)) {
await registration.Action(messageId, frame.SerializedMessage, handler, cancellationToken);
}
else {
logger.Error("Unknown message code {Code} for message {MessageId}.", frame.MessageTypeCode, messageId);
await handler.SendError(messageId, MessageError.UnknownMessageRegistryCode, cancellationToken);
}
}
public sealed class Builder {
private readonly Dictionary<Type, byte> messageTypeToTypeCodeMapping = new ();
private readonly Dictionary<byte, Registration> messageTypeCodeToRegistrationMapping = new ();
public void Add(byte messageTypeCode, Type messageType, MessageRegistry<TMessageBase>.DeserializeAndHandleFunc action) {
messageTypeToTypeCodeMapping.Add(messageType, messageTypeCode);
messageTypeCodeToRegistrationMapping.Add(messageTypeCode, new Registration(messageType, action));
}
public MessageTypeMapping<TMessageBase> Build(string loggerName) {
return new MessageTypeMapping<TMessageBase>(loggerName, messageTypeToTypeCodeMapping.ToFrozenDictionary(), messageTypeCodeToRegistrationMapping.ToFrozenDictionary());
}
}
}

View File

@@ -0,0 +1,6 @@
namespace Phantom.Utils.Rpc.Message;
readonly record struct MessageTypeMappings<TClientToServerMessage, TServerToClientMessage>(
MessageTypeMapping<TServerToClientMessage> ToClient,
MessageTypeMapping<TClientToServerMessage> ToServer
);

View File

@@ -0,0 +1,58 @@
using System.Text;
using Phantom.Utils.Rpc.Runtime;
namespace Phantom.Utils.Rpc.Message;
sealed class MessageTypeName {
private readonly string stringValue;
private readonly ReadOnlyMemory<byte> serializedBytes;
public MessageTypeName(string name) {
this.stringValue = name;
this.serializedBytes = Encoding.ASCII.GetBytes(name);
if (serializedBytes.Length is 0 or > byte.MaxValue) {
throw new ArgumentOutOfRangeException(nameof(name), "Message name must be between 0 and " + byte.MaxValue + " bytes.");
}
}
private MessageTypeName(ReadOnlyMemory<byte> serializedBytes) {
this.stringValue = Encoding.ASCII.GetString(serializedBytes.Span);
this.serializedBytes = serializedBytes;
}
public async ValueTask Write(RpcStream stream, CancellationToken cancellationToken) {
await stream.WriteByte((byte) serializedBytes.Length, cancellationToken);
await stream.WriteBytes(serializedBytes, cancellationToken);
}
public static async ValueTask WriteEnd(RpcStream stream, CancellationToken cancellationToken) {
await stream.WriteByte(value: 0, cancellationToken);
}
public static async ValueTask<MessageTypeName?> Read(RpcStream stream, CancellationToken cancellationToken) {
byte serializedBytesLength = await stream.ReadByte(cancellationToken);
if (serializedBytesLength == 0) {
return null;
}
var serializedBytes = await stream.ReadBytes(serializedBytesLength, cancellationToken);
return new MessageTypeName(serializedBytes);
}
public override bool Equals(object? obj) {
if (ReferenceEquals(this, obj)) {
return true;
}
return obj is MessageTypeName other && stringValue == other.stringValue;
}
public override int GetHashCode() {
return stringValue.GetHashCode();
}
public override string ToString() {
return stringValue;
}
}

View File

@@ -7,21 +7,25 @@ using Serilog;
namespace Phantom.Utils.Rpc.Runtime.Client; namespace Phantom.Utils.Rpc.Runtime.Client;
public sealed class RpcClient<TClientToServerMessage, TServerToClientMessage> : IRpcConnectionProvider, IDisposable { public sealed class RpcClient<TClientToServerMessage, TServerToClientMessage> : IRpcConnectionProvider, IDisposable {
public static async Task<RpcClient<TClientToServerMessage, TServerToClientMessage>?> Connect(string loggerName, RpcClientConnectionParameters connectionParameters, IMessageDefinitions<TClientToServerMessage, TServerToClientMessage> messageDefinitions, CancellationToken cancellationToken) { public static async Task<RpcClient<TClientToServerMessage, TServerToClientMessage>?> Connect(
RpcClientToServerConnector connector = new RpcClientToServerConnector(loggerName, connectionParameters); string loggerName,
RpcClientToServerConnector.Connection? connection = await connector.ConnectWithRetries(maxAttempts: 10, cancellationToken); RpcClientConnectionParameters connectionParameters,
return connection == null ? null : new RpcClient<TClientToServerMessage, TServerToClientMessage>(loggerName, connectionParameters, connector, connection, messageDefinitions); MessageRegistries<TClientToServerMessage, TServerToClientMessage> messageRegistries,
CancellationToken cancellationToken
) {
var connector = new RpcClientToServerConnector<TClientToServerMessage, TServerToClientMessage>(loggerName, connectionParameters, messageRegistries);
var connection = await connector.ConnectWithRetries(maxAttempts: 10, cancellationToken);
return connection == null ? null : new RpcClient<TClientToServerMessage, TServerToClientMessage>(loggerName, connectionParameters, connector, connection);
} }
private readonly string loggerName; private readonly string loggerName;
private readonly ILogger logger; private readonly ILogger logger;
private readonly RpcCommonConnectionParameters connectionParameters; private readonly RpcCommonConnectionParameters connectionParameters;
private readonly RpcClientToServerConnector connector; private readonly RpcClientToServerConnector<TClientToServerMessage, TServerToClientMessage> connector;
private readonly IMessageDefinitions<TClientToServerMessage, TServerToClientMessage> messageDefinitions;
private readonly IRpcFrameSenderProvider<TClientToServerMessage>.Mutable frameSenderProvider = new (); private readonly IRpcFrameSenderProvider<TClientToServerMessage>.Mutable frameSenderProvider = new ();
private RpcClientToServerConnector.Connection currentConnection; private RpcClientToServerConnector<TClientToServerMessage, TServerToClientMessage>.Connection currentConnection;
private readonly SemaphoreSlim currentConnectionSemaphore = new (1); private readonly SemaphoreSlim currentConnectionSemaphore = new (1);
private Task? listenerTask; private Task? listenerTask;
@@ -30,14 +34,18 @@ public sealed class RpcClient<TClientToServerMessage, TServerToClientMessage> :
public MessageSender<TClientToServerMessage> MessageSender { get; } public MessageSender<TClientToServerMessage> MessageSender { get; }
private RpcClient(string loggerName, RpcCommonConnectionParameters connectionParameters, RpcClientToServerConnector connector, RpcClientToServerConnector.Connection connection, IMessageDefinitions<TClientToServerMessage, TServerToClientMessage> messageDefinitions) { private RpcClient(
string loggerName,
RpcCommonConnectionParameters connectionParameters,
RpcClientToServerConnector<TClientToServerMessage, TServerToClientMessage> connector,
RpcClientToServerConnector<TClientToServerMessage, TServerToClientMessage>.Connection connection
) {
this.loggerName = loggerName; this.loggerName = loggerName;
this.logger = PhantomLogger.Create<RpcClient<TClientToServerMessage, TServerToClientMessage>>(loggerName); this.logger = PhantomLogger.Create<RpcClient<TClientToServerMessage, TServerToClientMessage>>(loggerName);
this.connectionParameters = connectionParameters; this.connectionParameters = connectionParameters;
this.connector = connector; this.connector = connector;
this.currentConnection = connection; this.currentConnection = connection;
this.messageDefinitions = messageDefinitions;
this.MessageSender = new MessageSender<TClientToServerMessage>(loggerName, connectionParameters, frameSenderProvider); this.MessageSender = new MessageSender<TClientToServerMessage>(loggerName, connectionParameters, frameSenderProvider);
} }
@@ -46,7 +54,7 @@ public sealed class RpcClient<TClientToServerMessage, TServerToClientMessage> :
return (await GetConnection(cancellationToken)).Stream; return (await GetConnection(cancellationToken)).Stream;
} }
private async Task<RpcClientToServerConnector.Connection> GetConnection(CancellationToken cancellationToken) { private async Task<RpcClientToServerConnector<TClientToServerMessage, TServerToClientMessage>.Connection> GetConnection(CancellationToken cancellationToken) {
await currentConnectionSemaphore.WaitAsync(cancellationToken); await currentConnectionSemaphore.WaitAsync(cancellationToken);
try { try {
if (!currentConnection.Socket.Connected) { if (!currentConnection.Socket.Connected) {
@@ -70,7 +78,7 @@ public sealed class RpcClient<TClientToServerMessage, TServerToClientMessage> :
private async Task Listen(IMessageReceiver<TServerToClientMessage> messageReceiver) { private async Task Listen(IMessageReceiver<TServerToClientMessage> messageReceiver) {
CancellationToken cancellationToken = shutdownCancellationTokenSource.Token; CancellationToken cancellationToken = shutdownCancellationTokenSource.Token;
RpcClientToServerConnector.Connection? connection = null; RpcClientToServerConnector<TClientToServerMessage, TServerToClientMessage>.Connection? connection = null;
SessionState? sessionState = null; SessionState? sessionState = null;
try { try {
@@ -138,10 +146,10 @@ public sealed class RpcClient<TClientToServerMessage, TServerToClientMessage> :
} }
} }
private SessionState NewSessionState(RpcClientToServerConnector.Connection connection, IMessageReceiver<TServerToClientMessage> messageReceiver) { private SessionState NewSessionState(RpcClientToServerConnector<TClientToServerMessage, TServerToClientMessage>.Connection connection, IMessageReceiver<TServerToClientMessage> messageReceiver) {
var frameSender = new RpcFrameSender<TClientToServerMessage>(loggerName, connectionParameters, this, messageDefinitions.ToServer, connection.PingInterval); var frameSender = new RpcFrameSender<TClientToServerMessage>(loggerName, connectionParameters, this, connection.MessageTypeMappings.ToServer, connection.PingInterval);
var messageHandler = new MessageHandler<TServerToClientMessage>(messageReceiver, frameSender); var messageHandler = new MessageHandler<TServerToClientMessage>(messageReceiver, frameSender);
var frameReader = new RpcFrameReader<TClientToServerMessage, TServerToClientMessage>(loggerName, connectionParameters, messageDefinitions.ToClient, messageHandler, MessageSender, frameSender); var frameReader = new RpcFrameReader<TClientToServerMessage, TServerToClientMessage>(loggerName, connectionParameters, connection.MessageTypeMappings.ToClient, messageHandler, MessageSender, frameSender);
frameSenderProvider.SetNewValue(frameSender); frameSenderProvider.SetNewValue(frameSender);
messageReceiver.OnSessionRestarted(); messageReceiver.OnSessionRestarted();
@@ -150,7 +158,7 @@ public sealed class RpcClient<TClientToServerMessage, TServerToClientMessage> :
} }
private readonly record struct SessionState(RpcFrameSender<TClientToServerMessage> FrameSender, RpcFrameReader<TClientToServerMessage, TServerToClientMessage> FrameReader) { private readonly record struct SessionState(RpcFrameSender<TClientToServerMessage> FrameSender, RpcFrameReader<TClientToServerMessage, TServerToClientMessage> FrameReader) {
public void Update(ILogger logger, RpcClientToServerConnector.Connection connection) { public void Update(ILogger logger, RpcClientToServerConnector<TClientToServerMessage, TServerToClientMessage>.Connection connection) {
TimeSpan currentPingInterval = FrameSender.PingInterval; TimeSpan currentPingInterval = FrameSender.PingInterval;
if (currentPingInterval != connection.PingInterval) { if (currentPingInterval != connection.PingInterval) {
logger.Warning("Server requested a different ping interval ({ServerPingInterval}s) than currently set ({ClientPingInterval}s), but ping interval cannot be updated for existing sessions.", connection.PingInterval.TotalSeconds, currentPingInterval.TotalSeconds); logger.Warning("Server requested a different ping interval ({ServerPingInterval}s) than currently set ({ClientPingInterval}s), but ping interval cannot be updated for existing sessions.", connection.PingInterval.TotalSeconds, currentPingInterval.TotalSeconds);

View File

@@ -1,30 +1,36 @@
using System.Net.Security; using System.Diagnostics.CodeAnalysis;
using System.Net.Security;
using System.Net.Sockets; using System.Net.Sockets;
using System.Security.Authentication; using System.Security.Authentication;
using System.Security.Cryptography.X509Certificates; using System.Security.Cryptography.X509Certificates;
using Phantom.Utils.Collections; using Phantom.Utils.Collections;
using Phantom.Utils.Logging; using Phantom.Utils.Logging;
using Phantom.Utils.Rpc.Message;
using Phantom.Utils.Rpc.Runtime.Tls; using Phantom.Utils.Rpc.Runtime.Tls;
using Serilog; using Serilog;
using Serilog.Events;
namespace Phantom.Utils.Rpc.Runtime.Client; namespace Phantom.Utils.Rpc.Runtime.Client;
sealed class RpcClientToServerConnector { [SuppressMessage("ReSharper", "StaticMemberInGenericType")]
sealed class RpcClientToServerConnector<TClientToServerMessage, TServerToClientMessage> {
private static readonly TimeSpan InitialRetryDelay = TimeSpan.FromMilliseconds(500); private static readonly TimeSpan InitialRetryDelay = TimeSpan.FromMilliseconds(500);
private static readonly TimeSpan MaximumRetryDelay = TimeSpan.FromSeconds(30); private static readonly TimeSpan MaximumRetryDelay = TimeSpan.FromSeconds(30);
private static readonly TimeSpan DisconnectTimeout = TimeSpan.FromSeconds(10); private static readonly TimeSpan DisconnectTimeout = TimeSpan.FromSeconds(10);
private readonly ILogger logger; private readonly ILogger logger;
private readonly Guid sessionId;
private readonly RpcClientConnectionParameters parameters; private readonly RpcClientConnectionParameters parameters;
private readonly MessageRegistries<TClientToServerMessage, TServerToClientMessage> messageRegistries;
private readonly Guid sessionId;
private readonly SslClientAuthenticationOptions sslOptions; private readonly SslClientAuthenticationOptions sslOptions;
private bool loggedCertificateValidationError = false; private bool loggedCertificateValidationError = false;
public RpcClientToServerConnector(string loggerName, RpcClientConnectionParameters parameters) { public RpcClientToServerConnector(string loggerName, RpcClientConnectionParameters parameters, MessageRegistries<TClientToServerMessage, TServerToClientMessage> messageRegistries) {
this.logger = PhantomLogger.Create<RpcClientToServerConnector>(loggerName); this.logger = PhantomLogger.Create<RpcClientToServerConnector<TClientToServerMessage, TServerToClientMessage>>(loggerName);
this.sessionId = Guid.NewGuid();
this.parameters = parameters; this.parameters = parameters;
this.messageRegistries = messageRegistries;
this.sessionId = Guid.NewGuid();
this.sslOptions = new SslClientAuthenticationOptions { this.sslOptions = new SslClientAuthenticationOptions {
AllowRenegotiation = false, AllowRenegotiation = false,
@@ -114,7 +120,7 @@ sealed class RpcClientToServerConnector {
if (await AuthenticateAndPerformHandshake(stream, cancellationToken) is {} result) { if (await AuthenticateAndPerformHandshake(stream, cancellationToken) is {} result) {
logger.Information("Connected to {Host}:{Port}.", parameters.Host, parameters.Port); logger.Information("Connected to {Host}:{Port}.", parameters.Host, parameters.Port);
return new Connection(clientSocket, stream, result.IsNewSession, result.PingInterval); return new Connection(clientSocket, stream, result.IsNewSession, result.PingInterval, result.MessageTypeMappings);
} }
} catch (Exception e) { } catch (Exception e) {
logger.Error(e, "Caught unhandled exception."); logger.Error(e, "Caught unhandled exception.");
@@ -167,13 +173,12 @@ sealed class RpcClientToServerConnector {
await stream.WriteGuid(sessionId, cancellationToken); await stream.WriteGuid(sessionId, cancellationToken);
await stream.Flush(cancellationToken); await stream.Flush(cancellationToken);
ushort pingIntervalSeconds = await stream.ReadUnsignedShort(cancellationToken); var pingInterval = await ReadPingInterval(stream, cancellationToken);
if (pingIntervalSeconds == 0) { if (pingInterval == null) {
logger.Error("Server sent invalid ping interval.");
return null; return null;
} }
logger.Debug("Server requested a ping interval of {PingInterval}s.", pingIntervalSeconds); var mappedMessageDefinitions = await ReadMessageMappings(stream, cancellationToken);
await parameters.Handshake.Perform(stream, cancellationToken); await parameters.Handshake.Perform(stream, cancellationToken);
@@ -183,10 +188,42 @@ sealed class RpcClientToServerConnector {
return null; return null;
} }
return new ConnectionResult(finalHandshakeResult == RpcFinalHandshakeResult.NewSession, TimeSpan.FromSeconds(pingIntervalSeconds)); return new ConnectionResult(finalHandshakeResult == RpcFinalHandshakeResult.NewSession, pingInterval.Value, mappedMessageDefinitions);
} }
private readonly record struct ConnectionResult(bool IsNewSession, TimeSpan PingInterval); private async Task<TimeSpan?> ReadPingInterval(RpcStream stream, CancellationToken cancellationToken) {
ushort pingIntervalSeconds = await stream.ReadUnsignedShort(cancellationToken);
if (pingIntervalSeconds == 0) {
logger.Error("Server sent invalid ping interval.");
return null;
}
logger.Debug("Server requested a ping interval of {PingInterval}s.", pingIntervalSeconds);
return TimeSpan.FromSeconds(pingIntervalSeconds);
}
private async Task<MessageTypeMappings<TClientToServerMessage, TServerToClientMessage>> ReadMessageMappings(RpcStream stream, CancellationToken cancellationToken) {
var toClient = await ReadMessageMapping(messageRegistries.ToClient, stream, cancellationToken);
var toServer = await ReadMessageMapping(messageRegistries.ToServer, stream, cancellationToken);
return new MessageTypeMappings<TClientToServerMessage, TServerToClientMessage>(toClient, toServer);
}
private async Task<MessageTypeMapping<TMessageBase>> ReadMessageMapping<TMessageBase>(MessageRegistry<TMessageBase> messageRegistry, RpcStream stream, CancellationToken cancellationToken) {
var result = await messageRegistry.ReadMapping(stream, cancellationToken);
if (logger.IsEnabled(LogEventLevel.Debug)) {
foreach ((byte messageTypeCode, MessageTypeName messageTypeName) in result.SupportedMessages) {
logger.Debug("Server requested code {MessageCode} for message {MessageBaseTypeName}:{MessageTypeName}.", messageTypeCode, typeof(TMessageBase).Name, messageTypeName);
}
}
foreach ((byte messageTypeCode, MessageTypeName messageTypeName) in result.UnsupportedMessages) {
logger.Warning("Server requested code {MessageCode} for message {MessageBaseTypeName}:{MessageTypeName} that the client does not support.", messageTypeCode, typeof(TMessageBase).Name, messageTypeName);
}
return result.TypeMapping;
}
private bool ValidateServerCertificate(object sender, X509Certificate? certificate, X509Chain? chain, SslPolicyErrors sslPolicyErrors) { private bool ValidateServerCertificate(object sender, X509Certificate? certificate, X509Chain? chain, SslPolicyErrors sslPolicyErrors) {
if (certificate == null || sslPolicyErrors.HasFlag(SslPolicyErrors.RemoteCertificateNotAvailable)) { if (certificate == null || sslPolicyErrors.HasFlag(SslPolicyErrors.RemoteCertificateNotAvailable)) {
@@ -221,7 +258,9 @@ sealed class RpcClientToServerConnector {
await socket.DisconnectAsync(reuseSocket: false, timeoutTokenSource.Token); await socket.DisconnectAsync(reuseSocket: false, timeoutTokenSource.Token);
} }
internal sealed record Connection(Socket Socket, RpcStream Stream, bool IsNewSession, TimeSpan PingInterval) : IAsyncDisposable { private readonly record struct ConnectionResult(bool IsNewSession, TimeSpan PingInterval, MessageTypeMappings<TClientToServerMessage, TServerToClientMessage> MessageTypeMappings);
internal sealed record Connection(Socket Socket, RpcStream Stream, bool IsNewSession, TimeSpan PingInterval, MessageTypeMappings<TClientToServerMessage, TServerToClientMessage> MessageTypeMappings) : IAsyncDisposable {
public async Task Disconnect() { public async Task Disconnect() {
await DisconnectSocket(Socket, Stream); await DisconnectSocket(Socket, Stream);
} }

View File

@@ -9,7 +9,7 @@ namespace Phantom.Utils.Rpc.Runtime;
sealed class RpcFrameReader<TSentMessage, TReceivedMessage>( sealed class RpcFrameReader<TSentMessage, TReceivedMessage>(
string loggerName, string loggerName,
RpcCommonConnectionParameters connectionParameters, RpcCommonConnectionParameters connectionParameters,
MessageRegistry<TReceivedMessage> messageRegistry, MessageTypeMapping<TReceivedMessage> messageTypeMapping,
MessageHandler<TReceivedMessage> messageHandler, MessageHandler<TReceivedMessage> messageHandler,
MessageSender<TSentMessage> messageSender, MessageSender<TSentMessage> messageSender,
RpcFrameSender<TSentMessage> frameSender RpcFrameSender<TSentMessage> frameSender
@@ -38,7 +38,7 @@ sealed class RpcFrameReader<TSentMessage, TReceivedMessage>(
return; return;
} }
if (messageRegistry.TryGetType(frame, out var messageType)) { if (messageTypeMapping.TryGetType(frame, out var messageType)) {
logger.Debug("Received message {MesageId} of type {MessageType} ({Bytes} B).", frame.MessageId, messageType.Name, frame.SerializedMessage.Length); logger.Debug("Received message {MesageId} of type {MessageType} ({Bytes} B).", frame.MessageId, messageType.Name, frame.SerializedMessage.Length);
} }
@@ -58,7 +58,7 @@ sealed class RpcFrameReader<TSentMessage, TReceivedMessage>(
private async Task HandleMessage(MessageFrame frame, CancellationToken cancellationToken) { private async Task HandleMessage(MessageFrame frame, CancellationToken cancellationToken) {
try { try {
await messageRegistry.Handle(frame, messageHandler, cancellationToken); await messageTypeMapping.Handle(frame, messageHandler, cancellationToken);
} finally { } finally {
messageHandlingSemaphore.Release(); messageHandlingSemaphore.Release();
} }

View File

@@ -12,7 +12,7 @@ namespace Phantom.Utils.Rpc.Runtime;
sealed class RpcFrameSender<TMessageBase> : IMessageReplySender { sealed class RpcFrameSender<TMessageBase> : IMessageReplySender {
private readonly ILogger logger; private readonly ILogger logger;
private readonly IRpcConnectionProvider connectionProvider; private readonly IRpcConnectionProvider connectionProvider;
private readonly MessageRegistry<TMessageBase> messageRegistry; private readonly MessageTypeMapping<TMessageBase> messageTypeMapping;
private readonly MessageReceiveTracker messageReceiveTracker = new (); private readonly MessageReceiveTracker messageReceiveTracker = new ();
private readonly Channel<IFrame> frameQueue; private readonly Channel<IFrame> frameQueue;
@@ -27,10 +27,10 @@ sealed class RpcFrameSender<TMessageBase> : IMessageReplySender {
internal TimeSpan PingInterval { get; } internal TimeSpan PingInterval { get; }
internal RpcFrameSender(string loggerName, RpcCommonConnectionParameters connectionParameters, IRpcConnectionProvider connectionProvider, MessageRegistry<TMessageBase> messageRegistry, TimeSpan pingInterval) { internal RpcFrameSender(string loggerName, RpcCommonConnectionParameters connectionParameters, IRpcConnectionProvider connectionProvider, MessageTypeMapping<TMessageBase> messageTypeMapping, TimeSpan pingInterval) {
this.logger = PhantomLogger.Create<RpcFrameSender<TMessageBase>>(loggerName); this.logger = PhantomLogger.Create<RpcFrameSender<TMessageBase>>(loggerName);
this.connectionProvider = connectionProvider; this.connectionProvider = connectionProvider;
this.messageRegistry = messageRegistry; this.messageTypeMapping = messageTypeMapping;
this.frameQueue = Channel.CreateBounded<IFrame>(new BoundedChannelOptions(connectionParameters.FrameQueueCapacity) { this.frameQueue = Channel.CreateBounded<IFrame>(new BoundedChannelOptions(connectionParameters.FrameQueueCapacity) {
AllowSynchronousContinuations = false, AllowSynchronousContinuations = false,
@@ -50,7 +50,7 @@ sealed class RpcFrameSender<TMessageBase> : IMessageReplySender {
} }
public async ValueTask SendMessage<TMessage>(uint messageId, TMessage message, CancellationToken cancellationToken) where TMessage : TMessageBase { public async ValueTask SendMessage<TMessage>(uint messageId, TMessage message, CancellationToken cancellationToken) where TMessage : TMessageBase {
var frame = messageRegistry.CreateFrame(messageId, message); var frame = messageTypeMapping.CreateFrame(messageId, message);
logger.Debug("Sending message {MesageId} of type {MessageType} ({MessageBytes} B).", messageId, typeof(TMessage).Name, frame.SerializedMessage.Length); logger.Debug("Sending message {MesageId} of type {MessageType} ({MessageBytes} B).", messageId, typeof(TMessage).Name, frame.SerializedMessage.Length);
await SendFrame(frame, cancellationToken); await SendFrame(frame, cancellationToken);
} }

View File

@@ -10,17 +10,33 @@ using Serilog;
namespace Phantom.Utils.Rpc.Runtime.Server; namespace Phantom.Utils.Rpc.Runtime.Server;
public sealed class RpcServer<TClientToServerMessage, TServerToClientMessage, THandshakeResult>( public sealed class RpcServer<TClientToServerMessage, TServerToClientMessage, THandshakeResult> {
string loggerName, private readonly string loggerName;
RpcServerConnectionParameters connectionParameters, private readonly ILogger logger;
IMessageDefinitions<TClientToServerMessage, TServerToClientMessage> messageDefinitions, private readonly RpcServerConnectionParameters connectionParameters;
IRpcServerClientHandshake<THandshakeResult> clientHandshake, private readonly MessageRegistries<TClientToServerMessage, TServerToClientMessage>.WithMapping messageRegistries;
IRpcServerClientRegistrar<TClientToServerMessage, TServerToClientMessage, THandshakeResult> clientRegistrar private readonly IRpcServerClientHandshake<THandshakeResult> clientHandshake;
) { private readonly IRpcServerClientRegistrar<TClientToServerMessage, TServerToClientMessage, THandshakeResult> clientRegistrar;
private readonly ILogger logger = PhantomLogger.Create<RpcServer<TClientToServerMessage, TServerToClientMessage, THandshakeResult>>(loggerName);
private readonly RpcServerClientSessions<TServerToClientMessage> clientSessions = new (loggerName, connectionParameters, messageDefinitions.ToClient); private readonly RpcServerClientSessions<TServerToClientMessage> clientSessions;
private readonly List<Client> clients = []; private readonly List<Client> clients = [];
public RpcServer(
string loggerName,
RpcServerConnectionParameters connectionParameters,
MessageRegistries<TClientToServerMessage, TServerToClientMessage> messageRegistries,
IRpcServerClientHandshake<THandshakeResult> clientHandshake,
IRpcServerClientRegistrar<TClientToServerMessage, TServerToClientMessage, THandshakeResult> clientRegistrar
) {
this.loggerName = loggerName;
this.logger = PhantomLogger.Create<RpcServer<TClientToServerMessage, TServerToClientMessage, THandshakeResult>>(loggerName);
this.connectionParameters = connectionParameters;
this.messageRegistries = messageRegistries.CreateMapping();
this.clientHandshake = clientHandshake;
this.clientRegistrar = clientRegistrar;
this.clientSessions = new RpcServerClientSessions<TServerToClientMessage>(loggerName, connectionParameters, this.messageRegistries.ToClient.Mapping);
}
public async Task<bool> Run(CancellationToken shutdownToken) { public async Task<bool> Run(CancellationToken shutdownToken) {
EndPoint endPoint = connectionParameters.EndPoint; EndPoint endPoint = connectionParameters.EndPoint;
@@ -36,7 +52,7 @@ public sealed class RpcServer<TClientToServerMessage, TServerToClientMessage, TH
var serverData = new SharedData( var serverData = new SharedData(
connectionParameters, connectionParameters,
messageDefinitions.ToServer, messageRegistries,
clientHandshake, clientHandshake,
clientRegistrar, clientRegistrar,
clientSessions clientSessions
@@ -94,7 +110,7 @@ public sealed class RpcServer<TClientToServerMessage, TServerToClientMessage, TH
private readonly record struct SharedData( private readonly record struct SharedData(
RpcServerConnectionParameters ConnectionParameters, RpcServerConnectionParameters ConnectionParameters,
MessageRegistry<TClientToServerMessage> MessageRegistry, MessageRegistries<TClientToServerMessage, TServerToClientMessage>.WithMapping MessageDefinitions,
IRpcServerClientHandshake<THandshakeResult> ClientHandshake, IRpcServerClientHandshake<THandshakeResult> ClientHandshake,
IRpcServerClientRegistrar<TClientToServerMessage, TServerToClientMessage, THandshakeResult> ClientRegistrar, IRpcServerClientRegistrar<TClientToServerMessage, TServerToClientMessage, THandshakeResult> ClientRegistrar,
RpcServerClientSessions<TServerToClientMessage> ClientSessions RpcServerClientSessions<TServerToClientMessage> ClientSessions
@@ -226,6 +242,8 @@ public sealed class RpcServer<TClientToServerMessage, TServerToClientMessage, TH
} }
await stream.WriteUnsignedShort(sharedData.ConnectionParameters.PingIntervalSeconds, cancellationToken); await stream.WriteUnsignedShort(sharedData.ConnectionParameters.PingIntervalSeconds, cancellationToken);
await sharedData.MessageDefinitions.ToClient.Write(stream, cancellationToken);
await sharedData.MessageDefinitions.ToServer.Write(stream, cancellationToken);
await stream.Flush(cancellationToken); await stream.Flush(cancellationToken);
var sessionId = await stream.ReadGuid(cancellationToken); var sessionId = await stream.ReadGuid(cancellationToken);
@@ -263,7 +281,7 @@ public sealed class RpcServer<TClientToServerMessage, TServerToClientMessage, TH
switch (await sharedData.ClientHandshake.Perform(session.IsNew, stream, cancellationToken)) { switch (await sharedData.ClientHandshake.Perform(session.IsNew, stream, cancellationToken)) {
case Left<THandshakeResult, Exception>(var handshakeResult): case Left<THandshakeResult, Exception>(var handshakeResult):
try { try {
var connection = new RpcServerToClientConnection<TClientToServerMessage, TServerToClientMessage>(sharedData.ConnectionParameters, sharedData.MessageRegistry, session, stream); var connection = new RpcServerToClientConnection<TClientToServerMessage, TServerToClientMessage>(sharedData.ConnectionParameters, sharedData.MessageDefinitions.ToServer.Mapping, session, stream);
var messageReceiver = sharedData.ClientRegistrar.Register(connection, handshakeResult); var messageReceiver = sharedData.ClientRegistrar.Register(connection, handshakeResult);
return new EstablishedConnection(session, connection, messageReceiver); return new EstablishedConnection(session, connection, messageReceiver);

View File

@@ -28,12 +28,12 @@ sealed class RpcServerClientSession<TServerToClientMessage> : IRpcConnectionProv
public CancellationToken CloseCancellationToken => closeCancellationTokenSource.Token; public CancellationToken CloseCancellationToken => closeCancellationTokenSource.Token;
public RpcServerClientSession(string loggerName, RpcServerConnectionParameters connectionParameters, MessageRegistry<TServerToClientMessage> messageRegistry, RpcServerClientSessions<TServerToClientMessage> sessions, Guid sessionId) { public RpcServerClientSession(string loggerName, RpcServerConnectionParameters connectionParameters, MessageTypeMapping<TServerToClientMessage> messageTypeMapping, RpcServerClientSessions<TServerToClientMessage> sessions, Guid sessionId) {
this.logger = PhantomLogger.Create<RpcServerClientSession<TServerToClientMessage>>(loggerName); this.logger = PhantomLogger.Create<RpcServerClientSession<TServerToClientMessage>>(loggerName);
this.LoggerName = loggerName; this.LoggerName = loggerName;
this.sessions = sessions; this.sessions = sessions;
this.SessionId = sessionId; this.SessionId = sessionId;
this.FrameSender = new RpcFrameSender<TServerToClientMessage>(loggerName, connectionParameters, this, messageRegistry, connectionParameters.PingInterval); this.FrameSender = new RpcFrameSender<TServerToClientMessage>(loggerName, connectionParameters, this, messageTypeMapping, connectionParameters.PingInterval);
this.MessageSender = new MessageSender<TServerToClientMessage>(loggerName, connectionParameters, new IRpcFrameSenderProvider<TServerToClientMessage>.Constant(FrameSender)); this.MessageSender = new MessageSender<TServerToClientMessage>(loggerName, connectionParameters, new IRpcFrameSenderProvider<TServerToClientMessage>.Constant(FrameSender));
this.closeAfterDisconnectionTimer = new Timer(DisconnectedSessionTimeout) { AutoReset = false }; this.closeAfterDisconnectionTimer = new Timer(DisconnectedSessionTimeout) { AutoReset = false };

View File

@@ -7,7 +7,7 @@ namespace Phantom.Utils.Rpc.Runtime.Server;
sealed class RpcServerClientSessions<TServerToClientMessage> { sealed class RpcServerClientSessions<TServerToClientMessage> {
private readonly string loggerName; private readonly string loggerName;
private readonly RpcServerConnectionParameters connectionParameters; private readonly RpcServerConnectionParameters connectionParameters;
private readonly MessageRegistry<TServerToClientMessage> messageRegistry; private readonly MessageTypeMapping<TServerToClientMessage> messageTypeMapping;
private readonly ConcurrentDictionary<Guid, RpcServerClientSession<TServerToClientMessage>> sessionsById = new (); private readonly ConcurrentDictionary<Guid, RpcServerClientSession<TServerToClientMessage>> sessionsById = new ();
@@ -16,10 +16,10 @@ sealed class RpcServerClientSessions<TServerToClientMessage> {
public int Count => sessionsById.Count; public int Count => sessionsById.Count;
public RpcServerClientSessions(string loggerName, RpcServerConnectionParameters connectionParameters, MessageRegistry<TServerToClientMessage> messageRegistry) { public RpcServerClientSessions(string loggerName, RpcServerConnectionParameters connectionParameters, MessageTypeMapping<TServerToClientMessage> messageTypeMapping) {
this.loggerName = loggerName; this.loggerName = loggerName;
this.connectionParameters = connectionParameters; this.connectionParameters = connectionParameters;
this.messageRegistry = messageRegistry; this.messageTypeMapping = messageTypeMapping;
this.createSessionFunction = CreateSession; this.createSessionFunction = CreateSession;
} }
@@ -28,7 +28,7 @@ sealed class RpcServerClientSessions<TServerToClientMessage> {
} }
private RpcServerClientSession<TServerToClientMessage> CreateSession(Guid sessionId) { private RpcServerClientSession<TServerToClientMessage> CreateSession(Guid sessionId) {
return new RpcServerClientSession<TServerToClientMessage>(NextLoggerName(sessionId), connectionParameters, messageRegistry, this, sessionId); return new RpcServerClientSession<TServerToClientMessage>(NextLoggerName(sessionId), connectionParameters, messageTypeMapping, this, sessionId);
} }
private string NextLoggerName(Guid sessionId) { private string NextLoggerName(Guid sessionId) {

View File

@@ -8,24 +8,29 @@ namespace Phantom.Utils.Rpc.Runtime.Server;
public sealed class RpcServerToClientConnection<TClientToServerMessage, TServerToClientMessage> { public sealed class RpcServerToClientConnection<TClientToServerMessage, TServerToClientMessage> {
private readonly ILogger logger; private readonly ILogger logger;
private readonly RpcCommonConnectionParameters connectionParameters; private readonly RpcCommonConnectionParameters connectionParameters;
private readonly MessageRegistry<TClientToServerMessage> messageRegistry; private readonly MessageTypeMapping<TClientToServerMessage> messageTypeMapping;
private readonly RpcServerClientSession<TServerToClientMessage> session; private readonly RpcServerClientSession<TServerToClientMessage> session;
private readonly RpcStream stream; private readonly RpcStream stream;
public Guid SessionId => session.SessionId; public Guid SessionId => session.SessionId;
public MessageSender<TServerToClientMessage> MessageSender => session.MessageSender; public MessageSender<TServerToClientMessage> MessageSender => session.MessageSender;
internal RpcServerToClientConnection(RpcCommonConnectionParameters connectionParameters, MessageRegistry<TClientToServerMessage> messageRegistry, RpcServerClientSession<TServerToClientMessage> session, RpcStream stream) { internal RpcServerToClientConnection(
RpcCommonConnectionParameters connectionParameters,
MessageTypeMapping<TClientToServerMessage> messageTypeMapping,
RpcServerClientSession<TServerToClientMessage> session,
RpcStream stream
) {
this.logger = PhantomLogger.Create<RpcServerToClientConnection<TClientToServerMessage, TServerToClientMessage>>(session.LoggerName); this.logger = PhantomLogger.Create<RpcServerToClientConnection<TClientToServerMessage, TServerToClientMessage>>(session.LoggerName);
this.connectionParameters = connectionParameters; this.connectionParameters = connectionParameters;
this.messageRegistry = messageRegistry; this.messageTypeMapping = messageTypeMapping;
this.session = session; this.session = session;
this.stream = stream; this.stream = stream;
} }
internal async Task Listen(IMessageReceiver<TClientToServerMessage> messageReceiver) { internal async Task Listen(IMessageReceiver<TClientToServerMessage> messageReceiver) {
var messageHandler = new MessageHandler<TClientToServerMessage>(messageReceiver, session.FrameSender); var messageHandler = new MessageHandler<TClientToServerMessage>(messageReceiver, session.FrameSender);
var frameReader = new RpcFrameReader<TServerToClientMessage, TClientToServerMessage>(session.LoggerName, connectionParameters, messageRegistry, messageHandler, MessageSender, session.FrameSender); var frameReader = new RpcFrameReader<TServerToClientMessage, TClientToServerMessage>(session.LoggerName, connectionParameters, messageTypeMapping, messageHandler, MessageSender, session.FrameSender);
try { try {
await IFrame.ReadFrom(stream, frameReader, session.CloseCancellationToken); await IFrame.ReadFrom(stream, frameReader, session.CloseCancellationToken);

View File

@@ -0,0 +1,52 @@
using System.Buffers;
using System.Diagnostics.CodeAnalysis;
using NUnit.Framework;
using Phantom.Utils.Collections;
namespace Phantom.Utils.Tests.Collections;
[TestFixture]
[SuppressMessage("Performance", "CA1861")]
public sealed class SpanIndexEnumeratorTests {
private static SearchValues<char> Search => SearchValues.Create(' ', '-');
private static List<int> Indices(string str) {
List<int> indices = [];
foreach (int index in str.AsSpan().IndicesOf(Search)) {
indices.Add(index);
}
return indices;
}
[Test]
public void Empty() {
Assert.That(Indices(""), Is.EquivalentTo(Array.Empty<int>()));
}
[Test]
public void OnlyFirstIndex() {
Assert.That(Indices(" "), Is.EquivalentTo(new[] { 0 }));
}
[Test]
public void OnlyMiddleIndex() {
Assert.That(Indices("ab cd"), Is.EquivalentTo(new[] { 2 }));
}
[Test]
public void OnlyLastIndex() {
Assert.That(Indices("abc "), Is.EquivalentTo(new[] { 3 }));
}
[Test]
public void FirstAndLastIndex() {
Assert.That(Indices(" abc-"), Is.EquivalentTo(new[] { 0, 4 }));
}
[Test]
public void AllIndices() {
Assert.That(Indices("- - -"), Is.EquivalentTo(new[] { 0, 1, 2, 3, 4, 5, 6 }));
}
}

View File

@@ -0,0 +1,37 @@
using System.Buffers;
using System.Collections;
namespace Phantom.Utils.Collections;
public ref struct SpanIndexEnumerator<T>(ReadOnlySpan<T> span, SearchValues<T> searchValues) : IEnumerator<int> where T : IEquatable<T> {
private readonly ReadOnlySpan<T> span = span;
public int Current { get; private set; } = -1;
readonly object IEnumerator.Current => Current;
public readonly SpanIndexEnumerator<T> GetEnumerator() => this;
public bool MoveNext() {
int startIndex = Current + 1;
int relativeIndex = span[startIndex..].IndexOfAny(searchValues);
if (relativeIndex == -1) {
return false;
}
Current = startIndex + relativeIndex;
return true;
}
public void Reset() {
Current = -1;
}
public void Dispose() {}
}
public static class SpanIndexEnumeratorExtensions {
public static SpanIndexEnumerator<T> IndicesOf<T>(this ReadOnlySpan<T> span, SearchValues<T> searchValues) where T : IEquatable<T> {
return new SpanIndexEnumerator<T>(span, searchValues);
}
}

View File

@@ -1,12 +1,12 @@
@page "/instances/{InstanceGuid:guid}" @page "/instances/{InstanceGuid:guid}"
@attribute [Authorize(Permission.ViewInstancesPolicy)] @attribute [Authorize(Permission.ViewInstancesPolicy)]
@using Phantom.Common.Data.Instance
@using Phantom.Common.Data.Replies @using Phantom.Common.Data.Replies
@using Phantom.Common.Data.Web.Instance @using Phantom.Common.Data.Web.Instance
@using Phantom.Common.Data.Web.Users @using Phantom.Common.Data.Web.Users
@using Phantom.Utils.Result @using Phantom.Utils.Result
@using Phantom.Common.Data.Instance
@using Phantom.Web.Services.Instances
@using Phantom.Web.Services.Authorization @using Phantom.Web.Services.Authorization
@using Phantom.Web.Services.Instances
@inherits PhantomComponent @inherits PhantomComponent
@inject InstanceManager InstanceManager @inject InstanceManager InstanceManager
@@ -101,11 +101,11 @@
lastError = launchInstanceResult.ToSentence(); lastError = launchInstanceResult.ToSentence();
break; break;
case Err<UserInstanceActionFailure>(OfInstanceActionFailure(var failure)): case Err<UserInstanceActionFailure>(UserInstanceActionFailure.Instance(var failure)):
lastError = failure.ToSentence(); lastError = failure.ToSentence();
break; break;
case Err<UserInstanceActionFailure>(OfUserActionFailure(UserActionFailure.NotAuthorized)): case Err<UserInstanceActionFailure>(UserInstanceActionFailure.User(UserActionFailure.NotAuthorized)):
lastError = "You do not have permission to launch this instance."; lastError = "You do not have permission to launch this instance.";
break; break;

View File

@@ -3,7 +3,6 @@
@using System.Security.Cryptography @using System.Security.Cryptography
@using Phantom.Common.Data @using Phantom.Common.Data
@using Phantom.Common.Data.Web.Users @using Phantom.Common.Data.Web.Users
@using Phantom.Common.Data.Web.Users.CreateOrUpdateAdministratorUserResults
@using Phantom.Common.Messages.Web.ToController @using Phantom.Common.Messages.Web.ToController
@using Phantom.Utils.Cryptography @using Phantom.Utils.Cryptography
@using Phantom.Web.Services @using Phantom.Web.Services
@@ -91,12 +90,12 @@
private async Task<Result<string>> CreateOrUpdateAdministrator() { private async Task<Result<string>> CreateOrUpdateAdministrator() {
var reply = await ControllerConnection.Send<CreateOrUpdateAdministratorUserMessage, CreateOrUpdateAdministratorUserResult>(new CreateOrUpdateAdministratorUserMessage(form.Username, form.Password), Timeout.InfiniteTimeSpan); var reply = await ControllerConnection.Send<CreateOrUpdateAdministratorUserMessage, CreateOrUpdateAdministratorUserResult>(new CreateOrUpdateAdministratorUserMessage(form.Username, form.Password), Timeout.InfiniteTimeSpan);
return reply switch { return reply switch {
Success => Result.Ok, CreateOrUpdateAdministratorUserResult.Success => Result.Ok,
CreationFailed fail => fail.Error.ToSentences("\n"), CreateOrUpdateAdministratorUserResult.CreationFailed fail => fail.Error.ToSentences("\n"),
UpdatingFailed fail => fail.Error.ToSentences("\n"), CreateOrUpdateAdministratorUserResult.UpdatingFailed fail => fail.Error.ToSentences("\n"),
AddingToRoleFailed => "Could not assign administrator role to user.", CreateOrUpdateAdministratorUserResult.AddingToRoleFailed => "Could not assign administrator role to user.",
null => "Timed out.", null => "Timed out.",
_ => "Unknown error.", _ => "Unknown error.",
}; };
} }

View File

@@ -63,7 +63,7 @@ try {
MaxConcurrentlyHandledMessages: 100 MaxConcurrentlyHandledMessages: 100
); );
using var rpcClient = await RpcClient<IMessageToController, IMessageToWeb>.Connect("Controller", rpcClientConnectionParameters, WebMessageRegistries.Definitions, shutdownCancellationToken); using var rpcClient = await RpcClient<IMessageToController, IMessageToWeb>.Connect("Controller", rpcClientConnectionParameters, WebMessageRegistries.Registries, shutdownCancellationToken);
if (rpcClient == null) { if (rpcClient == null) {
PhantomLogger.Root.Fatal("Could not connect to Phantom Controller, shutting down."); PhantomLogger.Root.Fatal("Could not connect to Phantom Controller, shutting down.");
return 1; return 1;

View File

@@ -351,11 +351,11 @@
form.SubmitModel.StopSubmitting(createOrUpdateInstanceResult.ToSentence()); form.SubmitModel.StopSubmitting(createOrUpdateInstanceResult.ToSentence());
break; break;
case Err<UserInstanceActionFailure>(OfInstanceActionFailure(var failure)): case Err<UserInstanceActionFailure>(UserInstanceActionFailure.Instance(var failure)):
form.SubmitModel.StopSubmitting(failure.ToSentence()); form.SubmitModel.StopSubmitting(failure.ToSentence());
break; break;
case Err<UserInstanceActionFailure>(OfUserActionFailure(UserActionFailure.NotAuthorized)): case Err<UserInstanceActionFailure>(UserInstanceActionFailure.User(UserActionFailure.NotAuthorized)):
form.SubmitModel.StopSubmitting("You do not have permission to create or edit instances."); form.SubmitModel.StopSubmitting("You do not have permission to create or edit instances.");
break; break;

View File

@@ -2,7 +2,7 @@
@using Phantom.Common.Data.Web.Users @using Phantom.Common.Data.Web.Users
@using Phantom.Utils.Result @using Phantom.Utils.Result
@using Phantom.Web.Services.Instances @using Phantom.Web.Services.Instances
@inherits Phantom.Web.Components.PhantomComponent @inherits PhantomComponent
@inject InstanceManager InstanceManager @inject InstanceManager InstanceManager
<Form Model="form" OnSubmit="ExecuteCommand"> <Form Model="form" OnSubmit="ExecuteCommand">
@@ -49,11 +49,11 @@
form.SubmitModel.StopSubmitting(sendCommandToInstanceResult.ToSentence()); form.SubmitModel.StopSubmitting(sendCommandToInstanceResult.ToSentence());
break; break;
case Err<UserInstanceActionFailure>(OfInstanceActionFailure(var failure)): case Err<UserInstanceActionFailure>(UserInstanceActionFailure.Instance(var failure)):
form.SubmitModel.StopSubmitting(failure.ToSentence()); form.SubmitModel.StopSubmitting(failure.ToSentence());
break; break;
case Err<UserInstanceActionFailure>(OfUserActionFailure(UserActionFailure.NotAuthorized)): case Err<UserInstanceActionFailure>(UserInstanceActionFailure.User(UserActionFailure.NotAuthorized)):
form.SubmitModel.StopSubmitting("You do not have permission to send commands to this instance."); form.SubmitModel.StopSubmitting("You do not have permission to send commands to this instance.");
break; break;

View File

@@ -1,10 +1,10 @@
@using Phantom.Common.Data.Replies @using System.ComponentModel.DataAnnotations
@using Phantom.Common.Data.Minecraft
@using Phantom.Common.Data.Replies
@using Phantom.Common.Data.Web.Users @using Phantom.Common.Data.Web.Users
@using Phantom.Utils.Result @using Phantom.Utils.Result
@using Phantom.Web.Services.Instances @using Phantom.Web.Services.Instances
@using System.ComponentModel.DataAnnotations @inherits PhantomComponent
@using Phantom.Common.Data.Minecraft
@inherits Phantom.Web.Components.PhantomComponent
@inject IJSRuntime Js; @inject IJSRuntime Js;
@inject InstanceManager InstanceManager; @inject InstanceManager InstanceManager;
@@ -66,11 +66,11 @@
form.SubmitModel.StopSubmitting(stopInstanceResult.ToSentence()); form.SubmitModel.StopSubmitting(stopInstanceResult.ToSentence());
break; break;
case Err<UserInstanceActionFailure>(OfInstanceActionFailure(var failure)): case Err<UserInstanceActionFailure>(UserInstanceActionFailure.Instance(var failure)):
form.SubmitModel.StopSubmitting(failure.ToSentence()); form.SubmitModel.StopSubmitting(failure.ToSentence());
break; break;
case Err<UserInstanceActionFailure>(OfUserActionFailure(UserActionFailure.NotAuthorized)): case Err<UserInstanceActionFailure>(UserInstanceActionFailure.User(UserActionFailure.NotAuthorized)):
form.SubmitModel.StopSubmitting("You do not have permission to stop this instance."); form.SubmitModel.StopSubmitting("You do not have permission to stop this instance.");
break; break;

View File

@@ -1,9 +1,8 @@
@using Phantom.Common.Data.Web.Users @using System.ComponentModel.DataAnnotations
@using Phantom.Common.Data.Web.Users.CreateUserResults @using Phantom.Common.Data.Web.Users
@using Phantom.Utils.Result @using Phantom.Utils.Result
@using Phantom.Web.Services.Users @using Phantom.Web.Services.Users
@using System.ComponentModel.DataAnnotations @inherits PhantomComponent
@inherits Phantom.Web.Components.PhantomComponent
@inject IJSRuntime Js; @inject IJSRuntime Js;
@inject UserManager UserManager; @inject UserManager UserManager;
@@ -56,13 +55,13 @@
var result = await UserManager.Create(await GetAuthenticatedUser(), form.Username, form.Password, CancellationToken); var result = await UserManager.Create(await GetAuthenticatedUser(), form.Username, form.Password, CancellationToken);
switch (result.Variant()) { switch (result.Variant()) {
case Ok<CreateUserResult>(Success success): case Ok<CreateUserResult>(CreateUserResult.Success success):
await UserAdded.InvokeAsync(success.User); await UserAdded.InvokeAsync(success.User);
await Js.InvokeVoidAsync("closeModal", ModalId); await Js.InvokeVoidAsync("closeModal", ModalId);
form.SubmitModel.StopSubmitting(); form.SubmitModel.StopSubmitting();
break; break;
case Ok<CreateUserResult>(CreationFailed fail): case Ok<CreateUserResult>(CreateUserResult.CreationFailed fail):
form.SubmitModel.StopSubmitting(fail.Error.ToSentences("\n")); form.SubmitModel.StopSubmitting(fail.Error.ToSentences("\n"));
break; break;

View File

@@ -2,47 +2,42 @@
using Phantom.Common.Data.Replies; using Phantom.Common.Data.Replies;
using Phantom.Common.Data.Web.Minecraft; using Phantom.Common.Data.Web.Minecraft;
using Phantom.Common.Data.Web.Users; using Phantom.Common.Data.Web.Users;
using Phantom.Common.Data.Web.Users.AddUserErrors;
using Phantom.Common.Data.Web.Users.PasswordRequirementViolations;
using Phantom.Common.Data.Web.Users.SetUserPasswordErrors;
using Phantom.Common.Data.Web.Users.UsernameRequirementViolations;
using PasswordIsInvalid = Phantom.Common.Data.Web.Users.AddUserErrors.PasswordIsInvalid;
namespace Phantom.Web.Utils; namespace Phantom.Web.Utils;
static class Messages { static class Messages {
public static string ToSentences(this AddUserError error, string delimiter) { public static string ToSentences(this AddUserError error, string delimiter) {
return error switch { return error switch {
NameIsInvalid e => e.Violation.ToSentence(), AddUserError.NameIsInvalid e => e.Violation.ToSentence(),
PasswordIsInvalid e => string.Join(delimiter, e.Violations.Select(static v => v.ToSentence())), AddUserError.PasswordIsInvalid e => string.Join(delimiter, e.Violations.Select(static v => v.ToSentence())),
NameAlreadyExists => "Username is already occupied.", AddUserError.NameAlreadyExists => "Username is already occupied.",
_ => "Unknown error.", _ => "Unknown error.",
}; };
} }
public static string ToSentences(this SetUserPasswordError error, string delimiter) { public static string ToSentences(this SetUserPasswordError error, string delimiter) {
return error switch { return error switch {
UserNotFound => "User not found.", SetUserPasswordError.UserNotFound => "User not found.",
Common.Data.Web.Users.SetUserPasswordErrors.PasswordIsInvalid e => string.Join(delimiter, e.Violations.Select(static v => v.ToSentence())), SetUserPasswordError.PasswordIsInvalid e => string.Join(delimiter, e.Violations.Select(static v => v.ToSentence())),
_ => "Unknown error.", _ => "Unknown error.",
}; };
} }
public static string ToSentence(this UsernameRequirementViolation violation) { public static string ToSentence(this UsernameRequirementViolation violation) {
return violation switch { return violation switch {
IsEmpty => "Username must not be empty.", UsernameRequirementViolation.IsEmpty => "Username must not be empty.",
TooLong v => "Username must not be longer than " + v.MaxLength + " character(s).", UsernameRequirementViolation.TooLong v => "Username must not be longer than " + v.MaxLength + " character(s).",
_ => "Unknown error.", _ => "Unknown error.",
}; };
} }
public static string ToSentence(this PasswordRequirementViolation violation) { public static string ToSentence(this PasswordRequirementViolation violation) {
return violation switch { return violation switch {
TooShort v => "Password must be at least " + v.MinimumLength + " character(s) long.", PasswordRequirementViolation.TooShort v => "Password must be at least " + v.MinimumLength + " character(s) long.",
MustContainLowercaseLetter => "Password must contain a lowercase letter.", PasswordRequirementViolation.MustContainLowercaseLetter => "Password must contain a lowercase letter.",
MustContainUppercaseLetter => "Password must contain an uppercase letter.", PasswordRequirementViolation.MustContainUppercaseLetter => "Password must contain an uppercase letter.",
MustContainDigit => "Password must contain a digit.", PasswordRequirementViolation.MustContainDigit => "Password must contain a digit.",
_ => "Unknown error.", _ => "Unknown error.",
}; };
} }

View File

@@ -53,10 +53,10 @@ static class WebLauncher {
application.UseExceptionHandler("/_Error"); application.UseExceptionHandler("/_Error");
} }
application.UseStaticFiles();
application.UseRouting(); application.UseRouting();
application.UsePhantomServices(); application.UsePhantomServices();
application.MapStaticAssets();
application.MapControllers(); application.MapControllers();
application.MapBlazorHub(); application.MapBlazorHub();
application.MapFallbackToPage("/_Host"); application.MapFallbackToPage("/_Host");