1
0
mirror of https://github.com/chylex/Discord-History-Tracker.git synced 2025-09-15 19:32:09 +02:00

1 Commits

Author SHA1 Message Date
11f7d4a49f Rewrite database interface to be asynchronous and improve UI 2023-12-29 21:26:21 +01:00
119 changed files with 2286 additions and 2421 deletions

View File

@@ -9,15 +9,15 @@
<entry key="Desktop/Dialogs/Progress/ProgressDialog.axaml" value="Desktop/Desktop.csproj" />
<entry key="Desktop/Dialogs/TextBox/TextBoxDialog.axaml" value="Desktop/Desktop.csproj" />
<entry key="Desktop/Main/AboutWindow.axaml" value="Desktop/Desktop.csproj" />
<entry key="Desktop/Main/Controls/DownloadItemFilterPanel.axaml" value="Desktop/Desktop.csproj" />
<entry key="Desktop/Main/Controls/AttachmentFilterPanel.axaml" value="Desktop/Desktop.csproj" />
<entry key="Desktop/Main/Controls/MessageFilterPanel.axaml" value="Desktop/Desktop.csproj" />
<entry key="Desktop/Main/Controls/ServerConfigurationPanel.axaml" value="Desktop/Desktop.csproj" />
<entry key="Desktop/Main/Controls/StatusBar.axaml" value="Desktop/Desktop.csproj" />
<entry key="Desktop/Main/MainWindow.axaml" value="Desktop/Desktop.csproj" />
<entry key="Desktop/Main/Pages/AdvancedPage.axaml" value="Desktop/Desktop.csproj" />
<entry key="Desktop/Main/Pages/AttachmentsPage.axaml" value="Desktop/Desktop.csproj" />
<entry key="Desktop/Main/Pages/DatabasePage.axaml" value="Desktop/Desktop.csproj" />
<entry key="Desktop/Main/Pages/DebugPage.axaml" value="Desktop/Desktop.csproj" />
<entry key="Desktop/Main/Pages/DownloadsPage.axaml" value="Desktop/Desktop.csproj" />
<entry key="Desktop/Main/Pages/TrackingPage.axaml" value="Desktop/Desktop.csproj" />
<entry key="Desktop/Main/Pages/ViewerPage.axaml" value="Desktop/Desktop.csproj" />
<entry key="Desktop/Main/Screens/MainContentScreen.axaml" value="Desktop/Desktop.csproj" />
@@ -25,4 +25,4 @@
</map>
</option>
</component>
</project>
</project>

View File

