1
0
mirror of https://github.com/chylex/Discord-History-Tracker.git synced 2025-08-17 19:31:42 +02:00

34 Commits
v39 ... v40.0

Author SHA1 Message Date
ef3e34066a Release v40.0 2023-12-31 20:18:24 +01:00
37374eeb18 Migrate ConfigureAwait to Task.Run 2023-12-31 20:18:08 +01:00
23ddb45a0d Make opening/saving viewer asynchronous 2023-12-31 20:18:08 +01:00
9904a711f7 Make database connection pool asynchronous 2023-12-31 19:47:28 +01:00
d5720c8758 Code cleanup 2023-12-31 19:44:44 +01:00
89161e14b1 Increase default window size slightly 2023-12-31 16:54:14 +01:00
9d208b026c Synchronize publishing downloaded items to UI 2023-12-31 16:51:33 +01:00
119649ef9b Optimize rendering of download statistics table 2023-12-31 15:33:00 +01:00
0bc6232da7 Convert database events to reactive 2023-12-31 14:42:03 +01:00
de266473c5 Migrate models to a custom version of MVVM Community Toolkit 2023-12-31 14:42:03 +01:00
e0f359c15b Fix wrong attachment download progress information after re-enqueueing items 2023-12-30 12:27:44 +01:00
935f11d736 Make database schema upgrades asynchronous 2023-12-30 12:27:44 +01:00
f64141e768 Rewrite database interface to be asynchronous and improve UI 2023-12-30 12:27:44 +01:00
edea3470df Set C# language version to 12 2023-12-28 20:34:45 +01:00
031d521402 Convert attachment download events to reactive 2023-12-28 17:33:30 +01:00
0131f8cb50 Refactor integrated server management 2023-12-28 17:33:30 +01:00
3bf5acfa65 Fix trailing spaces Rider generates in code 2023-12-27 08:22:23 +01:00
f603c861c5 Encapsulate server-side state and ensure graceful shutdown when closing main window 2023-12-27 08:22:23 +01:00
d2934f4d6a Update user agent of downloader 2023-12-25 09:36:14 +01:00
567253d147 Use multiple threads to download attachments 2023-12-25 09:36:14 +01:00
aa6555990c Fix throwing exceptions in UI binding converters 2023-12-25 07:04:31 +01:00
3d9d6a454a Remove unnecessary ASP.NET features 2023-12-23 13:47:33 +01:00
ee39780928 Rewrite token authorization checks in integrated server 2023-12-23 13:47:31 +01:00
7b58f973a0 Disable ASP.NET logging and use custom logging for request duration 2023-12-23 11:19:54 +01:00
93fe018343 Add -console argument to show a console on Windows 2023-12-23 08:46:43 +01:00
4f5e27f651 Release v39.1 2023-12-22 16:31:11 +01:00
cbf81ec95a Fix missing JSON source generator when parsing integrated server requests 2023-12-22 16:31:11 +01:00
8a80cb8c20 Show progress dialog when upgrading database schema 2023-12-22 16:18:03 +01:00
865deb356a Fix progress dialog not propagating exceptions from its task 2023-12-22 14:47:55 +01:00
069ab97196 Disable reflection-based JSON serialization 2023-12-22 05:54:24 +01:00
caab038eaa Use source generators for JSON serialization everywhere 2023-12-22 05:24:28 +01:00
fb837374fc Enable single file compression and disable unnecessary .NET features 2023-12-22 02:30:24 +01:00
65d935cca1 Use compiled bindings in Avalonia XAML 2023-12-21 08:55:56 +01:00
6e64c86d7a Optimize viewer JSON export using source generators 2023-12-21 08:29:07 +01:00
124 changed files with 3555 additions and 2757 deletions

View File

