1
0
mirror of https://github.com/chylex/Discord-History-Tracker.git synced 2025-08-17 10:31:41 +02:00

18 Commits
v40.0 ... v41.2

Author SHA1 Message Date
d4da64a5ed Release v41.2 2024-04-17 14:14:46 +02:00
8de309a6c4 Fix some Rider inspections and a typo 2024-04-17 14:13:33 +02:00
de8d6a1e11 Stream downloaded files during database merges 2024-04-17 13:31:25 +02:00
d79e6f53b4 Stream downloaded files from database directly into HTTP server responses 2024-04-17 13:31:25 +02:00
70c04fc986 Stream downloaded files directly into database 2024-04-17 13:31:24 +02:00
c8d8d95daa Fix not rolling back database transactions after unhandled exceptions 2024-04-17 12:30:13 +02:00
daafdbbfaf Prevent active downloads from timing out
Closes #256
2024-04-17 08:51:37 +02:00
07615de87a Fix download timeouts not marking the downloaded file as failed
References #256
2024-04-16 11:00:38 +02:00
7fdc19880e Add -concurrentdownloads program argument to configure number of concurrent download tasks
References #256
2024-04-16 10:50:50 +02:00
67b9c12843 Release v41.1 2024-02-15 13:07:36 +01:00
9030a2f010 Update message timestamp processing for latest Discord update
Closes #249
2024-02-15 13:06:06 +01:00
a6dad6b4c7 Release v41.0 2024-01-11 04:24:34 +01:00
72b8fb7c14 Update viewer to reference downloaded embeds, avatars, and emoji 2024-01-07 19:15:12 +01:00
7173dc6cfc Refactor last change to CDN URL normalization 2024-01-07 06:17:17 +01:00
2c1e5a7603 Rework download storage and start collecting download URLs from embeds, avatars, and reactions
Resolves #200
2024-01-07 05:50:46 +01:00
4929a19397 Fix button to retry failed downloads & show error if downloads fail to start 2024-01-01 16:25:32 +01:00
c5f77872fe Fix some database calls not being asynchronous 2024-01-01 14:37:05 +01:00
c9e50e1a80 Refactor database schema upgrades 2024-01-01 09:29:32 +01:00
86 changed files with 1831 additions and 1445 deletions

View File

@@ -9,15 +9,15 @@
<entry key="Desktop/Dialogs/Progress/ProgressDialog.axaml" value="Desktop/Desktop.csproj" />
<entry key="Desktop/Dialogs/TextBox/TextBoxDialog.axaml" value="Desktop/Desktop.csproj" />
<entry key="Desktop/Main/AboutWindow.axaml" value="Desktop/Desktop.csproj" />
<entry key="Desktop/Main/Controls/AttachmentFilterPanel.axaml" value="Desktop/Desktop.csproj" />
<entry key="Desktop/Main/Controls/DownloadItemFilterPanel.axaml" value="Desktop/Desktop.csproj" />
<entry key="Desktop/Main/Controls/MessageFilterPanel.axaml" value="Desktop/Desktop.csproj" />
<entry key="Desktop/Main/Controls/ServerConfigurationPanel.axaml" value="Desktop/Desktop.csproj" />
<entry key="Desktop/Main/Controls/StatusBar.axaml" value="Desktop/Desktop.csproj" />
<entry key="Desktop/Main/MainWindow.axaml" value="Desktop/Desktop.csproj" />
<entry key="Desktop/Main/Pages/AdvancedPage.axaml" value="Desktop/Desktop.csproj" />
<entry key="Desktop/Main/Pages/AttachmentsPage.axaml" value="Desktop/Desktop.csproj" />
<entry key="Desktop/Main/Pages/DatabasePage.axaml" value="Desktop/Desktop.csproj" />
<entry key="Desktop/Main/Pages/DebugPage.axaml" value="Desktop/Desktop.csproj" />
<entry key="Desktop/Main/Pages/DownloadsPage.axaml" value="Desktop/Desktop.csproj" />
<entry key="Desktop/Main/Pages/TrackingPage.axaml" value="Desktop/Desktop.csproj" />
<entry key="Desktop/Main/Pages/ViewerPage.axaml" value="Desktop/Desktop.csproj" />
<entry key="Desktop/Main/Screens/MainContentScreen.axaml" value="Desktop/Desktop.csproj" />
@@ -25,4 +25,4 @@
</map>
</option>
</component>
</project>
</project>

View File

