1
0
mirror of https://github.com/chylex/Discord-History-Tracker.git synced 2025-04-18 10:15:44 +02:00

Compare commits

..

2 Commits

20 changed files with 195 additions and 127 deletions

View File

@ -51,6 +51,7 @@
<Style Selector="Expander"> <Style Selector="Expander">
<Setter Property="MinHeight" Value="40" /> <Setter Property="MinHeight" Value="40" />
<Setter Property="Padding" Value="12" />
<Setter Property="HorizontalAlignment" Value="Stretch" /> <Setter Property="HorizontalAlignment" Value="Stretch" />
</Style> </Style>

View File

@ -11,25 +11,41 @@
<pages:TrackingPageModel /> <pages:TrackingPageModel />
</Design.DataContext> </Design.DataContext>
<StackPanel Spacing="10"> <UserControl.Styles>
<TextBlock TextWrapping="Wrap"> <Style Selector="TextBlock">
<TextBlock.Text> <Setter Property="TextWrapping" Value="Wrap" />
<MultiBinding StringFormat="To start tracking messages, copy the tracking script and paste it into the console of either the Discord app, or your browser. The console is usually opened by pressing {0}."> </Style>
<Binding Path="OpenDevToolsShortcutText" /> <Style Selector="WrapPanel > Button">
</MultiBinding> <Setter Property="Margin" Value="0 0 10 10" />
</TextBlock.Text> </Style>
</TextBlock> </UserControl.Styles>
<StackPanel DockPanel.Dock="Left" Orientation="Horizontal" Spacing="10">
<Button x:Name="CopyTrackingScript" Click="CopyTrackingScriptButton_OnClick" IsEnabled="{Binding IsCopyTrackingScriptButtonEnabled}">Copy Tracking Script</Button> <StackPanel Spacing="25">
</StackPanel> <Expander Header="Method 1: Manual" IsExpanded="True">
<TextBlock TextWrapping="Wrap" Margin="0 5 0 0"> <StackPanel Orientation="Vertical" Spacing="10">
<TextBlock.Text> <TextBlock>
<MultiBinding StringFormat="By default, the Discord app does not allow opening the console. The button below will change a hidden setting in the Discord app that controls whether the {0} shortcut is enabled."> <TextBlock.Text>
<Binding Path="OpenDevToolsShortcutText" /> <MultiBinding StringFormat="Use {0} to open the Dev Tools in your browser or the Discord app. Copy the tracking script, and paste it into the console.">
</MultiBinding> <Binding Path="OpenDevToolsShortcutText" />
</TextBlock.Text> </MultiBinding>
</TextBlock> </TextBlock.Text>
<Button DockPanel.Dock="Right" Command="{Binding OnClickToggleAppDevTools}" Content="{Binding ToggleAppDevToolsButtonText}" IsEnabled="{Binding IsToggleAppDevToolsButtonEnabled}" /> </TextBlock>
<Button x:Name="CopyTrackingScript" Click="CopyTrackingScriptButton_OnClick">Copy Tracking Script</Button>
<TextBlock Margin="0 5 0 0">
By default, the Discord app blocks the shortcut for opening Dev Tools. The button below will change a hidden setting in the Discord app that controls it.
</TextBlock>
<Button Command="{Binding OnClickToggleAppDevTools}" Content="{Binding ToggleAppDevToolsButtonText}" IsEnabled="{Binding IsToggleAppDevToolsButtonEnabled}" />
</StackPanel>
</Expander>
<Expander Header="Method 2: Userscript" IsExpanded="True" Padding="12 12 12 2.5">
<StackPanel Orientation="Vertical" Spacing="10">
<TextBlock>If your browser has a userscript manager, you can install a userscript that adds a DHT icon next to the Help icon. Clicking the DHT icon opens a prompt, where you paste the Connection Code.</TextBlock>
<WrapPanel>
<Button Command="{Binding OnClickInstallOrUpdateUserscript}">Install or Update Userscript</Button>
<Button x:Name="CopyConnectionCode" Click="CopyConnectionScriptButton_OnClick">Copy Connection Code</Button>
</WrapPanel>
</StackPanel>
</Expander>
</StackPanel> </StackPanel>
</UserControl> </UserControl>

