1
0
mirror of https://github.com/chylex/Discord-History-Tracker.git synced 2025-12-19 04:58:53 +01:00

1 Commits

Author SHA1 Message Date
11f7d4a49f Rewrite database interface to be asynchronous and improve UI 2023-12-29 21:26:21 +01:00
69 changed files with 1108 additions and 857 deletions

View File

@@ -1,5 +1,4 @@
using System; using System;
using System.Collections.Generic;
using DHT.Utils.Logging; using DHT.Utils.Logging;
namespace DHT.Desktop; namespace DHT.Desktop;
@@ -9,15 +8,15 @@ sealed class Arguments {
private const int FirstArgument = 1; private const int FirstArgument = 1;
public static Arguments Empty => new (Array.Empty<string>()); public static Arguments Empty => new(Array.Empty<string>());
public bool Console { get; } public bool Console { get; }
public string? DatabaseFile { get; } public string? DatabaseFile { get; }
public ushort? ServerPort { get; } public ushort? ServerPort { get; }
public string? ServerToken { get; } public string? ServerToken { get; }
public Arguments(IReadOnlyList<string> args) { public Arguments(string[] args) {
for (int i = FirstArgument; i < args.Count; i++) { for (int i = FirstArgument; i < args.Length; i++) {
string key = args[i]; string key = args[i];
switch (key) { switch (key) {
@@ -36,7 +35,7 @@ sealed class Arguments {
value = key; value = key;
key = "-db"; key = "-db";
} }
else if (i >= args.Count - 1) { else if (i >= args.Length - 1) {
Log.Warn("Missing value for command line argument: " + key); Log.Warn("Missing value for command line argument: " + key);
continue; continue;
} }

View File

@@ -19,13 +19,13 @@ sealed class BytesValueConverter : IValueConverter {
} }
} }
private static readonly Unit[] Units = [ private static readonly Unit[] Units = {
new Unit("B", decimalPlaces: 0), new ("B", decimalPlaces: 0),
new Unit("kB", decimalPlaces: 0), new ("kB", decimalPlaces: 0),
new Unit("MB", decimalPlaces: 1), new ("MB", decimalPlaces: 1),
new Unit("GB", decimalPlaces: 1), new ("GB", decimalPlaces: 1),
new Unit("TB", decimalPlaces: 1) new ("TB", decimalPlaces: 1)
]; };
private const int Scale = 1000; private const int Scale = 1000;

View File

@@ -1,9 +1,11 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.IO; using System.IO;
using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using Avalonia.Controls; using Avalonia.Controls;
using Avalonia.Platform.Storage; using Avalonia.Platform.Storage;
using Avalonia.Threading;
using DHT.Desktop.Dialogs.File; using DHT.Desktop.Dialogs.File;
using DHT.Desktop.Dialogs.Message; using DHT.Desktop.Dialogs.Message;
using DHT.Server.Database; using DHT.Server.Database;
@@ -43,10 +45,15 @@ static class DatabaseGui {
} }
public static async Task<IDatabaseFile?> TryOpenOrCreateDatabaseFromPath(string path, Window window, ISchemaUpgradeCallbacks schemaUpgradeCallbacks) { public static async Task<IDatabaseFile?> TryOpenOrCreateDatabaseFromPath(string path, Window window, ISchemaUpgradeCallbacks schemaUpgradeCallbacks) {
var prevSynchronizationContext = SynchronizationContext.Current;
SynchronizationContext.SetSynchronizationContext(new AvaloniaSynchronizationContext());
var taskScheduler = TaskScheduler.FromCurrentSynchronizationContext();
SynchronizationContext.SetSynchronizationContext(prevSynchronizationContext);
IDatabaseFile? file = null; IDatabaseFile? file = null;
try { try {
file = await SqliteDatabaseFile.OpenOrCreate(path, schemaUpgradeCallbacks); file = await SqliteDatabaseFile.OpenOrCreate(path, schemaUpgradeCallbacks, taskScheduler);
} catch (InvalidDatabaseVersionException ex) { } catch (InvalidDatabaseVersionException ex) {
await Dialog.ShowOk(window, "Database Error", "Database '" + Path.GetFileName(path) + "' appears to be corrupted (invalid version: " + ex.Version + ")."); await Dialog.ShowOk(window, "Database Error", "Database '" + Path.GetFileName(path) + "' appears to be corrupted (invalid version: " + ex.Version + ").");
} catch (DatabaseTooNewException ex) { } catch (DatabaseTooNewException ex) {

View File

@@ -23,7 +23,6 @@
<PackageReference Include="Avalonia.Fonts.Inter" Version="11.0.6" /> <PackageReference Include="Avalonia.Fonts.Inter" Version="11.0.6" />
<PackageReference Include="Avalonia.ReactiveUI" Version="11.0.6" /> <PackageReference Include="Avalonia.ReactiveUI" Version="11.0.6" />
<PackageReference Include="Avalonia.Themes.Fluent" Version="11.0.6" /> <PackageReference Include="Avalonia.Themes.Fluent" Version="11.0.6" />
<PackageReference Include="CommunityToolkit.Mvvm" Version="999.0.0-build.0.g0d941a6a62" />
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>

View File

@@ -38,7 +38,7 @@
<ItemsRepeater ItemsSource="{Binding Items}"> <ItemsRepeater ItemsSource="{Binding Items}">
<ItemsRepeater.ItemTemplate> <ItemsRepeater.ItemTemplate>
<DataTemplate> <DataTemplate>
<CheckBox IsChecked="{Binding IsChecked}"> <CheckBox IsChecked="{Binding Checked}">
<Label> <Label>
<TextBlock Text="{Binding Title}" TextWrapping="Wrap" /> <TextBlock Text="{Binding Title}" TextWrapping="Wrap" />
</Label> </Label>

View File

@@ -2,11 +2,11 @@ using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.ComponentModel; using System.ComponentModel;
using System.Linq; using System.Linq;
using CommunityToolkit.Mvvm.ComponentModel; using DHT.Utils.Models;
namespace DHT.Desktop.Dialogs.CheckBox; namespace DHT.Desktop.Dialogs.CheckBox;
class CheckBoxDialogModel : ObservableObject { class CheckBoxDialogModel : BaseModel {
public string Title { get; init; } = ""; public string Title { get; init; } = "";
private IReadOnlyList<CheckBoxItem> items = Array.Empty<CheckBoxItem>(); private IReadOnlyList<CheckBoxItem> items = Array.Empty<CheckBoxItem>();
@@ -29,8 +29,8 @@ class CheckBoxDialogModel : ObservableObject {
private bool pauseCheckEvents = false; private bool pauseCheckEvents = false;
public bool AreAllSelected => Items.All(static item => item.IsChecked); public bool AreAllSelected => Items.All(static item => item.Checked);
public bool AreNoneSelected => Items.All(static item => !item.IsChecked); public bool AreNoneSelected => Items.All(static item => !item.Checked);
public void SelectAll() => SetAllChecked(true); public void SelectAll() => SetAllChecked(true);
public void SelectNone() => SetAllChecked(false); public void SelectNone() => SetAllChecked(false);
@@ -39,7 +39,7 @@ class CheckBoxDialogModel : ObservableObject {
pauseCheckEvents = true; pauseCheckEvents = true;
foreach (var item in Items) { foreach (var item in Items) {
item.IsChecked = isChecked; item.Checked = isChecked;
} }
pauseCheckEvents = false; pauseCheckEvents = false;
@@ -52,16 +52,16 @@ class CheckBoxDialogModel : ObservableObject {
} }
private void OnItemPropertyChanged(object? sender, PropertyChangedEventArgs e) { private void OnItemPropertyChanged(object? sender, PropertyChangedEventArgs e) {
if (!pauseCheckEvents && e.PropertyName == nameof(CheckBoxItem.IsChecked)) { if (!pauseCheckEvents && e.PropertyName == nameof(CheckBoxItem.Checked)) {
UpdateBulkButtons(); UpdateBulkButtons();
} }
} }
} }
sealed class CheckBoxDialogModel<T> : CheckBoxDialogModel { sealed class CheckBoxDialogModel<T> : CheckBoxDialogModel {
private new IReadOnlyList<CheckBoxItem<T>> Items { get; } public new IReadOnlyList<CheckBoxItem<T>> Items { get; }
public IEnumerable<CheckBoxItem<T>> SelectedItems => Items.Where(static item => item.IsChecked); public IEnumerable<CheckBoxItem<T>> SelectedItems => Items.Where(static item => item.Checked);
public CheckBoxDialogModel(IEnumerable<CheckBoxItem<T>> items) { public CheckBoxDialogModel(IEnumerable<CheckBoxItem<T>> items) {
this.Items = new List<CheckBoxItem<T>>(items); this.Items = new List<CheckBoxItem<T>>(items);

View File

@@ -1,13 +1,17 @@
using CommunityToolkit.Mvvm.ComponentModel; using DHT.Utils.Models;
namespace DHT.Desktop.Dialogs.CheckBox; namespace DHT.Desktop.Dialogs.CheckBox;
partial class CheckBoxItem : ObservableObject { class CheckBoxItem : BaseModel {
public string Title { get; init; } = ""; public string Title { get; init; } = "";
public object? Item { get; init; } = null; public object? Item { get; init; } = null;
[ObservableProperty]
private bool isChecked = false; private bool isChecked = false;
public bool Checked {
get => isChecked;
set => Change(ref isChecked, value);
}
} }
sealed class CheckBoxItem<T> : CheckBoxItem { sealed class CheckBoxItem<T> : CheckBoxItem {

View File

@@ -4,10 +4,11 @@ using System.Linq;
using System.Threading.Tasks; using System.Threading.Tasks;
using Avalonia.Threading; using Avalonia.Threading;
using DHT.Desktop.Common; using DHT.Desktop.Common;
using DHT.Utils.Models;
namespace DHT.Desktop.Dialogs.Progress; namespace DHT.Desktop.Dialogs.Progress;
sealed class ProgressDialogModel { sealed class ProgressDialogModel : BaseModel {
public string Title { get; init; } = ""; public string Title { get; init; } = "";
public IReadOnlyList<ProgressItem> Items { get; } = Array.Empty<ProgressItem>(); public IReadOnlyList<ProgressItem> Items { get; } = Array.Empty<ProgressItem>();

View File

@@ -1,12 +1,18 @@
using CommunityToolkit.Mvvm.ComponentModel; using DHT.Utils.Models;
namespace DHT.Desktop.Dialogs.Progress; namespace DHT.Desktop.Dialogs.Progress;
sealed partial class ProgressItem : ObservableObject { sealed class ProgressItem : BaseModel {
[ObservableProperty(Setter = Access.Private)]
[NotifyPropertyChangedFor(nameof(Opacity))]
private bool isVisible = false; private bool isVisible = false;
public bool IsVisible {
get => isVisible;
private set {
Change(ref isVisible, value);
OnPropertyChanged(nameof(Opacity));
}
}
public double Opacity => IsVisible ? 1.0 : 0.0; public double Opacity => IsVisible ? 1.0 : 0.0;
private string message = ""; private string message = "";
@@ -14,17 +20,29 @@ sealed partial class ProgressItem : ObservableObject {
public string Message { public string Message {
get => message; get => message;
set { set {
SetProperty(ref message, value); Change(ref message, value);
IsVisible = !string.IsNullOrEmpty(value); IsVisible = !string.IsNullOrEmpty(value);
} }
} }
[ObservableProperty]
private string items = ""; private string items = "";
[ObservableProperty] public string Items {
get => items;
set => Change(ref items, value);
}
private int progress = 0; private int progress = 0;
[ObservableProperty] public int Progress {
get => progress;
set => Change(ref progress, value);
}
private bool isIndeterminate; private bool isIndeterminate;
public bool IsIndeterminate {
get => isIndeterminate;
set => Change(ref isIndeterminate, value);
}
} }

View File

@@ -2,11 +2,11 @@ using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.ComponentModel; using System.ComponentModel;
using System.Linq; using System.Linq;
using CommunityToolkit.Mvvm.ComponentModel; using DHT.Utils.Models;
namespace DHT.Desktop.Dialogs.TextBox; namespace DHT.Desktop.Dialogs.TextBox;
class TextBoxDialogModel : ObservableObject { class TextBoxDialogModel : BaseModel {
public string Title { get; init; } = ""; public string Title { get; init; } = "";
public string Description { get; init; } = ""; public string Description { get; init; } = "";
@@ -36,7 +36,7 @@ class TextBoxDialogModel : ObservableObject {
} }
sealed class TextBoxDialogModel<T> : TextBoxDialogModel { sealed class TextBoxDialogModel<T> : TextBoxDialogModel {
private new IReadOnlyList<TextBoxItem<T>> Items { get; } public new IReadOnlyList<TextBoxItem<T>> Items { get; }
public IEnumerable<TextBoxItem<T>> ValidItems => Items.Where(static item => item.IsValid); public IEnumerable<TextBoxItem<T>> ValidItems => Items.Where(static item => item.IsValid);

View File

@@ -1,11 +1,11 @@
using System; using System;
using System.Collections; using System.Collections;
using System.ComponentModel; using System.ComponentModel;
using CommunityToolkit.Mvvm.ComponentModel; using DHT.Utils.Models;
namespace DHT.Desktop.Dialogs.TextBox; namespace DHT.Desktop.Dialogs.TextBox;
class TextBoxItem : ObservableObject, INotifyDataErrorInfo { class TextBoxItem : BaseModel, INotifyDataErrorInfo {
public string Title { get; init; } = ""; public string Title { get; init; } = "";
public object? Item { get; init; } = null; public object? Item { get; init; } = null;
@@ -17,7 +17,7 @@ class TextBoxItem : ObservableObject, INotifyDataErrorInfo {
public string Value { public string Value {
get => this.value; get => this.value;
set { set {
SetProperty(ref this.value, value); Change(ref this.value, value);
ErrorsChanged?.Invoke(this, new DataErrorsChangedEventArgs(nameof(Value))); ErrorsChanged?.Invoke(this, new DataErrorsChangedEventArgs(nameof(Value)));
} }
} }

View File

@@ -39,7 +39,7 @@ static class DiscordAppSettings {
public static async Task<bool?> AreDevToolsEnabled() { public static async Task<bool?> AreDevToolsEnabled() {
try { try {
var settingsJson = await ReadSettingsJson(); var settingsJson = await ReadSettingsJson().ConfigureAwait(false);
return AreDevToolsEnabled(settingsJson); return AreDevToolsEnabled(settingsJson);
} catch (Exception e) { } catch (Exception e) {
Log.Error("Cannot read settings file."); Log.Error("Cannot read settings file.");

View File

@@ -5,4 +5,4 @@ namespace DHT.Desktop.Discord;
[JsonSourceGenerationOptions(GenerationMode = JsonSourceGenerationMode.Default, WriteIndented = true)] [JsonSourceGenerationOptions(GenerationMode = JsonSourceGenerationMode.Default, WriteIndented = true)]
[JsonSerializable(typeof(JsonObject))] [JsonSerializable(typeof(JsonObject))]
sealed partial class DiscordAppSettingsJsonContext : JsonSerializerContext; sealed partial class DiscordAppSettingsJsonContext : JsonSerializerContext {}

View File

@@ -1,18 +1,17 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.ComponentModel; using System.ComponentModel;
using System.Reactive.Linq;
using System.Threading.Tasks; using System.Threading.Tasks;
using Avalonia.ReactiveUI;
using CommunityToolkit.Mvvm.ComponentModel;
using DHT.Desktop.Common; using DHT.Desktop.Common;
using DHT.Server; using DHT.Server;
using DHT.Server.Data.Filters; using DHT.Server.Data.Filters;
using DHT.Server.Database;
using DHT.Utils.Models;
using DHT.Utils.Tasks; using DHT.Utils.Tasks;
namespace DHT.Desktop.Main.Controls; namespace DHT.Desktop.Main.Controls;
sealed partial class AttachmentFilterPanelModel : ObservableObject, IDisposable { sealed class AttachmentFilterPanelModel : BaseModel, IDisposable {
public sealed record Unit(string Name, uint Scale); public sealed record Unit(string Name, uint Scale);
private static readonly Unit[] AllUnits = [ private static readonly Unit[] AllUnits = [
@@ -29,15 +28,25 @@ sealed partial class AttachmentFilterPanelModel : ObservableObject, IDisposable
public string FilterStatisticsText { get; private set; } = ""; public string FilterStatisticsText { get; private set; } = "";
[ObservableProperty]
private bool limitSize = false; private bool limitSize = false;
[ObservableProperty]
private ulong maximumSize = 0L; private ulong maximumSize = 0L;
[ObservableProperty]
private Unit maximumSizeUnit = AllUnits[0]; private Unit maximumSizeUnit = AllUnits[0];
public bool LimitSize {
get => limitSize;
set => Change(ref limitSize, value);
}
public ulong MaximumSize {
get => maximumSize;
set => Change(ref maximumSize, value);
}
public Unit MaximumSizeUnit {
get => maximumSizeUnit;
set => Change(ref maximumSizeUnit, value);
}
public IEnumerable<Unit> Units => AllUnits; public IEnumerable<Unit> Units => AllUnits;
private readonly State state; private readonly State state;
@@ -45,8 +54,6 @@ sealed partial class AttachmentFilterPanelModel : ObservableObject, IDisposable
private readonly RestartableTask<long> matchingAttachmentCountTask; private readonly RestartableTask<long> matchingAttachmentCountTask;
private long? matchingAttachmentCount; private long? matchingAttachmentCount;
private readonly IDisposable attachmentCountSubscription;
private long? totalAttachmentCount; private long? totalAttachmentCount;
[Obsolete("Designer")] [Obsolete("Designer")]
@@ -57,15 +64,15 @@ sealed partial class AttachmentFilterPanelModel : ObservableObject, IDisposable
this.verb = verb; this.verb = verb;
this.matchingAttachmentCountTask = new RestartableTask<long>(SetAttachmentCounts, TaskScheduler.FromCurrentSynchronizationContext()); this.matchingAttachmentCountTask = new RestartableTask<long>(SetAttachmentCounts, TaskScheduler.FromCurrentSynchronizationContext());
this.attachmentCountSubscription = state.Db.Attachments.TotalCount.ObserveOn(AvaloniaScheduler.Instance).Subscribe(OnAttachmentCountChanged);
UpdateFilterStatistics(); UpdateFilterStatistics();
PropertyChanged += OnPropertyChanged; PropertyChanged += OnPropertyChanged;
state.Db.Statistics.PropertyChanged += OnDbStatisticsChanged;
} }
public void Dispose() { public void Dispose() {
attachmentCountSubscription.Dispose(); state.Db.Statistics.PropertyChanged -= OnDbStatisticsChanged;
} }
private void OnPropertyChanged(object? sender, PropertyChangedEventArgs e) { private void OnPropertyChanged(object? sender, PropertyChangedEventArgs e) {
@@ -74,11 +81,12 @@ sealed partial class AttachmentFilterPanelModel : ObservableObject, IDisposable
} }
} }
private void OnAttachmentCountChanged(long newAttachmentCount) { private void OnDbStatisticsChanged(object? sender, PropertyChangedEventArgs e) {
totalAttachmentCount = newAttachmentCount; if (e.PropertyName == nameof(DatabaseStatistics.TotalAttachments)) {
totalAttachmentCount = state.Db.Statistics.TotalAttachments;
UpdateFilterStatistics(); UpdateFilterStatistics();
} }
}
private void UpdateFilterStatistics() { private void UpdateFilterStatistics() {
var filter = CreateFilter(); var filter = CreateFilter();
@@ -90,7 +98,7 @@ sealed partial class AttachmentFilterPanelModel : ObservableObject, IDisposable
else { else {
matchingAttachmentCount = null; matchingAttachmentCount = null;
UpdateFilterStatisticsText(); UpdateFilterStatisticsText();
matchingAttachmentCountTask.Restart(cancellationToken => state.Db.Attachments.Count(filter, cancellationToken)); matchingAttachmentCountTask.Restart(cancellationToken => state.Db.Downloads.CountAttachments(filter, cancellationToken));
} }
} }

View File

@@ -2,12 +2,9 @@ using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.ComponentModel; using System.ComponentModel;
using System.Linq; using System.Linq;
using System.Reactive.Linq;
using System.Text; using System.Text;
using System.Threading.Tasks; using System.Threading.Tasks;
using Avalonia.Controls; using Avalonia.Controls;
using Avalonia.ReactiveUI;
using CommunityToolkit.Mvvm.ComponentModel;
using DHT.Desktop.Common; using DHT.Desktop.Common;
using DHT.Desktop.Dialogs.CheckBox; using DHT.Desktop.Dialogs.CheckBox;
using DHT.Desktop.Dialogs.Message; using DHT.Desktop.Dialogs.Message;
@@ -15,11 +12,13 @@ using DHT.Desktop.Dialogs.Progress;
using DHT.Server; using DHT.Server;
using DHT.Server.Data; using DHT.Server.Data;
using DHT.Server.Data.Filters; using DHT.Server.Data.Filters;
using DHT.Server.Database;
using DHT.Utils.Models;
using DHT.Utils.Tasks; using DHT.Utils.Tasks;
namespace DHT.Desktop.Main.Controls; namespace DHT.Desktop.Main.Controls;
sealed partial class MessageFilterPanelModel : ObservableObject, IDisposable { sealed class MessageFilterPanelModel : BaseModel, IDisposable {
private static readonly HashSet<string> FilterProperties = [ private static readonly HashSet<string> FilterProperties = [
nameof(FilterByDate), nameof(FilterByDate),
nameof(StartDate), nameof(StartDate),
@@ -36,49 +35,71 @@ sealed partial class MessageFilterPanelModel : ObservableObject, IDisposable {
public bool HasAnyFilters => FilterByDate || FilterByChannel || FilterByUser; public bool HasAnyFilters => FilterByDate || FilterByChannel || FilterByUser;
[ObservableProperty]
private bool filterByDate = false; private bool filterByDate = false;
[ObservableProperty]
private DateTime? startDate = null; private DateTime? startDate = null;
[ObservableProperty]
private DateTime? endDate = null; private DateTime? endDate = null;
[ObservableProperty]
private bool filterByChannel = false; private bool filterByChannel = false;
[ObservableProperty]
private HashSet<ulong>? includedChannels = null; private HashSet<ulong>? includedChannels = null;
[ObservableProperty]
private bool filterByUser = false; private bool filterByUser = false;
[ObservableProperty]
private HashSet<ulong>? includedUsers = null; private HashSet<ulong>? includedUsers = null;
[ObservableProperty] public bool FilterByDate {
get => filterByDate;
set => Change(ref filterByDate, value);
}
public DateTime? StartDate {
get => startDate;
set => Change(ref startDate, value);
}
public DateTime? EndDate {
get => endDate;
set => Change(ref endDate, value);
}
public bool FilterByChannel {
get => filterByChannel;
set => Change(ref filterByChannel, value);
}
public HashSet<ulong>? IncludedChannels {
get => includedChannels;
set => Change(ref includedChannels, value);
}
public bool FilterByUser {
get => filterByUser;
set => Change(ref filterByUser, value);
}
public HashSet<ulong>? IncludedUsers {
get => includedUsers;
set => Change(ref includedUsers, value);
}
private string channelFilterLabel = ""; private string channelFilterLabel = "";
[ObservableProperty] public string ChannelFilterLabel {
get => channelFilterLabel;
set => Change(ref channelFilterLabel, value);
}
private string userFilterLabel = ""; private string userFilterLabel = "";
public string UserFilterLabel {
get => userFilterLabel;
set => Change(ref userFilterLabel, value);
}
private readonly Window window; private readonly Window window;
private readonly State state; private readonly State state;
private readonly string verb; private readonly string verb;
private readonly RestartableTask<long> exportedMessageCountTask; private readonly RestartableTask<long> exportedMessageCountTask;
private long? exportedMessageCount; private long? exportedMessageCount;
private readonly IDisposable messageCountSubscription;
private long? totalMessageCount; private long? totalMessageCount;
private readonly IDisposable channelCountSubscription;
private long? totalChannelCount;
private readonly IDisposable userCountSubscription;
private long? totalUserCount;
[Obsolete("Designer")] [Obsolete("Designer")]
public MessageFilterPanelModel() : this(null!, State.Dummy) {} public MessageFilterPanelModel() : this(null!, State.Dummy) {}
@@ -89,23 +110,17 @@ sealed partial class MessageFilterPanelModel : ObservableObject, IDisposable {
this.exportedMessageCountTask = new RestartableTask<long>(SetExportedMessageCount, TaskScheduler.FromCurrentSynchronizationContext()); this.exportedMessageCountTask = new RestartableTask<long>(SetExportedMessageCount, TaskScheduler.FromCurrentSynchronizationContext());
this.messageCountSubscription = state.Db.Messages.TotalCount.ObserveOn(AvaloniaScheduler.Instance).Subscribe(OnMessageCountChanged);
this.channelCountSubscription = state.Db.Channels.TotalCount.ObserveOn(AvaloniaScheduler.Instance).Subscribe(OnChannelCountChanged);
this.userCountSubscription = state.Db.Users.TotalCount.ObserveOn(AvaloniaScheduler.Instance).Subscribe(OnUserCountChanged);
UpdateFilterStatistics(); UpdateFilterStatistics();
UpdateChannelFilterLabel(); UpdateChannelFilterLabel();
UpdateUserFilterLabel(); UpdateUserFilterLabel();
PropertyChanged += OnPropertyChanged; PropertyChanged += OnPropertyChanged;
state.Db.Statistics.PropertyChanged += OnDbStatisticsChanged;
} }
public void Dispose() { public void Dispose() {
exportedMessageCountTask.Cancel(); exportedMessageCountTask.Cancel();
state.Db.Statistics.PropertyChanged -= OnDbStatisticsChanged;
messageCountSubscription.Dispose();
channelCountSubscription.Dispose();
userCountSubscription.Dispose();
} }
private void OnPropertyChanged(object? sender, PropertyChangedEventArgs e) { private void OnPropertyChanged(object? sender, PropertyChangedEventArgs e) {
@@ -122,42 +137,30 @@ sealed partial class MessageFilterPanelModel : ObservableObject, IDisposable {
} }
} }
private void OnMessageCountChanged(long newMessageCount) { private void OnDbStatisticsChanged(object? sender, PropertyChangedEventArgs e) {
totalMessageCount = newMessageCount; if (e.PropertyName == nameof(DatabaseStatistics.TotalMessages)) {
totalMessageCount = state.Db.Statistics.TotalMessages;
UpdateFilterStatistics(); UpdateFilterStatistics();
} }
else if (e.PropertyName == nameof(DatabaseStatistics.TotalChannels)) {
private void OnChannelCountChanged(long newChannelCount) {
totalChannelCount = newChannelCount;
UpdateChannelFilterLabel(); UpdateChannelFilterLabel();
} }
else if (e.PropertyName == nameof(DatabaseStatistics.TotalUsers)) {
private void OnUserCountChanged(long newUserCount) {
totalUserCount = newUserCount;
UpdateUserFilterLabel(); UpdateUserFilterLabel();
} }
}
private void UpdateChannelFilterLabel() { private void UpdateChannelFilterLabel() {
if (totalChannelCount.HasValue) { long total = state.Db.Statistics.TotalChannels;
long total = totalChannelCount.Value;
long included = FilterByChannel && IncludedChannels != null ? IncludedChannels.Count : total; long included = FilterByChannel && IncludedChannels != null ? IncludedChannels.Count : total;
ChannelFilterLabel = "Selected " + included.Format() + " / " + total.Pluralize("channel") + "."; ChannelFilterLabel = "Selected " + included.Format() + " / " + total.Pluralize("channel") + ".";
} }
else {
ChannelFilterLabel = "Loading...";
}
}
private void UpdateUserFilterLabel() { private void UpdateUserFilterLabel() {
if (totalUserCount.HasValue) { long total = state.Db.Statistics.TotalUsers;
long total = totalUserCount.Value;
long included = FilterByUser && IncludedUsers != null ? IncludedUsers.Count : total; long included = FilterByUser && IncludedUsers != null ? IncludedUsers.Count : total;
UserFilterLabel = "Selected " + included.Format() + " / " + total.Pluralize("user") + "."; UserFilterLabel = "Selected " + included.Format() + " / " + total.Pluralize("user") + ".";
} }
else {
UserFilterLabel = "Loading...";
}
}
private void UpdateFilterStatistics() { private void UpdateFilterStatistics() {
var filter = CreateFilter(); var filter = CreateFilter();
@@ -221,7 +224,7 @@ sealed partial class MessageFilterPanelModel : ObservableObject, IDisposable {
items.Add(new CheckBoxItem<ulong>(channelId) { items.Add(new CheckBoxItem<ulong>(channelId) {
Title = title, Title = title,
IsChecked = IncludedChannels == null || IncludedChannels.Contains(channelId) Checked = IncludedChannels == null || IncludedChannels.Contains(channelId)
}); });
} }
@@ -254,7 +257,7 @@ sealed partial class MessageFilterPanelModel : ObservableObject, IDisposable {
checkBoxItems.Add(new CheckBoxItem<ulong>(user.Id) { checkBoxItems.Add(new CheckBoxItem<ulong>(user.Id) {
Title = discriminator == null ? name : name + " #" + discriminator, Title = discriminator == null ? name : name + " #" + discriminator,
IsChecked = IncludedUsers == null || IncludedUsers.Contains(user.Id) Checked = IncludedUsers == null || IncludedUsers.Contains(user.Id)
}); });
} }

View File

@@ -2,31 +2,47 @@ using System;
using System.Threading.Tasks; using System.Threading.Tasks;
using Avalonia.Controls; using Avalonia.Controls;
using Avalonia.Threading; using Avalonia.Threading;
using CommunityToolkit.Mvvm.ComponentModel;
using DHT.Desktop.Dialogs.Message; using DHT.Desktop.Dialogs.Message;
using DHT.Desktop.Server; using DHT.Desktop.Server;
using DHT.Server; using DHT.Server;
using DHT.Server.Service; using DHT.Server.Service;
using DHT.Utils.Logging; using DHT.Utils.Logging;
using DHT.Utils.Models;
namespace DHT.Desktop.Main.Controls; namespace DHT.Desktop.Main.Controls;
sealed partial class ServerConfigurationPanelModel : ObservableObject, IDisposable { sealed class ServerConfigurationPanelModel : BaseModel, IDisposable {
private static readonly Log Log = Log.ForType<ServerConfigurationPanelModel>(); private static readonly Log Log = Log.ForType<ServerConfigurationPanelModel>();
[ObservableProperty]
[NotifyPropertyChangedFor(nameof(HasMadeChanges))]
private string inputPort; private string inputPort;
[ObservableProperty] public string InputPort {
[NotifyPropertyChangedFor(nameof(HasMadeChanges))] get => inputPort;
set {
Change(ref inputPort, value);
OnPropertyChanged(nameof(HasMadeChanges));
}
}
private string inputToken; private string inputToken;
public string InputToken {
get => inputToken;
set {
Change(ref inputToken, value);
OnPropertyChanged(nameof(HasMadeChanges));
}
}
public bool HasMadeChanges => ServerConfiguration.Port.ToString() != InputPort || ServerConfiguration.Token != InputToken; public bool HasMadeChanges => ServerConfiguration.Port.ToString() != InputPort || ServerConfiguration.Token != InputToken;
[ObservableProperty(Setter = Access.Private)]
private bool isToggleServerButtonEnabled = true; private bool isToggleServerButtonEnabled = true;
public bool IsToggleServerButtonEnabled {
get => isToggleServerButtonEnabled;
set => Change(ref isToggleServerButtonEnabled, value);
}
public string ToggleServerButtonText => server.IsRunning ? "Stop Server" : "Start Server"; public string ToggleServerButtonText => server.IsRunning ? "Stop Server" : "Start Server";
private readonly Window window; private readonly Window window;

View File

@@ -45,17 +45,17 @@
<Rectangle /> <Rectangle />
<StackPanel Orientation="Vertical"> <StackPanel Orientation="Vertical">
<TextBlock Classes="label">Servers</TextBlock> <TextBlock Classes="label">Servers</TextBlock>
<TextBlock Classes="value" Text="{Binding ServerCount, Mode=OneWay, Converter={StaticResource NumberValueConverter}}" /> <TextBlock Classes="value" Text="{Binding DatabaseStatistics.TotalServers, Mode=OneWay, Converter={StaticResource NumberValueConverter}}" />
</StackPanel> </StackPanel>
<Rectangle /> <Rectangle />
<StackPanel Orientation="Vertical"> <StackPanel Orientation="Vertical">
<TextBlock Classes="label">Channels</TextBlock> <TextBlock Classes="label">Channels</TextBlock>
<TextBlock Classes="value" Text="{Binding ChannelCount, Mode=OneWay, Converter={StaticResource NumberValueConverter}}" /> <TextBlock Classes="value" Text="{Binding DatabaseStatistics.TotalChannels, Mode=OneWay, Converter={StaticResource NumberValueConverter}}" />
</StackPanel> </StackPanel>
<Rectangle /> <Rectangle />
<StackPanel Orientation="Vertical"> <StackPanel Orientation="Vertical">
<TextBlock Classes="label">Messages</TextBlock> <TextBlock Classes="label">Messages</TextBlock>
<TextBlock Classes="value" Text="{Binding MessageCount, Mode=OneWay, Converter={StaticResource NumberValueConverter}}" /> <TextBlock Classes="value" Text="{Binding DatabaseStatistics.TotalMessages, Mode=OneWay, Converter={StaticResource NumberValueConverter}}" />
</StackPanel> </StackPanel>
</StackPanel> </StackPanel>

View File

@@ -1,25 +1,15 @@
using System; using System;
using System.Reactive.Linq;
using Avalonia.ReactiveUI;
using Avalonia.Threading; using Avalonia.Threading;
using CommunityToolkit.Mvvm.ComponentModel;
using DHT.Server; using DHT.Server;
using DHT.Server.Database;
using DHT.Server.Service; using DHT.Server.Service;
using DHT.Utils.Models;
namespace DHT.Desktop.Main.Controls; namespace DHT.Desktop.Main.Controls;
sealed partial class StatusBarModel : ObservableObject, IDisposable { sealed class StatusBarModel : BaseModel, IDisposable {
[ObservableProperty(Setter = Access.Private)] public DatabaseStatistics DatabaseStatistics { get; }
private long? serverCount;
[ObservableProperty(Setter = Access.Private)]
private long? channelCount;
[ObservableProperty(Setter = Access.Private)]
private long? messageCount;
[ObservableProperty(Setter = Access.Private)]
[NotifyPropertyChangedFor(nameof(ServerStatusText))]
private ServerManager.Status serverStatus; private ServerManager.Status serverStatus;
public string ServerStatusText => serverStatus switch { public string ServerStatusText => serverStatus switch {
@@ -31,33 +21,26 @@ sealed partial class StatusBarModel : ObservableObject, IDisposable {
}; };
private readonly State state; private readonly State state;
private readonly IDisposable serverCountSubscription;
private readonly IDisposable channelCountSubscription;
private readonly IDisposable messageCountSubscription;
[Obsolete("Designer")] [Obsolete("Designer")]
public StatusBarModel() : this(State.Dummy) {} public StatusBarModel() : this(State.Dummy) {}
public StatusBarModel(State state) { public StatusBarModel(State state) {
this.state = state; this.state = state;
this.DatabaseStatistics = state.Db.Statistics;
serverCountSubscription = state.Db.Servers.TotalCount.ObserveOn(AvaloniaScheduler.Instance).Subscribe(newServerCount => ServerCount = newServerCount);
channelCountSubscription = state.Db.Channels.TotalCount.ObserveOn(AvaloniaScheduler.Instance).Subscribe(newChannelCount => ChannelCount = newChannelCount);
messageCountSubscription = state.Db.Messages.TotalCount.ObserveOn(AvaloniaScheduler.Instance).Subscribe(newMessageCount => MessageCount = newMessageCount);
state.Server.StatusChanged += OnServerStatusChanged; state.Server.StatusChanged += OnServerStatusChanged;
serverStatus = state.Server.IsRunning ? ServerManager.Status.Started : ServerManager.Status.Stopped; serverStatus = state.Server.IsRunning ? ServerManager.Status.Started : ServerManager.Status.Stopped;
} }
public void Dispose() { public void Dispose() {
serverCountSubscription.Dispose(); state.Server.StatusChanged += OnServerStatusChanged;
channelCountSubscription.Dispose();
messageCountSubscription.Dispose();
state.Server.StatusChanged -= OnServerStatusChanged;
} }
private void OnServerStatusChanged(object? sender, ServerManager.Status e) { private void OnServerStatusChanged(object? sender, ServerManager.Status e) {
Dispatcher.UIThread.InvokeAsync(() => ServerStatus = e); Dispatcher.UIThread.InvokeAsync(() => {
serverStatus = e;
OnPropertyChanged(nameof(ServerStatusText));
});
} }
} }

View File

@@ -8,7 +8,7 @@
x:DataType="main:MainWindowModel" x:DataType="main:MainWindowModel"
Title="{Binding Title}" Title="{Binding Title}"
Icon="avares://DiscordHistoryTracker/Resources/icon.ico" Icon="avares://DiscordHistoryTracker/Resources/icon.ico"
Width="820" Height="520" Width="800" Height="500"
MinWidth="520" MinHeight="300" MinWidth="520" MinHeight="300"
WindowStartupLocation="CenterScreen" WindowStartupLocation="CenterScreen"
Closing="OnClosing"> Closing="OnClosing">

View File

@@ -3,26 +3,24 @@ using System.IO;
using System.Runtime.InteropServices; using System.Runtime.InteropServices;
using System.Threading.Tasks; using System.Threading.Tasks;
using Avalonia.Controls; using Avalonia.Controls;
using CommunityToolkit.Mvvm.ComponentModel;
using DHT.Desktop.Dialogs.Message; using DHT.Desktop.Dialogs.Message;
using DHT.Desktop.Main.Screens; using DHT.Desktop.Main.Screens;
using DHT.Desktop.Server; using DHT.Desktop.Server;
using DHT.Server; using DHT.Server;
using DHT.Server.Database; using DHT.Server.Database;
using DHT.Utils.Logging; using DHT.Utils.Logging;
using DHT.Utils.Models;
namespace DHT.Desktop.Main; namespace DHT.Desktop.Main;
sealed partial class MainWindowModel : ObservableObject, IAsyncDisposable { sealed class MainWindowModel : BaseModel, IAsyncDisposable {
private const string DefaultTitle = "Discord History Tracker"; private const string DefaultTitle = "Discord History Tracker";
private static readonly Log Log = Log.ForType<MainWindowModel>(); private static readonly Log Log = Log.ForType<MainWindowModel>();
[ObservableProperty(Setter = Access.Private)] public string Title { get; private set; } = DefaultTitle;
private string title = DefaultTitle;
[ObservableProperty(Setter = Access.Private)] public UserControl CurrentScreen { get; private set; }
private UserControl currentScreen;
private readonly WelcomeScreen welcomeScreen; private readonly WelcomeScreen welcomeScreen;
private readonly WelcomeScreenModel welcomeScreenModel; private readonly WelcomeScreenModel welcomeScreenModel;
@@ -43,7 +41,7 @@ sealed partial class MainWindowModel : ObservableObject, IAsyncDisposable {
welcomeScreenModel.DatabaseSelected += OnDatabaseSelected; welcomeScreenModel.DatabaseSelected += OnDatabaseSelected;
welcomeScreen = new WelcomeScreen { DataContext = welcomeScreenModel }; welcomeScreen = new WelcomeScreen { DataContext = welcomeScreenModel };
currentScreen = welcomeScreen; CurrentScreen = welcomeScreen;
var dbFile = args.DatabaseFile; var dbFile = args.DatabaseFile;
if (!string.IsNullOrWhiteSpace(dbFile)) { if (!string.IsNullOrWhiteSpace(dbFile)) {
@@ -95,6 +93,9 @@ sealed partial class MainWindowModel : ObservableObject, IAsyncDisposable {
Title = Path.GetFileName(state.Db.Path) + " - " + DefaultTitle; Title = Path.GetFileName(state.Db.Path) + " - " + DefaultTitle;
CurrentScreen = new MainContentScreen { DataContext = mainContentScreenModel }; CurrentScreen = new MainContentScreen { DataContext = mainContentScreenModel };
OnPropertyChanged(nameof(Title));
OnPropertyChanged(nameof(CurrentScreen));
window.Focus(); window.Focus();
} }
@@ -111,6 +112,9 @@ sealed partial class MainWindowModel : ObservableObject, IAsyncDisposable {
CurrentScreen = welcomeScreen; CurrentScreen = welcomeScreen;
welcomeScreenModel.DatabaseSelected += OnDatabaseSelected; welcomeScreenModel.DatabaseSelected += OnDatabaseSelected;
OnPropertyChanged(nameof(Title));
OnPropertyChanged(nameof(CurrentScreen));
} }
private async Task DisposeState() { private async Task DisposeState() {

View File

@@ -5,10 +5,11 @@ using DHT.Desktop.Dialogs.Message;
using DHT.Desktop.Dialogs.Progress; using DHT.Desktop.Dialogs.Progress;
using DHT.Desktop.Main.Controls; using DHT.Desktop.Main.Controls;
using DHT.Server; using DHT.Server;
using DHT.Utils.Models;
namespace DHT.Desktop.Main.Pages; namespace DHT.Desktop.Main.Pages;
sealed class AdvancedPageModel : IDisposable { sealed class AdvancedPageModel : BaseModel, IDisposable {
public ServerConfigurationPanelModel ServerConfigurationModel { get; } public ServerConfigurationPanelModel ServerConfigurationModel { get; }
private readonly Window window; private readonly Window window;

View File

@@ -1,22 +1,23 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Collections.ObjectModel; using System.ComponentModel;
using System.Reactive.Linq; using System.Reactive.Linq;
using System.Threading.Tasks; using System.Threading.Tasks;
using Avalonia.ReactiveUI; using Avalonia.ReactiveUI;
using CommunityToolkit.Mvvm.ComponentModel;
using DHT.Desktop.Common; using DHT.Desktop.Common;
using DHT.Desktop.Main.Controls; using DHT.Desktop.Main.Controls;
using DHT.Server; using DHT.Server;
using DHT.Server.Data; using DHT.Server.Data;
using DHT.Server.Data.Aggregations; using DHT.Server.Data.Aggregations;
using DHT.Server.Data.Filters; using DHT.Server.Data.Filters;
using DHT.Server.Database;
using DHT.Utils.Logging; using DHT.Utils.Logging;
using DHT.Utils.Models;
using DHT.Utils.Tasks; using DHT.Utils.Tasks;
namespace DHT.Desktop.Main.Pages; namespace DHT.Desktop.Main.Pages;
sealed partial class AttachmentsPageModel : ObservableObject, IDisposable { sealed class AttachmentsPageModel : BaseModel, IDisposable {
private static readonly Log Log = Log.ForType<AttachmentsPageModel>(); private static readonly Log Log = Log.ForType<AttachmentsPageModel>();
private static readonly DownloadItemFilter EnqueuedItemFilter = new () { private static readonly DownloadItemFilter EnqueuedItemFilter = new () {
@@ -26,24 +27,28 @@ sealed partial class AttachmentsPageModel : ObservableObject, IDisposable {
} }
}; };
[ObservableProperty(Setter = Access.Private)]
private bool isToggleDownloadButtonEnabled = true; private bool isToggleDownloadButtonEnabled = true;
public bool IsToggleDownloadButtonEnabled {
get => isToggleDownloadButtonEnabled;
set => Change(ref isToggleDownloadButtonEnabled, value);
}
public string ToggleDownloadButtonText => IsDownloading ? "Stop Downloading" : "Start Downloading"; public string ToggleDownloadButtonText => IsDownloading ? "Stop Downloading" : "Start Downloading";
[ObservableProperty(Setter = Access.Private)]
[NotifyPropertyChangedFor(nameof(IsRetryFailedOnDownloadsButtonEnabled))]
private bool isRetryingFailedDownloads = false; private bool isRetryingFailedDownloads = false;
[ObservableProperty(Setter = Access.Private)] public bool IsRetryingFailedDownloads {
[NotifyPropertyChangedFor(nameof(IsRetryFailedOnDownloadsButtonEnabled))] get => isRetryingFailedDownloads;
private bool hasFailedDownloads; set {
isRetryingFailedDownloads = value;
OnPropertyChanged(nameof(IsRetryFailedOnDownloadsButtonEnabled));
}
}
public bool IsRetryFailedOnDownloadsButtonEnabled => !IsRetryingFailedDownloads && hasFailedDownloads; public bool IsRetryFailedOnDownloadsButtonEnabled => !IsRetryingFailedDownloads && HasFailedDownloads;
[ObservableProperty(Setter = Access.Private)]
private string downloadMessage = "";
public string DownloadMessage { get; set; } = "";
public double DownloadProgress => totalItemsToDownloadCount is null or 0 ? 0.0 : 100.0 * doneItemsCount / totalItemsToDownloadCount.Value; public double DownloadProgress => totalItemsToDownloadCount is null or 0 ? 0.0 : 100.0 * doneItemsCount / totalItemsToDownloadCount.Value;
public AttachmentFilterPanelModel FilterModel { get; } public AttachmentFilterPanelModel FilterModel { get; }
@@ -53,20 +58,23 @@ sealed partial class AttachmentsPageModel : ObservableObject, IDisposable {
private readonly StatisticsRow statisticsFailed = new ("Failed"); private readonly StatisticsRow statisticsFailed = new ("Failed");
private readonly StatisticsRow statisticsSkipped = new ("Skipped"); private readonly StatisticsRow statisticsSkipped = new ("Skipped");
public ObservableCollection<StatisticsRow> StatisticsRows { get; } public List<StatisticsRow> StatisticsRows => [
statisticsEnqueued,
statisticsDownloaded,
statisticsFailed,
statisticsSkipped
];
public bool IsDownloading => state.Downloader.IsDownloading; public bool IsDownloading => state.Downloader.IsDownloading;
public bool HasFailedDownloads => statisticsFailed.Items > 0;
private readonly State state; private readonly State state;
private readonly ThrottledTask<int> enqueueDownloadItemsTask; private readonly ThrottledTask enqueueDownloadItemsTask;
private readonly ThrottledTask<DownloadStatusStatistics> downloadStatisticsTask; private readonly ThrottledTask<DownloadStatusStatistics> downloadStatisticsTask;
private readonly IDisposable attachmentCountSubscription;
private readonly IDisposable downloadCountSubscription;
private IDisposable? finishedItemsSubscription; private IDisposable? finishedItemsSubscription;
private int doneItemsCount; private int doneItemsCount;
private int totalEnqueuedItemCount; private int initialFinishedCount;
private int? totalItemsToDownloadCount; private int? totalItemsToDownloadCount;
public AttachmentsPageModel() : this(State.Dummy) {} public AttachmentsPageModel() : this(State.Dummy) {}
@@ -76,34 +84,22 @@ sealed partial class AttachmentsPageModel : ObservableObject, IDisposable {
FilterModel = new AttachmentFilterPanelModel(state); FilterModel = new AttachmentFilterPanelModel(state);
StatisticsRows = [ enqueueDownloadItemsTask = new ThrottledTask(RecomputeDownloadStatistics, TaskScheduler.FromCurrentSynchronizationContext());
statisticsEnqueued,
statisticsDownloaded,
statisticsFailed,
statisticsSkipped
];
enqueueDownloadItemsTask = new ThrottledTask<int>(OnItemsEnqueued, TaskScheduler.FromCurrentSynchronizationContext());
downloadStatisticsTask = new ThrottledTask<DownloadStatusStatistics>(UpdateStatistics, TaskScheduler.FromCurrentSynchronizationContext()); downloadStatisticsTask = new ThrottledTask<DownloadStatusStatistics>(UpdateStatistics, TaskScheduler.FromCurrentSynchronizationContext());
attachmentCountSubscription = state.Db.Attachments.TotalCount.ObserveOn(AvaloniaScheduler.Instance).Subscribe(OnAttachmentCountChanged);
downloadCountSubscription = state.Db.Downloads.TotalCount.ObserveOn(AvaloniaScheduler.Instance).Subscribe(OnDownloadCountChanged);
RecomputeDownloadStatistics(); RecomputeDownloadStatistics();
state.Db.Statistics.PropertyChanged += OnDbStatisticsChanged;
} }
public void Dispose() { public void Dispose() {
attachmentCountSubscription.Dispose(); state.Db.Statistics.PropertyChanged -= OnDbStatisticsChanged;
downloadCountSubscription.Dispose();
finishedItemsSubscription?.Dispose();
enqueueDownloadItemsTask.Dispose();
downloadStatisticsTask.Dispose(); downloadStatisticsTask.Dispose();
finishedItemsSubscription?.Dispose();
FilterModel.Dispose(); FilterModel.Dispose();
} }
private void OnAttachmentCountChanged(long newAttachmentCount) { private void OnDbStatisticsChanged(object? sender, PropertyChangedEventArgs e) {
if (e.PropertyName == nameof(DatabaseStatistics.TotalAttachments)) {
if (IsDownloading) { if (IsDownloading) {
EnqueueDownloadItemsLater(); EnqueueDownloadItemsLater();
} }
@@ -111,13 +107,13 @@ sealed partial class AttachmentsPageModel : ObservableObject, IDisposable {
RecomputeDownloadStatistics(); RecomputeDownloadStatistics();
} }
} }
else if (e.PropertyName == nameof(DatabaseStatistics.TotalDownloads)) {
private void OnDownloadCountChanged(long newDownloadCount) {
RecomputeDownloadStatistics(); RecomputeDownloadStatistics();
} }
}
private async Task EnqueueDownloadItems() { private async Task EnqueueDownloadItems() {
OnItemsEnqueued(await state.Db.Downloads.EnqueueDownloadItems(CreateAttachmentFilter())); await state.Db.Downloads.EnqueueDownloadItems(CreateAttachmentFilter());
} }
private void EnqueueDownloadItemsLater() { private void EnqueueDownloadItemsLater() {
@@ -125,19 +121,49 @@ sealed partial class AttachmentsPageModel : ObservableObject, IDisposable {
enqueueDownloadItemsTask.Post(cancellationToken => state.Db.Downloads.EnqueueDownloadItems(filter, cancellationToken)); enqueueDownloadItemsTask.Post(cancellationToken => state.Db.Downloads.EnqueueDownloadItems(filter, cancellationToken));
} }
private void OnItemsEnqueued(int itemCount) {
totalEnqueuedItemCount += itemCount;
totalItemsToDownloadCount = totalEnqueuedItemCount;
UpdateDownloadMessage();
RecomputeDownloadStatistics();
}
private AttachmentFilter CreateAttachmentFilter() { private AttachmentFilter CreateAttachmentFilter() {
var filter = FilterModel.CreateFilter(); var filter = FilterModel.CreateFilter();
filter.DownloadItemRule = AttachmentFilter.DownloadItemRules.OnlyNotPresent; filter.DownloadItemRule = AttachmentFilter.DownloadItemRules.OnlyNotPresent;
return filter; return filter;
} }
private void RecomputeDownloadStatistics() {
downloadStatisticsTask.Post(state.Db.Downloads.GetStatistics);
}
private void UpdateStatistics(DownloadStatusStatistics statusStatistics) {
var hadFailedDownloads = HasFailedDownloads;
statisticsEnqueued.Items = statusStatistics.EnqueuedCount;
statisticsEnqueued.Size = statusStatistics.EnqueuedSize;
statisticsDownloaded.Items = statusStatistics.SuccessfulCount;
statisticsDownloaded.Size = statusStatistics.SuccessfulSize;
statisticsFailed.Items = statusStatistics.FailedCount;
statisticsFailed.Size = statusStatistics.FailedSize;
statisticsSkipped.Items = statusStatistics.SkippedCount;
statisticsSkipped.Size = statusStatistics.SkippedSize;
OnPropertyChanged(nameof(StatisticsRows));
if (hadFailedDownloads != HasFailedDownloads) {
OnPropertyChanged(nameof(HasFailedDownloads));
OnPropertyChanged(nameof(IsRetryFailedOnDownloadsButtonEnabled));
}
totalItemsToDownloadCount = statisticsEnqueued.Items + statisticsDownloaded.Items + statisticsFailed.Items - initialFinishedCount;
UpdateDownloadMessage();
}
private void UpdateDownloadMessage() {
DownloadMessage = IsDownloading ? doneItemsCount.Format() + " / " + (totalItemsToDownloadCount?.Format() ?? "?") : "";
OnPropertyChanged(nameof(DownloadMessage));
OnPropertyChanged(nameof(DownloadProgress));
}
public async Task OnClickToggleDownload() { public async Task OnClickToggleDownload() {
IsToggleDownloadButtonEnabled = false; IsToggleDownloadButtonEnabled = false;
@@ -152,13 +178,14 @@ sealed partial class AttachmentsPageModel : ObservableObject, IDisposable {
await state.Db.Downloads.RemoveDownloadItems(EnqueuedItemFilter, FilterRemovalMode.RemoveMatching); await state.Db.Downloads.RemoveDownloadItems(EnqueuedItemFilter, FilterRemovalMode.RemoveMatching);
doneItemsCount = 0; doneItemsCount = 0;
totalEnqueuedItemCount = 0; initialFinishedCount = 0;
totalItemsToDownloadCount = null; totalItemsToDownloadCount = null;
UpdateDownloadMessage(); UpdateDownloadMessage();
} }
else { else {
var finishedItems = await state.Downloader.Start(); var finishedItems = await state.Downloader.Start();
initialFinishedCount = statisticsDownloaded.Items + statisticsFailed.Items;
finishedItemsSubscription = finishedItems.Select(static _ => true) finishedItemsSubscription = finishedItems.Select(static _ => true)
.Buffer(TimeSpan.FromMilliseconds(100)) .Buffer(TimeSpan.FromMilliseconds(100))
.Select(static items => items.Count) .Select(static items => items.Count)
@@ -177,6 +204,7 @@ sealed partial class AttachmentsPageModel : ObservableObject, IDisposable {
private void OnItemsFinished(int finishedItemCount) { private void OnItemsFinished(int finishedItemCount) {
doneItemsCount += finishedItemCount; doneItemsCount += finishedItemCount;
UpdateDownloadMessage(); UpdateDownloadMessage();
RecomputeDownloadStatistics();
} }
public async Task OnClickRetryFailedDownloads() { public async Task OnClickRetryFailedDownloads() {
@@ -203,42 +231,13 @@ sealed partial class AttachmentsPageModel : ObservableObject, IDisposable {
} }
} }
private void RecomputeDownloadStatistics() { public sealed class StatisticsRow {
downloadStatisticsTask.Post(state.Db.Downloads.GetStatistics); public string State { get; }
public int Items { get; set; }
public ulong? Size { get; set; }
public StatisticsRow(string state) {
State = state;
} }
private void UpdateStatistics(DownloadStatusStatistics statusStatistics) {
statisticsEnqueued.Items = statusStatistics.EnqueuedCount;
statisticsEnqueued.Size = statusStatistics.EnqueuedSize;
statisticsDownloaded.Items = statusStatistics.SuccessfulCount;
statisticsDownloaded.Size = statusStatistics.SuccessfulSize;
statisticsFailed.Items = statusStatistics.FailedCount;
statisticsFailed.Size = statusStatistics.FailedSize;
statisticsSkipped.Items = statusStatistics.SkippedCount;
statisticsSkipped.Size = statusStatistics.SkippedSize;
hasFailedDownloads = statusStatistics.FailedCount > 0;
UpdateDownloadMessage();
}
private void UpdateDownloadMessage() {
DownloadMessage = IsDownloading ? doneItemsCount.Format() + " / " + (totalItemsToDownloadCount?.Format() ?? "?") : "";
OnPropertyChanged(nameof(DownloadProgress));
}
[ObservableObject]
public sealed partial class StatisticsRow(string state) {
public string State { get; } = state;
[ObservableProperty]
private int items;
[ObservableProperty]
private ulong? size;
} }
} }

View File

@@ -20,10 +20,11 @@ using DHT.Server.Database;
using DHT.Server.Database.Import; using DHT.Server.Database.Import;
using DHT.Server.Database.Sqlite.Utils; using DHT.Server.Database.Sqlite.Utils;
using DHT.Utils.Logging; using DHT.Utils.Logging;
using DHT.Utils.Models;
namespace DHT.Desktop.Main.Pages; namespace DHT.Desktop.Main.Pages;
sealed class DatabasePageModel { sealed class DatabasePageModel : BaseModel {
private static readonly Log Log = Log.ForType<DatabasePageModel>(); private static readonly Log Log = Log.ForType<DatabasePageModel>();
public IDatabaseFile Db { get; } public IDatabaseFile Db { get; }
@@ -92,7 +93,7 @@ sealed class DatabasePageModel {
await target.AddFrom(db); await target.AddFrom(db);
return true; return true;
} finally { } finally {
await db.DisposeAsync(); db.Dispose();
} }
}); });
} }
@@ -193,7 +194,7 @@ sealed class DatabasePageModel {
private static async Task PerformImport(IDatabaseFile target, string[] paths, ProgressDialog dialog, IProgressCallback callback, string neutralDialogTitle, string errorDialogTitle, string itemName, Func<string, Task<bool>> performImport) { private static async Task PerformImport(IDatabaseFile target, string[] paths, ProgressDialog dialog, IProgressCallback callback, string neutralDialogTitle, string errorDialogTitle, string itemName, Func<string, Task<bool>> performImport) {
int total = paths.Length; int total = paths.Length;
var oldStatistics = await DatabaseStatistics.Take(target); var oldStatistics = await target.SnapshotStatistics();
int successful = 0; int successful = 0;
int finished = 0; int finished = 0;
@@ -224,26 +225,15 @@ sealed class DatabasePageModel {
return; return;
} }
var newStatistics = await DatabaseStatistics.Take(target); var newStatistics = await target.SnapshotStatistics();
await Dialog.ShowOk(dialog, neutralDialogTitle, GetImportDialogMessage(oldStatistics, newStatistics, successful, total, itemName)); await Dialog.ShowOk(dialog, neutralDialogTitle, GetImportDialogMessage(oldStatistics, newStatistics, successful, total, itemName));
} }
private sealed record DatabaseStatistics(long ServerCount, long ChannelCount, long UserCount, long MessageCount) { private static string GetImportDialogMessage(DatabaseStatisticsSnapshot oldStatistics, DatabaseStatisticsSnapshot newStatistics, int successfulItems, int totalItems, string itemName) {
public static async Task<DatabaseStatistics> Take(IDatabaseFile db) { long newServers = newStatistics.TotalServers - oldStatistics.TotalServers;
return new DatabaseStatistics( long newChannels = newStatistics.TotalChannels - oldStatistics.TotalChannels;
await db.Servers.Count(), long newUsers = newStatistics.TotalUsers - oldStatistics.TotalUsers;
await db.Channels.Count(), long newMessages = newStatistics.TotalMessages - oldStatistics.TotalMessages;
await db.Users.Count(),
await db.Messages.Count()
);
}
}
private static string GetImportDialogMessage(DatabaseStatistics oldStatistics, DatabaseStatistics newStatistics, int successfulItems, int totalItems, string itemName) {
long newServers = newStatistics.ServerCount - oldStatistics.ServerCount;
long newChannels = newStatistics.ChannelCount - oldStatistics.ChannelCount;
long newUsers = newStatistics.UserCount - oldStatistics.UserCount;
long newMessages = newStatistics.MessageCount - oldStatistics.MessageCount;
StringBuilder message = new StringBuilder(); StringBuilder message = new StringBuilder();
message.Append("Processed "); message.Append("Processed ");

View File

@@ -9,9 +9,10 @@ using DHT.Desktop.Dialogs.Progress;
using DHT.Server; using DHT.Server;
using DHT.Server.Data; using DHT.Server.Data;
using DHT.Server.Service; using DHT.Server.Service;
using DHT.Utils.Models;
namespace DHT.Desktop.Main.Pages { namespace DHT.Desktop.Main.Pages {
sealed class DebugPageModel { sealed class DebugPageModel : BaseModel {
public string GenerateChannels { get; set; } = "0"; public string GenerateChannels { get; set; } = "0";
public string GenerateUsers { get; set; } = "0"; public string GenerateUsers { get; set; } = "0";
public string GenerateMessages { get; set; } = "0"; public string GenerateMessages { get; set; } = "0";
@@ -128,7 +129,7 @@ namespace DHT.Desktop.Main.Pages {
return options[(int) Math.Floor(options.Length * rand.NextDouble() * rand.NextDouble())]; return options[(int) Math.Floor(options.Length * rand.NextDouble() * rand.NextDouble())];
} }
private static readonly string[] RandomWords = [ private static readonly string[] RandomWords = {
"apple", "apricot", "artichoke", "arugula", "asparagus", "avocado", "apple", "apricot", "artichoke", "arugula", "asparagus", "avocado",
"banana", "bean", "beechnut", "beet", "blackberry", "blackcurrant", "blueberry", "boysenberry", "bramble", "broccoli", "banana", "bean", "beechnut", "beet", "blackberry", "blackcurrant", "blueberry", "boysenberry", "bramble", "broccoli",
"cabbage", "cacao", "cantaloupe", "caper", "carambola", "carrot", "cauliflower", "celery", "chard", "cherry", "chokeberry", "citron", "clementine", "coconut", "corn", "crabapple", "cranberry", "cucumber", "currant", "cabbage", "cacao", "cantaloupe", "caper", "carambola", "carrot", "cauliflower", "celery", "chard", "cherry", "chokeberry", "citron", "clementine", "coconut", "corn", "crabapple", "cranberry", "cucumber", "currant",
@@ -151,8 +152,8 @@ namespace DHT.Desktop.Main.Pages {
"vanilla", "vanilla",
"watercress", "watermelon", "watercress", "watermelon",
"yam", "yam",
"zucchini" "zucchini",
]; };
private static string RandomText(Random rand, int maxWords) { private static string RandomText(Random rand, int maxWords) {
int wordCount = 1 + (int) Math.Floor(maxWords * Math.Pow(rand.NextDouble(), 3)); int wordCount = 1 + (int) Math.Floor(maxWords * Math.Pow(rand.NextDouble(), 3));
@@ -161,8 +162,10 @@ namespace DHT.Desktop.Main.Pages {
} }
} }
#else #else
using DHT.Utils.Models;
namespace DHT.Desktop.Main.Pages { namespace DHT.Desktop.Main.Pages {
sealed class DebugPageModel { sealed class DebugPageModel : BaseModel {
public string GenerateChannels { get; set; } = "0"; public string GenerateChannels { get; set; } = "0";
public string GenerateUsers { get; set; } = "0"; public string GenerateUsers { get; set; } = "0";
public string GenerateMessages { get; set; } = "0"; public string GenerateMessages { get; set; } = "0";

View File

@@ -3,25 +3,33 @@ using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using System.Web; using System.Web;
using Avalonia.Controls; using Avalonia.Controls;
using CommunityToolkit.Mvvm.ComponentModel;
using DHT.Desktop.Dialogs.Message; using DHT.Desktop.Dialogs.Message;
using DHT.Desktop.Discord; using DHT.Desktop.Discord;
using DHT.Desktop.Server; using DHT.Desktop.Server;
using DHT.Utils.Models;
using static DHT.Desktop.Program; using static DHT.Desktop.Program;
namespace DHT.Desktop.Main.Pages; namespace DHT.Desktop.Main.Pages;
sealed partial class TrackingPageModel : ObservableObject { sealed class TrackingPageModel : BaseModel {
[ObservableProperty(Setter = Access.Private)]
private bool isCopyTrackingScriptButtonEnabled = true; private bool isCopyTrackingScriptButtonEnabled = true;
[ObservableProperty(Setter = Access.Private)] public bool IsCopyTrackingScriptButtonEnabled {
[NotifyPropertyChangedFor(nameof(ToggleAppDevToolsButtonText))] get => isCopyTrackingScriptButtonEnabled;
private bool? areDevToolsEnabled = null; set => Change(ref isCopyTrackingScriptButtonEnabled, value);
}
[ObservableProperty(Setter = Access.Private)] private bool? areDevToolsEnabled;
[NotifyPropertyChangedFor(nameof(ToggleAppDevToolsButtonText))]
private bool isToggleAppDevToolsButtonEnabled = false; private bool? AreDevToolsEnabled {
get => areDevToolsEnabled;
set {
Change(ref areDevToolsEnabled, value);
OnPropertyChanged(nameof(ToggleAppDevToolsButtonText));
}
}
public bool IsToggleAppDevToolsButtonEnabled { get; private set; } = false;
public string ToggleAppDevToolsButtonText { public string ToggleAppDevToolsButtonText {
get { get {
@@ -78,15 +86,17 @@ sealed partial class TrackingPageModel : ObservableObject {
} }
private async Task InitializeDevToolsToggle() { private async Task InitializeDevToolsToggle() {
bool? devToolsEnabled = await Task.Run(DiscordAppSettings.AreDevToolsEnabled); bool? devToolsEnabled = await DiscordAppSettings.AreDevToolsEnabled();
if (devToolsEnabled.HasValue) { if (devToolsEnabled.HasValue) {
AreDevToolsEnabled = devToolsEnabled.Value;
IsToggleAppDevToolsButtonEnabled = true; IsToggleAppDevToolsButtonEnabled = true;
AreDevToolsEnabled = devToolsEnabled.Value;
} }
else { else {
IsToggleAppDevToolsButtonEnabled = false; IsToggleAppDevToolsButtonEnabled = false;
} }
OnPropertyChanged(nameof(IsToggleAppDevToolsButtonEnabled));
} }
public async Task OnClickToggleAppDevTools() { public async Task OnClickToggleAppDevTools() {

View File

@@ -8,7 +8,6 @@ using System.Threading.Tasks;
using System.Web; using System.Web;
using Avalonia.Controls; using Avalonia.Controls;
using Avalonia.Platform.Storage; using Avalonia.Platform.Storage;
using CommunityToolkit.Mvvm.ComponentModel;
using DHT.Desktop.Common; using DHT.Desktop.Common;
using DHT.Desktop.Dialogs.File; using DHT.Desktop.Dialogs.File;
using DHT.Desktop.Dialogs.Message; using DHT.Desktop.Dialogs.Message;
@@ -19,11 +18,12 @@ using DHT.Server;
using DHT.Server.Data.Filters; using DHT.Server.Data.Filters;
using DHT.Server.Database.Export; using DHT.Server.Database.Export;
using DHT.Server.Database.Export.Strategy; using DHT.Server.Database.Export.Strategy;
using DHT.Utils.Models;
using static DHT.Desktop.Program; using static DHT.Desktop.Program;
namespace DHT.Desktop.Main.Pages; namespace DHT.Desktop.Main.Pages;
sealed partial class ViewerPageModel : ObservableObject, IDisposable { sealed class ViewerPageModel : BaseModel, IDisposable {
public static readonly ConcurrentBag<string> TemporaryFiles = []; public static readonly ConcurrentBag<string> TemporaryFiles = [];
private static readonly FilePickerFileType[] ViewerFileTypes = [ private static readonly FilePickerFileType[] ViewerFileTypes = [
@@ -33,9 +33,13 @@ sealed partial class ViewerPageModel : ObservableObject, IDisposable {
public bool DatabaseToolFilterModeKeep { get; set; } = true; public bool DatabaseToolFilterModeKeep { get; set; } = true;
public bool DatabaseToolFilterModeRemove { get; set; } = false; public bool DatabaseToolFilterModeRemove { get; set; } = false;
[ObservableProperty]
private bool hasFilters = false; private bool hasFilters = false;
public bool HasFilters {
get => hasFilters;
set => Change(ref hasFilters, value);
}
public MessageFilterPanelModel FilterModel { get; } public MessageFilterPanelModel FilterModel { get; }
private readonly Window window; private readonly Window window;
@@ -65,13 +69,10 @@ sealed partial class ViewerPageModel : ObservableObject, IDisposable {
var fullPath = await PrepareTemporaryViewerFile(); var fullPath = await PrepareTemporaryViewerFile();
var strategy = new LiveViewerExportStrategy(ServerConfiguration.Port, ServerConfiguration.Token); var strategy = new LiveViewerExportStrategy(ServerConfiguration.Port, ServerConfiguration.Token);
await ProgressDialog.ShowIndeterminate(window, "Open Viewer", "Creating viewer...", _ => Task.Run(() => WriteViewerFile(fullPath, strategy))); await WriteViewerFile(fullPath, strategy);
Process.Start(new ProcessStartInfo(fullPath) { UseShellExecute = true });
Process.Start(new ProcessStartInfo(fullPath) {
UseShellExecute = true
});
} catch (Exception e) { } catch (Exception e) {
await Dialog.ShowOk(window, "Open Viewer", "Could not create or save viewer: " + e.Message); await Dialog.ShowOk(window, "Open Viewer", "Could not save viewer: " + e.Message);
} }
} }
@@ -109,9 +110,9 @@ sealed partial class ViewerPageModel : ObservableObject, IDisposable {
} }
try { try {
await ProgressDialog.ShowIndeterminate(window, "Save Viewer", "Creating viewer...", _ => Task.Run(() => WriteViewerFile(path, StandaloneViewerExportStrategy.Instance))); await WriteViewerFile(path, StandaloneViewerExportStrategy.Instance);
} catch (Exception e) { } catch (Exception e) {
await Dialog.ShowOk(window, "Save Viewer", "Could not create or save viewer: " + e.Message); await Dialog.ShowOk(window, "Save Viewer", "Could not save viewer: " + e.Message);
} }
} }

View File

@@ -3,21 +3,25 @@ using System.Collections.Generic;
using System.IO; using System.IO;
using System.Threading.Tasks; using System.Threading.Tasks;
using Avalonia.Controls; using Avalonia.Controls;
using CommunityToolkit.Mvvm.ComponentModel;
using DHT.Desktop.Common; using DHT.Desktop.Common;
using DHT.Desktop.Dialogs.Message; using DHT.Desktop.Dialogs.Message;
using DHT.Desktop.Dialogs.Progress; using DHT.Desktop.Dialogs.Progress;
using DHT.Server.Database; using DHT.Server.Database;
using DHT.Server.Database.Sqlite.Utils; using DHT.Server.Database.Sqlite.Utils;
using DHT.Utils.Models;
namespace DHT.Desktop.Main.Screens; namespace DHT.Desktop.Main.Screens;
sealed partial class WelcomeScreenModel : ObservableObject { sealed class WelcomeScreenModel : BaseModel {
public string Version => Program.Version; public string Version => Program.Version;
[ObservableProperty(Setter = Access.Private)]
private bool isOpenOrCreateDatabaseButtonEnabled = true; private bool isOpenOrCreateDatabaseButtonEnabled = true;
public bool IsOpenOrCreateDatabaseButtonEnabled {
get => isOpenOrCreateDatabaseButtonEnabled;
set => Change(ref isOpenOrCreateDatabaseButtonEnabled, value);
}
public event EventHandler<IDatabaseFile>? DatabaseSelected; public event EventHandler<IDatabaseFile>? DatabaseSelected;
private readonly Window window; private readonly Window window;

View File

@@ -1,6 +0,0 @@
<?xml version="1.0" encoding="utf-8"?>
<configuration>
<packageSources>
<add key="chylex's repository" value="https://nuget.chylex.com/feed/index.json" />
</packageSources>
</configuration>

View File

@@ -0,0 +1,46 @@
using DHT.Utils.Models;
namespace DHT.Server.Database;
/// <summary>
/// A live view of database statistics.
/// Some of the totals are computed asynchronously and may not reflect the most recent version of the database, or may not be available at all until computed for the first time.
/// </summary>
public sealed class DatabaseStatistics : BaseModel {
private long totalServers;
private long totalChannels;
private long totalUsers;
private long? totalMessages;
private long? totalAttachments;
private long? totalDownloads;
public long TotalServers {
get => totalServers;
internal set => Change(ref totalServers, value);
}
public long TotalChannels {
get => totalChannels;
internal set => Change(ref totalChannels, value);
}
public long TotalUsers {
get => totalUsers;
internal set => Change(ref totalUsers, value);
}
public long? TotalMessages {
get => totalMessages;
internal set => Change(ref totalMessages, value);
}
public long? TotalAttachments {
get => totalAttachments;
internal set => Change(ref totalAttachments, value);
}
public long? TotalDownloads {
get => totalDownloads;
internal set => Change(ref totalDownloads, value);
}
}

View File

@@ -0,0 +1,11 @@
namespace DHT.Server.Database;
/// <summary>
/// A complete snapshot of database statistics at a particular point in time.
/// </summary>
public readonly struct DatabaseStatisticsSnapshot {
public long TotalServers { get; internal init; }
public long TotalChannels { get; internal init; }
public long TotalUsers { get; internal init; }
public long TotalMessages { get; internal init; }
}

View File

@@ -9,21 +9,23 @@ sealed class DummyDatabaseFile : IDatabaseFile {
public static DummyDatabaseFile Instance { get; } = new (); public static DummyDatabaseFile Instance { get; } = new ();
public string Path => ""; public string Path => "";
public DatabaseStatistics Statistics { get; } = new ();
public IUserRepository Users { get; } = new IUserRepository.Dummy(); public IUserRepository Users { get; } = new IUserRepository.Dummy();
public IServerRepository Servers { get; } = new IServerRepository.Dummy(); public IServerRepository Servers { get; } = new IServerRepository.Dummy();
public IChannelRepository Channels { get; } = new IChannelRepository.Dummy(); public IChannelRepository Channels { get; } = new IChannelRepository.Dummy();
public IMessageRepository Messages { get; } = new IMessageRepository.Dummy(); public IMessageRepository Messages { get; } = new IMessageRepository.Dummy();
public IAttachmentRepository Attachments { get; } = new IAttachmentRepository.Dummy();
public IDownloadRepository Downloads { get; } = new IDownloadRepository.Dummy(); public IDownloadRepository Downloads { get; } = new IDownloadRepository.Dummy();
private DummyDatabaseFile() {} private DummyDatabaseFile() {}
public Task<DatabaseStatisticsSnapshot> SnapshotStatistics() {
return Task.FromResult(new DatabaseStatisticsSnapshot());
}
public Task Vacuum() { public Task Vacuum() {
return Task.CompletedTask; return Task.CompletedTask;
} }
public ValueTask DisposeAsync() { public void Dispose() {}
return ValueTask.CompletedTask;
}
} }

View File

@@ -3,9 +3,9 @@ using System.Text.Json.Serialization;
namespace DHT.Server.Database.Export; namespace DHT.Server.Database.Export;
[JsonSourceGenerationOptions( [JsonSourceGenerationOptions(
Converters = [typeof(SnowflakeJsonSerializer)], Converters = new [] { typeof(SnowflakeJsonSerializer) },
PropertyNamingPolicy = JsonKnownNamingPolicy.CamelCase, PropertyNamingPolicy = JsonKnownNamingPolicy.CamelCase,
GenerationMode = JsonSourceGenerationMode.Default GenerationMode = JsonSourceGenerationMode.Default
)] )]
[JsonSerializable(typeof(ViewerJson))] [JsonSerializable(typeof(ViewerJson))]
sealed partial class ViewerJsonContext : JsonSerializerContext; sealed partial class ViewerJsonContext : JsonSerializerContext {}

View File

@@ -4,14 +4,15 @@ using DHT.Server.Database.Repositories;
namespace DHT.Server.Database; namespace DHT.Server.Database;
public interface IDatabaseFile : IAsyncDisposable { public interface IDatabaseFile : IDisposable {
string Path { get; } string Path { get; }
DatabaseStatistics Statistics { get; }
Task<DatabaseStatisticsSnapshot> SnapshotStatistics();
IUserRepository Users { get; } IUserRepository Users { get; }
IServerRepository Servers { get; } IServerRepository Servers { get; }
IChannelRepository Channels { get; } IChannelRepository Channels { get; }
IMessageRepository Messages { get; } IMessageRepository Messages { get; }
IAttachmentRepository Attachments { get; }
IDownloadRepository Downloads { get; } IDownloadRepository Downloads { get; }
Task Vacuum(); Task Vacuum();

View File

@@ -4,4 +4,4 @@ namespace DHT.Server.Database.Import;
[JsonSourceGenerationOptions(PropertyNamingPolicy = JsonKnownNamingPolicy.SnakeCaseLower, GenerationMode = JsonSourceGenerationMode.Default)] [JsonSourceGenerationOptions(PropertyNamingPolicy = JsonKnownNamingPolicy.SnakeCaseLower, GenerationMode = JsonSourceGenerationMode.Default)]
[JsonSerializable(typeof(DiscordEmbedLegacyJson))] [JsonSerializable(typeof(DiscordEmbedLegacyJson))]
sealed partial class DiscordEmbedLegacyJsonContext : JsonSerializerContext; sealed partial class DiscordEmbedLegacyJsonContext : JsonSerializerContext {}

View File

@@ -1,21 +0,0 @@
using System;
using System.Reactive.Linq;
using System.Threading;
using System.Threading.Tasks;
using DHT.Server.Data.Filters;
namespace DHT.Server.Database.Repositories;
public interface IAttachmentRepository {
IObservable<long> TotalCount { get; }
Task<long> Count(AttachmentFilter? filter = null, CancellationToken cancellationToken = default);
internal sealed class Dummy : IAttachmentRepository {
public IObservable<long> TotalCount { get; } = Observable.Return(0L);
public Task<long> Count(AttachmentFilter? filter = null, CancellationToken cancellationToken = default) {
return Task.FromResult(0L);
}
}
}

View File

@@ -1,33 +1,20 @@
using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Reactive.Linq;
using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using DHT.Server.Data; using DHT.Server.Data;
namespace DHT.Server.Database.Repositories; namespace DHT.Server.Database.Repositories;
public interface IChannelRepository { public interface IChannelRepository {
IObservable<long> TotalCount { get; }
Task Add(IReadOnlyList<Channel> channels); Task Add(IReadOnlyList<Channel> channels);
Task<long> Count(CancellationToken cancellationToken = default);
IAsyncEnumerable<Channel> Get(); IAsyncEnumerable<Channel> Get();
internal sealed class Dummy : IChannelRepository { internal sealed class Dummy : IChannelRepository {
public IObservable<long> TotalCount { get; } = Observable.Return(0L);
public Task Add(IReadOnlyList<Channel> channels) { public Task Add(IReadOnlyList<Channel> channels) {
return Task.CompletedTask; return Task.CompletedTask;
} }
public Task<long> Count(CancellationToken cancellationToken) {
return Task.FromResult(0L);
}
public IAsyncEnumerable<Channel> Get() { public IAsyncEnumerable<Channel> Get() {
return AsyncEnumerable.Empty<Channel>(); return AsyncEnumerable.Empty<Channel>();
} }

View File

@@ -1,7 +1,5 @@
using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Reactive.Linq;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using DHT.Server.Data; using DHT.Server.Data;
@@ -12,7 +10,7 @@ using DHT.Server.Download;
namespace DHT.Server.Database.Repositories; namespace DHT.Server.Database.Repositories;
public interface IDownloadRepository { public interface IDownloadRepository {
IObservable<long> TotalCount { get; } Task<long> CountAttachments(AttachmentFilter? filter = null, CancellationToken cancellationToken = default);
Task AddDownload(Data.Download download); Task AddDownload(Data.Download download);
@@ -24,14 +22,16 @@ public interface IDownloadRepository {
Task<DownloadedAttachment?> GetDownloadedAttachment(string normalizedUrl); Task<DownloadedAttachment?> GetDownloadedAttachment(string normalizedUrl);
Task<int> EnqueueDownloadItems(AttachmentFilter? filter = null, CancellationToken cancellationToken = default); Task EnqueueDownloadItems(AttachmentFilter? filter = null, CancellationToken cancellationToken = default);
IAsyncEnumerable<DownloadItem> PullEnqueuedDownloadItems(int count, CancellationToken cancellationToken = default); IAsyncEnumerable<DownloadItem> PullEnqueuedDownloadItems(int count, CancellationToken cancellationToken = default);
Task RemoveDownloadItems(DownloadItemFilter? filter, FilterRemovalMode mode); Task RemoveDownloadItems(DownloadItemFilter? filter, FilterRemovalMode mode);
internal sealed class Dummy : IDownloadRepository { internal sealed class Dummy : IDownloadRepository {
public IObservable<long> TotalCount { get; } = Observable.Return(0L); public Task<long> CountAttachments(AttachmentFilter? filter, CancellationToken cancellationToken) {
return Task.FromResult(0L);
}
public Task AddDownload(Data.Download download) { public Task AddDownload(Data.Download download) {
return Task.CompletedTask; return Task.CompletedTask;
@@ -53,8 +53,8 @@ public interface IDownloadRepository {
return Task.FromResult<DownloadedAttachment?>(null); return Task.FromResult<DownloadedAttachment?>(null);
} }
public Task<int> EnqueueDownloadItems(AttachmentFilter? filter, CancellationToken cancellationToken) { public Task EnqueueDownloadItems(AttachmentFilter? filter, CancellationToken cancellationToken) {
return Task.FromResult(0); return Task.CompletedTask;
} }
public IAsyncEnumerable<DownloadItem> PullEnqueuedDownloadItems(int count, CancellationToken cancellationToken) { public IAsyncEnumerable<DownloadItem> PullEnqueuedDownloadItems(int count, CancellationToken cancellationToken) {

View File

@@ -1,7 +1,5 @@
using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Reactive.Linq;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using DHT.Server.Data; using DHT.Server.Data;
@@ -10,8 +8,6 @@ using DHT.Server.Data.Filters;
namespace DHT.Server.Database.Repositories; namespace DHT.Server.Database.Repositories;
public interface IMessageRepository { public interface IMessageRepository {
IObservable<long> TotalCount { get; }
Task Add(IReadOnlyList<Message> messages); Task Add(IReadOnlyList<Message> messages);
Task<long> Count(MessageFilter? filter = null, CancellationToken cancellationToken = default); Task<long> Count(MessageFilter? filter = null, CancellationToken cancellationToken = default);
@@ -23,8 +19,6 @@ public interface IMessageRepository {
Task Remove(MessageFilter filter, FilterRemovalMode mode); Task Remove(MessageFilter filter, FilterRemovalMode mode);
internal sealed class Dummy : IMessageRepository { internal sealed class Dummy : IMessageRepository {
public IObservable<long> TotalCount { get; } = Observable.Return(0L);
public Task Add(IReadOnlyList<Message> messages) { public Task Add(IReadOnlyList<Message> messages) {
return Task.CompletedTask; return Task.CompletedTask;
} }

View File

@@ -1,32 +1,19 @@
using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Reactive.Linq;
using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
namespace DHT.Server.Database.Repositories; namespace DHT.Server.Database.Repositories;
public interface IServerRepository { public interface IServerRepository {
IObservable<long> TotalCount { get; }
Task Add(IReadOnlyList<Data.Server> servers); Task Add(IReadOnlyList<Data.Server> servers);
Task<long> Count(CancellationToken cancellationToken = default);
IAsyncEnumerable<Data.Server> Get(); IAsyncEnumerable<Data.Server> Get();
internal sealed class Dummy : IServerRepository { internal sealed class Dummy : IServerRepository {
public IObservable<long> TotalCount { get; } = Observable.Return(0L);
public Task Add(IReadOnlyList<Data.Server> servers) { public Task Add(IReadOnlyList<Data.Server> servers) {
return Task.CompletedTask; return Task.CompletedTask;
} }
public Task<long> Count(CancellationToken cancellationToken) {
return Task.FromResult(0L);
}
public IAsyncEnumerable<Data.Server> Get() { public IAsyncEnumerable<Data.Server> Get() {
return AsyncEnumerable.Empty<Data.Server>(); return AsyncEnumerable.Empty<Data.Server>();
} }

View File

@@ -1,33 +1,20 @@
using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Reactive.Linq;
using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using DHT.Server.Data; using DHT.Server.Data;
namespace DHT.Server.Database.Repositories; namespace DHT.Server.Database.Repositories;
public interface IUserRepository { public interface IUserRepository {
IObservable<long> TotalCount { get; }
Task Add(IReadOnlyList<User> users); Task Add(IReadOnlyList<User> users);
Task<long> Count(CancellationToken cancellationToken = default);
IAsyncEnumerable<User> Get(); IAsyncEnumerable<User> Get();
internal sealed class Dummy : IUserRepository { internal sealed class Dummy : IUserRepository {
public IObservable<long> TotalCount { get; } = Observable.Return(0L);
public Task Add(IReadOnlyList<User> users) { public Task Add(IReadOnlyList<User> users) {
return Task.CompletedTask; return Task.CompletedTask;
} }
public Task<long> Count(CancellationToken cancellationToken) {
return Task.FromResult(0L);
}
public IAsyncEnumerable<User> Get() { public IAsyncEnumerable<User> Get() {
return AsyncEnumerable.Empty<User>(); return AsyncEnumerable.Empty<User>();
} }

View File

@@ -1,26 +0,0 @@
using System;
using System.Threading;
using System.Threading.Tasks;
using DHT.Utils.Tasks;
namespace DHT.Server.Database.Sqlite.Repositories;
abstract class BaseSqliteRepository : IDisposable {
private readonly ObservableThrottledTask<long> totalCountTask = new (TaskScheduler.Default);
public IObservable<long> TotalCount => totalCountTask;
protected BaseSqliteRepository() {
UpdateTotalCount();
}
public void Dispose() {
totalCountTask.Dispose();
}
protected void UpdateTotalCount() {
totalCountTask.Post(Count);
}
public abstract Task<long> Count(CancellationToken cancellationToken);
}

View File

@@ -1,28 +0,0 @@
using System.Threading;
using System.Threading.Tasks;
using DHT.Server.Data.Filters;
using DHT.Server.Database.Repositories;
using DHT.Server.Database.Sqlite.Utils;
namespace DHT.Server.Database.Sqlite.Repositories;
sealed class SqliteAttachmentRepository : BaseSqliteRepository, IAttachmentRepository {
private readonly SqliteConnectionPool pool;
public SqliteAttachmentRepository(SqliteConnectionPool pool) {
this.pool = pool;
}
internal new void UpdateTotalCount() {
base.UpdateTotalCount();
}
public override Task<long> Count(CancellationToken cancellationToken) {
return Count(filter: null, cancellationToken);
}
public async Task<long> Count(AttachmentFilter? filter, CancellationToken cancellationToken) {
await using var conn = await pool.Take();
return await conn.ExecuteReaderAsync("SELECT COUNT(DISTINCT normalized_url) FROM attachments a" + filter.GenerateWhereClause("a"), static reader => reader?.GetInt64(0) ?? 0L, cancellationToken);
}
}

View File

@@ -1,5 +1,4 @@
using System.Collections.Generic; using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using DHT.Server.Data; using DHT.Server.Data;
using DHT.Server.Database.Repositories; using DHT.Server.Database.Repositories;
@@ -8,15 +7,26 @@ using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite.Repositories; namespace DHT.Server.Database.Sqlite.Repositories;
sealed class SqliteChannelRepository : BaseSqliteRepository, IChannelRepository { sealed class SqliteChannelRepository : IChannelRepository {
private readonly SqliteConnectionPool pool; private readonly SqliteConnectionPool pool;
private readonly DatabaseStatistics statistics;
public SqliteChannelRepository(SqliteConnectionPool pool) { public SqliteChannelRepository(SqliteConnectionPool pool, DatabaseStatistics statistics) {
this.pool = pool; this.pool = pool;
this.statistics = statistics;
}
internal async Task Initialize() {
using var conn = pool.Take();
await UpdateChannelStatistics(conn);
}
private async Task UpdateChannelStatistics(ISqliteConnection conn) {
statistics.TotalChannels = await conn.ExecuteReaderAsync("SELECT COUNT(*) FROM channels", static reader => reader?.GetInt64(0) ?? 0L);
} }
public async Task Add(IReadOnlyList<Channel> channels) { public async Task Add(IReadOnlyList<Channel> channels) {
await using var conn = await pool.Take(); using var conn = pool.Take();
await using (var tx = await conn.BeginTransactionAsync()) { await using (var tx = await conn.BeginTransactionAsync()) {
await using var cmd = conn.Upsert("channels", [ await using var cmd = conn.Upsert("channels", [
@@ -43,16 +53,11 @@ sealed class SqliteChannelRepository : BaseSqliteRepository, IChannelRepository
await tx.CommitAsync(); await tx.CommitAsync();
} }
UpdateTotalCount(); await UpdateChannelStatistics(conn);
}
public override async Task<long> Count(CancellationToken cancellationToken) {
await using var conn = await pool.Take();
return await conn.ExecuteReaderAsync("SELECT COUNT(*) FROM channels", static reader => reader?.GetInt64(0) ?? 0L, cancellationToken);
} }
public async IAsyncEnumerable<Channel> Get() { public async IAsyncEnumerable<Channel> Get() {
await using var conn = await pool.Take(); using var conn = pool.Take();
await using var cmd = conn.Command("SELECT id, server, name, parent_id, position, topic, nsfw FROM channels"); await using var cmd = conn.Command("SELECT id, server, name, parent_id, position, topic, nsfw FROM channels");
await using var reader = await cmd.ExecuteReaderAsync(); await using var reader = await cmd.ExecuteReaderAsync();

View File

@@ -9,19 +9,27 @@ using DHT.Server.Data.Filters;
using DHT.Server.Database.Repositories; using DHT.Server.Database.Repositories;
using DHT.Server.Database.Sqlite.Utils; using DHT.Server.Database.Sqlite.Utils;
using DHT.Server.Download; using DHT.Server.Download;
using DHT.Utils.Tasks;
using Microsoft.Data.Sqlite; using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite.Repositories; namespace DHT.Server.Database.Sqlite.Repositories;
sealed class SqliteDownloadRepository : BaseSqliteRepository, IDownloadRepository { sealed class SqliteDownloadRepository : IDownloadRepository {
private readonly SqliteConnectionPool pool; private readonly SqliteConnectionPool pool;
private readonly AsyncValueComputer<long>.Single totalDownloadsComputer;
public SqliteDownloadRepository(SqliteConnectionPool pool) { public SqliteDownloadRepository(SqliteConnectionPool pool, AsyncValueComputer<long>.Single totalDownloadsComputer) {
this.pool = pool; this.pool = pool;
this.totalDownloadsComputer = totalDownloadsComputer;
}
public async Task<long> CountAttachments(AttachmentFilter? filter, CancellationToken cancellationToken) {
using var conn = pool.Take();
return await conn.ExecuteReaderAsync("SELECT COUNT(DISTINCT normalized_url) FROM attachments a" + filter.GenerateWhereClause("a"), static reader => reader?.GetInt64(0) ?? 0L, cancellationToken);
} }
public async Task AddDownload(Data.Download download) { public async Task AddDownload(Data.Download download) {
await using (var conn = await pool.Take()) { using (var conn = pool.Take()) {
await using var cmd = conn.Upsert("downloads", [ await using var cmd = conn.Upsert("downloads", [
("normalized_url", SqliteType.Text), ("normalized_url", SqliteType.Text),
("download_url", SqliteType.Text), ("download_url", SqliteType.Text),
@@ -38,12 +46,7 @@ sealed class SqliteDownloadRepository : BaseSqliteRepository, IDownloadRepositor
await cmd.ExecuteNonQueryAsync(); await cmd.ExecuteNonQueryAsync();
} }
UpdateTotalCount(); totalDownloadsComputer.Recompute();
}
public override async Task<long> Count(CancellationToken cancellationToken) {
await using var conn = await pool.Take();
return await conn.ExecuteReaderAsync("SELECT COUNT(*) FROM downloads", static reader => reader?.GetInt64(0) ?? 0L, cancellationToken);
} }
public async Task<DownloadStatusStatistics> GetStatistics(CancellationToken cancellationToken) { public async Task<DownloadStatusStatistics> GetStatistics(CancellationToken cancellationToken) {
@@ -97,14 +100,14 @@ sealed class SqliteDownloadRepository : BaseSqliteRepository, IDownloadRepositor
var result = new DownloadStatusStatistics(); var result = new DownloadStatusStatistics();
await using var conn = await pool.Take(); using var conn = pool.Take();
await LoadUndownloadedStatistics(conn, result, cancellationToken); await LoadUndownloadedStatistics(conn, result, cancellationToken);
await LoadSuccessStatistics(conn, result, cancellationToken); await LoadSuccessStatistics(conn, result, cancellationToken);
return result; return result;
} }
public async IAsyncEnumerable<Data.Download> GetWithoutData() { public async IAsyncEnumerable<Data.Download> GetWithoutData() {
await using var conn = await pool.Take(); using var conn = pool.Take();
await using var cmd = conn.Command("SELECT normalized_url, download_url, status, size FROM downloads"); await using var cmd = conn.Command("SELECT normalized_url, download_url, status, size FROM downloads");
await using var reader = await cmd.ExecuteReaderAsync(); await using var reader = await cmd.ExecuteReaderAsync();
@@ -120,7 +123,7 @@ sealed class SqliteDownloadRepository : BaseSqliteRepository, IDownloadRepositor
} }
public async Task<Data.Download> HydrateWithData(Data.Download download) { public async Task<Data.Download> HydrateWithData(Data.Download download) {
await using var conn = await pool.Take(); using var conn = pool.Take();
await using var cmd = conn.Command("SELECT blob FROM downloads WHERE normalized_url = :url"); await using var cmd = conn.Command("SELECT blob FROM downloads WHERE normalized_url = :url");
cmd.AddAndSet(":url", SqliteType.Text, download.NormalizedUrl); cmd.AddAndSet(":url", SqliteType.Text, download.NormalizedUrl);
@@ -136,7 +139,7 @@ sealed class SqliteDownloadRepository : BaseSqliteRepository, IDownloadRepositor
} }
public async Task<DownloadedAttachment?> GetDownloadedAttachment(string normalizedUrl) { public async Task<DownloadedAttachment?> GetDownloadedAttachment(string normalizedUrl) {
await using var conn = await pool.Take(); using var conn = pool.Take();
await using var cmd = conn.Command( await using var cmd = conn.Command(
""" """
@@ -161,8 +164,8 @@ sealed class SqliteDownloadRepository : BaseSqliteRepository, IDownloadRepositor
}; };
} }
public async Task<int> EnqueueDownloadItems(AttachmentFilter? filter, CancellationToken cancellationToken) { public async Task EnqueueDownloadItems(AttachmentFilter? filter, CancellationToken cancellationToken) {
await using var conn = await pool.Take(); using var conn = pool.Take();
await using var cmd = conn.Command( await using var cmd = conn.Command(
$""" $"""
@@ -175,13 +178,13 @@ sealed class SqliteDownloadRepository : BaseSqliteRepository, IDownloadRepositor
); );
cmd.AddAndSet(":enqueued", SqliteType.Integer, (int) DownloadStatus.Enqueued); cmd.AddAndSet(":enqueued", SqliteType.Integer, (int) DownloadStatus.Enqueued);
return await cmd.ExecuteNonQueryAsync(cancellationToken); await cmd.ExecuteNonQueryAsync(cancellationToken);
} }
public async IAsyncEnumerable<DownloadItem> PullEnqueuedDownloadItems(int count, [EnumeratorCancellation] CancellationToken cancellationToken) { public async IAsyncEnumerable<DownloadItem> PullEnqueuedDownloadItems(int count, [EnumeratorCancellation] CancellationToken cancellationToken) {
var found = new List<DownloadItem>(); var found = new List<DownloadItem>();
await using var conn = await pool.Take(); using var conn = pool.Take();
await using (var cmd = conn.Command("SELECT normalized_url, download_url, size FROM downloads WHERE status = :enqueued LIMIT :limit")) { await using (var cmd = conn.Command("SELECT normalized_url, download_url, size FROM downloads WHERE status = :enqueued LIMIT :limit")) {
cmd.AddAndSet(":enqueued", SqliteType.Integer, (int) DownloadStatus.Enqueued); cmd.AddAndSet(":enqueued", SqliteType.Integer, (int) DownloadStatus.Enqueued);
@@ -215,7 +218,7 @@ sealed class SqliteDownloadRepository : BaseSqliteRepository, IDownloadRepositor
} }
public async Task RemoveDownloadItems(DownloadItemFilter? filter, FilterRemovalMode mode) { public async Task RemoveDownloadItems(DownloadItemFilter? filter, FilterRemovalMode mode) {
await using (var conn = await pool.Take()) { using (var conn = pool.Take()) {
await conn.ExecuteAsync( await conn.ExecuteAsync(
$""" $"""
-- noinspection SqlWithoutWhere -- noinspection SqlWithoutWhere
@@ -225,6 +228,6 @@ sealed class SqliteDownloadRepository : BaseSqliteRepository, IDownloadRepositor
); );
} }
UpdateTotalCount(); totalDownloadsComputer.Recompute();
} }
} }

View File

@@ -7,17 +7,20 @@ using DHT.Server.Data;
using DHT.Server.Data.Filters; using DHT.Server.Data.Filters;
using DHT.Server.Database.Repositories; using DHT.Server.Database.Repositories;
using DHT.Server.Database.Sqlite.Utils; using DHT.Server.Database.Sqlite.Utils;
using DHT.Utils.Tasks;
using Microsoft.Data.Sqlite; using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite.Repositories; namespace DHT.Server.Database.Sqlite.Repositories;
sealed class SqliteMessageRepository : BaseSqliteRepository, IMessageRepository { sealed class SqliteMessageRepository : IMessageRepository {
private readonly SqliteConnectionPool pool; private readonly SqliteConnectionPool pool;
private readonly SqliteAttachmentRepository attachments; private readonly AsyncValueComputer<long>.Single totalMessagesComputer;
private readonly AsyncValueComputer<long>.Single totalAttachmentsComputer;
public SqliteMessageRepository(SqliteConnectionPool pool, SqliteAttachmentRepository attachments) { public SqliteMessageRepository(SqliteConnectionPool pool, AsyncValueComputer<long>.Single totalMessagesComputer, AsyncValueComputer<long>.Single totalAttachmentsComputer) {
this.pool = pool; this.pool = pool;
this.attachments = attachments; this.totalMessagesComputer = totalMessagesComputer;
this.totalAttachmentsComputer = totalAttachmentsComputer;
} }
public async Task Add(IReadOnlyList<Message> messages) { public async Task Add(IReadOnlyList<Message> messages) {
@@ -36,7 +39,7 @@ sealed class SqliteMessageRepository : BaseSqliteRepository, IMessageRepository
bool addedAttachments = false; bool addedAttachments = false;
await using (var conn = await pool.Take()) { using (var conn = pool.Take()) {
await using var tx = await conn.BeginTransactionAsync(); await using var tx = await conn.BeginTransactionAsync();
await using var messageCmd = conn.Upsert("messages", [ await using var messageCmd = conn.Upsert("messages", [
@@ -158,19 +161,15 @@ sealed class SqliteMessageRepository : BaseSqliteRepository, IMessageRepository
await tx.CommitAsync(); await tx.CommitAsync();
} }
UpdateTotalCount(); totalMessagesComputer.Recompute();
if (addedAttachments) { if (addedAttachments) {
attachments.UpdateTotalCount(); totalAttachmentsComputer.Recompute();
} }
} }
public override Task<long> Count(CancellationToken cancellationToken) {
return Count(filter: null, cancellationToken);
}
public async Task<long> Count(MessageFilter? filter, CancellationToken cancellationToken) { public async Task<long> Count(MessageFilter? filter, CancellationToken cancellationToken) {
await using var conn = await pool.Take(); using var conn = pool.Take();
return await conn.ExecuteReaderAsync("SELECT COUNT(*) FROM messages" + filter.GenerateWhereClause(), static reader => reader?.GetInt64(0) ?? 0L, cancellationToken); return await conn.ExecuteReaderAsync("SELECT COUNT(*) FROM messages" + filter.GenerateWhereClause(), static reader => reader?.GetInt64(0) ?? 0L, cancellationToken);
} }
@@ -205,7 +204,7 @@ sealed class SqliteMessageRepository : BaseSqliteRepository, IMessageRepository
} }
public async IAsyncEnumerable<Message> Get(MessageFilter? filter) { public async IAsyncEnumerable<Message> Get(MessageFilter? filter) {
await using var conn = await pool.Take(); using var conn = pool.Take();
const string AttachmentSql = const string AttachmentSql =
""" """
@@ -281,7 +280,7 @@ sealed class SqliteMessageRepository : BaseSqliteRepository, IMessageRepository
} }
public async IAsyncEnumerable<ulong> GetIds(MessageFilter? filter) { public async IAsyncEnumerable<ulong> GetIds(MessageFilter? filter) {
await using var conn = await pool.Take(); using var conn = pool.Take();
await using var cmd = conn.Command("SELECT message_id FROM messages" + filter.GenerateWhereClause()); await using var cmd = conn.Command("SELECT message_id FROM messages" + filter.GenerateWhereClause());
await using var reader = await cmd.ExecuteReaderAsync(); await using var reader = await cmd.ExecuteReaderAsync();
@@ -292,7 +291,7 @@ sealed class SqliteMessageRepository : BaseSqliteRepository, IMessageRepository
} }
public async Task Remove(MessageFilter filter, FilterRemovalMode mode) { public async Task Remove(MessageFilter filter, FilterRemovalMode mode) {
await using (var conn = await pool.Take()) { using (var conn = pool.Take()) {
await conn.ExecuteAsync( await conn.ExecuteAsync(
$""" $"""
-- noinspection SqlWithoutWhere -- noinspection SqlWithoutWhere
@@ -302,6 +301,6 @@ sealed class SqliteMessageRepository : BaseSqliteRepository, IMessageRepository
); );
} }
UpdateTotalCount(); totalMessagesComputer.Recompute();
} }
} }

View File

@@ -1,5 +1,4 @@
using System.Collections.Generic; using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using DHT.Server.Data; using DHT.Server.Data;
using DHT.Server.Database.Repositories; using DHT.Server.Database.Repositories;
@@ -8,15 +7,26 @@ using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite.Repositories; namespace DHT.Server.Database.Sqlite.Repositories;
sealed class SqliteServerRepository : BaseSqliteRepository, IServerRepository { sealed class SqliteServerRepository : IServerRepository {
private readonly SqliteConnectionPool pool; private readonly SqliteConnectionPool pool;
private readonly DatabaseStatistics statistics;
public SqliteServerRepository(SqliteConnectionPool pool) { public SqliteServerRepository(SqliteConnectionPool pool, DatabaseStatistics statistics) {
this.pool = pool; this.pool = pool;
this.statistics = statistics;
}
internal async Task Initialize() {
using var conn = pool.Take();
await UpdateServerStatistics(conn);
}
private async Task UpdateServerStatistics(ISqliteConnection conn) {
statistics.TotalServers = await conn.ExecuteReaderAsync("SELECT COUNT(*) FROM servers", static reader => reader?.GetInt64(0) ?? 0L);
} }
public async Task Add(IReadOnlyList<Data.Server> servers) { public async Task Add(IReadOnlyList<Data.Server> servers) {
await using var conn = await pool.Take(); using var conn = pool.Take();
await using (var tx = await conn.BeginTransactionAsync()) { await using (var tx = await conn.BeginTransactionAsync()) {
await using var cmd = conn.Upsert("servers", [ await using var cmd = conn.Upsert("servers", [
@@ -35,16 +45,11 @@ sealed class SqliteServerRepository : BaseSqliteRepository, IServerRepository {
await tx.CommitAsync(); await tx.CommitAsync();
} }
UpdateTotalCount(); await UpdateServerStatistics(conn);
}
public override async Task<long> Count(CancellationToken cancellationToken) {
await using var conn = await pool.Take();
return await conn.ExecuteReaderAsync("SELECT COUNT(*) FROM servers", static reader => reader?.GetInt64(0) ?? 0L, cancellationToken);
} }
public async IAsyncEnumerable<Data.Server> Get() { public async IAsyncEnumerable<Data.Server> Get() {
await using var conn = await pool.Take(); using var conn = pool.Take();
await using var cmd = conn.Command("SELECT id, name, type FROM servers"); await using var cmd = conn.Command("SELECT id, name, type FROM servers");
await using var reader = await cmd.ExecuteReaderAsync(); await using var reader = await cmd.ExecuteReaderAsync();

View File

@@ -1,5 +1,4 @@
using System.Collections.Generic; using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using DHT.Server.Data; using DHT.Server.Data;
using DHT.Server.Database.Repositories; using DHT.Server.Database.Repositories;
@@ -8,17 +7,28 @@ using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite.Repositories; namespace DHT.Server.Database.Sqlite.Repositories;
sealed class SqliteUserRepository : BaseSqliteRepository, IUserRepository { sealed class SqliteUserRepository : IUserRepository {
private readonly SqliteConnectionPool pool; private readonly SqliteConnectionPool pool;
private readonly DatabaseStatistics statistics;
public SqliteUserRepository(SqliteConnectionPool pool) { public SqliteUserRepository(SqliteConnectionPool pool, DatabaseStatistics statistics) {
this.pool = pool; this.pool = pool;
this.statistics = statistics;
}
internal async Task Initialize() {
using var conn = pool.Take();
await UpdateUserStatistics(conn);
}
private async Task UpdateUserStatistics(ISqliteConnection conn) {
statistics.TotalUsers = await conn.ExecuteReaderAsync("SELECT COUNT(*) FROM users", static reader => reader?.GetInt64(0) ?? 0L);
} }
public async Task Add(IReadOnlyList<User> users) { public async Task Add(IReadOnlyList<User> users) {
await using (var conn = await pool.Take()) { using var conn = pool.Take();
await using var tx = await conn.BeginTransactionAsync();
await using (var tx = await conn.BeginTransactionAsync()) {
await using var cmd = conn.Upsert("users", [ await using var cmd = conn.Upsert("users", [
("id", SqliteType.Integer), ("id", SqliteType.Integer),
("name", SqliteType.Text), ("name", SqliteType.Text),
@@ -37,16 +47,11 @@ sealed class SqliteUserRepository : BaseSqliteRepository, IUserRepository {
await tx.CommitAsync(); await tx.CommitAsync();
} }
UpdateTotalCount(); await UpdateUserStatistics(conn);
}
public override async Task<long> Count(CancellationToken cancellationToken) {
await using var conn = await pool.Take();
return await conn.ExecuteReaderAsync("SELECT COUNT(*) FROM users", static reader => reader?.GetInt64(0) ?? 0L, cancellationToken);
} }
public async IAsyncEnumerable<User> Get() { public async IAsyncEnumerable<User> Get() {
await using var conn = await pool.Take(); using var conn = pool.Take();
await using var cmd = conn.Command("SELECT id, name, avatar_url, discriminator FROM users"); await using var cmd = conn.Command("SELECT id, name, avatar_url, discriminator FROM users");
await using var reader = await cmd.ExecuteReaderAsync(); await using var reader = await cmd.ExecuteReaderAsync();

View File

@@ -21,14 +21,14 @@ sealed class Schema {
} }
public async Task<bool> Setup(ISchemaUpgradeCallbacks callbacks) { public async Task<bool> Setup(ISchemaUpgradeCallbacks callbacks) {
await conn.ExecuteAsync("CREATE TABLE IF NOT EXISTS metadata (key TEXT PRIMARY KEY, value TEXT)"); conn.Execute(@"CREATE TABLE IF NOT EXISTS metadata (key TEXT PRIMARY KEY, value TEXT)");
var dbVersionStr = await conn.ExecuteReaderAsync("SELECT value FROM metadata WHERE key = 'version'", static reader => reader?.GetString(0)); var dbVersionStr = conn.SelectScalar("SELECT value FROM metadata WHERE key = 'version'");
if (dbVersionStr == null) { if (dbVersionStr == null) {
await InitializeSchemas(); InitializeSchemas();
} }
else if (!int.TryParse(dbVersionStr, out int dbVersion) || dbVersion < 1) { else if (!int.TryParse(dbVersionStr.ToString(), out int dbVersion) || dbVersion < 1) {
throw new InvalidDatabaseVersionException(dbVersionStr); throw new InvalidDatabaseVersionException(dbVersionStr.ToString() ?? "<null>");
} }
else if (dbVersion > Version) { else if (dbVersion > Version) {
throw new DatabaseTooNewException(dbVersion); throw new DatabaseTooNewException(dbVersion);
@@ -45,8 +45,8 @@ sealed class Schema {
return true; return true;
} }
private async Task InitializeSchemas() { private void InitializeSchemas() {
await conn.ExecuteAsync(""" conn.Execute("""
CREATE TABLE users ( CREATE TABLE users (
id INTEGER PRIMARY KEY NOT NULL, id INTEGER PRIMARY KEY NOT NULL,
name TEXT NOT NULL, name TEXT NOT NULL,
@@ -55,7 +55,7 @@ sealed class Schema {
) )
"""); """);
await conn.ExecuteAsync(""" conn.Execute("""
CREATE TABLE servers ( CREATE TABLE servers (
id INTEGER PRIMARY KEY NOT NULL, id INTEGER PRIMARY KEY NOT NULL,
name TEXT NOT NULL, name TEXT NOT NULL,
@@ -63,7 +63,7 @@ sealed class Schema {
) )
"""); """);
await conn.ExecuteAsync(""" conn.Execute("""
CREATE TABLE channels ( CREATE TABLE channels (
id INTEGER PRIMARY KEY NOT NULL, id INTEGER PRIMARY KEY NOT NULL,
server INTEGER NOT NULL, server INTEGER NOT NULL,
@@ -75,7 +75,7 @@ sealed class Schema {
) )
"""); """);
await conn.ExecuteAsync(""" conn.Execute("""
CREATE TABLE messages ( CREATE TABLE messages (
message_id INTEGER PRIMARY KEY NOT NULL, message_id INTEGER PRIMARY KEY NOT NULL,
sender_id INTEGER NOT NULL, sender_id INTEGER NOT NULL,
@@ -85,7 +85,7 @@ sealed class Schema {
) )
"""); """);
await conn.ExecuteAsync(""" conn.Execute("""
CREATE TABLE attachments ( CREATE TABLE attachments (
message_id INTEGER NOT NULL, message_id INTEGER NOT NULL,
attachment_id INTEGER NOT NULL PRIMARY KEY NOT NULL, attachment_id INTEGER NOT NULL PRIMARY KEY NOT NULL,
@@ -99,14 +99,14 @@ sealed class Schema {
) )
"""); """);
await conn.ExecuteAsync(""" conn.Execute("""
CREATE TABLE embeds ( CREATE TABLE embeds (
message_id INTEGER NOT NULL, message_id INTEGER NOT NULL,
json TEXT NOT NULL json TEXT NOT NULL
) )
"""); """);
await conn.ExecuteAsync(""" conn.Execute("""
CREATE TABLE downloads ( CREATE TABLE downloads (
normalized_url TEXT NOT NULL PRIMARY KEY, normalized_url TEXT NOT NULL PRIMARY KEY,
download_url TEXT, download_url TEXT,
@@ -116,7 +116,7 @@ sealed class Schema {
) )
"""); """);
await conn.ExecuteAsync(""" conn.Execute("""
CREATE TABLE reactions ( CREATE TABLE reactions (
message_id INTEGER NOT NULL, message_id INTEGER NOT NULL,
emoji_id INTEGER, emoji_id INTEGER,
@@ -126,18 +126,18 @@ sealed class Schema {
) )
"""); """);
await CreateMessageEditTimestampTable(); CreateMessageEditTimestampTable();
await CreateMessageRepliedToTable(); CreateMessageRepliedToTable();
await conn.ExecuteAsync("CREATE INDEX attachments_message_ix ON attachments(message_id)"); conn.Execute("CREATE INDEX attachments_message_ix ON attachments(message_id)");
await conn.ExecuteAsync("CREATE INDEX embeds_message_ix ON embeds(message_id)"); conn.Execute("CREATE INDEX embeds_message_ix ON embeds(message_id)");
await conn.ExecuteAsync("CREATE INDEX reactions_message_ix ON reactions(message_id)"); conn.Execute("CREATE INDEX reactions_message_ix ON reactions(message_id)");
await conn.ExecuteAsync("INSERT INTO metadata (key, value) VALUES ('version', " + Version + ")"); conn.Execute("INSERT INTO metadata (key, value) VALUES ('version', " + Version + ")");
} }
private async Task CreateMessageEditTimestampTable() { private void CreateMessageEditTimestampTable() {
await conn.ExecuteAsync(""" conn.Execute("""
CREATE TABLE edit_timestamps ( CREATE TABLE edit_timestamps (
message_id INTEGER PRIMARY KEY NOT NULL, message_id INTEGER PRIMARY KEY NOT NULL,
edit_timestamp INTEGER NOT NULL edit_timestamp INTEGER NOT NULL
@@ -145,8 +145,8 @@ sealed class Schema {
"""); """);
} }
private async Task CreateMessageRepliedToTable() { private void CreateMessageRepliedToTable() {
await conn.ExecuteAsync(""" conn.Execute("""
CREATE TABLE replied_to ( CREATE TABLE replied_to (
message_id INTEGER PRIMARY KEY NOT NULL, message_id INTEGER PRIMARY KEY NOT NULL,
replied_to_id INTEGER NOT NULL replied_to_id INTEGER NOT NULL
@@ -213,7 +213,7 @@ sealed class Schema {
} }
} }
await conn.ExecuteAsync("PRAGMA cache_size = -20000"); conn.Execute("PRAGMA cache_size = -20000");
DbTransaction tx; DbTransaction tx;
@@ -263,17 +263,17 @@ sealed class Schema {
await tx.CommitAsync(); await tx.CommitAsync();
await tx.DisposeAsync(); await tx.DisposeAsync();
await conn.ExecuteAsync("PRAGMA cache_size = -2000"); conn.Execute("PRAGMA cache_size = -2000");
} }
private async Task UpgradeSchemas(int dbVersion, ISchemaUpgradeCallbacks.IProgressReporter reporter) { private async Task UpgradeSchemas(int dbVersion, ISchemaUpgradeCallbacks.IProgressReporter reporter) {
var perf = Log.Start("from version " + dbVersion); var perf = Log.Start("from version " + dbVersion);
await conn.ExecuteAsync("UPDATE metadata SET value = " + Version + " WHERE key = 'version'"); conn.Execute("UPDATE metadata SET value = " + Version + " WHERE key = 'version'");
if (dbVersion <= 1) { if (dbVersion <= 1) {
await reporter.MainWork("Applying schema changes...", 0, 1); await reporter.MainWork("Applying schema changes...", 0, 1);
await conn.ExecuteAsync("ALTER TABLE channels ADD parent_id INTEGER"); conn.Execute("ALTER TABLE channels ADD parent_id INTEGER");
perf.Step("Upgrade to version 2"); perf.Step("Upgrade to version 2");
await reporter.NextVersion(); await reporter.NextVersion();
@@ -282,37 +282,37 @@ sealed class Schema {
if (dbVersion <= 2) { if (dbVersion <= 2) {
await reporter.MainWork("Applying schema changes...", 0, 1); await reporter.MainWork("Applying schema changes...", 0, 1);
await CreateMessageEditTimestampTable(); CreateMessageEditTimestampTable();
await CreateMessageRepliedToTable(); CreateMessageRepliedToTable();
await conn.ExecuteAsync(""" conn.Execute("""
INSERT INTO edit_timestamps (message_id, edit_timestamp) INSERT INTO edit_timestamps (message_id, edit_timestamp)
SELECT message_id, edit_timestamp SELECT message_id, edit_timestamp
FROM messages FROM messages
WHERE edit_timestamp IS NOT NULL WHERE edit_timestamp IS NOT NULL
"""); """);
await conn.ExecuteAsync(""" conn.Execute("""
INSERT INTO replied_to (message_id, replied_to_id) INSERT INTO replied_to (message_id, replied_to_id)
SELECT message_id, replied_to_id SELECT message_id, replied_to_id
FROM messages FROM messages
WHERE replied_to_id IS NOT NULL WHERE replied_to_id IS NOT NULL
"""); """);
await conn.ExecuteAsync("ALTER TABLE messages DROP COLUMN replied_to_id"); conn.Execute("ALTER TABLE messages DROP COLUMN replied_to_id");
await conn.ExecuteAsync("ALTER TABLE messages DROP COLUMN edit_timestamp"); conn.Execute("ALTER TABLE messages DROP COLUMN edit_timestamp");
perf.Step("Upgrade to version 3"); perf.Step("Upgrade to version 3");
await reporter.MainWork("Vacuuming the database...", 1, 1); await reporter.MainWork("Vacuuming the database...", 1, 1);
await conn.ExecuteAsync("VACUUM"); conn.Execute("VACUUM");
perf.Step("Vacuum"); perf.Step("Vacuum");
await reporter.NextVersion(); await reporter.NextVersion();
} }
if (dbVersion <= 3) { if (dbVersion <= 3) {
await conn.ExecuteAsync(""" conn.Execute("""
CREATE TABLE downloads ( CREATE TABLE downloads (
url TEXT NOT NULL PRIMARY KEY, url TEXT NOT NULL PRIMARY KEY,
status INTEGER NOT NULL, status INTEGER NOT NULL,
@@ -327,8 +327,8 @@ sealed class Schema {
if (dbVersion <= 4) { if (dbVersion <= 4) {
await reporter.MainWork("Applying schema changes...", 0, 1); await reporter.MainWork("Applying schema changes...", 0, 1);
await conn.ExecuteAsync("ALTER TABLE attachments ADD width INTEGER"); conn.Execute("ALTER TABLE attachments ADD width INTEGER");
await conn.ExecuteAsync("ALTER TABLE attachments ADD height INTEGER"); conn.Execute("ALTER TABLE attachments ADD height INTEGER");
perf.Step("Upgrade to version 5"); perf.Step("Upgrade to version 5");
await reporter.NextVersion(); await reporter.NextVersion();
@@ -336,8 +336,8 @@ sealed class Schema {
if (dbVersion <= 5) { if (dbVersion <= 5) {
await reporter.MainWork("Applying schema changes...", 0, 3); await reporter.MainWork("Applying schema changes...", 0, 3);
await conn.ExecuteAsync("ALTER TABLE attachments ADD download_url TEXT"); conn.Execute("ALTER TABLE attachments ADD download_url TEXT");
await conn.ExecuteAsync("ALTER TABLE downloads ADD download_url TEXT"); conn.Execute("ALTER TABLE downloads ADD download_url TEXT");
await reporter.MainWork("Updating attachments...", 1, 3); await reporter.MainWork("Updating attachments...", 1, 3);
await NormalizeAttachmentUrls(reporter); await NormalizeAttachmentUrls(reporter);
@@ -346,8 +346,8 @@ sealed class Schema {
await NormalizeDownloadUrls(reporter); await NormalizeDownloadUrls(reporter);
await reporter.MainWork("Applying schema changes...", 3, 3); await reporter.MainWork("Applying schema changes...", 3, 3);
await conn.ExecuteAsync("ALTER TABLE attachments RENAME COLUMN url TO normalized_url"); conn.Execute("ALTER TABLE attachments RENAME COLUMN url TO normalized_url");
await conn.ExecuteAsync("ALTER TABLE downloads RENAME COLUMN url TO normalized_url"); conn.Execute("ALTER TABLE downloads RENAME COLUMN url TO normalized_url");
perf.Step("Upgrade to version 6"); perf.Step("Upgrade to version 6");
await reporter.NextVersion(); await reporter.NextVersion();

View File

@@ -3,6 +3,7 @@ using System.Threading.Tasks;
using DHT.Server.Database.Repositories; using DHT.Server.Database.Repositories;
using DHT.Server.Database.Sqlite.Repositories; using DHT.Server.Database.Sqlite.Repositories;
using DHT.Server.Database.Sqlite.Utils; using DHT.Server.Database.Sqlite.Utils;
using DHT.Utils.Tasks;
using Microsoft.Data.Sqlite; using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite; namespace DHT.Server.Database.Sqlite;
@@ -10,39 +11,41 @@ namespace DHT.Server.Database.Sqlite;
public sealed class SqliteDatabaseFile : IDatabaseFile { public sealed class SqliteDatabaseFile : IDatabaseFile {
private const int DefaultPoolSize = 5; private const int DefaultPoolSize = 5;
public static async Task<SqliteDatabaseFile?> OpenOrCreate(string path, ISchemaUpgradeCallbacks schemaUpgradeCallbacks) { public static async Task<SqliteDatabaseFile?> OpenOrCreate(string path, ISchemaUpgradeCallbacks schemaUpgradeCallbacks, TaskScheduler computeTaskResultScheduler) {
var connectionString = new SqliteConnectionStringBuilder { var connectionString = new SqliteConnectionStringBuilder {
DataSource = path, DataSource = path,
Mode = SqliteOpenMode.ReadWriteCreate, Mode = SqliteOpenMode.ReadWriteCreate,
}; };
var pool = await SqliteConnectionPool.Create(connectionString, DefaultPoolSize); var pool = new SqliteConnectionPool(connectionString, DefaultPoolSize);
bool wasOpened; bool wasOpened;
try { try {
await using var conn = await pool.Take(); using var conn = pool.Take();
wasOpened = await new Schema(conn).Setup(schemaUpgradeCallbacks); wasOpened = await new Schema(conn).Setup(schemaUpgradeCallbacks);
} catch (Exception) { } catch (Exception) {
await pool.DisposeAsync(); pool.Dispose();
throw; throw;
} }
if (wasOpened) { if (wasOpened) {
return new SqliteDatabaseFile(path, pool); var db = new SqliteDatabaseFile(path, pool, computeTaskResultScheduler);
await db.Initialize();
return db;
} }
else { else {
await pool.DisposeAsync(); pool.Dispose();
return null; return null;
} }
} }
public string Path { get; } public string Path { get; }
public DatabaseStatistics Statistics { get; }
public IUserRepository Users => users; public IUserRepository Users => users;
public IServerRepository Servers => servers; public IServerRepository Servers => servers;
public IChannelRepository Channels => channels; public IChannelRepository Channels => channels;
public IMessageRepository Messages => messages; public IMessageRepository Messages => messages;
public IAttachmentRepository Attachments => attachments;
public IDownloadRepository Downloads => downloads; public IDownloadRepository Downloads => downloads;
private readonly SqliteConnectionPool pool; private readonly SqliteConnectionPool pool;
@@ -51,32 +54,84 @@ public sealed class SqliteDatabaseFile : IDatabaseFile {
private readonly SqliteServerRepository servers; private readonly SqliteServerRepository servers;
private readonly SqliteChannelRepository channels; private readonly SqliteChannelRepository channels;
private readonly SqliteMessageRepository messages; private readonly SqliteMessageRepository messages;
private readonly SqliteAttachmentRepository attachments;
private readonly SqliteDownloadRepository downloads; private readonly SqliteDownloadRepository downloads;
private SqliteDatabaseFile(string path, SqliteConnectionPool pool) { private readonly AsyncValueComputer<long>.Single totalMessagesComputer;
this.Path = path; private readonly AsyncValueComputer<long>.Single totalAttachmentsComputer;
private readonly AsyncValueComputer<long>.Single totalDownloadsComputer;
private SqliteDatabaseFile(string path, SqliteConnectionPool pool, TaskScheduler computeTaskResultScheduler) {
this.pool = pool; this.pool = pool;
users = new SqliteUserRepository(pool); this.totalMessagesComputer = AsyncValueComputer<long>.WithResultProcessor(UpdateMessageStatistics, computeTaskResultScheduler).WithOutdatedResults().BuildWithComputer(ComputeMessageStatistics);
servers = new SqliteServerRepository(pool); this.totalAttachmentsComputer = AsyncValueComputer<long>.WithResultProcessor(UpdateAttachmentStatistics, computeTaskResultScheduler).WithOutdatedResults().BuildWithComputer(ComputeAttachmentStatistics);
channels = new SqliteChannelRepository(pool); this.totalDownloadsComputer = AsyncValueComputer<long>.WithResultProcessor(UpdateDownloadStatistics, computeTaskResultScheduler).WithOutdatedResults().BuildWithComputer(ComputeDownloadStatistics);
messages = new SqliteMessageRepository(pool, attachments = new SqliteAttachmentRepository(pool));
downloads = new SqliteDownloadRepository(pool); this.Path = path;
this.Statistics = new DatabaseStatistics();
this.users = new SqliteUserRepository(pool, Statistics);
this.servers = new SqliteServerRepository(pool, Statistics);
this.channels = new SqliteChannelRepository(pool, Statistics);
this.messages = new SqliteMessageRepository(pool, totalMessagesComputer, totalAttachmentsComputer);
this.downloads = new SqliteDownloadRepository(pool, totalDownloadsComputer);
totalMessagesComputer.Recompute();
totalAttachmentsComputer.Recompute();
totalDownloadsComputer.Recompute();
} }
public async ValueTask DisposeAsync() { private async Task Initialize() {
users.Dispose(); await users.Initialize();
servers.Dispose(); await servers.Initialize();
channels.Dispose(); await channels.Initialize();
messages.Dispose(); }
attachments.Dispose();
downloads.Dispose(); public void Dispose() {
await pool.DisposeAsync(); totalMessagesComputer.Cancel();
totalAttachmentsComputer.Cancel();
totalDownloadsComputer.Cancel();
pool.Dispose();
}
public async Task<DatabaseStatisticsSnapshot> SnapshotStatistics() {
return new DatabaseStatisticsSnapshot {
TotalServers = Statistics.TotalServers,
TotalChannels = Statistics.TotalChannels,
TotalUsers = Statistics.TotalUsers,
TotalMessages = await ComputeMessageStatistics(),
};
} }
public async Task Vacuum() { public async Task Vacuum() {
await using var conn = await pool.Take(); using var conn = pool.Take();
await conn.ExecuteAsync("VACUUM"); await conn.ExecuteAsync("VACUUM");
} }
private async Task<long> ComputeMessageStatistics() {
using var conn = pool.Take();
return await conn.ExecuteReaderAsync("SELECT COUNT(*) FROM messages", static reader => reader?.GetInt64(0) ?? 0L);
}
private void UpdateMessageStatistics(long totalMessages) {
Statistics.TotalMessages = totalMessages;
}
private async Task<long> ComputeAttachmentStatistics() {
using var conn = pool.Take();
return await conn.ExecuteReaderAsync("SELECT COUNT(DISTINCT normalized_url) FROM attachments", static reader => reader?.GetInt64(0) ?? 0L);
}
private void UpdateAttachmentStatistics(long totalAttachments) {
Statistics.TotalAttachments = totalAttachments;
}
private async Task<long> ComputeDownloadStatistics() {
using var conn = pool.Take();
return await conn.ExecuteReaderAsync("SELECT COUNT(*) FROM downloads", static reader => reader?.GetInt64(0) ?? 0L);
}
private void UpdateDownloadStatistics(long totalDownloads) {
Statistics.TotalDownloads = totalDownloads;
}
} }

View File

@@ -3,6 +3,6 @@ using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite.Utils; namespace DHT.Server.Database.Sqlite.Utils;
interface ISqliteConnection : IAsyncDisposable { interface ISqliteConnection : IDisposable {
SqliteConnection InnerConnection { get; } SqliteConnection InnerConnection { get; }
} }

View File

@@ -1,77 +1,100 @@
using System; using System;
using System.Collections.Concurrent;
using System.Collections.Generic; using System.Collections.Generic;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using DHT.Utils.Logging;
using DHT.Utils.Collections;
using Microsoft.Data.Sqlite; using Microsoft.Data.Sqlite;
namespace DHT.Server.Database.Sqlite.Utils; namespace DHT.Server.Database.Sqlite.Utils;
sealed class SqliteConnectionPool : IAsyncDisposable { sealed class SqliteConnectionPool : IDisposable {
public static async Task<SqliteConnectionPool> Create(SqliteConnectionStringBuilder connectionStringBuilder, int poolSize) {
var pool = new SqliteConnectionPool(poolSize);
await pool.InitializePooledConnections(connectionStringBuilder);
return pool;
}
private static string GetConnectionString(SqliteConnectionStringBuilder connectionStringBuilder) { private static string GetConnectionString(SqliteConnectionStringBuilder connectionStringBuilder) {
connectionStringBuilder.Pooling = false; connectionStringBuilder.Pooling = false;
return connectionStringBuilder.ToString(); return connectionStringBuilder.ToString();
} }
private readonly int poolSize; private readonly object monitor = new ();
private readonly List<PooledConnection> all; private readonly Random rand = new ();
private readonly ConcurrentPool<PooledConnection> free; private volatile bool isDisposed;
private readonly CancellationTokenSource disposalTokenSource = new (); private readonly BlockingCollection<PooledConnection> free = new (new ConcurrentStack<PooledConnection>());
private readonly CancellationToken disposalToken; private readonly List<PooledConnection> used;
private SqliteConnectionPool(int poolSize) { public SqliteConnectionPool(SqliteConnectionStringBuilder connectionStringBuilder, int poolSize) {
this.poolSize = poolSize;
this.all = new List<PooledConnection>(poolSize);
this.free = new ConcurrentPool<PooledConnection>(poolSize);
this.disposalToken = disposalTokenSource.Token;
}
private async Task InitializePooledConnections(SqliteConnectionStringBuilder connectionStringBuilder) {
var connectionString = GetConnectionString(connectionStringBuilder); var connectionString = GetConnectionString(connectionStringBuilder);
for (int i = 0; i < poolSize; i++) { for (int i = 0; i < poolSize; i++) {
var conn = new SqliteConnection(connectionString); var conn = new SqliteConnection(connectionString);
conn.Open(); conn.Open();
var pooledConnection = new PooledConnection(this, conn); var pooledConn = new PooledConnection(this, conn);
await using (var cmd = pooledConnection.Command("PRAGMA journal_mode=WAL")) { using (var cmd = pooledConn.Command("PRAGMA journal_mode=WAL")) {
await cmd.ExecuteNonQueryAsync(disposalToken); cmd.ExecuteNonQuery();
} }
all.Add(pooledConnection); free.Add(pooledConn);
await free.Push(pooledConnection, disposalToken); }
used = new List<PooledConnection>(poolSize);
}
private void ThrowIfDisposed() {
ObjectDisposedException.ThrowIf(isDisposed, nameof(SqliteConnectionPool));
}
public ISqliteConnection Take() {
while (true) {
ThrowIfDisposed();
lock (monitor) {
if (free.TryTake(out var conn)) {
used.Add(conn);
return conn;
}
else {
Log.ForType<SqliteConnectionPool>().Warn("Thread " + Environment.CurrentManagedThreadId + " is starving for connections.");
} }
} }
public async Task<ISqliteConnection> Take() { Thread.Sleep(TimeSpan.FromMilliseconds(rand.Next(100, 200)));
return await free.Pop(disposalToken); }
} }
private async Task Return(PooledConnection conn) { private void Return(PooledConnection conn) {
await free.Push(conn, disposalToken); ThrowIfDisposed();
lock (monitor) {
if (used.Remove(conn)) {
free.Add(conn);
}
}
} }
public async ValueTask DisposeAsync() { public void Dispose() {
if (disposalToken.IsCancellationRequested) { if (isDisposed) {
return; return;
} }
await disposalTokenSource.CancelAsync(); isDisposed = true;
foreach (var conn in all) { lock (monitor) {
await conn.InnerConnection.CloseAsync(); while (free.TryTake(out var conn)) {
await conn.InnerConnection.DisposeAsync(); Close(conn.InnerConnection);
} }
disposalTokenSource.Dispose(); foreach (var conn in used) {
Close(conn.InnerConnection);
}
free.Dispose();
used.Clear();
}
}
private static void Close(SqliteConnection conn) {
conn.Close();
conn.Dispose();
} }
private sealed class PooledConnection : ISqliteConnection { private sealed class PooledConnection : ISqliteConnection {
@@ -84,8 +107,8 @@ sealed class SqliteConnectionPool : IAsyncDisposable {
this.InnerConnection = conn; this.InnerConnection = conn;
} }
public async ValueTask DisposeAsync() { void IDisposable.Dispose() {
await pool.Return(this); pool.Return(this);
} }
} }
} }

View File

@@ -19,6 +19,11 @@ static class SqliteExtensions {
return cmd; return cmd;
} }
public static void Execute(this ISqliteConnection conn, string sql) {
using var cmd = conn.Command(sql);
cmd.ExecuteNonQuery();
}
public static async Task<int> ExecuteAsync(this ISqliteConnection conn, [LanguageInjection("sql")] string sql, CancellationToken cancellationToken = default) { public static async Task<int> ExecuteAsync(this ISqliteConnection conn, [LanguageInjection("sql")] string sql, CancellationToken cancellationToken = default) {
await using var cmd = conn.Command(sql); await using var cmd = conn.Command(sql);
return await cmd.ExecuteNonQueryAsync(cancellationToken); return await cmd.ExecuteNonQueryAsync(cancellationToken);
@@ -31,6 +36,11 @@ static class SqliteExtensions {
return reader.Read() ? readFunction(reader) : readFunction(null); return reader.Read() ? readFunction(reader) : readFunction(null);
} }
public static object? SelectScalar(this ISqliteConnection conn, string sql) {
using var cmd = conn.Command(sql);
return cmd.ExecuteScalar();
}
public static SqliteCommand Insert(this ISqliteConnection conn, string tableName, (string Name, SqliteType Type)[] columns) { public static SqliteCommand Insert(this ISqliteConnection conn, string tableName, (string Name, SqliteType Type)[] columns) {
string columnNames = string.Join(',', columns.Select(static c => c.Name)); string columnNames = string.Join(',', columns.Select(static c => c.Name));
string columnParams = string.Join(',', columns.Select(static c => ':' + c.Name)); string columnParams = string.Join(',', columns.Select(static c => ':' + c.Name));
@@ -58,7 +68,7 @@ static class SqliteExtensions {
public static SqliteCommand Delete(this ISqliteConnection conn, string tableName, (string Name, SqliteType Type) column) { public static SqliteCommand Delete(this ISqliteConnection conn, string tableName, (string Name, SqliteType Type) column) {
var cmd = conn.Command("DELETE FROM " + tableName + " WHERE " + column.Name + " = :" + column.Name); var cmd = conn.Command("DELETE FROM " + tableName + " WHERE " + column.Name + " = :" + column.Name);
CreateParameters(cmd, [column]); CreateParameters(cmd, new[] { column });
return cmd; return cmd;
} }

View File

@@ -5,7 +5,7 @@ namespace DHT.Server.Database.Sqlite.Utils;
sealed class SqliteWhereGenerator { sealed class SqliteWhereGenerator {
private readonly string? tableAlias; private readonly string? tableAlias;
private readonly bool invert; private readonly bool invert;
private readonly List<string> conditions = []; private readonly List<string> conditions = new ();
public SqliteWhereGenerator(string? tableAlias, bool invert) { public SqliteWhereGenerator(string? tableAlias, bool invert) {
this.tableAlias = tableAlias; this.tableAlias = tableAlias;

View File

@@ -4,7 +4,7 @@ using System.Collections.Frozen;
namespace DHT.Server.Download; namespace DHT.Server.Download;
static class DiscordCdn { static class DiscordCdn {
private static FrozenSet<string> CdnHosts { get; } = new[] { private static FrozenSet<string> CdnHosts { get; } = new [] {
"cdn.discordapp.com", "cdn.discordapp.com",
"cdn.discord.com", "cdn.discord.com",
}.ToFrozenSet(); }.ToFrozenSet();

View File

@@ -29,7 +29,7 @@ sealed class DownloaderTask : IAsyncDisposable {
private readonly CancellationToken cancellationToken; private readonly CancellationToken cancellationToken;
private readonly IDatabaseFile db; private readonly IDatabaseFile db;
private readonly ISubject<DownloadItem> finishedItemPublisher = Subject.Synchronize(new Subject<DownloadItem>()); private readonly Subject<DownloadItem> finishedItemPublisher = new ();
private readonly Task queueWriterTask; private readonly Task queueWriterTask;
private readonly Task[] downloadTasks; private readonly Task[] downloadTasks;

View File

@@ -13,6 +13,7 @@
<ItemGroup> <ItemGroup>
<PackageReference Include="Microsoft.Data.Sqlite" Version="8.0.0" /> <PackageReference Include="Microsoft.Data.Sqlite" Version="8.0.0" />
<PackageReference Include="System.Linq.Async" Version="6.0.1" /> <PackageReference Include="System.Linq.Async" Version="6.0.1" />
<PackageReference Include="System.Reactive" Version="6.0.0" />
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>

View File

@@ -11,12 +11,12 @@ using Microsoft.Extensions.Hosting;
namespace DHT.Server.Service; namespace DHT.Server.Service;
sealed class Startup { sealed class Startup {
private static readonly string[] AllowedOrigins = [ private static readonly string[] AllowedOrigins = {
"https://discord.com", "https://discord.com",
"https://ptb.discord.com", "https://ptb.discord.com",
"https://canary.discord.com", "https://canary.discord.com",
"https://discordapp.com" "https://discordapp.com",
]; };
public void ConfigureServices(IServiceCollection services) { public void ConfigureServices(IServiceCollection services) {
services.Configure<JsonOptions>(static options => { services.Configure<JsonOptions>(static options => {

View File

@@ -22,6 +22,6 @@ public sealed class State : IAsyncDisposable {
public async ValueTask DisposeAsync() { public async ValueTask DisposeAsync() {
await Downloader.Stop(); await Downloader.Stop();
await Server.Stop(); await Server.Stop();
await Db.DisposeAsync(); Db.Dispose();
} }
} }

View File

@@ -1,45 +0,0 @@
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
namespace DHT.Utils.Collections;
public sealed class ConcurrentPool<T> {
private readonly SemaphoreSlim mutexSemaphore;
private readonly SemaphoreSlim availableItemSemaphore;
private readonly Stack<T> items;
public ConcurrentPool(int size) {
mutexSemaphore = new SemaphoreSlim(1);
availableItemSemaphore = new SemaphoreSlim(0, size);
items = new Stack<T>();
}
public async Task Push(T item, CancellationToken cancellationToken) {
await PushItem(item, cancellationToken);
availableItemSemaphore.Release();
}
public async Task<T> Pop(CancellationToken cancellationToken) {
await availableItemSemaphore.WaitAsync(cancellationToken);
return await PopItem(cancellationToken);
}
private async Task PushItem(T item, CancellationToken cancellationToken) {
await mutexSemaphore.WaitAsync(cancellationToken);
try {
items.Push(item);
} finally {
mutexSemaphore.Release();
}
}
private async Task<T> PopItem(CancellationToken cancellationToken) {
await mutexSemaphore.WaitAsync(cancellationToken);
try {
return items.Pop();
} finally {
mutexSemaphore.Release();
}
}
}

View File

@@ -0,0 +1,19 @@
using System.Collections.Generic;
namespace DHT.Utils.Collections;
public sealed class MultiDictionary<TKey, TValue> where TKey : notnull {
private readonly Dictionary<TKey, List<TValue>> dict = new();
public void Add(TKey key, TValue value) {
if (!dict.TryGetValue(key, out var list)) {
dict[key] = list = new List<TValue>();
}
list.Add(value);
}
public List<TValue>? GetListOrNull(TKey key) {
return dict.TryGetValue(key, out var list) ? list : null;
}
}

View File

@@ -13,20 +13,42 @@ public static class HttpOutput {
} }
} }
public sealed class Text(string text) : IHttpOutput { public sealed class Text : IHttpOutput {
private readonly string text;
public Text(string text) {
this.text = text;
}
public Task WriteTo(HttpResponse response) { public Task WriteTo(HttpResponse response) {
return response.WriteAsync(text, Encoding.UTF8); return response.WriteAsync(text, Encoding.UTF8);
} }
} }
public sealed class File(string? contentType, byte[] bytes) : IHttpOutput { public sealed class File : IHttpOutput {
private readonly string? contentType;
private readonly byte[] bytes;
public File(string? contentType, byte[] bytes) {
this.contentType = contentType;
this.bytes = bytes;
}
public async Task WriteTo(HttpResponse response) { public async Task WriteTo(HttpResponse response) {
response.ContentType = contentType ?? string.Empty; response.ContentType = contentType ?? string.Empty;
await response.Body.WriteAsync(bytes); await response.Body.WriteAsync(bytes);
} }
} }
public sealed class Redirect(string url, bool permanent) : IHttpOutput { public sealed class Redirect : IHttpOutput {
private readonly string url;
private readonly bool permanent;
public Redirect(string url, bool permanent) {
this.url = url;
this.permanent = permanent;
}
public Task WriteTo(HttpResponse response) { public Task WriteTo(HttpResponse response) {
response.Redirect(url, permanent); response.Redirect(url, permanent);
return Task.CompletedTask; return Task.CompletedTask;

View File

@@ -5,4 +5,4 @@ namespace DHT.Utils.Http;
[JsonSourceGenerationOptions(PropertyNamingPolicy = JsonKnownNamingPolicy.CamelCase, GenerationMode = JsonSourceGenerationMode.Default)] [JsonSourceGenerationOptions(PropertyNamingPolicy = JsonKnownNamingPolicy.CamelCase, GenerationMode = JsonSourceGenerationMode.Default)]
[JsonSerializable(typeof(JsonElement))] [JsonSerializable(typeof(JsonElement))]
public sealed partial class JsonElementContext : JsonSerializerContext; public sealed partial class JsonElementContext : JsonSerializerContext {}

View File

@@ -0,0 +1,22 @@
using System.Collections.Generic;
using System.ComponentModel;
using System.Runtime.CompilerServices;
using JetBrains.Annotations;
namespace DHT.Utils.Models;
public abstract class BaseModel : INotifyPropertyChanged {
public event PropertyChangedEventHandler? PropertyChanged;
[NotifyPropertyChangedInvocator]
protected void OnPropertyChanged([CallerMemberName] string? propertyName = null) {
PropertyChanged?.Invoke(this, new PropertyChangedEventArgs(propertyName));
}
protected void Change<T>(ref T field, T newValue, [CallerMemberName] string? propertyName = null) {
if (!EqualityComparer<T>.Default.Equals(field, newValue)) {
field = newValue;
OnPropertyChanged(propertyName);
}
}
}

View File

@@ -6,7 +6,13 @@ using System.Threading.Tasks;
namespace DHT.Utils.Resources; namespace DHT.Utils.Resources;
public sealed class ResourceLoader(Assembly assembly) { public sealed class ResourceLoader {
private readonly Assembly assembly;
public ResourceLoader(Assembly assembly) {
this.assembly = assembly;
}
private Stream GetEmbeddedStream(string filename) { private Stream GetEmbeddedStream(string filename) {
Stream? stream = null; Stream? stream = null;
foreach (var embeddedName in assembly.GetManifestResourceNames()) { foreach (var embeddedName in assembly.GetManifestResourceNames()) {
@@ -29,7 +35,7 @@ public sealed class ResourceLoader(Assembly assembly) {
} }
public async Task<string> ReadJoinedAsync(string path, char separator) { public async Task<string> ReadJoinedAsync(string path, char separator) {
StringBuilder joined = new (); StringBuilder joined = new();
foreach (var embeddedName in assembly.GetManifestResourceNames()) { foreach (var embeddedName in assembly.GetManifestResourceNames()) {
if (embeddedName.Replace('\\', '/').StartsWith(path)) { if (embeddedName.Replace('\\', '/').StartsWith(path)) {

View File

@@ -0,0 +1,130 @@
using System;
using System.Diagnostics.CodeAnalysis;
using System.Threading;
using System.Threading.Tasks;
namespace DHT.Utils.Tasks;
public sealed class AsyncValueComputer<TValue> {
private readonly Action<TValue> resultProcessor;
private readonly TaskScheduler resultTaskScheduler;
private readonly bool processOutdatedResults;
private readonly object stateLock = new ();
private SoftHardCancellationToken? currentCancellationTokenSource;
private bool wasHardCancelled = false;
private Func<Task<TValue>>? currentComputeFunction;
private bool hasComputeFunctionChanged = false;
private AsyncValueComputer(Action<TValue> resultProcessor, TaskScheduler resultTaskScheduler, bool processOutdatedResults) {
this.resultProcessor = resultProcessor;
this.resultTaskScheduler = resultTaskScheduler;
this.processOutdatedResults = processOutdatedResults;
}
public void Cancel() {
lock (stateLock) {
wasHardCancelled = true;
currentCancellationTokenSource?.RequestHardCancellation();
}
}
public void Compute(Func<Task<TValue>> func) {
lock (stateLock) {
wasHardCancelled = false;
if (currentComputeFunction != null) {
currentComputeFunction = func;
hasComputeFunctionChanged = true;
currentCancellationTokenSource?.RequestSoftCancellation();
}
else {
EnqueueComputation(func);
}
}
}
[SuppressMessage("ReSharper", "MethodSupportsCancellation")]
private void EnqueueComputation(Func<Task<TValue>> func) {
var cancellationTokenSource = new SoftHardCancellationToken();
currentCancellationTokenSource?.RequestSoftCancellation();
currentCancellationTokenSource = cancellationTokenSource;
currentComputeFunction = func;
hasComputeFunctionChanged = false;
var task = Task.Run(func);
task.ContinueWith(t => {
if (!cancellationTokenSource.IsCancelled(processOutdatedResults)) {
resultProcessor(t.Result);
}
}, CancellationToken.None, TaskContinuationOptions.NotOnFaulted, resultTaskScheduler);
task.ContinueWith(_ => {
lock (stateLock) {
cancellationTokenSource.Dispose();
if (currentCancellationTokenSource == cancellationTokenSource) {
currentCancellationTokenSource = null;
}
if (hasComputeFunctionChanged && !wasHardCancelled) {
EnqueueComputation(currentComputeFunction);
}
else {
currentComputeFunction = null;
hasComputeFunctionChanged = false;
}
}
});
}
public sealed class Single {
private readonly AsyncValueComputer<TValue> baseComputer;
private readonly Func<Task<TValue>> resultComputer;
internal Single(AsyncValueComputer<TValue> baseComputer, Func<Task<TValue>> resultComputer) {
this.baseComputer = baseComputer;
this.resultComputer = resultComputer;
}
public void Recompute() {
baseComputer.Compute(resultComputer);
}
public void Cancel() {
baseComputer.Cancel();
}
}
public static Builder WithResultProcessor(Action<TValue> resultProcessor, TaskScheduler? scheduler = null) {
return new Builder(resultProcessor, scheduler ?? TaskScheduler.FromCurrentSynchronizationContext());
}
public sealed class Builder {
private readonly Action<TValue> resultProcessor;
private readonly TaskScheduler resultTaskScheduler;
private bool processOutdatedResults;
internal Builder(Action<TValue> resultProcessor, TaskScheduler resultTaskScheduler) {
this.resultProcessor = resultProcessor;
this.resultTaskScheduler = resultTaskScheduler;
}
public Builder WithOutdatedResults() {
this.processOutdatedResults = true;
return this;
}
public AsyncValueComputer<TValue> Build() {
return new AsyncValueComputer<TValue>(resultProcessor, resultTaskScheduler, processOutdatedResults);
}
public Single BuildWithComputer(Func<Task<TValue>> resultComputer) {
return new Single(Build(), resultComputer);
}
}
}

View File

@@ -1,31 +0,0 @@
using System;
using System.Reactive.Subjects;
using System.Threading;
using System.Threading.Tasks;
namespace DHT.Utils.Tasks;
public sealed class ObservableThrottledTask<T> : IObservable<T>, IDisposable {
private readonly ReplaySubject<T> subject;
private readonly ThrottledTask<T> task;
public ObservableThrottledTask(TaskScheduler resultScheduler) {
this.subject = new ReplaySubject<T>(bufferSize: 1);
this.task = new ThrottledTask<T>(subject.OnNext, resultScheduler);
}
public void Post(Func<CancellationToken, Task<T>> resultComputer) {
task.Post(resultComputer);
}
public IDisposable Subscribe(IObserver<T> observer) {
return subject.Subscribe(observer);
}
public void Dispose() {
task.Dispose();
subject.OnCompleted();
subject.Dispose();
}
}

View File

@@ -0,0 +1,39 @@
using System;
using System.Threading;
namespace DHT.Utils.Tasks;
/// <summary>
/// Manages a pair of cancellation tokens that follow these rules:
/// <list type="number">
/// <item><description>If the soft token is cancelled, the hard token remains uncancelled.</description></item>
/// <item><description>If the hard token is cancelled, the soft token is also cancelled.</description></item>
/// </list>
/// </summary>
sealed class SoftHardCancellationToken : IDisposable {
private readonly CancellationTokenSource soft;
private readonly CancellationTokenSource hard;
public SoftHardCancellationToken() {
this.soft = new CancellationTokenSource();
this.hard = new CancellationTokenSource();
}
public bool IsCancelled(bool onlyHardCancellation) {
return (onlyHardCancellation ? hard : soft).IsCancellationRequested;
}
public void RequestSoftCancellation() {
soft.Cancel();
}
public void RequestHardCancellation() {
soft.Cancel();
hard.Cancel();
}
public void Dispose() {
soft.Dispose();
hard.Dispose();
}
}

View File

@@ -16,7 +16,6 @@
<ItemGroup> <ItemGroup>
<PackageReference Include="JetBrains.Annotations" Version="2023.2.0" /> <PackageReference Include="JetBrains.Annotations" Version="2023.2.0" />
<PackageReference Include="System.Reactive" Version="6.0.0" />
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>

View File

@@ -8,5 +8,5 @@ using DHT.Utils;
namespace DHT.Utils; namespace DHT.Utils;
static class Version { static class Version {
public const string Tag = "40.0.0.0"; public const string Tag = "39.1.0.0";
} }