@@ -15,6 +15,7 @@ sealed class Arguments {
public string? DatabaseFile { get; }
public ushort? ServerPort { get; }
public string? ServerToken { get; }
public byte? ConcurrentDownloads { get; }
public Arguments(IReadOnlyList<string> args) {
for (int i = FirstArgument; i < args.Count; i++) {
@@ -50,11 +51,11 @@ sealed class Arguments {
continue;
case "-port": {
if (ushort.TryParse(value, out var port)) {
ServerPort = port;
if (!ushort.TryParse(value, out var port)) {
Log.Warn("Invalid port number: " + value);
}
else {
Log.Warn("Invalid port number: " + value);
ServerPort = port;
}
continue;
@@ -63,6 +64,20 @@ sealed class Arguments {
case "-token":
ServerToken = value;
continue;
case "-concurrentdownloads":
if (!ulong.TryParse(value, out var concurrentDownloads) || concurrentDownloads == 0) {
Log.Warn("Invalid concurrent downloads count: " + value);
}
else if (concurrentDownloads > 10) {
Log.Warn("Limiting concurrent downloads to 10");
ConcurrentDownloads = 10;
}
else {
ConcurrentDownloads = (byte) concurrentDownloads;
}
continue;
default:
Log.Warn("Unknown command line argument: " + key);

View File

@@ -29,7 +29,7 @@ sealed class BytesValueConverter : IValueConverter {
private const int Scale = 1000;
private static string Convert(ulong size) {
public static string Convert(ulong size) {
int power = size == 0L ? 0 : (int) Math.Log(size, Scale);
int unit = power >= Units.Length ? Units.Length - 1 : power;
return Units[unit].Format(unit == 0 ? size : size / Math.Pow(Scale, unit));

View File

@@ -9,7 +9,7 @@ using DHT.Desktop.Dialogs.Message;
using DHT.Server.Database;
using DHT.Server.Database.Exceptions;
using DHT.Server.Database.Sqlite;
using DHT.Server.Database.Sqlite.Utils;
using DHT.Server.Database.Sqlite.Schema;
using DHT.Utils.Logging;
namespace DHT.Desktop.Common;

View File

@@ -32,10 +32,6 @@
<ItemGroup>
<Compile Include="..\Version.cs" Link="Version.cs" />
<Compile Update="Dialogs\TextBox\TextBoxDialog.axaml.cs">
<DependentUpon>CheckBoxDialog.axaml</DependentUpon>
<SubType>Code</SubType>
</Compile>
</ItemGroup>
<ItemGroup>

View File

@@ -4,11 +4,11 @@
xmlns:mc="http://schemas.openxmlformats.org/markup-compatibility/2006"
xmlns:controls="clr-namespace:DHT.Desktop.Main.Controls"
mc:Ignorable="d"
x:Class="DHT.Desktop.Main.Controls.AttachmentFilterPanel"
x:DataType="controls:AttachmentFilterPanelModel">
x:Class="DHT.Desktop.Main.Controls.DownloadItemFilterPanel"
x:DataType="controls:DownloadItemFilterPanelModel">
<Design.DataContext>
<controls:AttachmentFilterPanelModel />
<controls:DownloadItemFilterPanelModel />
</Design.DataContext>
<UserControl.Styles>

View File

@@ -4,8 +4,8 @@ using Avalonia.Controls;
namespace DHT.Desktop.Main.Controls;
[SuppressMessage("ReSharper", "MemberCanBeInternal")]
public sealed partial class AttachmentFilterPanel : UserControl {
public AttachmentFilterPanel() {
public sealed partial class DownloadItemFilterPanel : UserControl {
public DownloadItemFilterPanel() {
InitializeComponent();
}
}

View File

@@ -12,7 +12,7 @@ using DHT.Utils.Tasks;
namespace DHT.Desktop.Main.Controls;
sealed partial class AttachmentFilterPanelModel : ObservableObject, IDisposable {
sealed partial class DownloadItemFilterPanelModel : ObservableObject, IDisposable {
public sealed record Unit(string Name, uint Scale);
private static readonly Unit[] AllUnits = [
@@ -43,21 +43,21 @@ sealed partial class AttachmentFilterPanelModel : ObservableObject, IDisposable
private readonly State state;
private readonly string verb;
private readonly RestartableTask<long> matchingAttachmentCountTask;
private long? matchingAttachmentCount;
private readonly RestartableTask<long> downloadItemCountTask;
private long? matchingItemCount;
private readonly IDisposable attachmentCountSubscription;
private long? totalAttachmentCount;
private readonly IDisposable downloadItemCountSubscription;
private long? totalItemCount;
[Obsolete("Designer")]
public AttachmentFilterPanelModel() : this(State.Dummy) {}
public DownloadItemFilterPanelModel() : this(State.Dummy) {}
public AttachmentFilterPanelModel(State state, string verb = "Matches") {
public DownloadItemFilterPanelModel(State state, string verb = "Matches") {
this.state = state;
this.verb = verb;
this.matchingAttachmentCountTask = new RestartableTask<long>(SetAttachmentCounts, TaskScheduler.FromCurrentSynchronizationContext());
this.attachmentCountSubscription = state.Db.Attachments.TotalCount.ObserveOn(AvaloniaScheduler.Instance).Subscribe(OnAttachmentCountChanged);
this.downloadItemCountTask = new RestartableTask<long>(SetMatchingCount, TaskScheduler.FromCurrentSynchronizationContext());
this.downloadItemCountSubscription = state.Db.Downloads.TotalCount.ObserveOn(AvaloniaScheduler.Instance).Subscribe(OnDownloadItemCountChanged);
UpdateFilterStatistics();
@@ -65,7 +65,8 @@ sealed partial class AttachmentFilterPanelModel : ObservableObject, IDisposable
}
public void Dispose() {
attachmentCountSubscription.Dispose();
downloadItemCountTask.Cancel();
downloadItemCountSubscription.Dispose();
}
private void OnPropertyChanged(object? sender, PropertyChangedEventArgs e) {
@@ -74,8 +75,8 @@ sealed partial class AttachmentFilterPanelModel : ObservableObject, IDisposable
}
}
private void OnAttachmentCountChanged(long newAttachmentCount) {
totalAttachmentCount = newAttachmentCount;
private void OnDownloadItemCountChanged(long newItemCount) {
totalItemCount = newItemCount;
UpdateFilterStatistics();
}
@@ -83,32 +84,32 @@ sealed partial class AttachmentFilterPanelModel : ObservableObject, IDisposable
private void UpdateFilterStatistics() {
var filter = CreateFilter();
if (filter.IsEmpty) {
matchingAttachmentCountTask.Cancel();
matchingAttachmentCount = totalAttachmentCount;
downloadItemCountTask.Cancel();
matchingItemCount = totalItemCount;
UpdateFilterStatisticsText();
}
else {
matchingAttachmentCount = null;
matchingItemCount = null;
UpdateFilterStatisticsText();
matchingAttachmentCountTask.Restart(cancellationToken => state.Db.Attachments.Count(filter, cancellationToken));
downloadItemCountTask.Restart(cancellationToken => state.Db.Downloads.Count(filter, cancellationToken));
}
}
private void SetAttachmentCounts(long matchingAttachmentCount) {
this.matchingAttachmentCount = matchingAttachmentCount;
private void SetMatchingCount(long matchingAttachmentCount) {
this.matchingItemCount = matchingAttachmentCount;
UpdateFilterStatisticsText();
}
private void UpdateFilterStatisticsText() {
var matchingAttachmentCountStr = matchingAttachmentCount?.Format() ?? "(...)";
var totalAttachmentCountStr = totalAttachmentCount?.Format() ?? "(...)";
var matchingItemCountStr = matchingItemCount?.Format() ?? "(...)";
var totalItemCountStr = totalItemCount?.Format() ?? "(...)";
FilterStatisticsText = verb + " " + matchingAttachmentCountStr + " out of " + totalAttachmentCountStr + " attachment" + (totalAttachmentCount is null or 1 ? "." : "s.");
FilterStatisticsText = verb + " " + matchingItemCountStr + " out of " + totalItemCountStr + " file" + (totalItemCount is null or 1 ? "." : "s.");
OnPropertyChanged(nameof(FilterStatisticsText));
}
public AttachmentFilter CreateFilter() {
AttachmentFilter filter = new ();
public DownloadItemFilter CreateFilter() {
DownloadItemFilter filter = new ();
if (LimitSize) {
try {

View File

@@ -30,6 +30,7 @@ sealed partial class MainWindowModel : ObservableObject, IAsyncDisposable {
private MainContentScreenModel? mainContentScreenModel;
private readonly Window window;
private readonly int? concurrentDownloads;
private State? state;
@@ -73,6 +74,8 @@ sealed partial class MainWindowModel : ObservableObject, IAsyncDisposable {
if (args.ServerToken != null) {
ServerConfiguration.Token = args.ServerToken;
}
concurrentDownloads = args.ConcurrentDownloads;
}
private async void OnDatabaseSelected(object? sender, IDatabaseFile db) {
@@ -80,7 +83,7 @@ sealed partial class MainWindowModel : ObservableObject, IAsyncDisposable {
await DisposeState();
state = new State(db);
state = new State(db, concurrentDownloads);
try {
await state.Server.Start(ServerConfiguration.Port, ServerConfiguration.Token);

View File

@@ -1,244 +0,0 @@
using System;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Reactive.Linq;
using System.Threading.Tasks;
using Avalonia.ReactiveUI;
using CommunityToolkit.Mvvm.ComponentModel;
using DHT.Desktop.Common;
using DHT.Desktop.Main.Controls;
using DHT.Server;
using DHT.Server.Data;
using DHT.Server.Data.Aggregations;
using DHT.Server.Data.Filters;
using DHT.Utils.Logging;
using DHT.Utils.Tasks;
namespace DHT.Desktop.Main.Pages;
sealed partial class AttachmentsPageModel : ObservableObject, IDisposable {
private static readonly Log Log = Log.ForType<AttachmentsPageModel>();
private static readonly DownloadItemFilter EnqueuedItemFilter = new () {
IncludeStatuses = new HashSet<DownloadStatus> {
DownloadStatus.Enqueued,
DownloadStatus.Downloading
}
};
[ObservableProperty(Setter = Access.Private)]
private bool isToggleDownloadButtonEnabled = true;
public string ToggleDownloadButtonText => IsDownloading ? "Stop Downloading" : "Start Downloading";
[ObservableProperty(Setter = Access.Private)]
[NotifyPropertyChangedFor(nameof(IsRetryFailedOnDownloadsButtonEnabled))]
private bool isRetryingFailedDownloads = false;
[ObservableProperty(Setter = Access.Private)]
[NotifyPropertyChangedFor(nameof(IsRetryFailedOnDownloadsButtonEnabled))]
private bool hasFailedDownloads;
public bool IsRetryFailedOnDownloadsButtonEnabled => !IsRetryingFailedDownloads && hasFailedDownloads;
[ObservableProperty(Setter = Access.Private)]
private string downloadMessage = "";
public double DownloadProgress => totalItemsToDownloadCount is null or 0 ? 0.0 : 100.0 * doneItemsCount / totalItemsToDownloadCount.Value;
public AttachmentFilterPanelModel FilterModel { get; }
private readonly StatisticsRow statisticsEnqueued = new ("Enqueued");
private readonly StatisticsRow statisticsDownloaded = new ("Downloaded");
private readonly StatisticsRow statisticsFailed = new ("Failed");
private readonly StatisticsRow statisticsSkipped = new ("Skipped");
public ObservableCollection<StatisticsRow> StatisticsRows { get; }
public bool IsDownloading => state.Downloader.IsDownloading;
private readonly State state;
private readonly ThrottledTask<int> enqueueDownloadItemsTask;
private readonly ThrottledTask<DownloadStatusStatistics> downloadStatisticsTask;
private readonly IDisposable attachmentCountSubscription;
private readonly IDisposable downloadCountSubscription;
private IDisposable? finishedItemsSubscription;
private int doneItemsCount;
private int totalEnqueuedItemCount;
private int? totalItemsToDownloadCount;
public AttachmentsPageModel() : this(State.Dummy) {}
public AttachmentsPageModel(State state) {
this.state = state;
FilterModel = new AttachmentFilterPanelModel(state);
StatisticsRows = [
statisticsEnqueued,
statisticsDownloaded,
statisticsFailed,
statisticsSkipped
];
enqueueDownloadItemsTask = new ThrottledTask<int>(OnItemsEnqueued, TaskScheduler.FromCurrentSynchronizationContext());
downloadStatisticsTask = new ThrottledTask<DownloadStatusStatistics>(UpdateStatistics, TaskScheduler.FromCurrentSynchronizationContext());
attachmentCountSubscription = state.Db.Attachments.TotalCount.ObserveOn(AvaloniaScheduler.Instance).Subscribe(OnAttachmentCountChanged);
downloadCountSubscription = state.Db.Downloads.TotalCount.ObserveOn(AvaloniaScheduler.Instance).Subscribe(OnDownloadCountChanged);
RecomputeDownloadStatistics();
}
public void Dispose() {
attachmentCountSubscription.Dispose();
downloadCountSubscription.Dispose();
finishedItemsSubscription?.Dispose();
enqueueDownloadItemsTask.Dispose();
downloadStatisticsTask.Dispose();
FilterModel.Dispose();
}
private void OnAttachmentCountChanged(long newAttachmentCount) {
if (IsDownloading) {
EnqueueDownloadItemsLater();
}
else {
RecomputeDownloadStatistics();
}
}
private void OnDownloadCountChanged(long newDownloadCount) {
RecomputeDownloadStatistics();
}
private async Task EnqueueDownloadItems() {
OnItemsEnqueued(await state.Db.Downloads.EnqueueDownloadItems(CreateAttachmentFilter()));
}
private void EnqueueDownloadItemsLater() {
var filter = CreateAttachmentFilter();
enqueueDownloadItemsTask.Post(cancellationToken => state.Db.Downloads.EnqueueDownloadItems(filter, cancellationToken));
}
private void OnItemsEnqueued(int itemCount) {
totalEnqueuedItemCount += itemCount;
totalItemsToDownloadCount = totalEnqueuedItemCount;
UpdateDownloadMessage();
RecomputeDownloadStatistics();
}
private AttachmentFilter CreateAttachmentFilter() {
var filter = FilterModel.CreateFilter();
filter.DownloadItemRule = AttachmentFilter.DownloadItemRules.OnlyNotPresent;
return filter;
}
public async Task OnClickToggleDownload() {
IsToggleDownloadButtonEnabled = false;
if (IsDownloading) {
await state.Downloader.Stop();
finishedItemsSubscription?.Dispose();
finishedItemsSubscription = null;
RecomputeDownloadStatistics();
await state.Db.Downloads.RemoveDownloadItems(EnqueuedItemFilter, FilterRemovalMode.RemoveMatching);
doneItemsCount = 0;
totalEnqueuedItemCount = 0;
totalItemsToDownloadCount = null;
UpdateDownloadMessage();
}
else {
var finishedItems = await state.Downloader.Start();
finishedItemsSubscription = finishedItems.Select(static _ => true)
.Buffer(TimeSpan.FromMilliseconds(100))
.Select(static items => items.Count)
.Where(static items => items > 0)
.ObserveOn(AvaloniaScheduler.Instance)
.Subscribe(OnItemsFinished);
await EnqueueDownloadItems();
}
OnPropertyChanged(nameof(ToggleDownloadButtonText));
OnPropertyChanged(nameof(IsDownloading));
IsToggleDownloadButtonEnabled = true;
}
private void OnItemsFinished(int finishedItemCount) {
doneItemsCount += finishedItemCount;
UpdateDownloadMessage();
}
public async Task OnClickRetryFailedDownloads() {
IsRetryingFailedDownloads = true;
try {
var allExceptFailedFilter = new DownloadItemFilter {
IncludeStatuses = new HashSet<DownloadStatus> {
DownloadStatus.Enqueued,
DownloadStatus.Downloading,
DownloadStatus.Success
}
};
await state.Db.Downloads.RemoveDownloadItems(allExceptFailedFilter, FilterRemovalMode.KeepMatching);
if (IsDownloading) {
await EnqueueDownloadItems();
}
} catch (Exception e) {
Log.Error(e);
} finally {
IsRetryingFailedDownloads = false;
}
}
private void RecomputeDownloadStatistics() {
downloadStatisticsTask.Post(state.Db.Downloads.GetStatistics);
}
private void UpdateStatistics(DownloadStatusStatistics statusStatistics) {
statisticsEnqueued.Items = statusStatistics.EnqueuedCount;
statisticsEnqueued.Size = statusStatistics.EnqueuedSize;
statisticsDownloaded.Items = statusStatistics.SuccessfulCount;
statisticsDownloaded.Size = statusStatistics.SuccessfulSize;
statisticsFailed.Items = statusStatistics.FailedCount;
statisticsFailed.Size = statusStatistics.FailedSize;
statisticsSkipped.Items = statusStatistics.SkippedCount;
statisticsSkipped.Size = statusStatistics.SkippedSize;
hasFailedDownloads = statusStatistics.FailedCount > 0;
UpdateDownloadMessage();
}
private void UpdateDownloadMessage() {
DownloadMessage = IsDownloading ? doneItemsCount.Format() + " / " + (totalItemsToDownloadCount?.Format() ?? "?") : "";
OnPropertyChanged(nameof(DownloadProgress));
}
[ObservableObject]
public sealed partial class StatisticsRow(string state) {
public string State { get; } = state;
[ObservableProperty]
private int items;
[ObservableProperty]
private ulong? size;
}
}

View File

@@ -18,7 +18,7 @@ using DHT.Server;
using DHT.Server.Data;
using DHT.Server.Database;
using DHT.Server.Database.Import;
using DHT.Server.Database.Sqlite.Utils;
using DHT.Server.Database.Sqlite.Schema;
using DHT.Utils.Logging;
namespace DHT.Desktop.Main.Pages;

View File

@@ -10,164 +10,164 @@ using DHT.Server;
using DHT.Server.Data;
using DHT.Server.Service;
namespace DHT.Desktop.Main.Pages {
sealed class DebugPageModel {
public string GenerateChannels { get; set; } = "0";
public string GenerateUsers { get; set; } = "0";
public string GenerateMessages { get; set; } = "0";
namespace DHT.Desktop.Main.Pages;
private readonly Window window;
private readonly State state;
sealed class DebugPageModel {
public string GenerateChannels { get; set; } = "0";
public string GenerateUsers { get; set; } = "0";
public string GenerateMessages { get; set; } = "0";
[Obsolete("Designer")]
public DebugPageModel() : this(null!, State.Dummy) {}
private readonly Window window;
private readonly State state;
public DebugPageModel(Window window, State state) {
this.window = window;
this.state = state;
[Obsolete("Designer")]
public DebugPageModel() : this(null!, State.Dummy) {}
public DebugPageModel(Window window, State state) {
this.window = window;
this.state = state;
}
public async void OnClickAddRandomDataToDatabase() {
if (!int.TryParse(GenerateChannels, out int channels) || channels < 1) {
await Dialog.ShowOk(window, "Generate Random Data", "Amount of channels must be at least 1!");
return;
}
public async void OnClickAddRandomDataToDatabase() {
if (!int.TryParse(GenerateChannels, out int channels) || channels < 1) {
await Dialog.ShowOk(window, "Generate Random Data", "Amount of channels must be at least 1!");
return;
}
if (!int.TryParse(GenerateUsers, out int users) || users < 1) {
await Dialog.ShowOk(window, "Generate Random Data", "Amount of users must be at least 1!");
return;
}
if (!int.TryParse(GenerateMessages, out int messages) || messages < 1) {
await Dialog.ShowOk(window, "Generate Random Data", "Amount of messages must be at least 1!");
return;
}
await ProgressDialog.Show(window, "Generating Random Data", async (_, callback) => await GenerateRandomData(channels, users, messages, callback));
if (!int.TryParse(GenerateUsers, out int users) || users < 1) {
await Dialog.ShowOk(window, "Generate Random Data", "Amount of users must be at least 1!");
return;
}
private const int BatchSize = 500;
if (!int.TryParse(GenerateMessages, out int messages) || messages < 1) {
await Dialog.ShowOk(window, "Generate Random Data", "Amount of messages must be at least 1!");
return;
}
private async Task GenerateRandomData(int channelCount, int userCount, int messageCount, IProgressCallback callback) {
int batchCount = (messageCount + BatchSize - 1) / BatchSize;
await callback.Update("Adding messages in batches of " + BatchSize, 0, batchCount);
await ProgressDialog.Show(window, "Generating Random Data", async (_, callback) => await GenerateRandomData(channels, users, messages, callback));
}
var rand = new Random();
var server = new DHT.Server.Data.Server {
Id = RandomId(rand),
Name = RandomName("s"),
Type = ServerType.Server,
};
private const int BatchSize = 500;
var channels = Enumerable.Range(0, channelCount).Select(i => new Channel {
Id = RandomId(rand),
Server = server.Id,
Name = RandomName("c"),
ParentId = null,
Position = i,
Topic = RandomText(rand, 10),
Nsfw = rand.Next(4) == 0,
private async Task GenerateRandomData(int channelCount, int userCount, int messageCount, IProgressCallback callback) {
int batchCount = (messageCount + BatchSize - 1) / BatchSize;
await callback.Update("Adding messages in batches of " + BatchSize, 0, batchCount);
var rand = new Random();
var server = new DHT.Server.Data.Server {
Id = RandomId(rand),
Name = RandomName("s"),
Type = ServerType.Server,
};
var channels = Enumerable.Range(0, channelCount).Select(i => new Channel {
Id = RandomId(rand),
Server = server.Id,
Name = RandomName("c"),
ParentId = null,
Position = i,
Topic = RandomText(rand, 10),
Nsfw = rand.Next(4) == 0,
}).ToArray();
var users = Enumerable.Range(0, userCount).Select(_ => new User {
Id = RandomId(rand),
Name = RandomName("u"),
AvatarUrl = null,
Discriminator = rand.Next(0, 9999).ToString(),
}).ToArray();
await state.Db.Users.Add(users);
await state.Db.Servers.Add([server]);
await state.Db.Channels.Add(channels);
var now = DateTimeOffset.Now;
int batchIndex = 0;
while (messageCount > 0) {
int hourOffset = batchIndex;
var messages = Enumerable.Range(0, Math.Min(messageCount, BatchSize)).Select(i => {
DateTimeOffset time = now.AddHours(hourOffset).AddMinutes(i * 60.0 / BatchSize);
DateTimeOffset? edit = rand.Next(100) == 0 ? time.AddSeconds(rand.Next(1, 60)) : null;
var timeMillis = time.ToUnixTimeMilliseconds();
var editMillis = edit?.ToUnixTimeMilliseconds();
return new Message {
Id = (ulong) timeMillis,
Sender = RandomBiasedIndex(rand, users).Id,
Channel = RandomBiasedIndex(rand, channels).Id,
Text = RandomText(rand, 100),
Timestamp = timeMillis,
EditTimestamp = editMillis,
RepliedToId = null,
Attachments = ImmutableList<Attachment>.Empty,
Embeds = ImmutableList<Embed>.Empty,
Reactions = ImmutableList<Reaction>.Empty,
};
}).ToArray();
var users = Enumerable.Range(0, userCount).Select(_ => new User {
Id = RandomId(rand),
Name = RandomName("u"),
AvatarUrl = null,
Discriminator = rand.Next(0, 9999).ToString(),
}).ToArray();
await state.Db.Messages.Add(messages);
await state.Db.Users.Add(users);
await state.Db.Servers.Add([server]);
await state.Db.Channels.Add(channels);
var now = DateTimeOffset.Now;
int batchIndex = 0;
while (messageCount > 0) {
int hourOffset = batchIndex;
var messages = Enumerable.Range(0, Math.Min(messageCount, BatchSize)).Select(i => {
DateTimeOffset time = now.AddHours(hourOffset).AddMinutes(i * 60.0 / BatchSize);
DateTimeOffset? edit = rand.Next(100) == 0 ? time.AddSeconds(rand.Next(1, 60)) : null;
var timeMillis = time.ToUnixTimeMilliseconds();
var editMillis = edit?.ToUnixTimeMilliseconds();
return new Message {
Id = (ulong) timeMillis,
Sender = RandomBiasedIndex(rand, users).Id,
Channel = RandomBiasedIndex(rand, channels).Id,
Text = RandomText(rand, 100),
Timestamp = timeMillis,
EditTimestamp = editMillis,
RepliedToId = null,
Attachments = ImmutableList<Attachment>.Empty,
Embeds = ImmutableList<Embed>.Empty,
Reactions = ImmutableList<Reaction>.Empty,
};
}).ToArray();
await state.Db.Messages.Add(messages);
messageCount -= BatchSize;
await callback.Update("Adding messages in batches of " + BatchSize, ++batchIndex, batchCount);
}
messageCount -= BatchSize;
await callback.Update("Adding messages in batches of " + BatchSize, ++batchIndex, batchCount);
}
}
private static ulong RandomId(Random rand) {
ulong h = unchecked((ulong) rand.Next());
ulong l = unchecked((ulong) rand.Next());
return (h << 32) | l;
}
private static ulong RandomId(Random rand) {
ulong h = unchecked((ulong) rand.Next());
ulong l = unchecked((ulong) rand.Next());
return (h << 32) | l;
}
private static string RandomName(string prefix) {
return prefix + "-" + ServerUtils.GenerateRandomToken(5);
}
private static string RandomName(string prefix) {
return prefix + "-" + ServerUtils.GenerateRandomToken(5);
}
private static T RandomBiasedIndex<T>(Random rand, T[] options) {
return options[(int) Math.Floor(options.Length * rand.NextDouble() * rand.NextDouble())];
}
private static T RandomBiasedIndex<T>(Random rand, T[] options) {
return options[(int) Math.Floor(options.Length * rand.NextDouble() * rand.NextDouble())];
}
private static readonly string[] RandomWords = [
"apple", "apricot", "artichoke", "arugula", "asparagus", "avocado",
"banana", "bean", "beechnut", "beet", "blackberry", "blackcurrant", "blueberry", "boysenberry", "bramble", "broccoli",
"cabbage", "cacao", "cantaloupe", "caper", "carambola", "carrot", "cauliflower", "celery", "chard", "cherry", "chokeberry", "citron", "clementine", "coconut", "corn", "crabapple", "cranberry", "cucumber", "currant",
"daikon", "date", "dewberry", "durian",
"edamame", "eggplant", "elderberry", "endive",
"fig",
"garlic", "ginger", "gooseberry", "grape", "grapefruit", "guava",
"honeysuckle", "horseradish", "huckleberry",
"jackfruit", "jicama",
"kale", "kiwi", "kohlrabi", "kumquat",
"leek", "lemon", "lentil", "lettuce", "lime",
"mandarin", "mango", "mushroom", "myrtle",
"nectarine", "nut",
"olive", "okra", "onion", "orange",
"papaya", "parsnip", "pawpaw", "peach", "pear", "pea", "pepper", "persimmon", "pineapple", "plum", "plantain", "pomegranate", "pomelo", "potato", "prune", "pumpkin",
"quandong", "quinoa",
"radicchio", "radish", "raisin", "raspberry", "redcurrant", "rhubarb", "rutabaga",
"spinach", "strawberry", "squash",
"tamarind", "tangerine", "tomatillo", "tomato", "turnip",
"vanilla",
"watercress", "watermelon",
"yam",
"zucchini"
];
private static readonly string[] RandomWords = [
"apple", "apricot", "artichoke", "arugula", "asparagus", "avocado",
"banana", "bean", "beechnut", "beet", "blackberry", "blackcurrant", "blueberry", "boysenberry", "bramble", "broccoli",
"cabbage", "cacao", "cantaloupe", "caper", "carambola", "carrot", "cauliflower", "celery", "chard", "cherry", "chokeberry", "citron", "clementine", "coconut", "corn", "crabapple", "cranberry", "cucumber", "currant",
"daikon", "date", "dewberry", "durian",
"edamame", "eggplant", "elderberry", "endive",
"fig",
"garlic", "ginger", "gooseberry", "grape", "grapefruit", "guava",
"honeysuckle", "horseradish", "huckleberry",
"jackfruit", "jicama",
"kale", "kiwi", "kohlrabi", "kumquat",
"leek", "lemon", "lentil", "lettuce", "lime",
"mandarin", "mango", "mushroom", "myrtle",
"nectarine", "nut",
"olive", "okra", "onion", "orange",
"papaya", "parsnip", "pawpaw", "peach", "pear", "pea", "pepper", "persimmon", "pineapple", "plum", "plantain", "pomegranate", "pomelo", "potato", "prune", "pumpkin",
"quandong", "quinoa",
"radicchio", "radish", "raisin", "raspberry", "redcurrant", "rhubarb", "rutabaga",
"spinach", "strawberry", "squash",
"tamarind", "tangerine", "tomatillo", "tomato", "turnip",
"vanilla",
"watercress", "watermelon",
"yam",
"zucchini"
];
private static string RandomText(Random rand, int maxWords) {
int wordCount = 1 + (int) Math.Floor(maxWords * Math.Pow(rand.NextDouble(), 3));
return string.Join(' ', Enumerable.Range(0, wordCount).Select(_ => RandomWords[rand.Next(RandomWords.Length)]));
}
private static string RandomText(Random rand, int maxWords) {
int wordCount = 1 + (int) Math.Floor(maxWords * Math.Pow(rand.NextDouble(), 3));
return string.Join(' ', Enumerable.Range(0, wordCount).Select(_ => RandomWords[rand.Next(RandomWords.Length)]));
}
}
#else
namespace DHT.Desktop.Main.Pages {
sealed class DebugPageModel {
public string GenerateChannels { get; set; } = "0";
public string GenerateUsers { get; set; } = "0";
public string GenerateMessages { get; set; } = "0";
namespace DHT.Desktop.Main.Pages;
public void OnClickAddRandomDataToDatabase() {}
}
sealed class DebugPageModel {
public string GenerateChannels { get; set; } = "0";
public string GenerateUsers { get; set; } = "0";
public string GenerateMessages { get; set; } = "0";
public void OnClickAddRandomDataToDatabase() {}
}
#endif

View File

@@ -5,11 +5,11 @@
xmlns:pages="clr-namespace:DHT.Desktop.Main.Pages"
xmlns:controls="clr-namespace:DHT.Desktop.Main.Controls"
mc:Ignorable="d" d:DesignWidth="800" d:DesignHeight="450"
x:Class="DHT.Desktop.Main.Pages.AttachmentsPage"
x:DataType="pages:AttachmentsPageModel">
x:Class="DHT.Desktop.Main.Pages.DownloadsPage"
x:DataType="pages:DownloadsPageModel">
<Design.DataContext>
<pages:AttachmentsPageModel />
<pages:DownloadsPageModel />
</Design.DataContext>
<UserControl.Styles>
@@ -31,19 +31,15 @@
</UserControl.Styles>
<StackPanel Orientation="Vertical" Spacing="20">
<DockPanel>
<Button Command="{Binding OnClickToggleDownload}" Content="{Binding ToggleDownloadButtonText}" IsEnabled="{Binding IsToggleDownloadButtonEnabled}" DockPanel.Dock="Left" />
<TextBlock Text="{Binding DownloadMessage}" MinWidth="100" Margin="10 0 0 0" VerticalAlignment="Center" TextAlignment="Right" DockPanel.Dock="Left" />
<ProgressBar Value="{Binding DownloadProgress}" IsVisible="{Binding IsDownloading}" Margin="15 0" VerticalAlignment="Center" DockPanel.Dock="Right" />
</DockPanel>
<controls:AttachmentFilterPanel DataContext="{Binding FilterModel}" IsEnabled="{Binding !IsDownloading, RelativeSource={RelativeSource AncestorType=pages:AttachmentsPageModel}}" />
<Button Command="{Binding OnClickToggleDownload}" Content="{Binding ToggleDownloadButtonText}" IsEnabled="{Binding IsToggleDownloadButtonEnabled}" />
<controls:DownloadItemFilterPanel DataContext="{Binding FilterModel}" IsEnabled="{Binding !$parent[UserControl].((pages:DownloadsPageModel)DataContext).IsDownloading}" />
<StackPanel Orientation="Vertical" Spacing="12">
<Expander Header="Download Status" IsExpanded="True">
<DataGrid ItemsSource="{Binding StatisticsRows}" AutoGenerateColumns="False" CanUserReorderColumns="False" CanUserResizeColumns="False" CanUserSortColumns="False" IsReadOnly="True">
<DataGrid.Columns>
<DataGridTextColumn Header="State" Binding="{Binding State}" Width="*" />
<DataGridTextColumn Header="Attachments" Binding="{Binding Items, Mode=OneWay, Converter={StaticResource NumberValueConverter}}" Width="*" CellStyleClasses="right" />
<DataGridTextColumn Header="Size" Binding="{Binding Size, Mode=OneWay, Converter={StaticResource BytesValueConverter}}" Width="*" CellStyleClasses="right" />
<DataGridTextColumn Header="State" Binding="{Binding State, Mode=OneWay}" Width="*" />
<DataGridTextColumn Header="Files" Binding="{Binding Items, Mode=OneWay, Converter={StaticResource NumberValueConverter}}" Width="*" CellStyleClasses="right" />
<DataGridTextColumn Header="Size" Binding="{Binding SizeText, Mode=OneWay}" Width="*" CellStyleClasses="right" />
</DataGrid.Columns>
</DataGrid>
</Expander>

View File

@@ -4,8 +4,8 @@ using Avalonia.Controls;
namespace DHT.Desktop.Main.Pages;
[SuppressMessage("ReSharper", "MemberCanBeInternal")]
public sealed partial class AttachmentsPage : UserControl {
public AttachmentsPage() {
public sealed partial class DownloadsPage : UserControl {
public DownloadsPage() {
InitializeComponent();
}
}

View File

@@ -0,0 +1,186 @@
using System;
using System.Collections.ObjectModel;
using System.Reactive.Linq;
using System.Threading.Tasks;
using Avalonia.ReactiveUI;
using CommunityToolkit.Mvvm.ComponentModel;
using DHT.Desktop.Common;
using DHT.Desktop.Main.Controls;
using DHT.Server;
using DHT.Server.Data.Aggregations;
using DHT.Server.Data.Filters;
using DHT.Server.Download;
using DHT.Utils.Logging;
using DHT.Utils.Tasks;
namespace DHT.Desktop.Main.Pages;
sealed partial class DownloadsPageModel : ObservableObject, IDisposable {
private static readonly Log Log = Log.ForType<DownloadsPageModel>();
[ObservableProperty(Setter = Access.Private)]
private bool isToggleDownloadButtonEnabled = true;
public string ToggleDownloadButtonText => IsDownloading ? "Stop Downloading" : "Start Downloading";
[ObservableProperty(Setter = Access.Private)]
[NotifyPropertyChangedFor(nameof(IsRetryFailedOnDownloadsButtonEnabled))]
private bool isRetryingFailedDownloads = false;
[ObservableProperty(Setter = Access.Private)]
[NotifyPropertyChangedFor(nameof(IsRetryFailedOnDownloadsButtonEnabled))]
private bool hasFailedDownloads;
public bool IsRetryFailedOnDownloadsButtonEnabled => !IsRetryingFailedDownloads && HasFailedDownloads;
[ObservableProperty(Setter = Access.Private)]
private string downloadMessage = "";
public DownloadItemFilterPanelModel FilterModel { get; }
private readonly StatisticsRow statisticsPending = new ("Pending");
private readonly StatisticsRow statisticsDownloaded = new ("Downloaded");
private readonly StatisticsRow statisticsFailed = new ("Failed");
private readonly StatisticsRow statisticsSkipped = new ("Skipped");
public ObservableCollection<StatisticsRow> StatisticsRows { get; }
public bool IsDownloading => state.Downloader.IsDownloading;
private readonly State state;
private readonly ThrottledTask<DownloadStatusStatistics> downloadStatisticsTask;
private readonly IDisposable downloadItemCountSubscription;
private IDisposable? finishedItemsSubscription;
private DownloadItemFilter? currentDownloadFilter;
public DownloadsPageModel() : this(State.Dummy) {}
public DownloadsPageModel(State state) {
this.state = state;
FilterModel = new DownloadItemFilterPanelModel(state);
StatisticsRows = [
statisticsPending,
statisticsDownloaded,
statisticsFailed,
statisticsSkipped
];
downloadStatisticsTask = new ThrottledTask<DownloadStatusStatistics>(Log, UpdateStatistics, TaskScheduler.FromCurrentSynchronizationContext());
downloadItemCountSubscription = state.Db.Downloads.TotalCount.ObserveOn(AvaloniaScheduler.Instance).Subscribe(OnDownloadCountChanged);
RecomputeDownloadStatistics();
}
public void Dispose() {
finishedItemsSubscription?.Dispose();
downloadItemCountSubscription.Dispose();
downloadStatisticsTask.Dispose();
FilterModel.Dispose();
}
private void OnDownloadCountChanged(long newDownloadCount) {
RecomputeDownloadStatistics();
}
public async Task OnClickToggleDownload() {
IsToggleDownloadButtonEnabled = false;
if (IsDownloading) {
await state.Downloader.Stop();
await state.Db.Downloads.MoveDownloadingItemsBackToQueue();
finishedItemsSubscription?.Dispose();
finishedItemsSubscription = null;
currentDownloadFilter = null;
}
else {
await state.Db.Downloads.MoveDownloadingItemsBackToQueue();
var finishedItems = await state.Downloader.Start(currentDownloadFilter = FilterModel.CreateFilter());
finishedItemsSubscription = finishedItems.ObserveOn(AvaloniaScheduler.Instance).Subscribe(OnItemFinished);
}
RecomputeDownloadStatistics();
OnPropertyChanged(nameof(ToggleDownloadButtonText));
OnPropertyChanged(nameof(IsDownloading));
IsToggleDownloadButtonEnabled = true;
}
private void OnItemFinished(DownloadItem item) {
RecomputeDownloadStatistics();
}
public async Task OnClickRetryFailedDownloads() {
IsRetryingFailedDownloads = true;
try {
await state.Db.Downloads.RetryFailed();
RecomputeDownloadStatistics();
} catch (Exception e) {
Log.Error(e);
} finally {
IsRetryingFailedDownloads = false;
}
}
private void RecomputeDownloadStatistics() {
downloadStatisticsTask.Post(cancellationToken => state.Db.Downloads.GetStatistics(currentDownloadFilter ?? new DownloadItemFilter(), cancellationToken));
}
private void UpdateStatistics(DownloadStatusStatistics statusStatistics) {
statisticsPending.Items = statusStatistics.PendingCount;
statisticsPending.Size = statusStatistics.PendingTotalSize;
statisticsPending.HasFilesWithUnknownSize = statusStatistics.PendingWithUnknownSizeCount > 0;
statisticsDownloaded.Items = statusStatistics.SuccessfulCount;
statisticsDownloaded.Size = statusStatistics.SuccessfulTotalSize;
statisticsDownloaded.HasFilesWithUnknownSize = statusStatistics.SuccessfulWithUnknownSizeCount > 0;
statisticsFailed.Items = statusStatistics.FailedCount;
statisticsFailed.Size = statusStatistics.FailedTotalSize;
statisticsFailed.HasFilesWithUnknownSize = statusStatistics.FailedWithUnknownSizeCount > 0;
statisticsSkipped.Items = statusStatistics.SkippedCount;
statisticsSkipped.Size = statusStatistics.SkippedTotalSize;
statisticsSkipped.HasFilesWithUnknownSize = statusStatistics.SkippedWithUnknownSizeCount > 0;
HasFailedDownloads = statusStatistics.FailedCount > 0;
}
[ObservableObject]
public sealed partial class StatisticsRow(string state) {
public string State { get; } = state;
[ObservableProperty]
private int items;
[ObservableProperty]
[NotifyPropertyChangedFor(nameof(SizeText))]
private ulong? size;
[ObservableProperty]
[NotifyPropertyChangedFor(nameof(SizeText))]
private bool hasFilesWithUnknownSize;
public string SizeText {
get {
if (size == null) {
return "-";
}
else if (hasFilesWithUnknownSize) {
return "\u2265 " + BytesValueConverter.Convert(size.Value);
}
else {
return BytesValueConverter.Convert(size.Value);
}
}
}
}
}

View File

@@ -18,7 +18,6 @@ using DHT.Desktop.Server;
using DHT.Server;
using DHT.Server.Data.Filters;
using DHT.Server.Database.Export;
using DHT.Server.Database.Export.Strategy;
using static DHT.Desktop.Program;
namespace DHT.Desktop.Main.Pages;
@@ -63,9 +62,13 @@ sealed partial class ViewerPageModel : ObservableObject, IDisposable {
public async void OnClickOpenViewer() {
try {
var fullPath = await PrepareTemporaryViewerFile();
var strategy = new LiveViewerExportStrategy(ServerConfiguration.Port, ServerConfiguration.Token);
string jsConstants = $"""
window.DHT_SERVER_URL = "{HttpUtility.JavaScriptStringEncode("http://127.0.0.1:" + ServerConfiguration.Port)}";
window.DHT_SERVER_TOKEN = "{HttpUtility.JavaScriptStringEncode(ServerConfiguration.Token)}";
""";
await ProgressDialog.ShowIndeterminate(window, "Open Viewer", "Creating viewer...", _ => Task.Run(() => WriteViewerFile(fullPath, strategy)));
await ProgressDialog.ShowIndeterminate(window, "Open Viewer", "Creating viewer...", _ => Task.Run(() => WriteViewerFile(fullPath, jsConstants)));
Process.Start(new ProcessStartInfo(fullPath) {
UseShellExecute = true
@@ -109,17 +112,18 @@ sealed partial class ViewerPageModel : ObservableObject, IDisposable {
}
try {
await ProgressDialog.ShowIndeterminate(window, "Save Viewer", "Creating viewer...", _ => Task.Run(() => WriteViewerFile(path, StandaloneViewerExportStrategy.Instance)));
await ProgressDialog.ShowIndeterminate(window, "Save Viewer", "Creating viewer...", _ => Task.Run(() => WriteViewerFile(path, string.Empty)));
} catch (Exception e) {
await Dialog.ShowOk(window, "Save Viewer", "Could not create or save viewer: " + e.Message);
}
}
private async Task WriteViewerFile(string path, IViewerExportStrategy strategy) {
private async Task WriteViewerFile(string path, string jsConstants) {
const string ArchiveTag = "/*[ARCHIVE]*/";
string indexFile = await Resources.ReadTextAsync("Viewer/index.html");
string viewerTemplate = indexFile.Replace("/*[JS]*/", await Resources.ReadJoinedAsync("Viewer/scripts/", '\n'))
string viewerTemplate = indexFile.Replace("/*[CONSTANTS]*/", jsConstants)
.Replace("/*[JS]*/", await Resources.ReadJoinedAsync("Viewer/scripts/", '\n'))
.Replace("/*[CSS]*/", await Resources.ReadJoinedAsync("Viewer/styles/", '\n'));
int viewerArchiveTagStart = viewerTemplate.IndexOf(ArchiveTag);
@@ -128,7 +132,7 @@ sealed partial class ViewerPageModel : ObservableObject, IDisposable {
string jsonTempFile = path + ".tmp";
await using (var jsonStream = new FileStream(jsonTempFile, FileMode.Create, FileAccess.ReadWrite, FileShare.Read)) {
await ViewerJsonExport.Generate(jsonStream, strategy, state.Db, FilterModel.CreateFilter());
await ViewerJsonExport.Generate(jsonStream, state.Db, FilterModel.CreateFilter());
char[] jsonBuffer = new char[Math.Min(32768, jsonStream.Position)];
jsonStream.Position = 0;

View File

@@ -74,7 +74,7 @@
<DockPanel>
<Border Classes="statusBar" DockPanel.Dock="Bottom">
<DockPanel>
<TextBlock Classes="invisibleTabItem" DockPanel.Dock="Left">Attachments</TextBlock>
<TextBlock Classes="invisibleTabItem" DockPanel.Dock="Left">Downloads</TextBlock>
<controls:StatusBar DataContext="{Binding StatusBarModel}" DockPanel.Dock="Right" />
</DockPanel>
</Border>
@@ -94,9 +94,9 @@
<ContentPresenter Content="{Binding TrackingPage}" Classes="page" />
</ScrollViewer>
</TabItem>
<TabItem x:Name="TabAttachments" Header="Attachments" Grid.Row="2">
<TabItem x:Name="TabDownloads" Header="Downloads" Grid.Row="2">
<ScrollViewer>
<ContentPresenter Content="{Binding AttachmentsPage}" Classes="page" />
<ContentPresenter Content="{Binding DownloadsPage}" Classes="page" />
</ScrollViewer>
</TabItem>
<TabItem x:Name="TabViewer" Header="Viewer" Grid.Row="3">

View File

@@ -13,8 +13,8 @@ sealed class MainContentScreenModel : IDisposable {
public TrackingPage TrackingPage { get; }
private TrackingPageModel TrackingPageModel { get; }
public AttachmentsPage AttachmentsPage { get; }
private AttachmentsPageModel AttachmentsPageModel { get; }
public DownloadsPage DownloadsPage { get; }
private DownloadsPageModel DownloadsPageModel { get; }
public ViewerPage ViewerPage { get; }
private ViewerPageModel ViewerPageModel { get; }
@@ -52,8 +52,8 @@ sealed class MainContentScreenModel : IDisposable {
TrackingPageModel = new TrackingPageModel(window);
TrackingPage = new TrackingPage { DataContext = TrackingPageModel };
AttachmentsPageModel = new AttachmentsPageModel(state);
AttachmentsPage = new AttachmentsPage { DataContext = AttachmentsPageModel };
DownloadsPageModel = new DownloadsPageModel(state);
DownloadsPage = new DownloadsPage { DataContext = DownloadsPageModel };
ViewerPageModel = new ViewerPageModel(window, state);
ViewerPage = new ViewerPage { DataContext = ViewerPageModel };
@@ -72,7 +72,7 @@ sealed class MainContentScreenModel : IDisposable {
}
public void Dispose() {
AttachmentsPageModel.Dispose();
DownloadsPageModel.Dispose();
ViewerPageModel.Dispose();
AdvancedPageModel.Dispose();
StatusBarModel.Dispose();

View File

@@ -8,7 +8,7 @@ using DHT.Desktop.Common;
using DHT.Desktop.Dialogs.Message;
using DHT.Desktop.Dialogs.Progress;
using DHT.Server.Database;
using DHT.Server.Database.Sqlite.Utils;
using DHT.Server.Database.Sqlite.Schema;
namespace DHT.Desktop.Main.Screens;

View File

@@ -48,6 +48,16 @@ const STATE = (function() {
});
};
const getDate = function(date) {
if (date instanceof Date) {
return date;
}
else {
// noinspection JSUnresolvedReference
return date.toDate();
}
};
const trackingStateChangedListeners = [];
let isTracking = false;
@@ -69,8 +79,8 @@ const STATE = (function() {
* @property {String} channel_id
* @property {DiscordUser} author
* @property {String} content
* @property {Timestamp} timestamp
* @property {Timestamp|null} editedTimestamp
* @property {Date} timestamp
* @property {Date|null} editedTimestamp
* @property {DiscordAttachment[]} attachments
* @property {Object[]} embeds
* @property {DiscordMessageReaction[]} [reactions]
@@ -106,11 +116,6 @@ const STATE = (function() {
* @property {Boolean} animated
*/
/**
* @name Timestamp
* @property {Function} toDate
*/
return {
setup(port, token) {
serverPort = port;
@@ -223,12 +228,12 @@ const STATE = (function() {
sender: msg.author.id,
channel: msg.channel_id,
text: msg.content,
timestamp: msg.timestamp.toDate().getTime()
timestamp: getDate(msg.timestamp).getTime()
};
if (msg.editedTimestamp !== null) {
// noinspection JSUnusedGlobalSymbols
obj.editTimestamp = msg.editedTimestamp.toDate().getTime();
obj.editTimestamp = getDate(msg.editedTimestamp).getTime();
}
if (msg.messageReference !== null) {

View File

@@ -6,6 +6,7 @@
<script type="text/javascript">
window.DHT_EMBEDDED = "/*[ARCHIVE]*/";
/*[CONSTANTS]*/
/*[JS]*/
</script>
<style>

View File

@@ -35,6 +35,23 @@ const DISCORD = (function() {
let templateReaction;
let templateReactionCustom;
const fileUrlProcessor = function(serverUrl, serverToken) {
if (typeof serverUrl === "string" && typeof serverToken === "string") {
return url => serverUrl + "/get-downloaded-file/" + encodeURIComponent(url) + "?token=" + encodeURIComponent(serverToken);
}
else {
return url => url;
}
}(
window["DHT_SERVER_URL"],
window["DHT_SERVER_TOKEN"]
);
const getEmoji = function(name, id, extension) {
const tag = ":" + name + ":";
return "<img src='" + fileUrlProcessor("https://cdn.discordapp.com/emojis/" + id + "." + extension) + "' alt='" + tag + "' title='" + tag + "' class='emoji'>";
};
const processMessageContents = function(contents) {
let processed = DOM.escapeHTML(contents.replace(regex.formatUrlNoEmbed, "$1"));
@@ -54,29 +71,33 @@ const DISCORD = (function() {
.replace(regex.formatStrike, "<s>$1</s>");
}
const animatedEmojiExtension = SETTINGS.enableAnimatedEmoji ? "gif" : "png";
const animatedEmojiExtension = SETTINGS.enableAnimatedEmoji ? "gif" : "webp";
// noinspection HtmlUnknownTarget
processed = processed
.replace(regex.formatUrl, "<a href='$1' target='_blank' rel='noreferrer'>$1</a>")
.replace(regex.mentionChannel, (full, match) => "<span class='link mention-chat'>#" + STATE.getChannelName(match) + "</span>")
.replace(regex.mentionUser, (full, match) => "<span class='link mention-user' title='#" + (STATE.getUserTag(match) || "????") + "'>@" + STATE.getUserName(match) + "</span>")
.replace(regex.customEmojiStatic, "<img src='https://cdn.discordapp.com/emojis/$2.png' alt=':$1:' title=':$1:' class='emoji'>")
.replace(regex.customEmojiAnimated, "<img src='https://cdn.discordapp.com/emojis/$2." + animatedEmojiExtension + "' alt=':$1:' title=':$1:' class='emoji'>");
.replace(regex.customEmojiStatic, (full, m1, m2) => getEmoji(m1, m2, "webp"))
.replace(regex.customEmojiAnimated, (full, m1, m2) => getEmoji(m1, m2, animatedEmojiExtension));
return "<p>" + processed + "</p>";
};
const getAvatarUrlObject = function(avatar) {
return { url: fileUrlProcessor("https://cdn.discordapp.com/avatars/" + avatar.id + "/" + avatar.path + ".webp") };
};
const getImageEmbed = function(url, image) {
if (!SETTINGS.enableImagePreviews) {
return "";
}
if (image.width && image.height) {
return templateEmbedImageWithSize.apply({ url, src: image.url, width: image.width, height: image.height });
return templateEmbedImageWithSize.apply({ url: fileUrlProcessor(url), src: fileUrlProcessor(image.url), width: image.width, height: image.height });
}
else {
return templateEmbedImage.apply({ url, src: image.url });
return templateEmbedImage.apply({ url: fileUrlProcessor(url), src: fileUrlProcessor(image.url) });
}
};
@@ -125,8 +146,9 @@ const DISCORD = (function() {
"</div>"
].join(""));
// noinspection HtmlUnknownTarget
templateUserAvatar = new TEMPLATE([
"<img src='https://cdn.discordapp.com/avatars/{id}/{path}.webp?size=128' alt=''>"
"<img src='{url}' alt=''>"
].join(""));
// noinspection HtmlUnknownTarget
@@ -167,8 +189,9 @@ const DISCORD = (function() {
"<span class='reaction-wrapper'><span class='reaction-emoji'>{n}</span><span class='count'>{c}</span></span>"
].join(""));
// noinspection HtmlUnknownTarget
templateReactionCustom = new TEMPLATE([
"<span class='reaction-wrapper'><img src='https://cdn.discordapp.com/emojis/{id}.{ext}' alt=':{n}:' title=':{n}:' class='reaction-emoji-custom'><span class='count'>{c}</span></span>"
"<span class='reaction-wrapper'><img src='{url}' alt=':{n}:' title=':{n}:' class='reaction-emoji-custom'><span class='count'>{c}</span></span>"
].join(""));
},
@@ -199,7 +222,7 @@ const DISCORD = (function() {
getMessageHTML(message) { // noinspection FunctionWithInconsistentReturnsJS
return (SETTINGS.enableUserAvatars ? templateMessageWithAvatar : templateMessageNoAvatar).apply(message, (property, value) => {
if (property === "avatar") {
return value ? templateUserAvatar.apply(value) : "";
return value ? templateUserAvatar.apply(getAvatarUrlObject(value)) : "";
}
else if (property === "user.tag") {
return value ? value : "????";
@@ -220,10 +243,10 @@ const DISCORD = (function() {
return templateEmbedUnsupported.apply(embed);
}
else if ("image" in embed && embed.image.url) {
return getImageEmbed(embed.url, embed.image);
return getImageEmbed(fileUrlProcessor(embed.url), embed.image);
}
else if ("thumbnail" in embed && embed.thumbnail.url) {
return getImageEmbed(embed.url, embed.thumbnail);
return getImageEmbed(fileUrlProcessor(embed.url), embed.thumbnail);
}
else if ("title" in embed && "description" in embed) {
return templateEmbedRich.apply(embed);
@@ -242,14 +265,16 @@ const DISCORD = (function() {
}
return value.map(attachment => {
const url = fileUrlProcessor(attachment.url);
if (!DISCORD.isImageAttachment(attachment) || !SETTINGS.enableImagePreviews) {
return templateAttachmentDownload.apply(attachment);
return templateAttachmentDownload.apply({ url, name: attachment.name });
}
else if ("width" in attachment && "height" in attachment) {
return templateEmbedImageWithSize.apply({ url: attachment.url, src: attachment.url, width: attachment.width, height: attachment.height });
return templateEmbedImageWithSize.apply({ url, src: url, width: attachment.width, height: attachment.height });
}
else {
return templateEmbedImage.apply({ url: attachment.url, src: attachment.url });
return templateEmbedImage.apply({ url, src: url });
}
}).join("");
}
@@ -265,7 +290,7 @@ const DISCORD = (function() {
}
const user = "<span class='reply-username' title='#" + (value.user.tag ? value.user.tag : "????") + "'>" + value.user.name + "</span>";
const avatar = SETTINGS.enableUserAvatars && value.avatar ? "<span class='reply-avatar'>" + templateUserAvatar.apply(value.avatar) + "</span>" : "";
const avatar = SETTINGS.enableUserAvatars && value.avatar ? "<span class='reply-avatar'>" + templateUserAvatar.apply(getAvatarUrlObject(value.avatar)) + "</span>" : "";
const contents = value.contents ? "<span class='reply-contents'>" + processMessageContents(value.contents) + "</span>" : "";
return "<span class='jump' data-jump='" + value.id + "'>Jump to reply</span><span class='user'>" + avatar + user + "</span>" + contents;
@@ -277,9 +302,10 @@ const DISCORD = (function() {
return "<div class='reactions'>" + value.map(reaction => {
if ("id" in reaction){
// noinspection JSUnusedGlobalSymbols, JSUnresolvedVariable
reaction.ext = reaction.a && SETTINGS.enableAnimatedEmoji ? "gif" : "png";
return templateReactionCustom.apply(reaction);
const ext = reaction.a && SETTINGS.enableAnimatedEmoji ? "gif" : "webp";
const url = fileUrlProcessor("https://cdn.discordapp.com/emojis/" + reaction.id + "." + ext);
// noinspection JSUnusedGlobalSymbols
return templateReactionCustom.apply({ url, n: reaction.n, c: reaction.c });
}
else {
return templateReaction.apply(reaction);

View File

@@ -1,15 +1,19 @@
namespace DHT.Server.Data.Aggregations;
public sealed class DownloadStatusStatistics {
public int EnqueuedCount { get; internal set; }
public ulong EnqueuedSize { get; internal set; }
public int PendingCount { get; internal init; }
public ulong PendingTotalSize { get; internal init; }
public int PendingWithUnknownSizeCount { get; internal init; }
public int SuccessfulCount { get; internal set; }
public ulong SuccessfulSize { get; internal set; }
public int SuccessfulCount { get; internal init; }
public ulong SuccessfulTotalSize { get; internal init; }
public int SuccessfulWithUnknownSizeCount { get; internal init; }
public int FailedCount { get; internal set; }
public ulong FailedSize { get; internal set; }
public int SkippedCount { get; internal set; }
public ulong SkippedSize { get; internal set; }
public int FailedCount { get; internal init; }
public ulong FailedTotalSize { get; internal init; }
public int FailedWithUnknownSizeCount { get; internal init; }
public int SkippedCount { get; internal init; }
public ulong SkippedTotalSize { get; internal init; }
public int SkippedWithUnknownSizeCount { get; internal init; }
}

View File

@@ -1,33 +1,17 @@
using System;
using System.Net;
using DHT.Server.Download;
namespace DHT.Server.Data;
public readonly struct Download {
internal static Download NewSuccess(DownloadItem item, byte[] data) {
return new Download(item.NormalizedUrl, item.DownloadUrl, DownloadStatus.Success, (ulong) Math.Max(data.LongLength, 0), data);
}
internal static Download NewFailure(DownloadItem item, HttpStatusCode? statusCode, ulong size) {
return new Download(item.NormalizedUrl, item.DownloadUrl, statusCode.HasValue ? (DownloadStatus) (int) statusCode : DownloadStatus.GenericError, size);
}
public sealed class Download {
public string NormalizedUrl { get; }
public string DownloadUrl { get; }
public DownloadStatus Status { get; }
public ulong Size { get; }
public byte[]? Data { get; }
public string? Type { get; }
public ulong? Size { get; }
internal Download(string normalizedUrl, string downloadUrl, DownloadStatus status, ulong size, byte[]? data = null) {
internal Download(string normalizedUrl, string downloadUrl, DownloadStatus status, string? type, ulong? size) {
NormalizedUrl = normalizedUrl;
DownloadUrl = downloadUrl;
Status = status;
Type = type;
Size = size;
Data = data;
}
internal Download WithData(byte[] data) {
return new Download(NormalizedUrl, DownloadUrl, Status, Size, data);
}
}

View File

@@ -6,8 +6,9 @@ namespace DHT.Server.Data;
/// Extends <see cref="HttpStatusCode"/> with custom status codes in the range 0-99.
/// </summary>
public enum DownloadStatus {
Enqueued = 0,
Pending = 0,
GenericError = 1,
Downloading = 2,
LastCustomCode = 99,
Success = HttpStatusCode.OK
}

View File

@@ -1,6 +0,0 @@
namespace DHT.Server.Data;
public readonly struct DownloadedAttachment {
public string? Type { get; internal init; }
public byte[] Data { get; internal init; }
}

View File

@@ -0,0 +1,17 @@
namespace DHT.Server.Data.Embeds;
sealed class DiscordEmbedJson {
public string? Type { get; set; }
public string? Url { get; set; }
public JsonImage? Image { get; set; }
public JsonImage? Thumbnail { get; set; }
public JsonImage? Video { get; set; }
public sealed class JsonImage {
public string? Url { get; set; }
public string? ProxyUrl { get; set; }
public int? Width { get; set; }
public int? Height { get; set; }
}
}

View File

@@ -0,0 +1,7 @@
using System.Text.Json.Serialization;
namespace DHT.Server.Data.Embeds;
[JsonSourceGenerationOptions(PropertyNamingPolicy = JsonKnownNamingPolicy.SnakeCaseLower, GenerationMode = JsonSourceGenerationMode.Default)]
[JsonSerializable(typeof(DiscordEmbedJson))]
sealed partial class DiscordEmbedJsonContext : JsonSerializerContext;

View File

@@ -1,15 +0,0 @@
namespace DHT.Server.Data.Filters;
public sealed class AttachmentFilter {
public ulong? MaxBytes { get; set; } = null;
public DownloadItemRules? DownloadItemRule { get; set; } = null;
public bool IsEmpty => MaxBytes == null &&
DownloadItemRule == null;
public enum DownloadItemRules {
OnlyNotPresent,
OnlyPresent
}
}

View File

@@ -3,8 +3,10 @@ using System.Collections.Generic;
namespace DHT.Server.Data.Filters;
public sealed class DownloadItemFilter {
public HashSet<DownloadStatus>? IncludeStatuses { get; init; } = null;
public HashSet<DownloadStatus>? ExcludeStatuses { get; init; } = null;
public HashSet<DownloadStatus>? IncludeStatuses { get; set; } = null;
public HashSet<DownloadStatus>? ExcludeStatuses { get; set; } = null;
public ulong? MaxBytes { get; set; } = null;
public bool IsEmpty => IncludeStatuses == null && ExcludeStatuses == null;
public bool IsEmpty => IncludeStatuses == null && ExcludeStatuses == null && MaxBytes == null;
}

View File

@@ -25,8 +25,10 @@ public static class DatabaseExtensions {
await target.Messages.Add(batchedMessages);
await foreach (var download in source.Downloads.GetWithoutData()) {
await target.Downloads.AddDownload(download.Status == DownloadStatus.Success ? await source.Downloads.HydrateWithData(download) : download);
await foreach (var download in source.Downloads.Get()) {
if (download.Status != DownloadStatus.Success || !await source.Downloads.GetDownloadData(download.NormalizedUrl, stream => target.Downloads.AddDownload(download, stream))) {
await target.Downloads.AddDownload(download, stream: null);
}
}
}
}

View File

@@ -14,7 +14,6 @@ sealed class DummyDatabaseFile : IDatabaseFile {
public IServerRepository Servers { get; } = new IServerRepository.Dummy();
public IChannelRepository Channels { get; } = new IChannelRepository.Dummy();
public IMessageRepository Messages { get; } = new IMessageRepository.Dummy();
public IAttachmentRepository Attachments { get; } = new IAttachmentRepository.Dummy();
public IDownloadRepository Downloads { get; } = new IDownloadRepository.Dummy();
private DummyDatabaseFile() {}

View File

@@ -5,9 +5,9 @@ namespace DHT.Server.Database.Exceptions;
public sealed class DatabaseTooNewException : Exception {
public int DatabaseVersion { get; }
public int CurrentVersion => Schema.Version;
public int CurrentVersion => SqliteSchema.Version;
internal DatabaseTooNewException(int databaseVersion) : base("Database is too new: " + databaseVersion + " > " + Schema.Version) {
internal DatabaseTooNewException(int databaseVersion) : base("Database is too new: " + databaseVersion + " > " + SqliteSchema.Version) {
this.DatabaseVersion = databaseVersion;
}
}

View File

@@ -1,7 +0,0 @@
using DHT.Server.Data;
namespace DHT.Server.Database.Export.Strategy;
public interface IViewerExportStrategy {
string GetAttachmentUrl(Attachment attachment);
}

View File

@@ -1,18 +0,0 @@
using System.Net;
using DHT.Server.Data;
namespace DHT.Server.Database.Export.Strategy;
public sealed class LiveViewerExportStrategy : IViewerExportStrategy {
private readonly string safePort;
private readonly string safeToken;
public LiveViewerExportStrategy(ushort port, string token) {
this.safePort = port.ToString();
this.safeToken = WebUtility.UrlEncode(token);
}
public string GetAttachmentUrl(Attachment attachment) {
return "http://127.0.0.1:" + safePort + "/get-attachment/" + WebUtility.UrlEncode(attachment.NormalizedUrl) + "?token=" + safeToken;
}
}

View File

@@ -1,18 +0,0 @@
using DHT.Server.Data;
namespace DHT.Server.Database.Export.Strategy;
public sealed class StandaloneViewerExportStrategy : IViewerExportStrategy {
public static StandaloneViewerExportStrategy Instance { get; } = new ();
private StandaloneViewerExportStrategy() {}
public string GetAttachmentUrl(Attachment attachment) {
// The normalized URL will not load files from Discord CDN once the time limit is enforced.
// The downloaded URL would work, but only for a limited time, so it is better for the links to not work
// rather than give users a false sense of security.
return attachment.NormalizedUrl;
}
}

View File

@@ -6,7 +6,6 @@ using System.Text.Json;
using System.Threading.Tasks;
using DHT.Server.Data;
using DHT.Server.Data.Filters;
using DHT.Server.Database.Export.Strategy;
using DHT.Utils.Logging;
namespace DHT.Server.Database.Export;
@@ -14,7 +13,7 @@ namespace DHT.Server.Database.Export;
public static class ViewerJsonExport {
private static readonly Log Log = Log.ForType(typeof(ViewerJsonExport));
public static async Task Generate(Stream stream, IViewerExportStrategy strategy, IDatabaseFile db, MessageFilter? filter = null) {
public static async Task Generate(Stream stream, IDatabaseFile db, MessageFilter? filter = null) {
var perf = Log.Start();
var includedUserIds = new HashSet<ulong>();
@@ -49,7 +48,7 @@ public static class ViewerJsonExport {
Servers = servers,
Channels = channels
},
Data = GenerateMessageList(includedMessages, userIndices, strategy)
Data = GenerateMessageList(includedMessages, userIndices)
};
perf.Step("Generate value object");
@@ -125,7 +124,7 @@ public static class ViewerJsonExport {
return channels;
}
private static Dictionary<Snowflake, Dictionary<Snowflake, ViewerJson.JsonMessage>> GenerateMessageList(List<Message> includedMessages, Dictionary<ulong, int> userIndices, IViewerExportStrategy strategy) {
private static Dictionary<Snowflake, Dictionary<Snowflake, ViewerJson.JsonMessage>> GenerateMessageList(List<Message> includedMessages, Dictionary<ulong, int> userIndices) {
var data = new Dictionary<Snowflake, Dictionary<Snowflake, ViewerJson.JsonMessage>>();
foreach (var grouping in includedMessages.GroupBy(static message => message.Channel)) {
@@ -142,9 +141,9 @@ public static class ViewerJsonExport {
Te = message.EditTimestamp,
R = message.RepliedToId?.ToString(),
A = message.Attachments.IsEmpty ? null : message.Attachments.Select(attachment => {
A = message.Attachments.IsEmpty ? null : message.Attachments.Select(static attachment => {
var a = new ViewerJson.JsonMessageAttachment {
Url = strategy.GetAttachmentUrl(attachment),
Url = attachment.DownloadUrl,
Name = Uri.TryCreate(attachment.NormalizedUrl, UriKind.Absolute, out var uri) ? Path.GetFileName(uri.LocalPath) : attachment.NormalizedUrl
};

View File

@@ -11,7 +11,6 @@ public interface IDatabaseFile : IAsyncDisposable {
IServerRepository Servers { get; }
IChannelRepository Channels { get; }
IMessageRepository Messages { get; }
IAttachmentRepository Attachments { get; }
IDownloadRepository Downloads { get; }
Task Vacuum();

View File

@@ -161,7 +161,7 @@ public static class LegacyArchiveImport {
var messagesObj = data.HasKey(channelIdStr) ? data.RequireObject(channelIdStr, DataPath) : (JsonElement?) null;
if (messagesObj == null) {
return Array.Empty<Message>();
return [];
}
return messagesObj.Value.EnumerateObject().Select(item => {

View File

@@ -1,21 +0,0 @@
using System;
using System.Reactive.Linq;
using System.Threading;
using System.Threading.Tasks;
using DHT.Server.Data.Filters;
namespace DHT.Server.Database.Repositories;
public interface IAttachmentRepository {
IObservable<long> TotalCount { get; }
Task<long> Count(AttachmentFilter? filter = null, CancellationToken cancellationToken = default);
internal sealed class Dummy : IAttachmentRepository {
public IObservable<long> TotalCount { get; } = Observable.Return(0L);
public Task<long> Count(AttachmentFilter? filter = null, CancellationToken cancellationToken = default) {
return Task.FromResult(0L);
}
}
}

View File

@@ -1,10 +1,10 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Reactive.Linq;
using System.Threading;
using System.Threading.Tasks;
using DHT.Server.Data;
using DHT.Server.Data.Aggregations;
using DHT.Server.Data.Filters;
using DHT.Server.Download;
@@ -14,55 +14,61 @@ namespace DHT.Server.Database.Repositories;
public interface IDownloadRepository {
IObservable<long> TotalCount { get; }
Task AddDownload(Data.Download download);
Task AddDownload(Data.Download item, Stream? stream);
Task<DownloadStatusStatistics> GetStatistics(CancellationToken cancellationToken = default);
Task<long> Count(DownloadItemFilter filter, CancellationToken cancellationToken = default);
IAsyncEnumerable<Data.Download> GetWithoutData();
Task<DownloadStatusStatistics> GetStatistics(DownloadItemFilter nonSkippedFilter, CancellationToken cancellationToken = default);
IAsyncEnumerable<Data.Download> Get();
Task<Data.Download> HydrateWithData(Data.Download download);
Task<bool> GetDownloadData(string normalizedUrl, Func<Stream, Task> dataProcessor);
Task<bool> GetSuccessfulDownloadWithData(string normalizedUrl, Func<Data.Download, Stream, Task> dataProcessor);
Task<DownloadedAttachment?> GetDownloadedAttachment(string normalizedUrl);
Task<int> EnqueueDownloadItems(AttachmentFilter? filter = null, CancellationToken cancellationToken = default);
IAsyncEnumerable<DownloadItem> PullEnqueuedDownloadItems(int count, CancellationToken cancellationToken = default);
Task RemoveDownloadItems(DownloadItemFilter? filter, FilterRemovalMode mode);
IAsyncEnumerable<DownloadItem> PullPendingDownloadItems(int count, DownloadItemFilter filter, CancellationToken cancellationToken = default);
Task MoveDownloadingItemsBackToQueue(CancellationToken cancellationToken = default);
Task<int> RetryFailed(CancellationToken cancellationToken = default);
internal sealed class Dummy : IDownloadRepository {
public IObservable<long> TotalCount { get; } = Observable.Return(0L);
public Task AddDownload(Data.Download download) {
public Task AddDownload(Data.Download item, Stream? stream) {
return Task.CompletedTask;
}
public Task<DownloadStatusStatistics> GetStatistics(CancellationToken cancellationToken) {
public Task<long> Count(DownloadItemFilter filter, CancellationToken cancellationToken) {
return Task.FromResult(0L);
}
public Task<DownloadStatusStatistics> GetStatistics(DownloadItemFilter nonSkippedFilter, CancellationToken cancellationToken) {
return Task.FromResult(new DownloadStatusStatistics());
}
public IAsyncEnumerable<Data.Download> GetWithoutData() {
public IAsyncEnumerable<Data.Download> Get() {
return AsyncEnumerable.Empty<Data.Download>();
}
public Task<Data.Download> HydrateWithData(Data.Download download) {
return Task.FromResult(download);
public Task<bool> GetDownloadData(string normalizedUrl, Func<Stream, Task> dataProcessor) {
return Task.FromResult(false);
}
public Task<DownloadedAttachment?> GetDownloadedAttachment(string normalizedUrl) {
return Task.FromResult<DownloadedAttachment?>(null);
public Task<bool> GetSuccessfulDownloadWithData(string normalizedUrl, Func<Data.Download, Stream, Task> dataProcessor) {
return Task.FromResult(false);
}
public Task<int> EnqueueDownloadItems(AttachmentFilter? filter, CancellationToken cancellationToken) {
return Task.FromResult(0);
}
public IAsyncEnumerable<DownloadItem> PullEnqueuedDownloadItems(int count, CancellationToken cancellationToken) {
public IAsyncEnumerable<DownloadItem> PullPendingDownloadItems(int count, DownloadItemFilter filter, CancellationToken cancellationToken) {
return AsyncEnumerable.Empty<DownloadItem>();
}
public Task RemoveDownloadItems(DownloadItemFilter? filter, FilterRemovalMode mode) {
public Task MoveDownloadingItemsBackToQueue(CancellationToken cancellationToken) {
return Task.CompletedTask;
}
public Task<int> RetryFailed(CancellationToken cancellationToken) {
return Task.FromResult(0);
}
}
}

View File

@@ -1,16 +1,20 @@
using System;
using System.Reactive.Linq;
using System.Threading;
using System.Threading.Tasks;
using DHT.Utils.Logging;
using DHT.Utils.Tasks;
namespace DHT.Server.Database.Sqlite.Repositories;
abstract class BaseSqliteRepository : IDisposable {
private readonly ObservableThrottledTask<long> totalCountTask = new (TaskScheduler.Default);
public IObservable<long> TotalCount => totalCountTask;
private readonly ObservableThrottledTask<long> totalCountTask;
protected BaseSqliteRepository() {
public IObservable<long> TotalCount { get; }
protected BaseSqliteRepository(Log log) {
totalCountTask = new ObservableThrottledTask<long>(log, TaskScheduler.Default);
TotalCount = totalCountTask.DistinctUntilChanged();
UpdateTotalCount();
}
@@ -21,6 +25,6 @@ abstract class BaseSqliteRepository : IDisposable {
protected void UpdateTotalCount() {
totalCountTask.Post(Count);
}
public abstract Task<long> Count(CancellationToken cancellationToken);
}

View File

@@ -1,28 +0,0 @@
using System.Threading;
using System.Threading.Tasks;
using DHT.Server.Data.Filters;
using DHT.Server.Database.Repositories;
using DHT.Server.Database.Sqlite.Utils;
namespace DHT.Server.Database.Sqlite.Repositories;
sealed class SqliteAttachmentRepository : BaseSqliteRepository, IAttachmentRepository {
private readonly SqliteConnectionPool pool;
public SqliteAttachmentRepository(SqliteConnectionPool pool) {
this.pool = pool;
}
internal new void UpdateTotalCount() {
base.UpdateTotalCount();
}
public override Task<long> Count(CancellationToken cancellationToken) {
return Count(filter: null, cancellationToken);
}
public async Task<long> Count(AttachmentFilter? filter, CancellationToken cancellationToken) {
await using var conn = await pool.Take();
return await conn.ExecuteReaderAsync("SELECT COUNT(DISTINCT normalized_url) FROM attachments a" + filter.GenerateWhereClause("a"), static reader => reader?.GetInt64(0) ?? 0L, cancellationToken);
}
}

View File

@@ -4,21 +4,24 @@ using System.Threading.Tasks;
using DHT.Server.Data;
using DHT.Server.Database.Repositories;
using DHT.Server.Database.Sqlite.Utils;
using DHT.Utils.Logging;
using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite.Repositories;
sealed class SqliteChannelRepository : BaseSqliteRepository, IChannelRepository {
private static readonly Log Log = Log.ForType<SqliteChannelRepository>();
private readonly SqliteConnectionPool pool;
public SqliteChannelRepository(SqliteConnectionPool pool) {
public SqliteChannelRepository(SqliteConnectionPool pool) : base(Log) {
this.pool = pool;
}
public async Task Add(IReadOnlyList<Channel> channels) {
await using var conn = await pool.Take();
await using (var tx = await conn.BeginTransactionAsync()) {
await using (var conn = await pool.Take()) {
await conn.BeginTransactionAsync();
await using var cmd = conn.Upsert("channels", [
("id", SqliteType.Integer),
("server", SqliteType.Integer),
@@ -40,7 +43,7 @@ sealed class SqliteChannelRepository : BaseSqliteRepository, IChannelRepository
await cmd.ExecuteNonQueryAsync();
}
await tx.CommitAsync();
await conn.CommitTransactionAsync();
}
UpdateTotalCount();
@@ -57,7 +60,7 @@ sealed class SqliteChannelRepository : BaseSqliteRepository, IChannelRepository
await using var cmd = conn.Command("SELECT id, server, name, parent_id, position, topic, nsfw FROM channels");
await using var reader = await cmd.ExecuteReaderAsync();
while (reader.Read()) {
while (await reader.ReadAsync()) {
yield return new Channel {
Id = reader.GetUint64(0),
Server = reader.GetUint64(1),

View File

@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
@@ -9,198 +10,274 @@ using DHT.Server.Data.Filters;
using DHT.Server.Database.Repositories;
using DHT.Server.Database.Sqlite.Utils;
using DHT.Server.Download;
using DHT.Utils.Logging;
using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite.Repositories;
sealed class SqliteDownloadRepository : BaseSqliteRepository, IDownloadRepository {
private readonly SqliteConnectionPool pool;
sealed class SqliteDownloadRepository(SqliteConnectionPool pool) : BaseSqliteRepository(Log), IDownloadRepository {
private static readonly Log Log = Log.ForType<SqliteDownloadRepository>();
public SqliteDownloadRepository(SqliteConnectionPool pool) {
this.pool = pool;
internal sealed class NewDownloadCollector : IAsyncDisposable {
private readonly SqliteDownloadRepository repository;
private bool hasAdded = false;
private readonly SqliteCommand metadataCmd;
public NewDownloadCollector(SqliteDownloadRepository repository, ISqliteConnection conn) {
this.repository = repository;
metadataCmd = conn.Command(
"""
INSERT INTO download_metadata (normalized_url, download_url, status, type, size)
VALUES (:normalized_url, :download_url, :status, :type, :size)
ON CONFLICT DO NOTHING
"""
);
metadataCmd.Add(":normalized_url", SqliteType.Text);
metadataCmd.Add(":download_url", SqliteType.Text);
metadataCmd.Add(":status", SqliteType.Integer);
metadataCmd.Add(":type", SqliteType.Text);
metadataCmd.Add(":size", SqliteType.Integer);
}
public async Task Add(Data.Download download) {
metadataCmd.Set(":normalized_url", download.NormalizedUrl);
metadataCmd.Set(":download_url", download.DownloadUrl);
metadataCmd.Set(":status", (int) download.Status);
metadataCmd.Set(":type", download.Type);
metadataCmd.Set(":size", download.Size);
hasAdded |= await metadataCmd.ExecuteNonQueryAsync() > 0;
}
public void OnCommitted() {
if (hasAdded) {
repository.UpdateTotalCount();
}
}
public async ValueTask DisposeAsync() {
await metadataCmd.DisposeAsync();
}
}
public async Task AddDownload(Data.Download download) {
public async Task AddDownload(Data.Download item, Stream? stream) {
await using (var conn = await pool.Take()) {
await using var cmd = conn.Upsert("downloads", [
await conn.BeginTransactionAsync();
await using var metadataCmd = conn.Upsert("download_metadata", [
("normalized_url", SqliteType.Text),
("download_url", SqliteType.Text),
("status", SqliteType.Integer),
("type", SqliteType.Text),
("size", SqliteType.Integer),
("blob", SqliteType.Blob)
]);
cmd.Set(":normalized_url", download.NormalizedUrl);
cmd.Set(":download_url", download.DownloadUrl);
cmd.Set(":status", (int) download.Status);
cmd.Set(":size", download.Size);
cmd.Set(":blob", download.Data);
await cmd.ExecuteNonQueryAsync();
metadataCmd.Set(":normalized_url", item.NormalizedUrl);
metadataCmd.Set(":download_url", item.DownloadUrl);
metadataCmd.Set(":status", (int) item.Status);
metadataCmd.Set(":type", item.Type);
metadataCmd.Set(":size", item.Size);
await metadataCmd.ExecuteNonQueryAsync();
if (stream == null) {
await using var deleteBlobCmd = conn.Command("DELETE FROM download_blobs WHERE normalized_url = :normalized_url");
deleteBlobCmd.AddAndSet(":normalized_url", SqliteType.Text, item.NormalizedUrl);
await deleteBlobCmd.ExecuteNonQueryAsync();
}
else {
await using var upsertBlobCmd = conn.Command(
"""
INSERT INTO download_blobs (normalized_url, blob)
VALUES (:normalized_url, ZEROBLOB(:blob_length))
ON CONFLICT (normalized_url) DO UPDATE SET blob = excluded.blob
RETURNING rowid
"""
);
upsertBlobCmd.AddAndSet(":normalized_url", SqliteType.Text, item.NormalizedUrl);
upsertBlobCmd.AddAndSet(":blob_length", SqliteType.Integer, item.Size);
long rowid = await upsertBlobCmd.ExecuteLongScalarAsync();
await using var blob = new SqliteBlob(conn.InnerConnection, "download_blobs", "blob", rowid);
await stream.CopyToAsync(blob);
}
await conn.CommitTransactionAsync();
}
UpdateTotalCount();
}
public override async Task<long> Count(CancellationToken cancellationToken) {
await using var conn = await pool.Take();
return await conn.ExecuteReaderAsync("SELECT COUNT(*) FROM downloads", static reader => reader?.GetInt64(0) ?? 0L, cancellationToken);
public override Task<long> Count(CancellationToken cancellationToken) {
return Count(filter: null, cancellationToken);
}
public async Task<DownloadStatusStatistics> GetStatistics(CancellationToken cancellationToken) {
static async Task LoadUndownloadedStatistics(ISqliteConnection conn, DownloadStatusStatistics result, CancellationToken cancellationToken) {
await using var cmd = conn.Command(
"""
SELECT IFNULL(COUNT(size), 0), IFNULL(SUM(size), 0)
FROM (SELECT MAX(a.size) size
FROM attachments a
WHERE a.normalized_url NOT IN (SELECT d.normalized_url FROM downloads d)
GROUP BY a.normalized_url)
""");
await using var reader = await cmd.ExecuteReaderAsync(cancellationToken);
if (reader.Read()) {
result.SkippedCount = reader.GetInt32(0);
result.SkippedSize = reader.GetUint64(1);
}
}
static async Task LoadSuccessStatistics(ISqliteConnection conn, DownloadStatusStatistics result, CancellationToken cancellationToken) {
await using var cmd = conn.Command(
"""
SELECT
IFNULL(SUM(CASE WHEN status IN (:enqueued, :downloading) THEN 1 ELSE 0 END), 0),
IFNULL(SUM(CASE WHEN status IN (:enqueued, :downloading) THEN size ELSE 0 END), 0),
IFNULL(SUM(CASE WHEN status = :success THEN 1 ELSE 0 END), 0),
IFNULL(SUM(CASE WHEN status = :success THEN size ELSE 0 END), 0),
IFNULL(SUM(CASE WHEN status NOT IN (:enqueued, :downloading) AND status != :success THEN 1 ELSE 0 END), 0),
IFNULL(SUM(CASE WHEN status NOT IN (:enqueued, :downloading) AND status != :success THEN size ELSE 0 END), 0)
FROM downloads
"""
);
cmd.AddAndSet(":enqueued", SqliteType.Integer, (int) DownloadStatus.Enqueued);
cmd.AddAndSet(":downloading", SqliteType.Integer, (int) DownloadStatus.Downloading);
cmd.AddAndSet(":success", SqliteType.Integer, (int) DownloadStatus.Success);
await using var reader = await cmd.ExecuteReaderAsync(cancellationToken);
if (reader.Read()) {
result.EnqueuedCount = reader.GetInt32(0);
result.EnqueuedSize = reader.GetUint64(1);
result.SuccessfulCount = reader.GetInt32(2);
result.SuccessfulSize = reader.GetUint64(3);
result.FailedCount = reader.GetInt32(4);
result.FailedSize = reader.GetUint64(5);
}
}
var result = new DownloadStatusStatistics();
public async Task<long> Count(DownloadItemFilter? filter, CancellationToken cancellationToken) {
await using var conn = await pool.Take();
await LoadUndownloadedStatistics(conn, result, cancellationToken);
await LoadSuccessStatistics(conn, result, cancellationToken);
return result;
return await conn.ExecuteReaderAsync("SELECT COUNT(*) FROM download_metadata" + filter.GenerateConditions().BuildWhereClause(), static reader => reader?.GetInt64(0) ?? 0L, cancellationToken);
}
public async IAsyncEnumerable<Data.Download> GetWithoutData() {
public async Task<DownloadStatusStatistics> GetStatistics(DownloadItemFilter nonSkippedFilter, CancellationToken cancellationToken) {
nonSkippedFilter.IncludeStatuses = null;
nonSkippedFilter.ExcludeStatuses = null;
string nonSkippedFilterConditions = nonSkippedFilter.GenerateConditions().Build();
await using var conn = await pool.Take();
await using var cmd = conn.Command("SELECT normalized_url, download_url, status, size FROM downloads");
await using var cmd = conn.Command(
$"""
SELECT
IFNULL(SUM(CASE WHEN (status = :downloading) OR (status = :pending AND {nonSkippedFilterConditions}) THEN 1 ELSE 0 END), 0),
IFNULL(SUM(CASE WHEN (status = :downloading) OR (status = :pending AND {nonSkippedFilterConditions}) THEN IFNULL(size, 0) ELSE 0 END), 0),
IFNULL(SUM(CASE WHEN ((status = :downloading) OR (status = :pending AND {nonSkippedFilterConditions})) AND size IS NULL THEN 1 ELSE 0 END), 0),
IFNULL(SUM(CASE WHEN status = :success THEN 1 ELSE 0 END), 0),
IFNULL(SUM(CASE WHEN status = :success THEN IFNULL(size, 0) ELSE 0 END), 0),
IFNULL(SUM(CASE WHEN status = :success AND size IS NULL THEN 1 ELSE 0 END), 0),
IFNULL(SUM(CASE WHEN status NOT IN (:pending, :downloading, :success) THEN 1 ELSE 0 END), 0),
IFNULL(SUM(CASE WHEN status NOT IN (:pending, :downloading, :success) THEN IFNULL(size, 0) ELSE 0 END), 0),
IFNULL(SUM(CASE WHEN status NOT IN (:pending, :downloading, :success) AND size IS NULL THEN 1 ELSE 0 END), 0),
IFNULL(SUM(CASE WHEN status = :pending AND NOT ({nonSkippedFilterConditions}) THEN 1 ELSE 0 END), 0),
IFNULL(SUM(CASE WHEN status = :pending AND NOT ({nonSkippedFilterConditions}) THEN IFNULL(size, 0) ELSE 0 END), 0),
IFNULL(SUM(CASE WHEN status = :pending AND NOT ({nonSkippedFilterConditions}) AND size IS NULL THEN 1 ELSE 0 END), 0)
FROM download_metadata
"""
);
cmd.AddAndSet(":pending", SqliteType.Integer, (int) DownloadStatus.Pending);
cmd.AddAndSet(":downloading", SqliteType.Integer, (int) DownloadStatus.Downloading);
cmd.AddAndSet(":success", SqliteType.Integer, (int) DownloadStatus.Success);
await using var reader = await cmd.ExecuteReaderAsync(cancellationToken);
if (!await reader.ReadAsync(cancellationToken)) {
return new DownloadStatusStatistics();
}
return new DownloadStatusStatistics {
PendingCount = reader.GetInt32(0),
PendingTotalSize = reader.GetUint64(1),
PendingWithUnknownSizeCount = reader.GetInt32(2),
SuccessfulCount = reader.GetInt32(3),
SuccessfulTotalSize = reader.GetUint64(4),
SuccessfulWithUnknownSizeCount = reader.GetInt32(5),
FailedCount = reader.GetInt32(6),
FailedTotalSize = reader.GetUint64(7),
FailedWithUnknownSizeCount = reader.GetInt32(8),
SkippedCount = reader.GetInt32(9),
SkippedTotalSize = reader.GetUint64(10),
SkippedWithUnknownSizeCount = reader.GetInt32(11)
};
}
public async IAsyncEnumerable<Data.Download> Get() {
await using var conn = await pool.Take();
await using var cmd = conn.Command("SELECT normalized_url, download_url, status, type, size FROM download_metadata");
await using var reader = await cmd.ExecuteReaderAsync();
while (reader.Read()) {
while (await reader.ReadAsync()) {
string normalizedUrl = reader.GetString(0);
string downloadUrl = reader.GetString(1);
var status = (DownloadStatus) reader.GetInt32(2);
ulong size = reader.GetUint64(3);
string? type = reader.IsDBNull(3) ? null : reader.GetString(3);
ulong? size = reader.IsDBNull(4) ? null : reader.GetUint64(4);
yield return new Data.Download(normalizedUrl, downloadUrl, status, size);
yield return new Data.Download(normalizedUrl, downloadUrl, status, type, size);
}
}
public async Task<Data.Download> HydrateWithData(Data.Download download) {
public async Task<bool> GetDownloadData(string normalizedUrl, Func<Stream, Task> dataProcessor) {
await using var conn = await pool.Take();
await using var cmd = conn.Command("SELECT blob FROM downloads WHERE normalized_url = :url");
cmd.AddAndSet(":url", SqliteType.Text, download.NormalizedUrl);
await using var cmd = conn.Command("SELECT rowid FROM download_blobs WHERE normalized_url = :normalized_url");
cmd.AddAndSet(":normalized_url", SqliteType.Text, normalizedUrl);
long rowid;
await using (var reader = await cmd.ExecuteReaderAsync()) {
if (!await reader.ReadAsync()) {
return false;
}
await using var reader = await cmd.ExecuteReaderAsync();
rowid = reader.GetInt64(0);
}
await using (var blob = new SqliteBlob(conn.InnerConnection, "download_blobs", "blob", rowid, readOnly: true)) {
await dataProcessor(blob);
}
if (reader.Read() && !reader.IsDBNull(0)) {
return download.WithData((byte[]) reader["blob"]);
}
else {
return download;
}
return true;
}
public async Task<DownloadedAttachment?> GetDownloadedAttachment(string normalizedUrl) {
public async Task<bool> GetSuccessfulDownloadWithData(string normalizedUrl, Func<Data.Download, Stream, Task> dataProcessor) {
await using var conn = await pool.Take();
await using var cmd = conn.Command(
"""
SELECT a.type, d.blob FROM downloads d
LEFT JOIN attachments a ON d.normalized_url = a.normalized_url
WHERE d.normalized_url = :normalized_url AND d.status = :success AND d.blob IS NOT NULL
SELECT dm.download_url, dm.type, db.rowid FROM download_metadata dm
JOIN download_blobs db ON dm.normalized_url = db.normalized_url
WHERE dm.normalized_url = :normalized_url AND dm.status = :success IS NOT NULL
"""
);
cmd.AddAndSet(":normalized_url", SqliteType.Text, normalizedUrl);
cmd.AddAndSet(":success", SqliteType.Integer, (int) DownloadStatus.Success);
await using var reader = await cmd.ExecuteReaderAsync();
string downloadUrl;
string? type;
long rowid;
await using (var reader = await cmd.ExecuteReaderAsync()) {
if (!await reader.ReadAsync()) {
return false;
}
if (!reader.Read()) {
return null;
downloadUrl = reader.GetString(0);
type = reader.IsDBNull(1) ? null : reader.GetString(1);
rowid = reader.GetInt64(2);
}
await using (var blob = new SqliteBlob(conn.InnerConnection, "download_blobs", "blob", rowid, readOnly: true)) {
await dataProcessor(new Data.Download(normalizedUrl, downloadUrl, DownloadStatus.Success, type, (ulong) blob.Length), blob);
}
return new DownloadedAttachment {
Type = reader.IsDBNull(0) ? null : reader.GetString(0),
Data = (byte[]) reader["blob"],
};
return true;
}
public async Task<int> EnqueueDownloadItems(AttachmentFilter? filter, CancellationToken cancellationToken) {
await using var conn = await pool.Take();
public async IAsyncEnumerable<DownloadItem> PullPendingDownloadItems(int count, DownloadItemFilter filter, [EnumeratorCancellation] CancellationToken cancellationToken) {
filter.IncludeStatuses = [DownloadStatus.Pending];
filter.ExcludeStatuses = null;
await using var cmd = conn.Command(
$"""
INSERT INTO downloads (normalized_url, download_url, status, size)
SELECT a.normalized_url, a.download_url, :enqueued, MAX(a.size)
FROM attachments a
{filter.GenerateWhereClause("a")}
GROUP BY a.normalized_url
"""
);
cmd.AddAndSet(":enqueued", SqliteType.Integer, (int) DownloadStatus.Enqueued);
return await cmd.ExecuteNonQueryAsync(cancellationToken);
}
public async IAsyncEnumerable<DownloadItem> PullEnqueuedDownloadItems(int count, [EnumeratorCancellation] CancellationToken cancellationToken) {
var found = new List<DownloadItem>();
await using var conn = await pool.Take();
await using (var cmd = conn.Command("SELECT normalized_url, download_url, size FROM downloads WHERE status = :enqueued LIMIT :limit")) {
cmd.AddAndSet(":enqueued", SqliteType.Integer, (int) DownloadStatus.Enqueued);
var sql = $"""
SELECT normalized_url, download_url, type, size
FROM download_metadata
{filter.GenerateConditions().BuildWhereClause()}
LIMIT :limit
""";
await using (var cmd = conn.Command(sql)) {
cmd.AddAndSet(":limit", SqliteType.Integer, Math.Max(0, count));
await using var reader = await cmd.ExecuteReaderAsync(cancellationToken);
while (reader.Read()) {
while (await reader.ReadAsync(cancellationToken)) {
found.Add(new DownloadItem {
NormalizedUrl = reader.GetString(0),
DownloadUrl = reader.GetString(1),
Size = reader.GetUint64(2),
Type = reader.IsDBNull(2) ? null : reader.GetString(2),
Size = reader.IsDBNull(3) ? null : reader.GetUint64(3)
});
}
}
if (found.Count != 0) {
await using var cmd = conn.Command("UPDATE downloads SET status = :downloading WHERE normalized_url = :normalized_url AND status = :enqueued");
cmd.AddAndSet(":enqueued", SqliteType.Integer, (int) DownloadStatus.Enqueued);
await using var cmd = conn.Command("UPDATE download_metadata SET status = :downloading WHERE normalized_url = :normalized_url AND status = :pending");
cmd.AddAndSet(":pending", SqliteType.Integer, (int) DownloadStatus.Pending);
cmd.AddAndSet(":downloading", SqliteType.Integer, (int) DownloadStatus.Downloading);
cmd.Add(":normalized_url", SqliteType.Text);
@@ -214,17 +291,23 @@ sealed class SqliteDownloadRepository : BaseSqliteRepository, IDownloadRepositor
}
}
public async Task RemoveDownloadItems(DownloadItemFilter? filter, FilterRemovalMode mode) {
await using (var conn = await pool.Take()) {
await conn.ExecuteAsync(
$"""
-- noinspection SqlWithoutWhere
DELETE FROM downloads
{filter.GenerateWhereClause(invert: mode == FilterRemovalMode.KeepMatching)}
"""
);
}
public async Task MoveDownloadingItemsBackToQueue(CancellationToken cancellationToken) {
await using var conn = await pool.Take();
UpdateTotalCount();
await using var cmd = conn.Command("UPDATE download_metadata SET status = :pending WHERE status = :downloading");
cmd.AddAndSet(":pending", SqliteType.Integer, (int) DownloadStatus.Pending);
cmd.AddAndSet(":downloading", SqliteType.Integer, (int) DownloadStatus.Downloading);
await cmd.ExecuteNonQueryAsync(cancellationToken);
}
public async Task<int> RetryFailed(CancellationToken cancellationToken) {
await using var conn = await pool.Take();
await using var cmd = conn.Command("UPDATE download_metadata SET status = :pending WHERE status = :generic_error OR (status > :last_custom_code AND status != :success)");
cmd.AddAndSet(":pending", SqliteType.Integer, (int) DownloadStatus.Pending);
cmd.AddAndSet(":generic_error", SqliteType.Integer, (int) DownloadStatus.GenericError);
cmd.AddAndSet(":last_custom_code", SqliteType.Integer, (int) DownloadStatus.LastCustomCode);
cmd.AddAndSet(":success", SqliteType.Integer, (int) DownloadStatus.Success);
return await cmd.ExecuteNonQueryAsync(cancellationToken);
}
}

View File

@@ -7,17 +7,21 @@ using DHT.Server.Data;
using DHT.Server.Data.Filters;
using DHT.Server.Database.Repositories;
using DHT.Server.Database.Sqlite.Utils;
using DHT.Server.Download;
using DHT.Utils.Logging;
using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite.Repositories;
sealed class SqliteMessageRepository : BaseSqliteRepository, IMessageRepository {
private static readonly Log Log = Log.ForType<SqliteMessageRepository>();
private readonly SqliteConnectionPool pool;
private readonly SqliteAttachmentRepository attachments;
private readonly SqliteDownloadRepository downloads;
public SqliteMessageRepository(SqliteConnectionPool pool, SqliteAttachmentRepository attachments) {
public SqliteMessageRepository(SqliteConnectionPool pool, SqliteDownloadRepository downloads) : base(Log) {
this.pool = pool;
this.attachments = attachments;
this.downloads = downloads;
}
public async Task Add(IReadOnlyList<Message> messages) {
@@ -34,10 +38,8 @@ sealed class SqliteMessageRepository : BaseSqliteRepository, IMessageRepository
await cmd.ExecuteNonQueryAsync();
}
bool addedAttachments = false;
await using (var conn = await pool.Take()) {
await using var tx = await conn.BeginTransactionAsync();
await conn.BeginTransactionAsync();
await using var messageCmd = conn.Upsert("messages", [
("message_id", SqliteType.Integer),
@@ -88,6 +90,8 @@ sealed class SqliteMessageRepository : BaseSqliteRepository, IMessageRepository
("emoji_flags", SqliteType.Integer),
("count", SqliteType.Integer)
]);
await using var downloadCollector = new SqliteDownloadRepository.NewDownloadCollector(downloads, conn);
foreach (var message in messages) {
object messageId = message.Id;
@@ -119,8 +123,6 @@ sealed class SqliteMessageRepository : BaseSqliteRepository, IMessageRepository
}
if (!message.Attachments.IsEmpty) {
addedAttachments = true;
foreach (var attachment in message.Attachments) {
attachmentCmd.Set(":message_id", messageId);
attachmentCmd.Set(":attachment_id", attachment.Id);
@@ -132,6 +134,8 @@ sealed class SqliteMessageRepository : BaseSqliteRepository, IMessageRepository
attachmentCmd.Set(":width", attachment.Width);
attachmentCmd.Set(":height", attachment.Height);
await attachmentCmd.ExecuteNonQueryAsync();
await downloadCollector.Add(DownloadLinkExtractor.FromAttachment(attachment));
}
}
@@ -140,6 +144,10 @@ sealed class SqliteMessageRepository : BaseSqliteRepository, IMessageRepository
embedCmd.Set(":message_id", messageId);
embedCmd.Set(":json", embed.Json);
await embedCmd.ExecuteNonQueryAsync();
if (DownloadLinkExtractor.TryFromEmbedJson(embed.Json) is {} download) {
await downloadCollector.Add(download);
}
}
}
@@ -151,18 +159,19 @@ sealed class SqliteMessageRepository : BaseSqliteRepository, IMessageRepository
reactionCmd.Set(":emoji_flags", (int) reaction.EmojiFlags);
reactionCmd.Set(":count", reaction.Count);
await reactionCmd.ExecuteNonQueryAsync();
if (reaction.EmojiId is {} emojiId) {
await downloadCollector.Add(DownloadLinkExtractor.FromEmoji(emojiId, reaction.EmojiFlags));
}
}
}
}
await tx.CommitAsync();
await conn.CommitTransactionAsync();
downloadCollector.OnCommitted();
}
UpdateTotalCount();
if (addedAttachments) {
attachments.UpdateTotalCount();
}
}
public override Task<long> Count(CancellationToken cancellationToken) {
@@ -171,14 +180,14 @@ sealed class SqliteMessageRepository : BaseSqliteRepository, IMessageRepository
public async Task<long> Count(MessageFilter? filter, CancellationToken cancellationToken) {
await using var conn = await pool.Take();
return await conn.ExecuteReaderAsync("SELECT COUNT(*) FROM messages" + filter.GenerateWhereClause(), static reader => reader?.GetInt64(0) ?? 0L, cancellationToken);
return await conn.ExecuteReaderAsync("SELECT COUNT(*) FROM messages" + filter.GenerateConditions().BuildWhereClause(), static reader => reader?.GetInt64(0) ?? 0L, cancellationToken);
}
private sealed class MesageToManyCommand<T> : IAsyncDisposable {
private sealed class MessageToManyCommand<T> : IAsyncDisposable {
private readonly SqliteCommand cmd;
private readonly Func<SqliteDataReader, T> readItem;
public MesageToManyCommand(ISqliteConnection conn, string sql, Func<SqliteDataReader, T> readItem) {
public MessageToManyCommand(ISqliteConnection conn, string sql, Func<SqliteDataReader, T> readItem) {
this.cmd = conn.Command(sql);
this.cmd.Add(":message_id", SqliteType.Integer);
@@ -214,7 +223,7 @@ sealed class SqliteMessageRepository : BaseSqliteRepository, IMessageRepository
WHERE message_id = :message_id
""";
await using var attachmentCmd = new MesageToManyCommand<Attachment>(conn, AttachmentSql, static reader => new Attachment {
await using var attachmentCmd = new MessageToManyCommand<Attachment>(conn, AttachmentSql, static reader => new Attachment {
Id = reader.GetUint64(0),
Name = reader.GetString(1),
Type = reader.IsDBNull(2) ? null : reader.GetString(2),
@@ -232,7 +241,7 @@ sealed class SqliteMessageRepository : BaseSqliteRepository, IMessageRepository
WHERE message_id = :message_id
""";
await using var embedCmd = new MesageToManyCommand<Embed>(conn, EmbedSql, static reader => new Embed {
await using var embedCmd = new MessageToManyCommand<Embed>(conn, EmbedSql, static reader => new Embed {
Json = reader.GetString(0)
});
@@ -243,7 +252,7 @@ sealed class SqliteMessageRepository : BaseSqliteRepository, IMessageRepository
WHERE message_id = :message_id
""";
await using var reactionsCmd = new MesageToManyCommand<Reaction>(conn, ReactionSql, static reader => new Reaction {
await using var reactionsCmd = new MessageToManyCommand<Reaction>(conn, ReactionSql, static reader => new Reaction {
EmojiId = reader.IsDBNull(0) ? null : reader.GetUint64(0),
EmojiName = reader.IsDBNull(1) ? null : reader.GetString(1),
EmojiFlags = (EmojiFlags) reader.GetInt16(2),
@@ -256,7 +265,7 @@ sealed class SqliteMessageRepository : BaseSqliteRepository, IMessageRepository
FROM messages m
LEFT JOIN edit_timestamps et ON m.message_id = et.message_id
LEFT JOIN replied_to rt ON m.message_id = rt.message_id
{filter.GenerateWhereClause("m")}
{filter.GenerateConditions("m").BuildWhereClause()}
"""
);
@@ -283,7 +292,7 @@ sealed class SqliteMessageRepository : BaseSqliteRepository, IMessageRepository
public async IAsyncEnumerable<ulong> GetIds(MessageFilter? filter) {
await using var conn = await pool.Take();
await using var cmd = conn.Command("SELECT message_id FROM messages" + filter.GenerateWhereClause());
await using var cmd = conn.Command("SELECT message_id FROM messages" + filter.GenerateConditions().BuildWhereClause());
await using var reader = await cmd.ExecuteReaderAsync();
while (await reader.ReadAsync()) {
@@ -297,7 +306,7 @@ sealed class SqliteMessageRepository : BaseSqliteRepository, IMessageRepository
$"""
-- noinspection SqlWithoutWhere
DELETE FROM messages
{filter.GenerateWhereClause(invert: mode == FilterRemovalMode.KeepMatching)}
{filter.GenerateConditions(invert: mode == FilterRemovalMode.KeepMatching).BuildWhereClause()}
"""
);
}

View File

@@ -4,21 +4,24 @@ using System.Threading.Tasks;
using DHT.Server.Data;
using DHT.Server.Database.Repositories;
using DHT.Server.Database.Sqlite.Utils;
using DHT.Utils.Logging;
using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite.Repositories;
sealed class SqliteServerRepository : BaseSqliteRepository, IServerRepository {
private static readonly Log Log = Log.ForType<SqliteServerRepository>();
private readonly SqliteConnectionPool pool;
public SqliteServerRepository(SqliteConnectionPool pool) {
public SqliteServerRepository(SqliteConnectionPool pool) : base(Log) {
this.pool = pool;
}
public async Task Add(IReadOnlyList<Data.Server> servers) {
await using var conn = await pool.Take();
await using (var tx = await conn.BeginTransactionAsync()) {
await using (var conn = await pool.Take()) {
await conn.BeginTransactionAsync();
await using var cmd = conn.Upsert("servers", [
("id", SqliteType.Integer),
("name", SqliteType.Text),
@@ -32,7 +35,7 @@ sealed class SqliteServerRepository : BaseSqliteRepository, IServerRepository {
await cmd.ExecuteNonQueryAsync();
}
await tx.CommitAsync();
await conn.CommitTransactionAsync();
}
UpdateTotalCount();
@@ -49,7 +52,7 @@ sealed class SqliteServerRepository : BaseSqliteRepository, IServerRepository {
await using var cmd = conn.Command("SELECT id, name, type FROM servers");
await using var reader = await cmd.ExecuteReaderAsync();
while (reader.Read()) {
while (await reader.ReadAsync()) {
yield return new Data.Server {
Id = reader.GetUint64(0),
Name = reader.GetString(1),

View File

@@ -4,21 +4,27 @@ using System.Threading.Tasks;
using DHT.Server.Data;
using DHT.Server.Database.Repositories;
using DHT.Server.Database.Sqlite.Utils;
using DHT.Server.Download;
using DHT.Utils.Logging;
using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite.Repositories;
sealed class SqliteUserRepository : BaseSqliteRepository, IUserRepository {
private static readonly Log Log = Log.ForType<SqliteUserRepository>();
private readonly SqliteConnectionPool pool;
public SqliteUserRepository(SqliteConnectionPool pool) {
private readonly SqliteDownloadRepository downloads;
public SqliteUserRepository(SqliteConnectionPool pool, SqliteDownloadRepository downloads) : base(Log) {
this.pool = pool;
this.downloads = downloads;
}
public async Task Add(IReadOnlyList<User> users) {
await using (var conn = await pool.Take()) {
await using var tx = await conn.BeginTransactionAsync();
await conn.BeginTransactionAsync();
await using var cmd = conn.Upsert("users", [
("id", SqliteType.Integer),
("name", SqliteType.Text),
@@ -26,15 +32,22 @@ sealed class SqliteUserRepository : BaseSqliteRepository, IUserRepository {
("discriminator", SqliteType.Text)
]);
await using var downloadCollector = new SqliteDownloadRepository.NewDownloadCollector(downloads, conn);
foreach (var user in users) {
cmd.Set(":id", user.Id);
cmd.Set(":name", user.Name);
cmd.Set(":avatar_url", user.AvatarUrl);
cmd.Set(":discriminator", user.Discriminator);
await cmd.ExecuteNonQueryAsync();
if (user.AvatarUrl is {} avatarUrl) {
await downloadCollector.Add(DownloadLinkExtractor.FromUserAvatar(user.Id, avatarUrl));
}
}
await tx.CommitAsync();
await conn.CommitTransactionAsync();
downloadCollector.OnCommitted();
}
UpdateTotalCount();
@@ -47,11 +60,11 @@ sealed class SqliteUserRepository : BaseSqliteRepository, IUserRepository {
public async IAsyncEnumerable<User> Get() {
await using var conn = await pool.Take();
await using var cmd = conn.Command("SELECT id, name, avatar_url, discriminator FROM users");
await using var reader = await cmd.ExecuteReaderAsync();
while (reader.Read()) {
while (await reader.ReadAsync()) {
yield return new User {
Id = reader.GetUint64(0),
Name = reader.GetString(1),

View File

@@ -1,358 +0,0 @@
using System.Collections.Generic;
using System.Data.Common;
using System.Threading.Tasks;
using DHT.Server.Database.Exceptions;
using DHT.Server.Database.Sqlite.Utils;
using DHT.Server.Download;
using DHT.Utils.Logging;
using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite;
sealed class Schema {
internal const int Version = 6;
private static readonly Log Log = Log.ForType<Schema>();
private readonly ISqliteConnection conn;
public Schema(ISqliteConnection conn) {
this.conn = conn;
}
public async Task<bool> Setup(ISchemaUpgradeCallbacks callbacks) {
await conn.ExecuteAsync("CREATE TABLE IF NOT EXISTS metadata (key TEXT PRIMARY KEY, value TEXT)");
var dbVersionStr = await conn.ExecuteReaderAsync("SELECT value FROM metadata WHERE key = 'version'", static reader => reader?.GetString(0));
if (dbVersionStr == null) {
await InitializeSchemas();
}
else if (!int.TryParse(dbVersionStr, out int dbVersion) || dbVersion < 1) {
throw new InvalidDatabaseVersionException(dbVersionStr);
}
else if (dbVersion > Version) {
throw new DatabaseTooNewException(dbVersion);
}
else if (dbVersion < Version) {
var proceed = await callbacks.CanUpgrade();
if (!proceed) {
return false;
}
await callbacks.Start(Version - dbVersion, async reporter => await UpgradeSchemas(dbVersion, reporter));
}
return true;
}
private async Task InitializeSchemas() {
await conn.ExecuteAsync("""
CREATE TABLE users (
id INTEGER PRIMARY KEY NOT NULL,
name TEXT NOT NULL,
avatar_url TEXT,
discriminator TEXT
)
""");
await conn.ExecuteAsync("""
CREATE TABLE servers (
id INTEGER PRIMARY KEY NOT NULL,
name TEXT NOT NULL,
type TEXT NOT NULL
)
""");
await conn.ExecuteAsync("""
CREATE TABLE channels (
id INTEGER PRIMARY KEY NOT NULL,
server INTEGER NOT NULL,
name TEXT NOT NULL,
parent_id INTEGER,
position INTEGER,
topic TEXT,
nsfw INTEGER
)
""");
await conn.ExecuteAsync("""
CREATE TABLE messages (
message_id INTEGER PRIMARY KEY NOT NULL,
sender_id INTEGER NOT NULL,
channel_id INTEGER NOT NULL,
text TEXT NOT NULL,
timestamp INTEGER NOT NULL
)
""");
await conn.ExecuteAsync("""
CREATE TABLE attachments (
message_id INTEGER NOT NULL,
attachment_id INTEGER NOT NULL PRIMARY KEY NOT NULL,
name TEXT NOT NULL,
type TEXT,
normalized_url TEXT NOT NULL,
download_url TEXT,
size INTEGER NOT NULL,
width INTEGER,
height INTEGER
)
""");
await conn.ExecuteAsync("""
CREATE TABLE embeds (
message_id INTEGER NOT NULL,
json TEXT NOT NULL
)
""");
await conn.ExecuteAsync("""
CREATE TABLE downloads (
normalized_url TEXT NOT NULL PRIMARY KEY,
download_url TEXT,
status INTEGER NOT NULL,
size INTEGER NOT NULL,
blob BLOB
)
""");
await conn.ExecuteAsync("""
CREATE TABLE reactions (
message_id INTEGER NOT NULL,
emoji_id INTEGER,
emoji_name TEXT,
emoji_flags INTEGER NOT NULL,
count INTEGER NOT NULL
)
""");
await CreateMessageEditTimestampTable();
await CreateMessageRepliedToTable();
await conn.ExecuteAsync("CREATE INDEX attachments_message_ix ON attachments(message_id)");
await conn.ExecuteAsync("CREATE INDEX embeds_message_ix ON embeds(message_id)");
await conn.ExecuteAsync("CREATE INDEX reactions_message_ix ON reactions(message_id)");
await conn.ExecuteAsync("INSERT INTO metadata (key, value) VALUES ('version', " + Version + ")");
}
private async Task CreateMessageEditTimestampTable() {
await conn.ExecuteAsync("""
CREATE TABLE edit_timestamps (
message_id INTEGER PRIMARY KEY NOT NULL,
edit_timestamp INTEGER NOT NULL
)
""");
}
private async Task CreateMessageRepliedToTable() {
await conn.ExecuteAsync("""
CREATE TABLE replied_to (
message_id INTEGER PRIMARY KEY NOT NULL,
replied_to_id INTEGER NOT NULL
)
""");
}
private async Task NormalizeAttachmentUrls(ISchemaUpgradeCallbacks.IProgressReporter reporter) {
await reporter.SubWork("Preparing attachments...", 0, 0);
var normalizedUrls = new Dictionary<long, string>();
await using (var selectCmd = conn.Command("SELECT attachment_id, url FROM attachments")) {
await using var reader = await selectCmd.ExecuteReaderAsync();
while (reader.Read()) {
var attachmentId = reader.GetInt64(0);
var originalUrl = reader.GetString(1);
normalizedUrls[attachmentId] = DiscordCdn.NormalizeUrl(originalUrl);
}
}
await using var tx = await conn.BeginTransactionAsync();
int totalUrls = normalizedUrls.Count;
int processedUrls = -1;
await using (var updateCmd = conn.Command("UPDATE attachments SET download_url = url, url = :normalized_url WHERE attachment_id = :attachment_id")) {
updateCmd.Add(":attachment_id", SqliteType.Integer);
updateCmd.Add(":normalized_url", SqliteType.Text);
foreach (var (attachmentId, normalizedUrl) in normalizedUrls) {
if (++processedUrls % 1000 == 0) {
await reporter.SubWork("Updating URLs...", processedUrls, totalUrls);
}
updateCmd.Set(":attachment_id", attachmentId);
updateCmd.Set(":normalized_url", normalizedUrl);
updateCmd.ExecuteNonQuery();
}
}
await reporter.SubWork("Updating URLs...", totalUrls, totalUrls);
await tx.CommitAsync();
}
private async Task NormalizeDownloadUrls(ISchemaUpgradeCallbacks.IProgressReporter reporter) {
await reporter.SubWork("Preparing downloads...", 0, 0);
var normalizedUrlsToOriginalUrls = new Dictionary<string, string>();
var duplicateUrlsToDelete = new HashSet<string>();
await using (var selectCmd = conn.Command("SELECT url FROM downloads ORDER BY CASE WHEN status = 200 THEN 0 ELSE 1 END")) {
await using var reader = await selectCmd.ExecuteReaderAsync();
while (reader.Read()) {
var originalUrl = reader.GetString(0);
var normalizedUrl = DiscordCdn.NormalizeUrl(originalUrl);
if (!normalizedUrlsToOriginalUrls.TryAdd(normalizedUrl, originalUrl)) {
duplicateUrlsToDelete.Add(originalUrl);
}
}
}
await conn.ExecuteAsync("PRAGMA cache_size = -20000");
DbTransaction tx;
await using (tx = await conn.BeginTransactionAsync()) {
await reporter.SubWork("Deleting duplicates...", 0, 0);
await using (var deleteCmd = conn.Delete("downloads", ("url", SqliteType.Text))) {
foreach (var duplicateUrl in duplicateUrlsToDelete) {
deleteCmd.Set(":url", duplicateUrl);
deleteCmd.ExecuteNonQuery();
}
}
await tx.CommitAsync();
}
int totalUrls = normalizedUrlsToOriginalUrls.Count;
int processedUrls = -1;
tx = await conn.BeginTransactionAsync();
await using (var updateCmd = conn.Command("UPDATE downloads SET download_url = :download_url, url = :normalized_url WHERE url = :download_url")) {
updateCmd.Add(":normalized_url", SqliteType.Text);
updateCmd.Add(":download_url", SqliteType.Text);
foreach (var (normalizedUrl, downloadUrl) in normalizedUrlsToOriginalUrls) {
if (++processedUrls % 100 == 0) {
await reporter.SubWork("Updating URLs...", processedUrls, totalUrls);
// Not proper way of dealing with transactions, but it avoids a long commit at the end.
// Schema upgrades are already non-atomic anyways, so this doesn't make it worse.
await tx.CommitAsync();
await tx.DisposeAsync();
tx = await conn.BeginTransactionAsync();
updateCmd.Transaction = (SqliteTransaction) tx;
}
updateCmd.Set(":normalized_url", normalizedUrl);
updateCmd.Set(":download_url", downloadUrl);
updateCmd.ExecuteNonQuery();
}
}
await reporter.SubWork("Updating URLs...", totalUrls, totalUrls);
await tx.CommitAsync();
await tx.DisposeAsync();
await conn.ExecuteAsync("PRAGMA cache_size = -2000");
}
private async Task UpgradeSchemas(int dbVersion, ISchemaUpgradeCallbacks.IProgressReporter reporter) {
var perf = Log.Start("from version " + dbVersion);
await conn.ExecuteAsync("UPDATE metadata SET value = " + Version + " WHERE key = 'version'");
if (dbVersion <= 1) {
await reporter.MainWork("Applying schema changes...", 0, 1);
await conn.ExecuteAsync("ALTER TABLE channels ADD parent_id INTEGER");
perf.Step("Upgrade to version 2");
await reporter.NextVersion();
}
if (dbVersion <= 2) {
await reporter.MainWork("Applying schema changes...", 0, 1);
await CreateMessageEditTimestampTable();
await CreateMessageRepliedToTable();
await conn.ExecuteAsync("""
INSERT INTO edit_timestamps (message_id, edit_timestamp)
SELECT message_id, edit_timestamp
FROM messages
WHERE edit_timestamp IS NOT NULL
""");
await conn.ExecuteAsync("""
INSERT INTO replied_to (message_id, replied_to_id)
SELECT message_id, replied_to_id
FROM messages
WHERE replied_to_id IS NOT NULL
""");
await conn.ExecuteAsync("ALTER TABLE messages DROP COLUMN replied_to_id");
await conn.ExecuteAsync("ALTER TABLE messages DROP COLUMN edit_timestamp");
perf.Step("Upgrade to version 3");
await reporter.MainWork("Vacuuming the database...", 1, 1);
await conn.ExecuteAsync("VACUUM");
perf.Step("Vacuum");
await reporter.NextVersion();
}
if (dbVersion <= 3) {
await conn.ExecuteAsync("""
CREATE TABLE downloads (
url TEXT NOT NULL PRIMARY KEY,
status INTEGER NOT NULL,
size INTEGER NOT NULL,
blob BLOB
)
""");
perf.Step("Upgrade to version 4");
await reporter.NextVersion();
}
if (dbVersion <= 4) {
await reporter.MainWork("Applying schema changes...", 0, 1);
await conn.ExecuteAsync("ALTER TABLE attachments ADD width INTEGER");
await conn.ExecuteAsync("ALTER TABLE attachments ADD height INTEGER");
perf.Step("Upgrade to version 5");
await reporter.NextVersion();
}
if (dbVersion <= 5) {
await reporter.MainWork("Applying schema changes...", 0, 3);
await conn.ExecuteAsync("ALTER TABLE attachments ADD download_url TEXT");
await conn.ExecuteAsync("ALTER TABLE downloads ADD download_url TEXT");
await reporter.MainWork("Updating attachments...", 1, 3);
await NormalizeAttachmentUrls(reporter);
await reporter.MainWork("Updating downloads...", 2, 3);
await NormalizeDownloadUrls(reporter);
await reporter.MainWork("Applying schema changes...", 3, 3);
await conn.ExecuteAsync("ALTER TABLE attachments RENAME COLUMN url TO normalized_url");
await conn.ExecuteAsync("ALTER TABLE downloads RENAME COLUMN url TO normalized_url");
perf.Step("Upgrade to version 6");
await reporter.NextVersion();
}
perf.End();
}
}

View File

@@ -0,0 +1,8 @@
using System.Threading.Tasks;
using DHT.Server.Database.Sqlite.Utils;
namespace DHT.Server.Database.Sqlite.Schema;
interface ISchemaUpgrade {
Task Run(ISqliteConnection conn, ISchemaUpgradeCallbacks.IProgressReporter reporter);
}

View File

@@ -1,7 +1,7 @@
using System;
using System.Threading.Tasks;
namespace DHT.Server.Database.Sqlite.Utils;
namespace DHT.Server.Database.Sqlite.Schema;
public interface ISchemaUpgradeCallbacks {
Task<bool> CanUpgrade();

View File

@@ -0,0 +1,11 @@
using System.Threading.Tasks;
using DHT.Server.Database.Sqlite.Utils;
namespace DHT.Server.Database.Sqlite.Schema;
sealed class SqliteSchemaUpgradeTo2 : ISchemaUpgrade {
async Task ISchemaUpgrade.Run(ISqliteConnection conn, ISchemaUpgradeCallbacks.IProgressReporter reporter) {
await reporter.MainWork("Applying schema changes...", 0, 1);
await conn.ExecuteAsync("ALTER TABLE channels ADD parent_id INTEGER");
}
}

View File

@@ -0,0 +1,33 @@
using System.Threading.Tasks;
using DHT.Server.Database.Sqlite.Utils;
namespace DHT.Server.Database.Sqlite.Schema;
sealed class SqliteSchemaUpgradeTo3 : ISchemaUpgrade {
async Task ISchemaUpgrade.Run(ISqliteConnection conn, ISchemaUpgradeCallbacks.IProgressReporter reporter) {
await reporter.MainWork("Applying schema changes...", 0, 1);
await SqliteSchema.CreateMessageEditTimestampTable(conn);
await SqliteSchema.CreateMessageRepliedToTable(conn);
await conn.ExecuteAsync("""
INSERT INTO edit_timestamps (message_id, edit_timestamp)
SELECT message_id, edit_timestamp
FROM messages
WHERE edit_timestamp IS NOT NULL
""");
await conn.ExecuteAsync("""
INSERT INTO replied_to (message_id, replied_to_id)
SELECT message_id, replied_to_id
FROM messages
WHERE replied_to_id IS NOT NULL
""");
await conn.ExecuteAsync("ALTER TABLE messages DROP COLUMN replied_to_id");
await conn.ExecuteAsync("ALTER TABLE messages DROP COLUMN edit_timestamp");
await reporter.MainWork("Vacuuming the database...", 1, 1);
await conn.ExecuteAsync("VACUUM");
}
}

View File

@@ -0,0 +1,19 @@
using System.Threading.Tasks;
using DHT.Server.Database.Sqlite.Utils;
namespace DHT.Server.Database.Sqlite.Schema;
sealed class SqliteSchemaUpgradeTo4 : ISchemaUpgrade {
async Task ISchemaUpgrade.Run(ISqliteConnection conn, ISchemaUpgradeCallbacks.IProgressReporter reporter) {
await reporter.MainWork("Applying schema changes...", 0, 1);
await conn.ExecuteAsync("""
CREATE TABLE downloads (
url TEXT NOT NULL PRIMARY KEY,
status INTEGER NOT NULL,
size INTEGER NOT NULL,
blob BLOB
)
""");
}
}

View File

@@ -0,0 +1,12 @@
using System.Threading.Tasks;
using DHT.Server.Database.Sqlite.Utils;
namespace DHT.Server.Database.Sqlite.Schema;
sealed class SqliteSchemaUpgradeTo5 : ISchemaUpgrade {
async Task ISchemaUpgrade.Run(ISqliteConnection conn, ISchemaUpgradeCallbacks.IProgressReporter reporter) {
await reporter.MainWork("Applying schema changes...", 0, 1);
await conn.ExecuteAsync("ALTER TABLE attachments ADD width INTEGER");
await conn.ExecuteAsync("ALTER TABLE attachments ADD height INTEGER");
}
}

View File

@@ -0,0 +1,132 @@
using System.Collections.Generic;
using System.Threading.Tasks;
using DHT.Server.Database.Sqlite.Utils;
using DHT.Server.Download;
using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite.Schema;
sealed class SqliteSchemaUpgradeTo6 : ISchemaUpgrade {
async Task ISchemaUpgrade.Run(ISqliteConnection conn, ISchemaUpgradeCallbacks.IProgressReporter reporter) {
await reporter.MainWork("Applying schema changes...", 0, 3);
await conn.ExecuteAsync("ALTER TABLE attachments ADD download_url TEXT");
await conn.ExecuteAsync("ALTER TABLE downloads ADD download_url TEXT");
await reporter.MainWork("Updating attachments...", 1, 3);
await NormalizeAttachmentUrls(conn, reporter);
await reporter.MainWork("Updating downloads...", 2, 3);
await NormalizeDownloadUrls(conn, reporter);
await reporter.MainWork("Applying schema changes...", 3, 3);
await conn.ExecuteAsync("ALTER TABLE attachments RENAME COLUMN url TO normalized_url");
await conn.ExecuteAsync("ALTER TABLE downloads RENAME COLUMN url TO normalized_url");
}
private async Task NormalizeAttachmentUrls(ISqliteConnection conn, ISchemaUpgradeCallbacks.IProgressReporter reporter) {
await reporter.SubWork("Preparing attachments...", 0, 0);
var normalizedUrls = new Dictionary<long, string>();
await using (var selectCmd = conn.Command("SELECT attachment_id, url FROM attachments")) {
await using var reader = await selectCmd.ExecuteReaderAsync();
while (await reader.ReadAsync()) {
var attachmentId = reader.GetInt64(0);
var originalUrl = reader.GetString(1);
normalizedUrls[attachmentId] = DiscordCdn.NormalizeUrl(originalUrl);
}
}
await conn.BeginTransactionAsync();
int totalUrls = normalizedUrls.Count;
int processedUrls = -1;
await using (var updateCmd = conn.Command("UPDATE attachments SET download_url = url, url = :normalized_url WHERE attachment_id = :attachment_id")) {
updateCmd.Add(":attachment_id", SqliteType.Integer);
updateCmd.Add(":normalized_url", SqliteType.Text);
foreach (var (attachmentId, normalizedUrl) in normalizedUrls) {
if (++processedUrls % 1000 == 0) {
await reporter.SubWork("Updating URLs...", processedUrls, totalUrls);
}
updateCmd.Set(":attachment_id", attachmentId);
updateCmd.Set(":normalized_url", normalizedUrl);
await updateCmd.ExecuteNonQueryAsync();
}
}
await reporter.SubWork("Updating URLs...", totalUrls, totalUrls);
await conn.CommitTransactionAsync();
}
private async Task NormalizeDownloadUrls(ISqliteConnection conn, ISchemaUpgradeCallbacks.IProgressReporter reporter) {
await reporter.SubWork("Preparing downloads...", 0, 0);
var normalizedUrlsToOriginalUrls = new Dictionary<string, string>();
var duplicateUrlsToDelete = new HashSet<string>();
await using (var selectCmd = conn.Command("SELECT url FROM downloads ORDER BY CASE WHEN status = 200 THEN 0 ELSE 1 END")) {
await using var reader = await selectCmd.ExecuteReaderAsync();
while (await reader.ReadAsync()) {
var originalUrl = reader.GetString(0);
var normalizedUrl = DiscordCdn.NormalizeUrl(originalUrl);
if (!normalizedUrlsToOriginalUrls.TryAdd(normalizedUrl, originalUrl)) {
duplicateUrlsToDelete.Add(originalUrl);
}
}
}
await conn.ExecuteAsync("PRAGMA cache_size = -20000");
await conn.BeginTransactionAsync();
await reporter.SubWork("Deleting duplicates...", 0, 0);
await using (var deleteCmd = conn.Delete("downloads", ("url", SqliteType.Text))) {
foreach (var duplicateUrl in duplicateUrlsToDelete) {
deleteCmd.Set(":url", duplicateUrl);
await deleteCmd.ExecuteNonQueryAsync();
}
}
await conn.CommitTransactionAsync();
int totalUrls = normalizedUrlsToOriginalUrls.Count;
int processedUrls = -1;
await conn.BeginTransactionAsync();
await using (var updateCmd = conn.Command("UPDATE downloads SET download_url = :download_url, url = :normalized_url WHERE url = :download_url")) {
updateCmd.Add(":normalized_url", SqliteType.Text);
updateCmd.Add(":download_url", SqliteType.Text);
foreach (var (normalizedUrl, downloadUrl) in normalizedUrlsToOriginalUrls) {
if (++processedUrls % 100 == 0) {
await reporter.SubWork("Updating URLs...", processedUrls, totalUrls);
// Not proper way of dealing with transactions, but it avoids a long commit at the end.
// Schema upgrades are already non-atomic anyways, so this doesn't make it worse.
await conn.CommitTransactionAsync();
await conn.BeginTransactionAsync();
conn.AssignActiveTransaction(updateCmd);
}
updateCmd.Set(":normalized_url", normalizedUrl);
updateCmd.Set(":download_url", downloadUrl);
await updateCmd.ExecuteNonQueryAsync();
}
}
await reporter.SubWork("Updating URLs...", totalUrls, totalUrls);
await conn.CommitTransactionAsync();
await conn.ExecuteAsync("PRAGMA cache_size = -2000");
}
}

View File

@@ -0,0 +1,153 @@
using System.Collections.Generic;
using System.Threading.Tasks;
using DHT.Server.Data;
using DHT.Server.Database.Sqlite.Utils;
using DHT.Server.Download;
using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite.Schema;
sealed class SqliteSchemaUpgradeTo7 : ISchemaUpgrade {
async Task ISchemaUpgrade.Run(ISqliteConnection conn, ISchemaUpgradeCallbacks.IProgressReporter reporter) {
await reporter.MainWork("Applying schema changes...", 0, 6);
await SqliteSchema.CreateDownloadTables(conn);
await reporter.MainWork("Migrating download metadata...", 1, 6);
await conn.ExecuteAsync("INSERT INTO download_metadata (normalized_url, download_url, status, size) SELECT normalized_url, download_url, status, size FROM downloads");
await reporter.MainWork("Merging attachment metadata...", 2, 6);
await conn.ExecuteAsync("UPDATE download_metadata SET type = (SELECT type FROM attachments WHERE download_metadata.normalized_url = attachments.normalized_url)");
await reporter.MainWork("Migrating downloaded files...", 3, 6);
await MigrateDownloadBlobsToNewTable(conn, reporter);
await reporter.MainWork("Applying schema changes...", 4, 6);
await conn.ExecuteAsync("DROP TABLE downloads");
await reporter.MainWork("Discovering downloadable links...", 5, 6);
await DiscoverDownloadableLinks(conn, reporter);
}
private async Task MigrateDownloadBlobsToNewTable(ISqliteConnection conn, ISchemaUpgradeCallbacks.IProgressReporter reporter) {
await reporter.SubWork("Listing downloaded files...", 0, 0);
var urlsToMigrate = await GetDownloadedFileUrls(conn);
int totalFiles = urlsToMigrate.Count;
int processedFiles = -1;
await reporter.SubWork("Processing downloaded files...", 0, totalFiles);
await conn.BeginTransactionAsync();
await using (var insertCmd = conn.Command("INSERT INTO download_blobs (normalized_url, blob) SELECT normalized_url, blob FROM downloads WHERE normalized_url = :normalized_url"))
await using (var deleteCmd = conn.Command("DELETE FROM downloads WHERE normalized_url = :normalized_url")) {
insertCmd.Add(":normalized_url", SqliteType.Text);
deleteCmd.Add(":normalized_url", SqliteType.Text);
foreach (var url in urlsToMigrate) {
if (++processedFiles % 10 == 0) {
await reporter.SubWork("Processing downloaded files...", processedFiles, totalFiles);
// Not proper way of dealing with transactions, but it avoids a long commit at the end.
// Schema upgrades are already non-atomic anyways, so this doesn't make it worse.
await conn.CommitTransactionAsync();
await conn.BeginTransactionAsync();
conn.AssignActiveTransaction(insertCmd);
conn.AssignActiveTransaction(deleteCmd);
}
insertCmd.Set(":normalized_url", url);
await insertCmd.ExecuteNonQueryAsync();
deleteCmd.Set(":normalized_url", url);
await deleteCmd.ExecuteNonQueryAsync();
}
}
await reporter.SubWork("Processing downloaded files...", totalFiles, totalFiles);
await conn.CommitTransactionAsync();
}
private async Task<List<string>> GetDownloadedFileUrls(ISqliteConnection conn) {
var urls = new List<string>();
await using var selectCmd = conn.Command("SELECT normalized_url FROM downloads WHERE blob IS NOT NULL");
await using var reader = await selectCmd.ExecuteReaderAsync();
while (await reader.ReadAsync()) {
urls.Add(reader.GetString(0));
}
return urls;
}
private async Task DiscoverDownloadableLinks(ISqliteConnection conn, ISchemaUpgradeCallbacks.IProgressReporter reporter) {
await reporter.SubWork("Processing attachments...", 0, 4);
await using (var cmd = conn.Command("""
INSERT OR IGNORE INTO download_metadata (normalized_url, download_url, status, type, size)
SELECT a.normalized_url, a.download_url, :pending, a.type, MAX(a.size)
FROM attachments a
GROUP BY a.normalized_url
""")) {
cmd.AddAndSet(":pending", SqliteType.Integer, (int) DownloadStatus.Pending);
await cmd.ExecuteNonQueryAsync();
}
static async Task InsertDownload(SqliteCommand insertCmd, Data.Download? download) {
if (download == null) {
return;
}
insertCmd.Set(":normalized_url", download.NormalizedUrl);
insertCmd.Set(":download_url", download.DownloadUrl);
insertCmd.Set(":status", (int) download.Status);
insertCmd.Set(":type", download.Type);
insertCmd.Set(":size", download.Size);
await insertCmd.ExecuteNonQueryAsync();
}
await conn.BeginTransactionAsync();
await using var insertCmd = conn.Command("INSERT OR IGNORE INTO download_metadata (normalized_url, download_url, status, type, size) VALUES (:normalized_url, :download_url, :status, :type, :size)");
insertCmd.Add(":normalized_url", SqliteType.Text);
insertCmd.Add(":download_url", SqliteType.Text);
insertCmd.Add(":status", SqliteType.Integer);
insertCmd.Add(":type", SqliteType.Text);
insertCmd.Add(":size", SqliteType.Integer);
await reporter.SubWork("Processing embeds...", 1, 4);
await using (var embedCmd = conn.Command("SELECT json FROM embeds")) {
await using var reader = await embedCmd.ExecuteReaderAsync();
while (await reader.ReadAsync()) {
await InsertDownload(insertCmd, await DownloadLinkExtractor.TryFromEmbedJson(reader.GetStream(0)));
}
}
await reporter.SubWork("Processing users...", 2, 4);
await using (var avatarCmd = conn.Command("SELECT id, avatar_url FROM users WHERE avatar_url IS NOT NULL")) {
await using var reader = await avatarCmd.ExecuteReaderAsync();
while (await reader.ReadAsync()) {
await InsertDownload(insertCmd, DownloadLinkExtractor.FromUserAvatar(reader.GetUint64(0), reader.GetString(1)));
}
}
await reporter.SubWork("Processing reactions...", 3, 4);
await using (var avatarCmd = conn.Command("SELECT DISTINCT emoji_id, emoji_flags FROM reactions WHERE emoji_id IS NOT NULL")) {
await using var reader = await avatarCmd.ExecuteReaderAsync();
while (await reader.ReadAsync()) {
await InsertDownload(insertCmd, DownloadLinkExtractor.FromEmoji(reader.GetUint64(0), (EmojiFlags) reader.GetInt16(1)));
}
}
await conn.CommitTransactionAsync();
}
}

View File

@@ -2,6 +2,7 @@ using System;
using System.Threading.Tasks;
using DHT.Server.Database.Repositories;
using DHT.Server.Database.Sqlite.Repositories;
using DHT.Server.Database.Sqlite.Schema;
using DHT.Server.Database.Sqlite.Utils;
using Microsoft.Data.Sqlite;
@@ -21,7 +22,7 @@ public sealed class SqliteDatabaseFile : IDatabaseFile {
try {
await using var conn = await pool.Take();
wasOpened = await new Schema(conn).Setup(schemaUpgradeCallbacks);
wasOpened = await new SqliteSchema(conn).Setup(schemaUpgradeCallbacks);
} catch (Exception) {
await pool.DisposeAsync();
throw;
@@ -42,7 +43,6 @@ public sealed class SqliteDatabaseFile : IDatabaseFile {
public IServerRepository Servers => servers;
public IChannelRepository Channels => channels;
public IMessageRepository Messages => messages;
public IAttachmentRepository Attachments => attachments;
public IDownloadRepository Downloads => downloads;
private readonly SqliteConnectionPool pool;
@@ -51,18 +51,17 @@ public sealed class SqliteDatabaseFile : IDatabaseFile {
private readonly SqliteServerRepository servers;
private readonly SqliteChannelRepository channels;
private readonly SqliteMessageRepository messages;
private readonly SqliteAttachmentRepository attachments;
private readonly SqliteDownloadRepository downloads;
private SqliteDatabaseFile(string path, SqliteConnectionPool pool) {
this.Path = path;
this.pool = pool;
users = new SqliteUserRepository(pool);
downloads = new SqliteDownloadRepository(pool);
users = new SqliteUserRepository(pool, downloads);
servers = new SqliteServerRepository(pool);
channels = new SqliteChannelRepository(pool);
messages = new SqliteMessageRepository(pool, attachments = new SqliteAttachmentRepository(pool));
downloads = new SqliteDownloadRepository(pool);
messages = new SqliteMessageRepository(pool, downloads);
}
public async ValueTask DisposeAsync() {
@@ -70,7 +69,6 @@ public sealed class SqliteDatabaseFile : IDatabaseFile {
servers.Dispose();
channels.Dispose();
messages.Dispose();
attachments.Dispose();
downloads.Dispose();
await pool.DisposeAsync();
}

View File

@@ -8,77 +8,53 @@ using DHT.Server.Database.Sqlite.Utils;
namespace DHT.Server.Database.Sqlite;
static class SqliteFilters {
private static string WhereAll(bool invert) {
return invert ? "WHERE FALSE" : "";
public static SqliteConditionBuilder GenerateConditions(this MessageFilter? filter, string? tableAlias = null, bool invert = false) {
var builder = new SqliteConditionBuilder(tableAlias, invert);
if (filter != null) {
if (filter.StartDate != null) {
builder.AddCondition("timestamp >= " + new DateTimeOffset(filter.StartDate.Value).ToUnixTimeMilliseconds());
}
if (filter.EndDate != null) {
builder.AddCondition("timestamp <= " + new DateTimeOffset(filter.EndDate.Value).ToUnixTimeMilliseconds());
}
if (filter.ChannelIds != null) {
builder.AddCondition("channel_id IN (" + string.Join(",", filter.ChannelIds) + ")");
}
if (filter.UserIds != null) {
builder.AddCondition("sender_id IN (" + string.Join(",", filter.UserIds) + ")");
}
if (filter.MessageIds != null) {
builder.AddCondition("message_id IN (" + string.Join(",", filter.MessageIds) + ")");
}
}
return builder;
}
public static string GenerateWhereClause(this MessageFilter? filter, string? tableAlias = null, bool invert = false) {
if (filter == null || filter.IsEmpty) {
return WhereAll(invert);
public static SqliteConditionBuilder GenerateConditions(this DownloadItemFilter? filter, string? tableAlias = null, bool invert = false) {
var builder = new SqliteConditionBuilder(tableAlias, invert);
if (filter != null) {
if (filter.IncludeStatuses != null) {
builder.AddCondition("status IN (" + filter.IncludeStatuses.In() + ")");
}
if (filter.ExcludeStatuses != null) {
builder.AddCondition("status NOT IN (" + filter.ExcludeStatuses.In() + ")");
}
if (filter.MaxBytes != null) {
builder.AddCondition("size IS NOT NULL");
builder.AddCondition("size <= " + filter.MaxBytes);
}
}
var where = new SqliteWhereGenerator(tableAlias, invert);
if (filter.StartDate != null) {
where.AddCondition("timestamp >= " + new DateTimeOffset(filter.StartDate.Value).ToUnixTimeMilliseconds());
}
if (filter.EndDate != null) {
where.AddCondition("timestamp <= " + new DateTimeOffset(filter.EndDate.Value).ToUnixTimeMilliseconds());
}
if (filter.ChannelIds != null) {
where.AddCondition("channel_id IN (" + string.Join(",", filter.ChannelIds) + ")");
}
if (filter.UserIds != null) {
where.AddCondition("sender_id IN (" + string.Join(",", filter.UserIds) + ")");
}
if (filter.MessageIds != null) {
where.AddCondition("message_id IN (" + string.Join(",", filter.MessageIds) + ")");
}
return where.Generate();
}
public static string GenerateWhereClause(this AttachmentFilter? filter, string? tableAlias = null, bool invert = false) {
if (filter == null || filter.IsEmpty) {
return WhereAll(invert);
}
var where = new SqliteWhereGenerator(tableAlias, invert);
if (filter.MaxBytes != null) {
where.AddCondition("size <= " + filter.MaxBytes);
}
if (filter.DownloadItemRule == AttachmentFilter.DownloadItemRules.OnlyNotPresent) {
where.AddCondition("normalized_url NOT IN (SELECT normalized_url FROM downloads)");
}
else if (filter.DownloadItemRule == AttachmentFilter.DownloadItemRules.OnlyPresent) {
where.AddCondition("normalized_url IN (SELECT normalized_url FROM downloads)");
}
return where.Generate();
}
public static string GenerateWhereClause(this DownloadItemFilter? filter, string? tableAlias = null, bool invert = false) {
if (filter == null || filter.IsEmpty) {
return WhereAll(invert);
}
var where = new SqliteWhereGenerator(tableAlias, invert);
if (filter.IncludeStatuses != null) {
where.AddCondition("status IN (" + filter.IncludeStatuses.In() + ")");
}
if (filter.ExcludeStatuses != null) {
where.AddCondition("status NOT IN (" + filter.ExcludeStatuses.In() + ")");
}
return where.Generate();
return builder;
}
private static string In(this ISet<DownloadStatus> statuses) {

View File

@@ -0,0 +1,193 @@
using System.Collections.Generic;
using System.Threading.Tasks;
using DHT.Server.Database.Exceptions;
using DHT.Server.Database.Sqlite.Schema;
using DHT.Server.Database.Sqlite.Utils;
using DHT.Utils.Logging;
namespace DHT.Server.Database.Sqlite;
sealed class SqliteSchema {
internal const int Version = 7;
private static readonly Log Log = Log.ForType<SqliteSchema>();
private readonly ISqliteConnection conn;
public SqliteSchema(ISqliteConnection conn) {
this.conn = conn;
}
public async Task<bool> Setup(ISchemaUpgradeCallbacks callbacks) {
await conn.ExecuteAsync("CREATE TABLE IF NOT EXISTS metadata (key TEXT PRIMARY KEY, value TEXT)");
var dbVersionStr = await conn.ExecuteReaderAsync("SELECT value FROM metadata WHERE key = 'version'", static reader => reader?.GetString(0));
if (dbVersionStr == null) {
await InitializeSchemas();
}
else if (!int.TryParse(dbVersionStr, out int dbVersion) || dbVersion < 1) {
throw new InvalidDatabaseVersionException(dbVersionStr);
}
else if (dbVersion > Version) {
throw new DatabaseTooNewException(dbVersion);
}
else if (dbVersion < Version) {
var proceed = await callbacks.CanUpgrade();
if (!proceed) {
return false;
}
await callbacks.Start(Version - dbVersion, async reporter => await UpgradeSchemas(dbVersion, reporter));
}
return true;
}
private async Task InitializeSchemas() {
await conn.ExecuteAsync("""
CREATE TABLE users (
id INTEGER PRIMARY KEY NOT NULL,
name TEXT NOT NULL,
avatar_url TEXT,
discriminator TEXT
)
""");
await conn.ExecuteAsync("""
CREATE TABLE servers (
id INTEGER PRIMARY KEY NOT NULL,
name TEXT NOT NULL,
type TEXT NOT NULL
)
""");
await conn.ExecuteAsync("""
CREATE TABLE channels (
id INTEGER PRIMARY KEY NOT NULL,
server INTEGER NOT NULL,
name TEXT NOT NULL,
parent_id INTEGER,
position INTEGER,
topic TEXT,
nsfw INTEGER
)
""");
await conn.ExecuteAsync("""
CREATE TABLE messages (
message_id INTEGER PRIMARY KEY NOT NULL,
sender_id INTEGER NOT NULL,
channel_id INTEGER NOT NULL,
text TEXT NOT NULL,
timestamp INTEGER NOT NULL
)
""");
await conn.ExecuteAsync("""
CREATE TABLE attachments (
message_id INTEGER NOT NULL,
attachment_id INTEGER NOT NULL PRIMARY KEY NOT NULL,
name TEXT NOT NULL,
type TEXT,
normalized_url TEXT NOT NULL,
download_url TEXT,
size INTEGER NOT NULL,
width INTEGER,
height INTEGER
)
""");
await conn.ExecuteAsync("""
CREATE TABLE embeds (
message_id INTEGER NOT NULL,
json TEXT NOT NULL
)
""");
await conn.ExecuteAsync("""
CREATE TABLE reactions (
message_id INTEGER NOT NULL,
emoji_id INTEGER,
emoji_name TEXT,
emoji_flags INTEGER NOT NULL,
count INTEGER NOT NULL
)
""");
await CreateMessageEditTimestampTable(conn);
await CreateMessageRepliedToTable(conn);
await CreateDownloadTables(conn);
await conn.ExecuteAsync("CREATE INDEX attachments_message_ix ON attachments(message_id)");
await conn.ExecuteAsync("CREATE INDEX embeds_message_ix ON embeds(message_id)");
await conn.ExecuteAsync("CREATE INDEX reactions_message_ix ON reactions(message_id)");
await conn.ExecuteAsync("INSERT INTO metadata (key, value) VALUES ('version', " + Version + ")");
}
internal static async Task CreateMessageEditTimestampTable(ISqliteConnection conn) {
await conn.ExecuteAsync("""
CREATE TABLE edit_timestamps (
message_id INTEGER PRIMARY KEY NOT NULL,
edit_timestamp INTEGER NOT NULL
)
""");
}
internal static async Task CreateMessageRepliedToTable(ISqliteConnection conn) {
await conn.ExecuteAsync("""
CREATE TABLE replied_to (
message_id INTEGER PRIMARY KEY NOT NULL,
replied_to_id INTEGER NOT NULL
)
""");
}
internal static async Task CreateDownloadTables(ISqliteConnection conn) {
await conn.ExecuteAsync("""
CREATE TABLE download_metadata (
normalized_url TEXT NOT NULL PRIMARY KEY,
download_url TEXT NOT NULL,
status INTEGER NOT NULL,
type TEXT,
size INTEGER
)
""");
await conn.ExecuteAsync("""
CREATE TABLE download_blobs (
normalized_url TEXT NOT NULL PRIMARY KEY,
blob BLOB NOT NULL,
FOREIGN KEY (normalized_url) REFERENCES download_metadata (normalized_url) ON UPDATE CASCADE ON DELETE CASCADE
)
""");
}
private async Task UpgradeSchemas(int dbVersion, ISchemaUpgradeCallbacks.IProgressReporter reporter) {
var upgrades = new Dictionary<int, ISchemaUpgrade> {
{ 1, new SqliteSchemaUpgradeTo2() },
{ 2, new SqliteSchemaUpgradeTo3() },
{ 3, new SqliteSchemaUpgradeTo4() },
{ 4, new SqliteSchemaUpgradeTo5() },
{ 5, new SqliteSchemaUpgradeTo6() },
{ 6, new SqliteSchemaUpgradeTo7() },
};
var perf = Log.Start("from version " + dbVersion);
for (int fromVersion = dbVersion; fromVersion < Version; fromVersion++) {
var toVersion = fromVersion + 1;
if (upgrades.TryGetValue(fromVersion, out var upgrade)) {
await upgrade.Run(conn, reporter);
}
await conn.ExecuteAsync("UPDATE metadata SET value = " + toVersion + " WHERE key = 'version'");
perf.Step("Upgrade to version " + toVersion);
await reporter.NextVersion();
}
perf.End();
}
}

View File

@@ -1,8 +1,15 @@
using System;
using System.Threading.Tasks;
using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite.Utils;
interface ISqliteConnection : IAsyncDisposable {
SqliteConnection InnerConnection { get; }
Task BeginTransactionAsync();
Task CommitTransactionAsync();
Task RollbackTransactionAsync();
void AssignActiveTransaction(SqliteCommand command);
}

View File

@@ -2,12 +2,12 @@ using System.Collections.Generic;
namespace DHT.Server.Database.Sqlite.Utils;
sealed class SqliteWhereGenerator {
sealed class SqliteConditionBuilder {
private readonly string? tableAlias;
private readonly bool invert;
private readonly List<string> conditions = [];
public SqliteWhereGenerator(string? tableAlias, bool invert) {
public SqliteConditionBuilder(string? tableAlias, bool invert) {
this.tableAlias = tableAlias;
this.invert = invert;
}
@@ -16,16 +16,20 @@ sealed class SqliteWhereGenerator {
conditions.Add(tableAlias == null ? condition : tableAlias + '.' + condition);
}
public string Generate() {
public string Build() {
if (conditions.Count == 0) {
return "";
return invert ? "FALSE" : "TRUE";
}
if (invert) {
return " WHERE NOT (" + string.Join(" AND ", conditions) + ")";
return "NOT (" + string.Join(" AND ", conditions) + ")";
}
else {
return " WHERE " + string.Join(" AND ", conditions);
return string.Join(" AND ", conditions);
}
}
public string BuildWhereClause() {
return " WHERE " + Build();
}
}

View File

@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
using System.Data.Common;
using System.Threading;
using System.Threading.Tasks;
using DHT.Utils.Collections;
@@ -42,9 +43,8 @@ sealed class SqliteConnectionPool : IAsyncDisposable {
var pooledConnection = new PooledConnection(this, conn);
await using (var cmd = pooledConnection.Command("PRAGMA journal_mode=WAL")) {
await cmd.ExecuteNonQueryAsync(disposalToken);
}
await pooledConnection.ExecuteAsync("PRAGMA journal_mode=WAL", disposalToken);
await pooledConnection.ExecuteAsync("PRAGMA foreign_keys=ON", disposalToken);
all.Add(pooledConnection);
await free.Push(pooledConnection, disposalToken);
@@ -74,17 +74,48 @@ sealed class SqliteConnectionPool : IAsyncDisposable {
disposalTokenSource.Dispose();
}
private sealed class PooledConnection : ISqliteConnection {
public SqliteConnection InnerConnection { get; }
private sealed class PooledConnection(SqliteConnectionPool pool, SqliteConnection conn) : ISqliteConnection {
public SqliteConnection InnerConnection { get; } = conn;
private readonly SqliteConnectionPool pool;
private DbTransaction? activeTransaction;
public PooledConnection(SqliteConnectionPool pool, SqliteConnection conn) {
this.pool = pool;
this.InnerConnection = conn;
public async Task BeginTransactionAsync() {
if (activeTransaction != null) {
throw new InvalidOperationException("A transaction is already active.");
}
activeTransaction = await InnerConnection.BeginTransactionAsync();
}
public async Task CommitTransactionAsync() {
if (activeTransaction == null) {
throw new InvalidOperationException("No active transaction to commit.");
}
await activeTransaction.CommitAsync();
await activeTransaction.DisposeAsync();
activeTransaction = null;
}
public async Task RollbackTransactionAsync() {
if (activeTransaction == null) {
throw new InvalidOperationException("No active transaction to rollback.");
}
await activeTransaction.RollbackAsync();
await activeTransaction.DisposeAsync();
activeTransaction = null;
}
public void AssignActiveTransaction(SqliteCommand command) {
command.Transaction = (SqliteTransaction?) activeTransaction;
}
public async ValueTask DisposeAsync() {
if (activeTransaction != null) {
await RollbackTransactionAsync();
}
await pool.Return(this);
}
}

View File

@@ -1,5 +1,4 @@
using System;
using System.Data.Common;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
@@ -9,11 +8,7 @@ using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite.Utils;
static class SqliteExtensions {
public static ValueTask<DbTransaction> BeginTransactionAsync(this ISqliteConnection conn) {
return conn.InnerConnection.BeginTransactionAsync();
}
public static SqliteCommand Command(this ISqliteConnection conn, string sql) {
public static SqliteCommand Command(this ISqliteConnection conn, [LanguageInjection("sql")] string sql) {
var cmd = conn.InnerConnection.CreateCommand();
cmd.CommandText = sql;
return cmd;
@@ -28,7 +23,11 @@ static class SqliteExtensions {
await using var cmd = conn.Command(sql);
await using var reader = await cmd.ExecuteReaderAsync(cancellationToken);
return reader.Read() ? readFunction(reader) : readFunction(null);
return await reader.ReadAsync(cancellationToken) ? readFunction(reader) : readFunction(null);
}
public static async Task<long> ExecuteLongScalarAsync(this SqliteCommand command) {
return (long) (await command.ExecuteScalarAsync())!;
}
public static SqliteCommand Insert(this ISqliteConnection conn, string tableName, (string Name, SqliteType Type)[] columns) {

View File

@@ -1,5 +1,7 @@
using System;
using System.Collections.Frozen;
using System.Diagnostics.CodeAnalysis;
using System.Web;
namespace DHT.Server.Download;
@@ -7,9 +9,35 @@ static class DiscordCdn {
private static FrozenSet<string> CdnHosts { get; } = new[] {
"cdn.discordapp.com",
"cdn.discord.com",
"media.discordapp.net"
}.ToFrozenSet();
private static bool IsCdnUrl(string originalUrl, [NotNullWhen(true)] out Uri? uri) {
return Uri.TryCreate(originalUrl, UriKind.Absolute, out uri) && CdnHosts.Contains(uri.Host);
}
public static string NormalizeUrl(string originalUrl) {
return Uri.TryCreate(originalUrl, UriKind.Absolute, out var uri) && CdnHosts.Contains(uri.Host) ? uri.GetLeftPart(UriPartial.Path) : originalUrl;
return IsCdnUrl(originalUrl, out var uri) ? DoNormalize(uri) : originalUrl;
}
public static bool NormalizeUrlAndReturnIfCdn(string originalUrl, out string normalizedUrl) {
if (IsCdnUrl(originalUrl, out var uri)) {
normalizedUrl = DoNormalize(uri);
return true;
}
else {
normalizedUrl = originalUrl;
return false;
}
}
private static string DoNormalize(Uri uri) {
var query = HttpUtility.ParseQueryString(uri.Query);
query.Remove("ex");
query.Remove("is");
query.Remove("hm");
return new UriBuilder(uri) { Query = query.ToString() }.Uri.ToString();
}
}

View File

@@ -1,7 +1,21 @@
using System;
using System.Net;
using DHT.Server.Data;
namespace DHT.Server.Download;
public readonly struct DownloadItem {
public string NormalizedUrl { get; init; }
public string DownloadUrl { get; init; }
public ulong Size { get; init; }
public string? Type { get; init; }
public ulong? Size { get; init; }
internal Data.Download ToSuccess(long size) {
return new Data.Download(NormalizedUrl, DownloadUrl, DownloadStatus.Success, Type, (ulong) Math.Max(size, 0));
}
internal Data.Download ToFailure(HttpStatusCode? statusCode = null) {
var status = statusCode.HasValue ? (DownloadStatus) (int) statusCode : DownloadStatus.GenericError;
return new Data.Download(NormalizedUrl, DownloadUrl, status, Type, Size);
}
}

View File

@@ -0,0 +1,122 @@
using System;
using System.IO;
using System.Net.Mime;
using System.Text.Json;
using System.Threading.Tasks;
using DHT.Server.Data;
using DHT.Server.Data.Embeds;
using DHT.Utils.Logging;
namespace DHT.Server.Download;
static class DownloadLinkExtractor {
private static readonly Log Log = Log.ForType(typeof(DownloadLinkExtractor));
public static Data.Download FromUserAvatar(ulong userId, string avatarPath) {
string url = $"https://cdn.discordapp.com/avatars/{userId}/{avatarPath}.webp";
return new Data.Download(url, url, DownloadStatus.Pending, MediaTypeNames.Image.Webp, size: null);
}
public static Data.Download FromEmoji(ulong emojiId, EmojiFlags flags) {
var isAnimated = flags.HasFlag(EmojiFlags.Animated);
string ext = isAnimated ? "gif" : "webp";
string type = isAnimated ? MediaTypeNames.Image.Gif : MediaTypeNames.Image.Webp;
string url = $"https://cdn.discordapp.com/emojis/{emojiId}.{ext}";
return new Data.Download(url, url, DownloadStatus.Pending, type, size: null);
}
public static Data.Download FromAttachment(Attachment attachment) {
return new Data.Download(attachment.NormalizedUrl, attachment.DownloadUrl, DownloadStatus.Pending, attachment.Type, attachment.Size);
}
public static async Task<Data.Download?> TryFromEmbedJson(Stream jsonStream) {
try {
return FromEmbed(await JsonSerializer.DeserializeAsync(jsonStream, DiscordEmbedJsonContext.Default.DiscordEmbedJson));
} catch (Exception e) {
Log.Error("Could not parse embed json: " + e);
return null;
}
}
public static Data.Download? TryFromEmbedJson(string json) {
try {
return FromEmbed(JsonSerializer.Deserialize(json, DiscordEmbedJsonContext.Default.DiscordEmbedJson));
} catch (Exception e) {
Log.Error("Could not parse embed json: " + e);
return null;
}
}
private static Data.Download? FromEmbed(DiscordEmbedJson? embed) {
if (embed is { Type: "image", Image.Url: {} imageUrl }) {
return FromEmbedImage(imageUrl);
}
else if (embed is { Type: "video", Video.Url: {} videoUrl }) {
return FromEmbedVideo(videoUrl);
}
else {
return null;
}
}
private static Data.Download? FromEmbedImage(string url) {
if (DiscordCdn.NormalizeUrlAndReturnIfCdn(url, out var normalizedUrl)) {
return new Data.Download(normalizedUrl, url, DownloadStatus.Pending, GuessImageType(normalizedUrl), size: null);
}
else {
Log.Debug("Skipping non-CDN image url: " + url);
return null;
}
}
private static Data.Download? FromEmbedVideo(string url) {
if (DiscordCdn.NormalizeUrlAndReturnIfCdn(url, out var normalizedUrl)) {
return new Data.Download(normalizedUrl, url, DownloadStatus.Pending, GuessVideoType(normalizedUrl), size: null);
}
else {
Log.Debug("Skipping non-CDN video url: " + url);
return null;
}
}
private static string? GuessImageType(string url) {
if (!Uri.TryCreate(url, UriKind.Absolute, out var uri)) {
return null;
}
ReadOnlySpan<char> extension = Path.GetExtension(uri.AbsolutePath).ToLowerInvariant();
// Remove Twitter quality suffix.
int colonIndex = extension.IndexOf(':');
if (colonIndex != -1) {
extension = extension[..colonIndex];
}
return extension switch {
".jpg" => MediaTypeNames.Image.Jpeg,
".jpeg" => MediaTypeNames.Image.Jpeg,
".png" => MediaTypeNames.Image.Png,
".gif" => MediaTypeNames.Image.Gif,
".webp" => MediaTypeNames.Image.Webp,
".bmp" => MediaTypeNames.Image.Bmp,
_ => null
};
}
private static string? GuessVideoType(string url) {
if (!Uri.TryCreate(url, UriKind.Absolute, out var uri)) {
return null;
}
string extension = Path.GetExtension(uri.AbsolutePath).ToLowerInvariant();
return extension switch {
".mp4" => "video/mp4",
".mpeg" => "video/mpeg",
".webm" => "video/webm",
".mov" => "video/quicktime",
_ => null
};
}
}

View File

@@ -1,6 +1,7 @@
using System;
using System.Threading;
using System.Threading.Tasks;
using DHT.Server.Data.Filters;
using DHT.Server.Database;
namespace DHT.Server.Download;
@@ -10,16 +11,18 @@ public sealed class Downloader {
public bool IsDownloading => current != null;
private readonly IDatabaseFile db;
private readonly int? concurrentDownloads;
private readonly SemaphoreSlim semaphore = new (1, 1);
internal Downloader(IDatabaseFile db) {
internal Downloader(IDatabaseFile db, int? concurrentDownloads) {
this.db = db;
this.concurrentDownloads = concurrentDownloads;
}
public async Task<IObservable<DownloadItem>> Start() {
public async Task<IObservable<DownloadItem>> Start(DownloadItemFilter filter) {
await semaphore.WaitAsync();
try {
current ??= new DownloaderTask(db);
current ??= new DownloaderTask(db, filter, concurrentDownloads);
return current.FinishedItems;
} finally {
semaphore.Release();

View File

@@ -5,6 +5,7 @@ using System.Reactive.Subjects;
using System.Threading;
using System.Threading.Channels;
using System.Threading.Tasks;
using DHT.Server.Data.Filters;
using DHT.Server.Database;
using DHT.Utils.Logging;
using DHT.Utils.Tasks;
@@ -14,10 +15,14 @@ namespace DHT.Server.Download;
sealed class DownloaderTask : IAsyncDisposable {
private static readonly Log Log = Log.ForType<DownloaderTask>();
private const int DownloadTasks = 4;
private const int DefaultConcurrentDownloads = 4;
private const int QueueSize = 25;
private const string UserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36";
private static int GetDownloadTaskCount(int? concurrentDownloads) {
return Math.Max(1, concurrentDownloads ?? DefaultConcurrentDownloads);
}
private readonly Channel<DownloadItem> downloadQueue = Channel.CreateBounded<DownloadItem>(new BoundedChannelOptions(QueueSize) {
SingleReader = false,
SingleWriter = true,
@@ -29,6 +34,7 @@ sealed class DownloaderTask : IAsyncDisposable {
private readonly CancellationToken cancellationToken;
private readonly IDatabaseFile db;
private readonly DownloadItemFilter filter;
private readonly ISubject<DownloadItem> finishedItemPublisher = Subject.Synchronize(new Subject<DownloadItem>());
private readonly Task queueWriterTask;
@@ -36,16 +42,17 @@ sealed class DownloaderTask : IAsyncDisposable {
public IObservable<DownloadItem> FinishedItems => finishedItemPublisher;
internal DownloaderTask(IDatabaseFile db) {
internal DownloaderTask(IDatabaseFile db, DownloadItemFilter filter, int? concurrentDownloads) {
this.db = db;
this.filter = filter;
this.cancellationToken = cancellationTokenSource.Token;
this.queueWriterTask = Task.Run(RunQueueWriterTask);
this.downloadTasks = Enumerable.Range(1, DownloadTasks).Select(taskIndex => Task.Run(() => RunDownloadTask(taskIndex))).ToArray();
this.downloadTasks = Enumerable.Range(1, GetDownloadTaskCount(concurrentDownloads)).Select(taskIndex => Task.Run(() => RunDownloadTask(taskIndex))).ToArray();
}
private async Task RunQueueWriterTask() {
while (await downloadQueue.Writer.WaitToWriteAsync(cancellationToken)) {
var newItems = await db.Downloads.PullEnqueuedDownloadItems(QueueSize, cancellationToken).ToListAsync(cancellationToken);
var newItems = await db.Downloads.PullPendingDownloadItems(QueueSize, filter, cancellationToken).ToListAsync(cancellationToken);
if (newItems.Count == 0) {
await Task.Delay(TimeSpan.FromMilliseconds(50), cancellationToken);
continue;
@@ -60,24 +67,39 @@ sealed class DownloaderTask : IAsyncDisposable {
private async Task RunDownloadTask(int taskIndex) {
var log = Log.ForType<DownloaderTask>("Task " + taskIndex);
var client = new HttpClient();
var client = new HttpClient(new SocketsHttpHandler {
ConnectTimeout = TimeSpan.FromSeconds(30)
});
client.Timeout = Timeout.InfiniteTimeSpan;
client.DefaultRequestHeaders.UserAgent.ParseAdd(UserAgent);
client.Timeout = TimeSpan.FromSeconds(30);
while (!cancellationToken.IsCancellationRequested) {
var item = await downloadQueue.Reader.ReadAsync(cancellationToken);
log.Debug("Downloading " + item.DownloadUrl + "...");
try {
var downloadedBytes = await client.GetByteArrayAsync(item.DownloadUrl, cancellationToken);
await db.Downloads.AddDownload(Data.Download.NewSuccess(item, downloadedBytes));
} catch (OperationCanceledException) {
var response = await client.SendAsync(new HttpRequestMessage(HttpMethod.Get, item.DownloadUrl), HttpCompletionOption.ResponseHeadersRead, cancellationToken);
response.EnsureSuccessStatusCode();
if (response.Content.Headers.ContentLength is {} contentLength) {
await using var stream = await response.Content.ReadAsStreamAsync(cancellationToken);
await db.Downloads.AddDownload(item.ToSuccess(contentLength), stream);
}
else {
await db.Downloads.AddDownload(item.ToFailure(), stream: null);
log.Error("Download response has no content length: " + item.DownloadUrl);
}
} catch (OperationCanceledException e) when (e.CancellationToken == cancellationToken) {
// Ignore.
} catch (TaskCanceledException e) when (e.InnerException is TimeoutException) {
await db.Downloads.AddDownload(item.ToFailure(), stream: null);
log.Error("Download timed out: " + item.DownloadUrl);
} catch (HttpRequestException e) {
await db.Downloads.AddDownload(Data.Download.NewFailure(item, e.StatusCode, item.Size));
await db.Downloads.AddDownload(item.ToFailure(e.StatusCode), stream: null);
log.Error(e);
} catch (Exception e) {
await db.Downloads.AddDownload(Data.Download.NewFailure(item, null, item.Size));
await db.Downloads.AddDownload(item.ToFailure(), stream: null);
log.Error(e);
} finally {
try {

View File

@@ -9,37 +9,37 @@ using Microsoft.AspNetCore.Http;
namespace DHT.Server.Endpoints;
abstract class BaseEndpoint {
abstract class BaseEndpoint(IDatabaseFile db) {
private static readonly Log Log = Log.ForType<BaseEndpoint>();
protected IDatabaseFile Db { get; }
protected BaseEndpoint(IDatabaseFile db) {
this.Db = db;
}
protected IDatabaseFile Db { get; } = db;
public async Task Handle(HttpContext ctx) {
var response = ctx.Response;
try {
response.StatusCode = (int) HttpStatusCode.OK;
var output = await Respond(ctx);
await output.WriteTo(response);
await Respond(ctx.Request, response);
} catch (HttpException e) {
Log.Error(e);
response.StatusCode = (int) e.StatusCode;
await response.WriteAsync(e.Message);
if (response.HasStarted) {
Log.Warn("Response has already started, cannot write status message: " + e.Message);
}
else {
await response.WriteAsync(e.Message);
}
} catch (Exception e) {
Log.Error(e);
response.StatusCode = (int) HttpStatusCode.InternalServerError;
}
}
protected abstract Task<IHttpOutput> Respond(HttpContext ctx);
protected abstract Task Respond(HttpRequest request, HttpResponse response);
protected static async Task<JsonElement> ReadJson(HttpContext ctx) {
protected static async Task<JsonElement> ReadJson(HttpRequest request) {
try {
return await ctx.Request.ReadFromJsonAsync(JsonElementContext.Default.JsonElement);
return await request.ReadFromJsonAsync(JsonElementContext.Default.JsonElement);
} catch (JsonException) {
throw new HttpException(HttpStatusCode.UnsupportedMediaType, "This endpoint only accepts JSON.");
}

View File

@@ -1,24 +0,0 @@
using System.Net;
using System.Threading.Tasks;
using DHT.Server.Data;
using DHT.Server.Database;
using DHT.Utils.Http;
using Microsoft.AspNetCore.Http;
namespace DHT.Server.Endpoints;
sealed class GetAttachmentEndpoint : BaseEndpoint {
public GetAttachmentEndpoint(IDatabaseFile db) : base(db) {}
protected override async Task<IHttpOutput> Respond(HttpContext ctx) {
string attachmentUrl = WebUtility.UrlDecode((string) ctx.Request.RouteValues["url"]!);
DownloadedAttachment? maybeDownloadedAttachment = await Db.Downloads.GetDownloadedAttachment(attachmentUrl);
if (maybeDownloadedAttachment is {} downloadedAttachment) {
return new HttpOutput.File(downloadedAttachment.Type, downloadedAttachment.Data);
}
else {
return new HttpOutput.Redirect(attachmentUrl, permanent: false);
}
}
}

View File

@@ -0,0 +1,19 @@
using System.Net;
using System.Threading.Tasks;
using DHT.Server.Database;
using DHT.Server.Download;
using DHT.Utils.Http;
using Microsoft.AspNetCore.Http;
namespace DHT.Server.Endpoints;
sealed class GetDownloadedFileEndpoint(IDatabaseFile db) : BaseEndpoint(db) {
protected override async Task Respond(HttpRequest request, HttpResponse response) {
string url = WebUtility.UrlDecode((string) request.RouteValues["url"]!);
string normalizedUrl = DiscordCdn.NormalizeUrl(url);
if (!await Db.Downloads.GetSuccessfulDownloadWithData(normalizedUrl, (download, stream) => response.WriteStreamAsync(download.Type, download.Size, stream))) {
response.Redirect(url, permanent: false);
}
}
}

View File

@@ -1,5 +1,5 @@
using System.Net.Mime;
using System.Reflection;
using System.Text;
using System.Threading.Tasks;
using System.Web;
using DHT.Server.Database;
@@ -10,25 +10,19 @@ using Microsoft.AspNetCore.Http;
namespace DHT.Server.Endpoints;
sealed class GetTrackingScriptEndpoint : BaseEndpoint {
sealed class GetTrackingScriptEndpoint(IDatabaseFile db, ServerParameters parameters) : BaseEndpoint(db) {
private static ResourceLoader Resources { get; } = new (Assembly.GetExecutingAssembly());
private readonly ServerParameters serverParameters;
public GetTrackingScriptEndpoint(IDatabaseFile db, ServerParameters parameters) : base(db) {
serverParameters = parameters;
}
protected override async Task<IHttpOutput> Respond(HttpContext ctx) {
protected override async Task Respond(HttpRequest request, HttpResponse response) {
string bootstrap = await Resources.ReadTextAsync("Tracker/bootstrap.js");
string script = bootstrap.Replace("= 0; /*[PORT]*/", "= " + serverParameters.Port + ";")
.Replace("/*[TOKEN]*/", HttpUtility.JavaScriptStringEncode(serverParameters.Token))
string script = bootstrap.Replace("= 0; /*[PORT]*/", "= " + parameters.Port + ";")
.Replace("/*[TOKEN]*/", HttpUtility.JavaScriptStringEncode(parameters.Token))
.Replace("/*[IMPORTS]*/", await Resources.ReadJoinedAsync("Tracker/scripts/", '\n'))
.Replace("/*[CSS-CONTROLLER]*/", await Resources.ReadTextAsync("Tracker/styles/controller.css"))
.Replace("/*[CSS-SETTINGS]*/", await Resources.ReadTextAsync("Tracker/styles/settings.css"))
.Replace("/*[DEBUGGER]*/", ctx.Request.Query.ContainsKey("debug") ? "debugger;" : "");
.Replace("/*[DEBUGGER]*/", request.Query.ContainsKey("debug") ? "debugger;" : "");
ctx.Response.Headers.Append("X-DHT", "1");
return new HttpOutput.File("text/javascript", Encoding.UTF8.GetBytes(script));
response.Headers.Append("X-DHT", "1");
await response.WriteTextAsync(MediaTypeNames.Text.JavaScript, script);
}
}

View File

@@ -8,18 +8,14 @@ using Microsoft.AspNetCore.Http;
namespace DHT.Server.Endpoints;
sealed class TrackChannelEndpoint : BaseEndpoint {
public TrackChannelEndpoint(IDatabaseFile db) : base(db) {}
protected override async Task<IHttpOutput> Respond(HttpContext ctx) {
var root = await ReadJson(ctx);
sealed class TrackChannelEndpoint(IDatabaseFile db) : BaseEndpoint(db) {
protected override async Task Respond(HttpRequest request, HttpResponse response) {
var root = await ReadJson(request);
var server = ReadServer(root.RequireObject("server"), "server");
var channel = ReadChannel(root.RequireObject("channel"), "channel", server.Id);
await Db.Servers.Add([server]);
await Db.Channels.Add([channel]);
return HttpOutput.None;
}
private static Data.Server ReadServer(JsonElement json, string path) => new () {

View File

@@ -15,14 +15,12 @@ using Microsoft.AspNetCore.Http;
namespace DHT.Server.Endpoints;
sealed class TrackMessagesEndpoint : BaseEndpoint {
sealed class TrackMessagesEndpoint(IDatabaseFile db) : BaseEndpoint(db) {
private const string HasNewMessages = "1";
private const string NoNewMessages = "0";
public TrackMessagesEndpoint(IDatabaseFile db) : base(db) {}
protected override async Task<IHttpOutput> Respond(HttpContext ctx) {
var root = await ReadJson(ctx);
protected override async Task Respond(HttpRequest request, HttpResponse response) {
var root = await ReadJson(request);
if (root.ValueKind != JsonValueKind.Array) {
throw new HttpException(HttpStatusCode.BadRequest, "Expected root element to be an array.");
@@ -43,7 +41,7 @@ sealed class TrackMessagesEndpoint : BaseEndpoint {
await Db.Messages.Add(messages);
return new HttpOutput.Text(anyNewMessages ? HasNewMessages : NoNewMessages);
await response.WriteTextAsync(anyNewMessages ? HasNewMessages : NoNewMessages);
}
private static Message ReadMessage(JsonElement json, string path) => new () {

View File

@@ -8,11 +8,9 @@ using Microsoft.AspNetCore.Http;
namespace DHT.Server.Endpoints;
sealed class TrackUsersEndpoint : BaseEndpoint {
public TrackUsersEndpoint(IDatabaseFile db) : base(db) {}
protected override async Task<IHttpOutput> Respond(HttpContext ctx) {
var root = await ReadJson(ctx);
sealed class TrackUsersEndpoint(IDatabaseFile db) : BaseEndpoint(db) {
protected override async Task Respond(HttpRequest request, HttpResponse response) {
var root = await ReadJson(request);
if (root.ValueKind != JsonValueKind.Array) {
throw new HttpException(HttpStatusCode.BadRequest, "Expected root element to be an array.");
@@ -26,8 +24,6 @@ sealed class TrackUsersEndpoint : BaseEndpoint {
}
await Db.Users.Add(users);
return HttpOutput.None;
}
private static User ReadUser(JsonElement json, string path) => new () {

View File

@@ -41,7 +41,7 @@ sealed class Startup {
app.UseEndpoints(endpoints => {
endpoints.MapGet("/get-tracking-script", new GetTrackingScriptEndpoint(db, parameters).Handle);
endpoints.MapGet("/get-attachment/{url}", new GetAttachmentEndpoint(db).Handle);
endpoints.MapGet("/get-downloaded-file/{url}", new GetDownloadedFileEndpoint(db).Handle);
endpoints.MapPost("/track-channel", new TrackChannelEndpoint(db).Handle);
endpoints.MapPost("/track-users", new TrackUsersEndpoint(db).Handle);
endpoints.MapPost("/track-messages", new TrackMessagesEndpoint(db).Handle);

View File

@@ -6,18 +6,12 @@ using DHT.Server.Service;
namespace DHT.Server;
public sealed class State : IAsyncDisposable {
public static State Dummy { get; } = new (DummyDatabaseFile.Instance);
public sealed class State(IDatabaseFile db, int? concurrentDownloads) : IAsyncDisposable {
public static State Dummy { get; } = new (DummyDatabaseFile.Instance, null);
public IDatabaseFile Db { get; }
public Downloader Downloader { get; }
public ServerManager Server { get; }
public State(IDatabaseFile db) {
Db = db;
Downloader = new Downloader(db);
Server = new ServerManager(db);
}
public IDatabaseFile Db { get; } = db;
public Downloader Downloader { get; } = new (db, concurrentDownloads);
public ServerManager Server { get; } = new (db);
public async ValueTask DisposeAsync() {
await Downloader.Stop();

View File

@@ -8,7 +8,7 @@ public static class LinqExtensions {
HashSet<TKey>? seenKeys = null;
foreach (var item in collection) {
seenKeys ??= new HashSet<TKey>();
seenKeys ??= [];
if (seenKeys.Add(getKeyFromItem(item))) {
yield return item;

View File

@@ -0,0 +1,33 @@
using System.IO;
using System.Net.Mime;
using System.Text;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
namespace DHT.Utils.Http;
public static class HttpExtensions {
public static Task WriteTextAsync(this HttpResponse response, string text) {
return WriteTextAsync(response, MediaTypeNames.Text.Plain, text);
}
public static async Task WriteTextAsync(this HttpResponse response, string contentType, string text) {
response.ContentType = contentType;
await response.StartAsync();
await response.WriteAsync(text, Encoding.UTF8);
}
public static async Task WriteFileAsync(this HttpResponse response, string? contentType, byte[] bytes) {
response.ContentType = contentType ?? string.Empty;
response.ContentLength = bytes.Length;
await response.StartAsync();
await response.Body.WriteAsync(bytes);
}
public static async Task WriteStreamAsync(this HttpResponse response, string? contentType, ulong? contentLength, Stream source) {
response.ContentType = contentType ?? string.Empty;
response.ContentLength = (long?) contentLength;
await response.StartAsync();
await source.CopyToAsync(response.Body);
}
}

View File

@@ -1,35 +0,0 @@
using System.Text;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
namespace DHT.Utils.Http;
public static class HttpOutput {
public static IHttpOutput None { get; } = new NoneImpl();
private sealed class NoneImpl : IHttpOutput {
public Task WriteTo(HttpResponse response) {
return Task.CompletedTask;
}
}
public sealed class Text(string text) : IHttpOutput {
public Task WriteTo(HttpResponse response) {
return response.WriteAsync(text, Encoding.UTF8);
}
}
public sealed class File(string? contentType, byte[] bytes) : IHttpOutput {
public async Task WriteTo(HttpResponse response) {
response.ContentType = contentType ?? string.Empty;
await response.Body.WriteAsync(bytes);
}
}
public sealed class Redirect(string url, bool permanent) : IHttpOutput {
public Task WriteTo(HttpResponse response) {
response.Redirect(url, permanent);
return Task.CompletedTask;
}
}
}

View File

@@ -1,8 +0,0 @@
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
namespace DHT.Utils.Http;
public interface IHttpOutput {
Task WriteTo(HttpResponse response);
}

View File

@@ -2,6 +2,7 @@ using System;
using System.Reactive.Subjects;
using System.Threading;
using System.Threading.Tasks;
using DHT.Utils.Logging;
namespace DHT.Utils.Tasks;
@@ -9,9 +10,9 @@ public sealed class ObservableThrottledTask<T> : IObservable<T>, IDisposable {
private readonly ReplaySubject<T> subject;
private readonly ThrottledTask<T> task;
public ObservableThrottledTask(TaskScheduler resultScheduler) {
public ObservableThrottledTask(Log log, TaskScheduler resultScheduler) {
this.subject = new ReplaySubject<T>(bufferSize: 1);
this.task = new ThrottledTask<T>(subject.OnNext, resultScheduler);
this.task = new ThrottledTask<T>(log, subject.OnNext, resultScheduler);
}
public void Post(Func<CancellationToken, Task<T>> resultComputer) {

View File

@@ -2,6 +2,7 @@ using System;
using System.Threading;
using System.Threading.Channels;
using System.Threading.Tasks;
using DHT.Utils.Logging;
namespace DHT.Utils.Tasks;
@@ -14,8 +15,11 @@ public abstract class ThrottledTaskBase<T> : IDisposable {
});
private readonly CancellationTokenSource cancellationTokenSource = new ();
private readonly Log log;
internal ThrottledTaskBase() {}
internal ThrottledTaskBase(Log log) {
this.log = log;
}
protected async Task ReaderTask() {
var cancellationToken = cancellationTokenSource.Token;
@@ -26,8 +30,8 @@ public abstract class ThrottledTaskBase<T> : IDisposable {
await Run(item, cancellationToken);
} catch (OperationCanceledException) {
throw;
} catch (Exception) {
// Ignore.
} catch (Exception e) {
log.Error("Caught exception in task: " + e);
}
}
} catch (OperationCanceledException) {
@@ -53,7 +57,7 @@ public sealed class ThrottledTask : ThrottledTaskBase<Task> {
private readonly Action resultProcessor;
private readonly TaskScheduler resultScheduler;
public ThrottledTask(Action resultProcessor, TaskScheduler resultScheduler) {
public ThrottledTask(Log log, Action resultProcessor, TaskScheduler resultScheduler) : base(log) {
this.resultProcessor = resultProcessor;
this.resultScheduler = resultScheduler;
@@ -70,7 +74,7 @@ public sealed class ThrottledTask<T> : ThrottledTaskBase<Task<T>> {
private readonly Action<T> resultProcessor;
private readonly TaskScheduler resultScheduler;
public ThrottledTask(Action<T> resultProcessor, TaskScheduler resultScheduler) {
public ThrottledTask(Log log, Action<T> resultProcessor, TaskScheduler resultScheduler) : base(log) {
this.resultProcessor = resultProcessor;
this.resultScheduler = resultScheduler;

View File

@@ -8,5 +8,5 @@ using DHT.Utils;
namespace DHT.Utils;
static class Version {
public const string Tag = "40.0.0.0";
public const string Tag = "41.2.0.0";
}

Binary file not shown.