View File

@ -1,4 +1,5 @@
using System; using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis; using System.Diagnostics.CodeAnalysis;
using System.Threading.Tasks; using System.Threading.Tasks;
using Avalonia.Controls; using Avalonia.Controls;
@ -8,24 +9,36 @@ namespace DHT.Desktop.Main.Pages;
[SuppressMessage("ReSharper", "MemberCanBeInternal")] [SuppressMessage("ReSharper", "MemberCanBeInternal")]
public sealed partial class TrackingPage : UserControl { public sealed partial class TrackingPage : UserControl {
private bool isCopyingScript; private readonly HashSet<Button> copyingButtons = new (ReferenceEqualityComparer.Instance);
public TrackingPage() { public TrackingPage() {
InitializeComponent(); InitializeComponent();
} }
public async void CopyTrackingScriptButton_OnClick(object? sender, RoutedEventArgs e) { public async void CopyTrackingScriptButton_OnClick(object? sender, RoutedEventArgs e) {
await HandleCopyButton(CopyTrackingScript, "Script Copied!", static model => model.OnClickCopyTrackingScript());
}
public async void CopyConnectionScriptButton_OnClick(object? sender, RoutedEventArgs e) {
await HandleCopyButton(CopyConnectionCode, "Code Copied!", static model => model.OnClickCopyConnectionCode());
}
private async Task HandleCopyButton(Button button, string copiedText, Func<TrackingPageModel, Task<bool>> onClick) {
if (DataContext is TrackingPageModel model) { if (DataContext is TrackingPageModel model) {
object? originalText = CopyTrackingScript.Content; object? originalText = button.Content;
CopyTrackingScript.MinWidth = CopyTrackingScript.Bounds.Width; button.MinWidth = button.Bounds.Width;
if (await model.OnClickCopyTrackingScript() && !isCopyingScript) { if (await onClick(model) && copyingButtons.Add(button)) {
isCopyingScript = true; button.IsEnabled = false;
CopyTrackingScript.Content = "Script Copied!"; button.Content = copiedText;
await Task.Delay(TimeSpan.FromSeconds(2)); try {
CopyTrackingScript.Content = originalText; await Task.Delay(TimeSpan.FromSeconds(2));
isCopyingScript = false; } finally {
copyingButtons.Remove(button);
button.IsEnabled = true;
button.Content = originalText;
}
} }
} }
} }

View File

@ -1,20 +1,22 @@
using System; using System;
using System.Text.RegularExpressions;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using System.Web; using System.Web;
using Avalonia.Controls; using Avalonia.Controls;
using Avalonia.Input.Platform; using Avalonia.Input.Platform;
using DHT.Desktop.Common;
using DHT.Desktop.Dialogs.Message; using DHT.Desktop.Dialogs.Message;
using DHT.Desktop.Discord; using DHT.Desktop.Discord;
using DHT.Desktop.Server; using DHT.Desktop.Server;
using DHT.Utils.Logging;
using PropertyChanged.SourceGenerator; using PropertyChanged.SourceGenerator;
using static DHT.Desktop.Program; using static DHT.Desktop.Program;
namespace DHT.Desktop.Main.Pages; namespace DHT.Desktop.Main.Pages;
sealed partial class TrackingPageModel { sealed partial class TrackingPageModel {
[Notify(Setter.Private)] private static readonly Log Log = Log.ForType<TrackingPageModel>();
private bool isCopyTrackingScriptButtonEnabled = true;
[Notify(Setter.Private)] [Notify(Setter.Private)]
private bool? areDevToolsEnabled = null; private bool? areDevToolsEnabled = null;
@ -51,32 +53,9 @@ sealed partial class TrackingPageModel {
} }
public async Task<bool> OnClickCopyTrackingScript() { public async Task<bool> OnClickCopyTrackingScript() {
IsCopyTrackingScriptButtonEnabled = false; string url = ServerConfiguration.HttpHost + $"/get-tracking-script?token={HttpUtility.UrlEncode(ServerConfiguration.Token)}";
try {
return await CopyTrackingScript();
} finally {
IsCopyTrackingScriptButtonEnabled = true;
}
}
private async Task<bool> CopyTrackingScript() {
string url = $"http://127.0.0.1:{ServerConfiguration.Port}/get-tracking-script?token={HttpUtility.UrlEncode(ServerConfiguration.Token)}";
string script = (await Resources.ReadTextAsync("tracker-loader.js")).Trim().Replace("{url}", url); string script = (await Resources.ReadTextAsync("tracker-loader.js")).Trim().Replace("{url}", url);
return await TryCopy(script, "Copy Tracking Script");
IClipboard? clipboard = window.Clipboard;
if (clipboard == null) {
await Dialog.ShowOk(window, "Copy Tracking Script", "Clipboard is not available on this system.");
return false;
}
try {
await clipboard.SetTextAsync(script);
return true;
} catch {
await Dialog.ShowOk(window, "Copy Tracking Script", "An error occurred while copying to clipboard.");
return false;
}
} }
private async Task InitializeDevToolsToggle() { private async Task InitializeDevToolsToggle() {
@ -132,4 +111,44 @@ sealed partial class TrackingPageModel {
throw new ArgumentOutOfRangeException(); throw new ArgumentOutOfRangeException();
} }
} }
public async Task OnClickInstallOrUpdateUserscript() {
try {
SystemUtils.OpenUrl(ServerConfiguration.HttpHost + "/get-userscript/dht.user.js");
} catch (Exception e) {
await Dialog.ShowOk(window, "Install or Update Userscript", "Could not open the browser: " + e.Message);
}
}
[GeneratedRegex("^[a-zA-Z0-9]{1,100}$")]
private static partial Regex ConnectionCodeTokenRegex();
public async Task<bool> OnClickCopyConnectionCode() {
const string Title = "Copy Connection Code";
if (ConnectionCodeTokenRegex().IsMatch(ServerConfiguration.Token)) {
return await TryCopy(ServerConfiguration.Port + ":" + ServerConfiguration.Token, Title);
}
else {
await Dialog.ShowOk(window, Title, "The internal server token cannot be used to create a connection code. Check the 'Advanced' tab and ensure the configured token is 1-100 characters long, and only contains plain letters and numbers.");
return false;
}
}
private async Task<bool> TryCopy(string script, string errorDialogTitle) {
IClipboard? clipboard = window.Clipboard;
if (clipboard == null) {
await Dialog.ShowOk(window, errorDialogTitle, "Clipboard is not available on this system.");
return false;
}
try {
await clipboard.SetTextAsync(script);
return true;
} catch (Exception e) {
Log.Error(e);
await Dialog.ShowOk(window, errorDialogTitle, "An error occurred while copying to clipboard.");
return false;
}
}
} }