@@ -1,31 +1,29 @@
using System;
using System.Collections.Generic;
using DHT.Utils.Logging;
namespace DHT.Desktop;
sealed class Arguments {
private static readonly Log Log = Log.ForType<Arguments>();
private const int FirstArgument = 1;
public static Arguments Empty => new (Array.Empty<string>());
public static Arguments Empty => new(Array.Empty<string>());
public bool Console { get; }
public string? DatabaseFile { get; }
public ushort? ServerPort { get; }
public string? ServerToken { get; }
public byte? ConcurrentDownloads { get; }
public Arguments(IReadOnlyList<string> args) {
for (int i = FirstArgument; i < args.Count; i++) {
public Arguments(string[] args) {
for (int i = FirstArgument; i < args.Length; i++) {
string key = args[i];
switch (key) {
case "-debug":
Log.IsDebugEnabled = true;
continue;
case "-console":
Console = true;
continue;
@@ -37,7 +35,7 @@ sealed class Arguments {
value = key;
key = "-db";
}
else if (i >= args.Count - 1) {
else if (i >= args.Length - 1) {
Log.Warn("Missing value for command line argument: " + key);
continue;
}
@@ -51,11 +49,11 @@ sealed class Arguments {
continue;
case "-port": {
if (!ushort.TryParse(value, out var port)) {
Log.Warn("Invalid port number: " + value);
if (ushort.TryParse(value, out var port)) {
ServerPort = port;
}
else {
ServerPort = port;
Log.Warn("Invalid port number: " + value);
}
continue;
@@ -64,20 +62,6 @@ sealed class Arguments {
case "-token":
ServerToken = value;
continue;
case "-concurrentdownloads":
if (!ulong.TryParse(value, out var concurrentDownloads) || concurrentDownloads == 0) {
Log.Warn("Invalid concurrent downloads count: " + value);
}
else if (concurrentDownloads > 10) {
Log.Warn("Limiting concurrent downloads to 10");
ConcurrentDownloads = 10;
}
else {
ConcurrentDownloads = (byte) concurrentDownloads;
}
continue;
default:
Log.Warn("Unknown command line argument: " + key);

View File

@@ -19,17 +19,17 @@ sealed class BytesValueConverter : IValueConverter {
}
}
private static readonly Unit[] Units = [
new Unit("B", decimalPlaces: 0),
new Unit("kB", decimalPlaces: 0),
new Unit("MB", decimalPlaces: 1),
new Unit("GB", decimalPlaces: 1),
new Unit("TB", decimalPlaces: 1)
];
private static readonly Unit[] Units = {
new ("B", decimalPlaces: 0),
new ("kB", decimalPlaces: 0),
new ("MB", decimalPlaces: 1),
new ("GB", decimalPlaces: 1),
new ("TB", decimalPlaces: 1)
};
private const int Scale = 1000;
public static string Convert(ulong size) {
private static string Convert(ulong size) {
int power = size == 0L ? 0 : (int) Math.Log(size, Scale);
int unit = power >= Units.Length ? Units.Length - 1 : power;
return Units[unit].Format(unit == 0 ? size : size / Math.Pow(Scale, unit));

View File

@@ -1,15 +1,17 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
using Avalonia.Controls;
using Avalonia.Platform.Storage;
using Avalonia.Threading;
using DHT.Desktop.Dialogs.File;
using DHT.Desktop.Dialogs.Message;
using DHT.Server.Database;
using DHT.Server.Database.Exceptions;
using DHT.Server.Database.Sqlite;
using DHT.Server.Database.Sqlite.Schema;
using DHT.Server.Database.Sqlite.Utils;
using DHT.Utils.Logging;
namespace DHT.Desktop.Common;
@@ -43,10 +45,15 @@ static class DatabaseGui {
}
public static async Task<IDatabaseFile?> TryOpenOrCreateDatabaseFromPath(string path, Window window, ISchemaUpgradeCallbacks schemaUpgradeCallbacks) {
var prevSynchronizationContext = SynchronizationContext.Current;
SynchronizationContext.SetSynchronizationContext(new AvaloniaSynchronizationContext());
var taskScheduler = TaskScheduler.FromCurrentSynchronizationContext();
SynchronizationContext.SetSynchronizationContext(prevSynchronizationContext);
IDatabaseFile? file = null;
try {
file = await SqliteDatabaseFile.OpenOrCreate(path, schemaUpgradeCallbacks);
file = await SqliteDatabaseFile.OpenOrCreate(path, schemaUpgradeCallbacks, taskScheduler);
} catch (InvalidDatabaseVersionException ex) {
await Dialog.ShowOk(window, "Database Error", "Database '" + Path.GetFileName(path) + "' appears to be corrupted (invalid version: " + ex.Version + ").");
} catch (DatabaseTooNewException ex) {

View File

@@ -23,7 +23,6 @@
<PackageReference Include="Avalonia.Fonts.Inter" Version="11.0.6" />
<PackageReference Include="Avalonia.ReactiveUI" Version="11.0.6" />
<PackageReference Include="Avalonia.Themes.Fluent" Version="11.0.6" />
<PackageReference Include="CommunityToolkit.Mvvm" Version="999.0.0-build.0.g0d941a6a62" />
</ItemGroup>
<ItemGroup>
@@ -32,6 +31,10 @@
<ItemGroup>
<Compile Include="..\Version.cs" Link="Version.cs" />
<Compile Update="Dialogs\TextBox\TextBoxDialog.axaml.cs">
<DependentUpon>CheckBoxDialog.axaml</DependentUpon>
<SubType>Code</SubType>
</Compile>
</ItemGroup>
<ItemGroup>

View File

@@ -38,7 +38,7 @@
<ItemsRepeater ItemsSource="{Binding Items}">
<ItemsRepeater.ItemTemplate>
<DataTemplate>
<CheckBox IsChecked="{Binding IsChecked}">
<CheckBox IsChecked="{Binding Checked}">
<Label>
<TextBlock Text="{Binding Title}" TextWrapping="Wrap" />
</Label>

View File

@@ -2,11 +2,11 @@ using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Linq;
using CommunityToolkit.Mvvm.ComponentModel;
using DHT.Utils.Models;
namespace DHT.Desktop.Dialogs.CheckBox;
class CheckBoxDialogModel : ObservableObject {
class CheckBoxDialogModel : BaseModel {
public string Title { get; init; } = "";
private IReadOnlyList<CheckBoxItem> items = Array.Empty<CheckBoxItem>();
@@ -29,8 +29,8 @@ class CheckBoxDialogModel : ObservableObject {
private bool pauseCheckEvents = false;
public bool AreAllSelected => Items.All(static item => item.IsChecked);
public bool AreNoneSelected => Items.All(static item => !item.IsChecked);
public bool AreAllSelected => Items.All(static item => item.Checked);
public bool AreNoneSelected => Items.All(static item => !item.Checked);
public void SelectAll() => SetAllChecked(true);
public void SelectNone() => SetAllChecked(false);
@@ -39,7 +39,7 @@ class CheckBoxDialogModel : ObservableObject {
pauseCheckEvents = true;
foreach (var item in Items) {
item.IsChecked = isChecked;
item.Checked = isChecked;
}
pauseCheckEvents = false;
@@ -52,16 +52,16 @@ class CheckBoxDialogModel : ObservableObject {
}
private void OnItemPropertyChanged(object? sender, PropertyChangedEventArgs e) {
if (!pauseCheckEvents && e.PropertyName == nameof(CheckBoxItem.IsChecked)) {
if (!pauseCheckEvents && e.PropertyName == nameof(CheckBoxItem.Checked)) {
UpdateBulkButtons();
}
}
}
sealed class CheckBoxDialogModel<T> : CheckBoxDialogModel {
private new IReadOnlyList<CheckBoxItem<T>> Items { get; }
public new IReadOnlyList<CheckBoxItem<T>> Items { get; }
public IEnumerable<CheckBoxItem<T>> SelectedItems => Items.Where(static item => item.IsChecked);
public IEnumerable<CheckBoxItem<T>> SelectedItems => Items.Where(static item => item.Checked);
public CheckBoxDialogModel(IEnumerable<CheckBoxItem<T>> items) {
this.Items = new List<CheckBoxItem<T>>(items);

View File

@@ -1,13 +1,17 @@
using CommunityToolkit.Mvvm.ComponentModel;
using DHT.Utils.Models;
namespace DHT.Desktop.Dialogs.CheckBox;
partial class CheckBoxItem : ObservableObject {
class CheckBoxItem : BaseModel {
public string Title { get; init; } = "";
public object? Item { get; init; } = null;
[ObservableProperty]
private bool isChecked = false;
public bool Checked {
get => isChecked;
set => Change(ref isChecked, value);
}
}
sealed class CheckBoxItem<T> : CheckBoxItem {

View File

@@ -4,10 +4,11 @@ using System.Linq;
using System.Threading.Tasks;
using Avalonia.Threading;
using DHT.Desktop.Common;
using DHT.Utils.Models;
namespace DHT.Desktop.Dialogs.Progress;
sealed class ProgressDialogModel {
sealed class ProgressDialogModel : BaseModel {
public string Title { get; init; } = "";
public IReadOnlyList<ProgressItem> Items { get; } = Array.Empty<ProgressItem>();

View File

@@ -1,12 +1,18 @@
using CommunityToolkit.Mvvm.ComponentModel;
using DHT.Utils.Models;
namespace DHT.Desktop.Dialogs.Progress;
sealed partial class ProgressItem : ObservableObject {
[ObservableProperty(Setter = Access.Private)]
[NotifyPropertyChangedFor(nameof(Opacity))]
sealed class ProgressItem : BaseModel {
private bool isVisible = false;
public bool IsVisible {
get => isVisible;
private set {
Change(ref isVisible, value);
OnPropertyChanged(nameof(Opacity));
}
}
public double Opacity => IsVisible ? 1.0 : 0.0;
private string message = "";
@@ -14,17 +20,29 @@ sealed partial class ProgressItem : ObservableObject {
public string Message {
get => message;
set {
SetProperty(ref message, value);
Change(ref message, value);
IsVisible = !string.IsNullOrEmpty(value);
}
}
[ObservableProperty]
private string items = "";
[ObservableProperty]
public string Items {
get => items;
set => Change(ref items, value);
}
private int progress = 0;
[ObservableProperty]
public int Progress {
get => progress;
set => Change(ref progress, value);
}
private bool isIndeterminate;
public bool IsIndeterminate {
get => isIndeterminate;
set => Change(ref isIndeterminate, value);
}
}

View File

@@ -2,11 +2,11 @@ using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Linq;
using CommunityToolkit.Mvvm.ComponentModel;
using DHT.Utils.Models;
namespace DHT.Desktop.Dialogs.TextBox;
class TextBoxDialogModel : ObservableObject {
class TextBoxDialogModel : BaseModel {
public string Title { get; init; } = "";
public string Description { get; init; } = "";
@@ -36,7 +36,7 @@ class TextBoxDialogModel : ObservableObject {
}
sealed class TextBoxDialogModel<T> : TextBoxDialogModel {
private new IReadOnlyList<TextBoxItem<T>> Items { get; }
public new IReadOnlyList<TextBoxItem<T>> Items { get; }
public IEnumerable<TextBoxItem<T>> ValidItems => Items.Where(static item => item.IsValid);

View File

@@ -1,11 +1,11 @@
using System;
using System.Collections;
using System.ComponentModel;
using CommunityToolkit.Mvvm.ComponentModel;
using DHT.Utils.Models;
namespace DHT.Desktop.Dialogs.TextBox;
class TextBoxItem : ObservableObject, INotifyDataErrorInfo {
class TextBoxItem : BaseModel, INotifyDataErrorInfo {
public string Title { get; init; } = "";
public object? Item { get; init; } = null;
@@ -17,7 +17,7 @@ class TextBoxItem : ObservableObject, INotifyDataErrorInfo {
public string Value {
get => this.value;
set {
SetProperty(ref this.value, value);
Change(ref this.value, value);
ErrorsChanged?.Invoke(this, new DataErrorsChangedEventArgs(nameof(Value)));
}
}

View File

@@ -39,7 +39,7 @@ static class DiscordAppSettings {
public static async Task<bool?> AreDevToolsEnabled() {
try {
var settingsJson = await ReadSettingsJson();
var settingsJson = await ReadSettingsJson().ConfigureAwait(false);
return AreDevToolsEnabled(settingsJson);
} catch (Exception e) {
Log.Error("Cannot read settings file.");

View File

@@ -5,4 +5,4 @@ namespace DHT.Desktop.Discord;
[JsonSourceGenerationOptions(GenerationMode = JsonSourceGenerationMode.Default, WriteIndented = true)]
[JsonSerializable(typeof(JsonObject))]
sealed partial class DiscordAppSettingsJsonContext : JsonSerializerContext;
sealed partial class DiscordAppSettingsJsonContext : JsonSerializerContext {}

View File

@@ -4,11 +4,11 @@
xmlns:mc="http://schemas.openxmlformats.org/markup-compatibility/2006"
xmlns:controls="clr-namespace:DHT.Desktop.Main.Controls"
mc:Ignorable="d"
x:Class="DHT.Desktop.Main.Controls.DownloadItemFilterPanel"
x:DataType="controls:DownloadItemFilterPanelModel">
x:Class="DHT.Desktop.Main.Controls.AttachmentFilterPanel"
x:DataType="controls:AttachmentFilterPanelModel">
<Design.DataContext>
<controls:DownloadItemFilterPanelModel />
<controls:AttachmentFilterPanelModel />
</Design.DataContext>
<UserControl.Styles>

View File

@@ -4,8 +4,8 @@ using Avalonia.Controls;
namespace DHT.Desktop.Main.Controls;
[SuppressMessage("ReSharper", "MemberCanBeInternal")]
public sealed partial class DownloadItemFilterPanel : UserControl {
public DownloadItemFilterPanel() {
public sealed partial class AttachmentFilterPanel : UserControl {
public AttachmentFilterPanel() {
InitializeComponent();
}
}

View File

@@ -0,0 +1,131 @@
using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Threading.Tasks;
using DHT.Desktop.Common;
using DHT.Server;
using DHT.Server.Data.Filters;
using DHT.Server.Database;
using DHT.Utils.Models;
using DHT.Utils.Tasks;
namespace DHT.Desktop.Main.Controls;
sealed class AttachmentFilterPanelModel : BaseModel, IDisposable {
public sealed record Unit(string Name, uint Scale);
private static readonly Unit[] AllUnits = [
new Unit("B", 1),
new Unit("kB", 1024),
new Unit("MB", 1024 * 1024)
];
private static readonly HashSet<string> FilterProperties = [
nameof(LimitSize),
nameof(MaximumSize),
nameof(MaximumSizeUnit)
];
public string FilterStatisticsText { get; private set; } = "";
private bool limitSize = false;
private ulong maximumSize = 0L;
private Unit maximumSizeUnit = AllUnits[0];
public bool LimitSize {
get => limitSize;
set => Change(ref limitSize, value);
}
public ulong MaximumSize {
get => maximumSize;
set => Change(ref maximumSize, value);
}
public Unit MaximumSizeUnit {
get => maximumSizeUnit;
set => Change(ref maximumSizeUnit, value);
}
public IEnumerable<Unit> Units => AllUnits;
private readonly State state;
private readonly string verb;
private readonly RestartableTask<long> matchingAttachmentCountTask;
private long? matchingAttachmentCount;
private long? totalAttachmentCount;
[Obsolete("Designer")]
public AttachmentFilterPanelModel() : this(State.Dummy) {}
public AttachmentFilterPanelModel(State state, string verb = "Matches") {
this.state = state;
this.verb = verb;
this.matchingAttachmentCountTask = new RestartableTask<long>(SetAttachmentCounts, TaskScheduler.FromCurrentSynchronizationContext());
UpdateFilterStatistics();
PropertyChanged += OnPropertyChanged;
state.Db.Statistics.PropertyChanged += OnDbStatisticsChanged;
}
public void Dispose() {
state.Db.Statistics.PropertyChanged -= OnDbStatisticsChanged;
}
private void OnPropertyChanged(object? sender, PropertyChangedEventArgs e) {
if (e.PropertyName != null && FilterProperties.Contains(e.PropertyName)) {
UpdateFilterStatistics();
}
}
private void OnDbStatisticsChanged(object? sender, PropertyChangedEventArgs e) {
if (e.PropertyName == nameof(DatabaseStatistics.TotalAttachments)) {
totalAttachmentCount = state.Db.Statistics.TotalAttachments;
UpdateFilterStatistics();
}
}
private void UpdateFilterStatistics() {
var filter = CreateFilter();
if (filter.IsEmpty) {
matchingAttachmentCountTask.Cancel();
matchingAttachmentCount = totalAttachmentCount;
UpdateFilterStatisticsText();
}
else {
matchingAttachmentCount = null;
UpdateFilterStatisticsText();
matchingAttachmentCountTask.Restart(cancellationToken => state.Db.Downloads.CountAttachments(filter, cancellationToken));
}
}
private void SetAttachmentCounts(long matchingAttachmentCount) {
this.matchingAttachmentCount = matchingAttachmentCount;
UpdateFilterStatisticsText();
}
private void UpdateFilterStatisticsText() {
var matchingAttachmentCountStr = matchingAttachmentCount?.Format() ?? "(...)";
var totalAttachmentCountStr = totalAttachmentCount?.Format() ?? "(...)";
FilterStatisticsText = verb + " " + matchingAttachmentCountStr + " out of " + totalAttachmentCountStr + " attachment" + (totalAttachmentCount is null or 1 ? "." : "s.");
OnPropertyChanged(nameof(FilterStatisticsText));
}
public AttachmentFilter CreateFilter() {
AttachmentFilter filter = new ();
if (LimitSize) {
try {
filter.MaxBytes = maximumSize * maximumSizeUnit.Scale;
} catch (ArithmeticException) {
// set no size limit, because the overflown size is larger than any file could possibly be
}
}
return filter;
}
}

View File

@@ -1,124 +0,0 @@
using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Reactive.Linq;
using System.Threading.Tasks;
using Avalonia.ReactiveUI;
using CommunityToolkit.Mvvm.ComponentModel;
using DHT.Desktop.Common;
using DHT.Server;
using DHT.Server.Data.Filters;
using DHT.Utils.Tasks;
namespace DHT.Desktop.Main.Controls;
sealed partial class DownloadItemFilterPanelModel : ObservableObject, IDisposable {
public sealed record Unit(string Name, uint Scale);
private static readonly Unit[] AllUnits = [
new Unit("B", 1),
new Unit("kB", 1024),
new Unit("MB", 1024 * 1024)
];
private static readonly HashSet<string> FilterProperties = [
nameof(LimitSize),
nameof(MaximumSize),
nameof(MaximumSizeUnit)
];
public string FilterStatisticsText { get; private set; } = "";
[ObservableProperty]
private bool limitSize = false;
[ObservableProperty]
private ulong maximumSize = 0L;
[ObservableProperty]
private Unit maximumSizeUnit = AllUnits[0];
public IEnumerable<Unit> Units => AllUnits;
private readonly State state;
private readonly string verb;
private readonly RestartableTask<long> downloadItemCountTask;
private long? matchingItemCount;
private readonly IDisposable downloadItemCountSubscription;
private long? totalItemCount;
[Obsolete("Designer")]
public DownloadItemFilterPanelModel() : this(State.Dummy) {}
public DownloadItemFilterPanelModel(State state, string verb = "Matches") {
this.state = state;
this.verb = verb;
this.downloadItemCountTask = new RestartableTask<long>(SetMatchingCount, TaskScheduler.FromCurrentSynchronizationContext());
this.downloadItemCountSubscription = state.Db.Downloads.TotalCount.ObserveOn(AvaloniaScheduler.Instance).Subscribe(OnDownloadItemCountChanged);
UpdateFilterStatistics();
PropertyChanged += OnPropertyChanged;
}
public void Dispose() {
downloadItemCountTask.Cancel();
downloadItemCountSubscription.Dispose();
}
private void OnPropertyChanged(object? sender, PropertyChangedEventArgs e) {
if (e.PropertyName != null && FilterProperties.Contains(e.PropertyName)) {
UpdateFilterStatistics();
}
}
private void OnDownloadItemCountChanged(long newItemCount) {
totalItemCount = newItemCount;
UpdateFilterStatistics();
}
private void UpdateFilterStatistics() {
var filter = CreateFilter();
if (filter.IsEmpty) {
downloadItemCountTask.Cancel();
matchingItemCount = totalItemCount;
UpdateFilterStatisticsText();
}
else {
matchingItemCount = null;
UpdateFilterStatisticsText();
downloadItemCountTask.Restart(cancellationToken => state.Db.Downloads.Count(filter, cancellationToken));
}
}
private void SetMatchingCount(long matchingAttachmentCount) {
this.matchingItemCount = matchingAttachmentCount;
UpdateFilterStatisticsText();
}
private void UpdateFilterStatisticsText() {
var matchingItemCountStr = matchingItemCount?.Format() ?? "(...)";
var totalItemCountStr = totalItemCount?.Format() ?? "(...)";
FilterStatisticsText = verb + " " + matchingItemCountStr + " out of " + totalItemCountStr + " file" + (totalItemCount is null or 1 ? "." : "s.");
OnPropertyChanged(nameof(FilterStatisticsText));
}
public DownloadItemFilter CreateFilter() {
DownloadItemFilter filter = new ();
if (LimitSize) {
try {
filter.MaxBytes = maximumSize * maximumSizeUnit.Scale;
} catch (ArithmeticException) {
// set no size limit, because the overflown size is larger than any file could possibly be
}
}
return filter;
}
}

View File

@@ -2,12 +2,9 @@ using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Linq;
using System.Reactive.Linq;
using System.Text;
using System.Threading.Tasks;
using Avalonia.Controls;
using Avalonia.ReactiveUI;
using CommunityToolkit.Mvvm.ComponentModel;
using DHT.Desktop.Common;
using DHT.Desktop.Dialogs.CheckBox;
using DHT.Desktop.Dialogs.Message;
@@ -15,11 +12,13 @@ using DHT.Desktop.Dialogs.Progress;
using DHT.Server;
using DHT.Server.Data;
using DHT.Server.Data.Filters;
using DHT.Server.Database;
using DHT.Utils.Models;
using DHT.Utils.Tasks;
namespace DHT.Desktop.Main.Controls;
sealed partial class MessageFilterPanelModel : ObservableObject, IDisposable {
sealed class MessageFilterPanelModel : BaseModel, IDisposable {
private static readonly HashSet<string> FilterProperties = [
nameof(FilterByDate),
nameof(StartDate),
@@ -36,48 +35,70 @@ sealed partial class MessageFilterPanelModel : ObservableObject, IDisposable {
public bool HasAnyFilters => FilterByDate || FilterByChannel || FilterByUser;
[ObservableProperty]
private bool filterByDate = false;
[ObservableProperty]
private DateTime? startDate = null;
[ObservableProperty]
private DateTime? endDate = null;
[ObservableProperty]
private bool filterByChannel = false;
[ObservableProperty]
private HashSet<ulong>? includedChannels = null;
[ObservableProperty]
private bool filterByUser = false;
[ObservableProperty]
private HashSet<ulong>? includedUsers = null;
[ObservableProperty]
public bool FilterByDate {
get => filterByDate;
set => Change(ref filterByDate, value);
}
public DateTime? StartDate {
get => startDate;
set => Change(ref startDate, value);
}
public DateTime? EndDate {
get => endDate;
set => Change(ref endDate, value);
}
public bool FilterByChannel {
get => filterByChannel;
set => Change(ref filterByChannel, value);
}
public HashSet<ulong>? IncludedChannels {
get => includedChannels;
set => Change(ref includedChannels, value);
}
public bool FilterByUser {
get => filterByUser;
set => Change(ref filterByUser, value);
}
public HashSet<ulong>? IncludedUsers {
get => includedUsers;
set => Change(ref includedUsers, value);
}
private string channelFilterLabel = "";
[ObservableProperty]
public string ChannelFilterLabel {
get => channelFilterLabel;
set => Change(ref channelFilterLabel, value);
}
private string userFilterLabel = "";
public string UserFilterLabel {
get => userFilterLabel;
set => Change(ref userFilterLabel, value);
}
private readonly Window window;
private readonly State state;
private readonly string verb;
private readonly RestartableTask<long> exportedMessageCountTask;
private long? exportedMessageCount;
private readonly IDisposable messageCountSubscription;
private long? totalMessageCount;
private readonly IDisposable channelCountSubscription;
private long? totalChannelCount;
private readonly IDisposable userCountSubscription;
private long? totalUserCount;
[Obsolete("Designer")]
public MessageFilterPanelModel() : this(null!, State.Dummy) {}
@@ -88,24 +109,18 @@ sealed partial class MessageFilterPanelModel : ObservableObject, IDisposable {
this.verb = verb;
this.exportedMessageCountTask = new RestartableTask<long>(SetExportedMessageCount, TaskScheduler.FromCurrentSynchronizationContext());
this.messageCountSubscription = state.Db.Messages.TotalCount.ObserveOn(AvaloniaScheduler.Instance).Subscribe(OnMessageCountChanged);
this.channelCountSubscription = state.Db.Channels.TotalCount.ObserveOn(AvaloniaScheduler.Instance).Subscribe(OnChannelCountChanged);
this.userCountSubscription = state.Db.Users.TotalCount.ObserveOn(AvaloniaScheduler.Instance).Subscribe(OnUserCountChanged);
UpdateFilterStatistics();
UpdateChannelFilterLabel();
UpdateUserFilterLabel();
PropertyChanged += OnPropertyChanged;
state.Db.Statistics.PropertyChanged += OnDbStatisticsChanged;
}
public void Dispose() {
exportedMessageCountTask.Cancel();
messageCountSubscription.Dispose();
channelCountSubscription.Dispose();
userCountSubscription.Dispose();
state.Db.Statistics.PropertyChanged -= OnDbStatisticsChanged;
}
private void OnPropertyChanged(object? sender, PropertyChangedEventArgs e) {
@@ -122,41 +137,29 @@ sealed partial class MessageFilterPanelModel : ObservableObject, IDisposable {
}
}
private void OnMessageCountChanged(long newMessageCount) {
totalMessageCount = newMessageCount;
UpdateFilterStatistics();
}
private void OnChannelCountChanged(long newChannelCount) {
totalChannelCount = newChannelCount;
UpdateChannelFilterLabel();
}
private void OnUserCountChanged(long newUserCount) {
totalUserCount = newUserCount;
UpdateUserFilterLabel();
private void OnDbStatisticsChanged(object? sender, PropertyChangedEventArgs e) {
if (e.PropertyName == nameof(DatabaseStatistics.TotalMessages)) {
totalMessageCount = state.Db.Statistics.TotalMessages;
UpdateFilterStatistics();
}
else if (e.PropertyName == nameof(DatabaseStatistics.TotalChannels)) {
UpdateChannelFilterLabel();
}
else if (e.PropertyName == nameof(DatabaseStatistics.TotalUsers)) {
UpdateUserFilterLabel();
}
}
private void UpdateChannelFilterLabel() {
if (totalChannelCount.HasValue) {
long total = totalChannelCount.Value;
long included = FilterByChannel && IncludedChannels != null ? IncludedChannels.Count : total;
ChannelFilterLabel = "Selected " + included.Format() + " / " + total.Pluralize("channel") + ".";
}
else {
ChannelFilterLabel = "Loading...";
}
long total = state.Db.Statistics.TotalChannels;
long included = FilterByChannel && IncludedChannels != null ? IncludedChannels.Count : total;
ChannelFilterLabel = "Selected " + included.Format() + " / " + total.Pluralize("channel") + ".";
}
private void UpdateUserFilterLabel() {
if (totalUserCount.HasValue) {
long total = totalUserCount.Value;
long included = FilterByUser && IncludedUsers != null ? IncludedUsers.Count : total;
UserFilterLabel = "Selected " + included.Format() + " / " + total.Pluralize("user") + ".";
}
else {
UserFilterLabel = "Loading...";
}
long total = state.Db.Statistics.TotalUsers;
long included = FilterByUser && IncludedUsers != null ? IncludedUsers.Count : total;
UserFilterLabel = "Selected " + included.Format() + " / " + total.Pluralize("user") + ".";
}
private void UpdateFilterStatistics() {
@@ -221,13 +224,13 @@ sealed partial class MessageFilterPanelModel : ObservableObject, IDisposable {
items.Add(new CheckBoxItem<ulong>(channelId) {
Title = title,
IsChecked = IncludedChannels == null || IncludedChannels.Contains(channelId)
Checked = IncludedChannels == null || IncludedChannels.Contains(channelId)
});
}
return items;
}
const string Title = "Included Channels";
List<CheckBoxItem<ulong>> items;
@@ -254,7 +257,7 @@ sealed partial class MessageFilterPanelModel : ObservableObject, IDisposable {
checkBoxItems.Add(new CheckBoxItem<ulong>(user.Id) {
Title = discriminator == null ? name : name + " #" + discriminator,
IsChecked = IncludedUsers == null || IncludedUsers.Contains(user.Id)
Checked = IncludedUsers == null || IncludedUsers.Contains(user.Id)
});
}
@@ -262,7 +265,7 @@ sealed partial class MessageFilterPanelModel : ObservableObject, IDisposable {
}
const string Title = "Included Users";
List<CheckBoxItem<ulong>> items;
try {
items = await ProgressDialog.ShowIndeterminate(window, Title, "Loading users...", PrepareUserItems);

View File

@@ -2,31 +2,47 @@ using System;
using System.Threading.Tasks;
using Avalonia.Controls;
using Avalonia.Threading;
using CommunityToolkit.Mvvm.ComponentModel;
using DHT.Desktop.Dialogs.Message;
using DHT.Desktop.Server;
using DHT.Server;
using DHT.Server.Service;
using DHT.Utils.Logging;
using DHT.Utils.Models;
namespace DHT.Desktop.Main.Controls;
sealed partial class ServerConfigurationPanelModel : ObservableObject, IDisposable {
sealed class ServerConfigurationPanelModel : BaseModel, IDisposable {
private static readonly Log Log = Log.ForType<ServerConfigurationPanelModel>();
[ObservableProperty]
[NotifyPropertyChangedFor(nameof(HasMadeChanges))]
private string inputPort;
[ObservableProperty]
[NotifyPropertyChangedFor(nameof(HasMadeChanges))]
public string InputPort {
get => inputPort;
set {
Change(ref inputPort, value);
OnPropertyChanged(nameof(HasMadeChanges));
}
}
private string inputToken;
public string InputToken {
get => inputToken;
set {
Change(ref inputToken, value);
OnPropertyChanged(nameof(HasMadeChanges));
}
}
public bool HasMadeChanges => ServerConfiguration.Port.ToString() != InputPort || ServerConfiguration.Token != InputToken;
[ObservableProperty(Setter = Access.Private)]
private bool isToggleServerButtonEnabled = true;
public bool IsToggleServerButtonEnabled {
get => isToggleServerButtonEnabled;
set => Change(ref isToggleServerButtonEnabled, value);
}
public string ToggleServerButtonText => server.IsRunning ? "Stop Server" : "Start Server";
private readonly Window window;

View File

@@ -45,17 +45,17 @@
<Rectangle />
<StackPanel Orientation="Vertical">
<TextBlock Classes="label">Servers</TextBlock>
<TextBlock Classes="value" Text="{Binding ServerCount, Mode=OneWay, Converter={StaticResource NumberValueConverter}}" />
<TextBlock Classes="value" Text="{Binding DatabaseStatistics.TotalServers, Mode=OneWay, Converter={StaticResource NumberValueConverter}}" />
</StackPanel>
<Rectangle />
<StackPanel Orientation="Vertical">
<TextBlock Classes="label">Channels</TextBlock>
<TextBlock Classes="value" Text="{Binding ChannelCount, Mode=OneWay, Converter={StaticResource NumberValueConverter}}" />
<TextBlock Classes="value" Text="{Binding DatabaseStatistics.TotalChannels, Mode=OneWay, Converter={StaticResource NumberValueConverter}}" />
</StackPanel>
<Rectangle />
<StackPanel Orientation="Vertical">
<TextBlock Classes="label">Messages</TextBlock>
<TextBlock Classes="value" Text="{Binding MessageCount, Mode=OneWay, Converter={StaticResource NumberValueConverter}}" />
<TextBlock Classes="value" Text="{Binding DatabaseStatistics.TotalMessages, Mode=OneWay, Converter={StaticResource NumberValueConverter}}" />
</StackPanel>
</StackPanel>

View File

@@ -1,25 +1,15 @@
using System;
using System.Reactive.Linq;
using Avalonia.ReactiveUI;
using Avalonia.Threading;
using CommunityToolkit.Mvvm.ComponentModel;
using DHT.Server;
using DHT.Server.Database;
using DHT.Server.Service;
using DHT.Utils.Models;
namespace DHT.Desktop.Main.Controls;
sealed partial class StatusBarModel : ObservableObject, IDisposable {
[ObservableProperty(Setter = Access.Private)]
private long? serverCount;
[ObservableProperty(Setter = Access.Private)]
private long? channelCount;
[ObservableProperty(Setter = Access.Private)]
private long? messageCount;
sealed class StatusBarModel : BaseModel, IDisposable {
public DatabaseStatistics DatabaseStatistics { get; }
[ObservableProperty(Setter = Access.Private)]
[NotifyPropertyChangedFor(nameof(ServerStatusText))]
private ServerManager.Status serverStatus;
public string ServerStatusText => serverStatus switch {
@@ -31,33 +21,26 @@ sealed partial class StatusBarModel : ObservableObject, IDisposable {
};
private readonly State state;
private readonly IDisposable serverCountSubscription;
private readonly IDisposable channelCountSubscription;
private readonly IDisposable messageCountSubscription;
[Obsolete("Designer")]
public StatusBarModel() : this(State.Dummy) {}
public StatusBarModel(State state) {
this.state = state;
serverCountSubscription = state.Db.Servers.TotalCount.ObserveOn(AvaloniaScheduler.Instance).Subscribe(newServerCount => ServerCount = newServerCount);
channelCountSubscription = state.Db.Channels.TotalCount.ObserveOn(AvaloniaScheduler.Instance).Subscribe(newChannelCount => ChannelCount = newChannelCount);
messageCountSubscription = state.Db.Messages.TotalCount.ObserveOn(AvaloniaScheduler.Instance).Subscribe(newMessageCount => MessageCount = newMessageCount);
this.DatabaseStatistics = state.Db.Statistics;
state.Server.StatusChanged += OnServerStatusChanged;
serverStatus = state.Server.IsRunning ? ServerManager.Status.Started : ServerManager.Status.Stopped;
}
public void Dispose() {
serverCountSubscription.Dispose();
channelCountSubscription.Dispose();
messageCountSubscription.Dispose();
state.Server.StatusChanged -= OnServerStatusChanged;
state.Server.StatusChanged += OnServerStatusChanged;
}
private void OnServerStatusChanged(object? sender, ServerManager.Status e) {
Dispatcher.UIThread.InvokeAsync(() => ServerStatus = e);
Dispatcher.UIThread.InvokeAsync(() => {
serverStatus = e;
OnPropertyChanged(nameof(ServerStatusText));
});
}
}

View File

@@ -8,7 +8,7 @@
x:DataType="main:MainWindowModel"
Title="{Binding Title}"
Icon="avares://DiscordHistoryTracker/Resources/icon.ico"
Width="820" Height="520"
Width="800" Height="500"
MinWidth="520" MinHeight="300"
WindowStartupLocation="CenterScreen"
Closing="OnClosing">

View File

@@ -3,26 +3,24 @@ using System.IO;
using System.Runtime.InteropServices;
using System.Threading.Tasks;
using Avalonia.Controls;
using CommunityToolkit.Mvvm.ComponentModel;
using DHT.Desktop.Dialogs.Message;
using DHT.Desktop.Main.Screens;
using DHT.Desktop.Server;
using DHT.Server;
using DHT.Server.Database;
using DHT.Utils.Logging;
using DHT.Utils.Models;
namespace DHT.Desktop.Main;
sealed partial class MainWindowModel : ObservableObject, IAsyncDisposable {
sealed class MainWindowModel : BaseModel, IAsyncDisposable {
private const string DefaultTitle = "Discord History Tracker";
private static readonly Log Log = Log.ForType<MainWindowModel>();
[ObservableProperty(Setter = Access.Private)]
private string title = DefaultTitle;
public string Title { get; private set; } = DefaultTitle;
[ObservableProperty(Setter = Access.Private)]
private UserControl currentScreen;
public UserControl CurrentScreen { get; private set; }
private readonly WelcomeScreen welcomeScreen;
private readonly WelcomeScreenModel welcomeScreenModel;
@@ -30,7 +28,6 @@ sealed partial class MainWindowModel : ObservableObject, IAsyncDisposable {
private MainContentScreenModel? mainContentScreenModel;
private readonly Window window;
private readonly int? concurrentDownloads;
private State? state;
@@ -44,7 +41,7 @@ sealed partial class MainWindowModel : ObservableObject, IAsyncDisposable {
welcomeScreenModel.DatabaseSelected += OnDatabaseSelected;
welcomeScreen = new WelcomeScreen { DataContext = welcomeScreenModel };
currentScreen = welcomeScreen;
CurrentScreen = welcomeScreen;
var dbFile = args.DatabaseFile;
if (!string.IsNullOrWhiteSpace(dbFile)) {
@@ -74,8 +71,6 @@ sealed partial class MainWindowModel : ObservableObject, IAsyncDisposable {
if (args.ServerToken != null) {
ServerConfiguration.Token = args.ServerToken;
}
concurrentDownloads = args.ConcurrentDownloads;
}
private async void OnDatabaseSelected(object? sender, IDatabaseFile db) {
@@ -83,7 +78,7 @@ sealed partial class MainWindowModel : ObservableObject, IAsyncDisposable {
await DisposeState();
state = new State(db, concurrentDownloads);
state = new State(db);
try {
await state.Server.Start(ServerConfiguration.Port, ServerConfiguration.Token);
@@ -98,6 +93,9 @@ sealed partial class MainWindowModel : ObservableObject, IAsyncDisposable {
Title = Path.GetFileName(state.Db.Path) + " - " + DefaultTitle;
CurrentScreen = new MainContentScreen { DataContext = mainContentScreenModel };
OnPropertyChanged(nameof(Title));
OnPropertyChanged(nameof(CurrentScreen));
window.Focus();
}
@@ -114,6 +112,9 @@ sealed partial class MainWindowModel : ObservableObject, IAsyncDisposable {
CurrentScreen = welcomeScreen;
welcomeScreenModel.DatabaseSelected += OnDatabaseSelected;
OnPropertyChanged(nameof(Title));
OnPropertyChanged(nameof(CurrentScreen));
}
private async Task DisposeState() {

View File

@@ -5,10 +5,11 @@ using DHT.Desktop.Dialogs.Message;
using DHT.Desktop.Dialogs.Progress;
using DHT.Desktop.Main.Controls;
using DHT.Server;
using DHT.Utils.Models;
namespace DHT.Desktop.Main.Pages;
sealed class AdvancedPageModel : IDisposable {
sealed class AdvancedPageModel : BaseModel, IDisposable {
public ServerConfigurationPanelModel ServerConfigurationModel { get; }
private readonly Window window;

View File

@@ -5,11 +5,11 @@
xmlns:pages="clr-namespace:DHT.Desktop.Main.Pages"
xmlns:controls="clr-namespace:DHT.Desktop.Main.Controls"
mc:Ignorable="d" d:DesignWidth="800" d:DesignHeight="450"
x:Class="DHT.Desktop.Main.Pages.DownloadsPage"
x:DataType="pages:DownloadsPageModel">
x:Class="DHT.Desktop.Main.Pages.AttachmentsPage"
x:DataType="pages:AttachmentsPageModel">
<Design.DataContext>
<pages:DownloadsPageModel />
<pages:AttachmentsPageModel />
</Design.DataContext>
<UserControl.Styles>
@@ -31,15 +31,19 @@
</UserControl.Styles>
<StackPanel Orientation="Vertical" Spacing="20">
<Button Command="{Binding OnClickToggleDownload}" Content="{Binding ToggleDownloadButtonText}" IsEnabled="{Binding IsToggleDownloadButtonEnabled}" />
<controls:DownloadItemFilterPanel DataContext="{Binding FilterModel}" IsEnabled="{Binding !$parent[UserControl].((pages:DownloadsPageModel)DataContext).IsDownloading}" />
<DockPanel>
<Button Command="{Binding OnClickToggleDownload}" Content="{Binding ToggleDownloadButtonText}" IsEnabled="{Binding IsToggleDownloadButtonEnabled}" DockPanel.Dock="Left" />
<TextBlock Text="{Binding DownloadMessage}" MinWidth="100" Margin="10 0 0 0" VerticalAlignment="Center" TextAlignment="Right" DockPanel.Dock="Left" />
<ProgressBar Value="{Binding DownloadProgress}" IsVisible="{Binding IsDownloading}" Margin="15 0" VerticalAlignment="Center" DockPanel.Dock="Right" />
</DockPanel>
<controls:AttachmentFilterPanel DataContext="{Binding FilterModel}" IsEnabled="{Binding !IsDownloading, RelativeSource={RelativeSource AncestorType=pages:AttachmentsPageModel}}" />
<StackPanel Orientation="Vertical" Spacing="12">
<Expander Header="Download Status" IsExpanded="True">
<DataGrid ItemsSource="{Binding StatisticsRows}" AutoGenerateColumns="False" CanUserReorderColumns="False" CanUserResizeColumns="False" CanUserSortColumns="False" IsReadOnly="True">
<DataGrid.Columns>
<DataGridTextColumn Header="State" Binding="{Binding State, Mode=OneWay}" Width="*" />
<DataGridTextColumn Header="Files" Binding="{Binding Items, Mode=OneWay, Converter={StaticResource NumberValueConverter}}" Width="*" CellStyleClasses="right" />
<DataGridTextColumn Header="Size" Binding="{Binding SizeText, Mode=OneWay}" Width="*" CellStyleClasses="right" />
<DataGridTextColumn Header="State" Binding="{Binding State}" Width="*" />
<DataGridTextColumn Header="Attachments" Binding="{Binding Items, Mode=OneWay, Converter={StaticResource NumberValueConverter}}" Width="*" CellStyleClasses="right" />
<DataGridTextColumn Header="Size" Binding="{Binding Size, Mode=OneWay, Converter={StaticResource BytesValueConverter}}" Width="*" CellStyleClasses="right" />
</DataGrid.Columns>
</DataGrid>
</Expander>

View File

@@ -4,8 +4,8 @@ using Avalonia.Controls;
namespace DHT.Desktop.Main.Pages;
[SuppressMessage("ReSharper", "MemberCanBeInternal")]
public sealed partial class DownloadsPage : UserControl {
public DownloadsPage() {
public sealed partial class AttachmentsPage : UserControl {
public AttachmentsPage() {
InitializeComponent();
}
}

View File

@@ -0,0 +1,243 @@
using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Reactive.Linq;
using System.Threading.Tasks;
using Avalonia.ReactiveUI;
using DHT.Desktop.Common;
using DHT.Desktop.Main.Controls;
using DHT.Server;
using DHT.Server.Data;
using DHT.Server.Data.Aggregations;
using DHT.Server.Data.Filters;
using DHT.Server.Database;
using DHT.Utils.Logging;
using DHT.Utils.Models;
using DHT.Utils.Tasks;
namespace DHT.Desktop.Main.Pages;
sealed class AttachmentsPageModel : BaseModel, IDisposable {
private static readonly Log Log = Log.ForType<AttachmentsPageModel>();
private static readonly DownloadItemFilter EnqueuedItemFilter = new () {
IncludeStatuses = new HashSet<DownloadStatus> {
DownloadStatus.Enqueued,
DownloadStatus.Downloading
}
};
private bool isToggleDownloadButtonEnabled = true;
public bool IsToggleDownloadButtonEnabled {
get => isToggleDownloadButtonEnabled;
set => Change(ref isToggleDownloadButtonEnabled, value);
}
public string ToggleDownloadButtonText => IsDownloading ? "Stop Downloading" : "Start Downloading";
private bool isRetryingFailedDownloads = false;
public bool IsRetryingFailedDownloads {
get => isRetryingFailedDownloads;
set {
isRetryingFailedDownloads = value;
OnPropertyChanged(nameof(IsRetryFailedOnDownloadsButtonEnabled));
}
}
public bool IsRetryFailedOnDownloadsButtonEnabled => !IsRetryingFailedDownloads && HasFailedDownloads;
public string DownloadMessage { get; set; } = "";
public double DownloadProgress => totalItemsToDownloadCount is null or 0 ? 0.0 : 100.0 * doneItemsCount / totalItemsToDownloadCount.Value;
public AttachmentFilterPanelModel FilterModel { get; }
private readonly StatisticsRow statisticsEnqueued = new ("Enqueued");
private readonly StatisticsRow statisticsDownloaded = new ("Downloaded");
private readonly StatisticsRow statisticsFailed = new ("Failed");
private readonly StatisticsRow statisticsSkipped = new ("Skipped");
public List<StatisticsRow> StatisticsRows => [
statisticsEnqueued,
statisticsDownloaded,
statisticsFailed,
statisticsSkipped
];
public bool IsDownloading => state.Downloader.IsDownloading;
public bool HasFailedDownloads => statisticsFailed.Items > 0;
private readonly State state;
private readonly ThrottledTask enqueueDownloadItemsTask;
private readonly ThrottledTask<DownloadStatusStatistics> downloadStatisticsTask;
private IDisposable? finishedItemsSubscription;
private int doneItemsCount;
private int initialFinishedCount;
private int? totalItemsToDownloadCount;
public AttachmentsPageModel() : this(State.Dummy) {}
public AttachmentsPageModel(State state) {
this.state = state;
FilterModel = new AttachmentFilterPanelModel(state);
enqueueDownloadItemsTask = new ThrottledTask(RecomputeDownloadStatistics, TaskScheduler.FromCurrentSynchronizationContext());
downloadStatisticsTask = new ThrottledTask<DownloadStatusStatistics>(UpdateStatistics, TaskScheduler.FromCurrentSynchronizationContext());
RecomputeDownloadStatistics();
state.Db.Statistics.PropertyChanged += OnDbStatisticsChanged;
}
public void Dispose() {
state.Db.Statistics.PropertyChanged -= OnDbStatisticsChanged;
downloadStatisticsTask.Dispose();
finishedItemsSubscription?.Dispose();
FilterModel.Dispose();
}
private void OnDbStatisticsChanged(object? sender, PropertyChangedEventArgs e) {
if (e.PropertyName == nameof(DatabaseStatistics.TotalAttachments)) {
if (IsDownloading) {
EnqueueDownloadItemsLater();
}
else {
RecomputeDownloadStatistics();
}
}
else if (e.PropertyName == nameof(DatabaseStatistics.TotalDownloads)) {
RecomputeDownloadStatistics();
}
}
private async Task EnqueueDownloadItems() {
await state.Db.Downloads.EnqueueDownloadItems(CreateAttachmentFilter());
}
private void EnqueueDownloadItemsLater() {
var filter = CreateAttachmentFilter();
enqueueDownloadItemsTask.Post(cancellationToken => state.Db.Downloads.EnqueueDownloadItems(filter, cancellationToken));
}
private AttachmentFilter CreateAttachmentFilter() {
var filter = FilterModel.CreateFilter();
filter.DownloadItemRule = AttachmentFilter.DownloadItemRules.OnlyNotPresent;
return filter;
}
private void RecomputeDownloadStatistics() {
downloadStatisticsTask.Post(state.Db.Downloads.GetStatistics);
}
private void UpdateStatistics(DownloadStatusStatistics statusStatistics) {
var hadFailedDownloads = HasFailedDownloads;
statisticsEnqueued.Items = statusStatistics.EnqueuedCount;
statisticsEnqueued.Size = statusStatistics.EnqueuedSize;
statisticsDownloaded.Items = statusStatistics.SuccessfulCount;
statisticsDownloaded.Size = statusStatistics.SuccessfulSize;
statisticsFailed.Items = statusStatistics.FailedCount;
statisticsFailed.Size = statusStatistics.FailedSize;
statisticsSkipped.Items = statusStatistics.SkippedCount;
statisticsSkipped.Size = statusStatistics.SkippedSize;
OnPropertyChanged(nameof(StatisticsRows));
if (hadFailedDownloads != HasFailedDownloads) {
OnPropertyChanged(nameof(HasFailedDownloads));
OnPropertyChanged(nameof(IsRetryFailedOnDownloadsButtonEnabled));
}
totalItemsToDownloadCount = statisticsEnqueued.Items + statisticsDownloaded.Items + statisticsFailed.Items - initialFinishedCount;
UpdateDownloadMessage();
}
private void UpdateDownloadMessage() {
DownloadMessage = IsDownloading ? doneItemsCount.Format() + " / " + (totalItemsToDownloadCount?.Format() ?? "?") : "";
OnPropertyChanged(nameof(DownloadMessage));
OnPropertyChanged(nameof(DownloadProgress));
}
public async Task OnClickToggleDownload() {
IsToggleDownloadButtonEnabled = false;
if (IsDownloading) {
await state.Downloader.Stop();
finishedItemsSubscription?.Dispose();
finishedItemsSubscription = null;
RecomputeDownloadStatistics();
await state.Db.Downloads.RemoveDownloadItems(EnqueuedItemFilter, FilterRemovalMode.RemoveMatching);
doneItemsCount = 0;
initialFinishedCount = 0;
totalItemsToDownloadCount = null;
UpdateDownloadMessage();
}
else {
var finishedItems = await state.Downloader.Start();
initialFinishedCount = statisticsDownloaded.Items + statisticsFailed.Items;
finishedItemsSubscription = finishedItems.Select(static _ => true)
.Buffer(TimeSpan.FromMilliseconds(100))
.Select(static items => items.Count)
.Where(static items => items > 0)
.ObserveOn(AvaloniaScheduler.Instance)
.Subscribe(OnItemsFinished);
await EnqueueDownloadItems();
}
OnPropertyChanged(nameof(ToggleDownloadButtonText));
OnPropertyChanged(nameof(IsDownloading));
IsToggleDownloadButtonEnabled = true;
}
private void OnItemsFinished(int finishedItemCount) {
doneItemsCount += finishedItemCount;
UpdateDownloadMessage();
RecomputeDownloadStatistics();
}
public async Task OnClickRetryFailedDownloads() {
IsRetryingFailedDownloads = true;
try {
var allExceptFailedFilter = new DownloadItemFilter {
IncludeStatuses = new HashSet<DownloadStatus> {
DownloadStatus.Enqueued,
DownloadStatus.Downloading,
DownloadStatus.Success
}
};
await state.Db.Downloads.RemoveDownloadItems(allExceptFailedFilter, FilterRemovalMode.KeepMatching);
if (IsDownloading) {
await EnqueueDownloadItems();
}
} catch (Exception e) {
Log.Error(e);
} finally {
IsRetryingFailedDownloads = false;
}
}
public sealed class StatisticsRow {
public string State { get; }
public int Items { get; set; }
public ulong? Size { get; set; }
public StatisticsRow(string state) {
State = state;
}
}
}

View File

@@ -18,12 +18,13 @@ using DHT.Server;
using DHT.Server.Data;
using DHT.Server.Database;
using DHT.Server.Database.Import;
using DHT.Server.Database.Sqlite.Schema;
using DHT.Server.Database.Sqlite.Utils;
using DHT.Utils.Logging;
using DHT.Utils.Models;
namespace DHT.Desktop.Main.Pages;
sealed class DatabasePageModel {
sealed class DatabasePageModel : BaseModel {
private static readonly Log Log = Log.ForType<DatabasePageModel>();
public IDatabaseFile Db { get; }
@@ -80,7 +81,7 @@ sealed class DatabasePageModel {
private static async Task MergeWithDatabaseFromPaths(IDatabaseFile target, string[] paths, ProgressDialog dialog, IProgressCallback callback) {
var schemaUpgradeCallbacks = new SchemaUpgradeCallbacks(dialog, paths.Length);
await PerformImport(target, paths, dialog, callback, "Database Merge", "Database Error", "database file", async path => {
IDatabaseFile? db = await DatabaseGui.TryOpenOrCreateDatabaseFromPath(path, dialog, schemaUpgradeCallbacks);
@@ -92,7 +93,7 @@ sealed class DatabasePageModel {
await target.AddFrom(db);
return true;
} finally {
await db.DisposeAsync();
db.Dispose();
}
});
}
@@ -101,7 +102,7 @@ sealed class DatabasePageModel {
private readonly ProgressDialog dialog;
private readonly int total;
private bool? decision;
public SchemaUpgradeCallbacks(ProgressDialog dialog, int total) {
this.total = total;
this.dialog = dialog;
@@ -149,7 +150,7 @@ sealed class DatabasePageModel {
await PerformImport(target, paths, dialog, callback, "Legacy Archive Import", "Legacy Archive Error", "archive file", async path => {
await using var jsonStream = new FileStream(path, FileMode.Open, FileAccess.Read, FileShare.Read);
return await LegacyArchiveImport.Read(jsonStream, target, fakeSnowflake, async servers => {
SynchronizationContext? prevSyncContext = SynchronizationContext.Current;
SynchronizationContext.SetSynchronizationContext(new AvaloniaSynchronizationContext());
@@ -164,7 +165,7 @@ sealed class DatabasePageModel {
static bool IsValidSnowflake(string value) {
return string.IsNullOrEmpty(value) || ulong.TryParse(value, out _);
}
var items = new List<TextBoxItem<DHT.Server.Data.Server>>();
foreach (var server in servers.OrderBy(static server => server.Type).ThenBy(static server => server.Name)) {
@@ -193,7 +194,7 @@ sealed class DatabasePageModel {
private static async Task PerformImport(IDatabaseFile target, string[] paths, ProgressDialog dialog, IProgressCallback callback, string neutralDialogTitle, string errorDialogTitle, string itemName, Func<string, Task<bool>> performImport) {
int total = paths.Length;
var oldStatistics = await DatabaseStatistics.Take(target);
var oldStatistics = await target.SnapshotStatistics();
int successful = 0;
int finished = 0;
@@ -224,26 +225,15 @@ sealed class DatabasePageModel {
return;
}
var newStatistics = await DatabaseStatistics.Take(target);
var newStatistics = await target.SnapshotStatistics();
await Dialog.ShowOk(dialog, neutralDialogTitle, GetImportDialogMessage(oldStatistics, newStatistics, successful, total, itemName));
}
private sealed record DatabaseStatistics(long ServerCount, long ChannelCount, long UserCount, long MessageCount) {
public static async Task<DatabaseStatistics> Take(IDatabaseFile db) {
return new DatabaseStatistics(
await db.Servers.Count(),
await db.Channels.Count(),
await db.Users.Count(),
await db.Messages.Count()
);
}
}
private static string GetImportDialogMessage(DatabaseStatistics oldStatistics, DatabaseStatistics newStatistics, int successfulItems, int totalItems, string itemName) {
long newServers = newStatistics.ServerCount - oldStatistics.ServerCount;
long newChannels = newStatistics.ChannelCount - oldStatistics.ChannelCount;
long newUsers = newStatistics.UserCount - oldStatistics.UserCount;
long newMessages = newStatistics.MessageCount - oldStatistics.MessageCount;
private static string GetImportDialogMessage(DatabaseStatisticsSnapshot oldStatistics, DatabaseStatisticsSnapshot newStatistics, int successfulItems, int totalItems, string itemName) {
long newServers = newStatistics.TotalServers - oldStatistics.TotalServers;
long newChannels = newStatistics.TotalChannels - oldStatistics.TotalChannels;
long newUsers = newStatistics.TotalUsers - oldStatistics.TotalUsers;
long newMessages = newStatistics.TotalMessages - oldStatistics.TotalMessages;
StringBuilder message = new StringBuilder();
message.Append("Processed ");

View File

@@ -9,165 +9,168 @@ using DHT.Desktop.Dialogs.Progress;
using DHT.Server;
using DHT.Server.Data;
using DHT.Server.Service;
using DHT.Utils.Models;
namespace DHT.Desktop.Main.Pages;
namespace DHT.Desktop.Main.Pages {
sealed class DebugPageModel : BaseModel {
public string GenerateChannels { get; set; } = "0";
public string GenerateUsers { get; set; } = "0";
public string GenerateMessages { get; set; } = "0";
sealed class DebugPageModel {
public string GenerateChannels { get; set; } = "0";
public string GenerateUsers { get; set; } = "0";
public string GenerateMessages { get; set; } = "0";
private readonly Window window;
private readonly State state;
private readonly Window window;
private readonly State state;
[Obsolete("Designer")]
public DebugPageModel() : this(null!, State.Dummy) {}
[Obsolete("Designer")]
public DebugPageModel() : this(null!, State.Dummy) {}
public DebugPageModel(Window window, State state) {
this.window = window;
this.state = state;
}
public async void OnClickAddRandomDataToDatabase() {
if (!int.TryParse(GenerateChannels, out int channels) || channels < 1) {
await Dialog.ShowOk(window, "Generate Random Data", "Amount of channels must be at least 1!");
return;
public DebugPageModel(Window window, State state) {
this.window = window;
this.state = state;
}
if (!int.TryParse(GenerateUsers, out int users) || users < 1) {
await Dialog.ShowOk(window, "Generate Random Data", "Amount of users must be at least 1!");
return;
public async void OnClickAddRandomDataToDatabase() {
if (!int.TryParse(GenerateChannels, out int channels) || channels < 1) {
await Dialog.ShowOk(window, "Generate Random Data", "Amount of channels must be at least 1!");
return;
}
if (!int.TryParse(GenerateUsers, out int users) || users < 1) {
await Dialog.ShowOk(window, "Generate Random Data", "Amount of users must be at least 1!");
return;
}
if (!int.TryParse(GenerateMessages, out int messages) || messages < 1) {
await Dialog.ShowOk(window, "Generate Random Data", "Amount of messages must be at least 1!");
return;
}
await ProgressDialog.Show(window, "Generating Random Data", async (_, callback) => await GenerateRandomData(channels, users, messages, callback));
}
if (!int.TryParse(GenerateMessages, out int messages) || messages < 1) {
await Dialog.ShowOk(window, "Generate Random Data", "Amount of messages must be at least 1!");
return;
}
private const int BatchSize = 500;
await ProgressDialog.Show(window, "Generating Random Data", async (_, callback) => await GenerateRandomData(channels, users, messages, callback));
}
private async Task GenerateRandomData(int channelCount, int userCount, int messageCount, IProgressCallback callback) {
int batchCount = (messageCount + BatchSize - 1) / BatchSize;
await callback.Update("Adding messages in batches of " + BatchSize, 0, batchCount);
private const int BatchSize = 500;
var rand = new Random();
var server = new DHT.Server.Data.Server {
Id = RandomId(rand),
Name = RandomName("s"),
Type = ServerType.Server,
};
private async Task GenerateRandomData(int channelCount, int userCount, int messageCount, IProgressCallback callback) {
int batchCount = (messageCount + BatchSize - 1) / BatchSize;
await callback.Update("Adding messages in batches of " + BatchSize, 0, batchCount);
var rand = new Random();
var server = new DHT.Server.Data.Server {
Id = RandomId(rand),
Name = RandomName("s"),
Type = ServerType.Server,
};
var channels = Enumerable.Range(0, channelCount).Select(i => new Channel {
Id = RandomId(rand),
Server = server.Id,
Name = RandomName("c"),
ParentId = null,
Position = i,
Topic = RandomText(rand, 10),
Nsfw = rand.Next(4) == 0,
}).ToArray();
var users = Enumerable.Range(0, userCount).Select(_ => new User {
Id = RandomId(rand),
Name = RandomName("u"),
AvatarUrl = null,
Discriminator = rand.Next(0, 9999).ToString(),
}).ToArray();
await state.Db.Users.Add(users);
await state.Db.Servers.Add([server]);
await state.Db.Channels.Add(channels);
var now = DateTimeOffset.Now;
int batchIndex = 0;
while (messageCount > 0) {
int hourOffset = batchIndex;
var messages = Enumerable.Range(0, Math.Min(messageCount, BatchSize)).Select(i => {
DateTimeOffset time = now.AddHours(hourOffset).AddMinutes(i * 60.0 / BatchSize);
DateTimeOffset? edit = rand.Next(100) == 0 ? time.AddSeconds(rand.Next(1, 60)) : null;
var timeMillis = time.ToUnixTimeMilliseconds();
var editMillis = edit?.ToUnixTimeMilliseconds();
return new Message {
Id = (ulong) timeMillis,
Sender = RandomBiasedIndex(rand, users).Id,
Channel = RandomBiasedIndex(rand, channels).Id,
Text = RandomText(rand, 100),
Timestamp = timeMillis,
EditTimestamp = editMillis,
RepliedToId = null,
Attachments = ImmutableList<Attachment>.Empty,
Embeds = ImmutableList<Embed>.Empty,
Reactions = ImmutableList<Reaction>.Empty,
};
var channels = Enumerable.Range(0, channelCount).Select(i => new Channel {
Id = RandomId(rand),
Server = server.Id,
Name = RandomName("c"),
ParentId = null,
Position = i,
Topic = RandomText(rand, 10),
Nsfw = rand.Next(4) == 0,
}).ToArray();
await state.Db.Messages.Add(messages);
var users = Enumerable.Range(0, userCount).Select(_ => new User {
Id = RandomId(rand),
Name = RandomName("u"),
AvatarUrl = null,
Discriminator = rand.Next(0, 9999).ToString(),
}).ToArray();
messageCount -= BatchSize;
await callback.Update("Adding messages in batches of " + BatchSize, ++batchIndex, batchCount);
await state.Db.Users.Add(users);
await state.Db.Servers.Add([server]);
await state.Db.Channels.Add(channels);
var now = DateTimeOffset.Now;
int batchIndex = 0;
while (messageCount > 0) {
int hourOffset = batchIndex;
var messages = Enumerable.Range(0, Math.Min(messageCount, BatchSize)).Select(i => {
DateTimeOffset time = now.AddHours(hourOffset).AddMinutes(i * 60.0 / BatchSize);
DateTimeOffset? edit = rand.Next(100) == 0 ? time.AddSeconds(rand.Next(1, 60)) : null;
var timeMillis = time.ToUnixTimeMilliseconds();
var editMillis = edit?.ToUnixTimeMilliseconds();
return new Message {
Id = (ulong) timeMillis,
Sender = RandomBiasedIndex(rand, users).Id,
Channel = RandomBiasedIndex(rand, channels).Id,
Text = RandomText(rand, 100),
Timestamp = timeMillis,
EditTimestamp = editMillis,
RepliedToId = null,
Attachments = ImmutableList<Attachment>.Empty,
Embeds = ImmutableList<Embed>.Empty,
Reactions = ImmutableList<Reaction>.Empty,
};
}).ToArray();
await state.Db.Messages.Add(messages);
messageCount -= BatchSize;
await callback.Update("Adding messages in batches of " + BatchSize, ++batchIndex, batchCount);
}
}
}
private static ulong RandomId(Random rand) {
ulong h = unchecked((ulong) rand.Next());
ulong l = unchecked((ulong) rand.Next());
return (h << 32) | l;
}
private static ulong RandomId(Random rand) {
ulong h = unchecked((ulong) rand.Next());
ulong l = unchecked((ulong) rand.Next());
return (h << 32) | l;
}
private static string RandomName(string prefix) {
return prefix + "-" + ServerUtils.GenerateRandomToken(5);
}
private static string RandomName(string prefix) {
return prefix + "-" + ServerUtils.GenerateRandomToken(5);
}
private static T RandomBiasedIndex<T>(Random rand, T[] options) {
return options[(int) Math.Floor(options.Length * rand.NextDouble() * rand.NextDouble())];
}
private static T RandomBiasedIndex<T>(Random rand, T[] options) {
return options[(int) Math.Floor(options.Length * rand.NextDouble() * rand.NextDouble())];
}
private static readonly string[] RandomWords = [
"apple", "apricot", "artichoke", "arugula", "asparagus", "avocado",
"banana", "bean", "beechnut", "beet", "blackberry", "blackcurrant", "blueberry", "boysenberry", "bramble", "broccoli",
"cabbage", "cacao", "cantaloupe", "caper", "carambola", "carrot", "cauliflower", "celery", "chard", "cherry", "chokeberry", "citron", "clementine", "coconut", "corn", "crabapple", "cranberry", "cucumber", "currant",
"daikon", "date", "dewberry", "durian",
"edamame", "eggplant", "elderberry", "endive",
"fig",
"garlic", "ginger", "gooseberry", "grape", "grapefruit", "guava",
"honeysuckle", "horseradish", "huckleberry",
"jackfruit", "jicama",
"kale", "kiwi", "kohlrabi", "kumquat",
"leek", "lemon", "lentil", "lettuce", "lime",
"mandarin", "mango", "mushroom", "myrtle",
"nectarine", "nut",
"olive", "okra", "onion", "orange",
"papaya", "parsnip", "pawpaw", "peach", "pear", "pea", "pepper", "persimmon", "pineapple", "plum", "plantain", "pomegranate", "pomelo", "potato", "prune", "pumpkin",
"quandong", "quinoa",
"radicchio", "radish", "raisin", "raspberry", "redcurrant", "rhubarb", "rutabaga",
"spinach", "strawberry", "squash",
"tamarind", "tangerine", "tomatillo", "tomato", "turnip",
"vanilla",
"watercress", "watermelon",
"yam",
"zucchini"
];
private static readonly string[] RandomWords = {
"apple", "apricot", "artichoke", "arugula", "asparagus", "avocado",
"banana", "bean", "beechnut", "beet", "blackberry", "blackcurrant", "blueberry", "boysenberry", "bramble", "broccoli",
"cabbage", "cacao", "cantaloupe", "caper", "carambola", "carrot", "cauliflower", "celery", "chard", "cherry", "chokeberry", "citron", "clementine", "coconut", "corn", "crabapple", "cranberry", "cucumber", "currant",
"daikon", "date", "dewberry", "durian",
"edamame", "eggplant", "elderberry", "endive",
"fig",
"garlic", "ginger", "gooseberry", "grape", "grapefruit", "guava",
"honeysuckle", "horseradish", "huckleberry",
"jackfruit", "jicama",
"kale", "kiwi", "kohlrabi", "kumquat",
"leek", "lemon", "lentil", "lettuce", "lime",
"mandarin", "mango", "mushroom", "myrtle",
"nectarine", "nut",
"olive", "okra", "onion", "orange",
"papaya", "parsnip", "pawpaw", "peach", "pear", "pea", "pepper", "persimmon", "pineapple", "plum", "plantain", "pomegranate", "pomelo", "potato", "prune", "pumpkin",
"quandong", "quinoa",
"radicchio", "radish", "raisin", "raspberry", "redcurrant", "rhubarb", "rutabaga",
"spinach", "strawberry", "squash",
"tamarind", "tangerine", "tomatillo", "tomato", "turnip",
"vanilla",
"watercress", "watermelon",
"yam",
"zucchini",
};
private static string RandomText(Random rand, int maxWords) {
int wordCount = 1 + (int) Math.Floor(maxWords * Math.Pow(rand.NextDouble(), 3));
return string.Join(' ', Enumerable.Range(0, wordCount).Select(_ => RandomWords[rand.Next(RandomWords.Length)]));
private static string RandomText(Random rand, int maxWords) {
int wordCount = 1 + (int) Math.Floor(maxWords * Math.Pow(rand.NextDouble(), 3));
return string.Join(' ', Enumerable.Range(0, wordCount).Select(_ => RandomWords[rand.Next(RandomWords.Length)]));
}
}
}
#else
namespace DHT.Desktop.Main.Pages;
using DHT.Utils.Models;
sealed class DebugPageModel {
public string GenerateChannels { get; set; } = "0";
public string GenerateUsers { get; set; } = "0";
public string GenerateMessages { get; set; } = "0";
namespace DHT.Desktop.Main.Pages {
sealed class DebugPageModel : BaseModel {
public string GenerateChannels { get; set; } = "0";
public string GenerateUsers { get; set; } = "0";
public string GenerateMessages { get; set; } = "0";
public void OnClickAddRandomDataToDatabase() {}
public void OnClickAddRandomDataToDatabase() {}
}
}
#endif

View File

@@ -1,186 +0,0 @@
using System;
using System.Collections.ObjectModel;
using System.Reactive.Linq;
using System.Threading.Tasks;
using Avalonia.ReactiveUI;
using CommunityToolkit.Mvvm.ComponentModel;
using DHT.Desktop.Common;
using DHT.Desktop.Main.Controls;
using DHT.Server;
using DHT.Server.Data.Aggregations;
using DHT.Server.Data.Filters;
using DHT.Server.Download;
using DHT.Utils.Logging;
using DHT.Utils.Tasks;
namespace DHT.Desktop.Main.Pages;
sealed partial class DownloadsPageModel : ObservableObject, IDisposable {
private static readonly Log Log = Log.ForType<DownloadsPageModel>();
[ObservableProperty(Setter = Access.Private)]
private bool isToggleDownloadButtonEnabled = true;
public string ToggleDownloadButtonText => IsDownloading ? "Stop Downloading" : "Start Downloading";
[ObservableProperty(Setter = Access.Private)]
[NotifyPropertyChangedFor(nameof(IsRetryFailedOnDownloadsButtonEnabled))]
private bool isRetryingFailedDownloads = false;
[ObservableProperty(Setter = Access.Private)]
[NotifyPropertyChangedFor(nameof(IsRetryFailedOnDownloadsButtonEnabled))]
private bool hasFailedDownloads;
public bool IsRetryFailedOnDownloadsButtonEnabled => !IsRetryingFailedDownloads && HasFailedDownloads;
[ObservableProperty(Setter = Access.Private)]
private string downloadMessage = "";
public DownloadItemFilterPanelModel FilterModel { get; }
private readonly StatisticsRow statisticsPending = new ("Pending");
private readonly StatisticsRow statisticsDownloaded = new ("Downloaded");
private readonly StatisticsRow statisticsFailed = new ("Failed");
private readonly StatisticsRow statisticsSkipped = new ("Skipped");
public ObservableCollection<StatisticsRow> StatisticsRows { get; }
public bool IsDownloading => state.Downloader.IsDownloading;
private readonly State state;
private readonly ThrottledTask<DownloadStatusStatistics> downloadStatisticsTask;
private readonly IDisposable downloadItemCountSubscription;
private IDisposable? finishedItemsSubscription;
private DownloadItemFilter? currentDownloadFilter;
public DownloadsPageModel() : this(State.Dummy) {}
public DownloadsPageModel(State state) {
this.state = state;
FilterModel = new DownloadItemFilterPanelModel(state);
StatisticsRows = [
statisticsPending,
statisticsDownloaded,
statisticsFailed,
statisticsSkipped
];
downloadStatisticsTask = new ThrottledTask<DownloadStatusStatistics>(Log, UpdateStatistics, TaskScheduler.FromCurrentSynchronizationContext());
downloadItemCountSubscription = state.Db.Downloads.TotalCount.ObserveOn(AvaloniaScheduler.Instance).Subscribe(OnDownloadCountChanged);
RecomputeDownloadStatistics();
}
public void Dispose() {
finishedItemsSubscription?.Dispose();
downloadItemCountSubscription.Dispose();
downloadStatisticsTask.Dispose();
FilterModel.Dispose();
}
private void OnDownloadCountChanged(long newDownloadCount) {
RecomputeDownloadStatistics();
}
public async Task OnClickToggleDownload() {
IsToggleDownloadButtonEnabled = false;
if (IsDownloading) {
await state.Downloader.Stop();
await state.Db.Downloads.MoveDownloadingItemsBackToQueue();
finishedItemsSubscription?.Dispose();
finishedItemsSubscription = null;
currentDownloadFilter = null;
}
else {
await state.Db.Downloads.MoveDownloadingItemsBackToQueue();
var finishedItems = await state.Downloader.Start(currentDownloadFilter = FilterModel.CreateFilter());
finishedItemsSubscription = finishedItems.ObserveOn(AvaloniaScheduler.Instance).Subscribe(OnItemFinished);
}
RecomputeDownloadStatistics();
OnPropertyChanged(nameof(ToggleDownloadButtonText));
OnPropertyChanged(nameof(IsDownloading));
IsToggleDownloadButtonEnabled = true;
}
private void OnItemFinished(DownloadItem item) {
RecomputeDownloadStatistics();
}
public async Task OnClickRetryFailedDownloads() {
IsRetryingFailedDownloads = true;
try {
await state.Db.Downloads.RetryFailed();
RecomputeDownloadStatistics();
} catch (Exception e) {
Log.Error(e);
} finally {
IsRetryingFailedDownloads = false;
}
}
private void RecomputeDownloadStatistics() {
downloadStatisticsTask.Post(cancellationToken => state.Db.Downloads.GetStatistics(currentDownloadFilter ?? new DownloadItemFilter(), cancellationToken));
}
private void UpdateStatistics(DownloadStatusStatistics statusStatistics) {
statisticsPending.Items = statusStatistics.PendingCount;
statisticsPending.Size = statusStatistics.PendingTotalSize;
statisticsPending.HasFilesWithUnknownSize = statusStatistics.PendingWithUnknownSizeCount > 0;
statisticsDownloaded.Items = statusStatistics.SuccessfulCount;
statisticsDownloaded.Size = statusStatistics.SuccessfulTotalSize;
statisticsDownloaded.HasFilesWithUnknownSize = statusStatistics.SuccessfulWithUnknownSizeCount > 0;
statisticsFailed.Items = statusStatistics.FailedCount;
statisticsFailed.Size = statusStatistics.FailedTotalSize;
statisticsFailed.HasFilesWithUnknownSize = statusStatistics.FailedWithUnknownSizeCount > 0;
statisticsSkipped.Items = statusStatistics.SkippedCount;
statisticsSkipped.Size = statusStatistics.SkippedTotalSize;
statisticsSkipped.HasFilesWithUnknownSize = statusStatistics.SkippedWithUnknownSizeCount > 0;
HasFailedDownloads = statusStatistics.FailedCount > 0;
}
[ObservableObject]
public sealed partial class StatisticsRow(string state) {
public string State { get; } = state;
[ObservableProperty]
private int items;
[ObservableProperty]
[NotifyPropertyChangedFor(nameof(SizeText))]
private ulong? size;
[ObservableProperty]
[NotifyPropertyChangedFor(nameof(SizeText))]
private bool hasFilesWithUnknownSize;
public string SizeText {
get {
if (size == null) {
return "-";
}
else if (hasFilesWithUnknownSize) {
return "\u2265 " + BytesValueConverter.Convert(size.Value);
}
else {
return BytesValueConverter.Convert(size.Value);
}
}
}
}
}

View File

@@ -3,25 +3,33 @@ using System.Threading;
using System.Threading.Tasks;
using System.Web;
using Avalonia.Controls;
using CommunityToolkit.Mvvm.ComponentModel;
using DHT.Desktop.Dialogs.Message;
using DHT.Desktop.Discord;
using DHT.Desktop.Server;
using DHT.Utils.Models;
using static DHT.Desktop.Program;
namespace DHT.Desktop.Main.Pages;
sealed partial class TrackingPageModel : ObservableObject {
[ObservableProperty(Setter = Access.Private)]
sealed class TrackingPageModel : BaseModel {
private bool isCopyTrackingScriptButtonEnabled = true;
[ObservableProperty(Setter = Access.Private)]
[NotifyPropertyChangedFor(nameof(ToggleAppDevToolsButtonText))]
private bool? areDevToolsEnabled = null;
public bool IsCopyTrackingScriptButtonEnabled {
get => isCopyTrackingScriptButtonEnabled;
set => Change(ref isCopyTrackingScriptButtonEnabled, value);
}
[ObservableProperty(Setter = Access.Private)]
[NotifyPropertyChangedFor(nameof(ToggleAppDevToolsButtonText))]
private bool isToggleAppDevToolsButtonEnabled = false;
private bool? areDevToolsEnabled;
private bool? AreDevToolsEnabled {
get => areDevToolsEnabled;
set {
Change(ref areDevToolsEnabled, value);
OnPropertyChanged(nameof(ToggleAppDevToolsButtonText));
}
}
public bool IsToggleAppDevToolsButtonEnabled { get; private set; } = false;
public string ToggleAppDevToolsButtonText {
get {
@@ -78,15 +86,17 @@ sealed partial class TrackingPageModel : ObservableObject {
}
private async Task InitializeDevToolsToggle() {
bool? devToolsEnabled = await Task.Run(DiscordAppSettings.AreDevToolsEnabled);
bool? devToolsEnabled = await DiscordAppSettings.AreDevToolsEnabled();
if (devToolsEnabled.HasValue) {
AreDevToolsEnabled = devToolsEnabled.Value;
IsToggleAppDevToolsButtonEnabled = true;
AreDevToolsEnabled = devToolsEnabled.Value;
}
else {
IsToggleAppDevToolsButtonEnabled = false;
}
OnPropertyChanged(nameof(IsToggleAppDevToolsButtonEnabled));
}
public async Task OnClickToggleAppDevTools() {

View File

@@ -8,7 +8,6 @@ using System.Threading.Tasks;
using System.Web;
using Avalonia.Controls;
using Avalonia.Platform.Storage;
using CommunityToolkit.Mvvm.ComponentModel;
using DHT.Desktop.Common;
using DHT.Desktop.Dialogs.File;
using DHT.Desktop.Dialogs.Message;
@@ -18,11 +17,13 @@ using DHT.Desktop.Server;
using DHT.Server;
using DHT.Server.Data.Filters;
using DHT.Server.Database.Export;
using DHT.Server.Database.Export.Strategy;
using DHT.Utils.Models;
using static DHT.Desktop.Program;
namespace DHT.Desktop.Main.Pages;
sealed partial class ViewerPageModel : ObservableObject, IDisposable {
sealed class ViewerPageModel : BaseModel, IDisposable {
public static readonly ConcurrentBag<string> TemporaryFiles = [];
private static readonly FilePickerFileType[] ViewerFileTypes = [
@@ -32,9 +33,13 @@ sealed partial class ViewerPageModel : ObservableObject, IDisposable {
public bool DatabaseToolFilterModeKeep { get; set; } = true;
public bool DatabaseToolFilterModeRemove { get; set; } = false;
[ObservableProperty]
private bool hasFilters = false;
public bool HasFilters {
get => hasFilters;
set => Change(ref hasFilters, value);
}
public MessageFilterPanelModel FilterModel { get; }
private readonly Window window;
@@ -62,19 +67,12 @@ sealed partial class ViewerPageModel : ObservableObject, IDisposable {
public async void OnClickOpenViewer() {
try {
var fullPath = await PrepareTemporaryViewerFile();
string jsConstants = $"""
window.DHT_SERVER_URL = "{HttpUtility.JavaScriptStringEncode("http://127.0.0.1:" + ServerConfiguration.Port)}";
window.DHT_SERVER_TOKEN = "{HttpUtility.JavaScriptStringEncode(ServerConfiguration.Token)}";
""";
var strategy = new LiveViewerExportStrategy(ServerConfiguration.Port, ServerConfiguration.Token);
await ProgressDialog.ShowIndeterminate(window, "Open Viewer", "Creating viewer...", _ => Task.Run(() => WriteViewerFile(fullPath, jsConstants)));
Process.Start(new ProcessStartInfo(fullPath) {
UseShellExecute = true
});
await WriteViewerFile(fullPath, strategy);
Process.Start(new ProcessStartInfo(fullPath) { UseShellExecute = true });
} catch (Exception e) {
await Dialog.ShowOk(window, "Open Viewer", "Could not create or save viewer: " + e.Message);
await Dialog.ShowOk(window, "Open Viewer", "Could not save viewer: " + e.Message);
}
}
@@ -112,18 +110,17 @@ sealed partial class ViewerPageModel : ObservableObject, IDisposable {
}
try {
await ProgressDialog.ShowIndeterminate(window, "Save Viewer", "Creating viewer...", _ => Task.Run(() => WriteViewerFile(path, string.Empty)));
await WriteViewerFile(path, StandaloneViewerExportStrategy.Instance);
} catch (Exception e) {
await Dialog.ShowOk(window, "Save Viewer", "Could not create or save viewer: " + e.Message);
await Dialog.ShowOk(window, "Save Viewer", "Could not save viewer: " + e.Message);
}
}
private async Task WriteViewerFile(string path, string jsConstants) {
private async Task WriteViewerFile(string path, IViewerExportStrategy strategy) {
const string ArchiveTag = "/*[ARCHIVE]*/";
string indexFile = await Resources.ReadTextAsync("Viewer/index.html");
string viewerTemplate = indexFile.Replace("/*[CONSTANTS]*/", jsConstants)
.Replace("/*[JS]*/", await Resources.ReadJoinedAsync("Viewer/scripts/", '\n'))
string viewerTemplate = indexFile.Replace("/*[JS]*/", await Resources.ReadJoinedAsync("Viewer/scripts/", '\n'))
.Replace("/*[CSS]*/", await Resources.ReadJoinedAsync("Viewer/styles/", '\n'));
int viewerArchiveTagStart = viewerTemplate.IndexOf(ArchiveTag);
@@ -132,7 +129,7 @@ sealed partial class ViewerPageModel : ObservableObject, IDisposable {
string jsonTempFile = path + ".tmp";
await using (var jsonStream = new FileStream(jsonTempFile, FileMode.Create, FileAccess.ReadWrite, FileShare.Read)) {
await ViewerJsonExport.Generate(jsonStream, state.Db, FilterModel.CreateFilter());
await ViewerJsonExport.Generate(jsonStream, strategy, state.Db, FilterModel.CreateFilter());
char[] jsonBuffer = new char[Math.Min(32768, jsonStream.Position)];
jsonStream.Position = 0;

View File

@@ -74,7 +74,7 @@
<DockPanel>
<Border Classes="statusBar" DockPanel.Dock="Bottom">
<DockPanel>
<TextBlock Classes="invisibleTabItem" DockPanel.Dock="Left">Downloads</TextBlock>
<TextBlock Classes="invisibleTabItem" DockPanel.Dock="Left">Attachments</TextBlock>
<controls:StatusBar DataContext="{Binding StatusBarModel}" DockPanel.Dock="Right" />
</DockPanel>
</Border>
@@ -94,9 +94,9 @@
<ContentPresenter Content="{Binding TrackingPage}" Classes="page" />
</ScrollViewer>
</TabItem>
<TabItem x:Name="TabDownloads" Header="Downloads" Grid.Row="2">
<TabItem x:Name="TabAttachments" Header="Attachments" Grid.Row="2">
<ScrollViewer>
<ContentPresenter Content="{Binding DownloadsPage}" Classes="page" />
<ContentPresenter Content="{Binding AttachmentsPage}" Classes="page" />
</ScrollViewer>
</TabItem>
<TabItem x:Name="TabViewer" Header="Viewer" Grid.Row="3">

View File

@@ -13,8 +13,8 @@ sealed class MainContentScreenModel : IDisposable {
public TrackingPage TrackingPage { get; }
private TrackingPageModel TrackingPageModel { get; }
public DownloadsPage DownloadsPage { get; }
private DownloadsPageModel DownloadsPageModel { get; }
public AttachmentsPage AttachmentsPage { get; }
private AttachmentsPageModel AttachmentsPageModel { get; }
public ViewerPage ViewerPage { get; }
private ViewerPageModel ViewerPageModel { get; }
@@ -52,8 +52,8 @@ sealed class MainContentScreenModel : IDisposable {
TrackingPageModel = new TrackingPageModel(window);
TrackingPage = new TrackingPage { DataContext = TrackingPageModel };
DownloadsPageModel = new DownloadsPageModel(state);
DownloadsPage = new DownloadsPage { DataContext = DownloadsPageModel };
AttachmentsPageModel = new AttachmentsPageModel(state);
AttachmentsPage = new AttachmentsPage { DataContext = AttachmentsPageModel };
ViewerPageModel = new ViewerPageModel(window, state);
ViewerPage = new ViewerPage { DataContext = ViewerPageModel };
@@ -72,7 +72,7 @@ sealed class MainContentScreenModel : IDisposable {
}
public void Dispose() {
DownloadsPageModel.Dispose();
AttachmentsPageModel.Dispose();
ViewerPageModel.Dispose();
AdvancedPageModel.Dispose();
StatusBarModel.Dispose();

View File

@@ -3,21 +3,25 @@ using System.Collections.Generic;
using System.IO;
using System.Threading.Tasks;
using Avalonia.Controls;
using CommunityToolkit.Mvvm.ComponentModel;
using DHT.Desktop.Common;
using DHT.Desktop.Dialogs.Message;
using DHT.Desktop.Dialogs.Progress;
using DHT.Server.Database;
using DHT.Server.Database.Sqlite.Schema;
using DHT.Server.Database.Sqlite.Utils;
using DHT.Utils.Models;
namespace DHT.Desktop.Main.Screens;
sealed partial class WelcomeScreenModel : ObservableObject {
sealed class WelcomeScreenModel : BaseModel {
public string Version => Program.Version;
[ObservableProperty(Setter = Access.Private)]
private bool isOpenOrCreateDatabaseButtonEnabled = true;
public bool IsOpenOrCreateDatabaseButtonEnabled {
get => isOpenOrCreateDatabaseButtonEnabled;
set => Change(ref isOpenOrCreateDatabaseButtonEnabled, value);
}
public event EventHandler<IDatabaseFile>? DatabaseSelected;
private readonly Window window;

View File

@@ -1,6 +0,0 @@
<?xml version="1.0" encoding="utf-8"?>
<configuration>
<packageSources>
<add key="chylex's repository" value="https://nuget.chylex.com/feed/index.json" />
</packageSources>
</configuration>

View File

@@ -48,16 +48,6 @@ const STATE = (function() {
});
};
const getDate = function(date) {
if (date instanceof Date) {
return date;
}
else {
// noinspection JSUnresolvedReference
return date.toDate();
}
};
const trackingStateChangedListeners = [];
let isTracking = false;
@@ -79,8 +69,8 @@ const STATE = (function() {
* @property {String} channel_id
* @property {DiscordUser} author
* @property {String} content
* @property {Date} timestamp
* @property {Date|null} editedTimestamp
* @property {Timestamp} timestamp
* @property {Timestamp|null} editedTimestamp
* @property {DiscordAttachment[]} attachments
* @property {Object[]} embeds
* @property {DiscordMessageReaction[]} [reactions]
@@ -116,6 +106,11 @@ const STATE = (function() {
* @property {Boolean} animated
*/
/**
* @name Timestamp
* @property {Function} toDate
*/
return {
setup(port, token) {
serverPort = port;
@@ -228,12 +223,12 @@ const STATE = (function() {
sender: msg.author.id,
channel: msg.channel_id,
text: msg.content,
timestamp: getDate(msg.timestamp).getTime()
timestamp: msg.timestamp.toDate().getTime()
};
if (msg.editedTimestamp !== null) {
// noinspection JSUnusedGlobalSymbols
obj.editTimestamp = getDate(msg.editedTimestamp).getTime();
obj.editTimestamp = msg.editedTimestamp.toDate().getTime();
}
if (msg.messageReference !== null) {

View File

@@ -6,7 +6,6 @@
<script type="text/javascript">
window.DHT_EMBEDDED = "/*[ARCHIVE]*/";
/*[CONSTANTS]*/
/*[JS]*/
</script>
<style>

View File

@@ -35,23 +35,6 @@ const DISCORD = (function() {
let templateReaction;
let templateReactionCustom;
const fileUrlProcessor = function(serverUrl, serverToken) {
if (typeof serverUrl === "string" && typeof serverToken === "string") {
return url => serverUrl + "/get-downloaded-file/" + encodeURIComponent(url) + "?token=" + encodeURIComponent(serverToken);
}
else {
return url => url;
}
}(
window["DHT_SERVER_URL"],
window["DHT_SERVER_TOKEN"]
);
const getEmoji = function(name, id, extension) {
const tag = ":" + name + ":";
return "<img src='" + fileUrlProcessor("https://cdn.discordapp.com/emojis/" + id + "." + extension) + "' alt='" + tag + "' title='" + tag + "' class='emoji'>";
};
const processMessageContents = function(contents) {
let processed = DOM.escapeHTML(contents.replace(regex.formatUrlNoEmbed, "$1"));
@@ -71,33 +54,29 @@ const DISCORD = (function() {
.replace(regex.formatStrike, "<s>$1</s>");
}
const animatedEmojiExtension = SETTINGS.enableAnimatedEmoji ? "gif" : "webp";
const animatedEmojiExtension = SETTINGS.enableAnimatedEmoji ? "gif" : "png";
// noinspection HtmlUnknownTarget
processed = processed
.replace(regex.formatUrl, "<a href='$1' target='_blank' rel='noreferrer'>$1</a>")
.replace(regex.mentionChannel, (full, match) => "<span class='link mention-chat'>#" + STATE.getChannelName(match) + "</span>")
.replace(regex.mentionUser, (full, match) => "<span class='link mention-user' title='#" + (STATE.getUserTag(match) || "????") + "'>@" + STATE.getUserName(match) + "</span>")
.replace(regex.customEmojiStatic, (full, m1, m2) => getEmoji(m1, m2, "webp"))
.replace(regex.customEmojiAnimated, (full, m1, m2) => getEmoji(m1, m2, animatedEmojiExtension));
.replace(regex.customEmojiStatic, "<img src='https://cdn.discordapp.com/emojis/$2.png' alt=':$1:' title=':$1:' class='emoji'>")
.replace(regex.customEmojiAnimated, "<img src='https://cdn.discordapp.com/emojis/$2." + animatedEmojiExtension + "' alt=':$1:' title=':$1:' class='emoji'>");
return "<p>" + processed + "</p>";
};
const getAvatarUrlObject = function(avatar) {
return { url: fileUrlProcessor("https://cdn.discordapp.com/avatars/" + avatar.id + "/" + avatar.path + ".webp") };
};
const getImageEmbed = function(url, image) {
if (!SETTINGS.enableImagePreviews) {
return "";
}
if (image.width && image.height) {
return templateEmbedImageWithSize.apply({ url: fileUrlProcessor(url), src: fileUrlProcessor(image.url), width: image.width, height: image.height });
return templateEmbedImageWithSize.apply({ url, src: image.url, width: image.width, height: image.height });
}
else {
return templateEmbedImage.apply({ url: fileUrlProcessor(url), src: fileUrlProcessor(image.url) });
return templateEmbedImage.apply({ url, src: image.url });
}
};
@@ -146,9 +125,8 @@ const DISCORD = (function() {
"</div>"
].join(""));
// noinspection HtmlUnknownTarget
templateUserAvatar = new TEMPLATE([
"<img src='{url}' alt=''>"
"<img src='https://cdn.discordapp.com/avatars/{id}/{path}.webp?size=128' alt=''>"
].join(""));
// noinspection HtmlUnknownTarget
@@ -189,9 +167,8 @@ const DISCORD = (function() {
"<span class='reaction-wrapper'><span class='reaction-emoji'>{n}</span><span class='count'>{c}</span></span>"
].join(""));
// noinspection HtmlUnknownTarget
templateReactionCustom = new TEMPLATE([
"<span class='reaction-wrapper'><img src='{url}' alt=':{n}:' title=':{n}:' class='reaction-emoji-custom'><span class='count'>{c}</span></span>"
"<span class='reaction-wrapper'><img src='https://cdn.discordapp.com/emojis/{id}.{ext}' alt=':{n}:' title=':{n}:' class='reaction-emoji-custom'><span class='count'>{c}</span></span>"
].join(""));
},
@@ -222,7 +199,7 @@ const DISCORD = (function() {
getMessageHTML(message) { // noinspection FunctionWithInconsistentReturnsJS
return (SETTINGS.enableUserAvatars ? templateMessageWithAvatar : templateMessageNoAvatar).apply(message, (property, value) => {
if (property === "avatar") {
return value ? templateUserAvatar.apply(getAvatarUrlObject(value)) : "";
return value ? templateUserAvatar.apply(value) : "";
}
else if (property === "user.tag") {
return value ? value : "????";
@@ -243,10 +220,10 @@ const DISCORD = (function() {
return templateEmbedUnsupported.apply(embed);
}
else if ("image" in embed && embed.image.url) {
return getImageEmbed(fileUrlProcessor(embed.url), embed.image);
return getImageEmbed(embed.url, embed.image);
}
else if ("thumbnail" in embed && embed.thumbnail.url) {
return getImageEmbed(fileUrlProcessor(embed.url), embed.thumbnail);
return getImageEmbed(embed.url, embed.thumbnail);
}
else if ("title" in embed && "description" in embed) {
return templateEmbedRich.apply(embed);
@@ -265,16 +242,14 @@ const DISCORD = (function() {
}
return value.map(attachment => {
const url = fileUrlProcessor(attachment.url);
if (!DISCORD.isImageAttachment(attachment) || !SETTINGS.enableImagePreviews) {
return templateAttachmentDownload.apply({ url, name: attachment.name });
return templateAttachmentDownload.apply(attachment);
}
else if ("width" in attachment && "height" in attachment) {
return templateEmbedImageWithSize.apply({ url, src: url, width: attachment.width, height: attachment.height });
return templateEmbedImageWithSize.apply({ url: attachment.url, src: attachment.url, width: attachment.width, height: attachment.height });
}
else {
return templateEmbedImage.apply({ url, src: url });
return templateEmbedImage.apply({ url: attachment.url, src: attachment.url });
}
}).join("");
}
@@ -290,7 +265,7 @@ const DISCORD = (function() {
}
const user = "<span class='reply-username' title='#" + (value.user.tag ? value.user.tag : "????") + "'>" + value.user.name + "</span>";
const avatar = SETTINGS.enableUserAvatars && value.avatar ? "<span class='reply-avatar'>" + templateUserAvatar.apply(getAvatarUrlObject(value.avatar)) + "</span>" : "";
const avatar = SETTINGS.enableUserAvatars && value.avatar ? "<span class='reply-avatar'>" + templateUserAvatar.apply(value.avatar) + "</span>" : "";
const contents = value.contents ? "<span class='reply-contents'>" + processMessageContents(value.contents) + "</span>" : "";
return "<span class='jump' data-jump='" + value.id + "'>Jump to reply</span><span class='user'>" + avatar + user + "</span>" + contents;
@@ -302,10 +277,9 @@ const DISCORD = (function() {
return "<div class='reactions'>" + value.map(reaction => {
if ("id" in reaction){
const ext = reaction.a && SETTINGS.enableAnimatedEmoji ? "gif" : "webp";
const url = fileUrlProcessor("https://cdn.discordapp.com/emojis/" + reaction.id + "." + ext);
// noinspection JSUnusedGlobalSymbols
return templateReactionCustom.apply({ url, n: reaction.n, c: reaction.c });
// noinspection JSUnusedGlobalSymbols, JSUnresolvedVariable
reaction.ext = reaction.a && SETTINGS.enableAnimatedEmoji ? "gif" : "png";
return templateReactionCustom.apply(reaction);
}
else {
return templateReaction.apply(reaction);

View File

@@ -1,19 +1,15 @@
namespace DHT.Server.Data.Aggregations;
public sealed class DownloadStatusStatistics {
public int PendingCount { get; internal init; }
public ulong PendingTotalSize { get; internal init; }
public int PendingWithUnknownSizeCount { get; internal init; }
public int EnqueuedCount { get; internal set; }
public ulong EnqueuedSize { get; internal set; }
public int SuccessfulCount { get; internal init; }
public ulong SuccessfulTotalSize { get; internal init; }
public int SuccessfulWithUnknownSizeCount { get; internal init; }
public int SuccessfulCount { get; internal set; }
public ulong SuccessfulSize { get; internal set; }
public int FailedCount { get; internal init; }
public ulong FailedTotalSize { get; internal init; }
public int FailedWithUnknownSizeCount { get; internal init; }
public int SkippedCount { get; internal init; }
public ulong SkippedTotalSize { get; internal init; }
public int SkippedWithUnknownSizeCount { get; internal init; }
public int FailedCount { get; internal set; }
public ulong FailedSize { get; internal set; }
public int SkippedCount { get; internal set; }
public ulong SkippedSize { get; internal set; }
}

View File

@@ -1,17 +1,33 @@
using System;
using System.Net;
using DHT.Server.Download;
namespace DHT.Server.Data;
public sealed class Download {
public readonly struct Download {
internal static Download NewSuccess(DownloadItem item, byte[] data) {
return new Download(item.NormalizedUrl, item.DownloadUrl, DownloadStatus.Success, (ulong) Math.Max(data.LongLength, 0), data);
}
internal static Download NewFailure(DownloadItem item, HttpStatusCode? statusCode, ulong size) {
return new Download(item.NormalizedUrl, item.DownloadUrl, statusCode.HasValue ? (DownloadStatus) (int) statusCode : DownloadStatus.GenericError, size);
}
public string NormalizedUrl { get; }
public string DownloadUrl { get; }
public DownloadStatus Status { get; }
public string? Type { get; }
public ulong? Size { get; }
public ulong Size { get; }
public byte[]? Data { get; }
internal Download(string normalizedUrl, string downloadUrl, DownloadStatus status, string? type, ulong? size) {
internal Download(string normalizedUrl, string downloadUrl, DownloadStatus status, ulong size, byte[]? data = null) {
NormalizedUrl = normalizedUrl;
DownloadUrl = downloadUrl;
Status = status;
Type = type;
Size = size;
Data = data;
}
internal Download WithData(byte[] data) {
return new Download(NormalizedUrl, DownloadUrl, Status, Size, data);
}
}

View File

@@ -6,9 +6,8 @@ namespace DHT.Server.Data;
/// Extends <see cref="HttpStatusCode"/> with custom status codes in the range 0-99.
/// </summary>
public enum DownloadStatus {
Pending = 0,
Enqueued = 0,
GenericError = 1,
Downloading = 2,
LastCustomCode = 99,
Success = HttpStatusCode.OK
}

View File

@@ -0,0 +1,6 @@
namespace DHT.Server.Data;
public readonly struct DownloadedAttachment {
public string? Type { get; internal init; }
public byte[] Data { get; internal init; }
}

View File

@@ -1,17 +0,0 @@
namespace DHT.Server.Data.Embeds;
sealed class DiscordEmbedJson {
public string? Type { get; set; }
public string? Url { get; set; }
public JsonImage? Image { get; set; }
public JsonImage? Thumbnail { get; set; }
public JsonImage? Video { get; set; }
public sealed class JsonImage {
public string? Url { get; set; }
public string? ProxyUrl { get; set; }
public int? Width { get; set; }
public int? Height { get; set; }
}
}

View File

@@ -1,7 +0,0 @@
using System.Text.Json.Serialization;
namespace DHT.Server.Data.Embeds;
[JsonSourceGenerationOptions(PropertyNamingPolicy = JsonKnownNamingPolicy.SnakeCaseLower, GenerationMode = JsonSourceGenerationMode.Default)]
[JsonSerializable(typeof(DiscordEmbedJson))]
sealed partial class DiscordEmbedJsonContext : JsonSerializerContext;

View File

@@ -0,0 +1,15 @@
namespace DHT.Server.Data.Filters;
public sealed class AttachmentFilter {
public ulong? MaxBytes { get; set; } = null;
public DownloadItemRules? DownloadItemRule { get; set; } = null;
public bool IsEmpty => MaxBytes == null &&
DownloadItemRule == null;
public enum DownloadItemRules {
OnlyNotPresent,
OnlyPresent
}
}

View File

@@ -3,10 +3,8 @@ using System.Collections.Generic;
namespace DHT.Server.Data.Filters;
public sealed class DownloadItemFilter {
public HashSet<DownloadStatus>? IncludeStatuses { get; set; } = null;
public HashSet<DownloadStatus>? ExcludeStatuses { get; set; } = null;
public ulong? MaxBytes { get; set; } = null;
public HashSet<DownloadStatus>? IncludeStatuses { get; init; } = null;
public HashSet<DownloadStatus>? ExcludeStatuses { get; init; } = null;
public bool IsEmpty => IncludeStatuses == null && ExcludeStatuses == null && MaxBytes == null;
public bool IsEmpty => IncludeStatuses == null && ExcludeStatuses == null;
}

View File

@@ -25,10 +25,8 @@ public static class DatabaseExtensions {
await target.Messages.Add(batchedMessages);
await foreach (var download in source.Downloads.Get()) {
if (download.Status != DownloadStatus.Success || !await source.Downloads.GetDownloadData(download.NormalizedUrl, stream => target.Downloads.AddDownload(download, stream))) {
await target.Downloads.AddDownload(download, stream: null);
}
await foreach (var download in source.Downloads.GetWithoutData()) {
await target.Downloads.AddDownload(download.Status == DownloadStatus.Success ? await source.Downloads.HydrateWithData(download) : download);
}
}
}

View File

@@ -0,0 +1,46 @@
using DHT.Utils.Models;
namespace DHT.Server.Database;
/// <summary>
/// A live view of database statistics.
/// Some of the totals are computed asynchronously and may not reflect the most recent version of the database, or may not be available at all until computed for the first time.
/// </summary>
public sealed class DatabaseStatistics : BaseModel {
private long totalServers;
private long totalChannels;
private long totalUsers;
private long? totalMessages;
private long? totalAttachments;
private long? totalDownloads;
public long TotalServers {
get => totalServers;
internal set => Change(ref totalServers, value);
}
public long TotalChannels {
get => totalChannels;
internal set => Change(ref totalChannels, value);
}
public long TotalUsers {
get => totalUsers;
internal set => Change(ref totalUsers, value);
}
public long? TotalMessages {
get => totalMessages;
internal set => Change(ref totalMessages, value);
}
public long? TotalAttachments {
get => totalAttachments;
internal set => Change(ref totalAttachments, value);
}
public long? TotalDownloads {
get => totalDownloads;
internal set => Change(ref totalDownloads, value);
}
}

View File

@@ -0,0 +1,11 @@
namespace DHT.Server.Database;
/// <summary>
/// A complete snapshot of database statistics at a particular point in time.
/// </summary>
public readonly struct DatabaseStatisticsSnapshot {
public long TotalServers { get; internal init; }
public long TotalChannels { get; internal init; }
public long TotalUsers { get; internal init; }
public long TotalMessages { get; internal init; }
}

View File

@@ -9,6 +9,7 @@ sealed class DummyDatabaseFile : IDatabaseFile {
public static DummyDatabaseFile Instance { get; } = new ();
public string Path => "";
public DatabaseStatistics Statistics { get; } = new ();
public IUserRepository Users { get; } = new IUserRepository.Dummy();
public IServerRepository Servers { get; } = new IServerRepository.Dummy();
@@ -18,11 +19,13 @@ sealed class DummyDatabaseFile : IDatabaseFile {
private DummyDatabaseFile() {}
public Task<DatabaseStatisticsSnapshot> SnapshotStatistics() {
return Task.FromResult(new DatabaseStatisticsSnapshot());
}
public Task Vacuum() {
return Task.CompletedTask;
}
public ValueTask DisposeAsync() {
return ValueTask.CompletedTask;
}
public void Dispose() {}
}

View File

@@ -5,9 +5,9 @@ namespace DHT.Server.Database.Exceptions;
public sealed class DatabaseTooNewException : Exception {
public int DatabaseVersion { get; }
public int CurrentVersion => SqliteSchema.Version;
public int CurrentVersion => Schema.Version;
internal DatabaseTooNewException(int databaseVersion) : base("Database is too new: " + databaseVersion + " > " + SqliteSchema.Version) {
internal DatabaseTooNewException(int databaseVersion) : base("Database is too new: " + databaseVersion + " > " + Schema.Version) {
this.DatabaseVersion = databaseVersion;
}
}

View File

@@ -0,0 +1,7 @@
using DHT.Server.Data;
namespace DHT.Server.Database.Export.Strategy;
public interface IViewerExportStrategy {
string GetAttachmentUrl(Attachment attachment);
}

View File

@@ -0,0 +1,18 @@
using System.Net;
using DHT.Server.Data;
namespace DHT.Server.Database.Export.Strategy;
public sealed class LiveViewerExportStrategy : IViewerExportStrategy {
private readonly string safePort;
private readonly string safeToken;
public LiveViewerExportStrategy(ushort port, string token) {
this.safePort = port.ToString();
this.safeToken = WebUtility.UrlEncode(token);
}
public string GetAttachmentUrl(Attachment attachment) {
return "http://127.0.0.1:" + safePort + "/get-attachment/" + WebUtility.UrlEncode(attachment.NormalizedUrl) + "?token=" + safeToken;
}
}

View File

@@ -0,0 +1,18 @@
using DHT.Server.Data;
namespace DHT.Server.Database.Export.Strategy;
public sealed class StandaloneViewerExportStrategy : IViewerExportStrategy {
public static StandaloneViewerExportStrategy Instance { get; } = new ();
private StandaloneViewerExportStrategy() {}
public string GetAttachmentUrl(Attachment attachment) {
// The normalized URL will not load files from Discord CDN once the time limit is enforced.
// The downloaded URL would work, but only for a limited time, so it is better for the links to not work
// rather than give users a false sense of security.
return attachment.NormalizedUrl;
}
}

View File

@@ -3,9 +3,9 @@ using System.Text.Json.Serialization;
namespace DHT.Server.Database.Export;
[JsonSourceGenerationOptions(
Converters = [typeof(SnowflakeJsonSerializer)],
Converters = new [] { typeof(SnowflakeJsonSerializer) },
PropertyNamingPolicy = JsonKnownNamingPolicy.CamelCase,
GenerationMode = JsonSourceGenerationMode.Default
)]
[JsonSerializable(typeof(ViewerJson))]
sealed partial class ViewerJsonContext : JsonSerializerContext;
sealed partial class ViewerJsonContext : JsonSerializerContext {}

View File

@@ -6,6 +6,7 @@ using System.Text.Json;
using System.Threading.Tasks;
using DHT.Server.Data;
using DHT.Server.Data.Filters;
using DHT.Server.Database.Export.Strategy;
using DHT.Utils.Logging;
namespace DHT.Server.Database.Export;
@@ -13,7 +14,7 @@ namespace DHT.Server.Database.Export;
public static class ViewerJsonExport {
private static readonly Log Log = Log.ForType(typeof(ViewerJsonExport));
public static async Task Generate(Stream stream, IDatabaseFile db, MessageFilter? filter = null) {
public static async Task Generate(Stream stream, IViewerExportStrategy strategy, IDatabaseFile db, MessageFilter? filter = null) {
var perf = Log.Start();
var includedUserIds = new HashSet<ulong>();
@@ -48,7 +49,7 @@ public static class ViewerJsonExport {
Servers = servers,
Channels = channels
},
Data = GenerateMessageList(includedMessages, userIndices)
Data = GenerateMessageList(includedMessages, userIndices, strategy)
};
perf.Step("Generate value object");
@@ -124,7 +125,7 @@ public static class ViewerJsonExport {
return channels;
}
private static Dictionary<Snowflake, Dictionary<Snowflake, ViewerJson.JsonMessage>> GenerateMessageList(List<Message> includedMessages, Dictionary<ulong, int> userIndices) {
private static Dictionary<Snowflake, Dictionary<Snowflake, ViewerJson.JsonMessage>> GenerateMessageList(List<Message> includedMessages, Dictionary<ulong, int> userIndices, IViewerExportStrategy strategy) {
var data = new Dictionary<Snowflake, Dictionary<Snowflake, ViewerJson.JsonMessage>>();
foreach (var grouping in includedMessages.GroupBy(static message => message.Channel)) {
@@ -141,9 +142,9 @@ public static class ViewerJsonExport {
Te = message.EditTimestamp,
R = message.RepliedToId?.ToString(),
A = message.Attachments.IsEmpty ? null : message.Attachments.Select(static attachment => {
A = message.Attachments.IsEmpty ? null : message.Attachments.Select(attachment => {
var a = new ViewerJson.JsonMessageAttachment {
Url = attachment.DownloadUrl,
Url = strategy.GetAttachmentUrl(attachment),
Name = Uri.TryCreate(attachment.NormalizedUrl, UriKind.Absolute, out var uri) ? Path.GetFileName(uri.LocalPath) : attachment.NormalizedUrl
};

View File

@@ -4,8 +4,10 @@ using DHT.Server.Database.Repositories;
namespace DHT.Server.Database;
public interface IDatabaseFile : IAsyncDisposable {
public interface IDatabaseFile : IDisposable {
string Path { get; }
DatabaseStatistics Statistics { get; }
Task<DatabaseStatisticsSnapshot> SnapshotStatistics();
IUserRepository Users { get; }
IServerRepository Servers { get; }

View File

@@ -4,4 +4,4 @@ namespace DHT.Server.Database.Import;
[JsonSourceGenerationOptions(PropertyNamingPolicy = JsonKnownNamingPolicy.SnakeCaseLower, GenerationMode = JsonSourceGenerationMode.Default)]
[JsonSerializable(typeof(DiscordEmbedLegacyJson))]
sealed partial class DiscordEmbedLegacyJsonContext : JsonSerializerContext;
sealed partial class DiscordEmbedLegacyJsonContext : JsonSerializerContext {}

View File

@@ -161,7 +161,7 @@ public static class LegacyArchiveImport {
var messagesObj = data.HasKey(channelIdStr) ? data.RequireObject(channelIdStr, DataPath) : (JsonElement?) null;
if (messagesObj == null) {
return [];
return Array.Empty<Message>();
}
return messagesObj.Value.EnumerateObject().Select(item => {

View File

@@ -1,33 +1,20 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reactive.Linq;
using System.Threading;
using System.Threading.Tasks;
using DHT.Server.Data;
namespace DHT.Server.Database.Repositories;
public interface IChannelRepository {
IObservable<long> TotalCount { get; }
Task Add(IReadOnlyList<Channel> channels);
Task<long> Count(CancellationToken cancellationToken = default);
IAsyncEnumerable<Channel> Get();
internal sealed class Dummy : IChannelRepository {
public IObservable<long> TotalCount { get; } = Observable.Return(0L);
public Task Add(IReadOnlyList<Channel> channels) {
return Task.CompletedTask;
}
public Task<long> Count(CancellationToken cancellationToken) {
return Task.FromResult(0L);
}
public IAsyncEnumerable<Channel> Get() {
return AsyncEnumerable.Empty<Channel>();
}

View File

@@ -1,10 +1,8 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Reactive.Linq;
using System.Threading;
using System.Threading.Tasks;
using DHT.Server.Data;
using DHT.Server.Data.Aggregations;
using DHT.Server.Data.Filters;
using DHT.Server.Download;
@@ -12,63 +10,59 @@ using DHT.Server.Download;
namespace DHT.Server.Database.Repositories;
public interface IDownloadRepository {
IObservable<long> TotalCount { get; }
Task<long> CountAttachments(AttachmentFilter? filter = null, CancellationToken cancellationToken = default);
Task AddDownload(Data.Download item, Stream? stream);
Task AddDownload(Data.Download download);
Task<long> Count(DownloadItemFilter filter, CancellationToken cancellationToken = default);
Task<DownloadStatusStatistics> GetStatistics(CancellationToken cancellationToken = default);
Task<DownloadStatusStatistics> GetStatistics(DownloadItemFilter nonSkippedFilter, CancellationToken cancellationToken = default);
IAsyncEnumerable<Data.Download> Get();
IAsyncEnumerable<Data.Download> GetWithoutData();
Task<bool> GetDownloadData(string normalizedUrl, Func<Stream, Task> dataProcessor);
Task<bool> GetSuccessfulDownloadWithData(string normalizedUrl, Func<Data.Download, Stream, Task> dataProcessor);
Task<Data.Download> HydrateWithData(Data.Download download);
IAsyncEnumerable<DownloadItem> PullPendingDownloadItems(int count, DownloadItemFilter filter, CancellationToken cancellationToken = default);
Task MoveDownloadingItemsBackToQueue(CancellationToken cancellationToken = default);
Task<int> RetryFailed(CancellationToken cancellationToken = default);
Task<DownloadedAttachment?> GetDownloadedAttachment(string normalizedUrl);
Task EnqueueDownloadItems(AttachmentFilter? filter = null, CancellationToken cancellationToken = default);
IAsyncEnumerable<DownloadItem> PullEnqueuedDownloadItems(int count, CancellationToken cancellationToken = default);
Task RemoveDownloadItems(DownloadItemFilter? filter, FilterRemovalMode mode);
internal sealed class Dummy : IDownloadRepository {
public IObservable<long> TotalCount { get; } = Observable.Return(0L);
public Task AddDownload(Data.Download item, Stream? stream) {
return Task.CompletedTask;
}
public Task<long> Count(DownloadItemFilter filter, CancellationToken cancellationToken) {
public Task<long> CountAttachments(AttachmentFilter? filter, CancellationToken cancellationToken) {
return Task.FromResult(0L);
}
public Task<DownloadStatusStatistics> GetStatistics(DownloadItemFilter nonSkippedFilter, CancellationToken cancellationToken) {
return Task.FromResult(new DownloadStatusStatistics());
}
public IAsyncEnumerable<Data.Download> Get() {
return AsyncEnumerable.Empty<Data.Download>();
}
public Task<bool> GetDownloadData(string normalizedUrl, Func<Stream, Task> dataProcessor) {
return Task.FromResult(false);
}
public Task<bool> GetSuccessfulDownloadWithData(string normalizedUrl, Func<Data.Download, Stream, Task> dataProcessor) {
return Task.FromResult(false);
}
public IAsyncEnumerable<DownloadItem> PullPendingDownloadItems(int count, DownloadItemFilter filter, CancellationToken cancellationToken) {
return AsyncEnumerable.Empty<DownloadItem>();
}
public Task MoveDownloadingItemsBackToQueue(CancellationToken cancellationToken) {
public Task AddDownload(Data.Download download) {
return Task.CompletedTask;
}
public Task<int> RetryFailed(CancellationToken cancellationToken) {
return Task.FromResult(0);
public Task<DownloadStatusStatistics> GetStatistics(CancellationToken cancellationToken) {
return Task.FromResult(new DownloadStatusStatistics());
}
public IAsyncEnumerable<Data.Download> GetWithoutData() {
return AsyncEnumerable.Empty<Data.Download>();
}
public Task<Data.Download> HydrateWithData(Data.Download download) {
return Task.FromResult(download);
}
public Task<DownloadedAttachment?> GetDownloadedAttachment(string normalizedUrl) {
return Task.FromResult<DownloadedAttachment?>(null);
}
public Task EnqueueDownloadItems(AttachmentFilter? filter, CancellationToken cancellationToken) {
return Task.CompletedTask;
}
public IAsyncEnumerable<DownloadItem> PullEnqueuedDownloadItems(int count, CancellationToken cancellationToken) {
return AsyncEnumerable.Empty<DownloadItem>();
}
public Task RemoveDownloadItems(DownloadItemFilter? filter, FilterRemovalMode mode) {
return Task.CompletedTask;
}
}
}

View File

@@ -1,7 +1,5 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reactive.Linq;
using System.Threading;
using System.Threading.Tasks;
using DHT.Server.Data;
@@ -10,8 +8,6 @@ using DHT.Server.Data.Filters;
namespace DHT.Server.Database.Repositories;
public interface IMessageRepository {
IObservable<long> TotalCount { get; }
Task Add(IReadOnlyList<Message> messages);
Task<long> Count(MessageFilter? filter = null, CancellationToken cancellationToken = default);
@@ -23,8 +19,6 @@ public interface IMessageRepository {
Task Remove(MessageFilter filter, FilterRemovalMode mode);
internal sealed class Dummy : IMessageRepository {
public IObservable<long> TotalCount { get; } = Observable.Return(0L);
public Task Add(IReadOnlyList<Message> messages) {
return Task.CompletedTask;
}

View File

@@ -1,32 +1,19 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reactive.Linq;
using System.Threading;
using System.Threading.Tasks;
namespace DHT.Server.Database.Repositories;
public interface IServerRepository {
IObservable<long> TotalCount { get; }
Task Add(IReadOnlyList<Data.Server> servers);
Task<long> Count(CancellationToken cancellationToken = default);
IAsyncEnumerable<Data.Server> Get();
internal sealed class Dummy : IServerRepository {
public IObservable<long> TotalCount { get; } = Observable.Return(0L);
public Task Add(IReadOnlyList<Data.Server> servers) {
return Task.CompletedTask;
}
public Task<long> Count(CancellationToken cancellationToken) {
return Task.FromResult(0L);
}
public IAsyncEnumerable<Data.Server> Get() {
return AsyncEnumerable.Empty<Data.Server>();
}

View File

@@ -1,33 +1,20 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reactive.Linq;
using System.Threading;
using System.Threading.Tasks;
using DHT.Server.Data;
namespace DHT.Server.Database.Repositories;
public interface IUserRepository {
IObservable<long> TotalCount { get; }
Task Add(IReadOnlyList<User> users);
Task<long> Count(CancellationToken cancellationToken = default);
IAsyncEnumerable<User> Get();
internal sealed class Dummy : IUserRepository {
public IObservable<long> TotalCount { get; } = Observable.Return(0L);
public Task Add(IReadOnlyList<User> users) {
return Task.CompletedTask;
}
public Task<long> Count(CancellationToken cancellationToken) {
return Task.FromResult(0L);
}
public IAsyncEnumerable<User> Get() {
return AsyncEnumerable.Empty<User>();
}

View File

@@ -1,30 +0,0 @@
using System;
using System.Reactive.Linq;
using System.Threading;
using System.Threading.Tasks;
using DHT.Utils.Logging;
using DHT.Utils.Tasks;
namespace DHT.Server.Database.Sqlite.Repositories;
abstract class BaseSqliteRepository : IDisposable {
private readonly ObservableThrottledTask<long> totalCountTask;
public IObservable<long> TotalCount { get; }
protected BaseSqliteRepository(Log log) {
totalCountTask = new ObservableThrottledTask<long>(log, TaskScheduler.Default);
TotalCount = totalCountTask.DistinctUntilChanged();
UpdateTotalCount();
}
public void Dispose() {
totalCountTask.Dispose();
}
protected void UpdateTotalCount() {
totalCountTask.Post(Count);
}
public abstract Task<long> Count(CancellationToken cancellationToken);
}

View File

@@ -1,27 +1,34 @@
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using DHT.Server.Data;
using DHT.Server.Database.Repositories;
using DHT.Server.Database.Sqlite.Utils;
using DHT.Utils.Logging;
using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite.Repositories;
sealed class SqliteChannelRepository : BaseSqliteRepository, IChannelRepository {
private static readonly Log Log = Log.ForType<SqliteChannelRepository>();
sealed class SqliteChannelRepository : IChannelRepository {
private readonly SqliteConnectionPool pool;
private readonly DatabaseStatistics statistics;
public SqliteChannelRepository(SqliteConnectionPool pool) : base(Log) {
public SqliteChannelRepository(SqliteConnectionPool pool, DatabaseStatistics statistics) {
this.pool = pool;
this.statistics = statistics;
}
internal async Task Initialize() {
using var conn = pool.Take();
await UpdateChannelStatistics(conn);
}
private async Task UpdateChannelStatistics(ISqliteConnection conn) {
statistics.TotalChannels = await conn.ExecuteReaderAsync("SELECT COUNT(*) FROM channels", static reader => reader?.GetInt64(0) ?? 0L);
}
public async Task Add(IReadOnlyList<Channel> channels) {
await using (var conn = await pool.Take()) {
await conn.BeginTransactionAsync();
using var conn = pool.Take();
await using (var tx = await conn.BeginTransactionAsync()) {
await using var cmd = conn.Upsert("channels", [
("id", SqliteType.Integer),
("server", SqliteType.Integer),
@@ -43,24 +50,19 @@ sealed class SqliteChannelRepository : BaseSqliteRepository, IChannelRepository
await cmd.ExecuteNonQueryAsync();
}
await conn.CommitTransactionAsync();
await tx.CommitAsync();
}
UpdateTotalCount();
}
public override async Task<long> Count(CancellationToken cancellationToken) {
await using var conn = await pool.Take();
return await conn.ExecuteReaderAsync("SELECT COUNT(*) FROM channels", static reader => reader?.GetInt64(0) ?? 0L, cancellationToken);
await UpdateChannelStatistics(conn);
}
public async IAsyncEnumerable<Channel> Get() {
await using var conn = await pool.Take();
using var conn = pool.Take();
await using var cmd = conn.Command("SELECT id, server, name, parent_id, position, topic, nsfw FROM channels");
await using var reader = await cmd.ExecuteReaderAsync();
while (await reader.ReadAsync()) {
while (reader.Read()) {
yield return new Channel {
Id = reader.GetUint64(0),
Server = reader.GetUint64(1),

View File

@@ -1,6 +1,5 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
@@ -10,274 +9,201 @@ using DHT.Server.Data.Filters;
using DHT.Server.Database.Repositories;
using DHT.Server.Database.Sqlite.Utils;
using DHT.Server.Download;
using DHT.Utils.Logging;
using DHT.Utils.Tasks;
using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite.Repositories;
sealed class SqliteDownloadRepository(SqliteConnectionPool pool) : BaseSqliteRepository(Log), IDownloadRepository {
private static readonly Log Log = Log.ForType<SqliteDownloadRepository>();
sealed class SqliteDownloadRepository : IDownloadRepository {
private readonly SqliteConnectionPool pool;
private readonly AsyncValueComputer<long>.Single totalDownloadsComputer;
internal sealed class NewDownloadCollector : IAsyncDisposable {
private readonly SqliteDownloadRepository repository;
private bool hasAdded = false;
private readonly SqliteCommand metadataCmd;
public NewDownloadCollector(SqliteDownloadRepository repository, ISqliteConnection conn) {
this.repository = repository;
metadataCmd = conn.Command(
"""
INSERT INTO download_metadata (normalized_url, download_url, status, type, size)
VALUES (:normalized_url, :download_url, :status, :type, :size)
ON CONFLICT DO NOTHING
"""
);
metadataCmd.Add(":normalized_url", SqliteType.Text);
metadataCmd.Add(":download_url", SqliteType.Text);
metadataCmd.Add(":status", SqliteType.Integer);
metadataCmd.Add(":type", SqliteType.Text);
metadataCmd.Add(":size", SqliteType.Integer);
}
public async Task Add(Data.Download download) {
metadataCmd.Set(":normalized_url", download.NormalizedUrl);
metadataCmd.Set(":download_url", download.DownloadUrl);
metadataCmd.Set(":status", (int) download.Status);
metadataCmd.Set(":type", download.Type);
metadataCmd.Set(":size", download.Size);
hasAdded |= await metadataCmd.ExecuteNonQueryAsync() > 0;
}
public void OnCommitted() {
if (hasAdded) {
repository.UpdateTotalCount();
}
}
public async ValueTask DisposeAsync() {
await metadataCmd.DisposeAsync();
}
public SqliteDownloadRepository(SqliteConnectionPool pool, AsyncValueComputer<long>.Single totalDownloadsComputer) {
this.pool = pool;
this.totalDownloadsComputer = totalDownloadsComputer;
}
public async Task AddDownload(Data.Download item, Stream? stream) {
await using (var conn = await pool.Take()) {
await conn.BeginTransactionAsync();
await using var metadataCmd = conn.Upsert("download_metadata", [
public async Task<long> CountAttachments(AttachmentFilter? filter, CancellationToken cancellationToken) {
using var conn = pool.Take();
return await conn.ExecuteReaderAsync("SELECT COUNT(DISTINCT normalized_url) FROM attachments a" + filter.GenerateWhereClause("a"), static reader => reader?.GetInt64(0) ?? 0L, cancellationToken);
}
public async Task AddDownload(Data.Download download) {
using (var conn = pool.Take()) {
await using var cmd = conn.Upsert("downloads", [
("normalized_url", SqliteType.Text),
("download_url", SqliteType.Text),
("status", SqliteType.Integer),
("type", SqliteType.Text),
("size", SqliteType.Integer),
("blob", SqliteType.Blob)
]);
metadataCmd.Set(":normalized_url", item.NormalizedUrl);
metadataCmd.Set(":download_url", item.DownloadUrl);
metadataCmd.Set(":status", (int) item.Status);
metadataCmd.Set(":type", item.Type);
metadataCmd.Set(":size", item.Size);
await metadataCmd.ExecuteNonQueryAsync();
if (stream == null) {
await using var deleteBlobCmd = conn.Command("DELETE FROM download_blobs WHERE normalized_url = :normalized_url");
deleteBlobCmd.AddAndSet(":normalized_url", SqliteType.Text, item.NormalizedUrl);
await deleteBlobCmd.ExecuteNonQueryAsync();
}
else {
await using var upsertBlobCmd = conn.Command(
"""
INSERT INTO download_blobs (normalized_url, blob)
VALUES (:normalized_url, ZEROBLOB(:blob_length))
ON CONFLICT (normalized_url) DO UPDATE SET blob = excluded.blob
RETURNING rowid
"""
);
upsertBlobCmd.AddAndSet(":normalized_url", SqliteType.Text, item.NormalizedUrl);
upsertBlobCmd.AddAndSet(":blob_length", SqliteType.Integer, item.Size);
long rowid = await upsertBlobCmd.ExecuteLongScalarAsync();
await using var blob = new SqliteBlob(conn.InnerConnection, "download_blobs", "blob", rowid);
await stream.CopyToAsync(blob);
}
await conn.CommitTransactionAsync();
cmd.Set(":normalized_url", download.NormalizedUrl);
cmd.Set(":download_url", download.DownloadUrl);
cmd.Set(":status", (int) download.Status);
cmd.Set(":size", download.Size);
cmd.Set(":blob", download.Data);
await cmd.ExecuteNonQueryAsync();
}
UpdateTotalCount();
totalDownloadsComputer.Recompute();
}
public override Task<long> Count(CancellationToken cancellationToken) {
return Count(filter: null, cancellationToken);
}
public async Task<DownloadStatusStatistics> GetStatistics(CancellationToken cancellationToken) {
static async Task LoadUndownloadedStatistics(ISqliteConnection conn, DownloadStatusStatistics result, CancellationToken cancellationToken) {
await using var cmd = conn.Command(
"""
SELECT IFNULL(COUNT(size), 0), IFNULL(SUM(size), 0)
FROM (SELECT MAX(a.size) size
FROM attachments a
WHERE a.normalized_url NOT IN (SELECT d.normalized_url FROM downloads d)
GROUP BY a.normalized_url)
""");
public async Task<long> Count(DownloadItemFilter? filter, CancellationToken cancellationToken) {
await using var conn = await pool.Take();
return await conn.ExecuteReaderAsync("SELECT COUNT(*) FROM download_metadata" + filter.GenerateConditions().BuildWhereClause(), static reader => reader?.GetInt64(0) ?? 0L, cancellationToken);
}
await using var reader = await cmd.ExecuteReaderAsync(cancellationToken);
public async Task<DownloadStatusStatistics> GetStatistics(DownloadItemFilter nonSkippedFilter, CancellationToken cancellationToken) {
nonSkippedFilter.IncludeStatuses = null;
nonSkippedFilter.ExcludeStatuses = null;
string nonSkippedFilterConditions = nonSkippedFilter.GenerateConditions().Build();
await using var conn = await pool.Take();
await using var cmd = conn.Command(
$"""
SELECT
IFNULL(SUM(CASE WHEN (status = :downloading) OR (status = :pending AND {nonSkippedFilterConditions}) THEN 1 ELSE 0 END), 0),
IFNULL(SUM(CASE WHEN (status = :downloading) OR (status = :pending AND {nonSkippedFilterConditions}) THEN IFNULL(size, 0) ELSE 0 END), 0),
IFNULL(SUM(CASE WHEN ((status = :downloading) OR (status = :pending AND {nonSkippedFilterConditions})) AND size IS NULL THEN 1 ELSE 0 END), 0),
IFNULL(SUM(CASE WHEN status = :success THEN 1 ELSE 0 END), 0),
IFNULL(SUM(CASE WHEN status = :success THEN IFNULL(size, 0) ELSE 0 END), 0),
IFNULL(SUM(CASE WHEN status = :success AND size IS NULL THEN 1 ELSE 0 END), 0),
IFNULL(SUM(CASE WHEN status NOT IN (:pending, :downloading, :success) THEN 1 ELSE 0 END), 0),
IFNULL(SUM(CASE WHEN status NOT IN (:pending, :downloading, :success) THEN IFNULL(size, 0) ELSE 0 END), 0),
IFNULL(SUM(CASE WHEN status NOT IN (:pending, :downloading, :success) AND size IS NULL THEN 1 ELSE 0 END), 0),
IFNULL(SUM(CASE WHEN status = :pending AND NOT ({nonSkippedFilterConditions}) THEN 1 ELSE 0 END), 0),
IFNULL(SUM(CASE WHEN status = :pending AND NOT ({nonSkippedFilterConditions}) THEN IFNULL(size, 0) ELSE 0 END), 0),
IFNULL(SUM(CASE WHEN status = :pending AND NOT ({nonSkippedFilterConditions}) AND size IS NULL THEN 1 ELSE 0 END), 0)
FROM download_metadata
"""
);
cmd.AddAndSet(":pending", SqliteType.Integer, (int) DownloadStatus.Pending);
cmd.AddAndSet(":downloading", SqliteType.Integer, (int) DownloadStatus.Downloading);
cmd.AddAndSet(":success", SqliteType.Integer, (int) DownloadStatus.Success);
await using var reader = await cmd.ExecuteReaderAsync(cancellationToken);
if (!await reader.ReadAsync(cancellationToken)) {
return new DownloadStatusStatistics();
if (reader.Read()) {
result.SkippedCount = reader.GetInt32(0);
result.SkippedSize = reader.GetUint64(1);
}
}
return new DownloadStatusStatistics {
PendingCount = reader.GetInt32(0),
PendingTotalSize = reader.GetUint64(1),
PendingWithUnknownSizeCount = reader.GetInt32(2),
SuccessfulCount = reader.GetInt32(3),
SuccessfulTotalSize = reader.GetUint64(4),
SuccessfulWithUnknownSizeCount = reader.GetInt32(5),
FailedCount = reader.GetInt32(6),
FailedTotalSize = reader.GetUint64(7),
FailedWithUnknownSizeCount = reader.GetInt32(8),
SkippedCount = reader.GetInt32(9),
SkippedTotalSize = reader.GetUint64(10),
SkippedWithUnknownSizeCount = reader.GetInt32(11)
};
static async Task LoadSuccessStatistics(ISqliteConnection conn, DownloadStatusStatistics result, CancellationToken cancellationToken) {
await using var cmd = conn.Command(
"""
SELECT
IFNULL(SUM(CASE WHEN status IN (:enqueued, :downloading) THEN 1 ELSE 0 END), 0),
IFNULL(SUM(CASE WHEN status IN (:enqueued, :downloading) THEN size ELSE 0 END), 0),
IFNULL(SUM(CASE WHEN status = :success THEN 1 ELSE 0 END), 0),
IFNULL(SUM(CASE WHEN status = :success THEN size ELSE 0 END), 0),
IFNULL(SUM(CASE WHEN status NOT IN (:enqueued, :downloading) AND status != :success THEN 1 ELSE 0 END), 0),
IFNULL(SUM(CASE WHEN status NOT IN (:enqueued, :downloading) AND status != :success THEN size ELSE 0 END), 0)
FROM downloads
"""
);
cmd.AddAndSet(":enqueued", SqliteType.Integer, (int) DownloadStatus.Enqueued);
cmd.AddAndSet(":downloading", SqliteType.Integer, (int) DownloadStatus.Downloading);
cmd.AddAndSet(":success", SqliteType.Integer, (int) DownloadStatus.Success);
await using var reader = await cmd.ExecuteReaderAsync(cancellationToken);
if (reader.Read()) {
result.EnqueuedCount = reader.GetInt32(0);
result.EnqueuedSize = reader.GetUint64(1);
result.SuccessfulCount = reader.GetInt32(2);
result.SuccessfulSize = reader.GetUint64(3);
result.FailedCount = reader.GetInt32(4);
result.FailedSize = reader.GetUint64(5);
}
}
var result = new DownloadStatusStatistics();
using var conn = pool.Take();
await LoadUndownloadedStatistics(conn, result, cancellationToken);
await LoadSuccessStatistics(conn, result, cancellationToken);
return result;
}
public async IAsyncEnumerable<Data.Download> Get() {
await using var conn = await pool.Take();
public async IAsyncEnumerable<Data.Download> GetWithoutData() {
using var conn = pool.Take();
await using var cmd = conn.Command("SELECT normalized_url, download_url, status, type, size FROM download_metadata");
await using var cmd = conn.Command("SELECT normalized_url, download_url, status, size FROM downloads");
await using var reader = await cmd.ExecuteReaderAsync();
while (await reader.ReadAsync()) {
while (reader.Read()) {
string normalizedUrl = reader.GetString(0);
string downloadUrl = reader.GetString(1);
var status = (DownloadStatus) reader.GetInt32(2);
string? type = reader.IsDBNull(3) ? null : reader.GetString(3);
ulong? size = reader.IsDBNull(4) ? null : reader.GetUint64(4);
ulong size = reader.GetUint64(3);
yield return new Data.Download(normalizedUrl, downloadUrl, status, type, size);
yield return new Data.Download(normalizedUrl, downloadUrl, status, size);
}
}
public async Task<bool> GetDownloadData(string normalizedUrl, Func<Stream, Task> dataProcessor) {
await using var conn = await pool.Take();
public async Task<Data.Download> HydrateWithData(Data.Download download) {
using var conn = pool.Take();
await using var cmd = conn.Command("SELECT rowid FROM download_blobs WHERE normalized_url = :normalized_url");
cmd.AddAndSet(":normalized_url", SqliteType.Text, normalizedUrl);
long rowid;
await using (var reader = await cmd.ExecuteReaderAsync()) {
if (!await reader.ReadAsync()) {
return false;
}
await using var cmd = conn.Command("SELECT blob FROM downloads WHERE normalized_url = :url");
cmd.AddAndSet(":url", SqliteType.Text, download.NormalizedUrl);
rowid = reader.GetInt64(0);
await using var reader = await cmd.ExecuteReaderAsync();
if (reader.Read() && !reader.IsDBNull(0)) {
return download.WithData((byte[]) reader["blob"]);
}
await using (var blob = new SqliteBlob(conn.InnerConnection, "download_blobs", "blob", rowid, readOnly: true)) {
await dataProcessor(blob);
else {
return download;
}
return true;
}
public async Task<bool> GetSuccessfulDownloadWithData(string normalizedUrl, Func<Data.Download, Stream, Task> dataProcessor) {
await using var conn = await pool.Take();
public async Task<DownloadedAttachment?> GetDownloadedAttachment(string normalizedUrl) {
using var conn = pool.Take();
await using var cmd = conn.Command(
"""
SELECT dm.download_url, dm.type, db.rowid FROM download_metadata dm
JOIN download_blobs db ON dm.normalized_url = db.normalized_url
WHERE dm.normalized_url = :normalized_url AND dm.status = :success IS NOT NULL
SELECT a.type, d.blob FROM downloads d
LEFT JOIN attachments a ON d.normalized_url = a.normalized_url
WHERE d.normalized_url = :normalized_url AND d.status = :success AND d.blob IS NOT NULL
"""
);
cmd.AddAndSet(":normalized_url", SqliteType.Text, normalizedUrl);
cmd.AddAndSet(":success", SqliteType.Integer, (int) DownloadStatus.Success);
string downloadUrl;
string? type;
long rowid;
await using (var reader = await cmd.ExecuteReaderAsync()) {
if (!await reader.ReadAsync()) {
return false;
}
await using var reader = await cmd.ExecuteReaderAsync();
downloadUrl = reader.GetString(0);
type = reader.IsDBNull(1) ? null : reader.GetString(1);
rowid = reader.GetInt64(2);
}
await using (var blob = new SqliteBlob(conn.InnerConnection, "download_blobs", "blob", rowid, readOnly: true)) {
await dataProcessor(new Data.Download(normalizedUrl, downloadUrl, DownloadStatus.Success, type, (ulong) blob.Length), blob);
if (!reader.Read()) {
return null;
}
return true;
return new DownloadedAttachment {
Type = reader.IsDBNull(0) ? null : reader.GetString(0),
Data = (byte[]) reader["blob"],
};
}
public async IAsyncEnumerable<DownloadItem> PullPendingDownloadItems(int count, DownloadItemFilter filter, [EnumeratorCancellation] CancellationToken cancellationToken) {
filter.IncludeStatuses = [DownloadStatus.Pending];
filter.ExcludeStatuses = null;
public async Task EnqueueDownloadItems(AttachmentFilter? filter, CancellationToken cancellationToken) {
using var conn = pool.Take();
await using var cmd = conn.Command(
$"""
INSERT INTO downloads (normalized_url, download_url, status, size)
SELECT a.normalized_url, a.download_url, :enqueued, MAX(a.size)
FROM attachments a
{filter.GenerateWhereClause("a")}
GROUP BY a.normalized_url
"""
);
cmd.AddAndSet(":enqueued", SqliteType.Integer, (int) DownloadStatus.Enqueued);
await cmd.ExecuteNonQueryAsync(cancellationToken);
}
public async IAsyncEnumerable<DownloadItem> PullEnqueuedDownloadItems(int count, [EnumeratorCancellation] CancellationToken cancellationToken) {
var found = new List<DownloadItem>();
await using var conn = await pool.Take();
using var conn = pool.Take();
var sql = $"""
SELECT normalized_url, download_url, type, size
FROM download_metadata
{filter.GenerateConditions().BuildWhereClause()}
LIMIT :limit
""";
await using (var cmd = conn.Command(sql)) {
await using (var cmd = conn.Command("SELECT normalized_url, download_url, size FROM downloads WHERE status = :enqueued LIMIT :limit")) {
cmd.AddAndSet(":enqueued", SqliteType.Integer, (int) DownloadStatus.Enqueued);
cmd.AddAndSet(":limit", SqliteType.Integer, Math.Max(0, count));
await using var reader = await cmd.ExecuteReaderAsync(cancellationToken);
while (await reader.ReadAsync(cancellationToken)) {
while (reader.Read()) {
found.Add(new DownloadItem {
NormalizedUrl = reader.GetString(0),
DownloadUrl = reader.GetString(1),
Type = reader.IsDBNull(2) ? null : reader.GetString(2),
Size = reader.IsDBNull(3) ? null : reader.GetUint64(3)
Size = reader.GetUint64(2),
});
}
}
if (found.Count != 0) {
await using var cmd = conn.Command("UPDATE download_metadata SET status = :downloading WHERE normalized_url = :normalized_url AND status = :pending");
cmd.AddAndSet(":pending", SqliteType.Integer, (int) DownloadStatus.Pending);
await using var cmd = conn.Command("UPDATE downloads SET status = :downloading WHERE normalized_url = :normalized_url AND status = :enqueued");
cmd.AddAndSet(":enqueued", SqliteType.Integer, (int) DownloadStatus.Enqueued);
cmd.AddAndSet(":downloading", SqliteType.Integer, (int) DownloadStatus.Downloading);
cmd.Add(":normalized_url", SqliteType.Text);
@@ -291,23 +217,17 @@ sealed class SqliteDownloadRepository(SqliteConnectionPool pool) : BaseSqliteRep
}
}
public async Task MoveDownloadingItemsBackToQueue(CancellationToken cancellationToken) {
await using var conn = await pool.Take();
public async Task RemoveDownloadItems(DownloadItemFilter? filter, FilterRemovalMode mode) {
using (var conn = pool.Take()) {
await conn.ExecuteAsync(
$"""
-- noinspection SqlWithoutWhere
DELETE FROM downloads
{filter.GenerateWhereClause(invert: mode == FilterRemovalMode.KeepMatching)}
"""
);
}
await using var cmd = conn.Command("UPDATE download_metadata SET status = :pending WHERE status = :downloading");
cmd.AddAndSet(":pending", SqliteType.Integer, (int) DownloadStatus.Pending);
cmd.AddAndSet(":downloading", SqliteType.Integer, (int) DownloadStatus.Downloading);
await cmd.ExecuteNonQueryAsync(cancellationToken);
}
public async Task<int> RetryFailed(CancellationToken cancellationToken) {
await using var conn = await pool.Take();
await using var cmd = conn.Command("UPDATE download_metadata SET status = :pending WHERE status = :generic_error OR (status > :last_custom_code AND status != :success)");
cmd.AddAndSet(":pending", SqliteType.Integer, (int) DownloadStatus.Pending);
cmd.AddAndSet(":generic_error", SqliteType.Integer, (int) DownloadStatus.GenericError);
cmd.AddAndSet(":last_custom_code", SqliteType.Integer, (int) DownloadStatus.LastCustomCode);
cmd.AddAndSet(":success", SqliteType.Integer, (int) DownloadStatus.Success);
return await cmd.ExecuteNonQueryAsync(cancellationToken);
totalDownloadsComputer.Recompute();
}
}

View File

@@ -7,21 +7,20 @@ using DHT.Server.Data;
using DHT.Server.Data.Filters;
using DHT.Server.Database.Repositories;
using DHT.Server.Database.Sqlite.Utils;
using DHT.Server.Download;
using DHT.Utils.Logging;
using DHT.Utils.Tasks;
using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite.Repositories;
sealed class SqliteMessageRepository : BaseSqliteRepository, IMessageRepository {
private static readonly Log Log = Log.ForType<SqliteMessageRepository>();
sealed class SqliteMessageRepository : IMessageRepository {
private readonly SqliteConnectionPool pool;
private readonly SqliteDownloadRepository downloads;
private readonly AsyncValueComputer<long>.Single totalMessagesComputer;
private readonly AsyncValueComputer<long>.Single totalAttachmentsComputer;
public SqliteMessageRepository(SqliteConnectionPool pool, SqliteDownloadRepository downloads) : base(Log) {
public SqliteMessageRepository(SqliteConnectionPool pool, AsyncValueComputer<long>.Single totalMessagesComputer, AsyncValueComputer<long>.Single totalAttachmentsComputer) {
this.pool = pool;
this.downloads = downloads;
this.totalMessagesComputer = totalMessagesComputer;
this.totalAttachmentsComputer = totalAttachmentsComputer;
}
public async Task Add(IReadOnlyList<Message> messages) {
@@ -38,8 +37,10 @@ sealed class SqliteMessageRepository : BaseSqliteRepository, IMessageRepository
await cmd.ExecuteNonQueryAsync();
}
await using (var conn = await pool.Take()) {
await conn.BeginTransactionAsync();
bool addedAttachments = false;
using (var conn = pool.Take()) {
await using var tx = await conn.BeginTransactionAsync();
await using var messageCmd = conn.Upsert("messages", [
("message_id", SqliteType.Integer),
@@ -90,8 +91,6 @@ sealed class SqliteMessageRepository : BaseSqliteRepository, IMessageRepository
("emoji_flags", SqliteType.Integer),
("count", SqliteType.Integer)
]);
await using var downloadCollector = new SqliteDownloadRepository.NewDownloadCollector(downloads, conn);
foreach (var message in messages) {
object messageId = message.Id;
@@ -123,6 +122,8 @@ sealed class SqliteMessageRepository : BaseSqliteRepository, IMessageRepository
}
if (!message.Attachments.IsEmpty) {
addedAttachments = true;
foreach (var attachment in message.Attachments) {
attachmentCmd.Set(":message_id", messageId);
attachmentCmd.Set(":attachment_id", attachment.Id);
@@ -134,8 +135,6 @@ sealed class SqliteMessageRepository : BaseSqliteRepository, IMessageRepository
attachmentCmd.Set(":width", attachment.Width);
attachmentCmd.Set(":height", attachment.Height);
await attachmentCmd.ExecuteNonQueryAsync();
await downloadCollector.Add(DownloadLinkExtractor.FromAttachment(attachment));
}
}
@@ -144,10 +143,6 @@ sealed class SqliteMessageRepository : BaseSqliteRepository, IMessageRepository
embedCmd.Set(":message_id", messageId);
embedCmd.Set(":json", embed.Json);
await embedCmd.ExecuteNonQueryAsync();
if (DownloadLinkExtractor.TryFromEmbedJson(embed.Json) is {} download) {
await downloadCollector.Add(download);
}
}
}
@@ -159,35 +154,30 @@ sealed class SqliteMessageRepository : BaseSqliteRepository, IMessageRepository
reactionCmd.Set(":emoji_flags", (int) reaction.EmojiFlags);
reactionCmd.Set(":count", reaction.Count);
await reactionCmd.ExecuteNonQueryAsync();
if (reaction.EmojiId is {} emojiId) {
await downloadCollector.Add(DownloadLinkExtractor.FromEmoji(emojiId, reaction.EmojiFlags));
}
}
}
}
await conn.CommitTransactionAsync();
downloadCollector.OnCommitted();
await tx.CommitAsync();
}
UpdateTotalCount();
totalMessagesComputer.Recompute();
if (addedAttachments) {
totalAttachmentsComputer.Recompute();
}
}
public override Task<long> Count(CancellationToken cancellationToken) {
return Count(filter: null, cancellationToken);
}
public async Task<long> Count(MessageFilter? filter, CancellationToken cancellationToken) {
await using var conn = await pool.Take();
return await conn.ExecuteReaderAsync("SELECT COUNT(*) FROM messages" + filter.GenerateConditions().BuildWhereClause(), static reader => reader?.GetInt64(0) ?? 0L, cancellationToken);
using var conn = pool.Take();
return await conn.ExecuteReaderAsync("SELECT COUNT(*) FROM messages" + filter.GenerateWhereClause(), static reader => reader?.GetInt64(0) ?? 0L, cancellationToken);
}
private sealed class MessageToManyCommand<T> : IAsyncDisposable {
private sealed class MesageToManyCommand<T> : IAsyncDisposable {
private readonly SqliteCommand cmd;
private readonly Func<SqliteDataReader, T> readItem;
public MessageToManyCommand(ISqliteConnection conn, string sql, Func<SqliteDataReader, T> readItem) {
public MesageToManyCommand(ISqliteConnection conn, string sql, Func<SqliteDataReader, T> readItem) {
this.cmd = conn.Command(sql);
this.cmd.Add(":message_id", SqliteType.Integer);
@@ -214,7 +204,7 @@ sealed class SqliteMessageRepository : BaseSqliteRepository, IMessageRepository
}
public async IAsyncEnumerable<Message> Get(MessageFilter? filter) {
await using var conn = await pool.Take();
using var conn = pool.Take();
const string AttachmentSql =
"""
@@ -223,7 +213,7 @@ sealed class SqliteMessageRepository : BaseSqliteRepository, IMessageRepository
WHERE message_id = :message_id
""";
await using var attachmentCmd = new MessageToManyCommand<Attachment>(conn, AttachmentSql, static reader => new Attachment {
await using var attachmentCmd = new MesageToManyCommand<Attachment>(conn, AttachmentSql, static reader => new Attachment {
Id = reader.GetUint64(0),
Name = reader.GetString(1),
Type = reader.IsDBNull(2) ? null : reader.GetString(2),
@@ -241,7 +231,7 @@ sealed class SqliteMessageRepository : BaseSqliteRepository, IMessageRepository
WHERE message_id = :message_id
""";
await using var embedCmd = new MessageToManyCommand<Embed>(conn, EmbedSql, static reader => new Embed {
await using var embedCmd = new MesageToManyCommand<Embed>(conn, EmbedSql, static reader => new Embed {
Json = reader.GetString(0)
});
@@ -252,7 +242,7 @@ sealed class SqliteMessageRepository : BaseSqliteRepository, IMessageRepository
WHERE message_id = :message_id
""";
await using var reactionsCmd = new MessageToManyCommand<Reaction>(conn, ReactionSql, static reader => new Reaction {
await using var reactionsCmd = new MesageToManyCommand<Reaction>(conn, ReactionSql, static reader => new Reaction {
EmojiId = reader.IsDBNull(0) ? null : reader.GetUint64(0),
EmojiName = reader.IsDBNull(1) ? null : reader.GetString(1),
EmojiFlags = (EmojiFlags) reader.GetInt16(2),
@@ -265,7 +255,7 @@ sealed class SqliteMessageRepository : BaseSqliteRepository, IMessageRepository
FROM messages m
LEFT JOIN edit_timestamps et ON m.message_id = et.message_id
LEFT JOIN replied_to rt ON m.message_id = rt.message_id
{filter.GenerateConditions("m").BuildWhereClause()}
{filter.GenerateWhereClause("m")}
"""
);
@@ -290,9 +280,9 @@ sealed class SqliteMessageRepository : BaseSqliteRepository, IMessageRepository
}
public async IAsyncEnumerable<ulong> GetIds(MessageFilter? filter) {
await using var conn = await pool.Take();
using var conn = pool.Take();
await using var cmd = conn.Command("SELECT message_id FROM messages" + filter.GenerateConditions().BuildWhereClause());
await using var cmd = conn.Command("SELECT message_id FROM messages" + filter.GenerateWhereClause());
await using var reader = await cmd.ExecuteReaderAsync();
while (await reader.ReadAsync()) {
@@ -301,16 +291,16 @@ sealed class SqliteMessageRepository : BaseSqliteRepository, IMessageRepository
}
public async Task Remove(MessageFilter filter, FilterRemovalMode mode) {
await using (var conn = await pool.Take()) {
using (var conn = pool.Take()) {
await conn.ExecuteAsync(
$"""
-- noinspection SqlWithoutWhere
DELETE FROM messages
{filter.GenerateConditions(invert: mode == FilterRemovalMode.KeepMatching).BuildWhereClause()}
{filter.GenerateWhereClause(invert: mode == FilterRemovalMode.KeepMatching)}
"""
);
}
UpdateTotalCount();
totalMessagesComputer.Recompute();
}
}

View File

@@ -1,27 +1,34 @@
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using DHT.Server.Data;
using DHT.Server.Database.Repositories;
using DHT.Server.Database.Sqlite.Utils;
using DHT.Utils.Logging;
using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite.Repositories;
sealed class SqliteServerRepository : BaseSqliteRepository, IServerRepository {
private static readonly Log Log = Log.ForType<SqliteServerRepository>();
sealed class SqliteServerRepository : IServerRepository {
private readonly SqliteConnectionPool pool;
private readonly DatabaseStatistics statistics;
public SqliteServerRepository(SqliteConnectionPool pool) : base(Log) {
public SqliteServerRepository(SqliteConnectionPool pool, DatabaseStatistics statistics) {
this.pool = pool;
this.statistics = statistics;
}
internal async Task Initialize() {
using var conn = pool.Take();
await UpdateServerStatistics(conn);
}
private async Task UpdateServerStatistics(ISqliteConnection conn) {
statistics.TotalServers = await conn.ExecuteReaderAsync("SELECT COUNT(*) FROM servers", static reader => reader?.GetInt64(0) ?? 0L);
}
public async Task Add(IReadOnlyList<Data.Server> servers) {
await using (var conn = await pool.Take()) {
await conn.BeginTransactionAsync();
using var conn = pool.Take();
await using (var tx = await conn.BeginTransactionAsync()) {
await using var cmd = conn.Upsert("servers", [
("id", SqliteType.Integer),
("name", SqliteType.Text),
@@ -35,24 +42,19 @@ sealed class SqliteServerRepository : BaseSqliteRepository, IServerRepository {
await cmd.ExecuteNonQueryAsync();
}
await conn.CommitTransactionAsync();
await tx.CommitAsync();
}
UpdateTotalCount();
}
public override async Task<long> Count(CancellationToken cancellationToken) {
await using var conn = await pool.Take();
return await conn.ExecuteReaderAsync("SELECT COUNT(*) FROM servers", static reader => reader?.GetInt64(0) ?? 0L, cancellationToken);
await UpdateServerStatistics(conn);
}
public async IAsyncEnumerable<Data.Server> Get() {
await using var conn = await pool.Take();
using var conn = pool.Take();
await using var cmd = conn.Command("SELECT id, name, type FROM servers");
await using var reader = await cmd.ExecuteReaderAsync();
while (await reader.ReadAsync()) {
while (reader.Read()) {
yield return new Data.Server {
Id = reader.GetUint64(0),
Name = reader.GetString(1),

View File

@@ -1,30 +1,34 @@
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using DHT.Server.Data;
using DHT.Server.Database.Repositories;
using DHT.Server.Database.Sqlite.Utils;
using DHT.Server.Download;
using DHT.Utils.Logging;
using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite.Repositories;
sealed class SqliteUserRepository : BaseSqliteRepository, IUserRepository {
private static readonly Log Log = Log.ForType<SqliteUserRepository>();
sealed class SqliteUserRepository : IUserRepository {
private readonly SqliteConnectionPool pool;
private readonly SqliteDownloadRepository downloads;
public SqliteUserRepository(SqliteConnectionPool pool, SqliteDownloadRepository downloads) : base(Log) {
private readonly DatabaseStatistics statistics;
public SqliteUserRepository(SqliteConnectionPool pool, DatabaseStatistics statistics) {
this.pool = pool;
this.downloads = downloads;
this.statistics = statistics;
}
internal async Task Initialize() {
using var conn = pool.Take();
await UpdateUserStatistics(conn);
}
private async Task UpdateUserStatistics(ISqliteConnection conn) {
statistics.TotalUsers = await conn.ExecuteReaderAsync("SELECT COUNT(*) FROM users", static reader => reader?.GetInt64(0) ?? 0L);
}
public async Task Add(IReadOnlyList<User> users) {
await using (var conn = await pool.Take()) {
await conn.BeginTransactionAsync();
using var conn = pool.Take();
await using (var tx = await conn.BeginTransactionAsync()) {
await using var cmd = conn.Upsert("users", [
("id", SqliteType.Integer),
("name", SqliteType.Text),
@@ -32,39 +36,27 @@ sealed class SqliteUserRepository : BaseSqliteRepository, IUserRepository {
("discriminator", SqliteType.Text)
]);
await using var downloadCollector = new SqliteDownloadRepository.NewDownloadCollector(downloads, conn);
foreach (var user in users) {
cmd.Set(":id", user.Id);
cmd.Set(":name", user.Name);
cmd.Set(":avatar_url", user.AvatarUrl);
cmd.Set(":discriminator", user.Discriminator);
await cmd.ExecuteNonQueryAsync();
if (user.AvatarUrl is {} avatarUrl) {
await downloadCollector.Add(DownloadLinkExtractor.FromUserAvatar(user.Id, avatarUrl));
}
}
await conn.CommitTransactionAsync();
downloadCollector.OnCommitted();
await tx.CommitAsync();
}
UpdateTotalCount();
}
public override async Task<long> Count(CancellationToken cancellationToken) {
await using var conn = await pool.Take();
return await conn.ExecuteReaderAsync("SELECT COUNT(*) FROM users", static reader => reader?.GetInt64(0) ?? 0L, cancellationToken);
await UpdateUserStatistics(conn);
}
public async IAsyncEnumerable<User> Get() {
await using var conn = await pool.Take();
using var conn = pool.Take();
await using var cmd = conn.Command("SELECT id, name, avatar_url, discriminator FROM users");
await using var reader = await cmd.ExecuteReaderAsync();
while (await reader.ReadAsync()) {
while (reader.Read()) {
yield return new User {
Id = reader.GetUint64(0),
Name = reader.GetString(1),

View File

@@ -0,0 +1,358 @@
using System.Collections.Generic;
using System.Data.Common;
using System.Threading.Tasks;
using DHT.Server.Database.Exceptions;
using DHT.Server.Database.Sqlite.Utils;
using DHT.Server.Download;
using DHT.Utils.Logging;
using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite;
sealed class Schema {
internal const int Version = 6;
private static readonly Log Log = Log.ForType<Schema>();
private readonly ISqliteConnection conn;
public Schema(ISqliteConnection conn) {
this.conn = conn;
}
public async Task<bool> Setup(ISchemaUpgradeCallbacks callbacks) {
conn.Execute(@"CREATE TABLE IF NOT EXISTS metadata (key TEXT PRIMARY KEY, value TEXT)");
var dbVersionStr = conn.SelectScalar("SELECT value FROM metadata WHERE key = 'version'");
if (dbVersionStr == null) {
InitializeSchemas();
}
else if (!int.TryParse(dbVersionStr.ToString(), out int dbVersion) || dbVersion < 1) {
throw new InvalidDatabaseVersionException(dbVersionStr.ToString() ?? "<null>");
}
else if (dbVersion > Version) {
throw new DatabaseTooNewException(dbVersion);
}
else if (dbVersion < Version) {
var proceed = await callbacks.CanUpgrade();
if (!proceed) {
return false;
}
await callbacks.Start(Version - dbVersion, async reporter => await UpgradeSchemas(dbVersion, reporter));
}
return true;
}
private void InitializeSchemas() {
conn.Execute("""
CREATE TABLE users (
id INTEGER PRIMARY KEY NOT NULL,
name TEXT NOT NULL,
avatar_url TEXT,
discriminator TEXT
)
""");
conn.Execute("""
CREATE TABLE servers (
id INTEGER PRIMARY KEY NOT NULL,
name TEXT NOT NULL,
type TEXT NOT NULL
)
""");
conn.Execute("""
CREATE TABLE channels (
id INTEGER PRIMARY KEY NOT NULL,
server INTEGER NOT NULL,
name TEXT NOT NULL,
parent_id INTEGER,
position INTEGER,
topic TEXT,
nsfw INTEGER
)
""");
conn.Execute("""
CREATE TABLE messages (
message_id INTEGER PRIMARY KEY NOT NULL,
sender_id INTEGER NOT NULL,
channel_id INTEGER NOT NULL,
text TEXT NOT NULL,
timestamp INTEGER NOT NULL
)
""");
conn.Execute("""
CREATE TABLE attachments (
message_id INTEGER NOT NULL,
attachment_id INTEGER NOT NULL PRIMARY KEY NOT NULL,
name TEXT NOT NULL,
type TEXT,
normalized_url TEXT NOT NULL,
download_url TEXT,
size INTEGER NOT NULL,
width INTEGER,
height INTEGER
)
""");
conn.Execute("""
CREATE TABLE embeds (
message_id INTEGER NOT NULL,
json TEXT NOT NULL
)
""");
conn.Execute("""
CREATE TABLE downloads (
normalized_url TEXT NOT NULL PRIMARY KEY,
download_url TEXT,
status INTEGER NOT NULL,
size INTEGER NOT NULL,
blob BLOB
)
""");
conn.Execute("""
CREATE TABLE reactions (
message_id INTEGER NOT NULL,
emoji_id INTEGER,
emoji_name TEXT,
emoji_flags INTEGER NOT NULL,
count INTEGER NOT NULL
)
""");
CreateMessageEditTimestampTable();
CreateMessageRepliedToTable();
conn.Execute("CREATE INDEX attachments_message_ix ON attachments(message_id)");
conn.Execute("CREATE INDEX embeds_message_ix ON embeds(message_id)");
conn.Execute("CREATE INDEX reactions_message_ix ON reactions(message_id)");
conn.Execute("INSERT INTO metadata (key, value) VALUES ('version', " + Version + ")");
}
private void CreateMessageEditTimestampTable() {
conn.Execute("""
CREATE TABLE edit_timestamps (
message_id INTEGER PRIMARY KEY NOT NULL,
edit_timestamp INTEGER NOT NULL
)
""");
}
private void CreateMessageRepliedToTable() {
conn.Execute("""
CREATE TABLE replied_to (
message_id INTEGER PRIMARY KEY NOT NULL,
replied_to_id INTEGER NOT NULL
)
""");
}
private async Task NormalizeAttachmentUrls(ISchemaUpgradeCallbacks.IProgressReporter reporter) {
await reporter.SubWork("Preparing attachments...", 0, 0);
var normalizedUrls = new Dictionary<long, string>();
await using (var selectCmd = conn.Command("SELECT attachment_id, url FROM attachments")) {
await using var reader = await selectCmd.ExecuteReaderAsync();
while (reader.Read()) {
var attachmentId = reader.GetInt64(0);
var originalUrl = reader.GetString(1);
normalizedUrls[attachmentId] = DiscordCdn.NormalizeUrl(originalUrl);
}
}
await using var tx = await conn.BeginTransactionAsync();
int totalUrls = normalizedUrls.Count;
int processedUrls = -1;
await using (var updateCmd = conn.Command("UPDATE attachments SET download_url = url, url = :normalized_url WHERE attachment_id = :attachment_id")) {
updateCmd.Add(":attachment_id", SqliteType.Integer);
updateCmd.Add(":normalized_url", SqliteType.Text);
foreach (var (attachmentId, normalizedUrl) in normalizedUrls) {
if (++processedUrls % 1000 == 0) {
await reporter.SubWork("Updating URLs...", processedUrls, totalUrls);
}
updateCmd.Set(":attachment_id", attachmentId);
updateCmd.Set(":normalized_url", normalizedUrl);
updateCmd.ExecuteNonQuery();
}
}
await reporter.SubWork("Updating URLs...", totalUrls, totalUrls);
await tx.CommitAsync();
}
private async Task NormalizeDownloadUrls(ISchemaUpgradeCallbacks.IProgressReporter reporter) {
await reporter.SubWork("Preparing downloads...", 0, 0);
var normalizedUrlsToOriginalUrls = new Dictionary<string, string>();
var duplicateUrlsToDelete = new HashSet<string>();
await using (var selectCmd = conn.Command("SELECT url FROM downloads ORDER BY CASE WHEN status = 200 THEN 0 ELSE 1 END")) {
await using var reader = await selectCmd.ExecuteReaderAsync();
while (reader.Read()) {
var originalUrl = reader.GetString(0);
var normalizedUrl = DiscordCdn.NormalizeUrl(originalUrl);
if (!normalizedUrlsToOriginalUrls.TryAdd(normalizedUrl, originalUrl)) {
duplicateUrlsToDelete.Add(originalUrl);
}
}
}
conn.Execute("PRAGMA cache_size = -20000");
DbTransaction tx;
await using (tx = await conn.BeginTransactionAsync()) {
await reporter.SubWork("Deleting duplicates...", 0, 0);
await using (var deleteCmd = conn.Delete("downloads", ("url", SqliteType.Text))) {
foreach (var duplicateUrl in duplicateUrlsToDelete) {
deleteCmd.Set(":url", duplicateUrl);
deleteCmd.ExecuteNonQuery();
}
}
await tx.CommitAsync();
}
int totalUrls = normalizedUrlsToOriginalUrls.Count;
int processedUrls = -1;
tx = await conn.BeginTransactionAsync();
await using (var updateCmd = conn.Command("UPDATE downloads SET download_url = :download_url, url = :normalized_url WHERE url = :download_url")) {
updateCmd.Add(":normalized_url", SqliteType.Text);
updateCmd.Add(":download_url", SqliteType.Text);
foreach (var (normalizedUrl, downloadUrl) in normalizedUrlsToOriginalUrls) {
if (++processedUrls % 100 == 0) {
await reporter.SubWork("Updating URLs...", processedUrls, totalUrls);
// Not proper way of dealing with transactions, but it avoids a long commit at the end.
// Schema upgrades are already non-atomic anyways, so this doesn't make it worse.
await tx.CommitAsync();
await tx.DisposeAsync();
tx = await conn.BeginTransactionAsync();
updateCmd.Transaction = (SqliteTransaction) tx;
}
updateCmd.Set(":normalized_url", normalizedUrl);
updateCmd.Set(":download_url", downloadUrl);
updateCmd.ExecuteNonQuery();
}
}
await reporter.SubWork("Updating URLs...", totalUrls, totalUrls);
await tx.CommitAsync();
await tx.DisposeAsync();
conn.Execute("PRAGMA cache_size = -2000");
}
private async Task UpgradeSchemas(int dbVersion, ISchemaUpgradeCallbacks.IProgressReporter reporter) {
var perf = Log.Start("from version " + dbVersion);
conn.Execute("UPDATE metadata SET value = " + Version + " WHERE key = 'version'");
if (dbVersion <= 1) {
await reporter.MainWork("Applying schema changes...", 0, 1);
conn.Execute("ALTER TABLE channels ADD parent_id INTEGER");
perf.Step("Upgrade to version 2");
await reporter.NextVersion();
}
if (dbVersion <= 2) {
await reporter.MainWork("Applying schema changes...", 0, 1);
CreateMessageEditTimestampTable();
CreateMessageRepliedToTable();
conn.Execute("""
INSERT INTO edit_timestamps (message_id, edit_timestamp)
SELECT message_id, edit_timestamp
FROM messages
WHERE edit_timestamp IS NOT NULL
""");
conn.Execute("""
INSERT INTO replied_to (message_id, replied_to_id)
SELECT message_id, replied_to_id
FROM messages
WHERE replied_to_id IS NOT NULL
""");
conn.Execute("ALTER TABLE messages DROP COLUMN replied_to_id");
conn.Execute("ALTER TABLE messages DROP COLUMN edit_timestamp");
perf.Step("Upgrade to version 3");
await reporter.MainWork("Vacuuming the database...", 1, 1);
conn.Execute("VACUUM");
perf.Step("Vacuum");
await reporter.NextVersion();
}
if (dbVersion <= 3) {
conn.Execute("""
CREATE TABLE downloads (
url TEXT NOT NULL PRIMARY KEY,
status INTEGER NOT NULL,
size INTEGER NOT NULL,
blob BLOB
)
""");
perf.Step("Upgrade to version 4");
await reporter.NextVersion();
}
if (dbVersion <= 4) {
await reporter.MainWork("Applying schema changes...", 0, 1);
conn.Execute("ALTER TABLE attachments ADD width INTEGER");
conn.Execute("ALTER TABLE attachments ADD height INTEGER");
perf.Step("Upgrade to version 5");
await reporter.NextVersion();
}
if (dbVersion <= 5) {
await reporter.MainWork("Applying schema changes...", 0, 3);
conn.Execute("ALTER TABLE attachments ADD download_url TEXT");
conn.Execute("ALTER TABLE downloads ADD download_url TEXT");
await reporter.MainWork("Updating attachments...", 1, 3);
await NormalizeAttachmentUrls(reporter);
await reporter.MainWork("Updating downloads...", 2, 3);
await NormalizeDownloadUrls(reporter);
await reporter.MainWork("Applying schema changes...", 3, 3);
conn.Execute("ALTER TABLE attachments RENAME COLUMN url TO normalized_url");
conn.Execute("ALTER TABLE downloads RENAME COLUMN url TO normalized_url");
perf.Step("Upgrade to version 6");
await reporter.NextVersion();
}
perf.End();
}
}

View File

@@ -1,8 +0,0 @@
using System.Threading.Tasks;
using DHT.Server.Database.Sqlite.Utils;
namespace DHT.Server.Database.Sqlite.Schema;
interface ISchemaUpgrade {
Task Run(ISqliteConnection conn, ISchemaUpgradeCallbacks.IProgressReporter reporter);
}

View File

@@ -1,11 +0,0 @@
using System.Threading.Tasks;
using DHT.Server.Database.Sqlite.Utils;
namespace DHT.Server.Database.Sqlite.Schema;
sealed class SqliteSchemaUpgradeTo2 : ISchemaUpgrade {
async Task ISchemaUpgrade.Run(ISqliteConnection conn, ISchemaUpgradeCallbacks.IProgressReporter reporter) {
await reporter.MainWork("Applying schema changes...", 0, 1);
await conn.ExecuteAsync("ALTER TABLE channels ADD parent_id INTEGER");
}
}

View File

@@ -1,33 +0,0 @@
using System.Threading.Tasks;
using DHT.Server.Database.Sqlite.Utils;
namespace DHT.Server.Database.Sqlite.Schema;
sealed class SqliteSchemaUpgradeTo3 : ISchemaUpgrade {
async Task ISchemaUpgrade.Run(ISqliteConnection conn, ISchemaUpgradeCallbacks.IProgressReporter reporter) {
await reporter.MainWork("Applying schema changes...", 0, 1);
await SqliteSchema.CreateMessageEditTimestampTable(conn);
await SqliteSchema.CreateMessageRepliedToTable(conn);
await conn.ExecuteAsync("""
INSERT INTO edit_timestamps (message_id, edit_timestamp)
SELECT message_id, edit_timestamp
FROM messages
WHERE edit_timestamp IS NOT NULL
""");
await conn.ExecuteAsync("""
INSERT INTO replied_to (message_id, replied_to_id)
SELECT message_id, replied_to_id
FROM messages
WHERE replied_to_id IS NOT NULL
""");
await conn.ExecuteAsync("ALTER TABLE messages DROP COLUMN replied_to_id");
await conn.ExecuteAsync("ALTER TABLE messages DROP COLUMN edit_timestamp");
await reporter.MainWork("Vacuuming the database...", 1, 1);
await conn.ExecuteAsync("VACUUM");
}
}

View File

@@ -1,19 +0,0 @@
using System.Threading.Tasks;
using DHT.Server.Database.Sqlite.Utils;
namespace DHT.Server.Database.Sqlite.Schema;
sealed class SqliteSchemaUpgradeTo4 : ISchemaUpgrade {
async Task ISchemaUpgrade.Run(ISqliteConnection conn, ISchemaUpgradeCallbacks.IProgressReporter reporter) {
await reporter.MainWork("Applying schema changes...", 0, 1);
await conn.ExecuteAsync("""
CREATE TABLE downloads (
url TEXT NOT NULL PRIMARY KEY,
status INTEGER NOT NULL,
size INTEGER NOT NULL,
blob BLOB
)
""");
}
}

View File

@@ -1,12 +0,0 @@
using System.Threading.Tasks;
using DHT.Server.Database.Sqlite.Utils;
namespace DHT.Server.Database.Sqlite.Schema;
sealed class SqliteSchemaUpgradeTo5 : ISchemaUpgrade {
async Task ISchemaUpgrade.Run(ISqliteConnection conn, ISchemaUpgradeCallbacks.IProgressReporter reporter) {
await reporter.MainWork("Applying schema changes...", 0, 1);
await conn.ExecuteAsync("ALTER TABLE attachments ADD width INTEGER");
await conn.ExecuteAsync("ALTER TABLE attachments ADD height INTEGER");
}
}

View File

@@ -1,132 +0,0 @@
using System.Collections.Generic;
using System.Threading.Tasks;
using DHT.Server.Database.Sqlite.Utils;
using DHT.Server.Download;
using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite.Schema;
sealed class SqliteSchemaUpgradeTo6 : ISchemaUpgrade {
async Task ISchemaUpgrade.Run(ISqliteConnection conn, ISchemaUpgradeCallbacks.IProgressReporter reporter) {
await reporter.MainWork("Applying schema changes...", 0, 3);
await conn.ExecuteAsync("ALTER TABLE attachments ADD download_url TEXT");
await conn.ExecuteAsync("ALTER TABLE downloads ADD download_url TEXT");
await reporter.MainWork("Updating attachments...", 1, 3);
await NormalizeAttachmentUrls(conn, reporter);
await reporter.MainWork("Updating downloads...", 2, 3);
await NormalizeDownloadUrls(conn, reporter);
await reporter.MainWork("Applying schema changes...", 3, 3);
await conn.ExecuteAsync("ALTER TABLE attachments RENAME COLUMN url TO normalized_url");
await conn.ExecuteAsync("ALTER TABLE downloads RENAME COLUMN url TO normalized_url");
}
private async Task NormalizeAttachmentUrls(ISqliteConnection conn, ISchemaUpgradeCallbacks.IProgressReporter reporter) {
await reporter.SubWork("Preparing attachments...", 0, 0);
var normalizedUrls = new Dictionary<long, string>();
await using (var selectCmd = conn.Command("SELECT attachment_id, url FROM attachments")) {
await using var reader = await selectCmd.ExecuteReaderAsync();
while (await reader.ReadAsync()) {
var attachmentId = reader.GetInt64(0);
var originalUrl = reader.GetString(1);
normalizedUrls[attachmentId] = DiscordCdn.NormalizeUrl(originalUrl);
}
}
await conn.BeginTransactionAsync();
int totalUrls = normalizedUrls.Count;
int processedUrls = -1;
await using (var updateCmd = conn.Command("UPDATE attachments SET download_url = url, url = :normalized_url WHERE attachment_id = :attachment_id")) {
updateCmd.Add(":attachment_id", SqliteType.Integer);
updateCmd.Add(":normalized_url", SqliteType.Text);
foreach (var (attachmentId, normalizedUrl) in normalizedUrls) {
if (++processedUrls % 1000 == 0) {
await reporter.SubWork("Updating URLs...", processedUrls, totalUrls);
}
updateCmd.Set(":attachment_id", attachmentId);
updateCmd.Set(":normalized_url", normalizedUrl);
await updateCmd.ExecuteNonQueryAsync();
}
}
await reporter.SubWork("Updating URLs...", totalUrls, totalUrls);
await conn.CommitTransactionAsync();
}
private async Task NormalizeDownloadUrls(ISqliteConnection conn, ISchemaUpgradeCallbacks.IProgressReporter reporter) {
await reporter.SubWork("Preparing downloads...", 0, 0);
var normalizedUrlsToOriginalUrls = new Dictionary<string, string>();
var duplicateUrlsToDelete = new HashSet<string>();
await using (var selectCmd = conn.Command("SELECT url FROM downloads ORDER BY CASE WHEN status = 200 THEN 0 ELSE 1 END")) {
await using var reader = await selectCmd.ExecuteReaderAsync();
while (await reader.ReadAsync()) {
var originalUrl = reader.GetString(0);
var normalizedUrl = DiscordCdn.NormalizeUrl(originalUrl);
if (!normalizedUrlsToOriginalUrls.TryAdd(normalizedUrl, originalUrl)) {
duplicateUrlsToDelete.Add(originalUrl);
}
}
}
await conn.ExecuteAsync("PRAGMA cache_size = -20000");
await conn.BeginTransactionAsync();
await reporter.SubWork("Deleting duplicates...", 0, 0);
await using (var deleteCmd = conn.Delete("downloads", ("url", SqliteType.Text))) {
foreach (var duplicateUrl in duplicateUrlsToDelete) {
deleteCmd.Set(":url", duplicateUrl);
await deleteCmd.ExecuteNonQueryAsync();
}
}
await conn.CommitTransactionAsync();
int totalUrls = normalizedUrlsToOriginalUrls.Count;
int processedUrls = -1;
await conn.BeginTransactionAsync();
await using (var updateCmd = conn.Command("UPDATE downloads SET download_url = :download_url, url = :normalized_url WHERE url = :download_url")) {
updateCmd.Add(":normalized_url", SqliteType.Text);
updateCmd.Add(":download_url", SqliteType.Text);
foreach (var (normalizedUrl, downloadUrl) in normalizedUrlsToOriginalUrls) {
if (++processedUrls % 100 == 0) {
await reporter.SubWork("Updating URLs...", processedUrls, totalUrls);
// Not proper way of dealing with transactions, but it avoids a long commit at the end.
// Schema upgrades are already non-atomic anyways, so this doesn't make it worse.
await conn.CommitTransactionAsync();
await conn.BeginTransactionAsync();
conn.AssignActiveTransaction(updateCmd);
}
updateCmd.Set(":normalized_url", normalizedUrl);
updateCmd.Set(":download_url", downloadUrl);
await updateCmd.ExecuteNonQueryAsync();
}
}
await reporter.SubWork("Updating URLs...", totalUrls, totalUrls);
await conn.CommitTransactionAsync();
await conn.ExecuteAsync("PRAGMA cache_size = -2000");
}
}

View File

@@ -1,153 +0,0 @@
using System.Collections.Generic;
using System.Threading.Tasks;
using DHT.Server.Data;
using DHT.Server.Database.Sqlite.Utils;
using DHT.Server.Download;
using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite.Schema;
sealed class SqliteSchemaUpgradeTo7 : ISchemaUpgrade {
async Task ISchemaUpgrade.Run(ISqliteConnection conn, ISchemaUpgradeCallbacks.IProgressReporter reporter) {
await reporter.MainWork("Applying schema changes...", 0, 6);
await SqliteSchema.CreateDownloadTables(conn);
await reporter.MainWork("Migrating download metadata...", 1, 6);
await conn.ExecuteAsync("INSERT INTO download_metadata (normalized_url, download_url, status, size) SELECT normalized_url, download_url, status, size FROM downloads");
await reporter.MainWork("Merging attachment metadata...", 2, 6);
await conn.ExecuteAsync("UPDATE download_metadata SET type = (SELECT type FROM attachments WHERE download_metadata.normalized_url = attachments.normalized_url)");
await reporter.MainWork("Migrating downloaded files...", 3, 6);
await MigrateDownloadBlobsToNewTable(conn, reporter);
await reporter.MainWork("Applying schema changes...", 4, 6);
await conn.ExecuteAsync("DROP TABLE downloads");
await reporter.MainWork("Discovering downloadable links...", 5, 6);
await DiscoverDownloadableLinks(conn, reporter);
}
private async Task MigrateDownloadBlobsToNewTable(ISqliteConnection conn, ISchemaUpgradeCallbacks.IProgressReporter reporter) {
await reporter.SubWork("Listing downloaded files...", 0, 0);
var urlsToMigrate = await GetDownloadedFileUrls(conn);
int totalFiles = urlsToMigrate.Count;
int processedFiles = -1;
await reporter.SubWork("Processing downloaded files...", 0, totalFiles);
await conn.BeginTransactionAsync();
await using (var insertCmd = conn.Command("INSERT INTO download_blobs (normalized_url, blob) SELECT normalized_url, blob FROM downloads WHERE normalized_url = :normalized_url"))
await using (var deleteCmd = conn.Command("DELETE FROM downloads WHERE normalized_url = :normalized_url")) {
insertCmd.Add(":normalized_url", SqliteType.Text);
deleteCmd.Add(":normalized_url", SqliteType.Text);
foreach (var url in urlsToMigrate) {
if (++processedFiles % 10 == 0) {
await reporter.SubWork("Processing downloaded files...", processedFiles, totalFiles);
// Not proper way of dealing with transactions, but it avoids a long commit at the end.
// Schema upgrades are already non-atomic anyways, so this doesn't make it worse.
await conn.CommitTransactionAsync();
await conn.BeginTransactionAsync();
conn.AssignActiveTransaction(insertCmd);
conn.AssignActiveTransaction(deleteCmd);
}
insertCmd.Set(":normalized_url", url);
await insertCmd.ExecuteNonQueryAsync();
deleteCmd.Set(":normalized_url", url);
await deleteCmd.ExecuteNonQueryAsync();
}
}
await reporter.SubWork("Processing downloaded files...", totalFiles, totalFiles);
await conn.CommitTransactionAsync();
}
private async Task<List<string>> GetDownloadedFileUrls(ISqliteConnection conn) {
var urls = new List<string>();
await using var selectCmd = conn.Command("SELECT normalized_url FROM downloads WHERE blob IS NOT NULL");
await using var reader = await selectCmd.ExecuteReaderAsync();
while (await reader.ReadAsync()) {
urls.Add(reader.GetString(0));
}
return urls;
}
private async Task DiscoverDownloadableLinks(ISqliteConnection conn, ISchemaUpgradeCallbacks.IProgressReporter reporter) {
await reporter.SubWork("Processing attachments...", 0, 4);
await using (var cmd = conn.Command("""
INSERT OR IGNORE INTO download_metadata (normalized_url, download_url, status, type, size)
SELECT a.normalized_url, a.download_url, :pending, a.type, MAX(a.size)
FROM attachments a
GROUP BY a.normalized_url
""")) {
cmd.AddAndSet(":pending", SqliteType.Integer, (int) DownloadStatus.Pending);
await cmd.ExecuteNonQueryAsync();
}
static async Task InsertDownload(SqliteCommand insertCmd, Data.Download? download) {
if (download == null) {
return;
}
insertCmd.Set(":normalized_url", download.NormalizedUrl);
insertCmd.Set(":download_url", download.DownloadUrl);
insertCmd.Set(":status", (int) download.Status);
insertCmd.Set(":type", download.Type);
insertCmd.Set(":size", download.Size);
await insertCmd.ExecuteNonQueryAsync();
}
await conn.BeginTransactionAsync();
await using var insertCmd = conn.Command("INSERT OR IGNORE INTO download_metadata (normalized_url, download_url, status, type, size) VALUES (:normalized_url, :download_url, :status, :type, :size)");
insertCmd.Add(":normalized_url", SqliteType.Text);
insertCmd.Add(":download_url", SqliteType.Text);
insertCmd.Add(":status", SqliteType.Integer);
insertCmd.Add(":type", SqliteType.Text);
insertCmd.Add(":size", SqliteType.Integer);
await reporter.SubWork("Processing embeds...", 1, 4);
await using (var embedCmd = conn.Command("SELECT json FROM embeds")) {
await using var reader = await embedCmd.ExecuteReaderAsync();
while (await reader.ReadAsync()) {
await InsertDownload(insertCmd, await DownloadLinkExtractor.TryFromEmbedJson(reader.GetStream(0)));
}
}
await reporter.SubWork("Processing users...", 2, 4);
await using (var avatarCmd = conn.Command("SELECT id, avatar_url FROM users WHERE avatar_url IS NOT NULL")) {
await using var reader = await avatarCmd.ExecuteReaderAsync();
while (await reader.ReadAsync()) {
await InsertDownload(insertCmd, DownloadLinkExtractor.FromUserAvatar(reader.GetUint64(0), reader.GetString(1)));
}
}
await reporter.SubWork("Processing reactions...", 3, 4);
await using (var avatarCmd = conn.Command("SELECT DISTINCT emoji_id, emoji_flags FROM reactions WHERE emoji_id IS NOT NULL")) {
await using var reader = await avatarCmd.ExecuteReaderAsync();
while (await reader.ReadAsync()) {
await InsertDownload(insertCmd, DownloadLinkExtractor.FromEmoji(reader.GetUint64(0), (EmojiFlags) reader.GetInt16(1)));
}
}
await conn.CommitTransactionAsync();
}
}

View File

@@ -2,8 +2,8 @@ using System;
using System.Threading.Tasks;
using DHT.Server.Database.Repositories;
using DHT.Server.Database.Sqlite.Repositories;
using DHT.Server.Database.Sqlite.Schema;
using DHT.Server.Database.Sqlite.Utils;
using DHT.Utils.Tasks;
using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite;
@@ -11,33 +11,36 @@ namespace DHT.Server.Database.Sqlite;
public sealed class SqliteDatabaseFile : IDatabaseFile {
private const int DefaultPoolSize = 5;
public static async Task<SqliteDatabaseFile?> OpenOrCreate(string path, ISchemaUpgradeCallbacks schemaUpgradeCallbacks) {
public static async Task<SqliteDatabaseFile?> OpenOrCreate(string path, ISchemaUpgradeCallbacks schemaUpgradeCallbacks, TaskScheduler computeTaskResultScheduler) {
var connectionString = new SqliteConnectionStringBuilder {
DataSource = path,
Mode = SqliteOpenMode.ReadWriteCreate,
};
var pool = await SqliteConnectionPool.Create(connectionString, DefaultPoolSize);
var pool = new SqliteConnectionPool(connectionString, DefaultPoolSize);
bool wasOpened;
try {
await using var conn = await pool.Take();
wasOpened = await new SqliteSchema(conn).Setup(schemaUpgradeCallbacks);
using var conn = pool.Take();
wasOpened = await new Schema(conn).Setup(schemaUpgradeCallbacks);
} catch (Exception) {
await pool.DisposeAsync();
pool.Dispose();
throw;
}
if (wasOpened) {
return new SqliteDatabaseFile(path, pool);
var db = new SqliteDatabaseFile(path, pool, computeTaskResultScheduler);
await db.Initialize();
return db;
}
else {
await pool.DisposeAsync();
pool.Dispose();
return null;
}
}
public string Path { get; }
public DatabaseStatistics Statistics { get; }
public IUserRepository Users => users;
public IServerRepository Servers => servers;
@@ -53,28 +56,82 @@ public sealed class SqliteDatabaseFile : IDatabaseFile {
private readonly SqliteMessageRepository messages;
private readonly SqliteDownloadRepository downloads;
private SqliteDatabaseFile(string path, SqliteConnectionPool pool) {
this.Path = path;
private readonly AsyncValueComputer<long>.Single totalMessagesComputer;
private readonly AsyncValueComputer<long>.Single totalAttachmentsComputer;
private readonly AsyncValueComputer<long>.Single totalDownloadsComputer;
private SqliteDatabaseFile(string path, SqliteConnectionPool pool, TaskScheduler computeTaskResultScheduler) {
this.pool = pool;
downloads = new SqliteDownloadRepository(pool);
users = new SqliteUserRepository(pool, downloads);
servers = new SqliteServerRepository(pool);
channels = new SqliteChannelRepository(pool);
messages = new SqliteMessageRepository(pool, downloads);
this.totalMessagesComputer = AsyncValueComputer<long>.WithResultProcessor(UpdateMessageStatistics, computeTaskResultScheduler).WithOutdatedResults().BuildWithComputer(ComputeMessageStatistics);
this.totalAttachmentsComputer = AsyncValueComputer<long>.WithResultProcessor(UpdateAttachmentStatistics, computeTaskResultScheduler).WithOutdatedResults().BuildWithComputer(ComputeAttachmentStatistics);
this.totalDownloadsComputer = AsyncValueComputer<long>.WithResultProcessor(UpdateDownloadStatistics, computeTaskResultScheduler).WithOutdatedResults().BuildWithComputer(ComputeDownloadStatistics);
this.Path = path;
this.Statistics = new DatabaseStatistics();
this.users = new SqliteUserRepository(pool, Statistics);
this.servers = new SqliteServerRepository(pool, Statistics);
this.channels = new SqliteChannelRepository(pool, Statistics);
this.messages = new SqliteMessageRepository(pool, totalMessagesComputer, totalAttachmentsComputer);
this.downloads = new SqliteDownloadRepository(pool, totalDownloadsComputer);
totalMessagesComputer.Recompute();
totalAttachmentsComputer.Recompute();
totalDownloadsComputer.Recompute();
}
public async ValueTask DisposeAsync() {
users.Dispose();
servers.Dispose();
channels.Dispose();
messages.Dispose();
downloads.Dispose();
await pool.DisposeAsync();
private async Task Initialize() {
await users.Initialize();
await servers.Initialize();
await channels.Initialize();
}
public void Dispose() {
totalMessagesComputer.Cancel();
totalAttachmentsComputer.Cancel();
totalDownloadsComputer.Cancel();
pool.Dispose();
}
public async Task<DatabaseStatisticsSnapshot> SnapshotStatistics() {
return new DatabaseStatisticsSnapshot {
TotalServers = Statistics.TotalServers,
TotalChannels = Statistics.TotalChannels,
TotalUsers = Statistics.TotalUsers,
TotalMessages = await ComputeMessageStatistics(),
};
}
public async Task Vacuum() {
await using var conn = await pool.Take();
using var conn = pool.Take();
await conn.ExecuteAsync("VACUUM");
}
private async Task<long> ComputeMessageStatistics() {
using var conn = pool.Take();
return await conn.ExecuteReaderAsync("SELECT COUNT(*) FROM messages", static reader => reader?.GetInt64(0) ?? 0L);
}
private void UpdateMessageStatistics(long totalMessages) {
Statistics.TotalMessages = totalMessages;
}
private async Task<long> ComputeAttachmentStatistics() {
using var conn = pool.Take();
return await conn.ExecuteReaderAsync("SELECT COUNT(DISTINCT normalized_url) FROM attachments", static reader => reader?.GetInt64(0) ?? 0L);
}
private void UpdateAttachmentStatistics(long totalAttachments) {
Statistics.TotalAttachments = totalAttachments;
}
private async Task<long> ComputeDownloadStatistics() {
using var conn = pool.Take();
return await conn.ExecuteReaderAsync("SELECT COUNT(*) FROM downloads", static reader => reader?.GetInt64(0) ?? 0L);
}
private void UpdateDownloadStatistics(long totalDownloads) {
Statistics.TotalDownloads = totalDownloads;
}
}

View File

@@ -8,53 +8,77 @@ using DHT.Server.Database.Sqlite.Utils;
namespace DHT.Server.Database.Sqlite;
static class SqliteFilters {
public static SqliteConditionBuilder GenerateConditions(this MessageFilter? filter, string? tableAlias = null, bool invert = false) {
var builder = new SqliteConditionBuilder(tableAlias, invert);
if (filter != null) {
if (filter.StartDate != null) {
builder.AddCondition("timestamp >= " + new DateTimeOffset(filter.StartDate.Value).ToUnixTimeMilliseconds());
}
if (filter.EndDate != null) {
builder.AddCondition("timestamp <= " + new DateTimeOffset(filter.EndDate.Value).ToUnixTimeMilliseconds());
}
if (filter.ChannelIds != null) {
builder.AddCondition("channel_id IN (" + string.Join(",", filter.ChannelIds) + ")");
}
if (filter.UserIds != null) {
builder.AddCondition("sender_id IN (" + string.Join(",", filter.UserIds) + ")");
}
if (filter.MessageIds != null) {
builder.AddCondition("message_id IN (" + string.Join(",", filter.MessageIds) + ")");
}
}
return builder;
private static string WhereAll(bool invert) {
return invert ? "WHERE FALSE" : "";
}
public static SqliteConditionBuilder GenerateConditions(this DownloadItemFilter? filter, string? tableAlias = null, bool invert = false) {
var builder = new SqliteConditionBuilder(tableAlias, invert);
if (filter != null) {
if (filter.IncludeStatuses != null) {
builder.AddCondition("status IN (" + filter.IncludeStatuses.In() + ")");
}
if (filter.ExcludeStatuses != null) {
builder.AddCondition("status NOT IN (" + filter.ExcludeStatuses.In() + ")");
}
if (filter.MaxBytes != null) {
builder.AddCondition("size IS NOT NULL");
builder.AddCondition("size <= " + filter.MaxBytes);
}
public static string GenerateWhereClause(this MessageFilter? filter, string? tableAlias = null, bool invert = false) {
if (filter == null || filter.IsEmpty) {
return WhereAll(invert);
}
return builder;
var where = new SqliteWhereGenerator(tableAlias, invert);
if (filter.StartDate != null) {
where.AddCondition("timestamp >= " + new DateTimeOffset(filter.StartDate.Value).ToUnixTimeMilliseconds());
}
if (filter.EndDate != null) {
where.AddCondition("timestamp <= " + new DateTimeOffset(filter.EndDate.Value).ToUnixTimeMilliseconds());
}
if (filter.ChannelIds != null) {
where.AddCondition("channel_id IN (" + string.Join(",", filter.ChannelIds) + ")");
}
if (filter.UserIds != null) {
where.AddCondition("sender_id IN (" + string.Join(",", filter.UserIds) + ")");
}
if (filter.MessageIds != null) {
where.AddCondition("message_id IN (" + string.Join(",", filter.MessageIds) + ")");
}
return where.Generate();
}
public static string GenerateWhereClause(this AttachmentFilter? filter, string? tableAlias = null, bool invert = false) {
if (filter == null || filter.IsEmpty) {
return WhereAll(invert);
}
var where = new SqliteWhereGenerator(tableAlias, invert);
if (filter.MaxBytes != null) {
where.AddCondition("size <= " + filter.MaxBytes);
}
if (filter.DownloadItemRule == AttachmentFilter.DownloadItemRules.OnlyNotPresent) {
where.AddCondition("normalized_url NOT IN (SELECT normalized_url FROM downloads)");
}
else if (filter.DownloadItemRule == AttachmentFilter.DownloadItemRules.OnlyPresent) {
where.AddCondition("normalized_url IN (SELECT normalized_url FROM downloads)");
}
return where.Generate();
}
public static string GenerateWhereClause(this DownloadItemFilter? filter, string? tableAlias = null, bool invert = false) {
if (filter == null || filter.IsEmpty) {
return WhereAll(invert);
}
var where = new SqliteWhereGenerator(tableAlias, invert);
if (filter.IncludeStatuses != null) {
where.AddCondition("status IN (" + filter.IncludeStatuses.In() + ")");
}
if (filter.ExcludeStatuses != null) {
where.AddCondition("status NOT IN (" + filter.ExcludeStatuses.In() + ")");
}
return where.Generate();
}
private static string In(this ISet<DownloadStatus> statuses) {

View File

@@ -1,193 +0,0 @@
using System.Collections.Generic;
using System.Threading.Tasks;
using DHT.Server.Database.Exceptions;
using DHT.Server.Database.Sqlite.Schema;
using DHT.Server.Database.Sqlite.Utils;
using DHT.Utils.Logging;
namespace DHT.Server.Database.Sqlite;
sealed class SqliteSchema {
internal const int Version = 7;
private static readonly Log Log = Log.ForType<SqliteSchema>();
private readonly ISqliteConnection conn;
public SqliteSchema(ISqliteConnection conn) {
this.conn = conn;
}
public async Task<bool> Setup(ISchemaUpgradeCallbacks callbacks) {
await conn.ExecuteAsync("CREATE TABLE IF NOT EXISTS metadata (key TEXT PRIMARY KEY, value TEXT)");
var dbVersionStr = await conn.ExecuteReaderAsync("SELECT value FROM metadata WHERE key = 'version'", static reader => reader?.GetString(0));
if (dbVersionStr == null) {
await InitializeSchemas();
}
else if (!int.TryParse(dbVersionStr, out int dbVersion) || dbVersion < 1) {
throw new InvalidDatabaseVersionException(dbVersionStr);
}
else if (dbVersion > Version) {
throw new DatabaseTooNewException(dbVersion);
}
else if (dbVersion < Version) {
var proceed = await callbacks.CanUpgrade();
if (!proceed) {
return false;
}
await callbacks.Start(Version - dbVersion, async reporter => await UpgradeSchemas(dbVersion, reporter));
}
return true;
}
private async Task InitializeSchemas() {
await conn.ExecuteAsync("""
CREATE TABLE users (
id INTEGER PRIMARY KEY NOT NULL,
name TEXT NOT NULL,
avatar_url TEXT,
discriminator TEXT
)
""");
await conn.ExecuteAsync("""
CREATE TABLE servers (
id INTEGER PRIMARY KEY NOT NULL,
name TEXT NOT NULL,
type TEXT NOT NULL
)
""");
await conn.ExecuteAsync("""
CREATE TABLE channels (
id INTEGER PRIMARY KEY NOT NULL,
server INTEGER NOT NULL,
name TEXT NOT NULL,
parent_id INTEGER,
position INTEGER,
topic TEXT,
nsfw INTEGER
)
""");
await conn.ExecuteAsync("""
CREATE TABLE messages (
message_id INTEGER PRIMARY KEY NOT NULL,
sender_id INTEGER NOT NULL,
channel_id INTEGER NOT NULL,
text TEXT NOT NULL,
timestamp INTEGER NOT NULL
)
""");
await conn.ExecuteAsync("""
CREATE TABLE attachments (
message_id INTEGER NOT NULL,
attachment_id INTEGER NOT NULL PRIMARY KEY NOT NULL,
name TEXT NOT NULL,
type TEXT,
normalized_url TEXT NOT NULL,
download_url TEXT,
size INTEGER NOT NULL,
width INTEGER,
height INTEGER
)
""");
await conn.ExecuteAsync("""
CREATE TABLE embeds (
message_id INTEGER NOT NULL,
json TEXT NOT NULL
)
""");
await conn.ExecuteAsync("""
CREATE TABLE reactions (
message_id INTEGER NOT NULL,
emoji_id INTEGER,
emoji_name TEXT,
emoji_flags INTEGER NOT NULL,
count INTEGER NOT NULL
)
""");
await CreateMessageEditTimestampTable(conn);
await CreateMessageRepliedToTable(conn);
await CreateDownloadTables(conn);
await conn.ExecuteAsync("CREATE INDEX attachments_message_ix ON attachments(message_id)");
await conn.ExecuteAsync("CREATE INDEX embeds_message_ix ON embeds(message_id)");
await conn.ExecuteAsync("CREATE INDEX reactions_message_ix ON reactions(message_id)");
await conn.ExecuteAsync("INSERT INTO metadata (key, value) VALUES ('version', " + Version + ")");
}
internal static async Task CreateMessageEditTimestampTable(ISqliteConnection conn) {
await conn.ExecuteAsync("""
CREATE TABLE edit_timestamps (
message_id INTEGER PRIMARY KEY NOT NULL,
edit_timestamp INTEGER NOT NULL
)
""");
}
internal static async Task CreateMessageRepliedToTable(ISqliteConnection conn) {
await conn.ExecuteAsync("""
CREATE TABLE replied_to (
message_id INTEGER PRIMARY KEY NOT NULL,
replied_to_id INTEGER NOT NULL
)
""");
}
internal static async Task CreateDownloadTables(ISqliteConnection conn) {
await conn.ExecuteAsync("""
CREATE TABLE download_metadata (
normalized_url TEXT NOT NULL PRIMARY KEY,
download_url TEXT NOT NULL,
status INTEGER NOT NULL,
type TEXT,
size INTEGER
)
""");
await conn.ExecuteAsync("""
CREATE TABLE download_blobs (
normalized_url TEXT NOT NULL PRIMARY KEY,
blob BLOB NOT NULL,
FOREIGN KEY (normalized_url) REFERENCES download_metadata (normalized_url) ON UPDATE CASCADE ON DELETE CASCADE
)
""");
}
private async Task UpgradeSchemas(int dbVersion, ISchemaUpgradeCallbacks.IProgressReporter reporter) {
var upgrades = new Dictionary<int, ISchemaUpgrade> {
{ 1, new SqliteSchemaUpgradeTo2() },
{ 2, new SqliteSchemaUpgradeTo3() },
{ 3, new SqliteSchemaUpgradeTo4() },
{ 4, new SqliteSchemaUpgradeTo5() },
{ 5, new SqliteSchemaUpgradeTo6() },
{ 6, new SqliteSchemaUpgradeTo7() },
};
var perf = Log.Start("from version " + dbVersion);
for (int fromVersion = dbVersion; fromVersion < Version; fromVersion++) {
var toVersion = fromVersion + 1;
if (upgrades.TryGetValue(fromVersion, out var upgrade)) {
await upgrade.Run(conn, reporter);
}
await conn.ExecuteAsync("UPDATE metadata SET value = " + toVersion + " WHERE key = 'version'");
perf.Step("Upgrade to version " + toVersion);
await reporter.NextVersion();
}
perf.End();
}
}

View File

@@ -1,7 +1,7 @@
using System;
using System.Threading.Tasks;
namespace DHT.Server.Database.Sqlite.Schema;
namespace DHT.Server.Database.Sqlite.Utils;
public interface ISchemaUpgradeCallbacks {
Task<bool> CanUpgrade();

View File

@@ -1,15 +1,8 @@
using System;
using System.Threading.Tasks;
using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite.Utils;
interface ISqliteConnection : IAsyncDisposable {
interface ISqliteConnection : IDisposable {
SqliteConnection InnerConnection { get; }
Task BeginTransactionAsync();
Task CommitTransactionAsync();
Task RollbackTransactionAsync();
void AssignActiveTransaction(SqliteCommand command);
}

View File

@@ -1,122 +1,114 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Data.Common;
using System.Threading;
using System.Threading.Tasks;
using DHT.Utils.Collections;
using DHT.Utils.Logging;
using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite.Utils;
sealed class SqliteConnectionPool : IAsyncDisposable {
public static async Task<SqliteConnectionPool> Create(SqliteConnectionStringBuilder connectionStringBuilder, int poolSize) {
var pool = new SqliteConnectionPool(poolSize);
await pool.InitializePooledConnections(connectionStringBuilder);
return pool;
}
sealed class SqliteConnectionPool : IDisposable {
private static string GetConnectionString(SqliteConnectionStringBuilder connectionStringBuilder) {
connectionStringBuilder.Pooling = false;
return connectionStringBuilder.ToString();
}
private readonly int poolSize;
private readonly List<PooledConnection> all;
private readonly ConcurrentPool<PooledConnection> free;
private readonly object monitor = new ();
private readonly Random rand = new ();
private volatile bool isDisposed;
private readonly CancellationTokenSource disposalTokenSource = new ();
private readonly CancellationToken disposalToken;
private readonly BlockingCollection<PooledConnection> free = new (new ConcurrentStack<PooledConnection>());
private readonly List<PooledConnection> used;
private SqliteConnectionPool(int poolSize) {
this.poolSize = poolSize;
this.all = new List<PooledConnection>(poolSize);
this.free = new ConcurrentPool<PooledConnection>(poolSize);
this.disposalToken = disposalTokenSource.Token;
}
private async Task InitializePooledConnections(SqliteConnectionStringBuilder connectionStringBuilder) {
public SqliteConnectionPool(SqliteConnectionStringBuilder connectionStringBuilder, int poolSize) {
var connectionString = GetConnectionString(connectionStringBuilder);
for (int i = 0; i < poolSize; i++) {
var conn = new SqliteConnection(connectionString);
conn.Open();
var pooledConnection = new PooledConnection(this, conn);
var pooledConn = new PooledConnection(this, conn);
await pooledConnection.ExecuteAsync("PRAGMA journal_mode=WAL", disposalToken);
await pooledConnection.ExecuteAsync("PRAGMA foreign_keys=ON", disposalToken);
using (var cmd = pooledConn.Command("PRAGMA journal_mode=WAL")) {
cmd.ExecuteNonQuery();
}
all.Add(pooledConnection);
await free.Push(pooledConnection, disposalToken);
free.Add(pooledConn);
}
used = new List<PooledConnection>(poolSize);
}
private void ThrowIfDisposed() {
ObjectDisposedException.ThrowIf(isDisposed, nameof(SqliteConnectionPool));
}
public ISqliteConnection Take() {
while (true) {
ThrowIfDisposed();
lock (monitor) {
if (free.TryTake(out var conn)) {
used.Add(conn);
return conn;
}
else {
Log.ForType<SqliteConnectionPool>().Warn("Thread " + Environment.CurrentManagedThreadId + " is starving for connections.");
}
}
Thread.Sleep(TimeSpan.FromMilliseconds(rand.Next(100, 200)));
}
}
public async Task<ISqliteConnection> Take() {
return await free.Pop(disposalToken);
private void Return(PooledConnection conn) {
ThrowIfDisposed();
lock (monitor) {
if (used.Remove(conn)) {
free.Add(conn);
}
}
}
private async Task Return(PooledConnection conn) {
await free.Push(conn, disposalToken);
}
public async ValueTask DisposeAsync() {
if (disposalToken.IsCancellationRequested) {
public void Dispose() {
if (isDisposed) {
return;
}
await disposalTokenSource.CancelAsync();
foreach (var conn in all) {
await conn.InnerConnection.CloseAsync();
await conn.InnerConnection.DisposeAsync();
isDisposed = true;
lock (monitor) {
while (free.TryTake(out var conn)) {
Close(conn.InnerConnection);
}
foreach (var conn in used) {
Close(conn.InnerConnection);
}
free.Dispose();
used.Clear();
}
disposalTokenSource.Dispose();
}
private sealed class PooledConnection(SqliteConnectionPool pool, SqliteConnection conn) : ISqliteConnection {
public SqliteConnection InnerConnection { get; } = conn;
private static void Close(SqliteConnection conn) {
conn.Close();
conn.Dispose();
}
private DbTransaction? activeTransaction;
private sealed class PooledConnection : ISqliteConnection {
public SqliteConnection InnerConnection { get; }
public async Task BeginTransactionAsync() {
if (activeTransaction != null) {
throw new InvalidOperationException("A transaction is already active.");
}
activeTransaction = await InnerConnection.BeginTransactionAsync();
private readonly SqliteConnectionPool pool;
public PooledConnection(SqliteConnectionPool pool, SqliteConnection conn) {
this.pool = pool;
this.InnerConnection = conn;
}
public async Task CommitTransactionAsync() {
if (activeTransaction == null) {
throw new InvalidOperationException("No active transaction to commit.");
}
await activeTransaction.CommitAsync();
await activeTransaction.DisposeAsync();
activeTransaction = null;
}
public async Task RollbackTransactionAsync() {
if (activeTransaction == null) {
throw new InvalidOperationException("No active transaction to rollback.");
}
await activeTransaction.RollbackAsync();
await activeTransaction.DisposeAsync();
activeTransaction = null;
}
public void AssignActiveTransaction(SqliteCommand command) {
command.Transaction = (SqliteTransaction?) activeTransaction;
}
public async ValueTask DisposeAsync() {
if (activeTransaction != null) {
await RollbackTransactionAsync();
}
await pool.Return(this);
void IDisposable.Dispose() {
pool.Return(this);
}
}
}

View File

@@ -1,4 +1,5 @@
using System;
using System.Data.Common;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
@@ -8,12 +9,21 @@ using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite.Utils;
static class SqliteExtensions {
public static SqliteCommand Command(this ISqliteConnection conn, [LanguageInjection("sql")] string sql) {
public static ValueTask<DbTransaction> BeginTransactionAsync(this ISqliteConnection conn) {
return conn.InnerConnection.BeginTransactionAsync();
}
public static SqliteCommand Command(this ISqliteConnection conn, string sql) {
var cmd = conn.InnerConnection.CreateCommand();
cmd.CommandText = sql;
return cmd;
}
public static void Execute(this ISqliteConnection conn, string sql) {
using var cmd = conn.Command(sql);
cmd.ExecuteNonQuery();
}
public static async Task<int> ExecuteAsync(this ISqliteConnection conn, [LanguageInjection("sql")] string sql, CancellationToken cancellationToken = default) {
await using var cmd = conn.Command(sql);
return await cmd.ExecuteNonQueryAsync(cancellationToken);
@@ -23,11 +33,12 @@ static class SqliteExtensions {
await using var cmd = conn.Command(sql);
await using var reader = await cmd.ExecuteReaderAsync(cancellationToken);
return await reader.ReadAsync(cancellationToken) ? readFunction(reader) : readFunction(null);
return reader.Read() ? readFunction(reader) : readFunction(null);
}
public static async Task<long> ExecuteLongScalarAsync(this SqliteCommand command) {
return (long) (await command.ExecuteScalarAsync())!;
public static object? SelectScalar(this ISqliteConnection conn, string sql) {
using var cmd = conn.Command(sql);
return cmd.ExecuteScalar();
}
public static SqliteCommand Insert(this ISqliteConnection conn, string tableName, (string Name, SqliteType Type)[] columns) {
@@ -57,7 +68,7 @@ static class SqliteExtensions {
public static SqliteCommand Delete(this ISqliteConnection conn, string tableName, (string Name, SqliteType Type) column) {
var cmd = conn.Command("DELETE FROM " + tableName + " WHERE " + column.Name + " = :" + column.Name);
CreateParameters(cmd, [column]);
CreateParameters(cmd, new[] { column });
return cmd;
}

View File

@@ -2,12 +2,12 @@ using System.Collections.Generic;
namespace DHT.Server.Database.Sqlite.Utils;
sealed class SqliteConditionBuilder {
sealed class SqliteWhereGenerator {
private readonly string? tableAlias;
private readonly bool invert;
private readonly List<string> conditions = [];
private readonly List<string> conditions = new ();
public SqliteConditionBuilder(string? tableAlias, bool invert) {
public SqliteWhereGenerator(string? tableAlias, bool invert) {
this.tableAlias = tableAlias;
this.invert = invert;
}
@@ -16,20 +16,16 @@ sealed class SqliteConditionBuilder {
conditions.Add(tableAlias == null ? condition : tableAlias + '.' + condition);
}
public string Build() {
public string Generate() {
if (conditions.Count == 0) {
return invert ? "FALSE" : "TRUE";
return "";
}
if (invert) {
return "NOT (" + string.Join(" AND ", conditions) + ")";
return " WHERE NOT (" + string.Join(" AND ", conditions) + ")";
}
else {
return string.Join(" AND ", conditions);
return " WHERE " + string.Join(" AND ", conditions);
}
}
public string BuildWhereClause() {
return " WHERE " + Build();
}
}

View File

@@ -1,43 +1,15 @@
using System;
using System.Collections.Frozen;
using System.Diagnostics.CodeAnalysis;
using System.Web;
namespace DHT.Server.Download;
static class DiscordCdn {
private static FrozenSet<string> CdnHosts { get; } = new[] {
private static FrozenSet<string> CdnHosts { get; } = new [] {
"cdn.discordapp.com",
"cdn.discord.com",
"media.discordapp.net"
}.ToFrozenSet();
private static bool IsCdnUrl(string originalUrl, [NotNullWhen(true)] out Uri? uri) {
return Uri.TryCreate(originalUrl, UriKind.Absolute, out uri) && CdnHosts.Contains(uri.Host);
}
public static string NormalizeUrl(string originalUrl) {
return IsCdnUrl(originalUrl, out var uri) ? DoNormalize(uri) : originalUrl;
}
public static bool NormalizeUrlAndReturnIfCdn(string originalUrl, out string normalizedUrl) {
if (IsCdnUrl(originalUrl, out var uri)) {
normalizedUrl = DoNormalize(uri);
return true;
}
else {
normalizedUrl = originalUrl;
return false;
}
}
private static string DoNormalize(Uri uri) {
var query = HttpUtility.ParseQueryString(uri.Query);
query.Remove("ex");
query.Remove("is");
query.Remove("hm");
return new UriBuilder(uri) { Query = query.ToString() }.Uri.ToString();
return Uri.TryCreate(originalUrl, UriKind.Absolute, out var uri) && CdnHosts.Contains(uri.Host) ? uri.GetLeftPart(UriPartial.Path) : originalUrl;
}
}

View File

@@ -1,21 +1,7 @@
using System;
using System.Net;
using DHT.Server.Data;
namespace DHT.Server.Download;
public readonly struct DownloadItem {
public string NormalizedUrl { get; init; }
public string DownloadUrl { get; init; }
public string? Type { get; init; }
public ulong? Size { get; init; }
internal Data.Download ToSuccess(long size) {
return new Data.Download(NormalizedUrl, DownloadUrl, DownloadStatus.Success, Type, (ulong) Math.Max(size, 0));
}
internal Data.Download ToFailure(HttpStatusCode? statusCode = null) {
var status = statusCode.HasValue ? (DownloadStatus) (int) statusCode : DownloadStatus.GenericError;
return new Data.Download(NormalizedUrl, DownloadUrl, status, Type, Size);
}
public ulong Size { get; init; }
}

View File

@@ -1,122 +0,0 @@
using System;
using System.IO;
using System.Net.Mime;
using System.Text.Json;
using System.Threading.Tasks;
using DHT.Server.Data;
using DHT.Server.Data.Embeds;
using DHT.Utils.Logging;
namespace DHT.Server.Download;
static class DownloadLinkExtractor {
private static readonly Log Log = Log.ForType(typeof(DownloadLinkExtractor));
public static Data.Download FromUserAvatar(ulong userId, string avatarPath) {
string url = $"https://cdn.discordapp.com/avatars/{userId}/{avatarPath}.webp";
return new Data.Download(url, url, DownloadStatus.Pending, MediaTypeNames.Image.Webp, size: null);
}
public static Data.Download FromEmoji(ulong emojiId, EmojiFlags flags) {
var isAnimated = flags.HasFlag(EmojiFlags.Animated);
string ext = isAnimated ? "gif" : "webp";
string type = isAnimated ? MediaTypeNames.Image.Gif : MediaTypeNames.Image.Webp;
string url = $"https://cdn.discordapp.com/emojis/{emojiId}.{ext}";
return new Data.Download(url, url, DownloadStatus.Pending, type, size: null);
}
public static Data.Download FromAttachment(Attachment attachment) {
return new Data.Download(attachment.NormalizedUrl, attachment.DownloadUrl, DownloadStatus.Pending, attachment.Type, attachment.Size);
}
public static async Task<Data.Download?> TryFromEmbedJson(Stream jsonStream) {
try {
return FromEmbed(await JsonSerializer.DeserializeAsync(jsonStream, DiscordEmbedJsonContext.Default.DiscordEmbedJson));
} catch (Exception e) {
Log.Error("Could not parse embed json: " + e);
return null;
}
}
public static Data.Download? TryFromEmbedJson(string json) {
try {
return FromEmbed(JsonSerializer.Deserialize(json, DiscordEmbedJsonContext.Default.DiscordEmbedJson));
} catch (Exception e) {
Log.Error("Could not parse embed json: " + e);
return null;
}
}
private static Data.Download? FromEmbed(DiscordEmbedJson? embed) {
if (embed is { Type: "image", Image.Url: {} imageUrl }) {
return FromEmbedImage(imageUrl);
}
else if (embed is { Type: "video", Video.Url: {} videoUrl }) {
return FromEmbedVideo(videoUrl);
}
else {
return null;
}
}
private static Data.Download? FromEmbedImage(string url) {
if (DiscordCdn.NormalizeUrlAndReturnIfCdn(url, out var normalizedUrl)) {
return new Data.Download(normalizedUrl, url, DownloadStatus.Pending, GuessImageType(normalizedUrl), size: null);
}
else {
Log.Debug("Skipping non-CDN image url: " + url);
return null;
}
}
private static Data.Download? FromEmbedVideo(string url) {
if (DiscordCdn.NormalizeUrlAndReturnIfCdn(url, out var normalizedUrl)) {
return new Data.Download(normalizedUrl, url, DownloadStatus.Pending, GuessVideoType(normalizedUrl), size: null);
}
else {
Log.Debug("Skipping non-CDN video url: " + url);
return null;
}
}
private static string? GuessImageType(string url) {
if (!Uri.TryCreate(url, UriKind.Absolute, out var uri)) {
return null;
}
ReadOnlySpan<char> extension = Path.GetExtension(uri.AbsolutePath).ToLowerInvariant();
// Remove Twitter quality suffix.
int colonIndex = extension.IndexOf(':');
if (colonIndex != -1) {
extension = extension[..colonIndex];
}
return extension switch {
".jpg" => MediaTypeNames.Image.Jpeg,
".jpeg" => MediaTypeNames.Image.Jpeg,
".png" => MediaTypeNames.Image.Png,
".gif" => MediaTypeNames.Image.Gif,
".webp" => MediaTypeNames.Image.Webp,
".bmp" => MediaTypeNames.Image.Bmp,
_ => null
};
}
private static string? GuessVideoType(string url) {
if (!Uri.TryCreate(url, UriKind.Absolute, out var uri)) {
return null;
}
string extension = Path.GetExtension(uri.AbsolutePath).ToLowerInvariant();
return extension switch {
".mp4" => "video/mp4",
".mpeg" => "video/mpeg",
".webm" => "video/webm",
".mov" => "video/quicktime",
_ => null
};
}
}

View File

@@ -1,7 +1,6 @@
using System;
using System.Threading;
using System.Threading.Tasks;
using DHT.Server.Data.Filters;
using DHT.Server.Database;
namespace DHT.Server.Download;
@@ -11,18 +10,16 @@ public sealed class Downloader {
public bool IsDownloading => current != null;
private readonly IDatabaseFile db;
private readonly int? concurrentDownloads;
private readonly SemaphoreSlim semaphore = new (1, 1);
internal Downloader(IDatabaseFile db, int? concurrentDownloads) {
internal Downloader(IDatabaseFile db) {
this.db = db;
this.concurrentDownloads = concurrentDownloads;
}
public async Task<IObservable<DownloadItem>> Start(DownloadItemFilter filter) {
public async Task<IObservable<DownloadItem>> Start() {
await semaphore.WaitAsync();
try {
current ??= new DownloaderTask(db, filter, concurrentDownloads);
current ??= new DownloaderTask(db);
return current.FinishedItems;
} finally {
semaphore.Release();

View File

@@ -5,7 +5,6 @@ using System.Reactive.Subjects;
using System.Threading;
using System.Threading.Channels;
using System.Threading.Tasks;
using DHT.Server.Data.Filters;
using DHT.Server.Database;
using DHT.Utils.Logging;
using DHT.Utils.Tasks;
@@ -15,14 +14,10 @@ namespace DHT.Server.Download;
sealed class DownloaderTask : IAsyncDisposable {
private static readonly Log Log = Log.ForType<DownloaderTask>();
private const int DefaultConcurrentDownloads = 4;
private const int DownloadTasks = 4;
private const int QueueSize = 25;
private const string UserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36";
private static int GetDownloadTaskCount(int? concurrentDownloads) {
return Math.Max(1, concurrentDownloads ?? DefaultConcurrentDownloads);
}
private readonly Channel<DownloadItem> downloadQueue = Channel.CreateBounded<DownloadItem>(new BoundedChannelOptions(QueueSize) {
SingleReader = false,
SingleWriter = true,
@@ -34,25 +29,23 @@ sealed class DownloaderTask : IAsyncDisposable {
private readonly CancellationToken cancellationToken;
private readonly IDatabaseFile db;
private readonly DownloadItemFilter filter;
private readonly ISubject<DownloadItem> finishedItemPublisher = Subject.Synchronize(new Subject<DownloadItem>());
private readonly Subject<DownloadItem> finishedItemPublisher = new ();
private readonly Task queueWriterTask;
private readonly Task[] downloadTasks;
public IObservable<DownloadItem> FinishedItems => finishedItemPublisher;
internal DownloaderTask(IDatabaseFile db, DownloadItemFilter filter, int? concurrentDownloads) {
internal DownloaderTask(IDatabaseFile db) {
this.db = db;
this.filter = filter;
this.cancellationToken = cancellationTokenSource.Token;
this.queueWriterTask = Task.Run(RunQueueWriterTask);
this.downloadTasks = Enumerable.Range(1, GetDownloadTaskCount(concurrentDownloads)).Select(taskIndex => Task.Run(() => RunDownloadTask(taskIndex))).ToArray();
this.downloadTasks = Enumerable.Range(1, DownloadTasks).Select(taskIndex => Task.Run(() => RunDownloadTask(taskIndex))).ToArray();
}
private async Task RunQueueWriterTask() {
while (await downloadQueue.Writer.WaitToWriteAsync(cancellationToken)) {
var newItems = await db.Downloads.PullPendingDownloadItems(QueueSize, filter, cancellationToken).ToListAsync(cancellationToken);
var newItems = await db.Downloads.PullEnqueuedDownloadItems(QueueSize, cancellationToken).ToListAsync(cancellationToken);
if (newItems.Count == 0) {
await Task.Delay(TimeSpan.FromMilliseconds(50), cancellationToken);
continue;
@@ -67,39 +60,24 @@ sealed class DownloaderTask : IAsyncDisposable {
private async Task RunDownloadTask(int taskIndex) {
var log = Log.ForType<DownloaderTask>("Task " + taskIndex);
var client = new HttpClient(new SocketsHttpHandler {
ConnectTimeout = TimeSpan.FromSeconds(30)
});
client.Timeout = Timeout.InfiniteTimeSpan;
var client = new HttpClient();
client.DefaultRequestHeaders.UserAgent.ParseAdd(UserAgent);
client.Timeout = TimeSpan.FromSeconds(30);
while (!cancellationToken.IsCancellationRequested) {
var item = await downloadQueue.Reader.ReadAsync(cancellationToken);
log.Debug("Downloading " + item.DownloadUrl + "...");
try {
var response = await client.SendAsync(new HttpRequestMessage(HttpMethod.Get, item.DownloadUrl), HttpCompletionOption.ResponseHeadersRead, cancellationToken);
response.EnsureSuccessStatusCode();
if (response.Content.Headers.ContentLength is {} contentLength) {
await using var stream = await response.Content.ReadAsStreamAsync(cancellationToken);
await db.Downloads.AddDownload(item.ToSuccess(contentLength), stream);
}
else {
await db.Downloads.AddDownload(item.ToFailure(), stream: null);
log.Error("Download response has no content length: " + item.DownloadUrl);
}
} catch (OperationCanceledException e) when (e.CancellationToken == cancellationToken) {
var downloadedBytes = await client.GetByteArrayAsync(item.DownloadUrl, cancellationToken);
await db.Downloads.AddDownload(Data.Download.NewSuccess(item, downloadedBytes));
} catch (OperationCanceledException) {
// Ignore.
} catch (TaskCanceledException e) when (e.InnerException is TimeoutException) {
await db.Downloads.AddDownload(item.ToFailure(), stream: null);
log.Error("Download timed out: " + item.DownloadUrl);
} catch (HttpRequestException e) {
await db.Downloads.AddDownload(item.ToFailure(e.StatusCode), stream: null);
await db.Downloads.AddDownload(Data.Download.NewFailure(item, e.StatusCode, item.Size));
log.Error(e);
} catch (Exception e) {
await db.Downloads.AddDownload(item.ToFailure(), stream: null);
await db.Downloads.AddDownload(Data.Download.NewFailure(item, null, item.Size));
log.Error(e);
} finally {
try {

View File

@@ -9,37 +9,37 @@ using Microsoft.AspNetCore.Http;
namespace DHT.Server.Endpoints;
abstract class BaseEndpoint(IDatabaseFile db) {
abstract class BaseEndpoint {
private static readonly Log Log = Log.ForType<BaseEndpoint>();
protected IDatabaseFile Db { get; } = db;
protected IDatabaseFile Db { get; }
protected BaseEndpoint(IDatabaseFile db) {
this.Db = db;
}
public async Task Handle(HttpContext ctx) {
var response = ctx.Response;
try {
response.StatusCode = (int) HttpStatusCode.OK;
await Respond(ctx.Request, response);
var output = await Respond(ctx);
await output.WriteTo(response);
} catch (HttpException e) {
Log.Error(e);
response.StatusCode = (int) e.StatusCode;
if (response.HasStarted) {
Log.Warn("Response has already started, cannot write status message: " + e.Message);
}
else {
await response.WriteAsync(e.Message);
}
await response.WriteAsync(e.Message);
} catch (Exception e) {
Log.Error(e);
response.StatusCode = (int) HttpStatusCode.InternalServerError;
}
}
protected abstract Task Respond(HttpRequest request, HttpResponse response);
protected abstract Task<IHttpOutput> Respond(HttpContext ctx);
protected static async Task<JsonElement> ReadJson(HttpRequest request) {
protected static async Task<JsonElement> ReadJson(HttpContext ctx) {
try {
return await request.ReadFromJsonAsync(JsonElementContext.Default.JsonElement);
return await ctx.Request.ReadFromJsonAsync(JsonElementContext.Default.JsonElement);
} catch (JsonException) {
throw new HttpException(HttpStatusCode.UnsupportedMediaType, "This endpoint only accepts JSON.");
}

View File

@@ -0,0 +1,24 @@
using System.Net;
using System.Threading.Tasks;
using DHT.Server.Data;
using DHT.Server.Database;
using DHT.Utils.Http;
using Microsoft.AspNetCore.Http;
namespace DHT.Server.Endpoints;
sealed class GetAttachmentEndpoint : BaseEndpoint {
public GetAttachmentEndpoint(IDatabaseFile db) : base(db) {}
protected override async Task<IHttpOutput> Respond(HttpContext ctx) {
string attachmentUrl = WebUtility.UrlDecode((string) ctx.Request.RouteValues["url"]!);
DownloadedAttachment? maybeDownloadedAttachment = await Db.Downloads.GetDownloadedAttachment(attachmentUrl);
if (maybeDownloadedAttachment is {} downloadedAttachment) {
return new HttpOutput.File(downloadedAttachment.Type, downloadedAttachment.Data);
}
else {
return new HttpOutput.Redirect(attachmentUrl, permanent: false);
}
}
}

View File

@@ -1,19 +0,0 @@
using System.Net;
using System.Threading.Tasks;
using DHT.Server.Database;
using DHT.Server.Download;
using DHT.Utils.Http;
using Microsoft.AspNetCore.Http;
namespace DHT.Server.Endpoints;
sealed class GetDownloadedFileEndpoint(IDatabaseFile db) : BaseEndpoint(db) {
protected override async Task Respond(HttpRequest request, HttpResponse response) {
string url = WebUtility.UrlDecode((string) request.RouteValues["url"]!);
string normalizedUrl = DiscordCdn.NormalizeUrl(url);
if (!await Db.Downloads.GetSuccessfulDownloadWithData(normalizedUrl, (download, stream) => response.WriteStreamAsync(download.Type, download.Size, stream))) {
response.Redirect(url, permanent: false);
}
}
}

View File

@@ -1,5 +1,5 @@
using System.Net.Mime;
using System.Reflection;
using System.Text;
using System.Threading.Tasks;
using System.Web;
using DHT.Server.Database;
@@ -10,19 +10,25 @@ using Microsoft.AspNetCore.Http;
namespace DHT.Server.Endpoints;
sealed class GetTrackingScriptEndpoint(IDatabaseFile db, ServerParameters parameters) : BaseEndpoint(db) {
sealed class GetTrackingScriptEndpoint : BaseEndpoint {
private static ResourceLoader Resources { get; } = new (Assembly.GetExecutingAssembly());
private readonly ServerParameters serverParameters;
public GetTrackingScriptEndpoint(IDatabaseFile db, ServerParameters parameters) : base(db) {
serverParameters = parameters;
}
protected override async Task Respond(HttpRequest request, HttpResponse response) {
protected override async Task<IHttpOutput> Respond(HttpContext ctx) {
string bootstrap = await Resources.ReadTextAsync("Tracker/bootstrap.js");
string script = bootstrap.Replace("= 0; /*[PORT]*/", "= " + parameters.Port + ";")
.Replace("/*[TOKEN]*/", HttpUtility.JavaScriptStringEncode(parameters.Token))
string script = bootstrap.Replace("= 0; /*[PORT]*/", "= " + serverParameters.Port + ";")
.Replace("/*[TOKEN]*/", HttpUtility.JavaScriptStringEncode(serverParameters.Token))
.Replace("/*[IMPORTS]*/", await Resources.ReadJoinedAsync("Tracker/scripts/", '\n'))
.Replace("/*[CSS-CONTROLLER]*/", await Resources.ReadTextAsync("Tracker/styles/controller.css"))
.Replace("/*[CSS-SETTINGS]*/", await Resources.ReadTextAsync("Tracker/styles/settings.css"))
.Replace("/*[DEBUGGER]*/", request.Query.ContainsKey("debug") ? "debugger;" : "");
.Replace("/*[DEBUGGER]*/", ctx.Request.Query.ContainsKey("debug") ? "debugger;" : "");
response.Headers.Append("X-DHT", "1");
await response.WriteTextAsync(MediaTypeNames.Text.JavaScript, script);
ctx.Response.Headers.Append("X-DHT", "1");
return new HttpOutput.File("text/javascript", Encoding.UTF8.GetBytes(script));
}
}

View File

@@ -8,14 +8,18 @@ using Microsoft.AspNetCore.Http;
namespace DHT.Server.Endpoints;
sealed class TrackChannelEndpoint(IDatabaseFile db) : BaseEndpoint(db) {
protected override async Task Respond(HttpRequest request, HttpResponse response) {
var root = await ReadJson(request);
sealed class TrackChannelEndpoint : BaseEndpoint {
public TrackChannelEndpoint(IDatabaseFile db) : base(db) {}
protected override async Task<IHttpOutput> Respond(HttpContext ctx) {
var root = await ReadJson(ctx);
var server = ReadServer(root.RequireObject("server"), "server");
var channel = ReadChannel(root.RequireObject("channel"), "channel", server.Id);
await Db.Servers.Add([server]);
await Db.Channels.Add([channel]);
return HttpOutput.None;
}
private static Data.Server ReadServer(JsonElement json, string path) => new () {

View File

@@ -15,12 +15,14 @@ using Microsoft.AspNetCore.Http;
namespace DHT.Server.Endpoints;
sealed class TrackMessagesEndpoint(IDatabaseFile db) : BaseEndpoint(db) {
sealed class TrackMessagesEndpoint : BaseEndpoint {
private const string HasNewMessages = "1";
private const string NoNewMessages = "0";
protected override async Task Respond(HttpRequest request, HttpResponse response) {
var root = await ReadJson(request);
public TrackMessagesEndpoint(IDatabaseFile db) : base(db) {}
protected override async Task<IHttpOutput> Respond(HttpContext ctx) {
var root = await ReadJson(ctx);
if (root.ValueKind != JsonValueKind.Array) {
throw new HttpException(HttpStatusCode.BadRequest, "Expected root element to be an array.");
@@ -41,7 +43,7 @@ sealed class TrackMessagesEndpoint(IDatabaseFile db) : BaseEndpoint(db) {
await Db.Messages.Add(messages);
await response.WriteTextAsync(anyNewMessages ? HasNewMessages : NoNewMessages);
return new HttpOutput.Text(anyNewMessages ? HasNewMessages : NoNewMessages);
}
private static Message ReadMessage(JsonElement json, string path) => new () {

View File

@@ -8,9 +8,11 @@ using Microsoft.AspNetCore.Http;
namespace DHT.Server.Endpoints;
sealed class TrackUsersEndpoint(IDatabaseFile db) : BaseEndpoint(db) {
protected override async Task Respond(HttpRequest request, HttpResponse response) {
var root = await ReadJson(request);
sealed class TrackUsersEndpoint : BaseEndpoint {
public TrackUsersEndpoint(IDatabaseFile db) : base(db) {}
protected override async Task<IHttpOutput> Respond(HttpContext ctx) {
var root = await ReadJson(ctx);
if (root.ValueKind != JsonValueKind.Array) {
throw new HttpException(HttpStatusCode.BadRequest, "Expected root element to be an array.");
@@ -24,6 +26,8 @@ sealed class TrackUsersEndpoint(IDatabaseFile db) : BaseEndpoint(db) {
}
await Db.Users.Add(users);
return HttpOutput.None;
}
private static User ReadUser(JsonElement json, string path) => new () {

Some files were not shown because too many files have changed in this diff Show More