@@ -1,4 +1,3 @@
using System;
using Avalonia;
using Avalonia.Controls.ApplicationLifetimes;
using Avalonia.Markup.Xaml;
@@ -13,7 +12,7 @@ sealed class App : Application {
public override void OnFrameworkInitializationCompleted() {
if (ApplicationLifetime is IClassicDesktopStyleApplicationLifetime desktop) {
desktop.MainWindow = new MainWindow(new Arguments(desktop.Args ?? Array.Empty<string>()));
desktop.MainWindow = new MainWindow(Program.Arguments);
}
base.OnFrameworkInitializationCompleted();

View File

@@ -1,4 +1,5 @@
using System;
using System.Collections.Generic;
using DHT.Utils.Logging;
namespace DHT.Desktop;
@@ -6,29 +7,36 @@ namespace DHT.Desktop;
sealed class Arguments {
private static readonly Log Log = Log.ForType<Arguments>();
public static Arguments Empty => new(Array.Empty<string>());
private const int FirstArgument = 1;
public static Arguments Empty => new (Array.Empty<string>());
public bool Console { get; }
public string? DatabaseFile { get; }
public ushort? ServerPort { get; }
public string? ServerToken { get; }
public Arguments(string[] args) {
for (int i = 0; i < args.Length; i++) {
public Arguments(IReadOnlyList<string> args) {
for (int i = FirstArgument; i < args.Count; i++) {
string key = args[i];
switch (key) {
case "-debug":
Log.IsDebugEnabled = true;
continue;
case "-console":
Console = true;
continue;
}
string value;
if (i == 0 && !key.StartsWith('-')) {
if (i == FirstArgument && !key.StartsWith('-')) {
value = key;
key = "-db";
}
else if (i >= args.Length - 1) {
else if (i >= args.Count - 1) {
Log.Warn("Missing value for command line argument: " + key);
continue;
}

View File

@@ -19,13 +19,13 @@ sealed class BytesValueConverter : IValueConverter {
}
}
private static readonly Unit[] Units = {
new ("B", decimalPlaces: 0),
new ("kB", decimalPlaces: 0),
new ("MB", decimalPlaces: 1),
new ("GB", decimalPlaces: 1),
new ("TB", decimalPlaces: 1)
};
private static readonly Unit[] Units = [
new Unit("B", decimalPlaces: 0),
new Unit("kB", decimalPlaces: 0),
new Unit("MB", decimalPlaces: 1),
new Unit("GB", decimalPlaces: 1),
new Unit("TB", decimalPlaces: 1)
];
private const int Scale = 1000;

View File

@@ -9,6 +9,7 @@ using DHT.Desktop.Dialogs.Message;
using DHT.Server.Database;
using DHT.Server.Database.Exceptions;
using DHT.Server.Database.Sqlite;
using DHT.Server.Database.Sqlite.Utils;
using DHT.Utils.Logging;
namespace DHT.Desktop.Common;
@@ -18,9 +19,9 @@ static class DatabaseGui {
private const string DatabaseFileInitialName = "archive.dht";
private static readonly IReadOnlyList<FilePickerFileType> DatabaseFileDialogFilter = new List<FilePickerFileType> {
FileDialogs.CreateFilter("Discord History Tracker Database", new [] { "dht" })
};
private static readonly IReadOnlyList<FilePickerFileType> DatabaseFileDialogFilter = [
FileDialogs.CreateFilter("Discord History Tracker Database", ["dht"])
];
public static async Task<string[]> NewOpenDatabaseFilesDialog(Window window, string? suggestedDirectory) {
return await window.StorageProvider.OpenFiles(new FilePickerOpenOptions {
@@ -41,11 +42,11 @@ static class DatabaseGui {
});
}
public static async Task<IDatabaseFile?> TryOpenOrCreateDatabaseFromPath(string path, Window window, Func<Task<bool>> checkCanUpgradeDatabase) {
public static async Task<IDatabaseFile?> TryOpenOrCreateDatabaseFromPath(string path, Window window, ISchemaUpgradeCallbacks schemaUpgradeCallbacks) {
IDatabaseFile? file = null;
try {
file = await SqliteDatabaseFile.OpenOrCreate(path, checkCanUpgradeDatabase);
file = await SqliteDatabaseFile.OpenOrCreate(path, schemaUpgradeCallbacks);
} catch (InvalidDatabaseVersionException ex) {
await Dialog.ShowOk(window, "Database Error", "Database '" + Path.GetFileName(path) + "' appears to be corrupted (invalid version: " + ex.Version + ").");
} catch (DatabaseTooNewException ex) {

View File

@@ -9,6 +9,7 @@
<PropertyGroup>
<OutputType>WinExe</OutputType>
<ApplicationIcon>./Resources/icon.ico</ApplicationIcon>
<AvaloniaUseCompiledBindingsByDefault>true</AvaloniaUseCompiledBindingsByDefault>
<CheckForOverflowUnderflow>true</CheckForOverflowUnderflow>
<SatelliteResourceLanguages>en</SatelliteResourceLanguages>
</PropertyGroup>
@@ -20,7 +21,9 @@
<PackageReference Include="Avalonia.Desktop" Version="11.0.6" />
<PackageReference Include="Avalonia.Diagnostics" Version="11.0.6" Condition=" '$(Configuration)' == 'Debug' " />
<PackageReference Include="Avalonia.Fonts.Inter" Version="11.0.6" />
<PackageReference Include="Avalonia.ReactiveUI" Version="11.0.6" />
<PackageReference Include="Avalonia.Themes.Fluent" Version="11.0.6" />
<PackageReference Include="CommunityToolkit.Mvvm" Version="999.0.0-build.0.g0d941a6a62" />
</ItemGroup>
<ItemGroup>

View File

@@ -5,6 +5,7 @@
xmlns:namespace="clr-namespace:DHT.Desktop.Dialogs.CheckBox"
mc:Ignorable="d" d:DesignWidth="500"
x:Class="DHT.Desktop.Dialogs.CheckBox.CheckBoxDialog"
x:DataType="namespace:CheckBoxDialogModel"
Title="{Binding Title}"
Icon="avares://DiscordHistoryTracker/Resources/icon.ico"
Width="500" SizeToContent="Height" CanResize="False"
@@ -37,7 +38,7 @@
<ItemsRepeater ItemsSource="{Binding Items}">
<ItemsRepeater.ItemTemplate>
<DataTemplate>
<CheckBox IsChecked="{Binding Checked}">
<CheckBox IsChecked="{Binding IsChecked}">
<Label>
<TextBlock Text="{Binding Title}" TextWrapping="Wrap" />
</Label>

View File

@@ -2,11 +2,11 @@ using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Linq;
using DHT.Utils.Models;
using CommunityToolkit.Mvvm.ComponentModel;
namespace DHT.Desktop.Dialogs.CheckBox;
class CheckBoxDialogModel : BaseModel {
class CheckBoxDialogModel : ObservableObject {
public string Title { get; init; } = "";
private IReadOnlyList<CheckBoxItem> items = Array.Empty<CheckBoxItem>();
@@ -29,8 +29,8 @@ class CheckBoxDialogModel : BaseModel {
private bool pauseCheckEvents = false;
public bool AreAllSelected => Items.All(static item => item.Checked);
public bool AreNoneSelected => Items.All(static item => !item.Checked);
public bool AreAllSelected => Items.All(static item => item.IsChecked);
public bool AreNoneSelected => Items.All(static item => !item.IsChecked);
public void SelectAll() => SetAllChecked(true);
public void SelectNone() => SetAllChecked(false);
@@ -39,7 +39,7 @@ class CheckBoxDialogModel : BaseModel {
pauseCheckEvents = true;
foreach (var item in Items) {
item.Checked = isChecked;
item.IsChecked = isChecked;
}
pauseCheckEvents = false;
@@ -52,16 +52,16 @@ class CheckBoxDialogModel : BaseModel {
}
private void OnItemPropertyChanged(object? sender, PropertyChangedEventArgs e) {
if (!pauseCheckEvents && e.PropertyName == nameof(CheckBoxItem.Checked)) {
if (!pauseCheckEvents && e.PropertyName == nameof(CheckBoxItem.IsChecked)) {
UpdateBulkButtons();
}
}
}
sealed class CheckBoxDialogModel<T> : CheckBoxDialogModel {
public new IReadOnlyList<CheckBoxItem<T>> Items { get; }
private new IReadOnlyList<CheckBoxItem<T>> Items { get; }
public IEnumerable<CheckBoxItem<T>> SelectedItems => Items.Where(static item => item.Checked);
public IEnumerable<CheckBoxItem<T>> SelectedItems => Items.Where(static item => item.IsChecked);
public CheckBoxDialogModel(IEnumerable<CheckBoxItem<T>> items) {
this.Items = new List<CheckBoxItem<T>>(items);

View File

@@ -1,17 +1,13 @@
using DHT.Utils.Models;
using CommunityToolkit.Mvvm.ComponentModel;
namespace DHT.Desktop.Dialogs.CheckBox;
class CheckBoxItem : BaseModel {
partial class CheckBoxItem : ObservableObject {
public string Title { get; init; } = "";
public object? Item { get; init; } = null;
[ObservableProperty]
private bool isChecked = false;
public bool Checked {
get => isChecked;
set => Change(ref isChecked, value);
}
}
sealed class CheckBoxItem<T> : CheckBoxItem {

View File

@@ -5,6 +5,7 @@
xmlns:namespace="clr-namespace:DHT.Desktop.Dialogs.Message"
mc:Ignorable="d" d:DesignWidth="500"
x:Class="DHT.Desktop.Dialogs.Message.MessageDialog"
x:DataType="namespace:MessageDialogModel"
Title="{Binding Title}"
Icon="avares://DiscordHistoryTracker/Resources/icon.ico"
Width="500" SizeToContent="Height" CanResize="False"

View File

@@ -4,4 +4,6 @@ namespace DHT.Desktop.Dialogs.Progress;
interface IProgressCallback {
Task Update(string message, int finishedItems, int totalItems);
Task UpdateIndeterminate(string message);
Task Hide();
}

View File

@@ -5,6 +5,7 @@
xmlns:namespace="clr-namespace:DHT.Desktop.Dialogs.Progress"
mc:Ignorable="d" d:DesignWidth="500"
x:Class="DHT.Desktop.Dialogs.Progress.ProgressDialog"
x:DataType="namespace:ProgressDialogModel"
Title="{Binding Title}"
Icon="avares://DiscordHistoryTracker/Resources/icon.ico"
Opened="OnOpened"
@@ -31,12 +32,18 @@
</Style>
</Window.Styles>
<StackPanel Margin="20">
<DockPanel>
<TextBlock DockPanel.Dock="Right" Text="{Binding Items}" Classes="items" />
<TextBlock DockPanel.Dock="Left" Text="{Binding Message}" />
</DockPanel>
<ProgressBar Value="{Binding Progress}" />
</StackPanel>
<ItemsRepeater ItemsSource="{Binding Items}" Margin="0 10">
<ItemsRepeater.ItemTemplate>
<DataTemplate>
<StackPanel Margin="20 10" IsHitTestVisible="{Binding IsVisible}" Opacity="{Binding Opacity}">
<DockPanel>
<TextBlock DockPanel.Dock="Right" Text="{Binding Items}" Classes="items" />
<TextBlock DockPanel.Dock="Left" Text="{Binding Message}" />
</DockPanel>
<ProgressBar IsIndeterminate="{Binding IsIndeterminate}" Value="{Binding Progress}" />
</StackPanel>
</DataTemplate>
</ItemsRepeater.ItemTemplate>
</ItemsRepeater>
</Window>

View File

@@ -7,7 +7,60 @@ namespace DHT.Desktop.Dialogs.Progress;
[SuppressMessage("ReSharper", "MemberCanBeInternal")]
public sealed partial class ProgressDialog : Window {
internal static async Task Show(Window owner, string title, Func<ProgressDialog, IProgressCallback, Task> action) {
var taskCompletionSource = new TaskCompletionSource();
var dialog = new ProgressDialog();
dialog.DataContext = new ProgressDialogModel(title, async callbacks => {
try {
await action(dialog, callbacks[0]);
taskCompletionSource.SetResult();
} catch (Exception e) {
taskCompletionSource.SetException(e);
}
});
await dialog.ShowProgressDialog(owner);
await taskCompletionSource.Task;
}
internal static async Task ShowIndeterminate(Window owner, string title, string message, Func<ProgressDialog, Task> action) {
var taskCompletionSource = new TaskCompletionSource();
var dialog = new ProgressDialog();
dialog.DataContext = new ProgressDialogModel(title, async callbacks => {
await callbacks[0].UpdateIndeterminate(message);
try {
await action(dialog);
taskCompletionSource.SetResult();
} catch (Exception e) {
taskCompletionSource.SetException(e);
}
});
await dialog.ShowProgressDialog(owner);
await taskCompletionSource.Task;
}
internal static async Task<T> ShowIndeterminate<T>(Window owner, string title, string message, Func<ProgressDialog, Task<T>> action) {
var taskCompletionSource = new TaskCompletionSource<T>();
var dialog = new ProgressDialog();
dialog.DataContext = new ProgressDialogModel(title, async callbacks => {
await callbacks[0].UpdateIndeterminate(message);
try {
taskCompletionSource.SetResult(await action(dialog));
} catch (Exception e) {
taskCompletionSource.SetException(e);
}
});
await dialog.ShowProgressDialog(owner);
return await taskCompletionSource.Task;
}
private bool isFinished = false;
private Task progressTask = Task.CompletedTask;
public ProgressDialog() {
InitializeComponent();
@@ -15,7 +68,8 @@ public sealed partial class ProgressDialog : Window {
public void OnOpened(object? sender, EventArgs e) {
if (DataContext is ProgressDialogModel model) {
Task.Run(model.StartTask).ContinueWith(OnFinished, TaskScheduler.FromCurrentSynchronizationContext());
progressTask = Task.Run(model.StartTask);
progressTask.ContinueWith(OnFinished, TaskScheduler.FromCurrentSynchronizationContext());
}
}
@@ -27,4 +81,9 @@ public sealed partial class ProgressDialog : Window {
isFinished = true;
Close();
}
public async Task ShowProgressDialog(Window owner) {
await ShowDialog(owner);
await progressTask;
}
}

View File

@@ -1,65 +1,63 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Avalonia.Threading;
using DHT.Desktop.Common;
using DHT.Utils.Models;
namespace DHT.Desktop.Dialogs.Progress;
sealed class ProgressDialogModel : BaseModel {
sealed class ProgressDialogModel {
public string Title { get; init; } = "";
private string message = "";
public string Message {
get => message;
private set => Change(ref message, value);
}
private string items = "";
public string Items {
get => items;
private set => Change(ref items, value);
}
private int progress = 0;
public int Progress {
get => progress;
private set => Change(ref progress, value);
}
public IReadOnlyList<ProgressItem> Items { get; } = Array.Empty<ProgressItem>();
private readonly TaskRunner? task;
[Obsolete("Designer")]
public ProgressDialogModel() {}
public ProgressDialogModel(TaskRunner task) {
public ProgressDialogModel(string title, TaskRunner task, int progressItems = 1) {
this.Title = title;
this.task = task;
this.Items = Enumerable.Range(0, progressItems).Select(static _ => new ProgressItem()).ToArray();
}
internal async Task StartTask() {
if (task != null) {
await task(new Callback(this));
await task(Items.Select(static item => new Callback(item)).ToArray());
}
}
public delegate Task TaskRunner(IProgressCallback callback);
public delegate Task TaskRunner(IReadOnlyList<IProgressCallback> callbacks);
private sealed class Callback : IProgressCallback {
private readonly ProgressDialogModel model;
private readonly ProgressItem item;
public Callback(ProgressDialogModel model) {
this.model = model;
public Callback(ProgressItem item) {
this.item = item;
}
async Task IProgressCallback.Update(string message, int finishedItems, int totalItems) {
public async Task Update(string message, int finishedItems, int totalItems) {
await Dispatcher.UIThread.InvokeAsync(() => {
model.Message = message;
model.Items = finishedItems.Format() + " / " + totalItems.Format();
model.Progress = 100 * finishedItems / totalItems;
item.Message = message;
item.Items = totalItems == 0 ? string.Empty : finishedItems.Format() + " / " + totalItems.Format();
item.Progress = totalItems == 0 ? 0 : 100 * finishedItems / totalItems;
item.IsIndeterminate = false;
});
}
public async Task UpdateIndeterminate(string message) {
await Dispatcher.UIThread.InvokeAsync(() => {
item.Message = message;
item.Items = string.Empty;
item.Progress = 0;
item.IsIndeterminate = true;
});
}
public Task Hide() {
return Update(string.Empty, 0, 0);
}
}
}

View File

@@ -0,0 +1,30 @@
using CommunityToolkit.Mvvm.ComponentModel;
namespace DHT.Desktop.Dialogs.Progress;
sealed partial class ProgressItem : ObservableObject {
[ObservableProperty(Setter = Access.Private)]
[NotifyPropertyChangedFor(nameof(Opacity))]
private bool isVisible = false;
public double Opacity => IsVisible ? 1.0 : 0.0;
private string message = "";
public string Message {
get => message;
set {
SetProperty(ref message, value);
IsVisible = !string.IsNullOrEmpty(value);
}
}
[ObservableProperty]
private string items = "";
[ObservableProperty]
private int progress = 0;
[ObservableProperty]
private bool isIndeterminate;
}

View File

@@ -5,6 +5,7 @@
xmlns:namespace="clr-namespace:DHT.Desktop.Dialogs.TextBox"
mc:Ignorable="d" d:DesignWidth="500"
x:Class="DHT.Desktop.Dialogs.TextBox.TextBoxDialog"
x:DataType="namespace:TextBoxDialogModel"
Title="{Binding Title}"
Icon="avares://DiscordHistoryTracker/Resources/icon.ico"
Width="500" SizeToContent="Height" CanResize="False"

View File

@@ -3,7 +3,7 @@ using Avalonia.Controls;
using Avalonia.Interactivity;
using DHT.Desktop.Dialogs.Message;
namespace DHT.Desktop.Dialogs.TextBox;
namespace DHT.Desktop.Dialogs.TextBox;
[SuppressMessage("ReSharper", "MemberCanBeInternal")]
public sealed partial class TextBoxDialog : Window {

View File

@@ -2,11 +2,11 @@ using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Linq;
using DHT.Utils.Models;
using CommunityToolkit.Mvvm.ComponentModel;
namespace DHT.Desktop.Dialogs.TextBox;
namespace DHT.Desktop.Dialogs.TextBox;
class TextBoxDialogModel : BaseModel {
class TextBoxDialogModel : ObservableObject {
public string Title { get; init; } = "";
public string Description { get; init; } = "";
@@ -36,7 +36,7 @@ class TextBoxDialogModel : BaseModel {
}
sealed class TextBoxDialogModel<T> : TextBoxDialogModel {
public new IReadOnlyList<TextBoxItem<T>> Items { get; }
private new IReadOnlyList<TextBoxItem<T>> Items { get; }
public IEnumerable<TextBoxItem<T>> ValidItems => Items.Where(static item => item.IsValid);

View File

@@ -1,11 +1,11 @@
using System;
using System.Collections;
using System.ComponentModel;
using DHT.Utils.Models;
using CommunityToolkit.Mvvm.ComponentModel;
namespace DHT.Desktop.Dialogs.TextBox;
namespace DHT.Desktop.Dialogs.TextBox;
class TextBoxItem : BaseModel, INotifyDataErrorInfo {
class TextBoxItem : ObservableObject, INotifyDataErrorInfo {
public string Title { get; init; } = "";
public object? Item { get; init; } = null;
@@ -17,7 +17,7 @@ class TextBoxItem : BaseModel, INotifyDataErrorInfo {
public string Value {
get => this.value;
set {
Change(ref this.value, value);
SetProperty(ref this.value, value);
ErrorsChanged?.Invoke(this, new DataErrorsChangedEventArgs(nameof(Value)));
}
}

View File

@@ -1,9 +1,9 @@
using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Runtime.InteropServices;
using System.Text.Json;
using System.Text.Json.Nodes;
using System.Threading.Tasks;
using DHT.Utils.Logging;
using static System.Environment.SpecialFolder;
@@ -39,7 +39,8 @@ static class DiscordAppSettings {
public static async Task<bool?> AreDevToolsEnabled() {
try {
return AreDevToolsEnabled(await ReadSettingsJson());
var settingsJson = await ReadSettingsJson();
return AreDevToolsEnabled(settingsJson);
} catch (Exception e) {
Log.Error("Cannot read settings file.");
Log.Error(e);
@@ -47,12 +48,12 @@ static class DiscordAppSettings {
}
}
private static bool AreDevToolsEnabled(Dictionary<string, object?> json) {
return json.TryGetValue(JsonKeyDevTools, out var value) && value is JsonElement { ValueKind: JsonValueKind.True };
private static bool AreDevToolsEnabled(JsonObject json) {
return json.TryGetPropertyValue(JsonKeyDevTools, out var node) && node?.GetValueKind() == JsonValueKind.True;
}
public static async Task<SettingsJsonResult> ConfigureDevTools(bool enable) {
Dictionary<string, object?> json;
JsonObject json;
try {
json = await ReadSettingsJson();
@@ -109,13 +110,13 @@ static class DiscordAppSettings {
return SettingsJsonResult.Success;
}
private static async Task<Dictionary<string, object?>> ReadSettingsJson() {
private static async Task<JsonObject> ReadSettingsJson() {
await using var stream = new FileStream(JsonFilePath, FileMode.Open, FileAccess.Read, FileShare.Read);
return await JsonSerializer.DeserializeAsync<Dictionary<string, object?>?>(stream) ?? throw new JsonException();
return await JsonSerializer.DeserializeAsync(stream, DiscordAppSettingsJsonContext.Default.JsonObject) ?? throw new JsonException();
}
private static async Task WriteSettingsJson(Dictionary<string, object?> json) {
private static async Task WriteSettingsJson(JsonObject json) {
await using var stream = new FileStream(JsonFilePath, FileMode.Truncate, FileAccess.Write, FileShare.None);
await JsonSerializer.SerializeAsync(stream, json, new JsonSerializerOptions { WriteIndented = true });
await JsonSerializer.SerializeAsync(stream, json, DiscordAppSettingsJsonContext.Default.JsonObject);
}
}

View File

@@ -0,0 +1,8 @@
using System.Text.Json.Nodes;
using System.Text.Json.Serialization;
namespace DHT.Desktop.Discord;
[JsonSourceGenerationOptions(GenerationMode = JsonSourceGenerationMode.Default, WriteIndented = true)]
[JsonSerializable(typeof(JsonObject))]
sealed partial class DiscordAppSettingsJsonContext : JsonSerializerContext;

View File

@@ -5,6 +5,7 @@
xmlns:main="clr-namespace:DHT.Desktop.Main"
mc:Ignorable="d" d:DesignWidth="480" d:DesignHeight="295"
x:Class="DHT.Desktop.Main.AboutWindow"
x:DataType="main:AboutWindowModel"
Title="About Discord History Tracker"
Icon="avares://DiscordHistoryTracker/Resources/icon.ico"
Width="480" Height="295" CanResize="False"

View File

@@ -4,7 +4,8 @@
xmlns:mc="http://schemas.openxmlformats.org/markup-compatibility/2006"
xmlns:controls="clr-namespace:DHT.Desktop.Main.Controls"
mc:Ignorable="d"
x:Class="DHT.Desktop.Main.Controls.AttachmentFilterPanel">
x:Class="DHT.Desktop.Main.Controls.AttachmentFilterPanel"
x:DataType="controls:AttachmentFilterPanelModel">
<Design.DataContext>
<controls:AttachmentFilterPanelModel />

View File

@@ -1,76 +1,71 @@
using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Reactive.Linq;
using System.Threading.Tasks;
using Avalonia.ReactiveUI;
using CommunityToolkit.Mvvm.ComponentModel;
using DHT.Desktop.Common;
using DHT.Server;
using DHT.Server.Data.Filters;
using DHT.Server.Database;
using DHT.Utils.Models;
using DHT.Utils.Tasks;
namespace DHT.Desktop.Main.Controls;
sealed class AttachmentFilterPanelModel : BaseModel, IDisposable {
sealed partial class AttachmentFilterPanelModel : ObservableObject, IDisposable {
public sealed record Unit(string Name, uint Scale);
private static readonly Unit[] AllUnits = {
new ("B", 1),
new ("kB", 1024),
new ("MB", 1024 * 1024)
};
private static readonly Unit[] AllUnits = [
new Unit("B", 1),
new Unit("kB", 1024),
new Unit("MB", 1024 * 1024)
];
private static readonly HashSet<string> FilterProperties = new () {
private static readonly HashSet<string> FilterProperties = [
nameof(LimitSize),
nameof(MaximumSize),
nameof(MaximumSizeUnit)
};
];
public string FilterStatisticsText { get; private set; } = "";
[ObservableProperty]
private bool limitSize = false;
[ObservableProperty]
private ulong maximumSize = 0L;
[ObservableProperty]
private Unit maximumSizeUnit = AllUnits[0];
public bool LimitSize {
get => limitSize;
set => Change(ref limitSize, value);
}
public ulong MaximumSize {
get => maximumSize;
set => Change(ref maximumSize, value);
}
public Unit MaximumSizeUnit {
get => maximumSizeUnit;
set => Change(ref maximumSizeUnit, value);
}
public IEnumerable<Unit> Units => AllUnits;
private readonly IDatabaseFile db;
private readonly State state;
private readonly string verb;
private readonly AsyncValueComputer<long> matchingAttachmentCountComputer;
private readonly RestartableTask<long> matchingAttachmentCountTask;
private long? matchingAttachmentCount;
private readonly IDisposable attachmentCountSubscription;
private long? totalAttachmentCount;
[Obsolete("Designer")]
public AttachmentFilterPanelModel() : this(DummyDatabaseFile.Instance) {}
public AttachmentFilterPanelModel() : this(State.Dummy) {}
public AttachmentFilterPanelModel(IDatabaseFile db, string verb = "Matches") {
this.db = db;
public AttachmentFilterPanelModel(State state, string verb = "Matches") {
this.state = state;
this.verb = verb;
this.matchingAttachmentCountComputer = AsyncValueComputer<long>.WithResultProcessor(SetAttachmentCounts).Build();
this.matchingAttachmentCountTask = new RestartableTask<long>(SetAttachmentCounts, TaskScheduler.FromCurrentSynchronizationContext());
this.attachmentCountSubscription = state.Db.Attachments.TotalCount.ObserveOn(AvaloniaScheduler.Instance).Subscribe(OnAttachmentCountChanged);
UpdateFilterStatistics();
PropertyChanged += OnPropertyChanged;
db.Statistics.PropertyChanged += OnDbStatisticsChanged;
}
public void Dispose() {
db.Statistics.PropertyChanged -= OnDbStatisticsChanged;
attachmentCountSubscription.Dispose();
}
private void OnPropertyChanged(object? sender, PropertyChangedEventArgs e) {
@@ -78,25 +73,24 @@ sealed class AttachmentFilterPanelModel : BaseModel, IDisposable {
UpdateFilterStatistics();
}
}
private void OnDbStatisticsChanged(object? sender, PropertyChangedEventArgs e) {
if (e.PropertyName == nameof(DatabaseStatistics.TotalAttachments)) {
totalAttachmentCount = db.Statistics.TotalAttachments;
UpdateFilterStatistics();
}
private void OnAttachmentCountChanged(long newAttachmentCount) {
totalAttachmentCount = newAttachmentCount;
UpdateFilterStatistics();
}
private void UpdateFilterStatistics() {
var filter = CreateFilter();
if (filter.IsEmpty) {
matchingAttachmentCountComputer.Cancel();
matchingAttachmentCountTask.Cancel();
matchingAttachmentCount = totalAttachmentCount;
UpdateFilterStatisticsText();
}
else {
matchingAttachmentCount = null;
UpdateFilterStatisticsText();
matchingAttachmentCountComputer.Compute(() => db.CountAttachments(filter));
matchingAttachmentCountTask.Restart(cancellationToken => state.Db.Attachments.Count(filter, cancellationToken));
}
}
@@ -114,7 +108,7 @@ sealed class AttachmentFilterPanelModel : BaseModel, IDisposable {
}
public AttachmentFilter CreateFilter() {
AttachmentFilter filter = new();
AttachmentFilter filter = new ();
if (LimitSize) {
try {

View File

@@ -4,7 +4,8 @@
xmlns:mc="http://schemas.openxmlformats.org/markup-compatibility/2006"
xmlns:controls="clr-namespace:DHT.Desktop.Main.Controls"
mc:Ignorable="d"
x:Class="DHT.Desktop.Main.Controls.MessageFilterPanel">
x:Class="DHT.Desktop.Main.Controls.MessageFilterPanel"
x:DataType="controls:MessageFilterPanelModel">
<Design.DataContext>
<controls:MessageFilterPanelModel />

View File

@@ -2,22 +2,25 @@ using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Linq;
using System.Reactive.Linq;
using System.Text;
using System.Threading.Tasks;
using Avalonia.Controls;
using Avalonia.ReactiveUI;
using CommunityToolkit.Mvvm.ComponentModel;
using DHT.Desktop.Common;
using DHT.Desktop.Dialogs.CheckBox;
using DHT.Desktop.Dialogs.Message;
using DHT.Desktop.Dialogs.Progress;
using DHT.Server;
using DHT.Server.Data;
using DHT.Server.Data.Filters;
using DHT.Server.Database;
using DHT.Utils.Models;
using DHT.Utils.Tasks;
namespace DHT.Desktop.Main.Controls;
sealed class MessageFilterPanelModel : BaseModel, IDisposable {
private static readonly HashSet<string> FilterProperties = new () {
sealed partial class MessageFilterPanelModel : ObservableObject, IDisposable {
private static readonly HashSet<string> FilterProperties = [
nameof(FilterByDate),
nameof(StartDate),
nameof(EndDate),
@@ -25,7 +28,7 @@ sealed class MessageFilterPanelModel : BaseModel, IDisposable {
nameof(IncludedChannels),
nameof(FilterByUser),
nameof(IncludedUsers)
};
];
public string FilterStatisticsText { get; private set; } = "";
@@ -33,91 +36,76 @@ sealed class MessageFilterPanelModel : BaseModel, IDisposable {
public bool HasAnyFilters => FilterByDate || FilterByChannel || FilterByUser;
[ObservableProperty]
private bool filterByDate = false;
[ObservableProperty]
private DateTime? startDate = null;
[ObservableProperty]
private DateTime? endDate = null;
[ObservableProperty]
private bool filterByChannel = false;
[ObservableProperty]
private HashSet<ulong>? includedChannels = null;
[ObservableProperty]
private bool filterByUser = false;
[ObservableProperty]
private HashSet<ulong>? includedUsers = null;
public bool FilterByDate {
get => filterByDate;
set => Change(ref filterByDate, value);
}
public DateTime? StartDate {
get => startDate;
set => Change(ref startDate, value);
}
public DateTime? EndDate {
get => endDate;
set => Change(ref endDate, value);
}
public bool FilterByChannel {
get => filterByChannel;
set => Change(ref filterByChannel, value);
}
public HashSet<ulong> IncludedChannels {
get => includedChannels ?? db.GetAllChannels().Select(static channel => channel.Id).ToHashSet();
set => Change(ref includedChannels, value);
}
public bool FilterByUser {
get => filterByUser;
set => Change(ref filterByUser, value);
}
public HashSet<ulong> IncludedUsers {
get => includedUsers ?? db.GetAllUsers().Select(static user => user.Id).ToHashSet();
set => Change(ref includedUsers, value);
}
[ObservableProperty]
private string channelFilterLabel = "";
public string ChannelFilterLabel {
get => channelFilterLabel;
set => Change(ref channelFilterLabel, value);
}
[ObservableProperty]
private string userFilterLabel = "";
public string UserFilterLabel {
get => userFilterLabel;
set => Change(ref userFilterLabel, value);
}
private readonly Window window;
private readonly IDatabaseFile db;
private readonly State state;
private readonly string verb;
private readonly AsyncValueComputer<long> exportedMessageCountComputer;
private readonly RestartableTask<long> exportedMessageCountTask;
private long? exportedMessageCount;
private readonly IDisposable messageCountSubscription;
private long? totalMessageCount;
private readonly IDisposable channelCountSubscription;
private long? totalChannelCount;
private readonly IDisposable userCountSubscription;
private long? totalUserCount;
[Obsolete("Designer")]
public MessageFilterPanelModel() : this(null!, DummyDatabaseFile.Instance) {}
public MessageFilterPanelModel() : this(null!, State.Dummy) {}
public MessageFilterPanelModel(Window window, IDatabaseFile db, string verb = "Matches") {
public MessageFilterPanelModel(Window window, State state, string verb = "Matches") {
this.window = window;
this.db = db;
this.state = state;
this.verb = verb;
this.exportedMessageCountComputer = AsyncValueComputer<long>.WithResultProcessor(SetExportedMessageCount).Build();
this.exportedMessageCountTask = new RestartableTask<long>(SetExportedMessageCount, TaskScheduler.FromCurrentSynchronizationContext());
this.messageCountSubscription = state.Db.Messages.TotalCount.ObserveOn(AvaloniaScheduler.Instance).Subscribe(OnMessageCountChanged);
this.channelCountSubscription = state.Db.Channels.TotalCount.ObserveOn(AvaloniaScheduler.Instance).Subscribe(OnChannelCountChanged);
this.userCountSubscription = state.Db.Users.TotalCount.ObserveOn(AvaloniaScheduler.Instance).Subscribe(OnUserCountChanged);
UpdateFilterStatistics();
UpdateChannelFilterLabel();
UpdateUserFilterLabel();
PropertyChanged += OnPropertyChanged;
db.Statistics.PropertyChanged += OnDbStatisticsChanged;
}
public void Dispose() {
db.Statistics.PropertyChanged -= OnDbStatisticsChanged;
exportedMessageCountTask.Cancel();
messageCountSubscription.Dispose();
channelCountSubscription.Dispose();
userCountSubscription.Dispose();
}
private void OnPropertyChanged(object? sender, PropertyChangedEventArgs e) {
@@ -134,30 +122,54 @@ sealed class MessageFilterPanelModel : BaseModel, IDisposable {
}
}
private void OnDbStatisticsChanged(object? sender, PropertyChangedEventArgs e) {
if (e.PropertyName == nameof(DatabaseStatistics.TotalMessages)) {
totalMessageCount = db.Statistics.TotalMessages;
UpdateFilterStatistics();
private void OnMessageCountChanged(long newMessageCount) {
totalMessageCount = newMessageCount;
UpdateFilterStatistics();
}
private void OnChannelCountChanged(long newChannelCount) {
totalChannelCount = newChannelCount;
UpdateChannelFilterLabel();
}
private void OnUserCountChanged(long newUserCount) {
totalUserCount = newUserCount;
UpdateUserFilterLabel();
}
private void UpdateChannelFilterLabel() {
if (totalChannelCount.HasValue) {
long total = totalChannelCount.Value;
long included = FilterByChannel && IncludedChannels != null ? IncludedChannels.Count : total;
ChannelFilterLabel = "Selected " + included.Format() + " / " + total.Pluralize("channel") + ".";
}
else if (e.PropertyName == nameof(DatabaseStatistics.TotalChannels)) {
UpdateChannelFilterLabel();
else {
ChannelFilterLabel = "Loading...";
}
else if (e.PropertyName == nameof(DatabaseStatistics.TotalUsers)) {
UpdateUserFilterLabel();
}
private void UpdateUserFilterLabel() {
if (totalUserCount.HasValue) {
long total = totalUserCount.Value;
long included = FilterByUser && IncludedUsers != null ? IncludedUsers.Count : total;
UserFilterLabel = "Selected " + included.Format() + " / " + total.Pluralize("user") + ".";
}
else {
UserFilterLabel = "Loading...";
}
}
private void UpdateFilterStatistics() {
var filter = CreateFilter();
if (filter.IsEmpty) {
exportedMessageCountComputer.Cancel();
exportedMessageCountTask.Cancel();
exportedMessageCount = totalMessageCount;
UpdateFilterStatisticsText();
}
else {
exportedMessageCount = null;
UpdateFilterStatisticsText();
exportedMessageCountComputer.Compute(() => db.CountMessages(filter));
exportedMessageCountTask.Restart(cancellationToken => state.Db.Messages.Count(filter, cancellationToken));
}
}
@@ -174,103 +186,98 @@ sealed class MessageFilterPanelModel : BaseModel, IDisposable {
OnPropertyChanged(nameof(FilterStatisticsText));
}
public async void OpenChannelFilterDialog() {
var servers = db.GetAllServers().ToDictionary(static server => server.Id);
var items = new List<CheckBoxItem<ulong>>();
var included = IncludedChannels;
public async Task OpenChannelFilterDialog() {
async Task<List<CheckBoxItem<ulong>>> PrepareChannelItems(ProgressDialog dialog) {
var items = new List<CheckBoxItem<ulong>>();
var servers = await state.Db.Servers.Get().ToDictionaryAsync(static server => server.Id);
foreach (var channel in db.GetAllChannels()) {
var channelId = channel.Id;
var channelName = channel.Name;
await foreach (var channel in state.Db.Channels.Get()) {
var channelId = channel.Id;
var channelName = channel.Name;
string title;
if (servers.TryGetValue(channel.Server, out var server)) {
var titleBuilder = new StringBuilder();
var serverType = server.Type;
string title;
if (servers.TryGetValue(channel.Server, out var server)) {
var titleBuilder = new StringBuilder();
var serverType = server.Type;
titleBuilder.Append('[')
.Append(ServerTypes.ToString(serverType))
.Append("] ");
titleBuilder.Append('[')
.Append(ServerTypes.ToString(serverType))
.Append("] ");
if (serverType == ServerType.DirectMessage) {
titleBuilder.Append(channelName);
if (serverType == ServerType.DirectMessage) {
titleBuilder.Append(channelName);
}
else {
titleBuilder.Append(server.Name)
.Append(" - ")
.Append(channelName);
}
title = titleBuilder.ToString();
}
else {
titleBuilder.Append(server.Name)
.Append(" - ")
.Append(channelName);
title = channelName;
}
title = titleBuilder.ToString();
}
else {
title = channelName;
items.Add(new CheckBoxItem<ulong>(channelId) {
Title = title,
IsChecked = IncludedChannels == null || IncludedChannels.Contains(channelId)
});
}
items.Add(new CheckBoxItem<ulong>(channelId) {
Title = title,
Checked = included.Contains(channelId)
});
return items;
}
var result = await OpenIdFilterDialog(window, "Included Channels", items);
const string Title = "Included Channels";
List<CheckBoxItem<ulong>> items;
try {
items = await ProgressDialog.ShowIndeterminate(window, Title, "Loading channels...", PrepareChannelItems);
} catch (Exception e) {
await Dialog.ShowOk(window, Title, "Error loading channels: " + e.Message);
return;
}
var result = await OpenIdFilterDialog(Title, items);
if (result != null) {
IncludedChannels = result;
}
}
public async void OpenUserFilterDialog() {
var items = new List<CheckBoxItem<ulong>>();
var included = IncludedUsers;
public async Task OpenUserFilterDialog() {
async Task<List<CheckBoxItem<ulong>>> PrepareUserItems(ProgressDialog dialog) {
var checkBoxItems = new List<CheckBoxItem<ulong>>();
foreach (var user in db.GetAllUsers()) {
var name = user.Name;
var discriminator = user.Discriminator;
await foreach (var user in state.Db.Users.Get()) {
var name = user.Name;
var discriminator = user.Discriminator;
items.Add(new CheckBoxItem<ulong>(user.Id) {
Title = discriminator == null ? name : name + " #" + discriminator,
Checked = included.Contains(user.Id)
});
checkBoxItems.Add(new CheckBoxItem<ulong>(user.Id) {
Title = discriminator == null ? name : name + " #" + discriminator,
IsChecked = IncludedUsers == null || IncludedUsers.Contains(user.Id)
});
}
return checkBoxItems;
}
var result = await OpenIdFilterDialog(window, "Included Users", items);
const string Title = "Included Users";
List<CheckBoxItem<ulong>> items;
try {
items = await ProgressDialog.ShowIndeterminate(window, Title, "Loading users...", PrepareUserItems);
} catch (Exception e) {
await Dialog.ShowOk(window, Title, "Error loading users: " + e.Message);
return;
}
var result = await OpenIdFilterDialog(Title, items);
if (result != null) {
IncludedUsers = result;
}
}
private void UpdateChannelFilterLabel() {
long total = db.Statistics.TotalChannels;
long included = FilterByChannel ? IncludedChannels.Count : total;
ChannelFilterLabel = "Selected " + included.Format() + " / " + total.Pluralize("channel") + ".";
}
private void UpdateUserFilterLabel() {
long total = db.Statistics.TotalUsers;
long included = FilterByUser ? IncludedUsers.Count : total;
UserFilterLabel = "Selected " + included.Format() + " / " + total.Pluralize("user") + ".";
}
public MessageFilter CreateFilter() {
MessageFilter filter = new();
if (FilterByDate) {
filter.StartDate = StartDate;
filter.EndDate = EndDate?.AddDays(1).AddMilliseconds(-1);
}
if (FilterByChannel) {
filter.ChannelIds = new HashSet<ulong>(IncludedChannels);
}
if (FilterByUser) {
filter.UserIds = new HashSet<ulong>(IncludedUsers);
}
return filter;
}
private static async Task<HashSet<ulong>?> OpenIdFilterDialog(Window window, string title, List<CheckBoxItem<ulong>> items) {
private async Task<HashSet<ulong>?> OpenIdFilterDialog(string title, List<CheckBoxItem<ulong>> items) {
items.Sort(static (item1, item2) => item1.Title.CompareTo(item2.Title));
var model = new CheckBoxDialogModel<ulong>(items) {
@@ -282,4 +289,23 @@ sealed class MessageFilterPanelModel : BaseModel, IDisposable {
return result == DialogResult.OkCancel.Ok ? model.SelectedItems.Select(static item => item.Item).ToHashSet() : null;
}
public MessageFilter CreateFilter() {
MessageFilter filter = new ();
if (FilterByDate) {
filter.StartDate = StartDate;
filter.EndDate = EndDate?.AddDays(1).AddMilliseconds(-1);
}
if (FilterByChannel && IncludedChannels != null) {
filter.ChannelIds = new HashSet<ulong>(IncludedChannels);
}
if (FilterByUser && IncludedUsers != null) {
filter.UserIds = new HashSet<ulong>(IncludedUsers);
}
return filter;
}
}

View File

@@ -4,7 +4,8 @@
xmlns:mc="http://schemas.openxmlformats.org/markup-compatibility/2006"
xmlns:controls="clr-namespace:DHT.Desktop.Main.Controls"
mc:Ignorable="d"
x:Class="DHT.Desktop.Main.Controls.ServerConfigurationPanel">
x:Class="DHT.Desktop.Main.Controls.ServerConfigurationPanel"
x:DataType="controls:ServerConfigurationPanelModel">
<Design.DataContext>
<controls:ServerConfigurationPanelModel />

View File

@@ -1,96 +1,95 @@
using System;
using System.Threading.Tasks;
using Avalonia.Controls;
using Avalonia.Threading;
using CommunityToolkit.Mvvm.ComponentModel;
using DHT.Desktop.Dialogs.Message;
using DHT.Desktop.Server;
using DHT.Server.Database;
using DHT.Server;
using DHT.Server.Service;
using DHT.Utils.Models;
using DHT.Utils.Logging;
namespace DHT.Desktop.Main.Controls;
sealed class ServerConfigurationPanelModel : BaseModel, IDisposable {
sealed partial class ServerConfigurationPanelModel : ObservableObject, IDisposable {
private static readonly Log Log = Log.ForType<ServerConfigurationPanelModel>();
[ObservableProperty]
[NotifyPropertyChangedFor(nameof(HasMadeChanges))]
private string inputPort;
public string InputPort {
get => inputPort;
set {
Change(ref inputPort, value);
OnPropertyChanged(nameof(HasMadeChanges));
}
}
[ObservableProperty]
[NotifyPropertyChangedFor(nameof(HasMadeChanges))]
private string inputToken;
public string InputToken {
get => inputToken;
set {
Change(ref inputToken, value);
OnPropertyChanged(nameof(HasMadeChanges));
}
}
public bool HasMadeChanges => ServerManager.Port.ToString() != InputPort || ServerManager.Token != InputToken;
public bool HasMadeChanges => ServerConfiguration.Port.ToString() != InputPort || ServerConfiguration.Token != InputToken;
[ObservableProperty(Setter = Access.Private)]
private bool isToggleServerButtonEnabled = true;
public bool IsToggleServerButtonEnabled {
get => isToggleServerButtonEnabled;
set => Change(ref isToggleServerButtonEnabled, value);
}
public string ToggleServerButtonText => serverManager.IsRunning ? "Stop Server" : "Start Server";
public event EventHandler<StatusBarModel.Status>? ServerStatusChanged;
public string ToggleServerButtonText => server.IsRunning ? "Stop Server" : "Start Server";
private readonly Window window;
private readonly ServerManager serverManager;
private readonly ServerManager server;
[Obsolete("Designer")]
public ServerConfigurationPanelModel() : this(null!, new ServerManager(DummyDatabaseFile.Instance)) {}
public ServerConfigurationPanelModel() : this(null!, State.Dummy) {}
public ServerConfigurationPanelModel(Window window, ServerManager serverManager) {
public ServerConfigurationPanelModel(Window window, State state) {
this.window = window;
this.serverManager = serverManager;
this.inputPort = ServerManager.Port.ToString();
this.inputToken = ServerManager.Token;
}
public void Initialize() {
ServerLauncher.ServerStatusChanged += ServerLauncherOnServerStatusChanged;
this.server = state.Server;
this.inputPort = ServerConfiguration.Port.ToString();
this.inputToken = ServerConfiguration.Token;
server.StatusChanged += OnServerStatusChanged;
}
public void Dispose() {
ServerLauncher.ServerStatusChanged -= ServerLauncherOnServerStatusChanged;
server.StatusChanged -= OnServerStatusChanged;
}
private void ServerLauncherOnServerStatusChanged(object? sender, EventArgs e) {
ServerStatusChanged?.Invoke(this, serverManager.IsRunning ? StatusBarModel.Status.Ready : StatusBarModel.Status.Stopped);
private void OnServerStatusChanged(object? sender, ServerManager.Status e) {
Dispatcher.UIThread.InvokeAsync(UpdateServerStatus);
}
private void UpdateServerStatus() {
OnPropertyChanged(nameof(ToggleServerButtonText));
}
private async Task StartServer() {
IsToggleServerButtonEnabled = false;
try {
await server.Start(ServerConfiguration.Port, ServerConfiguration.Token);
} catch (Exception e) {
Log.Error(e);
await Dialog.ShowOk(window, "Internal Server Error", e.Message);
}
UpdateServerStatus();
IsToggleServerButtonEnabled = true;
}
private void BeforeServerStart() {
private async Task StopServer() {
IsToggleServerButtonEnabled = false;
ServerStatusChanged?.Invoke(this, StatusBarModel.Status.Starting);
try {
await server.Stop();
} catch (Exception e) {
Log.Error(e);
await Dialog.ShowOk(window, "Internal Server Error", e.Message);
}
UpdateServerStatus();
IsToggleServerButtonEnabled = true;
}
private void StartServer() {
BeforeServerStart();
serverManager.Launch();
}
private void StopServer() {
IsToggleServerButtonEnabled = false;
ServerStatusChanged?.Invoke(this, StatusBarModel.Status.Stopping);
serverManager.Stop();
}
public void OnClickToggleServerButton() {
if (serverManager.IsRunning) {
StopServer();
public async Task OnClickToggleServerButton() {
if (server.IsRunning) {
await StopServer();
}
else {
StartServer();
await StartServer();
}
}
@@ -98,19 +97,22 @@ sealed class ServerConfigurationPanelModel : BaseModel, IDisposable {
InputToken = ServerUtils.GenerateRandomToken(20);
}
public async void OnClickApplyChanges() {
public async Task OnClickApplyChanges() {
if (!ushort.TryParse(InputPort, out ushort port)) {
await Dialog.ShowOk(window, "Invalid Port", "Port must be a number between 0 and 65535.");
return;
}
BeforeServerStart();
serverManager.Relaunch(port, InputToken);
ServerConfiguration.Port = port;
ServerConfiguration.Token = inputToken;
OnPropertyChanged(nameof(HasMadeChanges));
await StartServer();
}
public void OnClickCancelChanges() {
InputPort = ServerManager.Port.ToString();
InputToken = ServerManager.Token;
InputPort = ServerConfiguration.Port.ToString();
InputToken = ServerConfiguration.Token;
}
}

View File

@@ -4,7 +4,8 @@
xmlns:mc="http://schemas.openxmlformats.org/markup-compatibility/2006"
xmlns:controls="clr-namespace:DHT.Desktop.Main.Controls"
mc:Ignorable="d"
x:Class="DHT.Desktop.Main.Controls.StatusBar">
x:Class="DHT.Desktop.Main.Controls.StatusBar"
x:DataType="controls:StatusBarModel">
<Design.DataContext>
<controls:StatusBarModel />
@@ -39,22 +40,22 @@
<StackPanel Orientation="Horizontal" Margin="6 3">
<StackPanel Orientation="Vertical" Width="65">
<TextBlock Classes="label">Status</TextBlock>
<TextBlock FontSize="12" Margin="0 3 0 0" Text="{Binding StatusText}" />
<TextBlock FontSize="12" Margin="0 3 0 0" Text="{Binding ServerStatusText}" />
</StackPanel>
<Rectangle />
<StackPanel Orientation="Vertical">
<TextBlock Classes="label">Servers</TextBlock>
<TextBlock Classes="value" Text="{Binding DatabaseStatistics.TotalServers, Converter={StaticResource NumberValueConverter}}" />
<TextBlock Classes="value" Text="{Binding ServerCount, Mode=OneWay, Converter={StaticResource NumberValueConverter}}" />
</StackPanel>
<Rectangle />
<StackPanel Orientation="Vertical">
<TextBlock Classes="label">Channels</TextBlock>
<TextBlock Classes="value" Text="{Binding DatabaseStatistics.TotalChannels, Converter={StaticResource NumberValueConverter}}" />
<TextBlock Classes="value" Text="{Binding ChannelCount, Mode=OneWay, Converter={StaticResource NumberValueConverter}}" />
</StackPanel>
<Rectangle />
<StackPanel Orientation="Vertical">
<TextBlock Classes="label">Messages</TextBlock>
<TextBlock Classes="value" Text="{Binding DatabaseStatistics.TotalMessages, Converter={StaticResource NumberValueConverter}}" />
<TextBlock Classes="value" Text="{Binding MessageCount, Mode=OneWay, Converter={StaticResource NumberValueConverter}}" />
</StackPanel>
</StackPanel>

View File

@@ -1,45 +1,63 @@
using System;
using DHT.Server.Database;
using DHT.Utils.Models;
using System.Reactive.Linq;
using Avalonia.ReactiveUI;
using Avalonia.Threading;
using CommunityToolkit.Mvvm.ComponentModel;
using DHT.Server;
using DHT.Server.Service;
namespace DHT.Desktop.Main.Controls;
sealed class StatusBarModel : BaseModel {
public DatabaseStatistics DatabaseStatistics { get; }
sealed partial class StatusBarModel : ObservableObject, IDisposable {
[ObservableProperty(Setter = Access.Private)]
private long? serverCount;
[ObservableProperty(Setter = Access.Private)]
private long? channelCount;
[ObservableProperty(Setter = Access.Private)]
private long? messageCount;
private Status status = Status.Stopped;
[ObservableProperty(Setter = Access.Private)]
[NotifyPropertyChangedFor(nameof(ServerStatusText))]
private ServerManager.Status serverStatus;
public Status CurrentStatus {
get => status;
set {
status = value;
OnPropertyChanged(nameof(StatusText));
}
}
public string ServerStatusText => serverStatus switch {
ServerManager.Status.Starting => "STARTING",
ServerManager.Status.Started => "READY",
ServerManager.Status.Stopping => "STOPPING",
ServerManager.Status.Stopped => "STOPPED",
_ => ""
};
public string StatusText {
get {
return CurrentStatus switch {
Status.Starting => "STARTING",
Status.Ready => "READY",
Status.Stopping => "STOPPING",
Status.Stopped => "STOPPED",
_ => ""
};
}
}
private readonly State state;
private readonly IDisposable serverCountSubscription;
private readonly IDisposable channelCountSubscription;
private readonly IDisposable messageCountSubscription;
[Obsolete("Designer")]
public StatusBarModel() : this(new DatabaseStatistics()) {}
public StatusBarModel() : this(State.Dummy) {}
public StatusBarModel(DatabaseStatistics databaseStatistics) {
this.DatabaseStatistics = databaseStatistics;
public StatusBarModel(State state) {
this.state = state;
serverCountSubscription = state.Db.Servers.TotalCount.ObserveOn(AvaloniaScheduler.Instance).Subscribe(newServerCount => ServerCount = newServerCount);
channelCountSubscription = state.Db.Channels.TotalCount.ObserveOn(AvaloniaScheduler.Instance).Subscribe(newChannelCount => ChannelCount = newChannelCount);
messageCountSubscription = state.Db.Messages.TotalCount.ObserveOn(AvaloniaScheduler.Instance).Subscribe(newMessageCount => MessageCount = newMessageCount);
state.Server.StatusChanged += OnServerStatusChanged;
serverStatus = state.Server.IsRunning ? ServerManager.Status.Started : ServerManager.Status.Stopped;
}
public enum Status {
Starting,
Ready,
Stopping,
Stopped
public void Dispose() {
serverCountSubscription.Dispose();
channelCountSubscription.Dispose();
messageCountSubscription.Dispose();
state.Server.StatusChanged -= OnServerStatusChanged;
}
private void OnServerStatusChanged(object? sender, ServerManager.Status e) {
Dispatcher.UIThread.InvokeAsync(() => ServerStatus = e);
}
}

View File

@@ -5,12 +5,13 @@
xmlns:main="clr-namespace:DHT.Desktop.Main"
mc:Ignorable="d" d:DesignWidth="800" d:DesignHeight="450"
x:Class="DHT.Desktop.Main.MainWindow"
x:DataType="main:MainWindowModel"
Title="{Binding Title}"
Icon="avares://DiscordHistoryTracker/Resources/icon.ico"
Width="800" Height="500"
Width="820" Height="520"
MinWidth="520" MinHeight="300"
WindowStartupLocation="CenterScreen"
Closed="OnClosed">
Closing="OnClosing">
<Design.DataContext>
<main:MainWindowModel />

View File

@@ -1,14 +1,18 @@
using System;
using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Threading.Tasks;
using Avalonia.Controls;
using DHT.Desktop.Main.Pages;
using DHT.Utils.Logging;
using JetBrains.Annotations;
namespace DHT.Desktop.Main;
[SuppressMessage("ReSharper", "MemberCanBeInternal")]
public sealed partial class MainWindow : Window {
private static readonly Log Log = Log.ForType<MainWindow>();
[UsedImplicitly]
public MainWindow() {
InitializeComponent();
@@ -20,9 +24,24 @@ public sealed partial class MainWindow : Window {
DataContext = new MainWindowModel(this, args);
}
public void OnClosed(object? sender, EventArgs e) {
if (DataContext is IDisposable disposable) {
disposable.Dispose();
public async void OnClosing(object? sender, WindowClosingEventArgs e) {
e.Cancel = true;
Closing -= OnClosing;
try {
await Dispose();
} finally {
Close();
}
}
private async Task Dispose() {
if (DataContext is MainWindowModel model) {
try {
await model.DisposeAsync();
} catch (Exception ex) {
Log.Error("Caught exception while disposing window: " + ex);
}
}
foreach (var temporaryFile in ViewerPageModel.TemporaryFiles) {

View File

@@ -1,33 +1,37 @@
using System;
using System.ComponentModel;
using System.IO;
using System.Runtime.InteropServices;
using System.Threading.Tasks;
using Avalonia.Controls;
using CommunityToolkit.Mvvm.ComponentModel;
using DHT.Desktop.Dialogs.Message;
using DHT.Desktop.Main.Screens;
using DHT.Desktop.Server;
using DHT.Server;
using DHT.Server.Database;
using DHT.Utils.Models;
using DHT.Utils.Logging;
namespace DHT.Desktop.Main;
sealed class MainWindowModel : BaseModel, IDisposable {
sealed partial class MainWindowModel : ObservableObject, IAsyncDisposable {
private const string DefaultTitle = "Discord History Tracker";
public string Title { get; private set; } = DefaultTitle;
private static readonly Log Log = Log.ForType<MainWindowModel>();
public UserControl CurrentScreen { get; private set; }
[ObservableProperty(Setter = Access.Private)]
private string title = DefaultTitle;
[ObservableProperty(Setter = Access.Private)]
private UserControl currentScreen;
private readonly WelcomeScreen welcomeScreen;
private readonly WelcomeScreenModel welcomeScreenModel;
private MainContentScreen? mainContentScreen;
private MainContentScreenModel? mainContentScreenModel;
private readonly Window window;
private IDatabaseFile? db;
private State? state;
[Obsolete("Designer")]
public MainWindowModel() : this(null!, Arguments.Empty) {}
@@ -36,10 +40,10 @@ sealed class MainWindowModel : BaseModel, IDisposable {
this.window = window;
welcomeScreenModel = new WelcomeScreenModel(window);
welcomeScreen = new WelcomeScreen { DataContext = welcomeScreenModel };
CurrentScreen = welcomeScreen;
welcomeScreenModel.DatabaseSelected += OnDatabaseSelected;
welcomeScreenModel.PropertyChanged += WelcomeScreenModelOnPropertyChanged;
welcomeScreen = new WelcomeScreen { DataContext = welcomeScreenModel };
currentScreen = welcomeScreen;
var dbFile = args.DatabaseFile;
if (!string.IsNullOrWhiteSpace(dbFile)) {
@@ -63,54 +67,61 @@ sealed class MainWindowModel : BaseModel, IDisposable {
}
if (args.ServerPort != null) {
ServerManager.Port = args.ServerPort.Value;
ServerConfiguration.Port = args.ServerPort.Value;
}
if (args.ServerToken != null) {
ServerManager.Token = args.ServerToken;
ServerConfiguration.Token = args.ServerToken;
}
}
private async void WelcomeScreenModelOnPropertyChanged(object? sender, PropertyChangedEventArgs e) {
if (e.PropertyName == nameof(welcomeScreenModel.Db)) {
if (mainContentScreenModel != null) {
mainContentScreenModel.DatabaseClosed -= MainContentScreenModelOnDatabaseClosed;
mainContentScreenModel.Dispose();
}
private async void OnDatabaseSelected(object? sender, IDatabaseFile db) {
welcomeScreenModel.DatabaseSelected -= OnDatabaseSelected;
await DisposeState();
state = new State(db);
db?.Dispose();
db = welcomeScreenModel.Db;
try {
await state.Server.Start(ServerConfiguration.Port, ServerConfiguration.Token);
} catch (Exception ex) {
Log.Error(ex);
await Dialog.ShowOk(window, "Internal Server Error", ex.Message);
}
if (db == null) {
Title = DefaultTitle;
mainContentScreenModel = null;
mainContentScreen = null;
CurrentScreen = welcomeScreen;
}
else {
Title = Path.GetFileName(db.Path) + " - " + DefaultTitle;
mainContentScreenModel = new MainContentScreenModel(window, db);
await mainContentScreenModel.Initialize();
mainContentScreenModel.DatabaseClosed += MainContentScreenModelOnDatabaseClosed;
mainContentScreen = new MainContentScreen { DataContext = mainContentScreenModel };
CurrentScreen = mainContentScreen;
}
mainContentScreenModel = new MainContentScreenModel(window, state);
mainContentScreenModel.DatabaseClosed += MainContentScreenModelOnDatabaseClosed;
Title = Path.GetFileName(state.Db.Path) + " - " + DefaultTitle;
CurrentScreen = new MainContentScreen { DataContext = mainContentScreenModel };
OnPropertyChanged(nameof(CurrentScreen));
OnPropertyChanged(nameof(Title));
window.Focus();
}
window.Focus();
private async void MainContentScreenModelOnDatabaseClosed(object? sender, EventArgs e) {
if (mainContentScreenModel != null) {
mainContentScreenModel.DatabaseClosed -= MainContentScreenModelOnDatabaseClosed;
mainContentScreenModel.Dispose();
mainContentScreenModel = null;
}
await DisposeState();
Title = DefaultTitle;
CurrentScreen = welcomeScreen;
welcomeScreenModel.DatabaseSelected += OnDatabaseSelected;
}
private async Task DisposeState() {
if (state != null) {
await state.DisposeAsync();
state = null;
}
}
private void MainContentScreenModelOnDatabaseClosed(object? sender, EventArgs e) {
welcomeScreenModel.CloseDatabase();
}
public void Dispose() {
welcomeScreenModel.Dispose();
public async ValueTask DisposeAsync() {
mainContentScreenModel?.Dispose();
db?.Dispose();
db = null;
await DisposeState();
}
}

View File

@@ -5,7 +5,8 @@
xmlns:pages="clr-namespace:DHT.Desktop.Main.Pages"
xmlns:controls="clr-namespace:DHT.Desktop.Main.Controls"
mc:Ignorable="d" d:DesignWidth="800" d:DesignHeight="450"
x:Class="DHT.Desktop.Main.Pages.AdvancedPage">
x:Class="DHT.Desktop.Main.Pages.AdvancedPage"
x:DataType="pages:AdvancedPageModel">
<Design.DataContext>
<pages:AdvancedPageModel />

View File

@@ -1,39 +1,36 @@
using System;
using System.Threading.Tasks;
using Avalonia.Controls;
using DHT.Desktop.Dialogs.Message;
using DHT.Desktop.Dialogs.Progress;
using DHT.Desktop.Main.Controls;
using DHT.Desktop.Server;
using DHT.Server.Database;
using DHT.Utils.Models;
using DHT.Server;
namespace DHT.Desktop.Main.Pages;
sealed class AdvancedPageModel : BaseModel, IDisposable {
sealed class AdvancedPageModel : IDisposable {
public ServerConfigurationPanelModel ServerConfigurationModel { get; }
private readonly Window window;
private readonly IDatabaseFile db;
private readonly State state;
[Obsolete("Designer")]
public AdvancedPageModel() : this(null!, DummyDatabaseFile.Instance, new ServerManager(DummyDatabaseFile.Instance)) {}
public AdvancedPageModel() : this(null!, State.Dummy) {}
public AdvancedPageModel(Window window, IDatabaseFile db, ServerManager serverManager) {
public AdvancedPageModel(Window window, State state) {
this.window = window;
this.db = db;
this.state = state;
ServerConfigurationModel = new ServerConfigurationPanelModel(window, serverManager);
}
public void Initialize() {
ServerConfigurationModel.Initialize();
ServerConfigurationModel = new ServerConfigurationPanelModel(window, state);
}
public void Dispose() {
ServerConfigurationModel.Dispose();
}
public async void VacuumDatabase() {
db.Vacuum();
await Dialog.ShowOk(window, "Vacuum Database", "Done.");
public async Task VacuumDatabase() {
const string Title = "Vacuum Database";
await ProgressDialog.ShowIndeterminate(window, Title, "Vacuuming database...", _ => state.Db.Vacuum());
await Dialog.ShowOk(window, Title, "Done.");
}
}

View File

@@ -5,7 +5,8 @@
xmlns:pages="clr-namespace:DHT.Desktop.Main.Pages"
xmlns:controls="clr-namespace:DHT.Desktop.Main.Controls"
mc:Ignorable="d" d:DesignWidth="800" d:DesignHeight="450"
x:Class="DHT.Desktop.Main.Pages.AttachmentsPage">
x:Class="DHT.Desktop.Main.Pages.AttachmentsPage"
x:DataType="pages:AttachmentsPageModel">
<Design.DataContext>
<pages:AttachmentsPageModel />
@@ -32,22 +33,22 @@
<StackPanel Orientation="Vertical" Spacing="20">
<DockPanel>
<Button Command="{Binding OnClickToggleDownload}" Content="{Binding ToggleDownloadButtonText}" IsEnabled="{Binding IsToggleDownloadButtonEnabled}" DockPanel.Dock="Left" />
<TextBlock Text="{Binding DownloadMessage}" Margin="10 0 0 0" VerticalAlignment="Center" DockPanel.Dock="Left" />
<TextBlock Text="{Binding DownloadMessage}" MinWidth="100" Margin="10 0 0 0" VerticalAlignment="Center" TextAlignment="Right" DockPanel.Dock="Left" />
<ProgressBar Value="{Binding DownloadProgress}" IsVisible="{Binding IsDownloading}" Margin="15 0" VerticalAlignment="Center" DockPanel.Dock="Right" />
</DockPanel>
<controls:AttachmentFilterPanel DataContext="{Binding FilterModel}" IsEnabled="{Binding !DataContext.IsDownloading, RelativeSource={RelativeSource AncestorType=UserControl}}" />
<controls:AttachmentFilterPanel DataContext="{Binding FilterModel}" IsEnabled="{Binding !IsDownloading, RelativeSource={RelativeSource AncestorType=pages:AttachmentsPageModel}}" />
<StackPanel Orientation="Vertical" Spacing="12">
<Expander Header="Download Status" IsExpanded="True">
<DataGrid ItemsSource="{Binding StatisticsRows}" AutoGenerateColumns="False" CanUserReorderColumns="False" CanUserResizeColumns="False" CanUserSortColumns="False" IsReadOnly="True">
<DataGrid.Columns>
<DataGridTextColumn Header="State" Binding="{Binding State}" Width="*" />
<DataGridTextColumn Header="Attachments" Binding="{Binding Items, Converter={StaticResource NumberValueConverter}}" Width="*" CellStyleClasses="right" />
<DataGridTextColumn Header="Size" Binding="{Binding Size, Converter={StaticResource BytesValueConverter}}" Width="*" CellStyleClasses="right" />
<DataGridTextColumn Header="Attachments" Binding="{Binding Items, Mode=OneWay, Converter={StaticResource NumberValueConverter}}" Width="*" CellStyleClasses="right" />
<DataGridTextColumn Header="Size" Binding="{Binding Size, Mode=OneWay, Converter={StaticResource BytesValueConverter}}" Width="*" CellStyleClasses="right" />
</DataGrid.Columns>
</DataGrid>
</Expander>
<StackPanel Orientation="Horizontal" Spacing="10">
<Button Command="{Binding OnClickRetryFailedDownloads}" IsEnabled="{Binding HasFailedDownloads}">Retry Failed Downloads</Button>
<Button Command="{Binding OnClickRetryFailedDownloads}" IsEnabled="{Binding IsRetryFailedOnDownloadsButtonEnabled}">Retry Failed Downloads</Button>
</StackPanel>
</StackPanel>
</StackPanel>

View File

@@ -1,38 +1,50 @@
using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Threading;
using Avalonia.Threading;
using System.Collections.ObjectModel;
using System.Reactive.Linq;
using System.Threading.Tasks;
using Avalonia.ReactiveUI;
using CommunityToolkit.Mvvm.ComponentModel;
using DHT.Desktop.Common;
using DHT.Desktop.Main.Controls;
using DHT.Server;
using DHT.Server.Data;
using DHT.Server.Data.Aggregations;
using DHT.Server.Data.Filters;
using DHT.Server.Database;
using DHT.Server.Download;
using DHT.Utils.Models;
using DHT.Utils.Logging;
using DHT.Utils.Tasks;
namespace DHT.Desktop.Main.Pages;
sealed class AttachmentsPageModel : BaseModel, IDisposable {
private static readonly DownloadItemFilter EnqueuedItemFilter = new() {
sealed partial class AttachmentsPageModel : ObservableObject, IDisposable {
private static readonly Log Log = Log.ForType<AttachmentsPageModel>();
private static readonly DownloadItemFilter EnqueuedItemFilter = new () {
IncludeStatuses = new HashSet<DownloadStatus> {
DownloadStatus.Enqueued
DownloadStatus.Enqueued,
DownloadStatus.Downloading
}
};
private bool isThreadDownloadButtonEnabled = true;
[ObservableProperty(Setter = Access.Private)]
private bool isToggleDownloadButtonEnabled = true;
public string ToggleDownloadButtonText => downloadThread == null ? "Start Downloading" : "Stop Downloading";
public string ToggleDownloadButtonText => IsDownloading ? "Stop Downloading" : "Start Downloading";
public bool IsToggleDownloadButtonEnabled {
get => isThreadDownloadButtonEnabled;
set => Change(ref isThreadDownloadButtonEnabled, value);
}
[ObservableProperty(Setter = Access.Private)]
[NotifyPropertyChangedFor(nameof(IsRetryFailedOnDownloadsButtonEnabled))]
private bool isRetryingFailedDownloads = false;
public string DownloadMessage { get; set; } = "";
public double DownloadProgress => allItemsCount is null or 0 ? 0.0 : 100.0 * doneItemsCount / allItemsCount.Value;
[ObservableProperty(Setter = Access.Private)]
[NotifyPropertyChangedFor(nameof(IsRetryFailedOnDownloadsButtonEnabled))]
private bool hasFailedDownloads;
public bool IsRetryFailedOnDownloadsButtonEnabled => !IsRetryingFailedDownloads && hasFailedDownloads;
[ObservableProperty(Setter = Access.Private)]
private string downloadMessage = "";
public double DownloadProgress => totalItemsToDownloadCount is null or 0 ? 0.0 : 100.0 * doneItemsCount / totalItemsToDownloadCount.Value;
public AttachmentFilterPanelModel FilterModel { get; }
@@ -41,71 +53,161 @@ sealed class AttachmentsPageModel : BaseModel, IDisposable {
private readonly StatisticsRow statisticsFailed = new ("Failed");
private readonly StatisticsRow statisticsSkipped = new ("Skipped");
public List<StatisticsRow> StatisticsRows {
get {
return new List<StatisticsRow> {
statisticsEnqueued,
statisticsDownloaded,
statisticsFailed,
statisticsSkipped
};
}
}
public ObservableCollection<StatisticsRow> StatisticsRows { get; }
public bool IsDownloading => downloadThread != null;
public bool HasFailedDownloads => statisticsFailed.Items > 0;
public bool IsDownloading => state.Downloader.IsDownloading;
private readonly IDatabaseFile db;
private readonly AsyncValueComputer<DownloadStatusStatistics>.Single downloadStatisticsComputer;
private BackgroundDownloadThread? downloadThread;
private readonly State state;
private readonly ThrottledTask<int> enqueueDownloadItemsTask;
private readonly ThrottledTask<DownloadStatusStatistics> downloadStatisticsTask;
private readonly IDisposable attachmentCountSubscription;
private readonly IDisposable downloadCountSubscription;
private IDisposable? finishedItemsSubscription;
private int doneItemsCount;
private int? allItemsCount;
private int totalEnqueuedItemCount;
private int? totalItemsToDownloadCount;
public AttachmentsPageModel() : this(DummyDatabaseFile.Instance) {}
public AttachmentsPageModel() : this(State.Dummy) {}
public AttachmentsPageModel(IDatabaseFile db) {
this.db = db;
this.FilterModel = new AttachmentFilterPanelModel(db);
public AttachmentsPageModel(State state) {
this.state = state;
this.downloadStatisticsComputer = AsyncValueComputer<DownloadStatusStatistics>.WithResultProcessor(UpdateStatistics).WithOutdatedResults().BuildWithComputer(db.GetDownloadStatusStatistics);
this.downloadStatisticsComputer.Recompute();
FilterModel = new AttachmentFilterPanelModel(state);
StatisticsRows = [
statisticsEnqueued,
statisticsDownloaded,
statisticsFailed,
statisticsSkipped
];
db.Statistics.PropertyChanged += OnDbStatisticsChanged;
enqueueDownloadItemsTask = new ThrottledTask<int>(OnItemsEnqueued, TaskScheduler.FromCurrentSynchronizationContext());
downloadStatisticsTask = new ThrottledTask<DownloadStatusStatistics>(UpdateStatistics, TaskScheduler.FromCurrentSynchronizationContext());
attachmentCountSubscription = state.Db.Attachments.TotalCount.ObserveOn(AvaloniaScheduler.Instance).Subscribe(OnAttachmentCountChanged);
downloadCountSubscription = state.Db.Downloads.TotalCount.ObserveOn(AvaloniaScheduler.Instance).Subscribe(OnDownloadCountChanged);
RecomputeDownloadStatistics();
}
public void Dispose() {
db.Statistics.PropertyChanged -= OnDbStatisticsChanged;
attachmentCountSubscription.Dispose();
downloadCountSubscription.Dispose();
finishedItemsSubscription?.Dispose();
enqueueDownloadItemsTask.Dispose();
downloadStatisticsTask.Dispose();
FilterModel.Dispose();
DisposeDownloadThread();
}
private void OnDbStatisticsChanged(object? sender, PropertyChangedEventArgs e) {
if (e.PropertyName == nameof(DatabaseStatistics.TotalAttachments)) {
if (IsDownloading) {
EnqueueDownloadItems();
}
else {
downloadStatisticsComputer.Recompute();
}
private void OnAttachmentCountChanged(long newAttachmentCount) {
if (IsDownloading) {
EnqueueDownloadItemsLater();
}
else if (e.PropertyName == nameof(DatabaseStatistics.TotalDownloads)) {
downloadStatisticsComputer.Recompute();
else {
RecomputeDownloadStatistics();
}
}
private void EnqueueDownloadItems() {
private void OnDownloadCountChanged(long newDownloadCount) {
RecomputeDownloadStatistics();
}
private async Task EnqueueDownloadItems() {
OnItemsEnqueued(await state.Db.Downloads.EnqueueDownloadItems(CreateAttachmentFilter()));
}
private void EnqueueDownloadItemsLater() {
var filter = CreateAttachmentFilter();
enqueueDownloadItemsTask.Post(cancellationToken => state.Db.Downloads.EnqueueDownloadItems(filter, cancellationToken));
}
private void OnItemsEnqueued(int itemCount) {
totalEnqueuedItemCount += itemCount;
totalItemsToDownloadCount = totalEnqueuedItemCount;
UpdateDownloadMessage();
RecomputeDownloadStatistics();
}
private AttachmentFilter CreateAttachmentFilter() {
var filter = FilterModel.CreateFilter();
filter.DownloadItemRule = AttachmentFilter.DownloadItemRules.OnlyNotPresent;
db.EnqueueDownloadItems(filter);
return filter;
}
downloadStatisticsComputer.Recompute();
public async Task OnClickToggleDownload() {
IsToggleDownloadButtonEnabled = false;
if (IsDownloading) {
await state.Downloader.Stop();
finishedItemsSubscription?.Dispose();
finishedItemsSubscription = null;
RecomputeDownloadStatistics();
await state.Db.Downloads.RemoveDownloadItems(EnqueuedItemFilter, FilterRemovalMode.RemoveMatching);
doneItemsCount = 0;
totalEnqueuedItemCount = 0;
totalItemsToDownloadCount = null;
UpdateDownloadMessage();
}
else {
var finishedItems = await state.Downloader.Start();
finishedItemsSubscription = finishedItems.Select(static _ => true)
.Buffer(TimeSpan.FromMilliseconds(100))
.Select(static items => items.Count)
.Where(static items => items > 0)
.ObserveOn(AvaloniaScheduler.Instance)
.Subscribe(OnItemsFinished);
await EnqueueDownloadItems();
}
OnPropertyChanged(nameof(ToggleDownloadButtonText));
OnPropertyChanged(nameof(IsDownloading));
IsToggleDownloadButtonEnabled = true;
}
private void OnItemsFinished(int finishedItemCount) {
doneItemsCount += finishedItemCount;
UpdateDownloadMessage();
}
public async Task OnClickRetryFailedDownloads() {
IsRetryingFailedDownloads = true;
try {
var allExceptFailedFilter = new DownloadItemFilter {
IncludeStatuses = new HashSet<DownloadStatus> {
DownloadStatus.Enqueued,
DownloadStatus.Downloading,
DownloadStatus.Success
}
};
await state.Db.Downloads.RemoveDownloadItems(allExceptFailedFilter, FilterRemovalMode.KeepMatching);
if (IsDownloading) {
await EnqueueDownloadItems();
}
} catch (Exception e) {
Log.Error(e);
} finally {
IsRetryingFailedDownloads = false;
}
}
private void RecomputeDownloadStatistics() {
downloadStatisticsTask.Post(state.Db.Downloads.GetStatistics);
}
private void UpdateStatistics(DownloadStatusStatistics statusStatistics) {
var hadFailedDownloads = HasFailedDownloads;
statisticsEnqueued.Items = statusStatistics.EnqueuedCount;
statisticsEnqueued.Size = statusStatistics.EnqueuedSize;
@@ -118,88 +220,25 @@ sealed class AttachmentsPageModel : BaseModel, IDisposable {
statisticsSkipped.Items = statusStatistics.SkippedCount;
statisticsSkipped.Size = statusStatistics.SkippedSize;
OnPropertyChanged(nameof(StatisticsRows));
hasFailedDownloads = statusStatistics.FailedCount > 0;
if (hadFailedDownloads != HasFailedDownloads) {
OnPropertyChanged(nameof(HasFailedDownloads));
}
allItemsCount = doneItemsCount + statisticsEnqueued.Items;
UpdateDownloadMessage();
}
private void UpdateDownloadMessage() {
DownloadMessage = IsDownloading ? doneItemsCount.Format() + " / " + (allItemsCount?.Format() ?? "?") : "";
DownloadMessage = IsDownloading ? doneItemsCount.Format() + " / " + (totalItemsToDownloadCount?.Format() ?? "?") : "";
OnPropertyChanged(nameof(DownloadMessage));
OnPropertyChanged(nameof(DownloadProgress));
}
private void DownloadThreadOnOnItemFinished(object? sender, DownloadItem e) {
Interlocked.Increment(ref doneItemsCount);
[ObservableObject]
public sealed partial class StatisticsRow(string state) {
public string State { get; } = state;
Dispatcher.UIThread.Invoke(UpdateDownloadMessage);
downloadStatisticsComputer.Recompute();
}
[ObservableProperty]
private int items;
private void DownloadThreadOnOnServerStopped(object? sender, EventArgs e) {
downloadStatisticsComputer.Recompute();
IsToggleDownloadButtonEnabled = true;
}
public void OnClickToggleDownload() {
if (downloadThread == null) {
EnqueueDownloadItems();
downloadThread = new BackgroundDownloadThread(db);
downloadThread.OnItemFinished += DownloadThreadOnOnItemFinished;
downloadThread.OnServerStopped += DownloadThreadOnOnServerStopped;
}
else {
IsToggleDownloadButtonEnabled = false;
DisposeDownloadThread();
db.RemoveDownloadItems(EnqueuedItemFilter, FilterRemovalMode.RemoveMatching);
doneItemsCount = 0;
allItemsCount = null;
UpdateDownloadMessage();
}
OnPropertyChanged(nameof(ToggleDownloadButtonText));
OnPropertyChanged(nameof(IsDownloading));
}
public void OnClickRetryFailedDownloads() {
var allExceptFailedFilter = new DownloadItemFilter {
IncludeStatuses = new HashSet<DownloadStatus> {
DownloadStatus.Enqueued,
DownloadStatus.Success
}
};
db.RemoveDownloadItems(allExceptFailedFilter, FilterRemovalMode.KeepMatching);
if (IsDownloading) {
EnqueueDownloadItems();
}
}
private void DisposeDownloadThread() {
if (downloadThread != null) {
downloadThread.OnItemFinished -= DownloadThreadOnOnItemFinished;
downloadThread.StopThread();
}
downloadThread = null;
}
public sealed class StatisticsRow {
public string State { get; }
public int Items { get; set; }
public ulong? Size { get; set; }
public StatisticsRow(string state) {
State = state;
}
[ObservableProperty]
private ulong? size;
}
}

View File

@@ -4,7 +4,8 @@
xmlns:mc="http://schemas.openxmlformats.org/markup-compatibility/2006"
xmlns:pages="clr-namespace:DHT.Desktop.Main.Pages"
mc:Ignorable="d" d:DesignWidth="800" d:DesignHeight="450"
x:Class="DHT.Desktop.Main.Pages.DatabasePage">
x:Class="DHT.Desktop.Main.Pages.DatabasePage"
x:DataType="pages:DatabasePageModel">
<Design.DataContext>
<pages:DatabasePageModel />

View File

@@ -14,15 +14,16 @@ using DHT.Desktop.Dialogs.File;
using DHT.Desktop.Dialogs.Message;
using DHT.Desktop.Dialogs.Progress;
using DHT.Desktop.Dialogs.TextBox;
using DHT.Server;
using DHT.Server.Data;
using DHT.Server.Database;
using DHT.Server.Database.Import;
using DHT.Server.Database.Sqlite.Utils;
using DHT.Utils.Logging;
using DHT.Utils.Models;
namespace DHT.Desktop.Main.Pages;
namespace DHT.Desktop.Main.Pages;
sealed class DatabasePageModel : BaseModel {
sealed class DatabasePageModel {
private static readonly Log Log = Log.ForType<DatabasePageModel>();
public IDatabaseFile Db { get; }
@@ -32,14 +33,14 @@ sealed class DatabasePageModel : BaseModel {
private readonly Window window;
[Obsolete("Designer")]
public DatabasePageModel() : this(null!, DummyDatabaseFile.Instance) {}
public DatabasePageModel() : this(null!, State.Dummy) {}
public DatabasePageModel(Window window, IDatabaseFile db) {
public DatabasePageModel(Window window, State state) {
this.window = window;
this.Db = db;
this.Db = state.Db;
}
public async void OpenDatabaseFolder() {
public async Task OpenDatabaseFolder() {
string file = Db.Path;
string? folder = Path.GetDirectoryName(file);
@@ -70,69 +71,77 @@ sealed class DatabasePageModel : BaseModel {
DatabaseClosed?.Invoke(this, EventArgs.Empty);
}
public async void MergeWithDatabase() {
public async Task MergeWithDatabase() {
var paths = await DatabaseGui.NewOpenDatabaseFilesDialog(window, Path.GetDirectoryName(Db.Path));
if (paths.Length == 0) {
return;
if (paths.Length > 0) {
await ProgressDialog.Show(window, "Database Merge", async (dialog, callback) => await MergeWithDatabaseFromPaths(Db, paths, dialog, callback));
}
ProgressDialog progressDialog = new ProgressDialog();
progressDialog.DataContext = new ProgressDialogModel(async callback => await MergeWithDatabaseFromPaths(Db, paths, progressDialog, callback)) {
Title = "Database Merge"
};
await progressDialog.ShowDialog(window);
}
private static async Task MergeWithDatabaseFromPaths(IDatabaseFile target, string[] paths, ProgressDialog dialog, IProgressCallback callback) {
int total = paths.Length;
DialogResult.YesNo? upgradeResult = null;
async Task<bool> CheckCanUpgradeDatabase() {
upgradeResult ??= total > 1
? await DatabaseGui.ShowCanUpgradeMultipleDatabaseDialog(dialog)
: await DatabaseGui.ShowCanUpgradeDatabaseDialog(dialog);
return DialogResult.YesNo.Yes == upgradeResult;
}
var schemaUpgradeCallbacks = new SchemaUpgradeCallbacks(dialog, paths.Length);
await PerformImport(target, paths, dialog, callback, "Database Merge", "Database Error", "database file", async path => {
SynchronizationContext? prevSyncContext = SynchronizationContext.Current;
SynchronizationContext.SetSynchronizationContext(new AvaloniaSynchronizationContext());
IDatabaseFile? db = await DatabaseGui.TryOpenOrCreateDatabaseFromPath(path, dialog, CheckCanUpgradeDatabase);
SynchronizationContext.SetSynchronizationContext(prevSyncContext);
IDatabaseFile? db = await DatabaseGui.TryOpenOrCreateDatabaseFromPath(path, dialog, schemaUpgradeCallbacks);
if (db == null) {
return false;
}
try {
target.AddFrom(db);
await target.AddFrom(db);
return true;
} finally {
db.Dispose();
await db.DisposeAsync();
}
});
}
public async void ImportLegacyArchive() {
private sealed class SchemaUpgradeCallbacks : ISchemaUpgradeCallbacks {
private readonly ProgressDialog dialog;
private readonly int total;
private bool? decision;
public SchemaUpgradeCallbacks(ProgressDialog dialog, int total) {
this.total = total;
this.dialog = dialog;
}
public async Task<bool> CanUpgrade() {
return decision ??= (total > 1
? await DatabaseGui.ShowCanUpgradeMultipleDatabaseDialog(dialog)
: await DatabaseGui.ShowCanUpgradeDatabaseDialog(dialog)) == DialogResult.YesNo.Yes;
}
public Task Start(int versionSteps, Func<ISchemaUpgradeCallbacks.IProgressReporter, Task> doUpgrade) {
return doUpgrade(new NullReporter());
}
private sealed class NullReporter : ISchemaUpgradeCallbacks.IProgressReporter {
public Task NextVersion() {
return Task.CompletedTask;
}
public Task MainWork(string message, int finishedItems, int totalItems) {
return Task.CompletedTask;
}
public Task SubWork(string message, int finishedItems, int totalItems) {
return Task.CompletedTask;
}
}
}
public async Task ImportLegacyArchive() {
var paths = await window.StorageProvider.OpenFiles(new FilePickerOpenOptions {
Title = "Open Legacy DHT Archive",
SuggestedStartLocation = await FileDialogs.GetSuggestedStartLocation(window, Path.GetDirectoryName(Db.Path)),
AllowMultiple = true
});
if (paths.Length == 0) {
return;
if (paths.Length > 0) {
await ProgressDialog.Show(window, "Legacy Archive Import", async (dialog, callback) => await ImportLegacyArchiveFromPaths(Db, paths, dialog, callback));
}
ProgressDialog progressDialog = new ProgressDialog();
progressDialog.DataContext = new ProgressDialogModel(async callback => await ImportLegacyArchiveFromPaths(Db, paths, progressDialog, callback)) {
Title = "Legacy Archive Import"
};
await progressDialog.ShowDialog(window);
}
private static async Task ImportLegacyArchiveFromPaths(IDatabaseFile target, string[] paths, ProgressDialog dialog, IProgressCallback callback) {
@@ -140,7 +149,7 @@ sealed class DatabasePageModel : BaseModel {
await PerformImport(target, paths, dialog, callback, "Legacy Archive Import", "Legacy Archive Error", "archive file", async path => {
await using var jsonStream = new FileStream(path, FileMode.Open, FileAccess.Read, FileShare.Read);
return await LegacyArchiveImport.Read(jsonStream, target, fakeSnowflake, async servers => {
SynchronizationContext? prevSyncContext = SynchronizationContext.Current;
SynchronizationContext.SetSynchronizationContext(new AvaloniaSynchronizationContext());
@@ -155,7 +164,7 @@ sealed class DatabasePageModel : BaseModel {
static bool IsValidSnowflake(string value) {
return string.IsNullOrEmpty(value) || ulong.TryParse(value, out _);
}
var items = new List<TextBoxItem<DHT.Server.Data.Server>>();
foreach (var server in servers.OrderBy(static server => server.Type).ThenBy(static server => server.Name)) {
@@ -184,7 +193,7 @@ sealed class DatabasePageModel : BaseModel {
private static async Task PerformImport(IDatabaseFile target, string[] paths, ProgressDialog dialog, IProgressCallback callback, string neutralDialogTitle, string errorDialogTitle, string itemName, Func<string, Task<bool>> performImport) {
int total = paths.Length;
var oldStatistics = target.SnapshotStatistics();
var oldStatistics = await DatabaseStatistics.Take(target);
int successful = 0;
int finished = 0;
@@ -215,14 +224,26 @@ sealed class DatabasePageModel : BaseModel {
return;
}
await Dialog.ShowOk(dialog, neutralDialogTitle, GetImportDialogMessage(oldStatistics, target.SnapshotStatistics(), successful, total, itemName));
var newStatistics = await DatabaseStatistics.Take(target);
await Dialog.ShowOk(dialog, neutralDialogTitle, GetImportDialogMessage(oldStatistics, newStatistics, successful, total, itemName));
}
private static string GetImportDialogMessage(DatabaseStatisticsSnapshot oldStatistics, DatabaseStatisticsSnapshot newStatistics, int successfulItems, int totalItems, string itemName) {
long newServers = newStatistics.TotalServers - oldStatistics.TotalServers;
long newChannels = newStatistics.TotalChannels - oldStatistics.TotalChannels;
long newUsers = newStatistics.TotalUsers - oldStatistics.TotalUsers;
long newMessages = newStatistics.TotalMessages - oldStatistics.TotalMessages;
private sealed record DatabaseStatistics(long ServerCount, long ChannelCount, long UserCount, long MessageCount) {
public static async Task<DatabaseStatistics> Take(IDatabaseFile db) {
return new DatabaseStatistics(
await db.Servers.Count(),
await db.Channels.Count(),
await db.Users.Count(),
await db.Messages.Count()
);
}
}
private static string GetImportDialogMessage(DatabaseStatistics oldStatistics, DatabaseStatistics newStatistics, int successfulItems, int totalItems, string itemName) {
long newServers = newStatistics.ServerCount - oldStatistics.ServerCount;
long newChannels = newStatistics.ChannelCount - oldStatistics.ChannelCount;
long newUsers = newStatistics.UserCount - oldStatistics.UserCount;
long newMessages = newStatistics.MessageCount - oldStatistics.MessageCount;
StringBuilder message = new StringBuilder();
message.Append("Processed ");

View File

@@ -4,7 +4,8 @@
xmlns:mc="http://schemas.openxmlformats.org/markup-compatibility/2006"
xmlns:pages="clr-namespace:DHT.Desktop.Main.Pages"
mc:Ignorable="d" d:DesignWidth="800" d:DesignHeight="450"
x:Class="DHT.Desktop.Main.Pages.DebugPage">
x:Class="DHT.Desktop.Main.Pages.DebugPage"
x:DataType="pages:DebugPageModel">
<Design.DataContext>
<pages:DebugPageModel />

View File

@@ -6,26 +6,25 @@ using System.Threading.Tasks;
using Avalonia.Controls;
using DHT.Desktop.Dialogs.Message;
using DHT.Desktop.Dialogs.Progress;
using DHT.Server;
using DHT.Server.Data;
using DHT.Server.Database;
using DHT.Server.Service;
using DHT.Utils.Models;
namespace DHT.Desktop.Main.Pages {
sealed class DebugPageModel : BaseModel {
sealed class DebugPageModel {
public string GenerateChannels { get; set; } = "0";
public string GenerateUsers { get; set; } = "0";
public string GenerateMessages { get; set; } = "0";
private readonly Window window;
private readonly IDatabaseFile db;
private readonly State state;
[Obsolete("Designer")]
public DebugPageModel() : this(null!, DummyDatabaseFile.Instance) {}
public DebugPageModel() : this(null!, State.Dummy) {}
public DebugPageModel(Window window, IDatabaseFile db) {
public DebugPageModel(Window window, State state) {
this.window = window;
this.db = db;
this.state = state;
}
public async void OnClickAddRandomDataToDatabase() {
@@ -44,13 +43,7 @@ namespace DHT.Desktop.Main.Pages {
return;
}
ProgressDialog progressDialog = new ProgressDialog {
DataContext = new ProgressDialogModel(async callback => await GenerateRandomData(channels, users, messages, callback)) {
Title = "Generating Random Data"
}
};
await progressDialog.ShowDialog(window);
await ProgressDialog.Show(window, "Generating Random Data", async (_, callback) => await GenerateRandomData(channels, users, messages, callback));
}
private const int BatchSize = 500;
@@ -83,12 +76,9 @@ namespace DHT.Desktop.Main.Pages {
Discriminator = rand.Next(0, 9999).ToString(),
}).ToArray();
db.AddServer(server);
db.AddUsers(users);
foreach (var channel in channels) {
db.AddChannel(channel);
}
await state.Db.Users.Add(users);
await state.Db.Servers.Add([server]);
await state.Db.Channels.Add(channels);
var now = DateTimeOffset.Now;
int batchIndex = 0;
@@ -111,13 +101,13 @@ namespace DHT.Desktop.Main.Pages {
Timestamp = timeMillis,
EditTimestamp = editMillis,
RepliedToId = null,
Attachments = ImmutableArray<Attachment>.Empty,
Embeds = ImmutableArray<Embed>.Empty,
Reactions = ImmutableArray<Reaction>.Empty,
Attachments = ImmutableList<Attachment>.Empty,
Embeds = ImmutableList<Embed>.Empty,
Reactions = ImmutableList<Reaction>.Empty,
};
}).ToArray();
db.AddMessages(messages);
await state.Db.Messages.Add(messages);
messageCount -= BatchSize;
await callback.Update("Adding messages in batches of " + BatchSize, ++batchIndex, batchCount);
@@ -138,7 +128,7 @@ namespace DHT.Desktop.Main.Pages {
return options[(int) Math.Floor(options.Length * rand.NextDouble() * rand.NextDouble())];
}
private static readonly string[] RandomWords = {
private static readonly string[] RandomWords = [
"apple", "apricot", "artichoke", "arugula", "asparagus", "avocado",
"banana", "bean", "beechnut", "beet", "blackberry", "blackcurrant", "blueberry", "boysenberry", "bramble", "broccoli",
"cabbage", "cacao", "cantaloupe", "caper", "carambola", "carrot", "cauliflower", "celery", "chard", "cherry", "chokeberry", "citron", "clementine", "coconut", "corn", "crabapple", "cranberry", "cucumber", "currant",
@@ -161,8 +151,8 @@ namespace DHT.Desktop.Main.Pages {
"vanilla",
"watercress", "watermelon",
"yam",
"zucchini",
};
"zucchini"
];
private static string RandomText(Random rand, int maxWords) {
int wordCount = 1 + (int) Math.Floor(maxWords * Math.Pow(rand.NextDouble(), 3));
@@ -171,10 +161,8 @@ namespace DHT.Desktop.Main.Pages {
}
}
#else
using DHT.Utils.Models;
namespace DHT.Desktop.Main.Pages {
sealed class DebugPageModel : BaseModel {
sealed class DebugPageModel {
public string GenerateChannels { get; set; } = "0";
public string GenerateUsers { get; set; } = "0";
public string GenerateMessages { get; set; } = "0";

View File

@@ -4,7 +4,8 @@
xmlns:mc="http://schemas.openxmlformats.org/markup-compatibility/2006"
xmlns:pages="clr-namespace:DHT.Desktop.Main.Pages"
mc:Ignorable="d" d:DesignWidth="800" d:DesignHeight="450"
x:Class="DHT.Desktop.Main.Pages.TrackingPage">
x:Class="DHT.Desktop.Main.Pages.TrackingPage"
x:DataType="pages:TrackingPageModel">
<Design.DataContext>
<pages:TrackingPageModel />
@@ -15,7 +16,7 @@
To start tracking messages, copy the tracking script and paste it into the console of either the Discord app, or your browser. The console is usually opened by pressing Ctrl+Shift+I.
</TextBlock>
<StackPanel DockPanel.Dock="Left" Orientation="Horizontal" Spacing="10">
<Button x:Name="CopyTrackingScript" Click="CopyTrackingScriptButton_OnClick">Copy Tracking Script</Button>
<Button x:Name="CopyTrackingScript" Click="CopyTrackingScriptButton_OnClick" IsEnabled="{Binding IsCopyTrackingScriptButtonEnabled}">Copy Tracking Script</Button>
</StackPanel>
<TextBlock TextWrapping="Wrap" Margin="0 5 0 0">
By default, the Discord app does not allow opening the console. The button below will change a hidden setting in the Discord app that controls whether the Ctrl+Shift+I shortcut is enabled.

View File

@@ -1,35 +1,39 @@
using System;
using System.Threading;
using System.Threading.Tasks;
using System.Web;
using Avalonia.Controls;
using CommunityToolkit.Mvvm.ComponentModel;
using DHT.Desktop.Dialogs.Message;
using DHT.Desktop.Discord;
using DHT.Desktop.Server;
using DHT.Utils.Models;
using static DHT.Desktop.Program;
namespace DHT.Desktop.Main.Pages;
sealed class TrackingPageModel : BaseModel {
private bool areDevToolsEnabled;
sealed partial class TrackingPageModel : ObservableObject {
[ObservableProperty(Setter = Access.Private)]
private bool isCopyTrackingScriptButtonEnabled = true;
private bool AreDevToolsEnabled {
get => areDevToolsEnabled;
set {
Change(ref areDevToolsEnabled, value);
OnPropertyChanged(nameof(ToggleAppDevToolsButtonText));
}
}
[ObservableProperty(Setter = Access.Private)]
[NotifyPropertyChangedFor(nameof(ToggleAppDevToolsButtonText))]
private bool? areDevToolsEnabled = null;
public bool IsToggleAppDevToolsButtonEnabled { get; private set; } = true;
[ObservableProperty(Setter = Access.Private)]
[NotifyPropertyChangedFor(nameof(ToggleAppDevToolsButtonText))]
private bool isToggleAppDevToolsButtonEnabled = false;
public string ToggleAppDevToolsButtonText {
get {
if (!AreDevToolsEnabled.HasValue) {
return "Loading...";
}
if (!IsToggleAppDevToolsButtonEnabled) {
return "Unavailable";
}
return AreDevToolsEnabled ? "Disable Ctrl+Shift+I" : "Enable Ctrl+Shift+I";
return AreDevToolsEnabled.Value ? "Disable Ctrl+Shift+I" : "Enable Ctrl+Shift+I";
}
}
@@ -40,21 +44,22 @@ sealed class TrackingPageModel : BaseModel {
public TrackingPageModel(Window window) {
this.window = window;
Task.Factory.StartNew(InitializeDevToolsToggle, CancellationToken.None, TaskCreationOptions.None, TaskScheduler.FromCurrentSynchronizationContext());
}
public async Task Initialize() {
bool? devToolsEnabled = await DiscordAppSettings.AreDevToolsEnabled();
if (devToolsEnabled.HasValue) {
AreDevToolsEnabled = devToolsEnabled.Value;
}
else {
IsToggleAppDevToolsButtonEnabled = false;
OnPropertyChanged(nameof(IsToggleAppDevToolsButtonEnabled));
public async Task<bool> OnClickCopyTrackingScript() {
IsCopyTrackingScriptButtonEnabled = false;
try {
return await CopyTrackingScript();
} finally {
IsCopyTrackingScriptButtonEnabled = true;
}
}
public async Task<bool> OnClickCopyTrackingScript() {
string url = $"http://127.0.0.1:{ServerManager.Port}/get-tracking-script?token={HttpUtility.UrlEncode(ServerManager.Token)}";
private async Task<bool> CopyTrackingScript() {
string url = $"http://127.0.0.1:{ServerConfiguration.Port}/get-tracking-script?token={HttpUtility.UrlEncode(ServerConfiguration.Token)}";
string script = (await Resources.ReadTextAsync("tracker-loader.js")).Trim().Replace("{url}", url);
var clipboard = window.Clipboard;
@@ -72,10 +77,26 @@ sealed class TrackingPageModel : BaseModel {
}
}
public async void OnClickToggleAppDevTools() {
private async Task InitializeDevToolsToggle() {
bool? devToolsEnabled = await Task.Run(DiscordAppSettings.AreDevToolsEnabled);
if (devToolsEnabled.HasValue) {
AreDevToolsEnabled = devToolsEnabled.Value;
IsToggleAppDevToolsButtonEnabled = true;
}
else {
IsToggleAppDevToolsButtonEnabled = false;
}
}
public async Task OnClickToggleAppDevTools() {
const string DialogTitle = "Discord App Settings File";
bool oldState = AreDevToolsEnabled;
if (!AreDevToolsEnabled.HasValue) {
return;
}
bool oldState = AreDevToolsEnabled.Value;
bool newState = !oldState;
switch (await DiscordAppSettings.ConfigureDevTools(newState)) {

View File

@@ -5,7 +5,8 @@
xmlns:pages="clr-namespace:DHT.Desktop.Main.Pages"
xmlns:controls="clr-namespace:DHT.Desktop.Main.Controls"
mc:Ignorable="d" d:DesignWidth="800" d:DesignHeight="450"
x:Class="DHT.Desktop.Main.Pages.ViewerPage">
x:Class="DHT.Desktop.Main.Pages.ViewerPage"
x:DataType="pages:ViewerPageModel">
<Design.DataContext>
<pages:ViewerPageModel />

View File

@@ -8,46 +8,47 @@ using System.Threading.Tasks;
using System.Web;
using Avalonia.Controls;
using Avalonia.Platform.Storage;
using CommunityToolkit.Mvvm.ComponentModel;
using DHT.Desktop.Common;
using DHT.Desktop.Dialogs.File;
using DHT.Desktop.Dialogs.Message;
using DHT.Desktop.Dialogs.Progress;
using DHT.Desktop.Main.Controls;
using DHT.Desktop.Server;
using DHT.Server;
using DHT.Server.Data.Filters;
using DHT.Server.Database;
using DHT.Server.Database.Export;
using DHT.Server.Database.Export.Strategy;
using DHT.Utils.Models;
using static DHT.Desktop.Program;
namespace DHT.Desktop.Main.Pages;
sealed class ViewerPageModel : BaseModel, IDisposable {
public static readonly ConcurrentBag<string> TemporaryFiles = new ();
sealed partial class ViewerPageModel : ObservableObject, IDisposable {
public static readonly ConcurrentBag<string> TemporaryFiles = [];
private static readonly FilePickerFileType[] ViewerFileTypes = [
FileDialogs.CreateFilter("Discord History Viewer", ["html"])
];
public bool DatabaseToolFilterModeKeep { get; set; } = true;
public bool DatabaseToolFilterModeRemove { get; set; } = false;
[ObservableProperty]
private bool hasFilters = false;
public bool HasFilters {
get => hasFilters;
set => Change(ref hasFilters, value);
}
private MessageFilterPanelModel FilterModel { get; }
public MessageFilterPanelModel FilterModel { get; }
private readonly Window window;
private readonly IDatabaseFile db;
private readonly State state;
[Obsolete("Designer")]
public ViewerPageModel() : this(null!, DummyDatabaseFile.Instance) {}
public ViewerPageModel() : this(null!, State.Dummy) {}
public ViewerPageModel(Window window, IDatabaseFile db) {
public ViewerPageModel(Window window, State state) {
this.window = window;
this.db = db;
this.state = state;
FilterModel = new MessageFilterPanelModel(window, db, "Will export");
FilterModel = new MessageFilterPanelModel(window, state, "Will export");
FilterModel.FilterPropertyChanged += OnFilterPropertyChanged;
}
@@ -59,6 +60,61 @@ sealed class ViewerPageModel : BaseModel, IDisposable {
HasFilters = FilterModel.HasAnyFilters;
}
public async void OnClickOpenViewer() {
try {
var fullPath = await PrepareTemporaryViewerFile();
var strategy = new LiveViewerExportStrategy(ServerConfiguration.Port, ServerConfiguration.Token);
await ProgressDialog.ShowIndeterminate(window, "Open Viewer", "Creating viewer...", _ => Task.Run(() => WriteViewerFile(fullPath, strategy)));
Process.Start(new ProcessStartInfo(fullPath) {
UseShellExecute = true
});
} catch (Exception e) {
await Dialog.ShowOk(window, "Open Viewer", "Could not create or save viewer: " + e.Message);
}
}
private async Task<string> PrepareTemporaryViewerFile() {
return await Task.Run(() => {
string rootPath = Path.Combine(Path.GetTempPath(), "DiscordHistoryTracker");
string filenameBase = Path.GetFileNameWithoutExtension(state.Db.Path) + "-" + DateTime.Now.ToString("yyyy-MM-dd");
string fullPath = Path.Combine(rootPath, filenameBase + ".html");
int counter = 0;
while (File.Exists(fullPath)) {
++counter;
fullPath = Path.Combine(rootPath, filenameBase + "-" + counter + ".html");
}
TemporaryFiles.Add(fullPath);
Directory.CreateDirectory(rootPath);
return fullPath;
});
}
public async void OnClickSaveViewer() {
string? path = await window.StorageProvider.SaveFile(new FilePickerSaveOptions {
Title = "Save Viewer",
FileTypeChoices = ViewerFileTypes,
SuggestedFileName = Path.GetFileNameWithoutExtension(state.Db.Path) + ".html",
SuggestedStartLocation = await FileDialogs.GetSuggestedStartLocation(window, Path.GetDirectoryName(state.Db.Path)),
});
if (path == null) {
return;
}
try {
await ProgressDialog.ShowIndeterminate(window, "Save Viewer", "Creating viewer...", _ => Task.Run(() => WriteViewerFile(path, StandaloneViewerExportStrategy.Instance)));
} catch (Exception e) {
await Dialog.ShowOk(window, "Save Viewer", "Could not create or save viewer: " + e.Message);
}
}
private async Task WriteViewerFile(string path, IViewerExportStrategy strategy) {
const string ArchiveTag = "/*[ARCHIVE]*/";
@@ -72,7 +128,7 @@ sealed class ViewerPageModel : BaseModel, IDisposable {
string jsonTempFile = path + ".tmp";
await using (var jsonStream = new FileStream(jsonTempFile, FileMode.Create, FileAccess.ReadWrite, FileShare.Read)) {
await ViewerJsonExport.Generate(jsonStream, strategy, db, FilterModel.CreateFilter());
await ViewerJsonExport.Generate(jsonStream, strategy, state.Db, FilterModel.CreateFilter());
char[] jsonBuffer = new char[Math.Min(32768, jsonStream.Position)];
jsonStream.Position = 0;
@@ -96,54 +152,23 @@ sealed class ViewerPageModel : BaseModel, IDisposable {
File.Delete(jsonTempFile);
}
public async void OnClickOpenViewer() {
string rootPath = Path.Combine(Path.GetTempPath(), "DiscordHistoryTracker");
string filenameBase = Path.GetFileNameWithoutExtension(db.Path) + "-" + DateTime.Now.ToString("yyyy-MM-dd");
string fullPath = Path.Combine(rootPath, filenameBase + ".html");
int counter = 0;
while (File.Exists(fullPath)) {
++counter;
fullPath = Path.Combine(rootPath, filenameBase + "-" + counter + ".html");
}
TemporaryFiles.Add(fullPath);
Directory.CreateDirectory(rootPath);
await WriteViewerFile(fullPath, new LiveViewerExportStrategy(ServerManager.Port, ServerManager.Token));
Process.Start(new ProcessStartInfo(fullPath) { UseShellExecute = true });
}
private static readonly FilePickerFileType[] ViewerFileTypes = {
FileDialogs.CreateFilter("Discord History Viewer", new string[] { "html" }),
};
public async void OnClickSaveViewer() {
string? path = await window.StorageProvider.SaveFile(new FilePickerSaveOptions {
Title = "Save Viewer",
FileTypeChoices = ViewerFileTypes,
SuggestedFileName = Path.GetFileNameWithoutExtension(db.Path) + ".html",
SuggestedStartLocation = await FileDialogs.GetSuggestedStartLocation(window, Path.GetDirectoryName(db.Path)),
});
if (path != null) {
await WriteViewerFile(path, StandaloneViewerExportStrategy.Instance);
}
}
public async void OnClickApplyFiltersToDatabase() {
public async Task OnClickApplyFiltersToDatabase() {
var filter = FilterModel.CreateFilter();
var messageCount = await ProgressDialog.ShowIndeterminate(window, "Apply Filters", "Counting matching messages...", _ => state.Db.Messages.Count(filter));
if (DatabaseToolFilterModeKeep) {
if (DialogResult.YesNo.Yes == await Dialog.ShowYesNo(window, "Keep Matching Messages in This Database", db.CountMessages(filter).Pluralize("message") + " will be kept, and the rest will be removed from this database. This action cannot be undone. Proceed?")) {
db.RemoveMessages(filter, FilterRemovalMode.KeepMatching);
if (DialogResult.YesNo.Yes == await Dialog.ShowYesNo(window, "Keep Matching Messages in This Database", messageCount.Pluralize("message") + " will be kept, and the rest will be removed from this database. This action cannot be undone. Proceed?")) {
await ApplyFilterToDatabase(filter, FilterRemovalMode.KeepMatching);
}
}
else if (DatabaseToolFilterModeRemove) {
if (DialogResult.YesNo.Yes == await Dialog.ShowYesNo(window, "Remove Matching Messages in This Database", db.CountMessages(filter).Pluralize("message") + " will be removed from this database. This action cannot be undone. Proceed?")) {
db.RemoveMessages(filter, FilterRemovalMode.RemoveMatching);
if (DialogResult.YesNo.Yes == await Dialog.ShowYesNo(window, "Remove Matching Messages in This Database", messageCount.Pluralize("message") + " will be removed from this database. This action cannot be undone. Proceed?")) {
await ApplyFilterToDatabase(filter, FilterRemovalMode.RemoveMatching);
}
}
}
private async Task ApplyFilterToDatabase(MessageFilter filter, FilterRemovalMode removalMode) {
await ProgressDialog.ShowIndeterminate(window, "Apply Filters", "Removing messages...", _ => state.Db.Messages.Remove(filter, removalMode));
}
}

View File

@@ -5,7 +5,8 @@
xmlns:controls="clr-namespace:DHT.Desktop.Main.Controls"
xmlns:screens="clr-namespace:DHT.Desktop.Main.Screens"
mc:Ignorable="d" d:DesignWidth="800" d:DesignHeight="450"
x:Class="DHT.Desktop.Main.Screens.MainContentScreen">
x:Class="DHT.Desktop.Main.Screens.MainContentScreen"
x:DataType="screens:MainContentScreenModel">
<Design.DataContext>
<screens:MainContentScreenModel />

View File

@@ -1,19 +1,12 @@
using System;
using System.Threading.Tasks;
using Avalonia.Controls;
using DHT.Desktop.Dialogs.Message;
using DHT.Desktop.Main.Controls;
using DHT.Desktop.Main.Pages;
using DHT.Desktop.Server;
using DHT.Server.Database;
using DHT.Server.Service;
using DHT.Utils.Logging;
using DHT.Server;
namespace DHT.Desktop.Main.Screens;
sealed class MainContentScreenModel : IDisposable {
private static readonly Log Log = Log.ForType<MainContentScreenModel>();
public DatabasePage DatabasePage { get; }
private DatabasePageModel DatabasePageModel { get; }
@@ -35,7 +28,7 @@ sealed class MainContentScreenModel : IDisposable {
public bool HasDebugPage => true;
private DebugPageModel DebugPageModel { get; }
#else
public bool HasDebugPage => false;
public bool HasDebugPage => false;
#endif
public StatusBarModel StatusBarModel { get; }
@@ -49,71 +42,39 @@ sealed class MainContentScreenModel : IDisposable {
}
}
private readonly Window window;
private readonly ServerManager serverManager;
[Obsolete("Designer")]
public MainContentScreenModel() : this(null!, DummyDatabaseFile.Instance) {}
public MainContentScreenModel() : this(null!, State.Dummy) {}
public MainContentScreenModel(Window window, IDatabaseFile db) {
this.window = window;
this.serverManager = new ServerManager(db);
ServerLauncher.ServerManagementExceptionCaught += ServerLauncherOnServerManagementExceptionCaught;
DatabasePageModel = new DatabasePageModel(window, db);
public MainContentScreenModel(Window window, State state) {
DatabasePageModel = new DatabasePageModel(window, state);
DatabasePage = new DatabasePage { DataContext = DatabasePageModel };
TrackingPageModel = new TrackingPageModel(window);
TrackingPage = new TrackingPage { DataContext = TrackingPageModel };
AttachmentsPageModel = new AttachmentsPageModel(db);
AttachmentsPageModel = new AttachmentsPageModel(state);
AttachmentsPage = new AttachmentsPage { DataContext = AttachmentsPageModel };
ViewerPageModel = new ViewerPageModel(window, db);
ViewerPageModel = new ViewerPageModel(window, state);
ViewerPage = new ViewerPage { DataContext = ViewerPageModel };
AdvancedPageModel = new AdvancedPageModel(window, db, serverManager);
AdvancedPageModel = new AdvancedPageModel(window, state);
AdvancedPage = new AdvancedPage { DataContext = AdvancedPageModel };
#if DEBUG
DebugPageModel = new DebugPageModel(window, db);
DebugPageModel = new DebugPageModel(window, state);
DebugPage = new DebugPage { DataContext = DebugPageModel };
#else
DebugPage = null;
DebugPage = null;
#endif
StatusBarModel = new StatusBarModel(db.Statistics);
AdvancedPageModel.ServerConfigurationModel.ServerStatusChanged += OnServerStatusChanged;
DatabaseClosed += OnDatabaseClosed;
StatusBarModel.CurrentStatus = serverManager.IsRunning ? StatusBarModel.Status.Ready : StatusBarModel.Status.Stopped;
}
public async Task Initialize() {
await TrackingPageModel.Initialize();
AdvancedPageModel.Initialize();
serverManager.Launch();
StatusBarModel = new StatusBarModel(state);
}
public void Dispose() {
ServerLauncher.ServerManagementExceptionCaught -= ServerLauncherOnServerManagementExceptionCaught;
AttachmentsPageModel.Dispose();
ViewerPageModel.Dispose();
serverManager.Dispose();
}
private void OnServerStatusChanged(object? sender, StatusBarModel.Status e) {
StatusBarModel.CurrentStatus = e;
}
private void OnDatabaseClosed(object? sender, EventArgs e) {
serverManager.Stop();
}
private async void ServerLauncherOnServerManagementExceptionCaught(object? sender, Exception ex) {
Log.Error(ex);
await Dialog.ShowOk(window, "Internal Server Error", ex.Message);
AdvancedPageModel.Dispose();
StatusBarModel.Dispose();
}
}

View File

@@ -4,7 +4,8 @@
xmlns:mc="http://schemas.openxmlformats.org/markup-compatibility/2006"
xmlns:screens="clr-namespace:DHT.Desktop.Main.Screens"
mc:Ignorable="d" d:DesignWidth="800" d:DesignHeight="450"
x:Class="DHT.Desktop.Main.Screens.WelcomeScreen">
x:Class="DHT.Desktop.Main.Screens.WelcomeScreen"
x:DataType="screens:WelcomeScreenModel">
<Design.DataContext>
<screens:WelcomeScreenModel />
@@ -31,7 +32,7 @@
<TextBlock Text="{Binding Version, StringFormat=Discord History Tracker v{0}}" FontSize="25" Margin="0 0 0 30" HorizontalAlignment="Center" />
<StackPanel Orientation="Horizontal" HorizontalAlignment="Center">
<Button Command="{Binding OpenOrCreateDatabase}">Open or Create Database</Button>
<Button Command="{Binding OpenOrCreateDatabase}" IsEnabled="{Binding IsOpenOrCreateDatabaseButtonEnabled}">Open or Create Database</Button>
<Button Command="{Binding ShowAboutDialog}">About</Button>
<Button Command="{Binding Exit}">Exit</Button>
</StackPanel>

View File

@@ -1,20 +1,25 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Threading.Tasks;
using Avalonia.Controls;
using CommunityToolkit.Mvvm.ComponentModel;
using DHT.Desktop.Common;
using DHT.Desktop.Dialogs.Message;
using DHT.Desktop.Dialogs.Progress;
using DHT.Server.Database;
using DHT.Utils.Models;
using DHT.Server.Database.Sqlite.Utils;
namespace DHT.Desktop.Main.Screens;
sealed class WelcomeScreenModel : BaseModel, IDisposable {
sealed partial class WelcomeScreenModel : ObservableObject {
public string Version => Program.Version;
public IDatabaseFile? Db { get; private set; }
public bool HasDatabase => Db != null;
[ObservableProperty(Setter = Access.Private)]
private bool isOpenOrCreateDatabaseButtonEnabled = true;
public event EventHandler<IDatabaseFile>? DatabaseSelected;
private readonly Window window;
private string? dbFilePath;
@@ -26,45 +31,89 @@ sealed class WelcomeScreenModel : BaseModel, IDisposable {
this.window = window;
}
public async void OpenOrCreateDatabase() {
var path = await DatabaseGui.NewOpenOrCreateDatabaseFileDialog(window, Path.GetDirectoryName(dbFilePath));
if (path != null) {
await OpenOrCreateDatabaseFromPath(path);
public async Task OpenOrCreateDatabase() {
IsOpenOrCreateDatabaseButtonEnabled = false;
try {
var path = await DatabaseGui.NewOpenOrCreateDatabaseFileDialog(window, Path.GetDirectoryName(dbFilePath));
if (path != null) {
await OpenOrCreateDatabaseFromPath(path);
}
} finally {
IsOpenOrCreateDatabaseButtonEnabled = true;
}
}
public async Task OpenOrCreateDatabaseFromPath(string path) {
if (Db != null) {
Db = null;
dbFilePath = path;
var db = await DatabaseGui.TryOpenOrCreateDatabaseFromPath(path, window, new SchemaUpgradeCallbacks(window));
if (db != null) {
DatabaseSelected?.Invoke(this, db);
}
}
private sealed class SchemaUpgradeCallbacks : ISchemaUpgradeCallbacks {
private readonly Window window;
public SchemaUpgradeCallbacks(Window window) {
this.window = window;
}
dbFilePath = path;
Db = await DatabaseGui.TryOpenOrCreateDatabaseFromPath(path, window, CheckCanUpgradeDatabase);
public async Task<bool> CanUpgrade() {
return DialogResult.YesNo.Yes == await DatabaseGui.ShowCanUpgradeDatabaseDialog(window);
}
OnPropertyChanged(nameof(Db));
OnPropertyChanged(nameof(HasDatabase));
public async Task Start(int versionSteps, Func<ISchemaUpgradeCallbacks.IProgressReporter, Task> doUpgrade) {
async Task StartUpgrade(IReadOnlyList<IProgressCallback> callbacks) {
var reporter = new ProgressReporter(versionSteps, callbacks);
await reporter.NextVersion();
await Task.Delay(TimeSpan.FromMilliseconds(800));
await doUpgrade(reporter);
await Task.Delay(TimeSpan.FromMilliseconds(600));
}
await new ProgressDialog { DataContext = new ProgressDialogModel("Upgrading Database", StartUpgrade, progressItems: 3) }.ShowProgressDialog(window);
}
private sealed class ProgressReporter : ISchemaUpgradeCallbacks.IProgressReporter {
private readonly IReadOnlyList<IProgressCallback> callbacks;
private readonly int versionSteps;
private int versionProgress = 0;
public ProgressReporter(int versionSteps, IReadOnlyList<IProgressCallback> callbacks) {
this.callbacks = callbacks;
this.versionSteps = versionSteps;
}
public async Task NextVersion() {
await callbacks[0].Update("Upgrading schema version...", versionProgress++, versionSteps);
await HideChildren(0);
}
public async Task MainWork(string message, int finishedItems, int totalItems) {
await callbacks[1].Update(message, finishedItems, totalItems);
await HideChildren(1);
}
public async Task SubWork(string message, int finishedItems, int totalItems) {
await callbacks[2].Update(message, finishedItems, totalItems);
await HideChildren(2);
}
private async Task HideChildren(int parentIndex) {
for (int i = parentIndex + 1; i < callbacks.Count; i++) {
await callbacks[i].Hide();
}
}
}
}
private async Task<bool> CheckCanUpgradeDatabase() {
return DialogResult.YesNo.Yes == await DatabaseGui.ShowCanUpgradeDatabaseDialog(window);
}
public void CloseDatabase() {
Dispose();
OnPropertyChanged(nameof(Db));
OnPropertyChanged(nameof(HasDatabase));
}
public async void ShowAboutDialog() {
await new AboutWindow { DataContext = new AboutWindowModel() }.ShowDialog(this.window);
public async Task ShowAboutDialog() {
await new AboutWindow { DataContext = new AboutWindowModel() }.ShowDialog(window);
}
public void Exit() {
window.Close();
}
public void Dispose() {
Db?.Dispose();
Db = null;
}
}

View File

@@ -1,6 +1,8 @@
using System.Globalization;
using System;
using System.Globalization;
using System.Reflection;
using Avalonia;
using DHT.Utils.Logging;
using DHT.Utils.Resources;
namespace DHT.Desktop;
@@ -9,6 +11,7 @@ static class Program {
public static string Version { get; }
public static CultureInfo Culture { get; }
public static ResourceLoader Resources { get; }
public static Arguments Arguments { get; }
static Program() {
var assembly = Assembly.GetExecutingAssembly();
@@ -25,10 +28,21 @@ static class Program {
CultureInfo.DefaultThreadCurrentUICulture = CultureInfo.InvariantCulture;
Resources = new ResourceLoader(assembly);
Arguments = new Arguments(Environment.GetCommandLineArgs());
}
public static void Main(string[] args) {
BuildAvaloniaApp().StartWithClassicDesktopLifetime(args);
if (Arguments.Console && OperatingSystem.IsWindows()) {
WindowsConsole.AllocConsole();
}
try {
BuildAvaloniaApp().StartWithClassicDesktopLifetime(args);
} finally {
if (Arguments.Console && OperatingSystem.IsWindows()) {
WindowsConsole.FreeConsole();
}
}
}
private static AppBuilder BuildAvaloniaApp() {

View File

@@ -0,0 +1,8 @@
using DHT.Server.Service;
namespace DHT.Desktop.Server;
static class ServerConfiguration {
public static ushort Port { get; set; } = ServerUtils.FindAvailablePort(50000, 60000);
public static string Token { get; set; } = ServerUtils.GenerateRandomToken(20);
}

View File

@@ -1,50 +0,0 @@
using System;
using DHT.Server.Database;
using DHT.Server.Service;
namespace DHT.Desktop.Server;
sealed class ServerManager : IDisposable {
public static ushort Port { get; set; } = ServerUtils.FindAvailablePort(50000, 60000);
public static string Token { get; set; } = ServerUtils.GenerateRandomToken(20);
private static ServerManager? instance;
public bool IsRunning => ServerLauncher.IsRunning;
private readonly IDatabaseFile db;
public ServerManager(IDatabaseFile db) {
if (db != DummyDatabaseFile.Instance) {
if (instance != null) {
throw new InvalidOperationException("Only one instance of ServerManager can exist at the same time!");
}
instance = this;
}
this.db = db;
}
public void Launch() {
ServerLauncher.Relaunch(Port, Token, db);
}
public void Relaunch(ushort port, string token) {
Port = port;
Token = token;
Launch();
}
public void Stop() {
ServerLauncher.Stop();
}
public void Dispose() {
Stop();
if (instance == this) {
instance = null;
}
}
}

View File

@@ -2,7 +2,7 @@
<PropertyGroup>
<TargetFramework>net8.0</TargetFramework>
<LangVersion>11</LangVersion>
<LangVersion>12</LangVersion>
<Nullable>enable</Nullable>
</PropertyGroup>
@@ -19,9 +19,21 @@
</PropertyGroup>
<PropertyGroup>
<SuppressTrimAnalysisWarnings>false</SuppressTrimAnalysisWarnings>
<PublishTrimmed>true</PublishTrimmed>
<TrimMode>partial</TrimMode>
<JsonSerializerIsReflectionEnabledByDefault>true</JsonSerializerIsReflectionEnabledByDefault>
<EnableUnsafeBinaryFormatterSerialization>false</EnableUnsafeBinaryFormatterSerialization>
<EnableUnsafeUTF7Encoding>false</EnableUnsafeUTF7Encoding>
<EventSourceSupport>false</EventSourceSupport>
<HttpActivityPropagationSupport>false</HttpActivityPropagationSupport>
<JsonSerializerIsReflectionEnabledByDefault>false</JsonSerializerIsReflectionEnabledByDefault>
</PropertyGroup>
<PropertyGroup>
<PublishSingleFile>true</PublishSingleFile>
<PublishReadyToRun>false</PublishReadyToRun>
<EnableCompressionInSingleFile>true</EnableCompressionInSingleFile>
<IncludeNativeLibrariesForSelfExtract>true</IncludeNativeLibrariesForSelfExtract>
</PropertyGroup>
<PropertyGroup Condition=" '$(Configuration)' == 'Release' ">

6
app/NuGet.Config Normal file
View File

@@ -0,0 +1,6 @@
<?xml version="1.0" encoding="utf-8"?>
<configuration>
<packageSources>
<add key="chylex's repository" value="https://nuget.chylex.com/feed/index.json" />
</packageSources>
</configuration>

View File

@@ -8,5 +8,6 @@ namespace DHT.Server.Data;
public enum DownloadStatus {
Enqueued = 0,
GenericError = 1,
Downloading = 2,
Success = HttpStatusCode.OK
}

View File

@@ -10,7 +10,7 @@ public readonly struct Message {
public long Timestamp { get; init; }
public long? EditTimestamp { get; init; }
public ulong? RepliedToId { get; init; }
public ImmutableArray<Attachment> Attachments { get; init; }
public ImmutableArray<Embed> Embeds { get; init; }
public ImmutableArray<Reaction> Reactions { get; init; }
public ImmutableList<Attachment> Attachments { get; init; }
public ImmutableList<Embed> Embeds { get; init; }
public ImmutableList<Reaction> Reactions { get; init; }
}

View File

@@ -1,29 +1,32 @@
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using DHT.Server.Data;
namespace DHT.Server.Database;
public static class DatabaseExtensions {
public static void AddFrom(this IDatabaseFile target, IDatabaseFile source) {
target.AddServers(source.GetAllServers());
target.AddChannels(source.GetAllChannels());
target.AddUsers(source.GetAllUsers().ToArray());
target.AddMessages(source.GetMessages().ToArray());
public static async Task AddFrom(this IDatabaseFile target, IDatabaseFile source) {
await target.Users.Add(await source.Users.Get().ToListAsync());
await target.Servers.Add(await source.Servers.Get().ToListAsync());
await target.Channels.Add(await source.Channels.Get().ToListAsync());
foreach (var download in source.GetDownloadsWithoutData()) {
target.AddDownload(download.Status == DownloadStatus.Success ? source.GetDownloadWithData(download) : download);
const int MessageBatchSize = 100;
List<Message> batchedMessages = new (MessageBatchSize);
await foreach (var message in source.Messages.Get()) {
batchedMessages.Add(message);
if (batchedMessages.Count >= MessageBatchSize) {
await target.Messages.Add(batchedMessages);
batchedMessages.Clear();
}
}
}
await target.Messages.Add(batchedMessages);
internal static void AddServers(this IDatabaseFile target, IEnumerable<Data.Server> servers) {
foreach (var server in servers) {
target.AddServer(server);
}
}
internal static void AddChannels(this IDatabaseFile target, IEnumerable<Channel> channels) {
foreach (var channel in channels) {
target.AddChannel(channel);
await foreach (var download in source.Downloads.GetWithoutData()) {
await target.Downloads.AddDownload(download.Status == DownloadStatus.Success ? await source.Downloads.HydrateWithData(download) : download);
}
}
}

View File

@@ -1,46 +0,0 @@
using DHT.Utils.Models;
namespace DHT.Server.Database;
/// <summary>
/// A live view of database statistics.
/// Some of the totals are computed asynchronously and may not reflect the most recent version of the database, or may not be available at all until computed for the first time.
/// </summary>
public sealed class DatabaseStatistics : BaseModel {
private long totalServers;
private long totalChannels;
private long totalUsers;
private long? totalMessages;
private long? totalAttachments;
private long? totalDownloads;
public long TotalServers {
get => totalServers;
internal set => Change(ref totalServers, value);
}
public long TotalChannels {
get => totalChannels;
internal set => Change(ref totalChannels, value);
}
public long TotalUsers {
get => totalUsers;
internal set => Change(ref totalUsers, value);
}
public long? TotalMessages {
get => totalMessages;
internal set => Change(ref totalMessages, value);
}
public long? TotalAttachments {
get => totalAttachments;
internal set => Change(ref totalAttachments, value);
}
public long? TotalDownloads {
get => totalDownloads;
internal set => Change(ref totalDownloads, value);
}
}

View File

@@ -1,11 +0,0 @@
namespace DHT.Server.Database;
/// <summary>
/// A complete snapshot of database statistics at a particular point in time.
/// </summary>
public readonly struct DatabaseStatisticsSnapshot {
public long TotalServers { get; internal init; }
public long TotalChannels { get; internal init; }
public long TotalUsers { get; internal init; }
public long TotalMessages { get; internal init; }
}

View File

@@ -1,90 +1,29 @@
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using DHT.Server.Data;
using DHT.Server.Data.Aggregations;
using DHT.Server.Data.Filters;
using DHT.Server.Download;
using System.Threading.Tasks;
using DHT.Server.Database.Repositories;
namespace DHT.Server.Database;
[SuppressMessage("ReSharper", "ArrangeObjectCreationWhenTypeNotEvident")]
public sealed class DummyDatabaseFile : IDatabaseFile {
public static DummyDatabaseFile Instance { get; } = new();
sealed class DummyDatabaseFile : IDatabaseFile {
public static DummyDatabaseFile Instance { get; } = new ();
public string Path => "";
public DatabaseStatistics Statistics { get; } = new();
public IUserRepository Users { get; } = new IUserRepository.Dummy();
public IServerRepository Servers { get; } = new IServerRepository.Dummy();
public IChannelRepository Channels { get; } = new IChannelRepository.Dummy();
public IMessageRepository Messages { get; } = new IMessageRepository.Dummy();
public IAttachmentRepository Attachments { get; } = new IAttachmentRepository.Dummy();
public IDownloadRepository Downloads { get; } = new IDownloadRepository.Dummy();
private DummyDatabaseFile() {}
public DatabaseStatisticsSnapshot SnapshotStatistics() {
return new();
public Task Vacuum() {
return Task.CompletedTask;
}
public void AddServer(Data.Server server) {}
public List<Data.Server> GetAllServers() {
return new();
public ValueTask DisposeAsync() {
return ValueTask.CompletedTask;
}
public void AddChannel(Channel channel) {}
public List<Channel> GetAllChannels() {
return new();
}
public void AddUsers(User[] users) {}
public List<User> GetAllUsers() {
return new();
}
public void AddMessages(Message[] messages) {}
public int CountMessages(MessageFilter? filter = null) {
return 0;
}
public List<Message> GetMessages(MessageFilter? filter = null) {
return new();
}
public HashSet<ulong> GetMessageIds(MessageFilter? filter = null) {
return new();
}
public void RemoveMessages(MessageFilter filter, FilterRemovalMode mode) {}
public int CountAttachments(AttachmentFilter? filter = null) {
return new();
}
public List<Data.Download> GetDownloadsWithoutData() {
return new();
}
public Data.Download GetDownloadWithData(Data.Download download) {
return download;
}
public DownloadedAttachment? GetDownloadedAttachment(string url) {
return null;
}
public void AddDownload(Data.Download download) {}
public void EnqueueDownloadItems(AttachmentFilter? filter = null) {}
public List<DownloadItem> GetEnqueuedDownloadItems(int count) {
return new();
}
public void RemoveDownloadItems(DownloadItemFilter? filter, FilterRemovalMode mode) {}
public DownloadStatusStatistics GetDownloadStatusStatistics() {
return new();
}
public void Vacuum() {}
public void Dispose() {}
}

View File

@@ -0,0 +1,3 @@
namespace DHT.Server.Database.Export;
readonly record struct Snowflake(ulong Id);

View File

@@ -0,0 +1,23 @@
using System;
using System.Text.Json;
using System.Text.Json.Serialization;
namespace DHT.Server.Database.Export;
sealed class SnowflakeJsonSerializer : JsonConverter<Snowflake> {
public override Snowflake Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) {
return new Snowflake(ulong.Parse(reader.GetString()!));
}
public override void Write(Utf8JsonWriter writer, Snowflake value, JsonSerializerOptions options) {
writer.WriteStringValue(value.Id.ToString());
}
public override Snowflake ReadAsPropertyName(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) {
return new Snowflake(ulong.Parse(reader.GetString()!));
}
public override void WriteAsPropertyName(Utf8JsonWriter writer, Snowflake value, JsonSerializerOptions options) {
writer.WritePropertyName(value.Id.ToString());
}
}

View File

@@ -0,0 +1,93 @@
using System.Collections.Generic;
using System.Text.Json.Serialization;
namespace DHT.Server.Database.Export;
sealed class ViewerJson {
public required JsonMeta Meta { get; init; }
public required Dictionary<Snowflake, Dictionary<Snowflake, JsonMessage>> Data { get; init; }
public sealed class JsonMeta {
public required Dictionary<Snowflake, JsonUser> Users { get; init; }
public required List<Snowflake> Userindex { get; init; }
public required List<JsonServer> Servers { get; init; }
public required Dictionary<Snowflake, JsonChannel> Channels { get; init; }
}
public sealed class JsonUser {
public required string Name { get; init; }
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? Avatar { get; init; }
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? Tag { get; init; }
}
public sealed class JsonServer {
public required string Name { get; init; }
public required string Type { get; init; }
}
public sealed class JsonChannel {
public required int Server { get; init; }
public required string Name { get; init; }
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? Parent { get; init; }
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public int? Position { get; init; }
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? Topic { get; init; }
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public bool? Nsfw { get; init; }
}
public sealed class JsonMessage {
public required int U { get; init; }
public required long T { get; init; }
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? M { get; init; }
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public long? Te { get; init; }
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? R { get; init; }
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public JsonMessageAttachment[]? A { get; init; }
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string[]? E { get; init; }
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public JsonMessageReaction[]? Re { get; init; }
}
public sealed class JsonMessageAttachment {
public required string Url { get; init; }
public required string Name { get; init; }
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public int? Width { get; set; }
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public int? Height { get; set; }
}
public sealed class JsonMessageReaction {
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? Id { get; init; }
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? N { get; init; }
public required bool A { get; init; }
public required int C { get; init; }
}
}

View File

@@ -0,0 +1,11 @@
using System.Text.Json.Serialization;
namespace DHT.Server.Database.Export;
[JsonSourceGenerationOptions(
Converters = [typeof(SnowflakeJsonSerializer)],
PropertyNamingPolicy = JsonKnownNamingPolicy.CamelCase,
GenerationMode = JsonSourceGenerationMode.Default
)]
[JsonSerializable(typeof(ViewerJson))]
sealed partial class ViewerJsonContext : JsonSerializerContext;

View File

@@ -21,7 +21,7 @@ public static class ViewerJsonExport {
var includedChannelIds = new HashSet<ulong>();
var includedServerIds = new HashSet<ulong>();
var includedMessages = db.GetMessages(filter);
var includedMessages = await db.Messages.Get(filter).ToListAsync();
var includedChannels = new List<Channel>();
foreach (var message in includedMessages) {
@@ -29,185 +29,145 @@ public static class ViewerJsonExport {
includedChannelIds.Add(message.Channel);
}
foreach (var channel in db.GetAllChannels()) {
await foreach (var channel in db.Channels.Get()) {
if (includedChannelIds.Contains(channel.Id)) {
includedChannels.Add(channel);
includedServerIds.Add(channel.Server);
}
}
var users = GenerateUserList(db, includedUserIds, out var userindex, out var userIndices);
var servers = GenerateServerList(db, includedServerIds, out var serverindex);
var channels = GenerateChannelList(includedChannels, serverindex);
var (users, userIndex, userIndices) = await GenerateUserList(db, includedUserIds);
var (servers, serverIndices) = await GenerateServerList(db, includedServerIds);
var channels = GenerateChannelList(includedChannels, serverIndices);
perf.Step("Collect database data");
var value = new {
meta = new { users, userindex, servers, channels },
data = GenerateMessageList(includedMessages, userIndices, strategy),
var value = new ViewerJson {
Meta = new ViewerJson.JsonMeta {
Users = users,
Userindex = userIndex,
Servers = servers,
Channels = channels
},
Data = GenerateMessageList(includedMessages, userIndices, strategy)
};
perf.Step("Generate value object");
var opts = new JsonSerializerOptions();
opts.Converters.Add(new ViewerJsonSnowflakeSerializer());
await JsonSerializer.SerializeAsync(stream, value, opts);
await JsonSerializer.SerializeAsync(stream, value, ViewerJsonContext.Default.ViewerJson);
perf.Step("Serialize to JSON");
perf.End();
}
private static Dictionary<string, object> GenerateUserList(IDatabaseFile db, HashSet<ulong> userIds, out List<string> userindex, out Dictionary<ulong, object> userIndices) {
var users = new Dictionary<string, object>();
userindex = new List<string>();
userIndices = new Dictionary<ulong, object>();
private static async Task<(Dictionary<Snowflake, ViewerJson.JsonUser> Users, List<Snowflake> UserIndex, Dictionary<ulong, int> UserIndices)> GenerateUserList(IDatabaseFile db, HashSet<ulong> userIds) {
var users = new Dictionary<Snowflake, ViewerJson.JsonUser>();
var userIndex = new List<Snowflake>();
var userIndices = new Dictionary<ulong, int>();
foreach (var user in db.GetAllUsers()) {
await foreach (var user in db.Users.Get()) {
var id = user.Id;
if (!userIds.Contains(id)) {
continue;
}
var obj = new Dictionary<string, object> {
["name"] = user.Name
};
if (user.AvatarUrl != null) {
obj["avatar"] = user.AvatarUrl;
}
if (user.Discriminator != null) {
obj["tag"] = user.Discriminator;
}
var idStr = id.ToString();
var idSnowflake = new Snowflake(id);
userIndices[id] = users.Count;
userindex.Add(idStr);
users[idStr] = obj;
userIndex.Add(idSnowflake);
users[idSnowflake] = new ViewerJson.JsonUser {
Name = user.Name,
Avatar = user.AvatarUrl,
Tag = user.Discriminator
};
}
return users;
return (users, userIndex, userIndices);
}
private static List<object> GenerateServerList(IDatabaseFile db, HashSet<ulong> serverIds, out Dictionary<ulong, object> serverIndices) {
var servers = new List<object>();
serverIndices = new Dictionary<ulong, object>();
private static async Task<(List<ViewerJson.JsonServer> Servers, Dictionary<ulong, int> ServerIndices)> GenerateServerList(IDatabaseFile db, HashSet<ulong> serverIds) {
var servers = new List<ViewerJson.JsonServer>();
var serverIndices = new Dictionary<ulong, int>();
foreach (var server in db.GetAllServers()) {
await foreach (var server in db.Servers.Get()) {
var id = server.Id;
if (!serverIds.Contains(id)) {
continue;
}
serverIndices[id] = servers.Count;
servers.Add(new Dictionary<string, object> {
["name"] = server.Name,
["type"] = ServerTypes.ToJsonViewerString(server.Type),
servers.Add(new ViewerJson.JsonServer {
Name = server.Name,
Type = ServerTypes.ToJsonViewerString(server.Type)
});
}
return servers;
return (servers, serverIndices);
}
private static Dictionary<string, object> GenerateChannelList(List<Channel> includedChannels, Dictionary<ulong, object> serverIndices) {
var channels = new Dictionary<string, object>();
private static Dictionary<Snowflake, ViewerJson.JsonChannel> GenerateChannelList(List<Channel> includedChannels, Dictionary<ulong, int> serverIndices) {
var channels = new Dictionary<Snowflake, ViewerJson.JsonChannel>();
foreach (var channel in includedChannels) {
var obj = new Dictionary<string, object> {
["server"] = serverIndices[channel.Server],
["name"] = channel.Name,
var channelIdSnowflake = new Snowflake(channel.Id);
channels[channelIdSnowflake] = new ViewerJson.JsonChannel {
Server = serverIndices[channel.Server],
Name = channel.Name,
Parent = channel.ParentId?.ToString(),
Position = channel.Position,
Topic = channel.Topic,
Nsfw = channel.Nsfw
};
if (channel.ParentId != null) {
obj["parent"] = channel.ParentId;
}
if (channel.Position != null) {
obj["position"] = channel.Position;
}
if (channel.Topic != null) {
obj["topic"] = channel.Topic;
}
if (channel.Nsfw != null) {
obj["nsfw"] = channel.Nsfw;
}
channels[channel.Id.ToString()] = obj;
}
return channels;
}
private static Dictionary<string, Dictionary<string, object>> GenerateMessageList( List<Message> includedMessages, Dictionary<ulong, object> userIndices, IViewerExportStrategy strategy) {
var data = new Dictionary<string, Dictionary<string, object>>();
private static Dictionary<Snowflake, Dictionary<Snowflake, ViewerJson.JsonMessage>> GenerateMessageList(List<Message> includedMessages, Dictionary<ulong, int> userIndices, IViewerExportStrategy strategy) {
var data = new Dictionary<Snowflake, Dictionary<Snowflake, ViewerJson.JsonMessage>>();
foreach (var grouping in includedMessages.GroupBy(static message => message.Channel)) {
var channel = grouping.Key.ToString();
var channelData = new Dictionary<string, object>();
var channelIdSnowflake = new Snowflake(grouping.Key);
var channelData = new Dictionary<Snowflake, ViewerJson.JsonMessage>();
foreach (var message in grouping) {
var obj = new Dictionary<string, object> {
["u"] = userIndices[message.Sender],
["t"] = message.Timestamp,
};
if (!string.IsNullOrEmpty(message.Text)) {
obj["m"] = message.Text;
}
if (message.EditTimestamp != null) {
obj["te"] = message.EditTimestamp;
}
if (message.RepliedToId != null) {
obj["r"] = message.RepliedToId.Value;
}
if (!message.Attachments.IsEmpty) {
obj["a"] = message.Attachments.Select(attachment => {
var a = new Dictionary<string, object> {
{ "url", strategy.GetAttachmentUrl(attachment) },
{ "name", Uri.TryCreate(attachment.NormalizedUrl, UriKind.Absolute, out var uri) ? Path.GetFileName(uri.LocalPath) : attachment.NormalizedUrl },
var messageIdSnowflake = new Snowflake(message.Id);
channelData[messageIdSnowflake] = new ViewerJson.JsonMessage {
U = userIndices[message.Sender],
T = message.Timestamp,
M = string.IsNullOrEmpty(message.Text) ? null : message.Text,
Te = message.EditTimestamp,
R = message.RepliedToId?.ToString(),
A = message.Attachments.IsEmpty ? null : message.Attachments.Select(attachment => {
var a = new ViewerJson.JsonMessageAttachment {
Url = strategy.GetAttachmentUrl(attachment),
Name = Uri.TryCreate(attachment.NormalizedUrl, UriKind.Absolute, out var uri) ? Path.GetFileName(uri.LocalPath) : attachment.NormalizedUrl
};
if (attachment is { Width: not null, Height: not null }) {
a["width"] = attachment.Width;
a["height"] = attachment.Height;
a.Width = attachment.Width;
a.Height = attachment.Height;
}
return a;
}).ToArray();
}
if (!message.Embeds.IsEmpty) {
obj["e"] = message.Embeds.Select(static embed => embed.Json).ToArray();
}
if (!message.Reactions.IsEmpty) {
obj["re"] = message.Reactions.Select(static reaction => {
var r = new Dictionary<string, object>();
if (reaction.EmojiId != null) {
r["id"] = reaction.EmojiId.Value;
}
if (reaction.EmojiName != null) {
r["n"] = reaction.EmojiName;
}
r["a"] = reaction.EmojiFlags.HasFlag(EmojiFlags.Animated);
r["c"] = reaction.Count;
return r;
}).ToArray();
}
channelData[message.Id.ToString()] = obj;
}).ToArray(),
E = message.Embeds.IsEmpty ? null : message.Embeds.Select(static embed => embed.Json).ToArray(),
Re = message.Reactions.IsEmpty ? null : message.Reactions.Select(static reaction => new ViewerJson.JsonMessageReaction {
Id = reaction.EmojiId?.ToString(),
N = reaction.EmojiName,
A = reaction.EmojiFlags.HasFlag(EmojiFlags.Animated),
C = reaction.Count
}).ToArray()
};
}
data[channel] = channelData;
data[channelIdSnowflake] = channelData;
}
return data;

View File

@@ -1,15 +0,0 @@
using System;
using System.Text.Json;
using System.Text.Json.Serialization;
namespace DHT.Server.Database.Export;
sealed class ViewerJsonSnowflakeSerializer : JsonConverter<ulong> {
public override ulong Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) {
return ulong.Parse(reader.GetString()!);
}
public override void Write(Utf8JsonWriter writer, ulong value, JsonSerializerOptions options) {
writer.WriteStringValue(value.ToString());
}
}

View File

@@ -1,43 +1,18 @@
using System;
using System.Collections.Generic;
using DHT.Server.Data;
using DHT.Server.Data.Aggregations;
using DHT.Server.Data.Filters;
using DHT.Server.Download;
using System.Threading.Tasks;
using DHT.Server.Database.Repositories;
namespace DHT.Server.Database;
public interface IDatabaseFile : IDisposable {
public interface IDatabaseFile : IAsyncDisposable {
string Path { get; }
DatabaseStatistics Statistics { get; }
DatabaseStatisticsSnapshot SnapshotStatistics();
void AddServer(Data.Server server);
List<Data.Server> GetAllServers();
IUserRepository Users { get; }
IServerRepository Servers { get; }
IChannelRepository Channels { get; }
IMessageRepository Messages { get; }
IAttachmentRepository Attachments { get; }
IDownloadRepository Downloads { get; }
void AddChannel(Channel channel);
List<Channel> GetAllChannels();
void AddUsers(User[] users);
List<User> GetAllUsers();
void AddMessages(Message[] messages);
int CountMessages(MessageFilter? filter = null);
List<Message> GetMessages(MessageFilter? filter = null);
HashSet<ulong> GetMessageIds(MessageFilter? filter = null);
void RemoveMessages(MessageFilter filter, FilterRemovalMode mode);
int CountAttachments(AttachmentFilter? filter = null);
void AddDownload(Data.Download download);
List<Data.Download> GetDownloadsWithoutData();
Data.Download GetDownloadWithData(Data.Download download);
DownloadedAttachment? GetDownloadedAttachment(string url);
void EnqueueDownloadItems(AttachmentFilter? filter = null);
List<DownloadItem> GetEnqueuedDownloadItems(int count);
void RemoveDownloadItems(DownloadItemFilter? filter, FilterRemovalMode mode);
DownloadStatusStatistics GetDownloadStatusStatistics();
void Vacuum();
Task Vacuum();
}

View File

@@ -0,0 +1,23 @@
using System.Text.Json.Serialization;
namespace DHT.Server.Database.Import;
sealed class DiscordEmbedLegacyJson {
public required string Url { get; init; }
public required string Type { get; init; }
public bool DhtLegacy { get; } = true;
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? Title { get; init; }
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? Description { get; init; }
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public ImageJson? Image { get; init; }
public sealed class ImageJson {
public required string Url { get; init; }
}
}

View File

@@ -0,0 +1,7 @@
using System.Text.Json.Serialization;
namespace DHT.Server.Database.Import;
[JsonSourceGenerationOptions(PropertyNamingPolicy = JsonKnownNamingPolicy.SnakeCaseLower, GenerationMode = JsonSourceGenerationMode.Default)]
[JsonSerializable(typeof(DiscordEmbedLegacyJson))]
sealed partial class DiscordEmbedLegacyJsonContext : JsonSerializerContext;

View File

@@ -21,7 +21,7 @@ public static class LegacyArchiveImport {
public static async Task<bool> Read(Stream stream, IDatabaseFile db, FakeSnowflake fakeSnowflake, Func<Data.Server[], Task<Dictionary<Data.Server, ulong>?>> askForServerIds) {
var perf = Log.Start();
var root = await JsonSerializer.DeserializeAsync<JsonElement>(stream);
var root = await JsonSerializer.DeserializeAsync(stream, JsonElementContext.Default.JsonElement);
try {
var meta = root.RequireObject("meta");
@@ -33,10 +33,8 @@ public static class LegacyArchiveImport {
var servers = ReadServerList(meta, fakeSnowflake);
var newServersOnly = new HashSet<Data.Server>(servers);
var oldServersById = db.GetAllServers().ToDictionary(static server => server.Id, static server => server);
var oldChannels = db.GetAllChannels();
var oldChannelsById = oldChannels.ToDictionary(static channel => channel.Id, static channel => channel);
var oldServersById = await db.Servers.Get().ToDictionaryAsync(static server => server.Id, static server => server);
var oldChannelsById = await db.Channels.Get().ToDictionaryAsync(static channel => channel.Id, static channel => channel);
foreach (var (channelId, serverIndex) in ReadChannelToServerIndexMapping(meta, servers)) {
if (oldChannelsById.TryGetValue(channelId, out var oldChannel) && oldServersById.TryGetValue(oldChannel.Server, out var oldServer) && newServersOnly.Remove(servers[serverIndex])) {
@@ -66,17 +64,17 @@ public static class LegacyArchiveImport {
perf.Step("Read channel list");
var oldMessageIds = db.GetMessageIds();
var oldMessageIds = await db.Messages.GetIds().ToHashSetAsync();
var newMessages = channels.SelectMany(channel => ReadMessages(data, channel, users, fakeSnowflake))
.Where(message => !oldMessageIds.Contains(message.Id))
.ToArray();
perf.Step("Read messages");
db.AddUsers(users);
db.AddServers(servers);
db.AddChannels(channels);
db.AddMessages(newMessages);
await db.Users.Add(users);
await db.Servers.Add(servers);
await db.Channels.Add(channels);
await db.Messages.Add(newMessages);
perf.Step("Import into database");
} catch (HttpException e) {
@@ -179,9 +177,9 @@ public static class LegacyArchiveImport {
Timestamp = messageObj.RequireLong("t", path),
EditTimestamp = messageObj.HasKey("te") ? messageObj.RequireLong("te", path) : null,
RepliedToId = messageObj.HasKey("r") ? messageObj.RequireSnowflake("r", path) : null,
Attachments = messageObj.HasKey("a") ? ReadMessageAttachments(messageObj.RequireArray("a", path), fakeSnowflake, path + ".a[]").ToImmutableArray() : ImmutableArray<Attachment>.Empty,
Embeds = messageObj.HasKey("e") ? ReadMessageEmbeds(messageObj.RequireArray("e", path), path + ".e[]").ToImmutableArray() : ImmutableArray<Embed>.Empty,
Reactions = messageObj.HasKey("re") ? ReadMessageReactions(messageObj.RequireArray("re", path), path + ".re[]").ToImmutableArray() : ImmutableArray<Reaction>.Empty,
Attachments = messageObj.HasKey("a") ? ReadMessageAttachments(messageObj.RequireArray("a", path), fakeSnowflake, path + ".a[]").ToImmutableList() : ImmutableList<Attachment>.Empty,
Embeds = messageObj.HasKey("e") ? ReadMessageEmbeds(messageObj.RequireArray("e", path), path + ".e[]").ToImmutableList() : ImmutableList<Embed>.Empty,
Reactions = messageObj.HasKey("re") ? ReadMessageReactions(messageObj.RequireArray("re", path), path + ".re[]").ToImmutableList() : ImmutableList<Reaction>.Empty,
};
}).ToArray();
}
@@ -212,30 +210,17 @@ public static class LegacyArchiveImport {
return embedsArray.Where(static embedObj => embedObj.HasKey("url")).Select(embedObj => {
string url = embedObj.RequireString("url", path);
string type = embedObj.RequireString("type", path);
var embedJson = new Dictionary<string, object> {
{ "url", url },
{ "type", type },
{ "dht_legacy", true },
var embed = new DiscordEmbedLegacyJson {
Url = url,
Type = type,
Title = type == "rich" && embedObj.HasKey("t") ? embedObj.RequireString("t", path) : null,
Description = type == "rich" && embedObj.HasKey("d") ? embedObj.RequireString("d", path) : null,
Image = type == "image" ? new DiscordEmbedLegacyJson.ImageJson { Url = url } : null
};
if (type == "image") {
embedJson["image"] = new Dictionary<string, string> {
{ "url", url }
};
}
else if (type == "rich") {
if (embedObj.HasKey("t")) {
embedJson["title"] = embedObj.RequireString("t", path);
}
if (embedObj.HasKey("d")) {
embedJson["description"] = embedObj.RequireString("d", path);
}
}
return new Embed {
Json = JsonSerializer.Serialize(embedJson)
Json = JsonSerializer.Serialize(embed, DiscordEmbedLegacyJsonContext.Default.DiscordEmbedLegacyJson)
};
});
}

View File

@@ -0,0 +1,21 @@
using System;
using System.Reactive.Linq;
using System.Threading;
using System.Threading.Tasks;
using DHT.Server.Data.Filters;
namespace DHT.Server.Database.Repositories;
public interface IAttachmentRepository {
IObservable<long> TotalCount { get; }
Task<long> Count(AttachmentFilter? filter = null, CancellationToken cancellationToken = default);
internal sealed class Dummy : IAttachmentRepository {
public IObservable<long> TotalCount { get; } = Observable.Return(0L);
public Task<long> Count(AttachmentFilter? filter = null, CancellationToken cancellationToken = default) {
return Task.FromResult(0L);
}
}
}

View File

@@ -0,0 +1,35 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reactive.Linq;
using System.Threading;
using System.Threading.Tasks;
using DHT.Server.Data;
namespace DHT.Server.Database.Repositories;
public interface IChannelRepository {
IObservable<long> TotalCount { get; }
Task Add(IReadOnlyList<Channel> channels);
Task<long> Count(CancellationToken cancellationToken = default);
IAsyncEnumerable<Channel> Get();
internal sealed class Dummy : IChannelRepository {
public IObservable<long> TotalCount { get; } = Observable.Return(0L);
public Task Add(IReadOnlyList<Channel> channels) {
return Task.CompletedTask;
}
public Task<long> Count(CancellationToken cancellationToken) {
return Task.FromResult(0L);
}
public IAsyncEnumerable<Channel> Get() {
return AsyncEnumerable.Empty<Channel>();
}
}
}

View File

@@ -0,0 +1,68 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reactive.Linq;
using System.Threading;
using System.Threading.Tasks;
using DHT.Server.Data;
using DHT.Server.Data.Aggregations;
using DHT.Server.Data.Filters;
using DHT.Server.Download;
namespace DHT.Server.Database.Repositories;
public interface IDownloadRepository {
IObservable<long> TotalCount { get; }
Task AddDownload(Data.Download download);
Task<DownloadStatusStatistics> GetStatistics(CancellationToken cancellationToken = default);
IAsyncEnumerable<Data.Download> GetWithoutData();
Task<Data.Download> HydrateWithData(Data.Download download);
Task<DownloadedAttachment?> GetDownloadedAttachment(string normalizedUrl);
Task<int> EnqueueDownloadItems(AttachmentFilter? filter = null, CancellationToken cancellationToken = default);
IAsyncEnumerable<DownloadItem> PullEnqueuedDownloadItems(int count, CancellationToken cancellationToken = default);
Task RemoveDownloadItems(DownloadItemFilter? filter, FilterRemovalMode mode);
internal sealed class Dummy : IDownloadRepository {
public IObservable<long> TotalCount { get; } = Observable.Return(0L);
public Task AddDownload(Data.Download download) {
return Task.CompletedTask;
}
public Task<DownloadStatusStatistics> GetStatistics(CancellationToken cancellationToken) {
return Task.FromResult(new DownloadStatusStatistics());
}
public IAsyncEnumerable<Data.Download> GetWithoutData() {
return AsyncEnumerable.Empty<Data.Download>();
}
public Task<Data.Download> HydrateWithData(Data.Download download) {
return Task.FromResult(download);
}
public Task<DownloadedAttachment?> GetDownloadedAttachment(string normalizedUrl) {
return Task.FromResult<DownloadedAttachment?>(null);
}
public Task<int> EnqueueDownloadItems(AttachmentFilter? filter, CancellationToken cancellationToken) {
return Task.FromResult(0);
}
public IAsyncEnumerable<DownloadItem> PullEnqueuedDownloadItems(int count, CancellationToken cancellationToken) {
return AsyncEnumerable.Empty<DownloadItem>();
}
public Task RemoveDownloadItems(DownloadItemFilter? filter, FilterRemovalMode mode) {
return Task.CompletedTask;
}
}
}

View File

@@ -0,0 +1,48 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reactive.Linq;
using System.Threading;
using System.Threading.Tasks;
using DHT.Server.Data;
using DHT.Server.Data.Filters;
namespace DHT.Server.Database.Repositories;
public interface IMessageRepository {
IObservable<long> TotalCount { get; }
Task Add(IReadOnlyList<Message> messages);
Task<long> Count(MessageFilter? filter = null, CancellationToken cancellationToken = default);
IAsyncEnumerable<Message> Get(MessageFilter? filter = null);
IAsyncEnumerable<ulong> GetIds(MessageFilter? filter = null);
Task Remove(MessageFilter filter, FilterRemovalMode mode);
internal sealed class Dummy : IMessageRepository {
public IObservable<long> TotalCount { get; } = Observable.Return(0L);
public Task Add(IReadOnlyList<Message> messages) {
return Task.CompletedTask;
}
public Task<long> Count(MessageFilter? filter, CancellationToken cancellationToken) {
return Task.FromResult(0L);
}
public IAsyncEnumerable<Message> Get(MessageFilter? filter) {
return AsyncEnumerable.Empty<Message>();
}
public IAsyncEnumerable<ulong> GetIds(MessageFilter? filter) {
return AsyncEnumerable.Empty<ulong>();
}
public Task Remove(MessageFilter filter, FilterRemovalMode mode) {
return Task.CompletedTask;
}
}
}

View File

@@ -0,0 +1,34 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reactive.Linq;
using System.Threading;
using System.Threading.Tasks;
namespace DHT.Server.Database.Repositories;
public interface IServerRepository {
IObservable<long> TotalCount { get; }
Task Add(IReadOnlyList<Data.Server> servers);
Task<long> Count(CancellationToken cancellationToken = default);
IAsyncEnumerable<Data.Server> Get();
internal sealed class Dummy : IServerRepository {
public IObservable<long> TotalCount { get; } = Observable.Return(0L);
public Task Add(IReadOnlyList<Data.Server> servers) {
return Task.CompletedTask;
}
public Task<long> Count(CancellationToken cancellationToken) {
return Task.FromResult(0L);
}
public IAsyncEnumerable<Data.Server> Get() {
return AsyncEnumerable.Empty<Data.Server>();
}
}
}

View File

@@ -0,0 +1,35 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reactive.Linq;
using System.Threading;
using System.Threading.Tasks;
using DHT.Server.Data;
namespace DHT.Server.Database.Repositories;
public interface IUserRepository {
IObservable<long> TotalCount { get; }
Task Add(IReadOnlyList<User> users);
Task<long> Count(CancellationToken cancellationToken = default);
IAsyncEnumerable<User> Get();
internal sealed class Dummy : IUserRepository {
public IObservable<long> TotalCount { get; } = Observable.Return(0L);
public Task Add(IReadOnlyList<User> users) {
return Task.CompletedTask;
}
public Task<long> Count(CancellationToken cancellationToken) {
return Task.FromResult(0L);
}
public IAsyncEnumerable<User> Get() {
return AsyncEnumerable.Empty<User>();
}
}
}

View File

@@ -0,0 +1,26 @@
using System;
using System.Threading;
using System.Threading.Tasks;
using DHT.Utils.Tasks;
namespace DHT.Server.Database.Sqlite.Repositories;
abstract class BaseSqliteRepository : IDisposable {
private readonly ObservableThrottledTask<long> totalCountTask = new (TaskScheduler.Default);
public IObservable<long> TotalCount => totalCountTask;
protected BaseSqliteRepository() {
UpdateTotalCount();
}
public void Dispose() {
totalCountTask.Dispose();
}
protected void UpdateTotalCount() {
totalCountTask.Post(Count);
}
public abstract Task<long> Count(CancellationToken cancellationToken);
}

View File

@@ -0,0 +1,28 @@
using System.Threading;
using System.Threading.Tasks;
using DHT.Server.Data.Filters;
using DHT.Server.Database.Repositories;
using DHT.Server.Database.Sqlite.Utils;
namespace DHT.Server.Database.Sqlite.Repositories;
sealed class SqliteAttachmentRepository : BaseSqliteRepository, IAttachmentRepository {
private readonly SqliteConnectionPool pool;
public SqliteAttachmentRepository(SqliteConnectionPool pool) {
this.pool = pool;
}
internal new void UpdateTotalCount() {
base.UpdateTotalCount();
}
public override Task<long> Count(CancellationToken cancellationToken) {
return Count(filter: null, cancellationToken);
}
public async Task<long> Count(AttachmentFilter? filter, CancellationToken cancellationToken) {
await using var conn = await pool.Take();
return await conn.ExecuteReaderAsync("SELECT COUNT(DISTINCT normalized_url) FROM attachments a" + filter.GenerateWhereClause("a"), static reader => reader?.GetInt64(0) ?? 0L, cancellationToken);
}
}

View File

@@ -0,0 +1,72 @@
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using DHT.Server.Data;
using DHT.Server.Database.Repositories;
using DHT.Server.Database.Sqlite.Utils;
using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite.Repositories;
sealed class SqliteChannelRepository : BaseSqliteRepository, IChannelRepository {
private readonly SqliteConnectionPool pool;
public SqliteChannelRepository(SqliteConnectionPool pool) {
this.pool = pool;
}
public async Task Add(IReadOnlyList<Channel> channels) {
await using var conn = await pool.Take();
await using (var tx = await conn.BeginTransactionAsync()) {
await using var cmd = conn.Upsert("channels", [
("id", SqliteType.Integer),
("server", SqliteType.Integer),
("name", SqliteType.Text),
("parent_id", SqliteType.Integer),
("position", SqliteType.Integer),
("topic", SqliteType.Text),
("nsfw", SqliteType.Integer)
]);
foreach (var channel in channels) {
cmd.Set(":id", channel.Id);
cmd.Set(":server", channel.Server);
cmd.Set(":name", channel.Name);
cmd.Set(":parent_id", channel.ParentId);
cmd.Set(":position", channel.Position);
cmd.Set(":topic", channel.Topic);
cmd.Set(":nsfw", channel.Nsfw);
await cmd.ExecuteNonQueryAsync();
}
await tx.CommitAsync();
}
UpdateTotalCount();
}
public override async Task<long> Count(CancellationToken cancellationToken) {
await using var conn = await pool.Take();
return await conn.ExecuteReaderAsync("SELECT COUNT(*) FROM channels", static reader => reader?.GetInt64(0) ?? 0L, cancellationToken);
}
public async IAsyncEnumerable<Channel> Get() {
await using var conn = await pool.Take();
await using var cmd = conn.Command("SELECT id, server, name, parent_id, position, topic, nsfw FROM channels");
await using var reader = await cmd.ExecuteReaderAsync();
while (reader.Read()) {
yield return new Channel {
Id = reader.GetUint64(0),
Server = reader.GetUint64(1),
Name = reader.GetString(2),
ParentId = reader.IsDBNull(3) ? null : reader.GetUint64(3),
Position = reader.IsDBNull(4) ? null : reader.GetInt32(4),
Topic = reader.IsDBNull(5) ? null : reader.GetString(5),
Nsfw = reader.IsDBNull(6) ? null : reader.GetBoolean(6),
};
}
}
}

View File

@@ -0,0 +1,230 @@
using System;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using DHT.Server.Data;
using DHT.Server.Data.Aggregations;
using DHT.Server.Data.Filters;
using DHT.Server.Database.Repositories;
using DHT.Server.Database.Sqlite.Utils;
using DHT.Server.Download;
using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite.Repositories;
sealed class SqliteDownloadRepository : BaseSqliteRepository, IDownloadRepository {
private readonly SqliteConnectionPool pool;
public SqliteDownloadRepository(SqliteConnectionPool pool) {
this.pool = pool;
}
public async Task AddDownload(Data.Download download) {
await using (var conn = await pool.Take()) {
await using var cmd = conn.Upsert("downloads", [
("normalized_url", SqliteType.Text),
("download_url", SqliteType.Text),
("status", SqliteType.Integer),
("size", SqliteType.Integer),
("blob", SqliteType.Blob)
]);
cmd.Set(":normalized_url", download.NormalizedUrl);
cmd.Set(":download_url", download.DownloadUrl);
cmd.Set(":status", (int) download.Status);
cmd.Set(":size", download.Size);
cmd.Set(":blob", download.Data);
await cmd.ExecuteNonQueryAsync();
}
UpdateTotalCount();
}
public override async Task<long> Count(CancellationToken cancellationToken) {
await using var conn = await pool.Take();
return await conn.ExecuteReaderAsync("SELECT COUNT(*) FROM downloads", static reader => reader?.GetInt64(0) ?? 0L, cancellationToken);
}
public async Task<DownloadStatusStatistics> GetStatistics(CancellationToken cancellationToken) {
static async Task LoadUndownloadedStatistics(ISqliteConnection conn, DownloadStatusStatistics result, CancellationToken cancellationToken) {
await using var cmd = conn.Command(
"""
SELECT IFNULL(COUNT(size), 0), IFNULL(SUM(size), 0)
FROM (SELECT MAX(a.size) size
FROM attachments a
WHERE a.normalized_url NOT IN (SELECT d.normalized_url FROM downloads d)
GROUP BY a.normalized_url)
""");
await using var reader = await cmd.ExecuteReaderAsync(cancellationToken);
if (reader.Read()) {
result.SkippedCount = reader.GetInt32(0);
result.SkippedSize = reader.GetUint64(1);
}
}
static async Task LoadSuccessStatistics(ISqliteConnection conn, DownloadStatusStatistics result, CancellationToken cancellationToken) {
await using var cmd = conn.Command(
"""
SELECT
IFNULL(SUM(CASE WHEN status IN (:enqueued, :downloading) THEN 1 ELSE 0 END), 0),
IFNULL(SUM(CASE WHEN status IN (:enqueued, :downloading) THEN size ELSE 0 END), 0),
IFNULL(SUM(CASE WHEN status = :success THEN 1 ELSE 0 END), 0),
IFNULL(SUM(CASE WHEN status = :success THEN size ELSE 0 END), 0),
IFNULL(SUM(CASE WHEN status NOT IN (:enqueued, :downloading) AND status != :success THEN 1 ELSE 0 END), 0),
IFNULL(SUM(CASE WHEN status NOT IN (:enqueued, :downloading) AND status != :success THEN size ELSE 0 END), 0)
FROM downloads
"""
);
cmd.AddAndSet(":enqueued", SqliteType.Integer, (int) DownloadStatus.Enqueued);
cmd.AddAndSet(":downloading", SqliteType.Integer, (int) DownloadStatus.Downloading);
cmd.AddAndSet(":success", SqliteType.Integer, (int) DownloadStatus.Success);
await using var reader = await cmd.ExecuteReaderAsync(cancellationToken);
if (reader.Read()) {
result.EnqueuedCount = reader.GetInt32(0);
result.EnqueuedSize = reader.GetUint64(1);
result.SuccessfulCount = reader.GetInt32(2);
result.SuccessfulSize = reader.GetUint64(3);
result.FailedCount = reader.GetInt32(4);
result.FailedSize = reader.GetUint64(5);
}
}
var result = new DownloadStatusStatistics();
await using var conn = await pool.Take();
await LoadUndownloadedStatistics(conn, result, cancellationToken);
await LoadSuccessStatistics(conn, result, cancellationToken);
return result;
}
public async IAsyncEnumerable<Data.Download> GetWithoutData() {
await using var conn = await pool.Take();
await using var cmd = conn.Command("SELECT normalized_url, download_url, status, size FROM downloads");
await using var reader = await cmd.ExecuteReaderAsync();
while (reader.Read()) {
string normalizedUrl = reader.GetString(0);
string downloadUrl = reader.GetString(1);
var status = (DownloadStatus) reader.GetInt32(2);
ulong size = reader.GetUint64(3);
yield return new Data.Download(normalizedUrl, downloadUrl, status, size);
}
}
public async Task<Data.Download> HydrateWithData(Data.Download download) {
await using var conn = await pool.Take();
await using var cmd = conn.Command("SELECT blob FROM downloads WHERE normalized_url = :url");
cmd.AddAndSet(":url", SqliteType.Text, download.NormalizedUrl);
await using var reader = await cmd.ExecuteReaderAsync();
if (reader.Read() && !reader.IsDBNull(0)) {
return download.WithData((byte[]) reader["blob"]);
}
else {
return download;
}
}
public async Task<DownloadedAttachment?> GetDownloadedAttachment(string normalizedUrl) {
await using var conn = await pool.Take();
await using var cmd = conn.Command(
"""
SELECT a.type, d.blob FROM downloads d
LEFT JOIN attachments a ON d.normalized_url = a.normalized_url
WHERE d.normalized_url = :normalized_url AND d.status = :success AND d.blob IS NOT NULL
"""
);
cmd.AddAndSet(":normalized_url", SqliteType.Text, normalizedUrl);
cmd.AddAndSet(":success", SqliteType.Integer, (int) DownloadStatus.Success);
await using var reader = await cmd.ExecuteReaderAsync();
if (!reader.Read()) {
return null;
}
return new DownloadedAttachment {
Type = reader.IsDBNull(0) ? null : reader.GetString(0),
Data = (byte[]) reader["blob"],
};
}
public async Task<int> EnqueueDownloadItems(AttachmentFilter? filter, CancellationToken cancellationToken) {
await using var conn = await pool.Take();
await using var cmd = conn.Command(
$"""
INSERT INTO downloads (normalized_url, download_url, status, size)
SELECT a.normalized_url, a.download_url, :enqueued, MAX(a.size)
FROM attachments a
{filter.GenerateWhereClause("a")}
GROUP BY a.normalized_url
"""
);
cmd.AddAndSet(":enqueued", SqliteType.Integer, (int) DownloadStatus.Enqueued);
return await cmd.ExecuteNonQueryAsync(cancellationToken);
}
public async IAsyncEnumerable<DownloadItem> PullEnqueuedDownloadItems(int count, [EnumeratorCancellation] CancellationToken cancellationToken) {
var found = new List<DownloadItem>();
await using var conn = await pool.Take();
await using (var cmd = conn.Command("SELECT normalized_url, download_url, size FROM downloads WHERE status = :enqueued LIMIT :limit")) {
cmd.AddAndSet(":enqueued", SqliteType.Integer, (int) DownloadStatus.Enqueued);
cmd.AddAndSet(":limit", SqliteType.Integer, Math.Max(0, count));
await using var reader = await cmd.ExecuteReaderAsync(cancellationToken);
while (reader.Read()) {
found.Add(new DownloadItem {
NormalizedUrl = reader.GetString(0),
DownloadUrl = reader.GetString(1),
Size = reader.GetUint64(2),
});
}
}
if (found.Count != 0) {
await using var cmd = conn.Command("UPDATE downloads SET status = :downloading WHERE normalized_url = :normalized_url AND status = :enqueued");
cmd.AddAndSet(":enqueued", SqliteType.Integer, (int) DownloadStatus.Enqueued);
cmd.AddAndSet(":downloading", SqliteType.Integer, (int) DownloadStatus.Downloading);
cmd.Add(":normalized_url", SqliteType.Text);
foreach (var item in found) {
cmd.Set(":normalized_url", item.NormalizedUrl);
if (await cmd.ExecuteNonQueryAsync(cancellationToken) == 1) {
yield return item;
}
}
}
}
public async Task RemoveDownloadItems(DownloadItemFilter? filter, FilterRemovalMode mode) {
await using (var conn = await pool.Take()) {
await conn.ExecuteAsync(
$"""
-- noinspection SqlWithoutWhere
DELETE FROM downloads
{filter.GenerateWhereClause(invert: mode == FilterRemovalMode.KeepMatching)}
"""
);
}
UpdateTotalCount();
}
}

View File

@@ -0,0 +1,307 @@
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Threading;
using System.Threading.Tasks;
using DHT.Server.Data;
using DHT.Server.Data.Filters;
using DHT.Server.Database.Repositories;
using DHT.Server.Database.Sqlite.Utils;
using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite.Repositories;
sealed class SqliteMessageRepository : BaseSqliteRepository, IMessageRepository {
private readonly SqliteConnectionPool pool;
private readonly SqliteAttachmentRepository attachments;
public SqliteMessageRepository(SqliteConnectionPool pool, SqliteAttachmentRepository attachments) {
this.pool = pool;
this.attachments = attachments;
}
public async Task Add(IReadOnlyList<Message> messages) {
if (messages.Count == 0) {
return;
}
static SqliteCommand DeleteByMessageId(ISqliteConnection conn, string tableName) {
return conn.Delete(tableName, ("message_id", SqliteType.Integer));
}
static async Task ExecuteDeleteByMessageId(SqliteCommand cmd, object id) {
cmd.Set(":message_id", id);
await cmd.ExecuteNonQueryAsync();
}
bool addedAttachments = false;
await using (var conn = await pool.Take()) {
await using var tx = await conn.BeginTransactionAsync();
await using var messageCmd = conn.Upsert("messages", [
("message_id", SqliteType.Integer),
("sender_id", SqliteType.Integer),
("channel_id", SqliteType.Integer),
("text", SqliteType.Text),
("timestamp", SqliteType.Integer)
]);
await using var deleteEditTimestampCmd = DeleteByMessageId(conn, "edit_timestamps");
await using var deleteRepliedToCmd = DeleteByMessageId(conn, "replied_to");
await using var deleteAttachmentsCmd = DeleteByMessageId(conn, "attachments");
await using var deleteEmbedsCmd = DeleteByMessageId(conn, "embeds");
await using var deleteReactionsCmd = DeleteByMessageId(conn, "reactions");
await using var editTimestampCmd = conn.Insert("edit_timestamps", [
("message_id", SqliteType.Integer),
("edit_timestamp", SqliteType.Integer)
]);
await using var repliedToCmd = conn.Insert("replied_to", [
("message_id", SqliteType.Integer),
("replied_to_id", SqliteType.Integer)
]);
await using var attachmentCmd = conn.Insert("attachments", [
("message_id", SqliteType.Integer),
("attachment_id", SqliteType.Integer),
("name", SqliteType.Text),
("type", SqliteType.Text),
("normalized_url", SqliteType.Text),
("download_url", SqliteType.Text),
("size", SqliteType.Integer),
("width", SqliteType.Integer),
("height", SqliteType.Integer)
]);
await using var embedCmd = conn.Insert("embeds", [
("message_id", SqliteType.Integer),
("json", SqliteType.Text)
]);
await using var reactionCmd = conn.Insert("reactions", [
("message_id", SqliteType.Integer),
("emoji_id", SqliteType.Integer),
("emoji_name", SqliteType.Text),
("emoji_flags", SqliteType.Integer),
("count", SqliteType.Integer)
]);
foreach (var message in messages) {
object messageId = message.Id;
messageCmd.Set(":message_id", messageId);
messageCmd.Set(":sender_id", message.Sender);
messageCmd.Set(":channel_id", message.Channel);
messageCmd.Set(":text", message.Text);
messageCmd.Set(":timestamp", message.Timestamp);
await messageCmd.ExecuteNonQueryAsync();
await ExecuteDeleteByMessageId(deleteEditTimestampCmd, messageId);
await ExecuteDeleteByMessageId(deleteRepliedToCmd, messageId);
await ExecuteDeleteByMessageId(deleteAttachmentsCmd, messageId);
await ExecuteDeleteByMessageId(deleteEmbedsCmd, messageId);
await ExecuteDeleteByMessageId(deleteReactionsCmd, messageId);
if (message.EditTimestamp is {} timestamp) {
editTimestampCmd.Set(":message_id", messageId);
editTimestampCmd.Set(":edit_timestamp", timestamp);
await editTimestampCmd.ExecuteNonQueryAsync();
}
if (message.RepliedToId is {} repliedToId) {
repliedToCmd.Set(":message_id", messageId);
repliedToCmd.Set(":replied_to_id", repliedToId);
await repliedToCmd.ExecuteNonQueryAsync();
}
if (!message.Attachments.IsEmpty) {
addedAttachments = true;
foreach (var attachment in message.Attachments) {
attachmentCmd.Set(":message_id", messageId);
attachmentCmd.Set(":attachment_id", attachment.Id);
attachmentCmd.Set(":name", attachment.Name);
attachmentCmd.Set(":type", attachment.Type);
attachmentCmd.Set(":normalized_url", attachment.NormalizedUrl);
attachmentCmd.Set(":download_url", attachment.DownloadUrl);
attachmentCmd.Set(":size", attachment.Size);
attachmentCmd.Set(":width", attachment.Width);
attachmentCmd.Set(":height", attachment.Height);
await attachmentCmd.ExecuteNonQueryAsync();
}
}
if (!message.Embeds.IsEmpty) {
foreach (var embed in message.Embeds) {
embedCmd.Set(":message_id", messageId);
embedCmd.Set(":json", embed.Json);
await embedCmd.ExecuteNonQueryAsync();
}
}
if (!message.Reactions.IsEmpty) {
foreach (var reaction in message.Reactions) {
reactionCmd.Set(":message_id", messageId);
reactionCmd.Set(":emoji_id", reaction.EmojiId);
reactionCmd.Set(":emoji_name", reaction.EmojiName);
reactionCmd.Set(":emoji_flags", (int) reaction.EmojiFlags);
reactionCmd.Set(":count", reaction.Count);
await reactionCmd.ExecuteNonQueryAsync();
}
}
}
await tx.CommitAsync();
}
UpdateTotalCount();
if (addedAttachments) {
attachments.UpdateTotalCount();
}
}
public override Task<long> Count(CancellationToken cancellationToken) {
return Count(filter: null, cancellationToken);
}
public async Task<long> Count(MessageFilter? filter, CancellationToken cancellationToken) {
await using var conn = await pool.Take();
return await conn.ExecuteReaderAsync("SELECT COUNT(*) FROM messages" + filter.GenerateWhereClause(), static reader => reader?.GetInt64(0) ?? 0L, cancellationToken);
}
private sealed class MesageToManyCommand<T> : IAsyncDisposable {
private readonly SqliteCommand cmd;
private readonly Func<SqliteDataReader, T> readItem;
public MesageToManyCommand(ISqliteConnection conn, string sql, Func<SqliteDataReader, T> readItem) {
this.cmd = conn.Command(sql);
this.cmd.Add(":message_id", SqliteType.Integer);
this.readItem = readItem;
}
public async Task<ImmutableList<T>> GetItems(ulong messageId) {
cmd.Set(":message_id", messageId);
var items = ImmutableList<T>.Empty;
await using var reader = await cmd.ExecuteReaderAsync();
while (await reader.ReadAsync()) {
items = items.Add(readItem(reader));
}
return items;
}
public async ValueTask DisposeAsync() {
await cmd.DisposeAsync();
}
}
public async IAsyncEnumerable<Message> Get(MessageFilter? filter) {
await using var conn = await pool.Take();
const string AttachmentSql =
"""
SELECT attachment_id, name, type, normalized_url, download_url, size, width, height
FROM attachments
WHERE message_id = :message_id
""";
await using var attachmentCmd = new MesageToManyCommand<Attachment>(conn, AttachmentSql, static reader => new Attachment {
Id = reader.GetUint64(0),
Name = reader.GetString(1),
Type = reader.IsDBNull(2) ? null : reader.GetString(2),
NormalizedUrl = reader.GetString(3),
DownloadUrl = reader.GetString(4),
Size = reader.GetUint64(5),
Width = reader.IsDBNull(6) ? null : reader.GetInt32(6),
Height = reader.IsDBNull(7) ? null : reader.GetInt32(7),
});
const string EmbedSql =
"""
SELECT json
FROM embeds
WHERE message_id = :message_id
""";
await using var embedCmd = new MesageToManyCommand<Embed>(conn, EmbedSql, static reader => new Embed {
Json = reader.GetString(0)
});
const string ReactionSql =
"""
SELECT emoji_id, emoji_name, emoji_flags, count
FROM reactions
WHERE message_id = :message_id
""";
await using var reactionsCmd = new MesageToManyCommand<Reaction>(conn, ReactionSql, static reader => new Reaction {
EmojiId = reader.IsDBNull(0) ? null : reader.GetUint64(0),
EmojiName = reader.IsDBNull(1) ? null : reader.GetString(1),
EmojiFlags = (EmojiFlags) reader.GetInt16(2),
Count = reader.GetInt32(3),
});
await using var messageCmd = conn.Command(
$"""
SELECT m.message_id, m.sender_id, m.channel_id, m.text, m.timestamp, et.edit_timestamp, rt.replied_to_id
FROM messages m
LEFT JOIN edit_timestamps et ON m.message_id = et.message_id
LEFT JOIN replied_to rt ON m.message_id = rt.message_id
{filter.GenerateWhereClause("m")}
"""
);
await using var reader = await messageCmd.ExecuteReaderAsync();
while (await reader.ReadAsync()) {
ulong messageId = reader.GetUint64(0);
yield return new Message {
Id = messageId,
Sender = reader.GetUint64(1),
Channel = reader.GetUint64(2),
Text = reader.GetString(3),
Timestamp = reader.GetInt64(4),
EditTimestamp = reader.IsDBNull(5) ? null : reader.GetInt64(5),
RepliedToId = reader.IsDBNull(6) ? null : reader.GetUint64(6),
Attachments = await attachmentCmd.GetItems(messageId),
Embeds = await embedCmd.GetItems(messageId),
Reactions = await reactionsCmd.GetItems(messageId)
};
}
}
public async IAsyncEnumerable<ulong> GetIds(MessageFilter? filter) {
await using var conn = await pool.Take();
await using var cmd = conn.Command("SELECT message_id FROM messages" + filter.GenerateWhereClause());
await using var reader = await cmd.ExecuteReaderAsync();
while (await reader.ReadAsync()) {
yield return reader.GetUint64(0);
}
}
public async Task Remove(MessageFilter filter, FilterRemovalMode mode) {
await using (var conn = await pool.Take()) {
await conn.ExecuteAsync(
$"""
-- noinspection SqlWithoutWhere
DELETE FROM messages
{filter.GenerateWhereClause(invert: mode == FilterRemovalMode.KeepMatching)}
"""
);
}
UpdateTotalCount();
}
}

View File

@@ -0,0 +1,60 @@
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using DHT.Server.Data;
using DHT.Server.Database.Repositories;
using DHT.Server.Database.Sqlite.Utils;
using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite.Repositories;
sealed class SqliteServerRepository : BaseSqliteRepository, IServerRepository {
private readonly SqliteConnectionPool pool;
public SqliteServerRepository(SqliteConnectionPool pool) {
this.pool = pool;
}
public async Task Add(IReadOnlyList<Data.Server> servers) {
await using var conn = await pool.Take();
await using (var tx = await conn.BeginTransactionAsync()) {
await using var cmd = conn.Upsert("servers", [
("id", SqliteType.Integer),
("name", SqliteType.Text),
("type", SqliteType.Text)
]);
foreach (var server in servers) {
cmd.Set(":id", server.Id);
cmd.Set(":name", server.Name);
cmd.Set(":type", ServerTypes.ToString(server.Type));
await cmd.ExecuteNonQueryAsync();
}
await tx.CommitAsync();
}
UpdateTotalCount();
}
public override async Task<long> Count(CancellationToken cancellationToken) {
await using var conn = await pool.Take();
return await conn.ExecuteReaderAsync("SELECT COUNT(*) FROM servers", static reader => reader?.GetInt64(0) ?? 0L, cancellationToken);
}
public async IAsyncEnumerable<Data.Server> Get() {
await using var conn = await pool.Take();
await using var cmd = conn.Command("SELECT id, name, type FROM servers");
await using var reader = await cmd.ExecuteReaderAsync();
while (reader.Read()) {
yield return new Data.Server {
Id = reader.GetUint64(0),
Name = reader.GetString(1),
Type = ServerTypes.FromString(reader.GetString(2)),
};
}
}
}

View File

@@ -0,0 +1,63 @@
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using DHT.Server.Data;
using DHT.Server.Database.Repositories;
using DHT.Server.Database.Sqlite.Utils;
using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite.Repositories;
sealed class SqliteUserRepository : BaseSqliteRepository, IUserRepository {
private readonly SqliteConnectionPool pool;
public SqliteUserRepository(SqliteConnectionPool pool) {
this.pool = pool;
}
public async Task Add(IReadOnlyList<User> users) {
await using (var conn = await pool.Take()) {
await using var tx = await conn.BeginTransactionAsync();
await using var cmd = conn.Upsert("users", [
("id", SqliteType.Integer),
("name", SqliteType.Text),
("avatar_url", SqliteType.Text),
("discriminator", SqliteType.Text)
]);
foreach (var user in users) {
cmd.Set(":id", user.Id);
cmd.Set(":name", user.Name);
cmd.Set(":avatar_url", user.AvatarUrl);
cmd.Set(":discriminator", user.Discriminator);
await cmd.ExecuteNonQueryAsync();
}
await tx.CommitAsync();
}
UpdateTotalCount();
}
public override async Task<long> Count(CancellationToken cancellationToken) {
await using var conn = await pool.Take();
return await conn.ExecuteReaderAsync("SELECT COUNT(*) FROM users", static reader => reader?.GetInt64(0) ?? 0L, cancellationToken);
}
public async IAsyncEnumerable<User> Get() {
await using var conn = await pool.Take();
await using var cmd = conn.Command("SELECT id, name, avatar_url, discriminator FROM users");
await using var reader = await cmd.ExecuteReaderAsync();
while (reader.Read()) {
yield return new User {
Id = reader.GetUint64(0),
Name = reader.GetString(1),
AvatarUrl = reader.IsDBNull(2) ? null : reader.GetString(2),
Discriminator = reader.IsDBNull(3) ? null : reader.GetString(3),
};
}
}
}

View File

@@ -1,5 +1,5 @@
using System;
using System.Collections.Generic;
using System.Data.Common;
using System.Threading.Tasks;
using DHT.Server.Database.Exceptions;
using DHT.Server.Database.Sqlite.Utils;
@@ -20,150 +20,148 @@ sealed class Schema {
this.conn = conn;
}
private void Execute(string sql) {
conn.Command(sql).ExecuteNonQuery();
}
public async Task<bool> Setup(ISchemaUpgradeCallbacks callbacks) {
await conn.ExecuteAsync("CREATE TABLE IF NOT EXISTS metadata (key TEXT PRIMARY KEY, value TEXT)");
public async Task<bool> Setup(Func<Task<bool>> checkCanUpgradeSchemas) {
Execute(@"CREATE TABLE IF NOT EXISTS metadata (key TEXT PRIMARY KEY, value TEXT)");
var dbVersionStr = conn.SelectScalar("SELECT value FROM metadata WHERE key = 'version'");
var dbVersionStr = await conn.ExecuteReaderAsync("SELECT value FROM metadata WHERE key = 'version'", static reader => reader?.GetString(0));
if (dbVersionStr == null) {
InitializeSchemas();
await InitializeSchemas();
}
else if (!int.TryParse(dbVersionStr.ToString(), out int dbVersion) || dbVersion < 1) {
throw new InvalidDatabaseVersionException(dbVersionStr.ToString() ?? "<null>");
else if (!int.TryParse(dbVersionStr, out int dbVersion) || dbVersion < 1) {
throw new InvalidDatabaseVersionException(dbVersionStr);
}
else if (dbVersion > Version) {
throw new DatabaseTooNewException(dbVersion);
}
else if (dbVersion < Version) {
var proceed = await checkCanUpgradeSchemas();
var proceed = await callbacks.CanUpgrade();
if (!proceed) {
return false;
}
UpgradeSchemas(dbVersion);
await callbacks.Start(Version - dbVersion, async reporter => await UpgradeSchemas(dbVersion, reporter));
}
return true;
}
private void InitializeSchemas() {
Execute("""
CREATE TABLE users (
id INTEGER PRIMARY KEY NOT NULL,
name TEXT NOT NULL,
avatar_url TEXT,
discriminator TEXT
)
""");
private async Task InitializeSchemas() {
await conn.ExecuteAsync("""
CREATE TABLE users (
id INTEGER PRIMARY KEY NOT NULL,
name TEXT NOT NULL,
avatar_url TEXT,
discriminator TEXT
)
""");
Execute("""
CREATE TABLE servers (
id INTEGER PRIMARY KEY NOT NULL,
name TEXT NOT NULL,
type TEXT NOT NULL
)
""");
await conn.ExecuteAsync("""
CREATE TABLE servers (
id INTEGER PRIMARY KEY NOT NULL,
name TEXT NOT NULL,
type TEXT NOT NULL
)
""");
Execute("""
CREATE TABLE channels (
id INTEGER PRIMARY KEY NOT NULL,
server INTEGER NOT NULL,
name TEXT NOT NULL,
parent_id INTEGER,
position INTEGER,
topic TEXT,
nsfw INTEGER
)
""");
await conn.ExecuteAsync("""
CREATE TABLE channels (
id INTEGER PRIMARY KEY NOT NULL,
server INTEGER NOT NULL,
name TEXT NOT NULL,
parent_id INTEGER,
position INTEGER,
topic TEXT,
nsfw INTEGER
)
""");
Execute("""
CREATE TABLE messages (
message_id INTEGER PRIMARY KEY NOT NULL,
sender_id INTEGER NOT NULL,
channel_id INTEGER NOT NULL,
text TEXT NOT NULL,
timestamp INTEGER NOT NULL
)
""");
await conn.ExecuteAsync("""
CREATE TABLE messages (
message_id INTEGER PRIMARY KEY NOT NULL,
sender_id INTEGER NOT NULL,
channel_id INTEGER NOT NULL,
text TEXT NOT NULL,
timestamp INTEGER NOT NULL
)
""");
Execute("""
CREATE TABLE attachments (
message_id INTEGER NOT NULL,
attachment_id INTEGER NOT NULL PRIMARY KEY NOT NULL,
name TEXT NOT NULL,
type TEXT,
normalized_url TEXT NOT NULL,
download_url TEXT,
size INTEGER NOT NULL,
width INTEGER,
height INTEGER
)
""");
await conn.ExecuteAsync("""
CREATE TABLE attachments (
message_id INTEGER NOT NULL,
attachment_id INTEGER NOT NULL PRIMARY KEY NOT NULL,
name TEXT NOT NULL,
type TEXT,
normalized_url TEXT NOT NULL,
download_url TEXT,
size INTEGER NOT NULL,
width INTEGER,
height INTEGER
)
""");
Execute("""
CREATE TABLE embeds (
message_id INTEGER NOT NULL,
json TEXT NOT NULL
)
""");
await conn.ExecuteAsync("""
CREATE TABLE embeds (
message_id INTEGER NOT NULL,
json TEXT NOT NULL
)
""");
Execute("""
CREATE TABLE downloads (
normalized_url TEXT NOT NULL PRIMARY KEY,
download_url TEXT,
status INTEGER NOT NULL,
size INTEGER NOT NULL,
blob BLOB
)
""");
Execute("""
CREATE TABLE reactions (
message_id INTEGER NOT NULL,
emoji_id INTEGER,
emoji_name TEXT,
emoji_flags INTEGER NOT NULL,
count INTEGER NOT NULL
)
""");
await conn.ExecuteAsync("""
CREATE TABLE downloads (
normalized_url TEXT NOT NULL PRIMARY KEY,
download_url TEXT,
status INTEGER NOT NULL,
size INTEGER NOT NULL,
blob BLOB
)
""");
CreateMessageEditTimestampTable();
CreateMessageRepliedToTable();
await conn.ExecuteAsync("""
CREATE TABLE reactions (
message_id INTEGER NOT NULL,
emoji_id INTEGER,
emoji_name TEXT,
emoji_flags INTEGER NOT NULL,
count INTEGER NOT NULL
)
""");
Execute("CREATE INDEX attachments_message_ix ON attachments(message_id)");
Execute("CREATE INDEX embeds_message_ix ON embeds(message_id)");
Execute("CREATE INDEX reactions_message_ix ON reactions(message_id)");
await CreateMessageEditTimestampTable();
await CreateMessageRepliedToTable();
Execute("INSERT INTO metadata (key, value) VALUES ('version', " + Version + ")");
await conn.ExecuteAsync("CREATE INDEX attachments_message_ix ON attachments(message_id)");
await conn.ExecuteAsync("CREATE INDEX embeds_message_ix ON embeds(message_id)");
await conn.ExecuteAsync("CREATE INDEX reactions_message_ix ON reactions(message_id)");
await conn.ExecuteAsync("INSERT INTO metadata (key, value) VALUES ('version', " + Version + ")");
}
private void CreateMessageEditTimestampTable() {
Execute("""
CREATE TABLE edit_timestamps (
message_id INTEGER PRIMARY KEY NOT NULL,
edit_timestamp INTEGER NOT NULL
)
""");
private async Task CreateMessageEditTimestampTable() {
await conn.ExecuteAsync("""
CREATE TABLE edit_timestamps (
message_id INTEGER PRIMARY KEY NOT NULL,
edit_timestamp INTEGER NOT NULL
)
""");
}
private void CreateMessageRepliedToTable() {
Execute("""
CREATE TABLE replied_to (
message_id INTEGER PRIMARY KEY NOT NULL,
replied_to_id INTEGER NOT NULL
)
""");
private async Task CreateMessageRepliedToTable() {
await conn.ExecuteAsync("""
CREATE TABLE replied_to (
message_id INTEGER PRIMARY KEY NOT NULL,
replied_to_id INTEGER NOT NULL
)
""");
}
private void NormalizeAttachmentUrls() {
private async Task NormalizeAttachmentUrls(ISchemaUpgradeCallbacks.IProgressReporter reporter) {
await reporter.SubWork("Preparing attachments...", 0, 0);
var normalizedUrls = new Dictionary<long, string>();
using (var selectCmd = conn.Command("SELECT attachment_id, url FROM attachments")) {
using var reader = selectCmd.ExecuteReader();
await using (var selectCmd = conn.Command("SELECT attachment_id, url FROM attachments")) {
await using var reader = await selectCmd.ExecuteReaderAsync();
while (reader.Read()) {
var attachmentId = reader.GetInt64(0);
var originalUrl = reader.GetString(1);
@@ -171,28 +169,39 @@ sealed class Schema {
}
}
using var tx = conn.BeginTransaction();
using (var updateCmd = conn.Command("UPDATE attachments SET download_url = url, url = :normalized_url WHERE attachment_id = :attachment_id")) {
updateCmd.Parameters.Add(":attachment_id", SqliteType.Integer);
updateCmd.Parameters.Add(":normalized_url", SqliteType.Text);
await using var tx = await conn.BeginTransactionAsync();
int totalUrls = normalizedUrls.Count;
int processedUrls = -1;
await using (var updateCmd = conn.Command("UPDATE attachments SET download_url = url, url = :normalized_url WHERE attachment_id = :attachment_id")) {
updateCmd.Add(":attachment_id", SqliteType.Integer);
updateCmd.Add(":normalized_url", SqliteType.Text);
foreach (var (attachmentId, normalizedUrl) in normalizedUrls) {
if (++processedUrls % 1000 == 0) {
await reporter.SubWork("Updating URLs...", processedUrls, totalUrls);
}
updateCmd.Set(":attachment_id", attachmentId);
updateCmd.Set(":normalized_url", normalizedUrl);
updateCmd.ExecuteNonQuery();
}
}
tx.Commit();
await reporter.SubWork("Updating URLs...", totalUrls, totalUrls);
await tx.CommitAsync();
}
private void NormalizeDownloadUrls() {
private async Task NormalizeDownloadUrls(ISchemaUpgradeCallbacks.IProgressReporter reporter) {
await reporter.SubWork("Preparing downloads...", 0, 0);
var normalizedUrlsToOriginalUrls = new Dictionary<string, string>();
var duplicateUrlsToDelete = new HashSet<string>();
using (var selectCmd = conn.Command("SELECT url FROM downloads ORDER BY CASE WHEN status = 200 THEN 0 ELSE 1 END")) {
using var reader = selectCmd.ExecuteReader();
await using (var selectCmd = conn.Command("SELECT url FROM downloads ORDER BY CASE WHEN status = 200 THEN 0 ELSE 1 END")) {
await using var reader = await selectCmd.ExecuteReaderAsync();
while (reader.Read()) {
var originalUrl = reader.GetString(0);
@@ -204,96 +213,144 @@ sealed class Schema {
}
}
using var tx = conn.BeginTransaction();
using (var deleteCmd = conn.Delete("downloads", ("url", SqliteType.Text))) {
foreach (var duplicateUrl in duplicateUrlsToDelete) {
deleteCmd.Set(":url", duplicateUrl);
deleteCmd.ExecuteNonQuery();
await conn.ExecuteAsync("PRAGMA cache_size = -20000");
DbTransaction tx;
await using (tx = await conn.BeginTransactionAsync()) {
await reporter.SubWork("Deleting duplicates...", 0, 0);
await using (var deleteCmd = conn.Delete("downloads", ("url", SqliteType.Text))) {
foreach (var duplicateUrl in duplicateUrlsToDelete) {
deleteCmd.Set(":url", duplicateUrl);
deleteCmd.ExecuteNonQuery();
}
}
await tx.CommitAsync();
}
using (var updateCmd = conn.Command("UPDATE downloads SET download_url = :download_url, url = :normalized_url WHERE url = :download_url")) {
updateCmd.Parameters.Add(":normalized_url", SqliteType.Text);
updateCmd.Parameters.Add(":download_url", SqliteType.Text);
int totalUrls = normalizedUrlsToOriginalUrls.Count;
int processedUrls = -1;
tx = await conn.BeginTransactionAsync();
await using (var updateCmd = conn.Command("UPDATE downloads SET download_url = :download_url, url = :normalized_url WHERE url = :download_url")) {
updateCmd.Add(":normalized_url", SqliteType.Text);
updateCmd.Add(":download_url", SqliteType.Text);
foreach (var (normalizedUrl, downloadUrl) in normalizedUrlsToOriginalUrls) {
if (++processedUrls % 100 == 0) {
await reporter.SubWork("Updating URLs...", processedUrls, totalUrls);
// Not proper way of dealing with transactions, but it avoids a long commit at the end.
// Schema upgrades are already non-atomic anyways, so this doesn't make it worse.
await tx.CommitAsync();
await tx.DisposeAsync();
tx = await conn.BeginTransactionAsync();
updateCmd.Transaction = (SqliteTransaction) tx;
}
updateCmd.Set(":normalized_url", normalizedUrl);
updateCmd.Set(":download_url", downloadUrl);
updateCmd.ExecuteNonQuery();
}
}
tx.Commit();
await reporter.SubWork("Updating URLs...", totalUrls, totalUrls);
await tx.CommitAsync();
await tx.DisposeAsync();
await conn.ExecuteAsync("PRAGMA cache_size = -2000");
}
private void UpgradeSchemas(int dbVersion) {
private async Task UpgradeSchemas(int dbVersion, ISchemaUpgradeCallbacks.IProgressReporter reporter) {
var perf = Log.Start("from version " + dbVersion);
Execute("UPDATE metadata SET value = " + Version + " WHERE key = 'version'");
await conn.ExecuteAsync("UPDATE metadata SET value = " + Version + " WHERE key = 'version'");
if (dbVersion <= 1) {
Execute("ALTER TABLE channels ADD parent_id INTEGER");
await reporter.MainWork("Applying schema changes...", 0, 1);
await conn.ExecuteAsync("ALTER TABLE channels ADD parent_id INTEGER");
perf.Step("Upgrade to version 2");
await reporter.NextVersion();
}
if (dbVersion <= 2) {
CreateMessageEditTimestampTable();
CreateMessageRepliedToTable();
await reporter.MainWork("Applying schema changes...", 0, 1);
Execute("""
INSERT INTO edit_timestamps (message_id, edit_timestamp)
SELECT message_id, edit_timestamp
FROM messages
WHERE edit_timestamp IS NOT NULL
""");
await CreateMessageEditTimestampTable();
await CreateMessageRepliedToTable();
Execute("""
INSERT INTO replied_to (message_id, replied_to_id)
SELECT message_id, replied_to_id
FROM messages
WHERE replied_to_id IS NOT NULL
""");
await conn.ExecuteAsync("""
INSERT INTO edit_timestamps (message_id, edit_timestamp)
SELECT message_id, edit_timestamp
FROM messages
WHERE edit_timestamp IS NOT NULL
""");
Execute("ALTER TABLE messages DROP COLUMN replied_to_id");
Execute("ALTER TABLE messages DROP COLUMN edit_timestamp");
await conn.ExecuteAsync("""
INSERT INTO replied_to (message_id, replied_to_id)
SELECT message_id, replied_to_id
FROM messages
WHERE replied_to_id IS NOT NULL
""");
await conn.ExecuteAsync("ALTER TABLE messages DROP COLUMN replied_to_id");
await conn.ExecuteAsync("ALTER TABLE messages DROP COLUMN edit_timestamp");
perf.Step("Upgrade to version 3");
Execute("VACUUM");
await reporter.MainWork("Vacuuming the database...", 1, 1);
await conn.ExecuteAsync("VACUUM");
perf.Step("Vacuum");
await reporter.NextVersion();
}
if (dbVersion <= 3) {
Execute("""
CREATE TABLE downloads (
url TEXT NOT NULL PRIMARY KEY,
status INTEGER NOT NULL,
size INTEGER NOT NULL,
blob BLOB
)
""");
await conn.ExecuteAsync("""
CREATE TABLE downloads (
url TEXT NOT NULL PRIMARY KEY,
status INTEGER NOT NULL,
size INTEGER NOT NULL,
blob BLOB
)
""");
perf.Step("Upgrade to version 4");
await reporter.NextVersion();
}
if (dbVersion <= 4) {
Execute("ALTER TABLE attachments ADD width INTEGER");
Execute("ALTER TABLE attachments ADD height INTEGER");
await reporter.MainWork("Applying schema changes...", 0, 1);
await conn.ExecuteAsync("ALTER TABLE attachments ADD width INTEGER");
await conn.ExecuteAsync("ALTER TABLE attachments ADD height INTEGER");
perf.Step("Upgrade to version 5");
await reporter.NextVersion();
}
if (dbVersion <= 5) {
Execute("ALTER TABLE attachments ADD download_url TEXT");
Execute("ALTER TABLE downloads ADD download_url TEXT");
NormalizeAttachmentUrls();
NormalizeDownloadUrls();
Execute("ALTER TABLE attachments RENAME COLUMN url TO normalized_url");
Execute("ALTER TABLE downloads RENAME COLUMN url TO normalized_url");
await reporter.MainWork("Applying schema changes...", 0, 3);
await conn.ExecuteAsync("ALTER TABLE attachments ADD download_url TEXT");
await conn.ExecuteAsync("ALTER TABLE downloads ADD download_url TEXT");
await reporter.MainWork("Updating attachments...", 1, 3);
await NormalizeAttachmentUrls(reporter);
await reporter.MainWork("Updating downloads...", 2, 3);
await NormalizeDownloadUrls(reporter);
await reporter.MainWork("Applying schema changes...", 3, 3);
await conn.ExecuteAsync("ALTER TABLE attachments RENAME COLUMN url TO normalized_url");
await conn.ExecuteAsync("ALTER TABLE downloads RENAME COLUMN url TO normalized_url");
perf.Step("Upgrade to version 6");
await reporter.NextVersion();
}
perf.End();

View File

@@ -1,16 +1,8 @@
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Text;
using System.Threading.Tasks;
using DHT.Server.Data;
using DHT.Server.Data.Aggregations;
using DHT.Server.Data.Filters;
using DHT.Server.Database.Repositories;
using DHT.Server.Database.Sqlite.Repositories;
using DHT.Server.Database.Sqlite.Utils;
using DHT.Server.Download;
using DHT.Utils.Collections;
using DHT.Utils.Logging;
using DHT.Utils.Tasks;
using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite;
@@ -18,696 +10,73 @@ namespace DHT.Server.Database.Sqlite;
public sealed class SqliteDatabaseFile : IDatabaseFile {
private const int DefaultPoolSize = 5;
public static async Task<SqliteDatabaseFile?> OpenOrCreate(string path, Func<Task<bool>> checkCanUpgradeSchemas) {
public static async Task<SqliteDatabaseFile?> OpenOrCreate(string path, ISchemaUpgradeCallbacks schemaUpgradeCallbacks) {
var connectionString = new SqliteConnectionStringBuilder {
DataSource = path,
Mode = SqliteOpenMode.ReadWriteCreate,
};
var pool = new SqliteConnectionPool(connectionString, DefaultPoolSize);
var pool = await SqliteConnectionPool.Create(connectionString, DefaultPoolSize);
bool wasOpened;
using (var conn = pool.Take()) {
wasOpened = await new Schema(conn).Setup(checkCanUpgradeSchemas);
try {
await using var conn = await pool.Take();
wasOpened = await new Schema(conn).Setup(schemaUpgradeCallbacks);
} catch (Exception) {
await pool.DisposeAsync();
throw;
}
if (wasOpened) {
return new SqliteDatabaseFile(path, pool);
}
else {
pool.Dispose();
await pool.DisposeAsync();
return null;
}
}
public string Path { get; }
public DatabaseStatistics Statistics { get; }
private readonly Log log;
public IUserRepository Users => users;
public IServerRepository Servers => servers;
public IChannelRepository Channels => channels;
public IMessageRepository Messages => messages;
public IAttachmentRepository Attachments => attachments;
public IDownloadRepository Downloads => downloads;
private readonly SqliteConnectionPool pool;
private readonly AsyncValueComputer<long>.Single totalMessagesComputer;
private readonly AsyncValueComputer<long>.Single totalAttachmentsComputer;
private readonly AsyncValueComputer<long>.Single totalDownloadsComputer;
private readonly SqliteUserRepository users;
private readonly SqliteServerRepository servers;
private readonly SqliteChannelRepository channels;
private readonly SqliteMessageRepository messages;
private readonly SqliteAttachmentRepository attachments;
private readonly SqliteDownloadRepository downloads;
private SqliteDatabaseFile(string path, SqliteConnectionPool pool) {
this.log = Log.ForType(typeof(SqliteDatabaseFile), System.IO.Path.GetFileName(path));
this.Path = path;
this.pool = pool;
this.totalMessagesComputer = AsyncValueComputer<long>.WithResultProcessor(UpdateMessageStatistics).WithOutdatedResults().BuildWithComputer(ComputeMessageStatistics);
this.totalAttachmentsComputer = AsyncValueComputer<long>.WithResultProcessor(UpdateAttachmentStatistics).WithOutdatedResults().BuildWithComputer(ComputeAttachmentStatistics);
this.totalDownloadsComputer = AsyncValueComputer<long>.WithResultProcessor(UpdateDownloadStatistics).WithOutdatedResults().BuildWithComputer(ComputeDownloadStatistics);
this.Path = path;
this.Statistics = new DatabaseStatistics();
using (var conn = pool.Take()) {
UpdateServerStatistics(conn);
UpdateChannelStatistics(conn);
UpdateUserStatistics(conn);
}
totalMessagesComputer.Recompute();
totalAttachmentsComputer.Recompute();
totalDownloadsComputer.Recompute();
users = new SqliteUserRepository(pool);
servers = new SqliteServerRepository(pool);
channels = new SqliteChannelRepository(pool);
messages = new SqliteMessageRepository(pool, attachments = new SqliteAttachmentRepository(pool));
downloads = new SqliteDownloadRepository(pool);
}
public void Dispose() {
pool.Dispose();
public async ValueTask DisposeAsync() {
users.Dispose();
servers.Dispose();
channels.Dispose();
messages.Dispose();
attachments.Dispose();
downloads.Dispose();
await pool.DisposeAsync();
}
public DatabaseStatisticsSnapshot SnapshotStatistics() {
return new DatabaseStatisticsSnapshot {
TotalServers = Statistics.TotalServers,
TotalChannels = Statistics.TotalChannels,
TotalUsers = Statistics.TotalUsers,
TotalMessages = ComputeMessageStatistics(),
};
}
public void AddServer(Data.Server server) {
using var conn = pool.Take();
using var cmd = conn.Upsert("servers", new[] {
("id", SqliteType.Integer),
("name", SqliteType.Text),
("type", SqliteType.Text),
});
cmd.Set(":id", server.Id);
cmd.Set(":name", server.Name);
cmd.Set(":type", ServerTypes.ToString(server.Type));
cmd.ExecuteNonQuery();
UpdateServerStatistics(conn);
}
public List<Data.Server> GetAllServers() {
var perf = log.Start();
var list = new List<Data.Server>();
using var conn = pool.Take();
using var cmd = conn.Command("SELECT id, name, type FROM servers");
using var reader = cmd.ExecuteReader();
while (reader.Read()) {
list.Add(new Data.Server {
Id = reader.GetUint64(0),
Name = reader.GetString(1),
Type = ServerTypes.FromString(reader.GetString(2)),
});
}
perf.End();
return list;
}
public void AddChannel(Channel channel) {
using var conn = pool.Take();
using var cmd = conn.Upsert("channels", new[] {
("id", SqliteType.Integer),
("server", SqliteType.Integer),
("name", SqliteType.Text),
("parent_id", SqliteType.Integer),
("position", SqliteType.Integer),
("topic", SqliteType.Text),
("nsfw", SqliteType.Integer),
});
cmd.Set(":id", channel.Id);
cmd.Set(":server", channel.Server);
cmd.Set(":name", channel.Name);
cmd.Set(":parent_id", channel.ParentId);
cmd.Set(":position", channel.Position);
cmd.Set(":topic", channel.Topic);
cmd.Set(":nsfw", channel.Nsfw);
cmd.ExecuteNonQuery();
UpdateChannelStatistics(conn);
}
public List<Channel> GetAllChannels() {
var list = new List<Channel>();
using var conn = pool.Take();
using var cmd = conn.Command("SELECT id, server, name, parent_id, position, topic, nsfw FROM channels");
using var reader = cmd.ExecuteReader();
while (reader.Read()) {
list.Add(new Channel {
Id = reader.GetUint64(0),
Server = reader.GetUint64(1),
Name = reader.GetString(2),
ParentId = reader.IsDBNull(3) ? null : reader.GetUint64(3),
Position = reader.IsDBNull(4) ? null : reader.GetInt32(4),
Topic = reader.IsDBNull(5) ? null : reader.GetString(5),
Nsfw = reader.IsDBNull(6) ? null : reader.GetBoolean(6),
});
}
return list;
}
public void AddUsers(User[] users) {
using var conn = pool.Take();
using var tx = conn.BeginTransaction();
using var cmd = conn.Upsert("users", new[] {
("id", SqliteType.Integer),
("name", SqliteType.Text),
("avatar_url", SqliteType.Text),
("discriminator", SqliteType.Text),
});
foreach (var user in users) {
cmd.Set(":id", user.Id);
cmd.Set(":name", user.Name);
cmd.Set(":avatar_url", user.AvatarUrl);
cmd.Set(":discriminator", user.Discriminator);
cmd.ExecuteNonQuery();
}
tx.Commit();
UpdateUserStatistics(conn);
}
public List<User> GetAllUsers() {
var perf = log.Start();
var list = new List<User>();
using var conn = pool.Take();
using var cmd = conn.Command("SELECT id, name, avatar_url, discriminator FROM users");
using var reader = cmd.ExecuteReader();
while (reader.Read()) {
list.Add(new User {
Id = reader.GetUint64(0),
Name = reader.GetString(1),
AvatarUrl = reader.IsDBNull(2) ? null : reader.GetString(2),
Discriminator = reader.IsDBNull(3) ? null : reader.GetString(3),
});
}
perf.End();
return list;
}
public void AddMessages(Message[] messages) {
static SqliteCommand DeleteByMessageId(ISqliteConnection conn, string tableName) {
return conn.Delete(tableName, ("message_id", SqliteType.Integer));
}
static void ExecuteDeleteByMessageId(SqliteCommand cmd, object id) {
cmd.Set(":message_id", id);
cmd.ExecuteNonQuery();
}
bool addedAttachments = false;
using (var conn = pool.Take()) {
using var tx = conn.BeginTransaction();
using var messageCmd = conn.Upsert("messages", new[] {
("message_id", SqliteType.Integer),
("sender_id", SqliteType.Integer),
("channel_id", SqliteType.Integer),
("text", SqliteType.Text),
("timestamp", SqliteType.Integer),
});
using var deleteEditTimestampCmd = DeleteByMessageId(conn, "edit_timestamps");
using var deleteRepliedToCmd = DeleteByMessageId(conn, "replied_to");
using var deleteAttachmentsCmd = DeleteByMessageId(conn, "attachments");
using var deleteEmbedsCmd = DeleteByMessageId(conn, "embeds");
using var deleteReactionsCmd = DeleteByMessageId(conn, "reactions");
using var editTimestampCmd = conn.Insert("edit_timestamps", new [] {
("message_id", SqliteType.Integer),
("edit_timestamp", SqliteType.Integer),
});
using var repliedToCmd = conn.Insert("replied_to", new [] {
("message_id", SqliteType.Integer),
("replied_to_id", SqliteType.Integer),
});
using var attachmentCmd = conn.Insert("attachments", new[] {
("message_id", SqliteType.Integer),
("attachment_id", SqliteType.Integer),
("name", SqliteType.Text),
("type", SqliteType.Text),
("normalized_url", SqliteType.Text),
("download_url", SqliteType.Text),
("size", SqliteType.Integer),
("width", SqliteType.Integer),
("height", SqliteType.Integer),
});
using var embedCmd = conn.Insert("embeds", new[] {
("message_id", SqliteType.Integer),
("json", SqliteType.Text),
});
using var reactionCmd = conn.Insert("reactions", new[] {
("message_id", SqliteType.Integer),
("emoji_id", SqliteType.Integer),
("emoji_name", SqliteType.Text),
("emoji_flags", SqliteType.Integer),
("count", SqliteType.Integer),
});
foreach (var message in messages) {
object messageId = message.Id;
messageCmd.Set(":message_id", messageId);
messageCmd.Set(":sender_id", message.Sender);
messageCmd.Set(":channel_id", message.Channel);
messageCmd.Set(":text", message.Text);
messageCmd.Set(":timestamp", message.Timestamp);
messageCmd.ExecuteNonQuery();
ExecuteDeleteByMessageId(deleteEditTimestampCmd, messageId);
ExecuteDeleteByMessageId(deleteRepliedToCmd, messageId);
ExecuteDeleteByMessageId(deleteAttachmentsCmd, messageId);
ExecuteDeleteByMessageId(deleteEmbedsCmd, messageId);
ExecuteDeleteByMessageId(deleteReactionsCmd, messageId);
if (message.EditTimestamp is {} timestamp) {
editTimestampCmd.Set(":message_id", messageId);
editTimestampCmd.Set(":edit_timestamp", timestamp);
editTimestampCmd.ExecuteNonQuery();
}
if (message.RepliedToId is {} repliedToId) {
repliedToCmd.Set(":message_id", messageId);
repliedToCmd.Set(":replied_to_id", repliedToId);
repliedToCmd.ExecuteNonQuery();
}
if (!message.Attachments.IsEmpty) {
addedAttachments = true;
foreach (var attachment in message.Attachments) {
attachmentCmd.Set(":message_id", messageId);
attachmentCmd.Set(":attachment_id", attachment.Id);
attachmentCmd.Set(":name", attachment.Name);
attachmentCmd.Set(":type", attachment.Type);
attachmentCmd.Set(":normalized_url", attachment.NormalizedUrl);
attachmentCmd.Set(":download_url", attachment.DownloadUrl);
attachmentCmd.Set(":size", attachment.Size);
attachmentCmd.Set(":width", attachment.Width);
attachmentCmd.Set(":height", attachment.Height);
attachmentCmd.ExecuteNonQuery();
}
}
if (!message.Embeds.IsEmpty) {
foreach (var embed in message.Embeds) {
embedCmd.Set(":message_id", messageId);
embedCmd.Set(":json", embed.Json);
embedCmd.ExecuteNonQuery();
}
}
if (!message.Reactions.IsEmpty) {
foreach (var reaction in message.Reactions) {
reactionCmd.Set(":message_id", messageId);
reactionCmd.Set(":emoji_id", reaction.EmojiId);
reactionCmd.Set(":emoji_name", reaction.EmojiName);
reactionCmd.Set(":emoji_flags", (int) reaction.EmojiFlags);
reactionCmd.Set(":count", reaction.Count);
reactionCmd.ExecuteNonQuery();
}
}
}
tx.Commit();
}
totalMessagesComputer.Recompute();
if (addedAttachments) {
totalAttachmentsComputer.Recompute();
}
}
public int CountMessages(MessageFilter? filter = null) {
using var conn = pool.Take();
using var cmd = conn.Command("SELECT COUNT(*) FROM messages" + filter.GenerateWhereClause());
using var reader = cmd.ExecuteReader();
return reader.Read() ? reader.GetInt32(0) : 0;
}
public List<Message> GetMessages(MessageFilter? filter = null) {
var perf = log.Start();
var list = new List<Message>();
var attachments = GetAllAttachments();
var embeds = GetAllEmbeds();
var reactions = GetAllReactions();
using var conn = pool.Take();
using var cmd = conn.Command($"""
SELECT m.message_id, m.sender_id, m.channel_id, m.text, m.timestamp, et.edit_timestamp, rt.replied_to_id
FROM messages m
LEFT JOIN edit_timestamps et ON m.message_id = et.message_id
LEFT JOIN replied_to rt ON m.message_id = rt.message_id
{filter.GenerateWhereClause("m")}
""");
using var reader = cmd.ExecuteReader();
while (reader.Read()) {
ulong id = reader.GetUint64(0);
list.Add(new Message {
Id = id,
Sender = reader.GetUint64(1),
Channel = reader.GetUint64(2),
Text = reader.GetString(3),
Timestamp = reader.GetInt64(4),
EditTimestamp = reader.IsDBNull(5) ? null : reader.GetInt64(5),
RepliedToId = reader.IsDBNull(6) ? null : reader.GetUint64(6),
Attachments = attachments.GetListOrNull(id)?.ToImmutableArray() ?? ImmutableArray<Attachment>.Empty,
Embeds = embeds.GetListOrNull(id)?.ToImmutableArray() ?? ImmutableArray<Embed>.Empty,
Reactions = reactions.GetListOrNull(id)?.ToImmutableArray() ?? ImmutableArray<Reaction>.Empty,
});
}
perf.End();
return list;
}
public HashSet<ulong> GetMessageIds(MessageFilter? filter = null) {
var perf = log.Start();
var ids = new HashSet<ulong>();
using var conn = pool.Take();
using var cmd = conn.Command("SELECT message_id FROM messages" + filter.GenerateWhereClause());
using var reader = cmd.ExecuteReader();
while (reader.Read()) {
ids.Add(reader.GetUint64(0));
}
perf.End();
return ids;
}
public void RemoveMessages(MessageFilter filter, FilterRemovalMode mode) {
var perf = log.Start();
DeleteFromTable("messages", filter.GenerateWhereClause(invert: mode == FilterRemovalMode.KeepMatching));
totalMessagesComputer.Recompute();
perf.End();
}
public int CountAttachments(AttachmentFilter? filter = null) {
using var conn = pool.Take();
using var cmd = conn.Command("SELECT COUNT(DISTINCT normalized_url) FROM attachments a" + filter.GenerateWhereClause("a"));
using var reader = cmd.ExecuteReader();
return reader.Read() ? reader.GetInt32(0) : 0;
}
public void AddDownload(Data.Download download) {
using var conn = pool.Take();
using var cmd = conn.Upsert("downloads", new[] {
("normalized_url", SqliteType.Text),
("download_url", SqliteType.Text),
("status", SqliteType.Integer),
("size", SqliteType.Integer),
("blob", SqliteType.Blob),
});
cmd.Set(":normalized_url", download.NormalizedUrl);
cmd.Set(":download_url", download.DownloadUrl);
cmd.Set(":status", (int) download.Status);
cmd.Set(":size", download.Size);
cmd.Set(":blob", download.Data);
cmd.ExecuteNonQuery();
totalDownloadsComputer.Recompute();
}
public List<Data.Download> GetDownloadsWithoutData() {
var list = new List<Data.Download>();
using var conn = pool.Take();
using var cmd = conn.Command("SELECT normalized_url, download_url, status, size FROM downloads");
using var reader = cmd.ExecuteReader();
while (reader.Read()) {
string normalizedUrl = reader.GetString(0);
string downloadUrl = reader.GetString(1);
var status = (DownloadStatus) reader.GetInt32(2);
ulong size = reader.GetUint64(3);
list.Add(new Data.Download(normalizedUrl, downloadUrl, status, size));
}
return list;
}
public Data.Download GetDownloadWithData(Data.Download download) {
using var conn = pool.Take();
using var cmd = conn.Command("SELECT blob FROM downloads WHERE normalized_url = :url");
cmd.AddAndSet(":url", SqliteType.Text, download.NormalizedUrl);
using var reader = cmd.ExecuteReader();
if (reader.Read() && !reader.IsDBNull(0)) {
return download.WithData((byte[]) reader["blob"]);
}
else {
return download;
}
}
public DownloadedAttachment? GetDownloadedAttachment(string normalizedUrl) {
using var conn = pool.Take();
using var cmd = conn.Command("""
SELECT a.type, d.blob FROM downloads d
LEFT JOIN attachments a ON d.normalized_url = a.normalized_url
WHERE d.normalized_url = :normalized_url AND d.status = :success AND d.blob IS NOT NULL
""");
cmd.AddAndSet(":normalized_url", SqliteType.Text, normalizedUrl);
cmd.AddAndSet(":success", SqliteType.Integer, (int) DownloadStatus.Success);
using var reader = cmd.ExecuteReader();
if (!reader.Read()) {
return null;
}
return new DownloadedAttachment {
Type = reader.IsDBNull(0) ? null : reader.GetString(0),
Data = (byte[]) reader["blob"],
};
}
public void EnqueueDownloadItems(AttachmentFilter? filter = null) {
using var conn = pool.Take();
using var cmd = conn.Command($"""
INSERT INTO downloads (normalized_url, download_url, status, size)
SELECT a.normalized_url, a.download_url, :enqueued, MAX(a.size)
FROM attachments a
{filter.GenerateWhereClause("a")}
GROUP BY a.normalized_url
""");
cmd.AddAndSet(":enqueued", SqliteType.Integer, (int) DownloadStatus.Enqueued);
cmd.ExecuteNonQuery();
}
public List<DownloadItem> GetEnqueuedDownloadItems(int count) {
var list = new List<DownloadItem>();
using var conn = pool.Take();
using var cmd = conn.Command("SELECT normalized_url, download_url, size FROM downloads WHERE status = :enqueued LIMIT :limit");
cmd.AddAndSet(":enqueued", SqliteType.Integer, (int) DownloadStatus.Enqueued);
cmd.AddAndSet(":limit", SqliteType.Integer, Math.Max(0, count));
using var reader = cmd.ExecuteReader();
while (reader.Read()) {
list.Add(new DownloadItem {
NormalizedUrl = reader.GetString(0),
DownloadUrl = reader.GetString(1),
Size = reader.GetUint64(2),
});
}
return list;
}
public void RemoveDownloadItems(DownloadItemFilter? filter, FilterRemovalMode mode) {
DeleteFromTable("downloads", filter.GenerateWhereClause(invert: mode == FilterRemovalMode.KeepMatching));
totalDownloadsComputer.Recompute();
}
public DownloadStatusStatistics GetDownloadStatusStatistics() {
static void LoadUndownloadedStatistics(ISqliteConnection conn, DownloadStatusStatistics result) {
using var cmd = conn.Command("SELECT IFNULL(COUNT(size), 0), IFNULL(SUM(size), 0) FROM (SELECT MAX(a.size) size FROM attachments a WHERE a.normalized_url NOT IN (SELECT d.normalized_url FROM downloads d) GROUP BY a.normalized_url)");
using var reader = cmd.ExecuteReader();
if (reader.Read()) {
result.SkippedCount = reader.GetInt32(0);
result.SkippedSize = reader.GetUint64(1);
}
}
static void LoadSuccessStatistics(ISqliteConnection conn, DownloadStatusStatistics result) {
using var cmd = conn.Command("""
SELECT
IFNULL(SUM(CASE WHEN status = :enqueued THEN 1 ELSE 0 END), 0),
IFNULL(SUM(CASE WHEN status = :enqueued THEN size ELSE 0 END), 0),
IFNULL(SUM(CASE WHEN status = :success THEN 1 ELSE 0 END), 0),
IFNULL(SUM(CASE WHEN status = :success THEN size ELSE 0 END), 0),
IFNULL(SUM(CASE WHEN status != :enqueued AND status != :success THEN 1 ELSE 0 END), 0),
IFNULL(SUM(CASE WHEN status != :enqueued AND status != :success THEN size ELSE 0 END), 0)
FROM downloads
""");
cmd.AddAndSet(":enqueued", SqliteType.Integer, (int) DownloadStatus.Enqueued);
cmd.AddAndSet(":success", SqliteType.Integer, (int) DownloadStatus.Success);
using var reader = cmd.ExecuteReader();
if (reader.Read()) {
result.EnqueuedCount = reader.GetInt32(0);
result.EnqueuedSize = reader.GetUint64(1);
result.SuccessfulCount = reader.GetInt32(2);
result.SuccessfulSize = reader.GetUint64(3);
result.FailedCount = reader.GetInt32(4);
result.FailedSize = reader.GetUint64(5);
}
}
var result = new DownloadStatusStatistics();
using var conn = pool.Take();
LoadUndownloadedStatistics(conn, result);
LoadSuccessStatistics(conn, result);
return result;
}
private MultiDictionary<ulong, Attachment> GetAllAttachments() {
var dict = new MultiDictionary<ulong, Attachment>();
using var conn = pool.Take();
using var cmd = conn.Command("SELECT message_id, attachment_id, name, type, normalized_url, download_url, size, width, height FROM attachments");
using var reader = cmd.ExecuteReader();
while (reader.Read()) {
ulong messageId = reader.GetUint64(0);
dict.Add(messageId, new Attachment {
Id = reader.GetUint64(1),
Name = reader.GetString(2),
Type = reader.IsDBNull(3) ? null : reader.GetString(3),
NormalizedUrl = reader.GetString(4),
DownloadUrl = reader.GetString(5),
Size = reader.GetUint64(6),
Width = reader.IsDBNull(7) ? null : reader.GetInt32(7),
Height = reader.IsDBNull(8) ? null : reader.GetInt32(8),
});
}
return dict;
}
private MultiDictionary<ulong, Embed> GetAllEmbeds() {
var dict = new MultiDictionary<ulong, Embed>();
using var conn = pool.Take();
using var cmd = conn.Command("SELECT message_id, json FROM embeds");
using var reader = cmd.ExecuteReader();
while (reader.Read()) {
ulong messageId = reader.GetUint64(0);
dict.Add(messageId, new Embed {
Json = reader.GetString(1),
});
}
return dict;
}
private MultiDictionary<ulong, Reaction> GetAllReactions() {
var dict = new MultiDictionary<ulong, Reaction>();
using var conn = pool.Take();
using var cmd = conn.Command("SELECT message_id, emoji_id, emoji_name, emoji_flags, count FROM reactions");
using var reader = cmd.ExecuteReader();
while (reader.Read()) {
ulong messageId = reader.GetUint64(0);
dict.Add(messageId, new Reaction {
EmojiId = reader.IsDBNull(1) ? null : reader.GetUint64(1),
EmojiName = reader.IsDBNull(2) ? null : reader.GetString(2),
EmojiFlags = (EmojiFlags) reader.GetInt16(3),
Count = reader.GetInt32(4),
});
}
return dict;
}
private void DeleteFromTable(string table, string whereClause) {
// Rider is being stupid...
StringBuilder build = new StringBuilder()
.Append("DELETE ")
.Append("FROM ")
.Append(table)
.Append(whereClause);
using var conn = pool.Take();
using var cmd = conn.Command(build.ToString());
cmd.ExecuteNonQuery();
}
public void Vacuum() {
using var conn = pool.Take();
using var cmd = conn.Command("VACUUM");
cmd.ExecuteNonQuery();
}
private void UpdateServerStatistics(ISqliteConnection conn) {
Statistics.TotalServers = conn.SelectScalar("SELECT COUNT(*) FROM servers") as long? ?? 0;
}
private void UpdateChannelStatistics(ISqliteConnection conn) {
Statistics.TotalChannels = conn.SelectScalar("SELECT COUNT(*) FROM channels") as long? ?? 0;
}
private void UpdateUserStatistics(ISqliteConnection conn) {
Statistics.TotalUsers = conn.SelectScalar("SELECT COUNT(*) FROM users") as long? ?? 0;
}
private long ComputeMessageStatistics() {
using var conn = pool.Take();
return conn.SelectScalar("SELECT COUNT(*) FROM messages") as long? ?? 0L;
}
private void UpdateMessageStatistics(long totalMessages) {
Statistics.TotalMessages = totalMessages;
}
private long ComputeAttachmentStatistics() {
using var conn = pool.Take();
return conn.SelectScalar("SELECT COUNT(DISTINCT normalized_url) FROM attachments") as long? ?? 0L;
}
private void UpdateAttachmentStatistics(long totalAttachments) {
Statistics.TotalAttachments = totalAttachments;
}
private long ComputeDownloadStatistics() {
using var conn = pool.Take();
return conn.SelectScalar("SELECT COUNT(*) FROM downloads") as long? ?? 0L;
}
private void UpdateDownloadStatistics(long totalDownloads) {
Statistics.TotalDownloads = totalDownloads;
public async Task Vacuum() {
await using var conn = await pool.Take();
await conn.ExecuteAsync("VACUUM");
}
}

View File

@@ -8,9 +8,13 @@ using DHT.Server.Database.Sqlite.Utils;
namespace DHT.Server.Database.Sqlite;
static class SqliteFilters {
private static string WhereAll(bool invert) {
return invert ? "WHERE FALSE" : "";
}
public static string GenerateWhereClause(this MessageFilter? filter, string? tableAlias = null, bool invert = false) {
if (filter == null) {
return "";
if (filter == null || filter.IsEmpty) {
return WhereAll(invert);
}
var where = new SqliteWhereGenerator(tableAlias, invert);
@@ -39,8 +43,8 @@ static class SqliteFilters {
}
public static string GenerateWhereClause(this AttachmentFilter? filter, string? tableAlias = null, bool invert = false) {
if (filter == null) {
return "";
if (filter == null || filter.IsEmpty) {
return WhereAll(invert);
}
var where = new SqliteWhereGenerator(tableAlias, invert);
@@ -60,8 +64,8 @@ static class SqliteFilters {
}
public static string GenerateWhereClause(this DownloadItemFilter? filter, string? tableAlias = null, bool invert = false) {
if (filter == null) {
return "";
if (filter == null || filter.IsEmpty) {
return WhereAll(invert);
}
var where = new SqliteWhereGenerator(tableAlias, invert);

View File

@@ -0,0 +1,15 @@
using System;
using System.Threading.Tasks;
namespace DHT.Server.Database.Sqlite.Utils;
public interface ISchemaUpgradeCallbacks {
Task<bool> CanUpgrade();
Task Start(int versionSteps, Func<IProgressReporter, Task> doUpgrade);
public interface IProgressReporter {
Task NextVersion();
Task MainWork(string message, int finishedItems, int totalItems);
Task SubWork(string message, int finishedItems, int totalItems);
}
}

View File

@@ -3,6 +3,6 @@ using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite.Utils;
interface ISqliteConnection : IDisposable {
interface ISqliteConnection : IAsyncDisposable {
SqliteConnection InnerConnection { get; }
}

View File

@@ -1,100 +1,77 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Threading;
using DHT.Utils.Logging;
using System.Threading.Tasks;
using DHT.Utils.Collections;
using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite.Utils;
sealed class SqliteConnectionPool : IDisposable {
sealed class SqliteConnectionPool : IAsyncDisposable {
public static async Task<SqliteConnectionPool> Create(SqliteConnectionStringBuilder connectionStringBuilder, int poolSize) {
var pool = new SqliteConnectionPool(poolSize);
await pool.InitializePooledConnections(connectionStringBuilder);
return pool;
}
private static string GetConnectionString(SqliteConnectionStringBuilder connectionStringBuilder) {
connectionStringBuilder.Pooling = false;
return connectionStringBuilder.ToString();
}
private readonly object monitor = new ();
private readonly Random rand = new ();
private volatile bool isDisposed;
private readonly int poolSize;
private readonly List<PooledConnection> all;
private readonly ConcurrentPool<PooledConnection> free;
private readonly BlockingCollection<PooledConnection> free = new (new ConcurrentStack<PooledConnection>());
private readonly List<PooledConnection> used;
private readonly CancellationTokenSource disposalTokenSource = new ();
private readonly CancellationToken disposalToken;
public SqliteConnectionPool(SqliteConnectionStringBuilder connectionStringBuilder, int poolSize) {
private SqliteConnectionPool(int poolSize) {
this.poolSize = poolSize;
this.all = new List<PooledConnection>(poolSize);
this.free = new ConcurrentPool<PooledConnection>(poolSize);
this.disposalToken = disposalTokenSource.Token;
}
private async Task InitializePooledConnections(SqliteConnectionStringBuilder connectionStringBuilder) {
var connectionString = GetConnectionString(connectionStringBuilder);
for (int i = 0; i < poolSize; i++) {
var conn = new SqliteConnection(connectionString);
conn.Open();
var pooledConn = new PooledConnection(this, conn);
var pooledConnection = new PooledConnection(this, conn);
using (var cmd = pooledConn.Command("PRAGMA journal_mode=WAL")) {
cmd.ExecuteNonQuery();
await using (var cmd = pooledConnection.Command("PRAGMA journal_mode=WAL")) {
await cmd.ExecuteNonQueryAsync(disposalToken);
}
free.Add(pooledConn);
}
used = new List<PooledConnection>(poolSize);
}
private void ThrowIfDisposed() {
ObjectDisposedException.ThrowIf(isDisposed, nameof(SqliteConnectionPool));
}
public ISqliteConnection Take() {
while (true) {
ThrowIfDisposed();
lock (monitor) {
if (free.TryTake(out var conn)) {
used.Add(conn);
return conn;
}
else {
Log.ForType<SqliteConnectionPool>().Warn("Thread " + Environment.CurrentManagedThreadId + " is starving for connections.");
}
}
Thread.Sleep(TimeSpan.FromMilliseconds(rand.Next(100, 200)));
all.Add(pooledConnection);
await free.Push(pooledConnection, disposalToken);
}
}
private void Return(PooledConnection conn) {
ThrowIfDisposed();
lock (monitor) {
if (used.Remove(conn)) {
free.Add(conn);
}
}
public async Task<ISqliteConnection> Take() {
return await free.Pop(disposalToken);
}
public void Dispose() {
if (isDisposed) {
private async Task Return(PooledConnection conn) {
await free.Push(conn, disposalToken);
}
public async ValueTask DisposeAsync() {
if (disposalToken.IsCancellationRequested) {
return;
}
isDisposed = true;
lock (monitor) {
while (free.TryTake(out var conn)) {
Close(conn.InnerConnection);
}
foreach (var conn in used) {
Close(conn.InnerConnection);
}
free.Dispose();
used.Clear();
await disposalTokenSource.CancelAsync();
foreach (var conn in all) {
await conn.InnerConnection.CloseAsync();
await conn.InnerConnection.DisposeAsync();
}
}
private static void Close(SqliteConnection conn) {
conn.Close();
conn.Dispose();
disposalTokenSource.Dispose();
}
private sealed class PooledConnection : ISqliteConnection {
@@ -107,8 +84,8 @@ sealed class SqliteConnectionPool : IDisposable {
this.InnerConnection = conn;
}
void IDisposable.Dispose() {
pool.Return(this);
public async ValueTask DisposeAsync() {
await pool.Return(this);
}
}
}

View File

@@ -1,23 +1,34 @@
using System;
using System.Data.Common;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using JetBrains.Annotations;
using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite.Utils;
static class SqliteExtensions {
public static ValueTask<DbTransaction> BeginTransactionAsync(this ISqliteConnection conn) {
return conn.InnerConnection.BeginTransactionAsync();
}
public static SqliteCommand Command(this ISqliteConnection conn, string sql) {
var cmd = conn.InnerConnection.CreateCommand();
cmd.CommandText = sql;
return cmd;
}
public static SqliteTransaction BeginTransaction(this ISqliteConnection conn) {
return conn.InnerConnection.BeginTransaction();
public static async Task<int> ExecuteAsync(this ISqliteConnection conn, [LanguageInjection("sql")] string sql, CancellationToken cancellationToken = default) {
await using var cmd = conn.Command(sql);
return await cmd.ExecuteNonQueryAsync(cancellationToken);
}
public static async Task<T> ExecuteReaderAsync<T>(this ISqliteConnection conn, string sql, Func<SqliteDataReader?, T> readFunction, CancellationToken cancellationToken = default) {
await using var cmd = conn.Command(sql);
await using var reader = await cmd.ExecuteReaderAsync(cancellationToken);
public static object? SelectScalar(this ISqliteConnection conn, string sql) {
using var cmd = conn.Command(sql);
return cmd.ExecuteScalar();
return reader.Read() ? readFunction(reader) : readFunction(null);
}
public static SqliteCommand Insert(this ISqliteConnection conn, string tableName, (string Name, SqliteType Type)[] columns) {
@@ -47,7 +58,7 @@ static class SqliteExtensions {
public static SqliteCommand Delete(this ISqliteConnection conn, string tableName, (string Name, SqliteType Type) column) {
var cmd = conn.Command("DELETE FROM " + tableName + " WHERE " + column.Name + " = :" + column.Name);
CreateParameters(cmd, new [] { column });
CreateParameters(cmd, [column]);
return cmd;
}
@@ -57,6 +68,10 @@ static class SqliteExtensions {
}
}
public static void Add(this SqliteCommand cmd, string key, SqliteType type) {
cmd.Parameters.Add(key, type);
}
public static void AddAndSet(this SqliteCommand cmd, string key, SqliteType type, object? value) {
cmd.Parameters.Add(key, type).Value = value ?? DBNull.Value;
}

View File

@@ -5,7 +5,7 @@ namespace DHT.Server.Database.Sqlite.Utils;
sealed class SqliteWhereGenerator {
private readonly string? tableAlias;
private readonly bool invert;
private readonly List<string> conditions = new ();
private readonly List<string> conditions = [];
public SqliteWhereGenerator(string? tableAlias, bool invert) {
this.tableAlias = tableAlias;

View File

@@ -1,130 +0,0 @@
using System;
using System.Collections.Concurrent;
using System.Net.Http;
using System.Threading;
using DHT.Server.Database;
using DHT.Utils.Logging;
using DHT.Utils.Models;
namespace DHT.Server.Download;
public sealed class BackgroundDownloadThread : BaseModel {
private static readonly Log Log = Log.ForType<BackgroundDownloadThread>();
public event EventHandler<DownloadItem>? OnItemFinished {
add => parameters.OnItemFinished += value;
remove => parameters.OnItemFinished -= value;
}
public event EventHandler? OnServerStopped {
add => parameters.OnServerStopped += value;
remove => parameters.OnServerStopped -= value;
}
private readonly CancellationTokenSource cancellationTokenSource;
private readonly ThreadInstance.Parameters parameters;
public BackgroundDownloadThread(IDatabaseFile db) {
this.cancellationTokenSource = new CancellationTokenSource();
this.parameters = new ThreadInstance.Parameters(db, cancellationTokenSource);
var thread = new Thread(new ThreadInstance().Work) {
Name = "DHT download thread"
};
thread.Start(parameters);
}
public void StopThread() {
try {
cancellationTokenSource.Cancel();
} catch (ObjectDisposedException) {
Log.Warn("Attempted to stop background download thread after the cancellation token has been disposed.");
}
}
private sealed class ThreadInstance {
private const int QueueSize = 32;
public sealed class Parameters {
public event EventHandler<DownloadItem>? OnItemFinished;
public event EventHandler? OnServerStopped;
public IDatabaseFile Db { get; }
public CancellationTokenSource CancellationTokenSource { get; }
public Parameters(IDatabaseFile db, CancellationTokenSource cancellationTokenSource) {
Db = db;
CancellationTokenSource = cancellationTokenSource;
}
public void FireOnItemFinished(DownloadItem item) {
OnItemFinished?.Invoke(null, item);
}
public void FireOnServerStopped() {
OnServerStopped?.Invoke(null, EventArgs.Empty);
}
}
private readonly HttpClient client = new ();
public ThreadInstance() {
client.DefaultRequestHeaders.UserAgent.ParseAdd("Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/99.0.4844.51 Safari/537.36");
}
public async void Work(object? obj) {
var parameters = (Parameters) obj!;
var cancellationTokenSource = parameters.CancellationTokenSource;
var cancellationToken = cancellationTokenSource.Token;
var db = parameters.Db;
var queue = new ConcurrentQueue<DownloadItem>();
try {
while (!cancellationToken.IsCancellationRequested) {
FillQueue(db, queue, cancellationToken);
while (!cancellationToken.IsCancellationRequested && queue.TryDequeue(out var item)) {
var downloadUrl = item.DownloadUrl;
Log.Debug("Downloading " + downloadUrl + "...");
try {
db.AddDownload(Data.Download.NewSuccess(item, await client.GetByteArrayAsync(downloadUrl, cancellationToken)));
} catch (HttpRequestException e) {
db.AddDownload(Data.Download.NewFailure(item, e.StatusCode, item.Size));
Log.Error(e);
} catch (Exception e) {
db.AddDownload(Data.Download.NewFailure(item, null, item.Size));
Log.Error(e);
} finally {
parameters.FireOnItemFinished(item);
}
}
}
} catch (OperationCanceledException) {
//
} catch (ObjectDisposedException) {
//
} finally {
cancellationTokenSource.Dispose();
parameters.FireOnServerStopped();
}
}
private static void FillQueue(IDatabaseFile db, ConcurrentQueue<DownloadItem> queue, CancellationToken cancellationToken) {
while (!cancellationToken.IsCancellationRequested && queue.IsEmpty) {
var newItems = db.GetEnqueuedDownloadItems(QueueSize);
if (newItems.Count == 0) {
Thread.Sleep(TimeSpan.FromMilliseconds(50));
}
else {
foreach (var item in newItems) {
queue.Enqueue(item);
}
}
}
}
}
}

View File

@@ -1,10 +1,10 @@
using System;
using System.Collections.Frozen;
namespace DHT.Server.Download;
namespace DHT.Server.Download;
static class DiscordCdn {
private static FrozenSet<string> CdnHosts { get; } = new [] {
private static FrozenSet<string> CdnHosts { get; } = new[] {
"cdn.discordapp.com",
"cdn.discord.com",
}.ToFrozenSet();

View File

@@ -0,0 +1,40 @@
using System;
using System.Threading;
using System.Threading.Tasks;
using DHT.Server.Database;
namespace DHT.Server.Download;
public sealed class Downloader {
private DownloaderTask? current;
public bool IsDownloading => current != null;
private readonly IDatabaseFile db;
private readonly SemaphoreSlim semaphore = new (1, 1);
internal Downloader(IDatabaseFile db) {
this.db = db;
}
public async Task<IObservable<DownloadItem>> Start() {
await semaphore.WaitAsync();
try {
current ??= new DownloaderTask(db);
return current.FinishedItems;
} finally {
semaphore.Release();
}
}
public async Task Stop() {
await semaphore.WaitAsync();
try {
if (current != null) {
await current.DisposeAsync();
current = null;
}
} finally {
semaphore.Release();
}
}
}

View File

@@ -0,0 +1,110 @@
using System;
using System.Linq;
using System.Net.Http;
using System.Reactive.Subjects;
using System.Threading;
using System.Threading.Channels;
using System.Threading.Tasks;
using DHT.Server.Database;
using DHT.Utils.Logging;
using DHT.Utils.Tasks;
namespace DHT.Server.Download;
sealed class DownloaderTask : IAsyncDisposable {
private static readonly Log Log = Log.ForType<DownloaderTask>();
private const int DownloadTasks = 4;
private const int QueueSize = 25;
private const string UserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36";
private readonly Channel<DownloadItem> downloadQueue = Channel.CreateBounded<DownloadItem>(new BoundedChannelOptions(QueueSize) {
SingleReader = false,
SingleWriter = true,
AllowSynchronousContinuations = false,
FullMode = BoundedChannelFullMode.Wait
});
private readonly CancellationTokenSource cancellationTokenSource = new ();
private readonly CancellationToken cancellationToken;
private readonly IDatabaseFile db;
private readonly ISubject<DownloadItem> finishedItemPublisher = Subject.Synchronize(new Subject<DownloadItem>());
private readonly Task queueWriterTask;
private readonly Task[] downloadTasks;
public IObservable<DownloadItem> FinishedItems => finishedItemPublisher;
internal DownloaderTask(IDatabaseFile db) {
this.db = db;
this.cancellationToken = cancellationTokenSource.Token;
this.queueWriterTask = Task.Run(RunQueueWriterTask);
this.downloadTasks = Enumerable.Range(1, DownloadTasks).Select(taskIndex => Task.Run(() => RunDownloadTask(taskIndex))).ToArray();
}
private async Task RunQueueWriterTask() {
while (await downloadQueue.Writer.WaitToWriteAsync(cancellationToken)) {
var newItems = await db.Downloads.PullEnqueuedDownloadItems(QueueSize, cancellationToken).ToListAsync(cancellationToken);
if (newItems.Count == 0) {
await Task.Delay(TimeSpan.FromMilliseconds(50), cancellationToken);
continue;
}
foreach (var newItem in newItems) {
await downloadQueue.Writer.WriteAsync(newItem, cancellationToken);
}
}
}
private async Task RunDownloadTask(int taskIndex) {
var log = Log.ForType<DownloaderTask>("Task " + taskIndex);
var client = new HttpClient();
client.DefaultRequestHeaders.UserAgent.ParseAdd(UserAgent);
client.Timeout = TimeSpan.FromSeconds(30);
while (!cancellationToken.IsCancellationRequested) {
var item = await downloadQueue.Reader.ReadAsync(cancellationToken);
log.Debug("Downloading " + item.DownloadUrl + "...");
try {
var downloadedBytes = await client.GetByteArrayAsync(item.DownloadUrl, cancellationToken);
await db.Downloads.AddDownload(Data.Download.NewSuccess(item, downloadedBytes));
} catch (OperationCanceledException) {
// Ignore.
} catch (HttpRequestException e) {
await db.Downloads.AddDownload(Data.Download.NewFailure(item, e.StatusCode, item.Size));
log.Error(e);
} catch (Exception e) {
await db.Downloads.AddDownload(Data.Download.NewFailure(item, null, item.Size));
log.Error(e);
} finally {
try {
finishedItemPublisher.OnNext(item);
} catch (Exception e) {
log.Error("Caught exception in event handler: " + e);
}
}
}
}
public async ValueTask DisposeAsync() {
try {
await cancellationTokenSource.CancelAsync();
} catch (Exception) {
Log.Warn("Attempted to stop background download twice.");
return;
}
downloadQueue.Writer.Complete();
try {
await queueWriterTask.WaitIgnoringCancellation();
await Task.WhenAll(downloadTasks).WaitIgnoringCancellation();
} finally {
cancellationTokenSource.Dispose();
finishedItemPublisher.OnCompleted();
}
}
}

View File

@@ -3,12 +3,9 @@ using System.Net;
using System.Text.Json;
using System.Threading.Tasks;
using DHT.Server.Database;
using DHT.Server.Service;
using DHT.Utils.Http;
using DHT.Utils.Logging;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Extensions;
using Microsoft.Extensions.Primitives;
namespace DHT.Server.Endpoints;
@@ -16,25 +13,14 @@ abstract class BaseEndpoint {
private static readonly Log Log = Log.ForType<BaseEndpoint>();
protected IDatabaseFile Db { get; }
protected ServerParameters Parameters { get; }
protected BaseEndpoint(IDatabaseFile db, ServerParameters parameters) {
protected BaseEndpoint(IDatabaseFile db) {
this.Db = db;
this.Parameters = parameters;
}
private async Task Handle(HttpContext ctx, StringValues token) {
var request = ctx.Request;
public async Task Handle(HttpContext ctx) {
var response = ctx.Response;
Log.Info("Request: " + request.GetDisplayUrl() + " (" + request.ContentLength + " B)");
if (token.Count != 1 || token[0] != Parameters.Token) {
Log.Error("Token: " + (token.Count == 1 ? token[0] : "<missing>"));
response.StatusCode = (int) HttpStatusCode.Forbidden;
return;
}
try {
response.StatusCode = (int) HttpStatusCode.OK;
var output = await Respond(ctx);
@@ -49,17 +35,13 @@ abstract class BaseEndpoint {
}
}
public async Task HandleGet(HttpContext ctx) {
await Handle(ctx, ctx.Request.Query["token"]);
}
public async Task HandlePost(HttpContext ctx) {
await Handle(ctx, ctx.Request.Headers["X-DHT-Token"]);
}
protected abstract Task<IHttpOutput> Respond(HttpContext ctx);
protected static async Task<JsonElement> ReadJson(HttpContext ctx) {
return await ctx.Request.ReadFromJsonAsync<JsonElement?>() ?? throw new HttpException(HttpStatusCode.UnsupportedMediaType, "This endpoint only accepts JSON.");
try {
return await ctx.Request.ReadFromJsonAsync(JsonElementContext.Default.JsonElement);
} catch (JsonException) {
throw new HttpException(HttpStatusCode.UnsupportedMediaType, "This endpoint only accepts JSON.");
}
}
}

View File

@@ -2,24 +2,23 @@ using System.Net;
using System.Threading.Tasks;
using DHT.Server.Data;
using DHT.Server.Database;
using DHT.Server.Service;
using DHT.Utils.Http;
using Microsoft.AspNetCore.Http;
namespace DHT.Server.Endpoints;
sealed class GetAttachmentEndpoint : BaseEndpoint {
public GetAttachmentEndpoint(IDatabaseFile db, ServerParameters parameters) : base(db, parameters) {}
public GetAttachmentEndpoint(IDatabaseFile db) : base(db) {}
protected override Task<IHttpOutput> Respond(HttpContext ctx) {
protected override async Task<IHttpOutput> Respond(HttpContext ctx) {
string attachmentUrl = WebUtility.UrlDecode((string) ctx.Request.RouteValues["url"]!);
DownloadedAttachment? maybeDownloadedAttachment = Db.GetDownloadedAttachment(attachmentUrl);
DownloadedAttachment? maybeDownloadedAttachment = await Db.Downloads.GetDownloadedAttachment(attachmentUrl);
if (maybeDownloadedAttachment is {} downloadedAttachment) {
return Task.FromResult<IHttpOutput>(new HttpOutput.File(downloadedAttachment.Type, downloadedAttachment.Data));
return new HttpOutput.File(downloadedAttachment.Type, downloadedAttachment.Data);
}
else {
return Task.FromResult<IHttpOutput>(new HttpOutput.Redirect(attachmentUrl, permanent: false));
return new HttpOutput.Redirect(attachmentUrl, permanent: false);
}
}
}

View File

@@ -13,12 +13,16 @@ namespace DHT.Server.Endpoints;
sealed class GetTrackingScriptEndpoint : BaseEndpoint {
private static ResourceLoader Resources { get; } = new (Assembly.GetExecutingAssembly());
public GetTrackingScriptEndpoint(IDatabaseFile db, ServerParameters parameters) : base(db, parameters) {}
private readonly ServerParameters serverParameters;
public GetTrackingScriptEndpoint(IDatabaseFile db, ServerParameters parameters) : base(db) {
serverParameters = parameters;
}
protected override async Task<IHttpOutput> Respond(HttpContext ctx) {
string bootstrap = await Resources.ReadTextAsync("Tracker/bootstrap.js");
string script = bootstrap.Replace("= 0; /*[PORT]*/", "= " + Parameters.Port + ";")
.Replace("/*[TOKEN]*/", HttpUtility.JavaScriptStringEncode(Parameters.Token))
string script = bootstrap.Replace("= 0; /*[PORT]*/", "= " + serverParameters.Port + ";")
.Replace("/*[TOKEN]*/", HttpUtility.JavaScriptStringEncode(serverParameters.Token))
.Replace("/*[IMPORTS]*/", await Resources.ReadJoinedAsync("Tracker/scripts/", '\n'))
.Replace("/*[CSS-CONTROLLER]*/", await Resources.ReadTextAsync("Tracker/styles/controller.css"))
.Replace("/*[CSS-SETTINGS]*/", await Resources.ReadTextAsync("Tracker/styles/settings.css"))

View File

@@ -3,33 +3,32 @@ using System.Text.Json;
using System.Threading.Tasks;
using DHT.Server.Data;
using DHT.Server.Database;
using DHT.Server.Service;
using DHT.Utils.Http;
using Microsoft.AspNetCore.Http;
namespace DHT.Server.Endpoints;
sealed class TrackChannelEndpoint : BaseEndpoint {
public TrackChannelEndpoint(IDatabaseFile db, ServerParameters parameters) : base(db, parameters) {}
public TrackChannelEndpoint(IDatabaseFile db) : base(db) {}
protected override async Task<IHttpOutput> Respond(HttpContext ctx) {
var root = await ReadJson(ctx);
var server = ReadServer(root.RequireObject("server"), "server");
var channel = ReadChannel(root.RequireObject("channel"), "channel", server.Id);
Db.AddServer(server);
Db.AddChannel(channel);
await Db.Servers.Add([server]);
await Db.Channels.Add([channel]);
return HttpOutput.None;
}
private static Data.Server ReadServer(JsonElement json, string path) => new() {
private static Data.Server ReadServer(JsonElement json, string path) => new () {
Id = json.RequireSnowflake("id", path),
Name = json.RequireString("name", path),
Type = ServerTypes.FromString(json.RequireString("type", path)) ?? throw new HttpException(HttpStatusCode.BadRequest, "Server type must be either 'SERVER', 'GROUP', or 'DM'.")
};
private static Channel ReadChannel(JsonElement json, string path, ulong serverId) => new() {
private static Channel ReadChannel(JsonElement json, string path, ulong serverId) => new () {
Id = json.RequireSnowflake("id", path),
Server = serverId,
Name = json.RequireString("name", path),

View File

@@ -9,7 +9,6 @@ using DHT.Server.Data;
using DHT.Server.Data.Filters;
using DHT.Server.Database;
using DHT.Server.Download;
using DHT.Server.Service;
using DHT.Utils.Collections;
using DHT.Utils.Http;
using Microsoft.AspNetCore.Http;
@@ -17,7 +16,10 @@ using Microsoft.AspNetCore.Http;
namespace DHT.Server.Endpoints;
sealed class TrackMessagesEndpoint : BaseEndpoint {
public TrackMessagesEndpoint(IDatabaseFile db, ServerParameters parameters) : base(db, parameters) {}
private const string HasNewMessages = "1";
private const string NoNewMessages = "0";
public TrackMessagesEndpoint(IDatabaseFile db) : base(db) {}
protected override async Task<IHttpOutput> Respond(HttpContext ctx) {
var root = await ReadJson(ctx);
@@ -37,14 +39,14 @@ sealed class TrackMessagesEndpoint : BaseEndpoint {
}
var addedMessageFilter = new MessageFilter { MessageIds = addedMessageIds };
bool anyNewMessages = Db.CountMessages(addedMessageFilter) < messages.Length;
bool anyNewMessages = await Db.Messages.Count(addedMessageFilter) < addedMessageIds.Count;
Db.AddMessages(messages);
await Db.Messages.Add(messages);
return new HttpOutput.Json(anyNewMessages ? 1 : 0);
return new HttpOutput.Text(anyNewMessages ? HasNewMessages : NoNewMessages);
}
private static Message ReadMessage(JsonElement json, string path) => new() {
private static Message ReadMessage(JsonElement json, string path) => new () {
Id = json.RequireSnowflake("id", path),
Sender = json.RequireSnowflake("sender", path),
Channel = json.RequireSnowflake("channel", path),
@@ -52,9 +54,9 @@ sealed class TrackMessagesEndpoint : BaseEndpoint {
Timestamp = json.RequireLong("timestamp", path),
EditTimestamp = json.HasKey("editTimestamp") ? json.RequireLong("editTimestamp", path) : null,
RepliedToId = json.HasKey("repliedToId") ? json.RequireSnowflake("repliedToId", path) : null,
Attachments = json.HasKey("attachments") ? ReadAttachments(json.RequireArray("attachments", path + ".attachments"), path + ".attachments[]").ToImmutableArray() : ImmutableArray<Attachment>.Empty,
Embeds = json.HasKey("embeds") ? ReadEmbeds(json.RequireArray("embeds", path + ".embeds"), path + ".embeds[]").ToImmutableArray() : ImmutableArray<Embed>.Empty,
Reactions = json.HasKey("reactions") ? ReadReactions(json.RequireArray("reactions", path + ".reactions"), path + ".reactions[]").ToImmutableArray() : ImmutableArray<Reaction>.Empty,
Attachments = json.HasKey("attachments") ? ReadAttachments(json.RequireArray("attachments", path + ".attachments"), path + ".attachments[]").ToImmutableList() : ImmutableList<Attachment>.Empty,
Embeds = json.HasKey("embeds") ? ReadEmbeds(json.RequireArray("embeds", path + ".embeds"), path + ".embeds[]").ToImmutableList() : ImmutableList<Embed>.Empty,
Reactions = json.HasKey("reactions") ? ReadReactions(json.RequireArray("reactions", path + ".reactions"), path + ".reactions[]").ToImmutableList() : ImmutableList<Reaction>.Empty,
};
[SuppressMessage("ReSharper", "ConvertToLambdaExpression")]

View File

@@ -3,14 +3,13 @@ using System.Text.Json;
using System.Threading.Tasks;
using DHT.Server.Data;
using DHT.Server.Database;
using DHT.Server.Service;
using DHT.Utils.Http;
using Microsoft.AspNetCore.Http;
namespace DHT.Server.Endpoints;
sealed class TrackUsersEndpoint : BaseEndpoint {
public TrackUsersEndpoint(IDatabaseFile db, ServerParameters parameters) : base(db, parameters) {}
public TrackUsersEndpoint(IDatabaseFile db) : base(db) {}
protected override async Task<IHttpOutput> Respond(HttpContext ctx) {
var root = await ReadJson(ctx);
@@ -26,12 +25,12 @@ sealed class TrackUsersEndpoint : BaseEndpoint {
users[i++] = ReadUser(user, "user");
}
Db.AddUsers(users);
await Db.Users.Add(users);
return HttpOutput.None;
}
private static User ReadUser(JsonElement json, string path) => new() {
private static User ReadUser(JsonElement json, string path) => new () {
Id = json.RequireSnowflake("id", path),
Name = json.RequireString("name", path),
AvatarUrl = json.HasKey("avatar") ? json.RequireString("avatar", path) : null,

Some files were not shown because too many files have changed in this diff Show More