1
0
mirror of https://github.com/chylex/Discord-History-Tracker.git synced 2025-08-18 13:31:42 +02:00

5 Commits

28 changed files with 299 additions and 74 deletions

View File

@@ -1,4 +1,3 @@
using System;
using Avalonia; using Avalonia;
using Avalonia.Controls.ApplicationLifetimes; using Avalonia.Controls.ApplicationLifetimes;
using Avalonia.Markup.Xaml; using Avalonia.Markup.Xaml;
@@ -13,7 +12,7 @@ sealed class App : Application {
public override void OnFrameworkInitializationCompleted() { public override void OnFrameworkInitializationCompleted() {
if (ApplicationLifetime is IClassicDesktopStyleApplicationLifetime desktop) { if (ApplicationLifetime is IClassicDesktopStyleApplicationLifetime desktop) {
desktop.MainWindow = new MainWindow(new Arguments(desktop.Args ?? Array.Empty<string>())); desktop.MainWindow = new MainWindow(Program.Arguments);
} }
base.OnFrameworkInitializationCompleted(); base.OnFrameworkInitializationCompleted();

View File

@@ -5,26 +5,33 @@ namespace DHT.Desktop;
sealed class Arguments { sealed class Arguments {
private static readonly Log Log = Log.ForType<Arguments>(); private static readonly Log Log = Log.ForType<Arguments>();
private const int FirstArgument = 1;
public static Arguments Empty => new(Array.Empty<string>()); public static Arguments Empty => new(Array.Empty<string>());
public bool Console { get; }
public string? DatabaseFile { get; } public string? DatabaseFile { get; }
public ushort? ServerPort { get; } public ushort? ServerPort { get; }
public string? ServerToken { get; } public string? ServerToken { get; }
public Arguments(string[] args) { public Arguments(string[] args) {
for (int i = 0; i < args.Length; i++) { for (int i = FirstArgument; i < args.Length; i++) {
string key = args[i]; string key = args[i];
switch (key) { switch (key) {
case "-debug": case "-debug":
Log.IsDebugEnabled = true; Log.IsDebugEnabled = true;
continue; continue;
case "-console":
Console = true;
continue;
} }
string value; string value;
if (i == 0 && !key.StartsWith('-')) { if (i == FirstArgument && !key.StartsWith('-')) {
value = key; value = key;
key = "-db"; key = "-db";
} }

View File

@@ -65,6 +65,8 @@ sealed class ViewerPageModel : BaseModel, IDisposable {
string indexFile = await Resources.ReadTextAsync("Viewer/index.html"); string indexFile = await Resources.ReadTextAsync("Viewer/index.html");
string viewerTemplate = indexFile.Replace("/*[JS]*/", await Resources.ReadJoinedAsync("Viewer/scripts/", '\n')) string viewerTemplate = indexFile.Replace("/*[JS]*/", await Resources.ReadJoinedAsync("Viewer/scripts/", '\n'))
.Replace("/*[CSS]*/", await Resources.ReadJoinedAsync("Viewer/styles/", '\n')); .Replace("/*[CSS]*/", await Resources.ReadJoinedAsync("Viewer/styles/", '\n'));
viewerTemplate = strategy.ProcessViewerTemplate(viewerTemplate);
int viewerArchiveTagStart = viewerTemplate.IndexOf(ArchiveTag); int viewerArchiveTagStart = viewerTemplate.IndexOf(ArchiveTag);
int viewerArchiveTagEnd = viewerArchiveTagStart + ArchiveTag.Length; int viewerArchiveTagEnd = viewerArchiveTagStart + ArchiveTag.Length;

View File

@@ -1,6 +1,8 @@
using System.Globalization; using System;
using System.Globalization;
using System.Reflection; using System.Reflection;
using Avalonia; using Avalonia;
using DHT.Utils.Logging;
using DHT.Utils.Resources; using DHT.Utils.Resources;
namespace DHT.Desktop; namespace DHT.Desktop;
@@ -9,6 +11,7 @@ static class Program {
public static string Version { get; } public static string Version { get; }
public static CultureInfo Culture { get; } public static CultureInfo Culture { get; }
public static ResourceLoader Resources { get; } public static ResourceLoader Resources { get; }
public static Arguments Arguments { get; }
static Program() { static Program() {
var assembly = Assembly.GetExecutingAssembly(); var assembly = Assembly.GetExecutingAssembly();
@@ -25,10 +28,21 @@ static class Program {
CultureInfo.DefaultThreadCurrentUICulture = CultureInfo.InvariantCulture; CultureInfo.DefaultThreadCurrentUICulture = CultureInfo.InvariantCulture;
Resources = new ResourceLoader(assembly); Resources = new ResourceLoader(assembly);
Arguments = new Arguments(Environment.GetCommandLineArgs());
} }
public static void Main(string[] args) { public static void Main(string[] args) {
BuildAvaloniaApp().StartWithClassicDesktopLifetime(args); if (Arguments.Console && OperatingSystem.IsWindows()) {
WindowsConsole.AllocConsole();
}
try {
BuildAvaloniaApp().StartWithClassicDesktopLifetime(args);
} finally {
if (Arguments.Console && OperatingSystem.IsWindows()) {
WindowsConsole.FreeConsole();
}
}
} }
private static AppBuilder BuildAvaloniaApp() { private static AppBuilder BuildAvaloniaApp() {

View File

@@ -6,6 +6,8 @@
<script type="text/javascript"> <script type="text/javascript">
window.DHT_EMBEDDED = "/*[ARCHIVE]*/"; window.DHT_EMBEDDED = "/*[ARCHIVE]*/";
window.DHT_SERVER_URL = "/*[SERVER_URL]*/";
window.DHT_SERVER_TOKEN = "/*[SERVER_TOKEN]*/";
/*[JS]*/ /*[JS]*/
</script> </script>
<style> <style>

View File

@@ -182,15 +182,32 @@ const STATE = (function() {
return null; return null;
}; };
const getMessageList = function() { const getMessageList = async function(abortSignal) {
if (!loadedMessages) { if (!loadedMessages) {
return []; return [];
} }
const messages = getMessages(selectedChannel); const messages = getMessages(selectedChannel);
const startIndex = messagesPerPage * (root.getCurrentPage() - 1); const startIndex = messagesPerPage * (root.getCurrentPage() - 1);
const slicedMessages = loadedMessages.slice(startIndex, !messagesPerPage ? undefined : startIndex + messagesPerPage);
return loadedMessages.slice(startIndex, !messagesPerPage ? undefined : startIndex + messagesPerPage).map(key => { let messageTexts = null;
if (window.DHT_SERVER_URL !== null) {
const messageIds = new Set(slicedMessages);
for (const key of slicedMessages) {
const message = messages[key];
if ("r" in message) {
messageIds.add(message.r);
}
}
messageTexts = await getMessageTextsFromServer(messageIds, abortSignal);
}
return slicedMessages.map(key => {
/** /**
* @type {{}} * @type {{}}
* @property {Number} u * @property {Number} u
@@ -216,6 +233,9 @@ const STATE = (function() {
if ("m" in message) { if ("m" in message) {
obj["contents"] = message.m; obj["contents"] = message.m;
} }
else if (messageTexts && key in messageTexts) {
obj["contents"] = messageTexts[key];
}
if ("e" in message) { if ("e" in message) {
obj["embeds"] = message.e.map(embed => JSON.parse(embed)); obj["embeds"] = message.e.map(embed => JSON.parse(embed));
@@ -230,15 +250,16 @@ const STATE = (function() {
} }
if ("r" in message) { if ("r" in message) {
const replyMessage = getMessageById(message.r); const replyId = message.r;
const replyMessage = getMessageById(replyId);
const replyUser = replyMessage ? getUser(replyMessage.u) : null; const replyUser = replyMessage ? getUser(replyMessage.u) : null;
const replyAvatar = replyUser && replyUser.avatar ? { id: getUserId(replyMessage.u), path: replyUser.avatar } : null; const replyAvatar = replyUser && replyUser.avatar ? { id: getUserId(replyMessage.u), path: replyUser.avatar } : null;
obj["reply"] = replyMessage ? { obj["reply"] = replyMessage ? {
"id": message.r, "id": replyId,
"user": replyUser, "user": replyUser,
"avatar": replyAvatar, "avatar": replyAvatar,
"contents": replyMessage.m "contents": messageTexts != null && replyId in messageTexts ? messageTexts[replyId] : replyMessage.m,
} : null; } : null;
} }
@@ -250,9 +271,35 @@ const STATE = (function() {
}); });
}; };
const getMessageTextsFromServer = async function(messageIds, abortSignal) {
let idParams = "";
for (const messageId of messageIds) {
idParams += "id=" + encodeURIComponent(messageId) + "&";
}
const response = await fetch(DHT_SERVER_URL + "/get-messages?" + idParams + "token=" + encodeURIComponent(DHT_SERVER_TOKEN), {
method: "GET",
headers: {
"Content-Type": "application/json",
},
credentials: "omit",
redirect: "error",
signal: abortSignal
});
if (response.status === 200) {
return response.json();
}
else {
throw new Error("Server returned status " + response.status + " " + response.statusText);
}
};
let eventOnUsersRefreshed; let eventOnUsersRefreshed;
let eventOnChannelsRefreshed; let eventOnChannelsRefreshed;
let eventOnMessagesRefreshed; let eventOnMessagesRefreshed;
let messageLoaderAborter = null;
const triggerUsersRefreshed = function() { const triggerUsersRefreshed = function() {
eventOnUsersRefreshed && eventOnUsersRefreshed(getUserList()); eventOnUsersRefreshed && eventOnUsersRefreshed(getUserList());
@@ -263,7 +310,22 @@ const STATE = (function() {
}; };
const triggerMessagesRefreshed = function() { const triggerMessagesRefreshed = function() {
eventOnMessagesRefreshed && eventOnMessagesRefreshed(getMessageList()); if (!eventOnMessagesRefreshed) {
return;
}
if (messageLoaderAborter != null) {
messageLoaderAborter.abort();
}
const aborter = new AbortController();
messageLoaderAborter = aborter;
getMessageList(aborter.signal).then(eventOnMessagesRefreshed).finally(() => {
if (messageLoaderAborter === aborter) {
messageLoaderAborter = null;
}
});
}; };
const getFilteredMessageKeys = function(channel) { const getFilteredMessageKeys = function(channel) {

View File

@@ -44,7 +44,7 @@ public sealed class DummyDatabaseFile : IDatabaseFile {
return 0; return 0;
} }
public List<Message> GetMessages(MessageFilter? filter = null) { public List<Message> GetMessages(MessageFilter? filter = null, bool includeText = true) {
return new(); return new();
} }

View File

@@ -3,5 +3,7 @@ using DHT.Server.Data;
namespace DHT.Server.Database.Export.Strategy; namespace DHT.Server.Database.Export.Strategy;
public interface IViewerExportStrategy { public interface IViewerExportStrategy {
bool IncludeMessageText { get; }
string ProcessViewerTemplate(string template);
string GetAttachmentUrl(Attachment attachment); string GetAttachmentUrl(Attachment attachment);
} }

View File

@@ -12,6 +12,13 @@ public sealed class LiveViewerExportStrategy : IViewerExportStrategy {
this.safeToken = WebUtility.UrlEncode(token); this.safeToken = WebUtility.UrlEncode(token);
} }
public bool IncludeMessageText => false;
public string ProcessViewerTemplate(string template) {
return template.Replace("/*[SERVER_URL]*/", "http://127.0.0.1:" + safePort)
.Replace("/*[SERVER_TOKEN]*/", WebUtility.UrlEncode(safeToken));
}
public string GetAttachmentUrl(Attachment attachment) { public string GetAttachmentUrl(Attachment attachment) {
return "http://127.0.0.1:" + safePort + "/get-attachment/" + WebUtility.UrlEncode(attachment.NormalizedUrl) + "?token=" + safeToken; return "http://127.0.0.1:" + safePort + "/get-attachment/" + WebUtility.UrlEncode(attachment.NormalizedUrl) + "?token=" + safeToken;
} }

View File

@@ -7,6 +7,13 @@ public sealed class StandaloneViewerExportStrategy : IViewerExportStrategy {
private StandaloneViewerExportStrategy() {} private StandaloneViewerExportStrategy() {}
public bool IncludeMessageText => true;
public string ProcessViewerTemplate(string template) {
return template.Replace("\"/*[SERVER_URL]*/\"", "null")
.Replace("\"/*[SERVER_TOKEN]*/\"", "null");
}
public string GetAttachmentUrl(Attachment attachment) { public string GetAttachmentUrl(Attachment attachment) {
// The normalized URL will not load files from Discord CDN once the time limit is enforced. // The normalized URL will not load files from Discord CDN once the time limit is enforced.

View File

@@ -21,7 +21,7 @@ public static class ViewerJsonExport {
var includedChannelIds = new HashSet<ulong>(); var includedChannelIds = new HashSet<ulong>();
var includedServerIds = new HashSet<ulong>(); var includedServerIds = new HashSet<ulong>();
var includedMessages = db.GetMessages(filter); var includedMessages = db.GetMessages(filter, strategy.IncludeMessageText);
var includedChannels = new List<Channel>(); var includedChannels = new List<Channel>();
foreach (var message in includedMessages) { foreach (var message in includedMessages) {

View File

@@ -23,7 +23,7 @@ public interface IDatabaseFile : IDisposable {
void AddMessages(Message[] messages); void AddMessages(Message[] messages);
int CountMessages(MessageFilter? filter = null); int CountMessages(MessageFilter? filter = null);
List<Message> GetMessages(MessageFilter? filter = null); List<Message> GetMessages(MessageFilter? filter = null, bool includeText = true);
HashSet<ulong> GetMessageIds(MessageFilter? filter = null); HashSet<ulong> GetMessageIds(MessageFilter? filter = null);
void RemoveMessages(MessageFilter filter, FilterRemovalMode mode); void RemoveMessages(MessageFilter filter, FilterRemovalMode mode);

View File

@@ -360,7 +360,7 @@ public sealed class SqliteDatabaseFile : IDatabaseFile {
return reader.Read() ? reader.GetInt32(0) : 0; return reader.Read() ? reader.GetInt32(0) : 0;
} }
public List<Message> GetMessages(MessageFilter? filter = null) { public List<Message> GetMessages(MessageFilter? filter = null, bool includeText = true) {
var perf = log.Start(); var perf = log.Start();
var list = new List<Message>(); var list = new List<Message>();
@@ -370,7 +370,7 @@ public sealed class SqliteDatabaseFile : IDatabaseFile {
using var conn = pool.Take(); using var conn = pool.Take();
using var cmd = conn.Command($""" using var cmd = conn.Command($"""
SELECT m.message_id, m.sender_id, m.channel_id, m.text, m.timestamp, et.edit_timestamp, rt.replied_to_id SELECT m.message_id, m.sender_id, m.channel_id, {(includeText ? "m.text" : "NULL")}, m.timestamp, et.edit_timestamp, rt.replied_to_id
FROM messages m FROM messages m
LEFT JOIN edit_timestamps et ON m.message_id = et.message_id LEFT JOIN edit_timestamps et ON m.message_id = et.message_id
LEFT JOIN replied_to rt ON m.message_id = rt.message_id LEFT JOIN replied_to rt ON m.message_id = rt.message_id
@@ -385,7 +385,7 @@ public sealed class SqliteDatabaseFile : IDatabaseFile {
Id = id, Id = id,
Sender = reader.GetUint64(1), Sender = reader.GetUint64(1),
Channel = reader.GetUint64(2), Channel = reader.GetUint64(2),
Text = reader.GetString(3), Text = includeText ? reader.GetString(3) : string.Empty,
Timestamp = reader.GetInt64(4), Timestamp = reader.GetInt64(4),
EditTimestamp = reader.IsDBNull(5) ? null : reader.GetInt64(5), EditTimestamp = reader.IsDBNull(5) ? null : reader.GetInt64(5),
RepliedToId = reader.IsDBNull(6) ? null : reader.GetUint64(6), RepliedToId = reader.IsDBNull(6) ? null : reader.GetUint64(6),

View File

@@ -3,12 +3,9 @@ using System.Net;
using System.Text.Json; using System.Text.Json;
using System.Threading.Tasks; using System.Threading.Tasks;
using DHT.Server.Database; using DHT.Server.Database;
using DHT.Server.Service;
using DHT.Utils.Http; using DHT.Utils.Http;
using DHT.Utils.Logging; using DHT.Utils.Logging;
using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Extensions;
using Microsoft.Extensions.Primitives;
namespace DHT.Server.Endpoints; namespace DHT.Server.Endpoints;
@@ -16,25 +13,14 @@ abstract class BaseEndpoint {
private static readonly Log Log = Log.ForType<BaseEndpoint>(); private static readonly Log Log = Log.ForType<BaseEndpoint>();
protected IDatabaseFile Db { get; } protected IDatabaseFile Db { get; }
protected ServerParameters Parameters { get; }
protected BaseEndpoint(IDatabaseFile db, ServerParameters parameters) { protected BaseEndpoint(IDatabaseFile db) {
this.Db = db; this.Db = db;
this.Parameters = parameters;
} }
private async Task Handle(HttpContext ctx, StringValues token) { public async Task Handle(HttpContext ctx) {
var request = ctx.Request;
var response = ctx.Response; var response = ctx.Response;
Log.Info("Request: " + request.GetDisplayUrl() + " (" + request.ContentLength + " B)");
if (token.Count != 1 || token[0] != Parameters.Token) {
Log.Error("Token: " + (token.Count == 1 ? token[0] : "<missing>"));
response.StatusCode = (int) HttpStatusCode.Forbidden;
return;
}
try { try {
response.StatusCode = (int) HttpStatusCode.OK; response.StatusCode = (int) HttpStatusCode.OK;
var output = await Respond(ctx); var output = await Respond(ctx);
@@ -49,14 +35,6 @@ abstract class BaseEndpoint {
} }
} }
public async Task HandleGet(HttpContext ctx) {
await Handle(ctx, ctx.Request.Query["token"]);
}
public async Task HandlePost(HttpContext ctx) {
await Handle(ctx, ctx.Request.Headers["X-DHT-Token"]);
}
protected abstract Task<IHttpOutput> Respond(HttpContext ctx); protected abstract Task<IHttpOutput> Respond(HttpContext ctx);
protected static async Task<JsonElement> ReadJson(HttpContext ctx) { protected static async Task<JsonElement> ReadJson(HttpContext ctx) {

View File

@@ -2,14 +2,13 @@ using System.Net;
using System.Threading.Tasks; using System.Threading.Tasks;
using DHT.Server.Data; using DHT.Server.Data;
using DHT.Server.Database; using DHT.Server.Database;
using DHT.Server.Service;
using DHT.Utils.Http; using DHT.Utils.Http;
using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http;
namespace DHT.Server.Endpoints; namespace DHT.Server.Endpoints;
sealed class GetAttachmentEndpoint : BaseEndpoint { sealed class GetAttachmentEndpoint : BaseEndpoint {
public GetAttachmentEndpoint(IDatabaseFile db, ServerParameters parameters) : base(db, parameters) {} public GetAttachmentEndpoint(IDatabaseFile db) : base(db) {}
protected override Task<IHttpOutput> Respond(HttpContext ctx) { protected override Task<IHttpOutput> Respond(HttpContext ctx) {
string attachmentUrl = WebUtility.UrlDecode((string) ctx.Request.RouteValues["url"]!); string attachmentUrl = WebUtility.UrlDecode((string) ctx.Request.RouteValues["url"]!);

View File

@@ -0,0 +1,34 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Net;
using System.Threading.Tasks;
using DHT.Server.Data.Filters;
using DHT.Server.Database;
using DHT.Utils.Http;
using Microsoft.AspNetCore.Http;
using GetMessagesJsonContext = DHT.Server.Endpoints.Responses.GetMessagesJsonContext;
namespace DHT.Server.Endpoints;
sealed class GetMessagesEndpoint : BaseEndpoint {
public GetMessagesEndpoint(IDatabaseFile db) : base(db) {}
protected override Task<IHttpOutput> Respond(HttpContext ctx) {
HashSet<ulong> messageIdSet;
try {
var messageIds = ctx.Request.Query["id"];
messageIdSet = messageIds.Select(ulong.Parse!).ToHashSet();
} catch (Exception) {
throw new HttpException(HttpStatusCode.BadRequest, "Invalid message ids.");
}
var messageFilter = new MessageFilter {
MessageIds = messageIdSet
};
var messages = Db.GetMessages(messageFilter).ToDictionary(static message => message.Id, static message => message.Text);
var response = new HttpOutput.Json<Dictionary<ulong, string>>(messages, GetMessagesJsonContext.Default.DictionaryUInt64String);
return Task.FromResult<IHttpOutput>(response);
}
}

View File

@@ -13,12 +13,16 @@ namespace DHT.Server.Endpoints;
sealed class GetTrackingScriptEndpoint : BaseEndpoint { sealed class GetTrackingScriptEndpoint : BaseEndpoint {
private static ResourceLoader Resources { get; } = new (Assembly.GetExecutingAssembly()); private static ResourceLoader Resources { get; } = new (Assembly.GetExecutingAssembly());
public GetTrackingScriptEndpoint(IDatabaseFile db, ServerParameters parameters) : base(db, parameters) {} private readonly ServerParameters serverParameters;
public GetTrackingScriptEndpoint(IDatabaseFile db, ServerParameters parameters) : base(db) {
serverParameters = parameters;
}
protected override async Task<IHttpOutput> Respond(HttpContext ctx) { protected override async Task<IHttpOutput> Respond(HttpContext ctx) {
string bootstrap = await Resources.ReadTextAsync("Tracker/bootstrap.js"); string bootstrap = await Resources.ReadTextAsync("Tracker/bootstrap.js");
string script = bootstrap.Replace("= 0; /*[PORT]*/", "= " + Parameters.Port + ";") string script = bootstrap.Replace("= 0; /*[PORT]*/", "= " + serverParameters.Port + ";")
.Replace("/*[TOKEN]*/", HttpUtility.JavaScriptStringEncode(Parameters.Token)) .Replace("/*[TOKEN]*/", HttpUtility.JavaScriptStringEncode(serverParameters.Token))
.Replace("/*[IMPORTS]*/", await Resources.ReadJoinedAsync("Tracker/scripts/", '\n')) .Replace("/*[IMPORTS]*/", await Resources.ReadJoinedAsync("Tracker/scripts/", '\n'))
.Replace("/*[CSS-CONTROLLER]*/", await Resources.ReadTextAsync("Tracker/styles/controller.css")) .Replace("/*[CSS-CONTROLLER]*/", await Resources.ReadTextAsync("Tracker/styles/controller.css"))
.Replace("/*[CSS-SETTINGS]*/", await Resources.ReadTextAsync("Tracker/styles/settings.css")) .Replace("/*[CSS-SETTINGS]*/", await Resources.ReadTextAsync("Tracker/styles/settings.css"))

View File

@@ -0,0 +1,8 @@
using System.Collections.Generic;
using System.Text.Json.Serialization;
namespace DHT.Server.Endpoints.Responses;
[JsonSourceGenerationOptions(PropertyNamingPolicy = JsonKnownNamingPolicy.CamelCase, GenerationMode = JsonSourceGenerationMode.Default)]
[JsonSerializable(typeof(Dictionary<ulong, string>))]
sealed partial class GetMessagesJsonContext : JsonSerializerContext {}

View File

@@ -3,14 +3,13 @@ using System.Text.Json;
using System.Threading.Tasks; using System.Threading.Tasks;
using DHT.Server.Data; using DHT.Server.Data;
using DHT.Server.Database; using DHT.Server.Database;
using DHT.Server.Service;
using DHT.Utils.Http; using DHT.Utils.Http;
using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http;
namespace DHT.Server.Endpoints; namespace DHT.Server.Endpoints;
sealed class TrackChannelEndpoint : BaseEndpoint { sealed class TrackChannelEndpoint : BaseEndpoint {
public TrackChannelEndpoint(IDatabaseFile db, ServerParameters parameters) : base(db, parameters) {} public TrackChannelEndpoint(IDatabaseFile db) : base(db) {}
protected override async Task<IHttpOutput> Respond(HttpContext ctx) { protected override async Task<IHttpOutput> Respond(HttpContext ctx) {
var root = await ReadJson(ctx); var root = await ReadJson(ctx);

View File

@@ -9,7 +9,6 @@ using DHT.Server.Data;
using DHT.Server.Data.Filters; using DHT.Server.Data.Filters;
using DHT.Server.Database; using DHT.Server.Database;
using DHT.Server.Download; using DHT.Server.Download;
using DHT.Server.Service;
using DHT.Utils.Collections; using DHT.Utils.Collections;
using DHT.Utils.Http; using DHT.Utils.Http;
using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http;
@@ -20,7 +19,7 @@ sealed class TrackMessagesEndpoint : BaseEndpoint {
private const string HasNewMessages = "1"; private const string HasNewMessages = "1";
private const string NoNewMessages = "0"; private const string NoNewMessages = "0";
public TrackMessagesEndpoint(IDatabaseFile db, ServerParameters parameters) : base(db, parameters) {} public TrackMessagesEndpoint(IDatabaseFile db) : base(db) {}
protected override async Task<IHttpOutput> Respond(HttpContext ctx) { protected override async Task<IHttpOutput> Respond(HttpContext ctx) {
var root = await ReadJson(ctx); var root = await ReadJson(ctx);

View File

@@ -3,14 +3,13 @@ using System.Text.Json;
using System.Threading.Tasks; using System.Threading.Tasks;
using DHT.Server.Data; using DHT.Server.Data;
using DHT.Server.Database; using DHT.Server.Database;
using DHT.Server.Service;
using DHT.Utils.Http; using DHT.Utils.Http;
using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http;
namespace DHT.Server.Endpoints; namespace DHT.Server.Endpoints;
sealed class TrackUsersEndpoint : BaseEndpoint { sealed class TrackUsersEndpoint : BaseEndpoint {
public TrackUsersEndpoint(IDatabaseFile db, ServerParameters parameters) : base(db, parameters) {} public TrackUsersEndpoint(IDatabaseFile db) : base(db) {}
protected override async Task<IHttpOutput> Respond(HttpContext ctx) { protected override async Task<IHttpOutput> Respond(HttpContext ctx) {
var root = await ReadJson(ctx); var root = await ReadJson(ctx);

View File

@@ -0,0 +1,44 @@
using System.Net;
using System.Threading.Tasks;
using DHT.Utils.Logging;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Primitives;
namespace DHT.Server.Service.Middlewares;
sealed class ServerAuthorizationMiddleware {
private static readonly Log Log = Log.ForType<ServerAuthorizationMiddleware>();
private readonly RequestDelegate next;
private readonly ServerParameters serverParameters;
public ServerAuthorizationMiddleware(RequestDelegate next, ServerParameters serverParameters) {
this.next = next;
this.serverParameters = serverParameters;
}
public async Task InvokeAsync(HttpContext context) {
var request = context.Request;
bool success = HttpMethods.IsGet(request.Method)
? CheckToken(request.Query["token"])
: CheckToken(request.Headers["X-DHT-Token"]);
if (success) {
await next(context);
}
else {
context.Response.StatusCode = (int) HttpStatusCode.Forbidden;
}
}
private bool CheckToken(StringValues token) {
if (token.Count == 1 && token[0] == serverParameters.Token) {
return true;
}
else {
Log.Error("Invalid token: " + (token.Count == 1 ? token[0] : "<missing>"));
return false;
}
}
}

View File

@@ -0,0 +1,29 @@
using System.Diagnostics;
using System.Threading.Tasks;
using DHT.Utils.Logging;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Extensions;
namespace DHT.Server.Service.Middlewares;
sealed class ServerLoggingMiddleware {
private static readonly Log Log = Log.ForType<ServerLoggingMiddleware>();
private readonly RequestDelegate next;
public ServerLoggingMiddleware(RequestDelegate next) {
this.next = next;
}
public async Task InvokeAsync(HttpContext context) {
var stopwatch = Stopwatch.StartNew();
await next(context);
stopwatch.Stop();
var request = context.Request;
var requestLength = request.ContentLength ?? 0L;
var responseStatus = context.Response.StatusCode;
var elapsedMs = stopwatch.ElapsedMilliseconds;
Log.Debug("Request to " + request.GetEncodedPathAndQuery() + " (" + requestLength + " B) returned " + responseStatus + ", took " + elapsedMs + " ms");
}
}

View File

@@ -4,7 +4,6 @@ using System.Diagnostics.CodeAnalysis;
using System.Threading; using System.Threading;
using DHT.Server.Database; using DHT.Server.Database;
using DHT.Utils.Logging; using DHT.Utils.Logging;
using Microsoft.AspNetCore;
using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.Server.Kestrel.Core; using Microsoft.AspNetCore.Server.Kestrel.Core;
using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection;
@@ -75,11 +74,11 @@ public static class ServerLauncher {
options.ListenLocalhost(port, static listenOptions => listenOptions.Protocols = HttpProtocols.Http1); options.ListenLocalhost(port, static listenOptions => listenOptions.Protocols = HttpProtocols.Http1);
} }
Server = WebHost.CreateDefaultBuilder() Server = new WebHostBuilder()
.ConfigureServices(AddServices) .ConfigureServices(AddServices)
.UseKestrel(SetKestrelOptions) .UseKestrel(SetKestrelOptions)
.UseStartup<Startup>() .UseStartup<Startup>()
.Build(); .Build();
Server.Start(); Server.Start();

View File

@@ -2,6 +2,7 @@ using System.Diagnostics.CodeAnalysis;
using System.Text.Json.Serialization; using System.Text.Json.Serialization;
using DHT.Server.Database; using DHT.Server.Database;
using DHT.Server.Endpoints; using DHT.Server.Endpoints;
using DHT.Server.Service.Middlewares;
using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Http.Json; using Microsoft.AspNetCore.Http.Json;
using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection;
@@ -15,6 +16,7 @@ sealed class Startup {
"https://ptb.discord.com", "https://ptb.discord.com",
"https://canary.discord.com", "https://canary.discord.com",
"https://discordapp.com", "https://discordapp.com",
"null" // For file:// protocol in the Viewer
}; };
public void ConfigureServices(IServiceCollection services) { public void ConfigureServices(IServiceCollection services) {
@@ -27,27 +29,24 @@ sealed class Startup {
builder.WithOrigins(AllowedOrigins).AllowCredentials().AllowAnyMethod().AllowAnyHeader().WithExposedHeaders("X-DHT"); builder.WithOrigins(AllowedOrigins).AllowCredentials().AllowAnyMethod().AllowAnyHeader().WithExposedHeaders("X-DHT");
}); });
}); });
services.AddRoutingCore();
} }
[SuppressMessage("ReSharper", "UnusedMember.Global")] [SuppressMessage("ReSharper", "UnusedMember.Global")]
public void Configure(IApplicationBuilder app, IHostApplicationLifetime lifetime, IDatabaseFile db, ServerParameters parameters) { public void Configure(IApplicationBuilder app, IHostApplicationLifetime lifetime, IDatabaseFile db, ServerParameters parameters) {
app.UseRouting(); app.UseMiddleware<ServerLoggingMiddleware>();
app.UseCors(); app.UseCors();
app.UseMiddleware<ServerAuthorizationMiddleware>();
app.UseRouting();
app.UseEndpoints(endpoints => { app.UseEndpoints(endpoints => {
GetTrackingScriptEndpoint getTrackingScript = new (db, parameters); endpoints.MapGet("/get-tracking-script", new GetTrackingScriptEndpoint(db, parameters).Handle);
endpoints.MapGet("/get-tracking-script", context => getTrackingScript.HandleGet(context)); endpoints.MapGet("/get-messages", new GetMessagesEndpoint(db).Handle);
endpoints.MapGet("/get-attachment/{url}", new GetAttachmentEndpoint(db).Handle);
TrackChannelEndpoint trackChannel = new (db, parameters); endpoints.MapPost("/track-channel", new TrackChannelEndpoint(db).Handle);
endpoints.MapPost("/track-channel", context => trackChannel.HandlePost(context)); endpoints.MapPost("/track-users", new TrackUsersEndpoint(db).Handle);
endpoints.MapPost("/track-messages", new TrackMessagesEndpoint(db).Handle);
TrackUsersEndpoint trackUsers = new (db, parameters);
endpoints.MapPost("/track-users", context => trackUsers.HandlePost(context));
TrackMessagesEndpoint trackMessages = new (db, parameters);
endpoints.MapPost("/track-messages", context => trackMessages.HandlePost(context));
GetAttachmentEndpoint getAttachment = new (db, parameters);
endpoints.MapGet("/get-attachment/{url}", context => getAttachment.HandleGet(context));
}); });
} }
} }

View File

@@ -1,4 +1,5 @@
using System.Text; using System.Text;
using System.Text.Json.Serialization.Metadata;
using System.Threading.Tasks; using System.Threading.Tasks;
using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http;
@@ -25,6 +26,20 @@ public static class HttpOutput {
} }
} }
public sealed class Json<TValue> : IHttpOutput {
private readonly TValue value;
private readonly JsonTypeInfo<TValue> typeInfo;
public Json(TValue value, JsonTypeInfo<TValue> typeInfo) {
this.value = value;
this.typeInfo = typeInfo;
}
public Task WriteTo(HttpResponse response) {
return response.WriteAsJsonAsync(value, typeInfo);
}
}
public sealed class File : IHttpOutput { public sealed class File : IHttpOutput {
private readonly string? contentType; private readonly string? contentType;
private readonly byte[] bytes; private readonly byte[] bytes;

View File

@@ -0,0 +1,13 @@
using System.Runtime.InteropServices;
using System.Runtime.Versioning;
namespace DHT.Utils.Logging;
[SupportedOSPlatform("windows")]
public static partial class WindowsConsole {
[LibraryImport("kernel32.dll", SetLastError = true)]
public static partial void AllocConsole();
[LibraryImport("kernel32.dll", SetLastError = true)]
public static partial void FreeConsole();
}

View File

@@ -6,6 +6,10 @@
<PackageId>DiscordHistoryTrackerUtils</PackageId> <PackageId>DiscordHistoryTrackerUtils</PackageId>
</PropertyGroup> </PropertyGroup>
<PropertyGroup>
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
</PropertyGroup>
<ItemGroup> <ItemGroup>
<FrameworkReference Include="Microsoft.AspNetCore.App" /> <FrameworkReference Include="Microsoft.AspNetCore.App" />
</ItemGroup> </ItemGroup>