View File

@ -51,10 +51,9 @@ sealed partial class ViewerPageModel : IDisposable {
public async void OnClickOpenViewer() { public async void OnClickOpenViewer() {
try { try {
string serverUrl = "http://127.0.0.1:" + ServerConfiguration.Port;
string serverToken = ServerConfiguration.Token; string serverToken = ServerConfiguration.Token;
string sessionId = state.ViewerSessions.Register(new ViewerSession(FilterModel.CreateFilter())).ToString(); string sessionId = state.ViewerSessions.Register(new ViewerSession(FilterModel.CreateFilter())).ToString();
SystemUtils.OpenUrl(serverUrl + "/viewer/?token=" + HttpUtility.UrlEncode(serverToken) + "&session=" + HttpUtility.UrlEncode(sessionId)); SystemUtils.OpenUrl(ServerConfiguration.HttpHost + "/viewer/?token=" + HttpUtility.UrlEncode(serverToken) + "&session=" + HttpUtility.UrlEncode(sessionId));
} catch (Exception e) { } catch (Exception e) {
await Dialog.ShowOk(window, "Open Viewer", "Could not open viewer: " + e.Message); await Dialog.ShowOk(window, "Open Viewer", "Could not open viewer: " + e.Message);
} }

View File

@ -5,4 +5,6 @@ namespace DHT.Desktop.Server;
static class ServerConfiguration { static class ServerConfiguration {
public static ushort Port { get; set; } = ServerUtils.FindAvailablePort(min: 50000, max: 60000); public static ushort Port { get; set; } = ServerUtils.FindAvailablePort(min: 50000, max: 60000);
public static string Token { get; set; } = ServerUtils.GenerateRandomToken(20); public static string Token { get; set; } = ServerUtils.GenerateRandomToken(20);
public static string HttpHost => "http://127.0.0.1:" + Port;
} }

View File

@ -159,11 +159,13 @@ function showConnectDialog() {
dialogElement.close(); dialogElement.close();
}); });
codeInputElement.addEventListener("paste", async function() { codeInputElement.addEventListener("paste", function() {
const code = parseConnectionCode(codeInputElement.value); setTimeout(async function() {
if (code !== null) { const code = parseConnectionCode(codeInputElement.value);
await onSubmit(code); if (code !== null) {
} await onSubmit(code);
}
}, 0);
}); });
formElement.addEventListener("submit", async function(e) { formElement.addEventListener("submit", async function(e) {
@ -211,7 +213,7 @@ function parseConnectionCode(code) {
return null; return null;
} }
const match = code.match(/^(\d{1,5}):([a-z0-9]{1,100})$/); const match = code.match(/^(\d{1,5}):([a-zA-Z0-9]{1,100})$/);
if (!match) { if (!match) {
return null; return null;
} }

View File

@ -3,18 +3,17 @@ using System.Net;
using System.Text.Json; using System.Text.Json;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using DHT.Server.Database;
using DHT.Utils.Http; using DHT.Utils.Http;
using DHT.Utils.Logging; using DHT.Utils.Logging;
using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.StaticFiles;
using Microsoft.Extensions.Primitives; using Microsoft.Extensions.Primitives;
namespace DHT.Server.Endpoints; namespace DHT.Server.Endpoints;
abstract class BaseEndpoint(IDatabaseFile db) { abstract class BaseEndpoint {
private static readonly Log Log = Log.ForType<BaseEndpoint>(); private static readonly Log Log = Log.ForType<BaseEndpoint>();
private static readonly FileExtensionContentTypeProvider ContentTypeProvider = new ();
protected IDatabaseFile Db { get; } = db;
public async Task Handle(HttpContext ctx) { public async Task Handle(HttpContext ctx) {
HttpResponse response = ctx.Response; HttpResponse response = ctx.Response;
@ -49,6 +48,16 @@ abstract class BaseEndpoint(IDatabaseFile db) {
} }
} }
protected static async Task WriteFileIfFound(HttpResponse response, string relativeFilePath, byte[]? bytes, CancellationToken cancellationToken) {
if (bytes == null) {
throw new HttpException(HttpStatusCode.NotFound, "File not found: " + relativeFilePath);
}
else {
string? contentType = ContentTypeProvider.TryGetContentType(relativeFilePath, out string? type) ? type : null;
await response.WriteFileAsync(contentType, bytes, cancellationToken);
}
}
protected static Guid GetSessionId(HttpRequest request) { protected static Guid GetSessionId(HttpRequest request) {
if (request.Query.TryGetValue("session", out StringValues sessionIdValue) && sessionIdValue.Count == 1 && Guid.TryParse(sessionIdValue[0], out Guid sessionId)) { if (request.Query.TryGetValue("session", out StringValues sessionIdValue) && sessionIdValue.Count == 1 && Guid.TryParse(sessionIdValue[0], out Guid sessionId)) {
return sessionId; return sessionId;

View File

@ -10,12 +10,12 @@ using Microsoft.AspNetCore.Http;
namespace DHT.Server.Endpoints; namespace DHT.Server.Endpoints;
sealed class GetDownloadedFileEndpoint(IDatabaseFile db) : BaseEndpoint(db) { sealed class GetDownloadedFileEndpoint(IDatabaseFile db) : BaseEndpoint {
protected override async Task Respond(HttpRequest request, HttpResponse response, CancellationToken cancellationToken) { protected override async Task Respond(HttpRequest request, HttpResponse response, CancellationToken cancellationToken) {
string url = WebUtility.UrlDecode((string) request.RouteValues["url"]!); string url = WebUtility.UrlDecode((string) request.RouteValues["url"]!);
string normalizedUrl = DiscordCdn.NormalizeUrl(url); string normalizedUrl = DiscordCdn.NormalizeUrl(url);
if (!await Db.Downloads.GetSuccessfulDownloadWithData(normalizedUrl, WriteDataTo(response), cancellationToken)) { if (!await db.Downloads.GetSuccessfulDownloadWithData(normalizedUrl, WriteDataTo(response), cancellationToken)) {
response.Redirect(url, permanent: false); response.Redirect(url, permanent: false);
} }
} }

View File

@ -2,7 +2,6 @@ using System.Net.Mime;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using System.Web; using System.Web;
using DHT.Server.Database;
using DHT.Server.Service; using DHT.Server.Service;
using DHT.Utils.Http; using DHT.Utils.Http;
using DHT.Utils.Resources; using DHT.Utils.Resources;
@ -10,7 +9,7 @@ using Microsoft.AspNetCore.Http;
namespace DHT.Server.Endpoints; namespace DHT.Server.Endpoints;
sealed class GetTrackingScriptEndpoint(IDatabaseFile db, ServerParameters parameters, ResourceLoader resources) : BaseEndpoint(db) { sealed class GetTrackingScriptEndpoint(ServerParameters parameters, ResourceLoader resources) : BaseEndpoint {
protected override async Task Respond(HttpRequest request, HttpResponse response, CancellationToken cancellationToken) { protected override async Task Respond(HttpRequest request, HttpResponse response, CancellationToken cancellationToken) {
string bootstrap = await resources.ReadTextAsync("Tracker/bootstrap.js"); string bootstrap = await resources.ReadTextAsync("Tracker/bootstrap.js");
string script = bootstrap.Replace("= 0; /*[PORT]*/", "= " + parameters.Port + ";") string script = bootstrap.Replace("= 0; /*[PORT]*/", "= " + parameters.Port + ";")

View File

@ -0,0 +1,18 @@
using System.Threading;
using System.Threading.Tasks;
using DHT.Server.Service.Middlewares;
using DHT.Utils.Resources;
using Microsoft.AspNetCore.Http;
namespace DHT.Server.Endpoints;
[ServerAuthorizationMiddleware.NoAuthorization]
sealed class GetUserscriptEndpoint(ResourceLoader resources) : BaseEndpoint {
protected override async Task Respond(HttpRequest request, HttpResponse response, CancellationToken cancellationToken) {
const string FileName = "dht.user.js";
const string ResourcePath = "Tracker/loader/" + FileName;
byte[]? resourceBytes = await resources.ReadBytesAsyncIfExists(ResourcePath);
await WriteFileIfFound(response, FileName, resourceBytes, cancellationToken);
}
}

View File

@ -8,12 +8,12 @@ using Microsoft.AspNetCore.Http;
namespace DHT.Server.Endpoints; namespace DHT.Server.Endpoints;
sealed class GetViewerMessagesEndpoint(IDatabaseFile db, ViewerSessions viewerSessions) : BaseEndpoint(db) { sealed class GetViewerMessagesEndpoint(IDatabaseFile db, ViewerSessions viewerSessions) : BaseEndpoint {
protected override Task Respond(HttpRequest request, HttpResponse response, CancellationToken cancellationToken) { protected override Task Respond(HttpRequest request, HttpResponse response, CancellationToken cancellationToken) {
Guid sessionId = GetSessionId(request); Guid sessionId = GetSessionId(request);
ViewerSession session = viewerSessions.Get(sessionId); ViewerSession session = viewerSessions.Get(sessionId);
response.ContentType = "application/x-ndjson"; response.ContentType = "application/x-ndjson";
return ViewerJsonExport.GetMessages(response.Body, Db, session.MessageFilter, cancellationToken); return ViewerJsonExport.GetMessages(response.Body, db, session.MessageFilter, cancellationToken);
} }
} }

View File

@ -9,12 +9,12 @@ using Microsoft.AspNetCore.Http;
namespace DHT.Server.Endpoints; namespace DHT.Server.Endpoints;
sealed class GetViewerMetadataEndpoint(IDatabaseFile db, ViewerSessions viewerSessions) : BaseEndpoint(db) { sealed class GetViewerMetadataEndpoint(IDatabaseFile db, ViewerSessions viewerSessions) : BaseEndpoint {
protected override Task Respond(HttpRequest request, HttpResponse response, CancellationToken cancellationToken) { protected override Task Respond(HttpRequest request, HttpResponse response, CancellationToken cancellationToken) {
Guid sessionId = GetSessionId(request); Guid sessionId = GetSessionId(request);
ViewerSession session = viewerSessions.Get(sessionId); ViewerSession session = viewerSessions.Get(sessionId);
response.ContentType = MediaTypeNames.Application.Json; response.ContentType = MediaTypeNames.Application.Json;
return ViewerJsonExport.GetMetadata(response.Body, Db, session.MessageFilter, cancellationToken); return ViewerJsonExport.GetMetadata(response.Body, db, session.MessageFilter, cancellationToken);
} }
} }

View File

@ -9,14 +9,14 @@ using Microsoft.AspNetCore.Http;
namespace DHT.Server.Endpoints; namespace DHT.Server.Endpoints;
sealed class TrackChannelEndpoint(IDatabaseFile db) : BaseEndpoint(db) { sealed class TrackChannelEndpoint(IDatabaseFile db) : BaseEndpoint {
protected override async Task Respond(HttpRequest request, HttpResponse response, CancellationToken cancellationToken) { protected override async Task Respond(HttpRequest request, HttpResponse response, CancellationToken cancellationToken) {
JsonElement root = await ReadJson(request); JsonElement root = await ReadJson(request);
Data.Server server = ReadServer(root.RequireObject("server"), "server"); Data.Server server = ReadServer(root.RequireObject("server"), "server");
Channel channel = ReadChannel(root.RequireObject("channel"), "channel", server.Id); Channel channel = ReadChannel(root.RequireObject("channel"), "channel", server.Id);
await Db.Servers.Add([server]); await db.Servers.Add([server]);
await Db.Channels.Add([channel]); await db.Channels.Add([channel]);
} }
private static Data.Server ReadServer(JsonElement json, string path) { private static Data.Server ReadServer(JsonElement json, string path) {

View File

@ -16,7 +16,7 @@ using Microsoft.AspNetCore.Http;
namespace DHT.Server.Endpoints; namespace DHT.Server.Endpoints;
sealed class TrackMessagesEndpoint(IDatabaseFile db) : BaseEndpoint(db) { sealed class TrackMessagesEndpoint(IDatabaseFile db) : BaseEndpoint {
private const string HasNewMessages = "1"; private const string HasNewMessages = "1";
private const string NoNewMessages = "0"; private const string NoNewMessages = "0";
@ -38,9 +38,9 @@ sealed class TrackMessagesEndpoint(IDatabaseFile db) : BaseEndpoint(db) {
} }
var addedMessageFilter = new MessageFilter { MessageIds = addedMessageIds }; var addedMessageFilter = new MessageFilter { MessageIds = addedMessageIds };
bool anyNewMessages = await Db.Messages.Count(addedMessageFilter, CancellationToken.None) < addedMessageIds.Count; bool anyNewMessages = await db.Messages.Count(addedMessageFilter, CancellationToken.None) < addedMessageIds.Count;
await Db.Messages.Add(messages); await db.Messages.Add(messages);
await response.WriteTextAsync(anyNewMessages ? HasNewMessages : NoNewMessages, cancellationToken); await response.WriteTextAsync(anyNewMessages ? HasNewMessages : NoNewMessages, cancellationToken);
} }

View File

@ -9,7 +9,7 @@ using Microsoft.AspNetCore.Http;
namespace DHT.Server.Endpoints; namespace DHT.Server.Endpoints;
sealed class TrackUsersEndpoint(IDatabaseFile db) : BaseEndpoint(db) { sealed class TrackUsersEndpoint(IDatabaseFile db) : BaseEndpoint {
protected override async Task Respond(HttpRequest request, HttpResponse response, CancellationToken cancellationToken) { protected override async Task Respond(HttpRequest request, HttpResponse response, CancellationToken cancellationToken) {
JsonElement root = await ReadJson(request); JsonElement root = await ReadJson(request);
@ -24,7 +24,7 @@ sealed class TrackUsersEndpoint(IDatabaseFile db) : BaseEndpoint(db) {
users[i++] = ReadUser(user, "user"); users[i++] = ReadUser(user, "user");
} }
await Db.Users.Add(users); await db.Users.Add(users);
} }
private static User ReadUser(JsonElement json, string path) { private static User ReadUser(JsonElement json, string path) {

View File

@ -1,18 +1,14 @@
using System.Collections.Generic; using System.Collections.Generic;
using System.Net;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using DHT.Server.Database; using DHT.Server.Service.Middlewares;
using DHT.Utils.Http;
using DHT.Utils.Resources; using DHT.Utils.Resources;
using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.StaticFiles;
namespace DHT.Server.Endpoints; namespace DHT.Server.Endpoints;
sealed class ViewerEndpoint(IDatabaseFile db, ResourceLoader resources) : BaseEndpoint(db) { [ServerAuthorizationMiddleware.NoAuthorization]
private static readonly FileExtensionContentTypeProvider ContentTypeProvider = new (); sealed class ViewerEndpoint(ResourceLoader resources) : BaseEndpoint {
private readonly Dictionary<string, byte[]?> cache = new (); private readonly Dictionary<string, byte[]?> cache = new ();
private readonly SemaphoreSlim cacheSemaphore = new (1); private readonly SemaphoreSlim cacheSemaphore = new (1);
@ -31,12 +27,6 @@ sealed class ViewerEndpoint(IDatabaseFile db, ResourceLoader resources) : BaseEn
cacheSemaphore.Release(); cacheSemaphore.Release();
} }
if (resourceBytes == null) { await WriteFileIfFound(response, path, resourceBytes, cancellationToken);
throw new HttpException(HttpStatusCode.NotFound, "File not found: " + path);
}
else {
string? contentType = ContentTypeProvider.TryGetContentType(path, out string? type) ? type : null;
await response.WriteFileAsync(contentType, resourceBytes, cancellationToken);
}
} }
} }

View File

@ -29,6 +29,11 @@
<Link>Resources/Tracker/%(RecursiveDir)%(Filename)%(Extension)</Link> <Link>Resources/Tracker/%(RecursiveDir)%(Filename)%(Extension)</Link>
<Visible>false</Visible> <Visible>false</Visible>
</EmbeddedResource> </EmbeddedResource>
<EmbeddedResource Include="../Resources/Tracker/loader/**">
<LogicalName>Tracker\loader\%(RecursiveDir)%(Filename)%(Extension)</LogicalName>
<Link>Resources/Tracker/loader/%(RecursiveDir)%(Filename)%(Extension)</Link>
<Visible>false</Visible>
</EmbeddedResource>
<EmbeddedResource Include="../Resources/Tracker/scripts/**"> <EmbeddedResource Include="../Resources/Tracker/scripts/**">
<LogicalName>Tracker\scripts\%(RecursiveDir)%(Filename)%(Extension)</LogicalName> <LogicalName>Tracker\scripts\%(RecursiveDir)%(Filename)%(Extension)</LogicalName>
<Link>Resources/Tracker/scripts/%(RecursiveDir)%(Filename)%(Extension)</Link> <Link>Resources/Tracker/scripts/%(RecursiveDir)%(Filename)%(Extension)</Link>

View File

@ -1,4 +1,6 @@
using System;
using System.Net; using System.Net;
using System.Reflection;
using System.Threading.Tasks; using System.Threading.Tasks;
using DHT.Utils.Logging; using DHT.Utils.Logging;
using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http;
@ -6,25 +8,11 @@ using Microsoft.Extensions.Primitives;
namespace DHT.Server.Service.Middlewares; namespace DHT.Server.Service.Middlewares;
sealed class ServerAuthorizationMiddleware { sealed class ServerAuthorizationMiddleware(RequestDelegate next, ServerParameters serverParameters) {
private static readonly Log Log = Log.ForType<ServerAuthorizationMiddleware>(); private static readonly Log Log = Log.ForType<ServerAuthorizationMiddleware>();
private readonly RequestDelegate next;
private readonly ServerParameters serverParameters;
public ServerAuthorizationMiddleware(RequestDelegate next, ServerParameters serverParameters) {
this.next = next;
this.serverParameters = serverParameters;
}
public async Task InvokeAsync(HttpContext context) { public async Task InvokeAsync(HttpContext context) {
HttpRequest request = context.Request; if (SkipAuthorization(context) || CheckToken(context.Request)) {
bool success = HttpMethods.IsGet(request.Method)
? CheckToken(request.Query["token"])
: CheckToken(request.Headers["X-DHT-Token"]);
if (success) {
await next(context); await next(context);
} }
else { else {
@ -32,6 +20,16 @@ sealed class ServerAuthorizationMiddleware {
} }
} }
private static bool SkipAuthorization(HttpContext context) {
return context.GetEndpoint()?.RequestDelegate?.Target?.GetType().GetCustomAttribute<NoAuthorization>() != null;
}
private bool CheckToken(HttpRequest request) {
return HttpMethods.IsGet(request.Method)
? CheckToken(request.Query["token"])
: CheckToken(request.Headers["X-DHT-Token"]);
}
private bool CheckToken(StringValues token) { private bool CheckToken(StringValues token) {
if (token.Count == 1 && token[0] == serverParameters.Token) { if (token.Count == 1 && token[0] == serverParameters.Token) {
return true; return true;
@ -41,4 +39,7 @@ sealed class ServerAuthorizationMiddleware {
return false; return false;
} }
} }
[AttributeUsage(AttributeTargets.Class)]
public sealed class NoAuthorization : Attribute;
} }

View File

@ -38,25 +38,19 @@ sealed class Startup {
public void Configure(IApplicationBuilder app, IHostApplicationLifetime lifetime, IDatabaseFile db, ServerParameters parameters, ResourceLoader resources, ViewerSessions viewerSessions) { public void Configure(IApplicationBuilder app, IHostApplicationLifetime lifetime, IDatabaseFile db, ServerParameters parameters, ResourceLoader resources, ViewerSessions viewerSessions) {
app.UseMiddleware<ServerLoggingMiddleware>(); app.UseMiddleware<ServerLoggingMiddleware>();
app.UseCors(); app.UseCors();
app.Map("/viewer", node => {
node.UseRouting();
node.UseEndpoints(endpoints => {
endpoints.MapGet("/{**path}", new ViewerEndpoint(db, resources).Handle);
});
});
app.UseMiddleware<ServerAuthorizationMiddleware>();
app.UseRouting(); app.UseRouting();
app.UseMiddleware<ServerAuthorizationMiddleware>();
app.UseEndpoints(endpoints => { app.UseEndpoints(endpoints => {
endpoints.MapGet("/get-tracking-script", new GetTrackingScriptEndpoint(db, parameters, resources).Handle);
endpoints.MapGet("/get-viewer-metadata", new GetViewerMetadataEndpoint(db, viewerSessions).Handle);
endpoints.MapGet("/get-viewer-messages", new GetViewerMessagesEndpoint(db, viewerSessions).Handle);
endpoints.MapGet("/get-downloaded-file/{url}", new GetDownloadedFileEndpoint(db).Handle); endpoints.MapGet("/get-downloaded-file/{url}", new GetDownloadedFileEndpoint(db).Handle);
endpoints.MapGet("/get-tracking-script", new GetTrackingScriptEndpoint(parameters, resources).Handle);
endpoints.MapGet("/get-userscript/{**ignored}", new GetUserscriptEndpoint(resources).Handle);
endpoints.MapGet("/get-viewer-messages", new GetViewerMessagesEndpoint(db, viewerSessions).Handle);
endpoints.MapGet("/get-viewer-metadata", new GetViewerMetadataEndpoint(db, viewerSessions).Handle);
endpoints.MapGet("/viewer/{**path}", new ViewerEndpoint(resources).Handle);
endpoints.MapPost("/track-channel", new TrackChannelEndpoint(db).Handle); endpoints.MapPost("/track-channel", new TrackChannelEndpoint(db).Handle);
endpoints.MapPost("/track-users", new TrackUsersEndpoint(db).Handle);
endpoints.MapPost("/track-messages", new TrackMessagesEndpoint(db).Handle); endpoints.MapPost("/track-messages", new TrackMessagesEndpoint(db).Handle);
endpoints.MapPost("/track-users", new TrackUsersEndpoint(db).Handle);
}); });
} }
} }